chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Handler for transforming responses api requests to litellm.completion requests
|
||||
"""
|
||||
|
||||
from typing import Any, Coroutine, Dict, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm.responses.litellm_completion_transformation.streaming_iterator import (
|
||||
LiteLLMCompletionStreamingIterator,
|
||||
)
|
||||
from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
LiteLLMCompletionResponsesConfig,
|
||||
)
|
||||
from litellm.responses.streaming_iterator import BaseResponsesAPIStreamingIterator
|
||||
from litellm.types.llms.openai import (
|
||||
ResponseInputParam,
|
||||
ResponsesAPIOptionalRequestParams,
|
||||
ResponsesAPIResponse,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
|
||||
class LiteLLMCompletionTransformationHandler:
|
||||
def response_api_handler(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, ResponseInputParam],
|
||||
responses_api_request: ResponsesAPIOptionalRequestParams,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
_is_async: bool = False,
|
||||
stream: Optional[bool] = None,
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> Union[
|
||||
ResponsesAPIResponse,
|
||||
BaseResponsesAPIStreamingIterator,
|
||||
Coroutine[
|
||||
Any, Any, Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]
|
||||
],
|
||||
]:
|
||||
litellm_completion_request: dict = LiteLLMCompletionResponsesConfig.transform_responses_api_request_to_chat_completion_request(
|
||||
model=model,
|
||||
input=input,
|
||||
responses_api_request=responses_api_request,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
stream=stream,
|
||||
extra_headers=extra_headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if _is_async:
|
||||
return self.async_response_api_handler(
|
||||
litellm_completion_request=litellm_completion_request,
|
||||
request_input=input,
|
||||
responses_api_request=responses_api_request,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
completion_args = {}
|
||||
completion_args.update(kwargs)
|
||||
completion_args.update(litellm_completion_request)
|
||||
|
||||
litellm_completion_response: Union[
|
||||
ModelResponse, litellm.CustomStreamWrapper
|
||||
] = litellm.completion(
|
||||
**litellm_completion_request,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if isinstance(litellm_completion_response, ModelResponse):
|
||||
responses_api_response: ResponsesAPIResponse = LiteLLMCompletionResponsesConfig.transform_chat_completion_response_to_responses_api_response(
|
||||
chat_completion_response=litellm_completion_response,
|
||||
request_input=input,
|
||||
responses_api_request=responses_api_request,
|
||||
)
|
||||
|
||||
return responses_api_response
|
||||
|
||||
elif isinstance(litellm_completion_response, litellm.CustomStreamWrapper):
|
||||
return LiteLLMCompletionStreamingIterator(
|
||||
model=model,
|
||||
litellm_custom_stream_wrapper=litellm_completion_response,
|
||||
request_input=input,
|
||||
responses_api_request=responses_api_request,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_metadata=kwargs.get("litellm_metadata", {}),
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unexpected response type: {type(litellm_completion_response)}"
|
||||
)
|
||||
|
||||
async def async_response_api_handler(
|
||||
self,
|
||||
litellm_completion_request: dict,
|
||||
request_input: Union[str, ResponseInputParam],
|
||||
responses_api_request: ResponsesAPIOptionalRequestParams,
|
||||
**kwargs,
|
||||
) -> Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]:
|
||||
previous_response_id: Optional[str] = responses_api_request.get(
|
||||
"previous_response_id"
|
||||
)
|
||||
if previous_response_id:
|
||||
litellm_completion_request = await LiteLLMCompletionResponsesConfig.async_responses_api_session_handler(
|
||||
previous_response_id=previous_response_id,
|
||||
litellm_completion_request=litellm_completion_request,
|
||||
)
|
||||
|
||||
acompletion_args = {}
|
||||
acompletion_args.update(kwargs)
|
||||
acompletion_args.update(litellm_completion_request)
|
||||
|
||||
litellm_completion_response: Union[
|
||||
ModelResponse, litellm.CustomStreamWrapper
|
||||
] = await litellm.acompletion(
|
||||
**acompletion_args,
|
||||
)
|
||||
|
||||
if isinstance(litellm_completion_response, ModelResponse):
|
||||
responses_api_response: ResponsesAPIResponse = LiteLLMCompletionResponsesConfig.transform_chat_completion_response_to_responses_api_response(
|
||||
chat_completion_response=litellm_completion_response,
|
||||
request_input=request_input,
|
||||
responses_api_request=responses_api_request,
|
||||
)
|
||||
|
||||
return responses_api_response
|
||||
|
||||
elif isinstance(litellm_completion_response, litellm.CustomStreamWrapper):
|
||||
return LiteLLMCompletionStreamingIterator(
|
||||
model=litellm_completion_request.get("model") or "",
|
||||
litellm_custom_stream_wrapper=litellm_completion_response,
|
||||
request_input=request_input,
|
||||
responses_api_request=responses_api_request,
|
||||
custom_llm_provider=litellm_completion_request.get(
|
||||
"custom_llm_provider"
|
||||
),
|
||||
litellm_metadata=kwargs.get("litellm_metadata", {}),
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unexpected response type: {type(litellm_completion_response)}"
|
||||
)
|
||||
@@ -0,0 +1,315 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import SpendLogsPayload
|
||||
from litellm.proxy.spend_tracking.cold_storage_handler import ColdStorageHandler
|
||||
from litellm.responses.utils import ResponsesAPIRequestUtils
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionResponseMessage,
|
||||
GenericChatCompletionMessage,
|
||||
ResponseInputParam,
|
||||
)
|
||||
from litellm.types.utils import ChatCompletionMessageToolCall, Message, ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
ChatCompletionSession,
|
||||
)
|
||||
else:
|
||||
ChatCompletionSession = Any
|
||||
|
||||
########################################################
|
||||
# Cold Storage Handler
|
||||
########################################################
|
||||
COLD_STORAGE_HANDLER = ColdStorageHandler()
|
||||
########################################################
|
||||
|
||||
|
||||
class ResponsesSessionHandler:
|
||||
@staticmethod
|
||||
async def get_chat_completion_message_history_for_previous_response_id(
|
||||
previous_response_id: str,
|
||||
) -> ChatCompletionSession:
|
||||
"""
|
||||
Return the chat completion message history for a previous response id
|
||||
"""
|
||||
from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
ChatCompletionSession,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"inside get_chat_completion_message_history_for_previous_response_id"
|
||||
)
|
||||
all_spend_logs: List[
|
||||
SpendLogsPayload
|
||||
] = await ResponsesSessionHandler.get_all_spend_logs_for_previous_response_id(
|
||||
previous_response_id
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"found %s spend logs for this response id", len(all_spend_logs)
|
||||
)
|
||||
|
||||
litellm_session_id: Optional[str] = None
|
||||
if len(all_spend_logs) > 0:
|
||||
litellm_session_id = all_spend_logs[0].get("session_id")
|
||||
|
||||
chat_completion_message_history: List[
|
||||
Union[
|
||||
AllMessageValues,
|
||||
GenericChatCompletionMessage,
|
||||
ChatCompletionMessageToolCall,
|
||||
ChatCompletionResponseMessage,
|
||||
Message,
|
||||
]
|
||||
] = []
|
||||
for spend_log in all_spend_logs:
|
||||
chat_completion_message_history = await ResponsesSessionHandler.extend_chat_completion_message_with_spend_log_payload(
|
||||
spend_log=spend_log,
|
||||
chat_completion_message_history=chat_completion_message_history,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"chat_completion_message_history %s",
|
||||
json.dumps(chat_completion_message_history, indent=4, default=str),
|
||||
)
|
||||
return ChatCompletionSession(
|
||||
messages=chat_completion_message_history,
|
||||
litellm_session_id=litellm_session_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def extend_chat_completion_message_with_spend_log_payload(
|
||||
spend_log: SpendLogsPayload,
|
||||
chat_completion_message_history: List[
|
||||
Union[
|
||||
AllMessageValues,
|
||||
GenericChatCompletionMessage,
|
||||
ChatCompletionMessageToolCall,
|
||||
ChatCompletionResponseMessage,
|
||||
Message,
|
||||
]
|
||||
],
|
||||
):
|
||||
"""
|
||||
Extend the chat completion message history with the spend log payload
|
||||
"""
|
||||
from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
LiteLLMCompletionResponsesConfig,
|
||||
)
|
||||
|
||||
proxy_server_request_dict = (
|
||||
await ResponsesSessionHandler.get_proxy_server_request_from_spend_log(
|
||||
spend_log=spend_log,
|
||||
)
|
||||
)
|
||||
response_input_param: Optional[Union[str, ResponseInputParam]] = None
|
||||
_messages: Optional[Union[str, ResponseInputParam]] = None
|
||||
|
||||
############################################################
|
||||
# Add Input messages for this Spend Log
|
||||
############################################################
|
||||
if proxy_server_request_dict:
|
||||
_response_input_param = proxy_server_request_dict.get("input", None)
|
||||
_messages = proxy_server_request_dict.get("messages", None)
|
||||
if isinstance(_response_input_param, str):
|
||||
response_input_param = _response_input_param
|
||||
elif isinstance(_response_input_param, dict):
|
||||
response_input_param = cast(ResponseInputParam, _response_input_param)
|
||||
|
||||
if response_input_param:
|
||||
chat_completion_messages = LiteLLMCompletionResponsesConfig.transform_responses_api_input_to_messages(
|
||||
input=response_input_param,
|
||||
responses_api_request=proxy_server_request_dict or {},
|
||||
)
|
||||
chat_completion_message_history.extend(chat_completion_messages)
|
||||
|
||||
############################################################
|
||||
# Check if `messages` field is present in the proxy server request dict
|
||||
############################################################
|
||||
elif _messages:
|
||||
# ensure all messages are /chat/completions/messages
|
||||
# certain requests can be stored as Responses API format - this ensures they are transformed to /chat/completions/messages
|
||||
chat_completion_messages = LiteLLMCompletionResponsesConfig.transform_responses_api_input_to_messages(
|
||||
input=_messages,
|
||||
responses_api_request=proxy_server_request_dict or {},
|
||||
)
|
||||
chat_completion_message_history.extend(chat_completion_messages)
|
||||
|
||||
############################################################
|
||||
# Add Output messages for this Spend Log
|
||||
############################################################
|
||||
_response_output = spend_log.get("response", "{}")
|
||||
if (
|
||||
isinstance(_response_output, dict)
|
||||
and _response_output
|
||||
and _response_output != {}
|
||||
):
|
||||
# transform `ChatCompletion Response` to `ResponsesAPIResponse`
|
||||
model_response = ModelResponse(**_response_output)
|
||||
for choice in model_response.choices:
|
||||
if hasattr(choice, "message"):
|
||||
chat_completion_message_history.append(getattr(choice, "message"))
|
||||
return chat_completion_message_history
|
||||
|
||||
@staticmethod
|
||||
async def get_proxy_server_request_from_spend_log(
|
||||
spend_log: SpendLogsPayload,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the parsed proxy server request from the spend log
|
||||
"""
|
||||
proxy_server_request: Union[str, dict] = (
|
||||
spend_log.get("proxy_server_request") or "{}"
|
||||
)
|
||||
proxy_server_request_dict: Optional[dict] = None
|
||||
if isinstance(proxy_server_request, dict):
|
||||
proxy_server_request_dict = proxy_server_request
|
||||
else:
|
||||
proxy_server_request_dict = json.loads(proxy_server_request)
|
||||
|
||||
############################################################
|
||||
# Check if user has setup cold storage for session handling
|
||||
############################################################
|
||||
if ResponsesSessionHandler._should_check_cold_storage_for_full_payload(
|
||||
proxy_server_request_dict
|
||||
):
|
||||
# Try to get cold storage object key from spend log metadata
|
||||
_proxy_server_request_dict: Optional[dict] = None
|
||||
cold_storage_object_key = (
|
||||
ResponsesSessionHandler._get_cold_storage_object_key_from_spend_log(
|
||||
spend_log
|
||||
)
|
||||
)
|
||||
if cold_storage_object_key:
|
||||
# Use the object key directly from metadata
|
||||
_proxy_server_request_dict = await ResponsesSessionHandler.get_proxy_server_request_from_cold_storage_with_object_key(
|
||||
object_key=cold_storage_object_key,
|
||||
)
|
||||
if _proxy_server_request_dict:
|
||||
proxy_server_request_dict = _proxy_server_request_dict
|
||||
|
||||
return proxy_server_request_dict
|
||||
|
||||
@staticmethod
|
||||
def _get_cold_storage_object_key_from_spend_log(
|
||||
spend_log: SpendLogsPayload,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Extract the cold storage object key from spend log metadata.
|
||||
|
||||
Args:
|
||||
spend_log: The spend log payload containing metadata
|
||||
|
||||
Returns:
|
||||
Optional[str]: The cold storage object key if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
metadata_str = spend_log.get("metadata", "{}")
|
||||
if isinstance(metadata_str, str):
|
||||
metadata_dict = json.loads(metadata_str)
|
||||
return metadata_dict.get("cold_storage_object_key")
|
||||
elif isinstance(metadata_str, dict):
|
||||
return metadata_str.get("cold_storage_object_key")
|
||||
return None
|
||||
except (json.JSONDecodeError, TypeError, AttributeError):
|
||||
verbose_proxy_logger.debug(
|
||||
"Failed to parse metadata from spend log to extract cold storage object key"
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def get_proxy_server_request_from_cold_storage_with_object_key(
|
||||
object_key: str,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the proxy server request from cold storage using the object key directly.
|
||||
|
||||
Args:
|
||||
object_key: The S3/GCS object key to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[dict]: The proxy server request dict or None if not found
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
"inside get_proxy_server_request_from_cold_storage_with_object_key..."
|
||||
)
|
||||
|
||||
proxy_server_request_dict = await COLD_STORAGE_HANDLER.get_proxy_server_request_from_cold_storage_with_object_key(
|
||||
object_key=object_key,
|
||||
)
|
||||
|
||||
return proxy_server_request_dict
|
||||
|
||||
@staticmethod
|
||||
def _should_check_cold_storage_for_full_payload(
|
||||
proxy_server_request_dict: Optional[dict],
|
||||
) -> bool:
|
||||
"""
|
||||
Only check cold storage when both are true
|
||||
1. `LITELLM_TRUNCATED_PAYLOAD_FIELD` is in the proxy server request dict
|
||||
2. `litellm.cold_storage_custom_logger` is not None
|
||||
"""
|
||||
from litellm.constants import LITELLM_TRUNCATED_PAYLOAD_FIELD
|
||||
|
||||
configured_cold_storage_custom_logger = litellm.cold_storage_custom_logger
|
||||
if configured_cold_storage_custom_logger is None:
|
||||
return False
|
||||
if proxy_server_request_dict is None:
|
||||
return True
|
||||
if len(proxy_server_request_dict) == 0:
|
||||
return True
|
||||
if LITELLM_TRUNCATED_PAYLOAD_FIELD in str(proxy_server_request_dict):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def get_all_spend_logs_for_previous_response_id(
|
||||
previous_response_id: str,
|
||||
) -> List[SpendLogsPayload]:
|
||||
"""
|
||||
Get all spend logs for a previous response id
|
||||
|
||||
|
||||
SQL query
|
||||
|
||||
SELECT session_id FROM spend_logs WHERE response_id = previous_response_id, SELECT * FROM spend_logs WHERE session_id = session_id
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
verbose_proxy_logger.debug("decoding response id=%s", previous_response_id)
|
||||
|
||||
decoded_response_id = (
|
||||
ResponsesAPIRequestUtils._decode_responses_api_response_id(
|
||||
previous_response_id
|
||||
)
|
||||
)
|
||||
previous_response_id = decoded_response_id.get(
|
||||
"response_id", previous_response_id
|
||||
)
|
||||
if prisma_client is None:
|
||||
return []
|
||||
|
||||
query = """
|
||||
WITH matching_session AS (
|
||||
SELECT session_id
|
||||
FROM "LiteLLM_SpendLogs"
|
||||
WHERE request_id = $1
|
||||
)
|
||||
SELECT *
|
||||
FROM "LiteLLM_SpendLogs"
|
||||
WHERE session_id IN (SELECT session_id FROM matching_session)
|
||||
ORDER BY "endTime" ASC;
|
||||
"""
|
||||
|
||||
spend_logs = await prisma_client.db.query_raw(query, previous_response_id)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Found the following spend logs for previous response id %s: %s",
|
||||
previous_response_id,
|
||||
json.dumps(spend_logs, indent=4, default=str),
|
||||
)
|
||||
|
||||
return spend_logs
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
1778
llm-gateway-competitors/litellm-wheel-src/litellm/responses/main.py
Normal file
1778
llm-gateway-competitors/litellm-wheel-src/litellm/responses/main.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,670 @@
|
||||
"""Helpers for handling MCP-aware `/chat/completions` requests."""
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from litellm.responses.mcp.litellm_proxy_mcp_handler import (
|
||||
LiteLLM_Proxy_MCP_Handler,
|
||||
)
|
||||
from litellm.responses.utils import ResponsesAPIRequestUtils
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
|
||||
def _add_mcp_metadata_to_response(
|
||||
response: Union[ModelResponse, CustomStreamWrapper],
|
||||
openai_tools: Optional[List],
|
||||
tool_calls: Optional[List] = None,
|
||||
tool_results: Optional[List] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add MCP metadata to response's provider_specific_fields.
|
||||
|
||||
This function adds MCP-related information to the response so that
|
||||
clients can access which tools were available, which were called, and
|
||||
what results were returned.
|
||||
|
||||
For ModelResponse: adds to choices[].message.provider_specific_fields
|
||||
For CustomStreamWrapper: stores in _hidden_params and automatically adds to
|
||||
final chunk's delta.provider_specific_fields via CustomStreamWrapper._add_mcp_metadata_to_final_chunk()
|
||||
"""
|
||||
if isinstance(response, CustomStreamWrapper):
|
||||
# For streaming, store MCP metadata in _hidden_params
|
||||
# CustomStreamWrapper._add_mcp_metadata_to_final_chunk() will automatically
|
||||
# add it to the final chunk's delta.provider_specific_fields
|
||||
if not hasattr(response, "_hidden_params"):
|
||||
response._hidden_params = {}
|
||||
|
||||
mcp_metadata = {}
|
||||
if openai_tools:
|
||||
mcp_metadata["mcp_list_tools"] = openai_tools
|
||||
if tool_calls:
|
||||
mcp_metadata["mcp_tool_calls"] = tool_calls
|
||||
if tool_results:
|
||||
mcp_metadata["mcp_call_results"] = tool_results
|
||||
|
||||
if mcp_metadata:
|
||||
response._hidden_params["mcp_metadata"] = mcp_metadata
|
||||
return
|
||||
|
||||
if not isinstance(response, ModelResponse):
|
||||
return
|
||||
|
||||
if not hasattr(response, "choices") or not response.choices:
|
||||
return
|
||||
|
||||
# Add MCP metadata to all choices' messages
|
||||
for choice in response.choices:
|
||||
message = getattr(choice, "message", None)
|
||||
if message is not None:
|
||||
# Get existing provider_specific_fields or create new dict
|
||||
provider_fields = getattr(message, "provider_specific_fields", None) or {}
|
||||
|
||||
# Add MCP metadata
|
||||
if openai_tools:
|
||||
provider_fields["mcp_list_tools"] = openai_tools
|
||||
if tool_calls:
|
||||
provider_fields["mcp_tool_calls"] = tool_calls
|
||||
if tool_results:
|
||||
provider_fields["mcp_call_results"] = tool_results
|
||||
|
||||
# Set the provider_specific_fields
|
||||
setattr(message, "provider_specific_fields", provider_fields)
|
||||
|
||||
|
||||
async def acompletion_with_mcp( # noqa: PLR0915
|
||||
model: str,
|
||||
messages: List,
|
||||
tools: Optional[List] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
"""
|
||||
Async completion with MCP integration.
|
||||
|
||||
This function handles MCP tool integration following the same pattern as aresponses_api_with_mcp.
|
||||
It's designed to be called from the synchronous completion() function and return a coroutine.
|
||||
|
||||
When MCP tools with server_url="litellm_proxy" are provided, this function will:
|
||||
1. Get available tools from the MCP server manager
|
||||
2. Transform them to OpenAI format
|
||||
3. Call acompletion with the transformed tools
|
||||
4. If require_approval="never" and tool calls are returned, automatically execute them
|
||||
5. Make a follow-up call with the tool results
|
||||
"""
|
||||
from litellm import acompletion as litellm_acompletion
|
||||
|
||||
# Parse MCP tools and separate from other tools
|
||||
(
|
||||
mcp_tools_with_litellm_proxy,
|
||||
other_tools,
|
||||
) = LiteLLM_Proxy_MCP_Handler._parse_mcp_tools(tools)
|
||||
|
||||
if not mcp_tools_with_litellm_proxy:
|
||||
# No MCP tools, proceed with regular completion
|
||||
return await litellm_acompletion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract user_api_key_auth from metadata or kwargs
|
||||
user_api_key_auth = kwargs.get("user_api_key_auth") or (
|
||||
(kwargs.get("metadata", {}) or {}).get("user_api_key_auth")
|
||||
)
|
||||
|
||||
# Extract MCP auth headers before fetching tools (needed for dynamic auth)
|
||||
(
|
||||
mcp_auth_header,
|
||||
mcp_server_auth_headers,
|
||||
oauth2_headers,
|
||||
raw_headers,
|
||||
) = ResponsesAPIRequestUtils.extract_mcp_headers_from_request(
|
||||
secret_fields=kwargs.get("secret_fields"),
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# Process MCP tools (pass auth headers for dynamic auth)
|
||||
(
|
||||
deduplicated_mcp_tools,
|
||||
tool_server_map,
|
||||
) = await LiteLLM_Proxy_MCP_Handler._process_mcp_tools_without_openai_transform(
|
||||
user_api_key_auth=user_api_key_auth,
|
||||
mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy,
|
||||
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||
mcp_auth_header=mcp_auth_header,
|
||||
mcp_server_auth_headers=mcp_server_auth_headers,
|
||||
)
|
||||
|
||||
openai_tools = LiteLLM_Proxy_MCP_Handler._transform_mcp_tools_to_openai(
|
||||
deduplicated_mcp_tools,
|
||||
target_format="chat",
|
||||
)
|
||||
|
||||
# Combine with other tools
|
||||
all_tools = openai_tools + other_tools if (openai_tools or other_tools) else None
|
||||
|
||||
# Determine if we should auto-execute tools
|
||||
should_auto_execute = LiteLLM_Proxy_MCP_Handler._should_auto_execute_tools(
|
||||
mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy
|
||||
)
|
||||
|
||||
# Prepare call parameters
|
||||
# Remove keys that shouldn't be passed to acompletion
|
||||
clean_kwargs = {k: v for k, v in kwargs.items() if k not in ["acompletion"]}
|
||||
|
||||
base_call_args = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": all_tools,
|
||||
"_skip_mcp_handler": True, # Prevent recursion
|
||||
**clean_kwargs,
|
||||
}
|
||||
|
||||
# If not auto-executing, just make the call with transformed tools
|
||||
if not should_auto_execute:
|
||||
response = await litellm_acompletion(**base_call_args)
|
||||
if isinstance(response, (ModelResponse, CustomStreamWrapper)):
|
||||
_add_mcp_metadata_to_response(
|
||||
response=response,
|
||||
openai_tools=openai_tools,
|
||||
)
|
||||
return response
|
||||
|
||||
# For auto-execute: handle streaming vs non-streaming differently
|
||||
stream = kwargs.get("stream", False)
|
||||
mock_tool_calls = base_call_args.pop("mock_tool_calls", None)
|
||||
|
||||
if stream:
|
||||
# Streaming mode: make initial call with streaming, collect chunks, detect tool calls
|
||||
initial_call_args = dict(base_call_args)
|
||||
initial_call_args["stream"] = True
|
||||
if mock_tool_calls is not None:
|
||||
initial_call_args["mock_tool_calls"] = mock_tool_calls
|
||||
|
||||
# Make initial streaming call
|
||||
initial_stream = await litellm_acompletion(**initial_call_args)
|
||||
|
||||
if not isinstance(initial_stream, CustomStreamWrapper):
|
||||
# Not a stream, return as-is
|
||||
if isinstance(initial_stream, ModelResponse):
|
||||
_add_mcp_metadata_to_response(
|
||||
response=initial_stream,
|
||||
openai_tools=openai_tools,
|
||||
)
|
||||
return initial_stream
|
||||
|
||||
# Create a custom async generator that collects chunks and handles tool execution
|
||||
from litellm.main import stream_chunk_builder
|
||||
from litellm.types.utils import ModelResponseStream
|
||||
|
||||
class MCPStreamingIterator:
|
||||
"""Custom iterator that collects chunks, detects tool calls, and adds MCP metadata to final chunk."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream_wrapper,
|
||||
messages,
|
||||
tool_server_map,
|
||||
user_api_key_auth,
|
||||
mcp_auth_header,
|
||||
mcp_server_auth_headers,
|
||||
oauth2_headers,
|
||||
raw_headers,
|
||||
litellm_call_id,
|
||||
litellm_trace_id,
|
||||
openai_tools,
|
||||
base_call_args,
|
||||
):
|
||||
self.stream_wrapper = stream_wrapper
|
||||
self.messages = messages
|
||||
self.tool_server_map = tool_server_map
|
||||
self.user_api_key_auth = user_api_key_auth
|
||||
self.mcp_auth_header = mcp_auth_header
|
||||
self.mcp_server_auth_headers = mcp_server_auth_headers
|
||||
self.oauth2_headers = oauth2_headers
|
||||
self.raw_headers = raw_headers
|
||||
self.litellm_call_id = litellm_call_id
|
||||
self.litellm_trace_id = litellm_trace_id
|
||||
self.openai_tools = openai_tools
|
||||
self.base_call_args = base_call_args
|
||||
self.collected_chunks: List[ModelResponseStream] = []
|
||||
self.tool_calls: Optional[List] = None
|
||||
self.tool_results: Optional[List] = None
|
||||
self.complete_response: Optional[ModelResponse] = None
|
||||
self.stream_exhausted = False
|
||||
self.tool_execution_done = False
|
||||
self.follow_up_stream = None
|
||||
self.follow_up_iterator = None
|
||||
self.follow_up_exhausted = False
|
||||
|
||||
async def __aiter__(self):
|
||||
return self
|
||||
|
||||
def _add_mcp_list_tools_to_chunk(
|
||||
self, chunk: ModelResponseStream
|
||||
) -> ModelResponseStream:
|
||||
"""Add mcp_list_tools to the first chunk."""
|
||||
from litellm.types.utils import (
|
||||
StreamingChoices,
|
||||
add_provider_specific_fields,
|
||||
)
|
||||
|
||||
if not self.openai_tools:
|
||||
return chunk
|
||||
|
||||
if hasattr(chunk, "choices") and chunk.choices:
|
||||
for choice in chunk.choices:
|
||||
if (
|
||||
isinstance(choice, StreamingChoices)
|
||||
and hasattr(choice, "delta")
|
||||
and choice.delta
|
||||
):
|
||||
# Get existing provider_specific_fields or create new dict
|
||||
existing_fields = (
|
||||
getattr(choice.delta, "provider_specific_fields", None)
|
||||
or {}
|
||||
)
|
||||
provider_fields = dict(
|
||||
existing_fields
|
||||
) # Create a copy to avoid mutating the original
|
||||
|
||||
# Add only mcp_list_tools to first chunk
|
||||
provider_fields["mcp_list_tools"] = self.openai_tools
|
||||
|
||||
# Use add_provider_specific_fields to ensure proper setting
|
||||
# This function handles Pydantic model attribute setting correctly
|
||||
add_provider_specific_fields(choice.delta, provider_fields)
|
||||
|
||||
return chunk
|
||||
|
||||
def _add_mcp_tool_metadata_to_final_chunk(
|
||||
self, chunk: ModelResponseStream
|
||||
) -> ModelResponseStream:
|
||||
"""Add mcp_tool_calls and mcp_call_results to the final chunk."""
|
||||
from litellm.types.utils import (
|
||||
StreamingChoices,
|
||||
add_provider_specific_fields,
|
||||
)
|
||||
|
||||
if hasattr(chunk, "choices") and chunk.choices:
|
||||
for choice in chunk.choices:
|
||||
if (
|
||||
isinstance(choice, StreamingChoices)
|
||||
and hasattr(choice, "delta")
|
||||
and choice.delta
|
||||
):
|
||||
# Get existing provider_specific_fields or create new dict
|
||||
# Access the attribute directly to handle Pydantic model attributes correctly
|
||||
existing_fields = {}
|
||||
if hasattr(choice.delta, "provider_specific_fields"):
|
||||
attr_value = getattr(
|
||||
choice.delta, "provider_specific_fields", None
|
||||
)
|
||||
if attr_value is not None:
|
||||
# Create a copy to avoid mutating the original
|
||||
existing_fields = (
|
||||
dict(attr_value)
|
||||
if isinstance(attr_value, dict)
|
||||
else {}
|
||||
)
|
||||
|
||||
provider_fields = existing_fields
|
||||
|
||||
# Add tool_calls and tool_results if available
|
||||
if self.tool_calls:
|
||||
provider_fields["mcp_tool_calls"] = self.tool_calls
|
||||
if self.tool_results:
|
||||
provider_fields["mcp_call_results"] = self.tool_results
|
||||
|
||||
# Use add_provider_specific_fields to ensure proper setting
|
||||
# This function handles Pydantic model attribute setting correctly
|
||||
add_provider_specific_fields(choice.delta, provider_fields)
|
||||
|
||||
return chunk
|
||||
|
||||
async def __anext__(self):
|
||||
# Phase 1: Collect and yield initial stream chunks
|
||||
if not self.stream_exhausted:
|
||||
# Get the iterator from the stream wrapper
|
||||
if not hasattr(self, "_stream_iterator"):
|
||||
self._stream_iterator = self.stream_wrapper.__aiter__()
|
||||
# Add mcp_list_tools to the first chunk (available from the start)
|
||||
_add_mcp_metadata_to_response(
|
||||
response=self.stream_wrapper,
|
||||
openai_tools=self.openai_tools,
|
||||
)
|
||||
|
||||
try:
|
||||
chunk = await self._stream_iterator.__anext__()
|
||||
self.collected_chunks.append(chunk)
|
||||
|
||||
# Add mcp_list_tools to the first chunk
|
||||
if len(self.collected_chunks) == 1:
|
||||
chunk = self._add_mcp_list_tools_to_chunk(chunk)
|
||||
|
||||
# Check if this is the final chunk (has finish_reason)
|
||||
is_final = (
|
||||
hasattr(chunk, "choices")
|
||||
and chunk.choices
|
||||
and hasattr(chunk.choices[0], "finish_reason")
|
||||
and chunk.choices[0].finish_reason is not None
|
||||
)
|
||||
|
||||
if is_final:
|
||||
# This is the final chunk, mark stream as exhausted
|
||||
self.stream_exhausted = True
|
||||
# Process tool calls after we've collected all chunks
|
||||
await self._process_tool_calls()
|
||||
# Apply MCP metadata (tool_calls and tool_results) to final chunk
|
||||
chunk = self._add_mcp_tool_metadata_to_final_chunk(chunk)
|
||||
# If we have tool results, prepare follow-up call immediately
|
||||
if self.tool_results and self.complete_response:
|
||||
await self._prepare_follow_up_call()
|
||||
|
||||
return chunk
|
||||
except StopAsyncIteration:
|
||||
self.stream_exhausted = True
|
||||
# Process tool calls after stream is exhausted
|
||||
await self._process_tool_calls()
|
||||
# If we have chunks, yield the final one with metadata
|
||||
if self.collected_chunks:
|
||||
final_chunk = self.collected_chunks[-1]
|
||||
final_chunk = self._add_mcp_tool_metadata_to_final_chunk(
|
||||
final_chunk
|
||||
)
|
||||
# If we have tool results, prepare follow-up call
|
||||
if self.tool_results and self.complete_response:
|
||||
await self._prepare_follow_up_call()
|
||||
return final_chunk
|
||||
|
||||
# Phase 2: Yield follow-up stream chunks if available
|
||||
if self.follow_up_stream and not self.follow_up_exhausted:
|
||||
if not self.follow_up_iterator:
|
||||
self.follow_up_iterator = self.follow_up_stream.__aiter__()
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.debug("Follow-up stream iterator created")
|
||||
|
||||
try:
|
||||
chunk = await self.follow_up_iterator.__anext__()
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.debug(f"Follow-up chunk yielded: {chunk}")
|
||||
return chunk
|
||||
except StopAsyncIteration:
|
||||
self.follow_up_exhausted = True
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.debug("Follow-up stream exhausted")
|
||||
# After follow-up stream is exhausted, check if we need to raise StopAsyncIteration
|
||||
raise StopAsyncIteration
|
||||
|
||||
# If we're here and follow_up_stream is None but we expected it, log a warning
|
||||
if (
|
||||
self.stream_exhausted
|
||||
and self.tool_results
|
||||
and self.complete_response
|
||||
and self.follow_up_stream is None
|
||||
):
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.warning(
|
||||
"Follow-up stream was not created despite having tool results"
|
||||
)
|
||||
|
||||
raise StopAsyncIteration
|
||||
|
||||
async def _process_tool_calls(self):
|
||||
"""Process tool calls after streaming completes."""
|
||||
if self.tool_execution_done:
|
||||
return
|
||||
|
||||
self.tool_execution_done = True
|
||||
|
||||
if not self.collected_chunks:
|
||||
return
|
||||
|
||||
# Build complete response from chunks
|
||||
complete_response = stream_chunk_builder(
|
||||
chunks=self.collected_chunks,
|
||||
messages=self.messages,
|
||||
)
|
||||
|
||||
if isinstance(complete_response, ModelResponse):
|
||||
self.complete_response = complete_response
|
||||
# Extract tool calls from complete response
|
||||
self.tool_calls = LiteLLM_Proxy_MCP_Handler._extract_tool_calls_from_chat_response(
|
||||
response=complete_response
|
||||
)
|
||||
|
||||
if self.tool_calls:
|
||||
# Execute tool calls
|
||||
self.tool_results = (
|
||||
await LiteLLM_Proxy_MCP_Handler._execute_tool_calls(
|
||||
tool_server_map=self.tool_server_map,
|
||||
tool_calls=self.tool_calls,
|
||||
user_api_key_auth=self.user_api_key_auth,
|
||||
mcp_auth_header=self.mcp_auth_header,
|
||||
mcp_server_auth_headers=self.mcp_server_auth_headers,
|
||||
oauth2_headers=self.oauth2_headers,
|
||||
raw_headers=self.raw_headers,
|
||||
litellm_call_id=self.litellm_call_id,
|
||||
litellm_trace_id=self.litellm_trace_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def _prepare_follow_up_call(self):
|
||||
"""Prepare and initiate follow-up call with tool results."""
|
||||
if self.follow_up_stream is not None:
|
||||
return # Already prepared
|
||||
|
||||
if not self.tool_results or not self.complete_response:
|
||||
return
|
||||
|
||||
# Create follow-up messages with tool results
|
||||
follow_up_messages = (
|
||||
LiteLLM_Proxy_MCP_Handler._create_follow_up_messages_for_chat(
|
||||
original_messages=self.messages,
|
||||
response=self.complete_response,
|
||||
tool_results=self.tool_results,
|
||||
)
|
||||
)
|
||||
|
||||
# Make follow-up call with streaming
|
||||
follow_up_call_args = dict(self.base_call_args)
|
||||
follow_up_call_args["messages"] = follow_up_messages
|
||||
follow_up_call_args["stream"] = True
|
||||
# Ensure follow-up call doesn't trigger MCP handler again
|
||||
follow_up_call_args["_skip_mcp_handler"] = True
|
||||
|
||||
# Import litellm here to ensure we get the patched version
|
||||
# This ensures the patch works correctly in tests
|
||||
import litellm
|
||||
|
||||
follow_up_response = await litellm.acompletion(**follow_up_call_args)
|
||||
|
||||
# Ensure follow-up response is a CustomStreamWrapper
|
||||
if isinstance(follow_up_response, CustomStreamWrapper):
|
||||
self.follow_up_stream = follow_up_response
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.debug("Follow-up stream created successfully")
|
||||
else:
|
||||
# Unexpected response type - log and set to None
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.warning(
|
||||
f"Follow-up response is not a CustomStreamWrapper: {type(follow_up_response)}"
|
||||
)
|
||||
self.follow_up_stream = None
|
||||
|
||||
# Create the custom iterator
|
||||
iterator = MCPStreamingIterator(
|
||||
stream_wrapper=initial_stream,
|
||||
messages=messages,
|
||||
tool_server_map=tool_server_map,
|
||||
user_api_key_auth=user_api_key_auth,
|
||||
mcp_auth_header=mcp_auth_header,
|
||||
mcp_server_auth_headers=mcp_server_auth_headers,
|
||||
oauth2_headers=oauth2_headers,
|
||||
raw_headers=raw_headers,
|
||||
litellm_call_id=kwargs.get("litellm_call_id"),
|
||||
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||
openai_tools=openai_tools,
|
||||
base_call_args=base_call_args,
|
||||
)
|
||||
|
||||
# Create a wrapper class that delegates to our custom iterator
|
||||
# We'll use a simple approach: just replace the __aiter__ method
|
||||
class MCPStreamWrapper(CustomStreamWrapper):
|
||||
def __init__(self, original_wrapper, custom_iterator):
|
||||
# Initialize with the same parameters as original wrapper
|
||||
super().__init__(
|
||||
completion_stream=None,
|
||||
model=getattr(original_wrapper, "model", "unknown"),
|
||||
logging_obj=getattr(original_wrapper, "logging_obj", None),
|
||||
custom_llm_provider=getattr(
|
||||
original_wrapper, "custom_llm_provider", None
|
||||
),
|
||||
stream_options=getattr(original_wrapper, "stream_options", None),
|
||||
make_call=getattr(original_wrapper, "make_call", None),
|
||||
_response_headers=getattr(
|
||||
original_wrapper, "_response_headers", None
|
||||
),
|
||||
)
|
||||
self._original_wrapper = original_wrapper
|
||||
self._custom_iterator = custom_iterator
|
||||
# Copy important attributes from original wrapper
|
||||
if hasattr(original_wrapper, "_hidden_params"):
|
||||
self._hidden_params = original_wrapper._hidden_params
|
||||
# For synchronous iteration, we need to run the async iterator
|
||||
self._sync_iterator = None
|
||||
self._sync_loop = None
|
||||
|
||||
def __aiter__(self):
|
||||
return self._custom_iterator
|
||||
|
||||
def __iter__(self):
|
||||
# For synchronous iteration, create a sync wrapper
|
||||
if self._sync_iterator is None:
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
self._sync_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
self._sync_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._sync_loop)
|
||||
self._sync_iterator = _SyncIteratorWrapper(
|
||||
self._custom_iterator, self._sync_loop
|
||||
)
|
||||
return self._sync_iterator
|
||||
|
||||
def __next__(self):
|
||||
# Delegate to sync iterator
|
||||
if self._sync_iterator is None:
|
||||
self.__iter__()
|
||||
return next(self._sync_iterator)
|
||||
|
||||
def __getattr__(self, name):
|
||||
# Delegate all other attributes to original wrapper
|
||||
return getattr(self._original_wrapper, name)
|
||||
|
||||
# Helper class to wrap async iterator for sync iteration
|
||||
class _SyncIteratorWrapper:
|
||||
def __init__(self, async_iterator, loop):
|
||||
self._async_iterator = async_iterator
|
||||
self._loop = loop
|
||||
self._iterator = None
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._iterator is None:
|
||||
# __aiter__ might be async, so we need to await it
|
||||
aiter_result = self._async_iterator.__aiter__()
|
||||
if hasattr(aiter_result, "__await__"):
|
||||
# It's a coroutine, await it
|
||||
self._iterator = self._loop.run_until_complete(aiter_result)
|
||||
else:
|
||||
# It's already an iterator
|
||||
self._iterator = aiter_result
|
||||
try:
|
||||
return self._loop.run_until_complete(self._iterator.__anext__())
|
||||
except StopAsyncIteration:
|
||||
raise StopIteration
|
||||
|
||||
return cast(CustomStreamWrapper, MCPStreamWrapper(initial_stream, iterator))
|
||||
|
||||
# Non-streaming mode: use existing logic
|
||||
initial_call_args = dict(base_call_args)
|
||||
initial_call_args["stream"] = False
|
||||
if mock_tool_calls is not None:
|
||||
initial_call_args["mock_tool_calls"] = mock_tool_calls
|
||||
|
||||
# Make initial call
|
||||
initial_response = await litellm_acompletion(**initial_call_args)
|
||||
|
||||
if not isinstance(initial_response, ModelResponse):
|
||||
return initial_response
|
||||
|
||||
# Extract tool calls from response
|
||||
tool_calls = LiteLLM_Proxy_MCP_Handler._extract_tool_calls_from_chat_response(
|
||||
response=initial_response
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
_add_mcp_metadata_to_response(
|
||||
response=initial_response,
|
||||
openai_tools=openai_tools,
|
||||
)
|
||||
return initial_response
|
||||
|
||||
# Execute tool calls
|
||||
tool_results = await LiteLLM_Proxy_MCP_Handler._execute_tool_calls(
|
||||
tool_server_map=tool_server_map,
|
||||
tool_calls=tool_calls,
|
||||
user_api_key_auth=user_api_key_auth,
|
||||
mcp_auth_header=mcp_auth_header,
|
||||
mcp_server_auth_headers=mcp_server_auth_headers,
|
||||
oauth2_headers=oauth2_headers,
|
||||
raw_headers=raw_headers,
|
||||
litellm_call_id=kwargs.get("litellm_call_id"),
|
||||
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||
)
|
||||
|
||||
if not tool_results:
|
||||
_add_mcp_metadata_to_response(
|
||||
response=initial_response,
|
||||
openai_tools=openai_tools,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
return initial_response
|
||||
|
||||
# Create follow-up messages with tool results
|
||||
follow_up_messages = LiteLLM_Proxy_MCP_Handler._create_follow_up_messages_for_chat(
|
||||
original_messages=messages,
|
||||
response=initial_response,
|
||||
tool_results=tool_results,
|
||||
)
|
||||
|
||||
# Make follow-up call with original stream setting
|
||||
follow_up_call_args = dict(base_call_args)
|
||||
follow_up_call_args["messages"] = follow_up_messages
|
||||
follow_up_call_args["stream"] = stream
|
||||
|
||||
response = await litellm_acompletion(**follow_up_call_args)
|
||||
if isinstance(response, (ModelResponse, CustomStreamWrapper)):
|
||||
_add_mcp_metadata_to_response(
|
||||
response=response,
|
||||
openai_tools=openai_tools,
|
||||
tool_calls=tool_calls,
|
||||
tool_results=tool_results,
|
||||
)
|
||||
return response
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,798 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.responses.streaming_iterator import BaseResponsesAPIStreamingIterator
|
||||
from litellm.types.llms.openai import (
|
||||
BaseLiteLLMOpenAIResponseObject,
|
||||
MCPCallArgumentsDeltaEvent,
|
||||
MCPCallArgumentsDoneEvent,
|
||||
MCPCallCompletedEvent,
|
||||
MCPCallFailedEvent,
|
||||
MCPCallInProgressEvent,
|
||||
MCPListToolsCompletedEvent,
|
||||
MCPListToolsFailedEvent,
|
||||
MCPListToolsInProgressEvent,
|
||||
ResponsesAPIResponse,
|
||||
ResponsesAPIStreamEvents,
|
||||
ResponsesAPIStreamingResponse,
|
||||
ToolParam,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.types import Tool as MCPTool
|
||||
else:
|
||||
MCPTool = Any
|
||||
|
||||
|
||||
async def create_mcp_list_tools_events(
|
||||
mcp_tools_with_litellm_proxy: List[ToolParam],
|
||||
user_api_key_auth: Any,
|
||||
base_item_id: str,
|
||||
pre_processed_mcp_tools: List[Any],
|
||||
) -> List[ResponsesAPIStreamingResponse]:
|
||||
"""Create MCP discovery events using pre-processed tools from the parent"""
|
||||
|
||||
events: List[ResponsesAPIStreamingResponse] = []
|
||||
|
||||
try:
|
||||
# Extract MCP server names
|
||||
mcp_servers = []
|
||||
for tool in mcp_tools_with_litellm_proxy:
|
||||
if isinstance(tool, dict) and "server_url" in tool:
|
||||
server_url = tool.get("server_url")
|
||||
if isinstance(server_url, str) and server_url.startswith(
|
||||
"litellm_proxy/mcp/"
|
||||
):
|
||||
server_name = server_url.split("/")[-1]
|
||||
mcp_servers.append(server_name)
|
||||
|
||||
# Emit list tools in progress event
|
||||
in_progress_event = MCPListToolsInProgressEvent(
|
||||
type=ResponsesAPIStreamEvents.MCP_LIST_TOOLS_IN_PROGRESS,
|
||||
sequence_number=1,
|
||||
output_index=0,
|
||||
item_id=base_item_id,
|
||||
)
|
||||
events.append(in_progress_event)
|
||||
|
||||
# Use the pre-processed MCP tools that were already fetched, filtered, and deduplicated by the parent
|
||||
filtered_mcp_tools = pre_processed_mcp_tools
|
||||
|
||||
# Convert tools to dict format for the event
|
||||
mcp_tools_dict = []
|
||||
for tool in filtered_mcp_tools:
|
||||
if hasattr(tool, "model_dump") and callable(getattr(tool, "model_dump")):
|
||||
# Type cast to help mypy understand this is safe after hasattr check
|
||||
mcp_tools_dict.append(cast(Any, tool).model_dump())
|
||||
elif hasattr(tool, "__dict__"):
|
||||
mcp_tools_dict.append(tool.__dict__)
|
||||
else:
|
||||
mcp_tools_dict.append({"name": getattr(tool, "name", str(tool))})
|
||||
|
||||
# Emit list tools completed event
|
||||
completed_event = MCPListToolsCompletedEvent(
|
||||
type=ResponsesAPIStreamEvents.MCP_LIST_TOOLS_COMPLETED,
|
||||
sequence_number=2,
|
||||
output_index=0,
|
||||
item_id=base_item_id,
|
||||
)
|
||||
events.append(completed_event)
|
||||
|
||||
# Add output_item.done event with the actual tools list (matching OpenAI format)
|
||||
from litellm.types.llms.openai import OutputItemDoneEvent
|
||||
|
||||
# Extract server label from the first MCP tool config
|
||||
server_label = ""
|
||||
if mcp_tools_with_litellm_proxy:
|
||||
first_tool = mcp_tools_with_litellm_proxy[0]
|
||||
if isinstance(first_tool, dict):
|
||||
server_label_value = first_tool.get("server_label", "")
|
||||
server_label = (
|
||||
str(server_label_value) if server_label_value is not None else ""
|
||||
)
|
||||
|
||||
# Format tools for OpenAI output_item.done format
|
||||
formatted_tools = []
|
||||
for tool in filtered_mcp_tools:
|
||||
tool_dict = {
|
||||
"name": getattr(tool, "name", "unknown"),
|
||||
"description": getattr(tool, "description", ""),
|
||||
"annotations": {"read_only": False},
|
||||
}
|
||||
|
||||
# Add input_schema if available
|
||||
if hasattr(tool, "inputSchema"):
|
||||
tool_dict["input_schema"] = getattr(tool, "inputSchema")
|
||||
elif hasattr(tool, "input_schema"):
|
||||
tool_dict["input_schema"] = getattr(tool, "input_schema")
|
||||
|
||||
formatted_tools.append(tool_dict)
|
||||
|
||||
# Create the output_item.done event with MCP tools list
|
||||
output_item_done_event = OutputItemDoneEvent(
|
||||
type=ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE,
|
||||
output_index=0,
|
||||
item=BaseLiteLLMOpenAIResponseObject(
|
||||
**{
|
||||
"id": base_item_id,
|
||||
"type": "mcp_list_tools",
|
||||
"server_label": server_label,
|
||||
"tools": formatted_tools,
|
||||
}
|
||||
),
|
||||
)
|
||||
events.append(output_item_done_event)
|
||||
|
||||
verbose_logger.debug(f"Created {len(events)} MCP discovery events")
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error creating MCP list tools events: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# Emit failed event on error
|
||||
failed_event = MCPListToolsFailedEvent(
|
||||
type=ResponsesAPIStreamEvents.MCP_LIST_TOOLS_FAILED,
|
||||
sequence_number=2,
|
||||
output_index=0,
|
||||
item_id=base_item_id,
|
||||
)
|
||||
events.append(failed_event)
|
||||
|
||||
# Still emit output_item.done event even on failure (with empty tools list)
|
||||
from litellm.types.llms.openai import OutputItemDoneEvent
|
||||
|
||||
output_item_done_event = OutputItemDoneEvent(
|
||||
type=ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE,
|
||||
output_index=0,
|
||||
item=BaseLiteLLMOpenAIResponseObject(
|
||||
**{
|
||||
"id": base_item_id,
|
||||
"type": "mcp_list_tools",
|
||||
"server_label": "",
|
||||
"tools": [],
|
||||
}
|
||||
),
|
||||
)
|
||||
events.append(output_item_done_event)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def create_mcp_call_events(
|
||||
tool_name: str,
|
||||
tool_call_id: str,
|
||||
arguments: str,
|
||||
result: Optional[str] = None,
|
||||
base_item_id: Optional[str] = None,
|
||||
sequence_start: int = 1,
|
||||
) -> List[ResponsesAPIStreamingResponse]:
|
||||
"""Create MCP call events following OpenAI's specification"""
|
||||
events: List[ResponsesAPIStreamingResponse] = []
|
||||
item_id = base_item_id or f"mcp_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# MCP call in progress event
|
||||
in_progress_event = MCPCallInProgressEvent(
|
||||
type=ResponsesAPIStreamEvents.MCP_CALL_IN_PROGRESS,
|
||||
sequence_number=sequence_start,
|
||||
output_index=0,
|
||||
item_id=item_id,
|
||||
)
|
||||
events.append(in_progress_event)
|
||||
|
||||
# MCP call arguments delta event (streaming the arguments)
|
||||
arguments_delta_event = MCPCallArgumentsDeltaEvent(
|
||||
type=ResponsesAPIStreamEvents.MCP_CALL_ARGUMENTS_DELTA,
|
||||
output_index=0,
|
||||
item_id=item_id,
|
||||
delta=arguments, # JSON string with arguments
|
||||
sequence_number=sequence_start + 1,
|
||||
)
|
||||
events.append(arguments_delta_event)
|
||||
|
||||
# MCP call arguments done event
|
||||
arguments_done_event = MCPCallArgumentsDoneEvent(
|
||||
type=ResponsesAPIStreamEvents.MCP_CALL_ARGUMENTS_DONE,
|
||||
output_index=0,
|
||||
item_id=item_id,
|
||||
arguments=arguments, # Complete JSON string with finalized arguments
|
||||
sequence_number=sequence_start + 2,
|
||||
)
|
||||
events.append(arguments_done_event)
|
||||
|
||||
# MCP call completed event (or failed if result indicates failure)
|
||||
if result is not None:
|
||||
completed_event = MCPCallCompletedEvent(
|
||||
type=ResponsesAPIStreamEvents.MCP_CALL_COMPLETED,
|
||||
sequence_number=sequence_start + 3,
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
)
|
||||
events.append(completed_event)
|
||||
|
||||
# Add output_item.done event with the tool call result
|
||||
from litellm.types.llms.openai import OutputItemDoneEvent
|
||||
|
||||
output_item_done_event = OutputItemDoneEvent(
|
||||
type=ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE,
|
||||
output_index=0,
|
||||
item=BaseLiteLLMOpenAIResponseObject(
|
||||
**{
|
||||
"id": item_id,
|
||||
"type": "mcp_call",
|
||||
"approval_request_id": f"mcpr_{uuid.uuid4().hex[:8]}",
|
||||
"arguments": arguments,
|
||||
"error": None,
|
||||
"name": tool_name,
|
||||
"output": result,
|
||||
"server_label": "litellm",
|
||||
}
|
||||
),
|
||||
)
|
||||
events.append(output_item_done_event)
|
||||
else:
|
||||
failed_event = MCPCallFailedEvent(
|
||||
type=ResponsesAPIStreamEvents.MCP_CALL_FAILED,
|
||||
sequence_number=sequence_start + 3,
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
)
|
||||
events.append(failed_event)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
class MCPEnhancedStreamingIterator(BaseResponsesAPIStreamingIterator):
|
||||
"""
|
||||
A complete MCP streaming iterator that handles the entire flow:
|
||||
1. Immediately emits MCP discovery events
|
||||
2. Makes the first LLM call and streams its response
|
||||
3. Handles tool execution and follow-up calls for auto-execute tools
|
||||
4. Emits tool execution events in the stream
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_iterator: Any, # Can be None - will be created internally
|
||||
mcp_events: List[ResponsesAPIStreamingResponse],
|
||||
tool_server_map: dict[str, str],
|
||||
mcp_tools_with_litellm_proxy: Optional[List[Any]] = None,
|
||||
user_api_key_auth: Any = None,
|
||||
original_request_params: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
# MCP setup
|
||||
self.mcp_tools_with_litellm_proxy = mcp_tools_with_litellm_proxy or []
|
||||
self.user_api_key_auth = user_api_key_auth
|
||||
self.original_request_params = original_request_params or {}
|
||||
self.should_auto_execute = self._should_auto_execute_tools()
|
||||
|
||||
# Streaming state management
|
||||
self.phase = "initial_response" # initial_response -> mcp_discovery -> tool_execution -> follow_up_response -> finished
|
||||
self.finished = False
|
||||
|
||||
# Event queues and generation flags
|
||||
self.mcp_discovery_events: List[
|
||||
ResponsesAPIStreamingResponse
|
||||
] = mcp_events # Pre-generated MCP discovery events
|
||||
self.tool_execution_events: List[ResponsesAPIStreamingResponse] = []
|
||||
self.mcp_discovery_generated = True # Events are already generated
|
||||
self.mcp_events = (
|
||||
mcp_events # Store the initial MCP events for backward compatibility
|
||||
)
|
||||
self.tool_server_map = tool_server_map
|
||||
|
||||
# Iterator references
|
||||
self.base_iterator: Optional[
|
||||
Union[Any, ResponsesAPIResponse]
|
||||
] = base_iterator # Will be created when needed
|
||||
self.follow_up_iterator: Optional[Any] = None
|
||||
|
||||
# Response collection for tool execution
|
||||
self.collected_response: Optional[ResponsesAPIResponse] = None
|
||||
|
||||
# Set up model metadata (will be updated when we get the real iterator)
|
||||
self.model = self.original_request_params.get("model", "unknown")
|
||||
self.litellm_metadata = {}
|
||||
self.custom_llm_provider = self.original_request_params.get(
|
||||
"custom_llm_provider", None
|
||||
)
|
||||
self.litellm_call_id = self.original_request_params.get("litellm_call_id")
|
||||
self.litellm_trace_id = self.original_request_params.get("litellm_trace_id")
|
||||
|
||||
self._extract_mcp_headers_from_params()
|
||||
|
||||
# Mark as async iterator
|
||||
self.is_async = True
|
||||
|
||||
# Track if we've emitted initial OpenAI lifecycle events
|
||||
self.initial_events_emitted = False
|
||||
|
||||
# Cache the response ID to ensure consistency across all events
|
||||
self._cached_response_id: Optional[str] = None
|
||||
|
||||
def _extract_mcp_headers_from_params(self) -> None:
|
||||
"""Extract MCP headers from original request params to pass to tool calls"""
|
||||
from typing import Dict, Optional
|
||||
from starlette.datastructures import Headers
|
||||
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
|
||||
MCPRequestHandler,
|
||||
)
|
||||
|
||||
# Extract headers from secret_fields in original_request_params
|
||||
raw_headers_from_request: Optional[Dict[str, str]] = None
|
||||
secret_fields = self.original_request_params.get("secret_fields")
|
||||
if secret_fields and isinstance(secret_fields, dict):
|
||||
raw_headers_from_request = secret_fields.get("raw_headers")
|
||||
|
||||
# Extract MCP-specific headers
|
||||
self.mcp_auth_header: Optional[str] = None
|
||||
self.mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None
|
||||
self.oauth2_headers: Optional[Dict[str, str]] = None
|
||||
self.raw_headers: Optional[Dict[str, str]] = raw_headers_from_request
|
||||
|
||||
if raw_headers_from_request:
|
||||
headers_obj = Headers(raw_headers_from_request)
|
||||
self.mcp_auth_header = MCPRequestHandler._get_mcp_auth_header_from_headers(
|
||||
headers_obj
|
||||
)
|
||||
self.mcp_server_auth_headers = (
|
||||
MCPRequestHandler._get_mcp_server_auth_headers_from_headers(headers_obj)
|
||||
)
|
||||
self.oauth2_headers = MCPRequestHandler._get_oauth2_headers_from_headers(
|
||||
headers_obj
|
||||
)
|
||||
|
||||
# Also check if headers are provided in tools array (from request body)
|
||||
tools = self.original_request_params.get("tools")
|
||||
if tools:
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict) and tool.get("type") == "mcp":
|
||||
tool_headers = tool.get("headers", {})
|
||||
if tool_headers and isinstance(tool_headers, dict):
|
||||
# Merge tool headers into mcp_server_auth_headers
|
||||
headers_obj_from_tool = Headers(tool_headers)
|
||||
tool_mcp_server_auth_headers = (
|
||||
MCPRequestHandler._get_mcp_server_auth_headers_from_headers(
|
||||
headers_obj_from_tool
|
||||
)
|
||||
)
|
||||
|
||||
if tool_mcp_server_auth_headers:
|
||||
if self.mcp_server_auth_headers is None:
|
||||
self.mcp_server_auth_headers = {}
|
||||
# Merge the headers from tool into existing headers
|
||||
for (
|
||||
server_alias,
|
||||
headers_dict,
|
||||
) in tool_mcp_server_auth_headers.items():
|
||||
if server_alias not in self.mcp_server_auth_headers:
|
||||
self.mcp_server_auth_headers[server_alias] = {}
|
||||
self.mcp_server_auth_headers[server_alias].update(
|
||||
headers_dict
|
||||
)
|
||||
|
||||
# Also merge raw headers
|
||||
if self.raw_headers is None:
|
||||
self.raw_headers = {}
|
||||
self.raw_headers.update(tool_headers)
|
||||
|
||||
def _should_auto_execute_tools(self) -> bool:
|
||||
"""Check if tools should be auto-executed"""
|
||||
from litellm.responses.mcp.litellm_proxy_mcp_handler import (
|
||||
LiteLLM_Proxy_MCP_Handler,
|
||||
)
|
||||
|
||||
return LiteLLM_Proxy_MCP_Handler._should_auto_execute_tools(
|
||||
self.mcp_tools_with_litellm_proxy
|
||||
)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> ResponsesAPIStreamingResponse:
|
||||
"""
|
||||
Phase-based streaming:
|
||||
1. initial_response - Stream the first LLM response (includes response.created, response.in_progress, response.output_item.added)
|
||||
2. mcp_discovery - Emit MCP discovery events (after response.output_item.added)
|
||||
3. continue_initial_response - Continue streaming the initial response content
|
||||
4. tool_execution - Emit tool execution events
|
||||
5. follow_up_response - Stream the follow-up response
|
||||
6. finished - End iteration
|
||||
"""
|
||||
|
||||
# Phase 1: Initial Response Stream (emit standard OpenAI events first)
|
||||
if self.phase == "initial_response":
|
||||
result = await self._handle_initial_response_phase()
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Phase 2: MCP Discovery Events (after response.output_item.added)
|
||||
if self.phase == "mcp_discovery":
|
||||
# Emit MCP discovery events
|
||||
if self.mcp_discovery_events:
|
||||
return self.mcp_discovery_events.pop(0)
|
||||
self.phase = "continue_initial_response"
|
||||
# Fall through to continue processing the initial response
|
||||
|
||||
# Phase 3: Continue Initial Response (after MCP discovery events)
|
||||
if self.phase == "continue_initial_response":
|
||||
try:
|
||||
return await self._process_base_iterator_chunk()
|
||||
except StopAsyncIteration:
|
||||
# Initial response ended, move to next phase
|
||||
if self.should_auto_execute and self.collected_response:
|
||||
self.phase = "tool_execution"
|
||||
await self._generate_tool_execution_events()
|
||||
else:
|
||||
self.phase = "finished"
|
||||
raise
|
||||
|
||||
# Phase 4: Tool Execution Events
|
||||
if self.phase == "tool_execution":
|
||||
# Emit any queued tool execution events
|
||||
if self.tool_execution_events:
|
||||
return self.tool_execution_events.pop(0)
|
||||
|
||||
# Move to follow-up response phase
|
||||
self.phase = "follow_up_response"
|
||||
await self._create_follow_up_iterator()
|
||||
|
||||
# Phase 5: Follow-up Response Stream
|
||||
if self.phase == "follow_up_response":
|
||||
if self.follow_up_iterator:
|
||||
try:
|
||||
return await cast(Any, self.follow_up_iterator).__anext__() # type: ignore[attr-defined]
|
||||
except StopAsyncIteration:
|
||||
self.phase = "finished"
|
||||
raise
|
||||
else:
|
||||
self.phase = "finished"
|
||||
raise StopAsyncIteration
|
||||
|
||||
# Phase 6: Finished
|
||||
if self.phase == "finished":
|
||||
raise StopAsyncIteration
|
||||
|
||||
# Should not reach here
|
||||
raise StopAsyncIteration
|
||||
|
||||
async def _handle_initial_response_phase(
|
||||
self,
|
||||
) -> Optional[ResponsesAPIStreamingResponse]:
|
||||
"""
|
||||
Handle Phase 1: Initial Response Stream.
|
||||
|
||||
Returns a chunk to emit, or None to fall through to the next phase.
|
||||
Raises StopAsyncIteration when the stream is exhausted with no auto-execution.
|
||||
"""
|
||||
if self.base_iterator is None:
|
||||
await self._create_initial_response_iterator()
|
||||
|
||||
if self.base_iterator is None:
|
||||
# LLM call failed — still emit MCP discovery events before finishing
|
||||
if self.mcp_discovery_events:
|
||||
self.phase = "mcp_discovery"
|
||||
else:
|
||||
self.phase = "finished"
|
||||
raise StopAsyncIteration
|
||||
return None
|
||||
|
||||
if self.base_iterator:
|
||||
if hasattr(self.base_iterator, "__anext__"):
|
||||
try:
|
||||
chunk = await cast(Any, self.base_iterator).__anext__() # type: ignore[attr-defined]
|
||||
|
||||
# Capture the response ID from the first event to ensure consistency
|
||||
if self._cached_response_id is None and hasattr(chunk, "response"):
|
||||
response_obj = getattr(chunk, "response", None)
|
||||
if response_obj and hasattr(response_obj, "id"):
|
||||
self._cached_response_id = response_obj.id
|
||||
verbose_logger.debug(
|
||||
f"Cached response ID: {self._cached_response_id}"
|
||||
)
|
||||
|
||||
# After emitting response.output_item.added, transition to MCP discovery
|
||||
if not self.initial_events_emitted and hasattr(chunk, "type"):
|
||||
chunk_type = getattr(chunk, "type", None)
|
||||
if chunk_type == ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED:
|
||||
self.initial_events_emitted = True
|
||||
self.phase = "mcp_discovery"
|
||||
return chunk
|
||||
|
||||
# If auto-execution is enabled, check for completed responses
|
||||
if self.should_auto_execute and self._is_response_completed(chunk):
|
||||
response_obj = getattr(chunk, "response", None)
|
||||
if isinstance(response_obj, ResponsesAPIResponse):
|
||||
self.collected_response = response_obj
|
||||
self.phase = "tool_execution"
|
||||
await self._generate_tool_execution_events()
|
||||
|
||||
return chunk
|
||||
except StopAsyncIteration:
|
||||
if self.should_auto_execute and self.collected_response:
|
||||
self.phase = "tool_execution"
|
||||
await self._generate_tool_execution_events()
|
||||
else:
|
||||
self.phase = "finished"
|
||||
raise
|
||||
else:
|
||||
# base_iterator is not async iterable (likely a ResponsesAPIResponse)
|
||||
if self.should_auto_execute and isinstance(
|
||||
self.base_iterator, ResponsesAPIResponse
|
||||
):
|
||||
self.collected_response = self.base_iterator
|
||||
self.phase = "tool_execution"
|
||||
await self._generate_tool_execution_events()
|
||||
else:
|
||||
self.phase = "finished"
|
||||
raise StopAsyncIteration
|
||||
return None
|
||||
|
||||
def _is_response_completed(self, chunk: ResponsesAPIStreamingResponse) -> bool:
|
||||
"""Check if this chunk indicates the response is completed"""
|
||||
from litellm.types.llms.openai import ResponsesAPIStreamEvents
|
||||
|
||||
return (
|
||||
getattr(chunk, "type", None) == ResponsesAPIStreamEvents.RESPONSE_COMPLETED
|
||||
)
|
||||
|
||||
async def _process_base_iterator_chunk(self) -> ResponsesAPIStreamingResponse:
|
||||
"""
|
||||
Process a chunk from the base iterator with response ID consistency enforcement.
|
||||
"""
|
||||
if not self.base_iterator or not hasattr(self.base_iterator, "__anext__"):
|
||||
raise StopAsyncIteration
|
||||
|
||||
chunk = await cast(Any, self.base_iterator).__anext__() # type: ignore[attr-defined]
|
||||
|
||||
# Ensure response ID consistency - update chunk if needed
|
||||
if self._cached_response_id and hasattr(chunk, "response"):
|
||||
response_obj = getattr(chunk, "response", None)
|
||||
if response_obj and hasattr(response_obj, "id"):
|
||||
if response_obj.id != self._cached_response_id:
|
||||
verbose_logger.debug(
|
||||
f"Updating response ID from {response_obj.id} to {self._cached_response_id}"
|
||||
)
|
||||
response_obj.id = self._cached_response_id
|
||||
|
||||
# If auto-execution is enabled, check for completed responses
|
||||
if self.should_auto_execute and self._is_response_completed(chunk):
|
||||
# Collect the response for tool execution
|
||||
response_obj = getattr(chunk, "response", None)
|
||||
if isinstance(response_obj, ResponsesAPIResponse):
|
||||
self.collected_response = response_obj
|
||||
# Move to tool execution phase after emitting this chunk
|
||||
self.phase = "tool_execution"
|
||||
await self._generate_tool_execution_events()
|
||||
|
||||
return chunk
|
||||
|
||||
async def _create_initial_response_iterator(self) -> None:
|
||||
"""Create the initial response iterator by making the first LLM call"""
|
||||
try:
|
||||
# Import the core aresponses function that doesn't have MCP logic
|
||||
from litellm.responses.main import aresponses
|
||||
|
||||
# Make the initial response API call - but avoid the MCP wrapper
|
||||
params = self.original_request_params.copy()
|
||||
params["stream"] = True # Ensure streaming
|
||||
|
||||
# Use the pre-fetched all_tools from original_request_params (no re-processing needed)
|
||||
params_for_llm = {}
|
||||
for key, value in params.items():
|
||||
params_for_llm[
|
||||
key
|
||||
] = value # Copy all params as-is since tools are already processed
|
||||
|
||||
tools_count = (
|
||||
len(params_for_llm.get("tools", []))
|
||||
if params_for_llm.get("tools")
|
||||
else 0
|
||||
)
|
||||
verbose_logger.debug(f"Making LLM call with {tools_count} tools")
|
||||
response = await aresponses(**params_for_llm)
|
||||
|
||||
# Set the base iterator
|
||||
if hasattr(response, "__aiter__") or hasattr(response, "__iter__"):
|
||||
self.base_iterator = response
|
||||
# Copy metadata from the real iterator
|
||||
self.model = getattr(response, "model", self.model)
|
||||
self.litellm_metadata = getattr(response, "litellm_metadata", {})
|
||||
self.custom_llm_provider = getattr(
|
||||
response, "custom_llm_provider", self.custom_llm_provider
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Created base iterator: {type(self.base_iterator)}"
|
||||
)
|
||||
else:
|
||||
# Non-streaming response - this shouldn't happen but handle it
|
||||
verbose_logger.warning(f"Got non-streaming response: {type(response)}")
|
||||
self.base_iterator = None
|
||||
self.phase = "finished"
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error creating initial response iterator: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
self.base_iterator = None
|
||||
# Don't set phase to "finished" here — let __anext__ emit any
|
||||
# pre-generated MCP discovery events before ending the iteration.
|
||||
|
||||
async def _generate_tool_execution_events(self) -> None:
|
||||
"""Generate tool execution events and execute tools"""
|
||||
if not self.collected_response:
|
||||
return
|
||||
from litellm.responses.mcp.litellm_proxy_mcp_handler import (
|
||||
LiteLLM_Proxy_MCP_Handler,
|
||||
)
|
||||
|
||||
try:
|
||||
# Extract tool calls from the response
|
||||
if self.collected_response is not None:
|
||||
tool_calls = LiteLLM_Proxy_MCP_Handler._extract_tool_calls_from_response(self.collected_response) # type: ignore[arg-type]
|
||||
else:
|
||||
tool_calls = []
|
||||
if not tool_calls:
|
||||
return
|
||||
|
||||
for tool_call in tool_calls:
|
||||
(
|
||||
tool_name,
|
||||
tool_arguments,
|
||||
tool_call_id,
|
||||
) = LiteLLM_Proxy_MCP_Handler._extract_tool_call_details(tool_call)
|
||||
if tool_name and tool_call_id:
|
||||
# Create MCP call events for this tool execution
|
||||
call_events = create_mcp_call_events(
|
||||
tool_name=tool_name,
|
||||
tool_call_id=tool_call_id,
|
||||
arguments=tool_arguments or "{}", # JSON string with arguments
|
||||
result=None, # Will be set after execution
|
||||
base_item_id=f"mcp_{uuid.uuid4().hex[:8]}",
|
||||
sequence_start=len(self.tool_execution_events) + 1,
|
||||
)
|
||||
# Add the in_progress and arguments events (not the completed event yet)
|
||||
self.tool_execution_events.extend(call_events[:-1])
|
||||
|
||||
# Execute the tools
|
||||
tool_results = await LiteLLM_Proxy_MCP_Handler._execute_tool_calls(
|
||||
tool_server_map=self.tool_server_map,
|
||||
tool_calls=tool_calls,
|
||||
user_api_key_auth=self.user_api_key_auth,
|
||||
mcp_auth_header=self.mcp_auth_header,
|
||||
mcp_server_auth_headers=self.mcp_server_auth_headers,
|
||||
oauth2_headers=self.oauth2_headers,
|
||||
raw_headers=self.raw_headers,
|
||||
litellm_call_id=self.litellm_call_id,
|
||||
litellm_trace_id=self.litellm_trace_id,
|
||||
)
|
||||
|
||||
# Create completion events and output_item.done events for tool execution
|
||||
for tool_result in tool_results:
|
||||
tool_call_id = tool_result.get("tool_call_id", "unknown")
|
||||
result_text = tool_result.get("result", "")
|
||||
|
||||
# Find matching tool name and arguments
|
||||
tool_name = "unknown"
|
||||
tool_arguments = "{}"
|
||||
for tool_call in tool_calls:
|
||||
(
|
||||
name,
|
||||
args,
|
||||
call_id,
|
||||
) = LiteLLM_Proxy_MCP_Handler._extract_tool_call_details(tool_call)
|
||||
if call_id == tool_call_id:
|
||||
tool_name = name or "unknown"
|
||||
tool_arguments = args or "{}"
|
||||
break
|
||||
|
||||
item_id = f"mcp_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Create the completion event
|
||||
completed_event = MCPCallCompletedEvent(
|
||||
type=ResponsesAPIStreamEvents.MCP_CALL_COMPLETED,
|
||||
sequence_number=len(self.tool_execution_events) + 1,
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
)
|
||||
self.tool_execution_events.append(completed_event)
|
||||
|
||||
# Create output_item.done event with the tool call result
|
||||
from litellm.types.llms.openai import OutputItemDoneEvent
|
||||
|
||||
output_item_done_event = OutputItemDoneEvent(
|
||||
type=ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE,
|
||||
output_index=0,
|
||||
item=BaseLiteLLMOpenAIResponseObject(
|
||||
**{
|
||||
"id": item_id,
|
||||
"type": "mcp_call",
|
||||
"approval_request_id": f"mcpr_{uuid.uuid4().hex[:8]}",
|
||||
"arguments": tool_arguments,
|
||||
"error": None,
|
||||
"name": tool_name,
|
||||
"output": result_text,
|
||||
"server_label": "litellm", # or extract from tool config
|
||||
}
|
||||
),
|
||||
)
|
||||
self.tool_execution_events.append(output_item_done_event)
|
||||
|
||||
# Store tool results for follow-up call
|
||||
self.tool_results = tool_results
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error in tool execution: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
self.tool_results = []
|
||||
|
||||
async def _create_follow_up_iterator(self) -> None:
|
||||
"""Create the follow-up response iterator with tool results"""
|
||||
if not self.collected_response or not hasattr(self, "tool_results"):
|
||||
return
|
||||
|
||||
from litellm.responses.main import aresponses
|
||||
from litellm.responses.mcp.litellm_proxy_mcp_handler import (
|
||||
LiteLLM_Proxy_MCP_Handler,
|
||||
)
|
||||
|
||||
try:
|
||||
# Create follow-up input
|
||||
if self.collected_response is not None:
|
||||
follow_up_input = LiteLLM_Proxy_MCP_Handler._create_follow_up_input(
|
||||
response=self.collected_response, # type: ignore[arg-type]
|
||||
tool_results=self.tool_results,
|
||||
original_input=self.original_request_params.get("input"),
|
||||
)
|
||||
|
||||
# Make follow-up call with streaming
|
||||
follow_up_params = self.original_request_params.copy()
|
||||
follow_up_params.update(
|
||||
{
|
||||
"input": follow_up_input,
|
||||
"stream": True,
|
||||
}
|
||||
)
|
||||
else:
|
||||
return
|
||||
# Remove tool_choice to avoid forcing more tool calls
|
||||
follow_up_params.pop("tool_choice", None)
|
||||
|
||||
follow_up_response = await aresponses(**follow_up_params)
|
||||
|
||||
# Set up the follow-up iterator
|
||||
if hasattr(follow_up_response, "__aiter__"):
|
||||
self.follow_up_iterator = follow_up_response
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error creating follow-up iterator: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
self.follow_up_iterator = None
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> ResponsesAPIStreamingResponse:
|
||||
# First, emit any queued MCP events
|
||||
if self.mcp_events: # type: ignore[attr-defined]
|
||||
return self.mcp_events.pop(0) # type: ignore[attr-defined]
|
||||
|
||||
# Then delegate to the base iterator
|
||||
if not self.is_async:
|
||||
try:
|
||||
if self.base_iterator and hasattr(self.base_iterator, "__next__"):
|
||||
return next(cast(Any, self.base_iterator)) # type: ignore[arg-type]
|
||||
else:
|
||||
raise StopIteration
|
||||
except StopIteration:
|
||||
self.finished = True
|
||||
raise
|
||||
else:
|
||||
raise RuntimeError("Cannot use sync iteration on async iterator")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,726 @@
|
||||
import base64
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
get_type_hints,
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
|
||||
from litellm.types.llms.openai import (
|
||||
ResponseAPIUsage,
|
||||
ResponsesAPIOptionalRequestParams,
|
||||
ResponsesAPIResponse,
|
||||
ResponseText,
|
||||
)
|
||||
from litellm.types.responses.main import DecodedResponseId
|
||||
from litellm.types.utils import (
|
||||
CompletionTokensDetailsWrapper,
|
||||
PromptTokensDetailsWrapper,
|
||||
SpecialEnums,
|
||||
Usage,
|
||||
)
|
||||
|
||||
|
||||
class ResponsesAPIRequestUtils:
|
||||
"""Helper utils for constructing ResponseAPI requests"""
|
||||
|
||||
@staticmethod
|
||||
def _check_valid_arg(
|
||||
supported_params: Optional[List[str]],
|
||||
non_default_params: Dict,
|
||||
drop_params: Optional[bool],
|
||||
custom_llm_provider: Optional[str],
|
||||
model: str,
|
||||
):
|
||||
if supported_params is None:
|
||||
return
|
||||
unsupported_params = {}
|
||||
for k in non_default_params.keys():
|
||||
if k not in supported_params:
|
||||
unsupported_params[k] = non_default_params[k]
|
||||
if unsupported_params:
|
||||
if litellm.drop_params is True or (
|
||||
drop_params is not None and drop_params is True
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise litellm.UnsupportedParamsError(
|
||||
status_code=500,
|
||||
message=f"{custom_llm_provider} does not support parameters: {unsupported_params}, for model={model}. To drop these, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\n",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_optional_params_responses_api(
|
||||
model: str,
|
||||
responses_api_provider_config: BaseResponsesAPIConfig,
|
||||
response_api_optional_params: ResponsesAPIOptionalRequestParams,
|
||||
allowed_openai_params: Optional[List[str]] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Get optional parameters for the responses API.
|
||||
|
||||
Args:
|
||||
params: Dictionary of all parameters
|
||||
model: The model name
|
||||
responses_api_provider_config: The provider configuration for responses API
|
||||
|
||||
Returns:
|
||||
A dictionary of supported parameters for the responses API
|
||||
"""
|
||||
from litellm.utils import _apply_openai_param_overrides
|
||||
|
||||
# Remove None values and internal parameters
|
||||
# Get supported parameters for the model
|
||||
supported_params = responses_api_provider_config.get_supported_openai_params(
|
||||
model
|
||||
)
|
||||
|
||||
non_default_params = cast(Dict, response_api_optional_params)
|
||||
# Check for unsupported parameters
|
||||
ResponsesAPIRequestUtils._check_valid_arg(
|
||||
supported_params=supported_params + (allowed_openai_params or []),
|
||||
non_default_params=non_default_params,
|
||||
drop_params=litellm.drop_params,
|
||||
custom_llm_provider=responses_api_provider_config.custom_llm_provider,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Map parameters to provider-specific format
|
||||
mapped_params = responses_api_provider_config.map_openai_params(
|
||||
response_api_optional_params=response_api_optional_params,
|
||||
model=model,
|
||||
drop_params=litellm.drop_params,
|
||||
)
|
||||
|
||||
# add any allowed_openai_params to the mapped_params
|
||||
mapped_params = _apply_openai_param_overrides(
|
||||
optional_params=mapped_params,
|
||||
non_default_params=non_default_params,
|
||||
allowed_openai_params=allowed_openai_params or [],
|
||||
)
|
||||
|
||||
return mapped_params
|
||||
|
||||
@staticmethod
|
||||
def get_requested_response_api_optional_param(
|
||||
params: Dict[str, Any],
|
||||
) -> ResponsesAPIOptionalRequestParams:
|
||||
"""
|
||||
Filter parameters to only include those defined in ResponsesAPIOptionalRequestParams.
|
||||
|
||||
Args:
|
||||
params: Dictionary of parameters to filter
|
||||
|
||||
Returns:
|
||||
ResponsesAPIOptionalRequestParams instance with only the valid parameters
|
||||
"""
|
||||
from litellm.utils import PreProcessNonDefaultParams
|
||||
|
||||
valid_keys = get_type_hints(ResponsesAPIOptionalRequestParams).keys()
|
||||
custom_llm_provider = params.pop("custom_llm_provider", None)
|
||||
special_params = params.pop("kwargs", {})
|
||||
|
||||
additional_drop_params = params.pop("additional_drop_params", None)
|
||||
non_default_params = (
|
||||
PreProcessNonDefaultParams.base_pre_process_non_default_params(
|
||||
passed_params=params,
|
||||
special_params=special_params,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
additional_drop_params=additional_drop_params,
|
||||
default_param_values={k: None for k in valid_keys},
|
||||
additional_endpoint_specific_params=["input"],
|
||||
)
|
||||
)
|
||||
|
||||
# decode previous_response_id if it's a litellm encoded id
|
||||
if "previous_response_id" in non_default_params:
|
||||
decoded_previous_response_id = ResponsesAPIRequestUtils.decode_previous_response_id_to_original_previous_response_id(
|
||||
non_default_params["previous_response_id"]
|
||||
)
|
||||
non_default_params["previous_response_id"] = decoded_previous_response_id
|
||||
|
||||
if "metadata" in non_default_params:
|
||||
from litellm.utils import add_openai_metadata
|
||||
|
||||
converted_metadata = add_openai_metadata(non_default_params["metadata"])
|
||||
if converted_metadata is not None:
|
||||
non_default_params["metadata"] = converted_metadata
|
||||
else:
|
||||
non_default_params.pop("metadata", None)
|
||||
|
||||
return cast(ResponsesAPIOptionalRequestParams, non_default_params)
|
||||
|
||||
# fmt: off
|
||||
@overload
|
||||
@staticmethod
|
||||
def _update_responses_api_response_id_with_model_id(
|
||||
responses_api_response: ResponsesAPIResponse,
|
||||
custom_llm_provider: Optional[str],
|
||||
litellm_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> ResponsesAPIResponse:
|
||||
...
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def _update_responses_api_response_id_with_model_id(
|
||||
responses_api_response: Dict[str, Any],
|
||||
custom_llm_provider: Optional[str],
|
||||
litellm_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
@staticmethod
|
||||
def _update_responses_api_response_id_with_model_id(
|
||||
responses_api_response: Union[ResponsesAPIResponse, Dict[str, Any]],
|
||||
custom_llm_provider: Optional[str],
|
||||
litellm_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[ResponsesAPIResponse, Dict[str, Any]]:
|
||||
"""Update the responses_api_response_id with model_id and custom_llm_provider.
|
||||
|
||||
Handles both ``ResponsesAPIResponse`` objects and plain dictionaries returned
|
||||
by some streaming providers.
|
||||
"""
|
||||
litellm_metadata = litellm_metadata or {}
|
||||
model_info: Dict[str, Any] = litellm_metadata.get("model_info", {}) or {}
|
||||
model_id = model_info.get("id")
|
||||
|
||||
# access the response id based on the object type
|
||||
if isinstance(responses_api_response, dict):
|
||||
response_id = responses_api_response.get("id")
|
||||
else:
|
||||
response_id = getattr(responses_api_response, "id", None)
|
||||
|
||||
# If no response_id, return the response as-is (likely an error response)
|
||||
if response_id is None:
|
||||
return responses_api_response
|
||||
|
||||
updated_id = ResponsesAPIRequestUtils._build_responses_api_response_id(
|
||||
model_id=model_id,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
if isinstance(responses_api_response, dict):
|
||||
responses_api_response["id"] = updated_id
|
||||
else:
|
||||
responses_api_response.id = updated_id
|
||||
|
||||
if litellm_metadata.get("encrypted_content_affinity_enabled"):
|
||||
responses_api_response = (
|
||||
ResponsesAPIRequestUtils._update_encrypted_content_item_ids_in_response(
|
||||
response=responses_api_response,
|
||||
model_id=model_id,
|
||||
)
|
||||
)
|
||||
|
||||
return responses_api_response
|
||||
|
||||
@staticmethod
|
||||
def _build_encrypted_item_id(model_id: str, item_id: str) -> str:
|
||||
"""Encode model_id into an output item ID for encrypted-content items.
|
||||
|
||||
Format: ``encitem_{base64("litellm:model_id:{model_id};item_id:{original_id}")}``
|
||||
"""
|
||||
assembled = f"litellm:model_id:{model_id};item_id:{item_id}"
|
||||
encoded = base64.b64encode(assembled.encode("utf-8")).decode("utf-8")
|
||||
return f"encitem_{encoded}"
|
||||
|
||||
@staticmethod
|
||||
def _decode_encrypted_item_id(encoded_id: str) -> Optional[Dict[str, str]]:
|
||||
"""Decode a litellm-encoded encrypted-content item ID.
|
||||
|
||||
Returns a dict with ``model_id`` and ``item_id`` keys, or ``None`` if
|
||||
the string is not a litellm-encoded item ID.
|
||||
"""
|
||||
if not encoded_id.startswith("encitem_"):
|
||||
return None
|
||||
try:
|
||||
cleaned = encoded_id[len("encitem_") :]
|
||||
# Restore any padding that may have been stripped in transit
|
||||
missing = len(cleaned) % 4
|
||||
if missing:
|
||||
cleaned += "=" * (4 - missing)
|
||||
decoded = base64.b64decode(cleaned.encode("utf-8")).decode("utf-8")
|
||||
# Split on first ";" only so that semicolons inside item_id are preserved
|
||||
parts = decoded.split(";", 1)
|
||||
if len(parts) < 2:
|
||||
return None
|
||||
model_id = parts[0].replace("litellm:model_id:", "")
|
||||
item_id = parts[1].replace("item_id:", "")
|
||||
return {"model_id": model_id, "item_id": item_id}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _wrap_encrypted_content_with_model_id(
|
||||
encrypted_content: str, model_id: str
|
||||
) -> str:
|
||||
"""Wrap encrypted_content with model_id metadata for affinity routing.
|
||||
|
||||
When Codex or other clients send items with encrypted_content but no ID,
|
||||
we encode the model_id directly into the encrypted_content itself.
|
||||
|
||||
Format: ``litellm_enc:{base64("model_id:{model_id}")};{original_encrypted_content}``
|
||||
"""
|
||||
metadata = f"model_id:{model_id}"
|
||||
encoded_metadata = base64.b64encode(metadata.encode("utf-8")).decode("utf-8")
|
||||
return f"litellm_enc:{encoded_metadata};{encrypted_content}"
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_encrypted_content_with_model_id(
|
||||
wrapped_content: str,
|
||||
) -> tuple[Optional[str], str]:
|
||||
"""Unwrap encrypted_content to extract model_id and original content.
|
||||
|
||||
Returns:
|
||||
Tuple of (model_id, original_encrypted_content).
|
||||
If not wrapped, returns (None, original_content).
|
||||
"""
|
||||
if not wrapped_content.startswith("litellm_enc:"):
|
||||
return None, wrapped_content
|
||||
|
||||
try:
|
||||
# Split on first ";" to separate metadata from content
|
||||
parts = wrapped_content.split(";", 1)
|
||||
if len(parts) < 2:
|
||||
return None, wrapped_content
|
||||
|
||||
metadata_b64 = parts[0].replace("litellm_enc:", "")
|
||||
original_content = parts[1]
|
||||
|
||||
# Restore padding if needed
|
||||
missing = len(metadata_b64) % 4
|
||||
if missing:
|
||||
metadata_b64 += "=" * (4 - missing)
|
||||
|
||||
decoded_metadata = base64.b64decode(metadata_b64.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
model_id = decoded_metadata.replace("model_id:", "")
|
||||
return model_id, original_content
|
||||
except Exception:
|
||||
return None, wrapped_content
|
||||
|
||||
@staticmethod
|
||||
def _update_encrypted_content_item_ids_in_response(
|
||||
response: Union["ResponsesAPIResponse", Dict[str, Any]],
|
||||
model_id: Optional[str],
|
||||
) -> Union["ResponsesAPIResponse", Dict[str, Any]]:
|
||||
"""Rewrite item IDs for output items that contain ``encrypted_content``.
|
||||
|
||||
Encodes ``model_id`` into the item ID so that follow-up requests can be
|
||||
routed back to the originating deployment without any cache lookup.
|
||||
|
||||
For items without an ID (e.g., from Codex), encodes model_id directly
|
||||
into the encrypted_content itself.
|
||||
"""
|
||||
if not model_id:
|
||||
return response
|
||||
|
||||
output: Optional[list] = None
|
||||
if isinstance(response, dict):
|
||||
output = response.get("output")
|
||||
else:
|
||||
output = getattr(response, "output", None)
|
||||
|
||||
if not isinstance(output, list):
|
||||
return response
|
||||
|
||||
for item in output:
|
||||
if isinstance(item, dict):
|
||||
item_id = item.get("id")
|
||||
encrypted_content = item.get("encrypted_content")
|
||||
|
||||
if encrypted_content and isinstance(encrypted_content, str):
|
||||
# Always wrap encrypted_content with model_id for redundancy
|
||||
item[
|
||||
"encrypted_content"
|
||||
] = ResponsesAPIRequestUtils._wrap_encrypted_content_with_model_id(
|
||||
encrypted_content, model_id
|
||||
)
|
||||
# Also encode the ID if present
|
||||
if item_id and isinstance(item_id, str):
|
||||
item["id"] = ResponsesAPIRequestUtils._build_encrypted_item_id(
|
||||
model_id, item_id
|
||||
)
|
||||
else:
|
||||
item_id = getattr(item, "id", None)
|
||||
encrypted_content = getattr(item, "encrypted_content", None)
|
||||
|
||||
if encrypted_content and isinstance(encrypted_content, str):
|
||||
# Always wrap encrypted_content with model_id for redundancy
|
||||
try:
|
||||
item.encrypted_content = ResponsesAPIRequestUtils._wrap_encrypted_content_with_model_id(
|
||||
encrypted_content, model_id
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
# Also encode the ID if present
|
||||
if item_id and isinstance(item_id, str):
|
||||
try:
|
||||
item.id = ResponsesAPIRequestUtils._build_encrypted_item_id(
|
||||
model_id, item_id
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _restore_encrypted_content_item_ids_in_input(request_input: Any) -> Any:
|
||||
"""Decode litellm-encoded item IDs in request input back to original IDs.
|
||||
|
||||
Called before forwarding the request to the upstream provider so the
|
||||
provider receives the original item IDs and unwrapped encrypted_content.
|
||||
|
||||
Handles both:
|
||||
1. Items with encoded IDs (encitem_...)
|
||||
2. Items with wrapped encrypted_content (litellm_enc:...)
|
||||
"""
|
||||
if not isinstance(request_input, list):
|
||||
return request_input
|
||||
|
||||
for item in request_input:
|
||||
if isinstance(item, dict):
|
||||
item_id = item.get("id")
|
||||
if item_id and isinstance(item_id, str):
|
||||
decoded = ResponsesAPIRequestUtils._decode_encrypted_item_id(
|
||||
item_id
|
||||
)
|
||||
if decoded:
|
||||
item["id"] = decoded["item_id"]
|
||||
|
||||
encrypted_content = item.get("encrypted_content")
|
||||
if encrypted_content and isinstance(encrypted_content, str):
|
||||
(
|
||||
_,
|
||||
unwrapped,
|
||||
) = ResponsesAPIRequestUtils._unwrap_encrypted_content_with_model_id(
|
||||
encrypted_content
|
||||
)
|
||||
if unwrapped != encrypted_content:
|
||||
item["encrypted_content"] = unwrapped
|
||||
|
||||
return request_input
|
||||
|
||||
@staticmethod
|
||||
def _build_responses_api_response_id(
|
||||
custom_llm_provider: Optional[str],
|
||||
model_id: Optional[str],
|
||||
response_id: str,
|
||||
) -> str:
|
||||
"""Build the responses_api_response_id"""
|
||||
assembled_id: str = str(
|
||||
SpecialEnums.LITELLM_MANAGED_RESPONSE_COMPLETE_STR.value
|
||||
).format(custom_llm_provider, model_id, response_id)
|
||||
base64_encoded_id: str = base64.b64encode(assembled_id.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
return f"resp_{base64_encoded_id}"
|
||||
|
||||
@staticmethod
|
||||
def _decode_responses_api_response_id(
|
||||
response_id: str,
|
||||
) -> DecodedResponseId:
|
||||
"""
|
||||
Decode the responses_api_response_id
|
||||
|
||||
Returns:
|
||||
DecodedResponseId: Structured tuple with custom_llm_provider, model_id, and response_id
|
||||
"""
|
||||
try:
|
||||
# Remove prefix and decode
|
||||
cleaned_id = response_id.replace("resp_", "")
|
||||
decoded_id = base64.b64decode(cleaned_id.encode("utf-8")).decode("utf-8")
|
||||
|
||||
# Parse components using known prefixes
|
||||
if ";" not in decoded_id:
|
||||
return DecodedResponseId(
|
||||
custom_llm_provider=None,
|
||||
model_id=None,
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
parts = decoded_id.split(";")
|
||||
|
||||
# Format: litellm:custom_llm_provider:{};model_id:{};response_id:{}
|
||||
custom_llm_provider = None
|
||||
model_id = None
|
||||
|
||||
if (
|
||||
len(parts) >= 3
|
||||
): # Full format with custom_llm_provider, model_id, and response_id
|
||||
custom_llm_provider_part = parts[0]
|
||||
model_id_part = parts[1]
|
||||
response_part = parts[2]
|
||||
|
||||
custom_llm_provider = custom_llm_provider_part.replace(
|
||||
"litellm:custom_llm_provider:", ""
|
||||
)
|
||||
model_id = model_id_part.replace("model_id:", "")
|
||||
decoded_response_id = response_part.replace("response_id:", "")
|
||||
else:
|
||||
decoded_response_id = response_id
|
||||
|
||||
return DecodedResponseId(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
model_id=model_id,
|
||||
response_id=decoded_response_id,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error decoding response_id '{response_id}': {e}")
|
||||
return DecodedResponseId(
|
||||
custom_llm_provider=None,
|
||||
model_id=None,
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_model_id_from_response_id(response_id: Optional[str]) -> Optional[str]:
|
||||
"""Get the model_id from the response_id"""
|
||||
if response_id is None:
|
||||
return None
|
||||
decoded_response_id = (
|
||||
ResponsesAPIRequestUtils._decode_responses_api_response_id(response_id)
|
||||
)
|
||||
return decoded_response_id.get("model_id") or None
|
||||
|
||||
@staticmethod
|
||||
def decode_previous_response_id_to_original_previous_response_id(
|
||||
previous_response_id: str,
|
||||
) -> str:
|
||||
"""
|
||||
Decode the previous_response_id to the original previous_response_id
|
||||
|
||||
Why?
|
||||
- LiteLLM encodes the `custom_llm_provider` and `model_id` into the `previous_response_id` this helps with maintaining session consistency when load balancing multiple deployments of the same model.
|
||||
- We cannot send the litellm encoded b64 to the upstream llm api, hence we decode it to the original `previous_response_id`
|
||||
|
||||
Args:
|
||||
previous_response_id: The previous_response_id to decode
|
||||
|
||||
Returns:
|
||||
The original previous_response_id
|
||||
"""
|
||||
decoded_response_id = (
|
||||
ResponsesAPIRequestUtils._decode_responses_api_response_id(
|
||||
previous_response_id
|
||||
)
|
||||
)
|
||||
return decoded_response_id.get("response_id", previous_response_id)
|
||||
|
||||
@staticmethod
|
||||
def convert_text_format_to_text_param(
|
||||
text_format: Optional[Union[Type["BaseModel"], dict]],
|
||||
text: Optional["ResponseText"] = None,
|
||||
) -> Optional["ResponseText"]:
|
||||
"""
|
||||
Convert text_format parameter to text parameter for the responses API.
|
||||
|
||||
Args:
|
||||
text_format: Pydantic model class or dict to convert to response format
|
||||
text: Existing text parameter (if provided, text_format is ignored)
|
||||
|
||||
Returns:
|
||||
ResponseText object with the converted format, or None if conversion fails
|
||||
"""
|
||||
if text_format is not None and text is None:
|
||||
from litellm.llms.base_llm.base_utils import type_to_response_format_param
|
||||
|
||||
# Convert Pydantic model to response format
|
||||
response_format = type_to_response_format_param(text_format)
|
||||
if response_format is not None:
|
||||
# Create ResponseText object with the format
|
||||
# The responses API expects the format to have name at the top level
|
||||
text = {
|
||||
"format": {
|
||||
"type": response_format["type"],
|
||||
"name": response_format["json_schema"]["name"],
|
||||
"schema": response_format["json_schema"]["schema"],
|
||||
"strict": response_format["json_schema"]["strict"],
|
||||
}
|
||||
}
|
||||
return text
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def extract_mcp_headers_from_request(
|
||||
secret_fields: Optional[Dict[str, Any]],
|
||||
tools: Optional[Iterable[Any]],
|
||||
) -> tuple[
|
||||
Optional[str],
|
||||
Optional[Dict[str, Dict[str, str]]],
|
||||
Optional[Dict[str, str]],
|
||||
Optional[Dict[str, str]],
|
||||
]:
|
||||
"""
|
||||
Extract MCP auth headers from the request to pass to MCP server.
|
||||
Headers from tools.headers in request body should be passed to MCP server.
|
||||
"""
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
|
||||
MCPRequestHandler,
|
||||
)
|
||||
|
||||
# Extract headers from secret_fields which contains the original request headers
|
||||
raw_headers_from_request: Optional[Dict[str, str]] = None
|
||||
if secret_fields and isinstance(secret_fields, dict):
|
||||
raw_headers_from_request = secret_fields.get("raw_headers")
|
||||
|
||||
# Extract MCP-specific headers using MCPRequestHandler methods
|
||||
mcp_auth_header: Optional[str] = None
|
||||
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None
|
||||
oauth2_headers: Optional[Dict[str, str]] = None
|
||||
|
||||
if raw_headers_from_request:
|
||||
headers_obj = Headers(raw_headers_from_request)
|
||||
mcp_auth_header = MCPRequestHandler._get_mcp_auth_header_from_headers(
|
||||
headers_obj
|
||||
)
|
||||
mcp_server_auth_headers = (
|
||||
MCPRequestHandler._get_mcp_server_auth_headers_from_headers(headers_obj)
|
||||
)
|
||||
oauth2_headers = MCPRequestHandler._get_oauth2_headers_from_headers(
|
||||
headers_obj
|
||||
)
|
||||
|
||||
if tools:
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict) and tool.get("type") == "mcp":
|
||||
tool_headers = tool.get("headers", {})
|
||||
if tool_headers and isinstance(tool_headers, dict):
|
||||
# Merge tool headers into mcp_server_auth_headers
|
||||
# Extract server-specific headers from tool.headers
|
||||
headers_obj_from_tool = Headers(tool_headers)
|
||||
tool_mcp_server_auth_headers = (
|
||||
MCPRequestHandler._get_mcp_server_auth_headers_from_headers(
|
||||
headers_obj_from_tool
|
||||
)
|
||||
)
|
||||
if tool_mcp_server_auth_headers:
|
||||
if mcp_server_auth_headers is None:
|
||||
mcp_server_auth_headers = {}
|
||||
# Merge the headers from tool into existing headers
|
||||
for (
|
||||
server_alias,
|
||||
headers_dict,
|
||||
) in tool_mcp_server_auth_headers.items():
|
||||
if server_alias not in mcp_server_auth_headers:
|
||||
mcp_server_auth_headers[server_alias] = {}
|
||||
mcp_server_auth_headers[server_alias].update(
|
||||
headers_dict
|
||||
)
|
||||
# Also merge raw headers (non-prefixed headers from tool.headers)
|
||||
if raw_headers_from_request is None:
|
||||
raw_headers_from_request = {}
|
||||
raw_headers_from_request.update(tool_headers)
|
||||
|
||||
return (
|
||||
mcp_auth_header,
|
||||
mcp_server_auth_headers,
|
||||
oauth2_headers,
|
||||
raw_headers_from_request,
|
||||
)
|
||||
|
||||
|
||||
class ResponseAPILoggingUtils:
|
||||
@staticmethod
|
||||
def _is_response_api_usage(usage: Union[dict, ResponseAPIUsage]) -> bool:
|
||||
"""returns True if usage is from OpenAI Response API"""
|
||||
if isinstance(usage, ResponseAPIUsage):
|
||||
return True
|
||||
if "input_tokens" in usage and "output_tokens" in usage:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _transform_response_api_usage_to_chat_usage(
|
||||
usage_input: Optional[Union[dict, ResponseAPIUsage]],
|
||||
) -> Usage:
|
||||
"""
|
||||
Transforms ResponseAPIUsage or ImageUsage to a Usage object.
|
||||
|
||||
Both have the same spec with input_tokens, output_tokens, and
|
||||
input_tokens_details (text_tokens, image_tokens).
|
||||
"""
|
||||
if usage_input is None:
|
||||
return Usage(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
response_api_usage: ResponseAPIUsage
|
||||
if isinstance(usage_input, dict):
|
||||
total_tokens = usage_input.get("total_tokens")
|
||||
if total_tokens is None:
|
||||
input_tokens = usage_input.get("input_tokens")
|
||||
output_tokens = usage_input.get("output_tokens")
|
||||
if input_tokens is not None and output_tokens is not None:
|
||||
total_tokens = input_tokens + output_tokens
|
||||
usage_input["total_tokens"] = total_tokens
|
||||
response_api_usage = ResponseAPIUsage(**usage_input)
|
||||
else:
|
||||
response_api_usage = usage_input
|
||||
prompt_tokens: int = response_api_usage.input_tokens or 0
|
||||
completion_tokens: int = response_api_usage.output_tokens or 0
|
||||
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
|
||||
if response_api_usage.input_tokens_details:
|
||||
if isinstance(response_api_usage.input_tokens_details, dict):
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
**response_api_usage.input_tokens_details
|
||||
)
|
||||
else:
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
cached_tokens=getattr(
|
||||
response_api_usage.input_tokens_details, "cached_tokens", None
|
||||
),
|
||||
audio_tokens=getattr(
|
||||
response_api_usage.input_tokens_details, "audio_tokens", None
|
||||
),
|
||||
text_tokens=getattr(
|
||||
response_api_usage.input_tokens_details, "text_tokens", None
|
||||
),
|
||||
image_tokens=getattr(
|
||||
response_api_usage.input_tokens_details, "image_tokens", None
|
||||
),
|
||||
)
|
||||
completion_tokens_details: Optional[CompletionTokensDetailsWrapper] = None
|
||||
output_tokens_details = getattr(
|
||||
response_api_usage, "output_tokens_details", None
|
||||
)
|
||||
if output_tokens_details:
|
||||
completion_tokens_details = CompletionTokensDetailsWrapper(
|
||||
reasoning_tokens=getattr(
|
||||
output_tokens_details, "reasoning_tokens", None
|
||||
),
|
||||
image_tokens=getattr(output_tokens_details, "image_tokens", None),
|
||||
text_tokens=getattr(output_tokens_details, "text_tokens", None),
|
||||
)
|
||||
|
||||
chat_usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
completion_tokens_details=completion_tokens_details,
|
||||
)
|
||||
|
||||
# Preserve cost attribute if it exists on ResponseAPIUsage
|
||||
if hasattr(response_api_usage, "cost") and response_api_usage.cost is not None:
|
||||
setattr(chat_usage, "cost", response_api_usage.cost)
|
||||
|
||||
return chat_usage
|
||||
Reference in New Issue
Block a user