chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,12 @@
"""
Vertex AI Agent Engine (Reasoning Engines) Provider
Supports Vertex AI Reasoning Engines via the :query and :streamQuery endpoints.
"""
from litellm.llms.vertex_ai.agent_engine.transformation import (
VertexAgentEngineConfig,
VertexAgentEngineError,
)
__all__ = ["VertexAgentEngineConfig", "VertexAgentEngineError"]

View File

@@ -0,0 +1,90 @@
"""
SSE Stream Iterator for Vertex AI Agent Engine.
Handles Server-Sent Events (SSE) streaming responses from Vertex AI Reasoning Engines.
"""
from typing import Any, Union
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.types.llms.openai import ChatCompletionUsageBlock
from litellm.types.utils import (
Delta,
GenericStreamingChunk,
ModelResponseStream,
StreamingChoices,
)
class VertexAgentEngineResponseIterator(BaseModelResponseIterator):
"""
Iterator for Vertex Agent Engine SSE streaming responses.
Uses BaseModelResponseIterator which handles sync/async iteration.
We just need to implement chunk_parser to parse Vertex Agent Engine response format.
"""
def __init__(self, streaming_response: Any, sync_stream: bool) -> None:
super().__init__(streaming_response=streaming_response, sync_stream=sync_stream)
def chunk_parser(
self, chunk: dict
) -> Union[GenericStreamingChunk, ModelResponseStream]:
"""
Parse a Vertex Agent Engine response chunk into ModelResponseStream.
Vertex Agent Engine response format:
{
"content": {
"parts": [{"text": "..."}],
"role": "model"
},
"finish_reason": "STOP",
"usage_metadata": {
"prompt_token_count": 100,
"candidates_token_count": 50,
"total_token_count": 150
}
}
"""
# Extract text from content.parts
text = None
content = chunk.get("content", {})
parts = content.get("parts", [])
for part in parts:
if isinstance(part, dict) and "text" in part:
text = part["text"]
break
# Extract finish_reason
finish_reason = None
raw_finish_reason = chunk.get("finish_reason")
if raw_finish_reason == "STOP":
finish_reason = "stop"
elif raw_finish_reason:
finish_reason = raw_finish_reason.lower()
# Extract usage from usage_metadata
usage = None
usage_metadata = chunk.get("usage_metadata", {})
if usage_metadata:
usage = ChatCompletionUsageBlock(
prompt_tokens=usage_metadata.get("prompt_token_count", 0),
completion_tokens=usage_metadata.get("candidates_token_count", 0),
total_tokens=usage_metadata.get("total_token_count", 0),
)
# Return ModelResponseStream (OpenAI-compatible chunk)
return ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=finish_reason,
index=0,
delta=Delta(
content=text,
role="assistant" if text else None,
),
)
],
usage=usage,
)

View File

