Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/google_genai/adapters/transformation.py
2026-03-26 20:06:14 +08:00

784 lines
31 KiB
Python

import json
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast
from litellm import verbose_logger
from litellm.litellm_core_utils.json_validation_rule import normalize_tool_schema
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionAssistantMessage,
ChatCompletionAssistantToolCall,
ChatCompletionImageObject,
ChatCompletionRequest,
ChatCompletionSystemMessage,
ChatCompletionTextObject,
ChatCompletionToolCallFunctionChunk,
ChatCompletionToolChoiceValues,
ChatCompletionToolMessage,
ChatCompletionToolParam,
ChatCompletionUserMessage,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import (
AdapterCompletionStreamWrapper,
Choices,
ModelResponse,
ModelResponseStream,
StreamingChoices,
)
class GoogleGenAIStreamWrapper(AdapterCompletionStreamWrapper):
"""
Wrapper for streaming Google GenAI generate_content responses.
Transforms OpenAI streaming chunks to Google GenAI format.
"""
sent_first_chunk: bool = False
# State tracking for accumulating partial tool calls
accumulated_tool_calls: Dict[str, Dict[str, Any]]
def __init__(self, completion_stream: Any):
self.sent_first_chunk = False
self.accumulated_tool_calls = {}
self._returned_response = False
super().__init__(completion_stream)
def __next__(self):
try:
if not hasattr(self.completion_stream, "__iter__"):
if self._returned_response:
raise StopIteration
self._returned_response = True
return GoogleGenAIAdapter().translate_completion_to_generate_content(
self.completion_stream
)
for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
continue
transformed_chunk = GoogleGenAIAdapter().translate_streaming_completion_to_generate_content(
chunk, self
)
if transformed_chunk:
return transformed_chunk
raise StopIteration
except StopIteration:
raise
except Exception:
raise StopIteration
async def __anext__(self):
try:
if not hasattr(self.completion_stream, "__aiter__"):
if self._returned_response:
raise StopAsyncIteration
self._returned_response = True
return GoogleGenAIAdapter().translate_completion_to_generate_content(
self.completion_stream
)
async for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
continue
transformed_chunk = GoogleGenAIAdapter().translate_streaming_completion_to_generate_content(
chunk, self
)
if transformed_chunk:
return transformed_chunk
# After the stream is exhausted, check for any remaining accumulated tool calls
if self.accumulated_tool_calls:
try:
parts = []
for (
tool_call_index,
tool_call_data,
) in self.accumulated_tool_calls.items():
try:
# For tool calls with no arguments, accumulated_args will be "", which is not valid JSON.
# We default to an empty JSON object in this case.
parsed_args = json.loads(
tool_call_data["arguments"] or "{}"
)
function_call_part = {
"functionCall": {
"name": tool_call_data["name"]
or "undefined_tool_name",
"args": parsed_args,
}
}
parts.append(function_call_part)
except json.JSONDecodeError:
# This can happen if the stream is abruptly cut off mid-argument string.
verbose_logger.warning(
f"Could not parse tool call arguments at end of stream for index {tool_call_index}. "
f"Name: {tool_call_data['name']}. "
f"Partial args: {tool_call_data['arguments']}"
)
pass
if parts:
final_chunk = {
"candidates": [
{
"content": {"parts": parts, "role": "model"},
"finishReason": "STOP",
"index": 0,
"safetyRatings": [],
}
]
}
return final_chunk
finally:
# Ensure the accumulator is always cleared to prevent memory leaks
self.accumulated_tool_calls.clear()
raise StopAsyncIteration
except StopAsyncIteration:
raise
except Exception:
raise StopAsyncIteration
def google_genai_sse_wrapper(self) -> Iterator[bytes]:
"""
Convert Google GenAI streaming chunks to Server-Sent Events format.
"""
for chunk in self.completion_stream:
if isinstance(chunk, dict):
payload = f"data: {json.dumps(chunk)}\n\n"
yield payload.encode()
else:
yield chunk
async def async_google_genai_sse_wrapper(self) -> AsyncIterator[bytes]:
"""
Async version of google_genai_sse_wrapper.
"""
from litellm.types.utils import ModelResponseStream
async for chunk in self.completion_stream:
if isinstance(chunk, dict):
payload = f"data: {json.dumps(chunk)}\n\n"
yield payload.encode()
elif isinstance(chunk, ModelResponseStream):
# Transform OpenAI streaming chunk to Google GenAI format
transformed_chunk = GoogleGenAIAdapter().translate_streaming_completion_to_generate_content(
chunk, self
)
if isinstance(transformed_chunk, dict): # Only return non-empty chunks
payload = f"data: {json.dumps(transformed_chunk)}\n\n"
yield payload.encode()
else:
# For empty chunks, continue to next iteration
continue
else:
# For other chunk types, yield them directly
if hasattr(chunk, "encode"):
yield chunk.encode()
else:
yield str(chunk).encode()
class GoogleGenAIAdapter:
"""Adapter for transforming Google GenAI generate_content requests to/from litellm.completion format"""
def __init__(self) -> None:
pass
def translate_generate_content_to_completion(
self,
model: str,
contents: Union[List[Dict[str, Any]], Dict[str, Any]],
config: Optional[Dict[str, Any]] = None,
litellm_params: Optional[GenericLiteLLMParams] = None,
**kwargs,
) -> Dict[str, Any]:
"""
Transform generate_content request to litellm completion format
Args:
model: The model name
contents: Generate content contents (can be list or single dict)
config: Optional config parameters
**kwargs: Additional parameters from the original request
Returns:
Dict in OpenAI format
"""
# Extract top-level fields from kwargs
system_instruction = kwargs.get("systemInstruction") or kwargs.get(
"system_instruction"
)
tools = kwargs.get("tools")
tool_config = kwargs.get("toolConfig") or kwargs.get("tool_config")
# Normalize contents to list format
if isinstance(contents, dict):
contents_list = [contents]
else:
contents_list = contents
# Transform contents to OpenAI messages format
messages = self._transform_contents_to_messages(
contents_list, system_instruction=system_instruction
)
# Create base request as dict (which is compatible with ChatCompletionRequest)
completion_request: ChatCompletionRequest = {
"model": model,
"messages": messages,
}
#########################################################
# Supported OpenAI chat completion params
# - temperature
# - max_tokens
# - top_p
# - frequency_penalty
# - presence_penalty
# - stop
# - tools
# - tool_choice
#########################################################
# Add config parameters if provided
if config:
# Map common Google GenAI config parameters to OpenAI equivalents
if "temperature" in config:
completion_request["temperature"] = config["temperature"]
if "maxOutputTokens" in config:
completion_request["max_tokens"] = config["maxOutputTokens"]
if "topP" in config:
completion_request["top_p"] = config["topP"]
if "topK" in config:
# OpenAI doesn't have direct topK, but we can pass it as extra
pass
if "stopSequences" in config:
completion_request["stop"] = config["stopSequences"]
# Handle tools transformation
if tools:
# Check if tools are already in OpenAI format or Google GenAI format
if isinstance(tools, list) and len(tools) > 0:
# Tools are in Google GenAI format, transform them
openai_tools = self._transform_google_genai_tools_to_openai(tools)
if openai_tools:
completion_request["tools"] = openai_tools
# Handle tool_config (tool choice)
if tool_config:
tool_choice = self._transform_google_genai_tool_config_to_openai(
tool_config
)
if tool_choice:
completion_request["tool_choice"] = tool_choice
#########################################################
# forward any litellm specific params
#########################################################
completion_request_dict = dict(completion_request)
if litellm_params:
completion_request_dict = self._add_generic_litellm_params_to_request(
completion_request_dict=completion_request_dict,
litellm_params=litellm_params,
)
return completion_request_dict
def _add_generic_litellm_params_to_request(
self,
completion_request_dict: Dict[str, Any],
litellm_params: Optional[GenericLiteLLMParams] = None,
) -> dict:
"""Add generic litellm params to request. e.g add api_base, api_key, api_version, etc.
Args:
completion_request_dict: Dict[str, Any]
litellm_params: GenericLiteLLMParams
Returns:
Dict[str, Any]
"""
allowed_fields = GenericLiteLLMParams.model_fields.keys()
if litellm_params:
litellm_dict = litellm_params.model_dump(exclude_none=True)
for key, value in litellm_dict.items():
if key in allowed_fields:
completion_request_dict[key] = value
return completion_request_dict
def translate_completion_output_params_streaming(
self,
completion_stream: Any,
) -> Union[AsyncIterator[bytes], None]:
"""Transform streaming completion output to Google GenAI format"""
google_genai_wrapper = GoogleGenAIStreamWrapper(
completion_stream=completion_stream
)
# Return the SSE-wrapped version for proper event formatting
return google_genai_wrapper.async_google_genai_sse_wrapper()
def _transform_google_genai_tools_to_openai(
self,
tools: List[Dict[str, Any]],
) -> List[ChatCompletionToolParam]:
"""Transform Google GenAI tools to OpenAI tools format"""
openai_tools: List[Dict[str, Any]] = []
for tool in tools:
if "functionDeclarations" in tool:
for func_decl in tool["functionDeclarations"]:
function_chunk: Dict[str, Any] = {
"name": func_decl.get("name", ""),
}
if "description" in func_decl:
function_chunk["description"] = func_decl["description"]
if "parametersJsonSchema" in func_decl:
function_chunk["parameters"] = func_decl["parametersJsonSchema"]
openai_tool = {"type": "function", "function": function_chunk}
openai_tools.append(openai_tool)
# normalize the tool schemas
normalized_tools = [normalize_tool_schema(tool) for tool in openai_tools]
return cast(List[ChatCompletionToolParam], normalized_tools)
def _transform_google_genai_tool_config_to_openai(
self,
tool_config: Dict[str, Any],
) -> Optional[ChatCompletionToolChoiceValues]:
"""Transform Google GenAI tool_config to OpenAI tool_choice"""
function_calling_config = tool_config.get("functionCallingConfig", {})
mode = function_calling_config.get("mode", "AUTO")
mode_mapping = {"AUTO": "auto", "ANY": "required", "NONE": "none"}
tool_choice = mode_mapping.get(mode, "auto")
return cast(ChatCompletionToolChoiceValues, tool_choice)
def _transform_contents_to_messages(
self,
contents: List[Dict[str, Any]],
system_instruction: Optional[Dict[str, Any]] = None,
) -> List[AllMessageValues]:
"""Transform Google GenAI contents to OpenAI messages format"""
messages: List[AllMessageValues] = []
# Handle system instruction
if system_instruction:
system_parts = system_instruction.get("parts", [])
if system_parts and "text" in system_parts[0]:
messages.append(
ChatCompletionSystemMessage(
role="system", content=system_parts[0]["text"]
)
)
for content in contents:
role = content.get("role", "user")
parts = content.get("parts", [])
if role == "user":
# Handle user messages with potential function responses
content_parts: List[
Union[ChatCompletionTextObject, ChatCompletionImageObject]
] = []
tool_messages: List[ChatCompletionToolMessage] = []
for part in parts:
if isinstance(part, dict):
if "text" in part:
content_parts.append(
cast(
ChatCompletionTextObject,
{"type": "text", "text": part["text"]},
)
)
elif "inline_data" in part:
# Handle Base64 image data
inline_data = part["inline_data"]
mime_type = inline_data.get("mime_type", "image/jpeg")
data = inline_data.get("data", "")
content_parts.append(
cast(
ChatCompletionImageObject,
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{data}"
},
},
)
)
elif "functionResponse" in part:
# Transform function response to tool message
func_response = part["functionResponse"]
tool_message = ChatCompletionToolMessage(
role="tool",
tool_call_id=f"call_{func_response.get('name', 'unknown')}",
content=json.dumps(func_response.get("response", {})),
)
tool_messages.append(tool_message)
elif isinstance(part, str):
content_parts.append(
cast(
ChatCompletionTextObject, {"type": "text", "text": part}
)
)
# Add user message if there's content
if content_parts:
# If only one text part, use simple string format for backward compatibility
if (
len(content_parts) == 1
and isinstance(content_parts[0], dict)
and content_parts[0].get("type") == "text"
):
text_part = cast(ChatCompletionTextObject, content_parts[0])
messages.append(
ChatCompletionUserMessage(
role="user", content=text_part["text"]
)
)
else:
# Use multimodal format (array of content parts)
messages.append(
ChatCompletionUserMessage(
role="user", content=content_parts
)
)
# Add tool messages
messages.extend(tool_messages)
elif role == "model":
# Handle assistant messages with potential function calls
combined_text = ""
tool_calls: List[ChatCompletionAssistantToolCall] = []
for part in parts:
if isinstance(part, dict):
if "text" in part:
combined_text += part["text"]
elif "functionCall" in part:
# Transform function call to tool call
func_call = part["functionCall"]
tool_call = ChatCompletionAssistantToolCall(
id=f"call_{func_call.get('name', 'unknown')}",
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=func_call.get("name", ""),
arguments=json.dumps(func_call.get("args", {})),
),
)
tool_calls.append(tool_call)
elif isinstance(part, str):
combined_text += part
# Create assistant message
if tool_calls:
assistant_message = ChatCompletionAssistantMessage(
role="assistant",
content=combined_text if combined_text else None,
tool_calls=tool_calls,
)
else:
assistant_message = ChatCompletionAssistantMessage(
role="assistant",
content=combined_text if combined_text else None,
)
messages.append(assistant_message)
return messages
def translate_completion_to_generate_content(
self,
response: ModelResponse,
) -> Dict[str, Any]:
"""
Transform litellm completion response to Google GenAI generate_content format
Args:
response: ModelResponse from litellm.completion
Returns:
Dict in Google GenAI generate_content response format
"""
# Extract the main response content
choice = response.choices[0] if response.choices else None
if not choice:
raise ValueError("Invalid completion response: no choices found")
# Handle different choice types (Choices vs StreamingChoices)
if isinstance(choice, Choices):
if not choice.message:
raise ValueError(
"Invalid completion response: no message found in choice"
)
parts = self._transform_openai_message_to_google_genai_parts(choice.message)
else:
# Fallback for generic choice objects
message_content = getattr(choice, "message", {}).get(
"content", ""
) or getattr(choice, "delta", {}).get("content", "")
parts = [{"text": message_content}] if message_content else []
# Create Google GenAI format response
generate_content_response: Dict[str, Any] = {
"candidates": [
{
"content": {"parts": parts, "role": "model"},
"finishReason": self._map_finish_reason(
getattr(choice, "finish_reason", None)
),
"index": 0,
"safetyRatings": [],
}
],
"usageMetadata": (
self._map_usage(getattr(response, "usage", None))
if hasattr(response, "usage") and getattr(response, "usage", None)
else {
"promptTokenCount": 0,
"candidatesTokenCount": 0,
"totalTokenCount": 0,
}
),
}
# Add text field for convenience (common in Google GenAI responses)
text_content = ""
for part in parts:
if isinstance(part, dict) and "text" in part:
text_content += part["text"]
if text_content:
generate_content_response["text"] = text_content
return generate_content_response
def translate_streaming_completion_to_generate_content(
self,
response: Union[ModelResponse, ModelResponseStream],
wrapper: GoogleGenAIStreamWrapper,
) -> Optional[Dict[str, Any]]:
"""
Transform streaming litellm completion chunk to Google GenAI generate_content format
Args:
response: Streaming ModelResponse chunk from litellm.completion
wrapper: GoogleGenAIStreamWrapper instance
Returns:
Dict in Google GenAI streaming generate_content response format
"""
# Extract the main response content from streaming chunk
choice = response.choices[0] if response.choices else None
if not choice:
# Return empty chunk if no choices
return None
# Handle streaming choice
if isinstance(choice, StreamingChoices):
if choice.delta:
parts = self._transform_openai_delta_to_google_genai_parts_with_accumulation(
choice.delta, wrapper
)
else:
parts = []
finish_reason = getattr(choice, "finish_reason", None)
else:
# Fallback for generic choice objects
message_content = getattr(choice, "delta", {}).get("content", "")
parts = [{"text": message_content}] if message_content else []
finish_reason = getattr(choice, "finish_reason", None)
# Only create response chunk if we have parts or it's the final chunk
if not parts and not finish_reason:
return None
# Create Google GenAI streaming format response
streaming_chunk: Dict[str, Any] = {
"candidates": [
{
"content": {"parts": parts, "role": "model"},
"finishReason": (
self._map_finish_reason(finish_reason)
if finish_reason
else None
),
"index": 0,
"safetyRatings": [],
}
]
}
# Add usage metadata only in the final chunk (when finish_reason is present)
if finish_reason:
usage_metadata = (
self._map_usage(getattr(response, "usage", None))
if hasattr(response, "usage") and getattr(response, "usage", None)
else {
"promptTokenCount": 0,
"candidatesTokenCount": 0,
"totalTokenCount": 0,
}
)
streaming_chunk["usageMetadata"] = usage_metadata
# Add text field for convenience (common in Google GenAI responses)
text_content = ""
for part in parts:
if isinstance(part, dict) and "text" in part:
text_content += part["text"]
if text_content:
streaming_chunk["text"] = text_content
return streaming_chunk
def _transform_openai_message_to_google_genai_parts(
self,
message: Any,
) -> List[Dict[str, Any]]:
"""Transform OpenAI message to Google GenAI parts format"""
parts: List[Dict[str, Any]] = []
# Add text content if present
if hasattr(message, "content") and message.content:
parts.append({"text": message.content})
# Add tool calls if present
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
if hasattr(tool_call, "function") and tool_call.function:
try:
args = (
json.loads(tool_call.function.arguments)
if tool_call.function.arguments
else {}
)
except json.JSONDecodeError:
args = {}
function_call_part = {
"functionCall": {
"name": tool_call.function.name or "undefined_tool_name",
"args": args,
}
}
parts.append(function_call_part)
return parts if parts else [{"text": ""}]
def _transform_openai_delta_to_google_genai_parts_with_accumulation(
self, delta: Any, wrapper: GoogleGenAIStreamWrapper
) -> List[Dict[str, Any]]:
"""Transforms OpenAI delta to Google GenAI parts, accumulating streaming tool calls."""
# 1. Initialize wrapper state if it doesn't exist
if not hasattr(wrapper, "accumulated_tool_calls"):
wrapper.accumulated_tool_calls = {}
parts: List[Dict[str, Any]] = []
if hasattr(delta, "content") and delta.content:
parts.append({"text": delta.content})
# 2. Ensure tool_calls is iterable
tool_calls = delta.tool_calls or []
for tool_call in tool_calls:
if not hasattr(tool_call, "function"):
continue
# 3. Use `index` as the primary key for accumulation
tool_call_index = getattr(tool_call, "index", None)
if tool_call_index is None:
continue # Index is essential for tracking streaming tool calls
# Initialize accumulator for this index if it's new
if tool_call_index not in wrapper.accumulated_tool_calls:
wrapper.accumulated_tool_calls[tool_call_index] = {
"name": "",
"arguments": "",
}
# Accumulate name and arguments
function_name = getattr(tool_call.function, "name", None)
args_chunk = getattr(tool_call.function, "arguments", None)
# Optimization: Skip chunks that have no new data
if not function_name and not args_chunk:
verbose_logger.debug(
f"Skipping empty tool call chunk for index: {tool_call_index}"
)
continue
if function_name:
wrapper.accumulated_tool_calls[tool_call_index]["name"] = function_name
if args_chunk:
wrapper.accumulated_tool_calls[tool_call_index][
"arguments"
] += args_chunk
# Attempt to parse and emit a complete tool call
accumulated_data = wrapper.accumulated_tool_calls[tool_call_index]
accumulated_name = accumulated_data["name"]
accumulated_args = accumulated_data["arguments"]
# 5. Attempt to parse arguments even if name hasn't arrived.
try:
# Attempt to parse the accumulated arguments string
parsed_args = json.loads(accumulated_args)
# If parsing succeeds, but we don't have a name yet, wait.
# The part will be created by a later chunk that brings the name.
if accumulated_name:
# If successful, create the part and clean up
function_call_part = {
"functionCall": {"name": accumulated_name, "args": parsed_args}
}
parts.append(function_call_part)
# Remove the completed tool call from the accumulator
del wrapper.accumulated_tool_calls[tool_call_index]
except json.JSONDecodeError:
# The JSON for arguments is still incomplete.
# We will continue to accumulate and wait for more chunks.
pass
return parts
def _map_finish_reason(self, finish_reason: Optional[str]) -> str:
"""Map OpenAI finish reasons to Google GenAI finish reasons"""
if not finish_reason:
return "STOP"
mapping = {
"stop": "STOP",
"length": "MAX_TOKENS",
"content_filter": "SAFETY",
"tool_calls": "STOP",
"function_call": "STOP",
}
return mapping.get(finish_reason, "STOP")
def _map_usage(self, usage: Any) -> Dict[str, int]:
"""Map OpenAI usage to Google GenAI usage format"""
return {
"promptTokenCount": getattr(usage, "prompt_tokens", 0) or 0,
"candidatesTokenCount": getattr(usage, "completion_tokens", 0) or 0,
"totalTokenCount": getattr(usage, "total_tokens", 0) or 0,
}