@@ -0,0 +1,517 @@
"""
Transformation for Vertex AI Agent Engine (Reasoning Engines)
Handles the transformation between LiteLLM's OpenAI-compatible format and
Vertex AI Reasoning Engine's API format.
API Reference:
- :query endpoint - for session management (create, get, list, delete)
- :streamQuery endpoint - for actual queries (stream_query method)
"""
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import httpx
from litellm._logging import verbose_logger
from litellm._uuid import uuid
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.vertex_ai.agent_engine.sse_iterator import (
VertexAgentEngineResponseIterator,
)
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
HTTPHandler = Any
AsyncHTTPHandler = Any
CustomStreamWrapper = Any
class VertexAgentEngineError(BaseLLMException):
"""Exception for Vertex Agent Engine errors."""
def __init__(self, status_code: int, message: str):
self.status_code = status_code
self.message = message
super().__init__(message=message, status_code=status_code)
class VertexAgentEngineConfig(BaseConfig, VertexBase):
"""
Configuration for Vertex AI Agent Engine (Reasoning Engines).
Model format: vertex_ai/agent_engine/<resource_id>
Where resource_id is the numeric ID of the reasoning engine.
"""
def __init__(self, **kwargs):
BaseConfig.__init__(self, **kwargs)
VertexBase.__init__(self)
def get_supported_openai_params(self, model: str) -> List[str]:
"""Vertex Agent Engine has limited OpenAI compatible params."""
return ["user"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""Map OpenAI params to Agent Engine params."""
# Map 'user' to 'user_id' for session management
if "user" in non_default_params:
optional_params["user_id"] = non_default_params["user"]
return optional_params
def _parse_model_string(self, model: str) -> Tuple[str, str]:
"""
Parse model string to extract resource ID.
Model format: agent_engine/<project_number>/<location>/<engine_id>
Or: agent_engine/<engine_id> (uses default project/location)
Returns: (resource_path, engine_id)
"""
# Remove 'agent_engine/' prefix if present
if model.startswith("agent_engine/"):
model = model[len("agent_engine/") :]
# Check if it's a full resource path
if model.startswith("projects/"):
# Full path: projects/123/locations/us-central1/reasoningEngines/456
return model, model.split("/")[-1]
# Just the engine ID
return model, model
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for the request.
For Vertex Agent Engine:
- Non-streaming: :query endpoint (for session management)
- Streaming: :streamQuery endpoint (for actual queries)
"""
resource_path, engine_id = self._parse_model_string(model)
# Get project and location from litellm_params or environment
vertex_project = self.safe_get_vertex_ai_project(litellm_params)
vertex_location = (
self.safe_get_vertex_ai_location(litellm_params) or "us-central1"
)
# Build the full resource path if only engine_id was provided
if not resource_path.startswith("projects/"):
if not vertex_project:
raise ValueError(
"vertex_project is required for Vertex Agent Engine. "
"Set via litellm_params['vertex_project'] or VERTEXAI_PROJECT env var."
)
resource_path = f"projects/{vertex_project}/locations/{vertex_location}/reasoningEngines/{engine_id}"
base_url = get_vertex_base_url(vertex_location)
# Always use :streamQuery endpoint for actual queries
# The :query endpoint only supports session management methods
# (create_session, get_session, list_sessions, delete_session, etc.)
endpoint = f"{base_url}/v1beta1/{resource_path}:streamQuery"
verbose_logger.debug(f"Vertex Agent Engine URL: {endpoint}")
return endpoint
def _get_auth_headers(
self,
optional_params: dict,
litellm_params: dict,
) -> Dict[str, str]:
"""Get authentication headers using Google Cloud credentials."""
vertex_credentials = self.safe_get_vertex_ai_credentials(litellm_params)
vertex_project = self.safe_get_vertex_ai_project(litellm_params)
# Get access token using VertexBase
access_token, project_id = self.get_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
)
verbose_logger.debug(
f"Vertex Agent Engine: Authenticated for project {project_id}"
)
return {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
def _get_user_id(self, optional_params: dict) -> str:
"""Get or generate user ID for session management."""
user_id = optional_params.get("user_id") or optional_params.get("user")
if user_id:
return user_id
# Generate a user ID
return f"litellm-user-{str(uuid.uuid4())[:8]}"
def _get_session_id(self, optional_params: dict) -> Optional[str]:
"""Get session ID if provided."""
return optional_params.get("session_id")
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform the request to Vertex Agent Engine format.
The API expects:
{
"class_method": "stream_query",
"input": {
"message": "...",
"user_id": "...",
"session_id": "..." (optional)
}
}
"""
# Use the last message content as the prompt
prompt = convert_content_list_to_str(messages[-1])
# Get user_id and session_id
user_id = self._get_user_id(optional_params)
session_id = self._get_session_id(optional_params)
# Build the input
input_data: Dict[str, Any] = {
"message": prompt,
"user_id": user_id,
}
if session_id:
input_data["session_id"] = session_id
# Build the request payload
# Note: stream_query is used for both streaming and non-streaming
# The difference is the endpoint (:streamQuery vs :query)
payload = {
"class_method": "stream_query",
"input": input_data,
}
verbose_logger.debug(f"Vertex Agent Engine payload: {payload}")
return payload
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""Validate environment and set up authentication headers."""
auth_headers = self._get_auth_headers(optional_params, litellm_params)
headers.update(auth_headers)
return headers
def _extract_text_from_response(self, response_data: dict) -> str:
"""Extract text content from the response."""
# Try to get from content.parts
content = response_data.get("content", {})
parts = content.get("parts", [])
for part in parts:
if "text" in part:
return part["text"]
# Try actions.state_delta
actions = response_data.get("actions", {})
state_delta = actions.get("state_delta", {})
for key, value in state_delta.items():
if isinstance(value, str) and value:
return value
return ""
def _calculate_usage(
self, model: str, messages: List[AllMessageValues], content: str
) -> Optional[Usage]:
"""Calculate token usage using LiteLLM's token counter."""
try:
from litellm.utils import token_counter
prompt_tokens = token_counter(model="gpt-3.5-turbo", messages=messages)
completion_tokens = token_counter(
model="gpt-3.5-turbo", text=content, count_response_tokens=True
)
total_tokens = prompt_tokens + completion_tokens
return Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
except Exception as e:
verbose_logger.warning(f"Failed to calculate token usage: {str(e)}")
return None
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
Transform Vertex Agent Engine response to LiteLLM ModelResponse format.
The response is a streaming SSE format even for non-streaming requests.
We need to collect all the chunks and extract the final response.
"""
try:
content_type = raw_response.headers.get("content-type", "").lower()
verbose_logger.debug(
f"Vertex Agent Engine response Content-Type: {content_type}"
)
# Parse the SSE response
response_text = raw_response.text
verbose_logger.debug(f"Response (first 500 chars): {response_text[:500]}")
# Extract content from SSE stream
content = ""
for line in response_text.strip().split("\n"):
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
if isinstance(data, dict):
text = self._extract_text_from_response(data)
if text:
content = text # Use the last non-empty text
except json.JSONDecodeError:
continue
# Create the message
message = Message(content=content, role="assistant")
# Create choices
choice = Choices(finish_reason="stop", index=0, message=message)
# Update model response
model_response.choices = [choice]
model_response.model = model
# Calculate usage
calculated_usage = self._calculate_usage(model, messages, content)
if calculated_usage:
setattr(model_response, "usage", calculated_usage)
return model_response
except Exception as e:
verbose_logger.error(
f"Error processing Vertex Agent Engine response: {str(e)}"
)
raise VertexAgentEngineError(
message=f"Error processing response: {str(e)}",
status_code=raw_response.status_code,
)
def get_streaming_response(
self,
model: str,
raw_response: httpx.Response,
) -> VertexAgentEngineResponseIterator:
"""Return a streaming iterator for SSE responses."""
return VertexAgentEngineResponseIterator(
streaming_response=raw_response.iter_lines(),
sync_stream=True,
)
def get_sync_custom_stream_wrapper(
self,
model: str,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
api_base: str,
headers: dict,
data: dict,
messages: list,
client: Optional[Union[HTTPHandler, "AsyncHTTPHandler"]] = None,
json_mode: Optional[bool] = None,
signed_json_body: Optional[bytes] = None,
) -> "CustomStreamWrapper":
"""Get a CustomStreamWrapper for synchronous streaming."""
from litellm.llms.custom_httpx.http_handler import (
HTTPHandler,
_get_httpx_client,
)
from litellm.utils import CustomStreamWrapper
if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client(params={})
# Avoid logging sensitive api_base directly
verbose_logger.debug("Making sync streaming request to Vertex AI endpoint.")
# Make streaming request
response = client.post(
api_base,
headers=headers,
data=json.dumps(data),
stream=True,
logging_obj=logging_obj,
)
if response.status_code != 200:
raise VertexAgentEngineError(
status_code=response.status_code, message=str(response.read())
)
# Create iterator for SSE stream
completion_stream = self.get_streaming_response(
model=model, raw_response=response
)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return streaming_response
async def get_async_custom_stream_wrapper(
self,
model: str,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
api_base: str,
headers: dict,
data: dict,
messages: list,
client: Optional["AsyncHTTPHandler"] = None,
json_mode: Optional[bool] = None,
signed_json_body: Optional[bytes] = None,
) -> "CustomStreamWrapper":
"""Get a CustomStreamWrapper for asynchronous streaming."""
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
from litellm.utils import CustomStreamWrapper
if client is None or not isinstance(client, AsyncHTTPHandler):
client = get_async_httpx_client(
llm_provider=cast(Any, "vertex_ai"), params={}
)
# Avoid logging sensitive api_base directly
verbose_logger.debug("Making async streaming request to Vertex AI endpoint.")
# Make async streaming request
response = await client.post(
api_base,
headers=headers,
data=json.dumps(data),
stream=True,
logging_obj=logging_obj,
)
if response.status_code != 200:
raise VertexAgentEngineError(
status_code=response.status_code, message=str(await response.aread())
)
# Create iterator for SSE stream (async)
completion_stream = VertexAgentEngineResponseIterator(
streaming_response=response.aiter_lines(),
sync_stream=False,
)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return streaming_response
@property
def has_custom_stream_wrapper(self) -> bool:
"""Indicates that this config has custom streaming support."""
return True
@property
def supports_stream_param_in_request_body(self) -> bool:
"""Agent Engine does not allow passing `stream` in the request body."""
return False
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return VertexAgentEngineError(status_code=status_code, message=error_message)
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
"""Agent Engine always returns SSE streams, so we use real streaming."""
return False