Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/responses/mcp/chat_completions_handler.py
2026-03-26 16:04:46 +08:00

671 lines
28 KiB
Python

"""Helpers for handling MCP-aware `/chat/completions` requests."""
from typing import (
Any,
List,
Optional,
Union,
cast,
)
from litellm.responses.mcp.litellm_proxy_mcp_handler import (
LiteLLM_Proxy_MCP_Handler,
)
from litellm.responses.utils import ResponsesAPIRequestUtils
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper
def _add_mcp_metadata_to_response(
response: Union[ModelResponse, CustomStreamWrapper],
openai_tools: Optional[List],
tool_calls: Optional[List] = None,
tool_results: Optional[List] = None,
) -> None:
"""
Add MCP metadata to response's provider_specific_fields.
This function adds MCP-related information to the response so that
clients can access which tools were available, which were called, and
what results were returned.
For ModelResponse: adds to choices[].message.provider_specific_fields
For CustomStreamWrapper: stores in _hidden_params and automatically adds to
final chunk's delta.provider_specific_fields via CustomStreamWrapper._add_mcp_metadata_to_final_chunk()
"""
if isinstance(response, CustomStreamWrapper):
# For streaming, store MCP metadata in _hidden_params
# CustomStreamWrapper._add_mcp_metadata_to_final_chunk() will automatically
# add it to the final chunk's delta.provider_specific_fields
if not hasattr(response, "_hidden_params"):
response._hidden_params = {}
mcp_metadata = {}
if openai_tools:
mcp_metadata["mcp_list_tools"] = openai_tools
if tool_calls:
mcp_metadata["mcp_tool_calls"] = tool_calls
if tool_results:
mcp_metadata["mcp_call_results"] = tool_results
if mcp_metadata:
response._hidden_params["mcp_metadata"] = mcp_metadata
return
if not isinstance(response, ModelResponse):
return
if not hasattr(response, "choices") or not response.choices:
return
# Add MCP metadata to all choices' messages
for choice in response.choices:
message = getattr(choice, "message", None)
if message is not None:
# Get existing provider_specific_fields or create new dict
provider_fields = getattr(message, "provider_specific_fields", None) or {}
# Add MCP metadata
if openai_tools:
provider_fields["mcp_list_tools"] = openai_tools
if tool_calls:
provider_fields["mcp_tool_calls"] = tool_calls
if tool_results:
provider_fields["mcp_call_results"] = tool_results
# Set the provider_specific_fields
setattr(message, "provider_specific_fields", provider_fields)
async def acompletion_with_mcp( # noqa: PLR0915
model: str,
messages: List,
tools: Optional[List] = None,
**kwargs: Any,
) -> Union[ModelResponse, CustomStreamWrapper]:
"""
Async completion with MCP integration.
This function handles MCP tool integration following the same pattern as aresponses_api_with_mcp.
It's designed to be called from the synchronous completion() function and return a coroutine.
When MCP tools with server_url="litellm_proxy" are provided, this function will:
1. Get available tools from the MCP server manager
2. Transform them to OpenAI format
3. Call acompletion with the transformed tools
4. If require_approval="never" and tool calls are returned, automatically execute them
5. Make a follow-up call with the tool results
"""
from litellm import acompletion as litellm_acompletion
# Parse MCP tools and separate from other tools
(
mcp_tools_with_litellm_proxy,
other_tools,
) = LiteLLM_Proxy_MCP_Handler._parse_mcp_tools(tools)
if not mcp_tools_with_litellm_proxy:
# No MCP tools, proceed with regular completion
return await litellm_acompletion(
model=model,
messages=messages,
tools=tools,
**kwargs,
)
# Extract user_api_key_auth from metadata or kwargs
user_api_key_auth = kwargs.get("user_api_key_auth") or (
(kwargs.get("metadata", {}) or {}).get("user_api_key_auth")
)
# Extract MCP auth headers before fetching tools (needed for dynamic auth)
(
mcp_auth_header,
mcp_server_auth_headers,
oauth2_headers,
raw_headers,
) = ResponsesAPIRequestUtils.extract_mcp_headers_from_request(
secret_fields=kwargs.get("secret_fields"),
tools=tools,
)
# Process MCP tools (pass auth headers for dynamic auth)
(
deduplicated_mcp_tools,
tool_server_map,
) = await LiteLLM_Proxy_MCP_Handler._process_mcp_tools_without_openai_transform(
user_api_key_auth=user_api_key_auth,
mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy,
litellm_trace_id=kwargs.get("litellm_trace_id"),
mcp_auth_header=mcp_auth_header,
mcp_server_auth_headers=mcp_server_auth_headers,
)
openai_tools = LiteLLM_Proxy_MCP_Handler._transform_mcp_tools_to_openai(
deduplicated_mcp_tools,
target_format="chat",
)
# Combine with other tools
all_tools = openai_tools + other_tools if (openai_tools or other_tools) else None
# Determine if we should auto-execute tools
should_auto_execute = LiteLLM_Proxy_MCP_Handler._should_auto_execute_tools(
mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy
)
# Prepare call parameters
# Remove keys that shouldn't be passed to acompletion
clean_kwargs = {k: v for k, v in kwargs.items() if k not in ["acompletion"]}
base_call_args = {
"model": model,
"messages": messages,
"tools": all_tools,
"_skip_mcp_handler": True, # Prevent recursion
**clean_kwargs,
}
# If not auto-executing, just make the call with transformed tools
if not should_auto_execute:
response = await litellm_acompletion(**base_call_args)
if isinstance(response, (ModelResponse, CustomStreamWrapper)):
_add_mcp_metadata_to_response(
response=response,
openai_tools=openai_tools,
)
return response
# For auto-execute: handle streaming vs non-streaming differently
stream = kwargs.get("stream", False)
mock_tool_calls = base_call_args.pop("mock_tool_calls", None)
if stream:
# Streaming mode: make initial call with streaming, collect chunks, detect tool calls
initial_call_args = dict(base_call_args)
initial_call_args["stream"] = True
if mock_tool_calls is not None:
initial_call_args["mock_tool_calls"] = mock_tool_calls
# Make initial streaming call
initial_stream = await litellm_acompletion(**initial_call_args)
if not isinstance(initial_stream, CustomStreamWrapper):
# Not a stream, return as-is
if isinstance(initial_stream, ModelResponse):
_add_mcp_metadata_to_response(
response=initial_stream,
openai_tools=openai_tools,
)
return initial_stream
# Create a custom async generator that collects chunks and handles tool execution
from litellm.main import stream_chunk_builder
from litellm.types.utils import ModelResponseStream
class MCPStreamingIterator:
"""Custom iterator that collects chunks, detects tool calls, and adds MCP metadata to final chunk."""
def __init__(
self,
stream_wrapper,
messages,
tool_server_map,
user_api_key_auth,
mcp_auth_header,
mcp_server_auth_headers,
oauth2_headers,
raw_headers,
litellm_call_id,
litellm_trace_id,
openai_tools,
base_call_args,
):
self.stream_wrapper = stream_wrapper
self.messages = messages
self.tool_server_map = tool_server_map
self.user_api_key_auth = user_api_key_auth
self.mcp_auth_header = mcp_auth_header
self.mcp_server_auth_headers = mcp_server_auth_headers
self.oauth2_headers = oauth2_headers
self.raw_headers = raw_headers
self.litellm_call_id = litellm_call_id
self.litellm_trace_id = litellm_trace_id
self.openai_tools = openai_tools
self.base_call_args = base_call_args
self.collected_chunks: List[ModelResponseStream] = []
self.tool_calls: Optional[List] = None
self.tool_results: Optional[List] = None
self.complete_response: Optional[ModelResponse] = None
self.stream_exhausted = False
self.tool_execution_done = False
self.follow_up_stream = None
self.follow_up_iterator = None
self.follow_up_exhausted = False
async def __aiter__(self):
return self
def _add_mcp_list_tools_to_chunk(
self, chunk: ModelResponseStream
) -> ModelResponseStream:
"""Add mcp_list_tools to the first chunk."""
from litellm.types.utils import (
StreamingChoices,
add_provider_specific_fields,
)
if not self.openai_tools:
return chunk
if hasattr(chunk, "choices") and chunk.choices:
for choice in chunk.choices:
if (
isinstance(choice, StreamingChoices)
and hasattr(choice, "delta")
and choice.delta
):
# Get existing provider_specific_fields or create new dict
existing_fields = (
getattr(choice.delta, "provider_specific_fields", None)
or {}
)
provider_fields = dict(
existing_fields
) # Create a copy to avoid mutating the original
# Add only mcp_list_tools to first chunk
provider_fields["mcp_list_tools"] = self.openai_tools
# Use add_provider_specific_fields to ensure proper setting
# This function handles Pydantic model attribute setting correctly
add_provider_specific_fields(choice.delta, provider_fields)
return chunk
def _add_mcp_tool_metadata_to_final_chunk(
self, chunk: ModelResponseStream
) -> ModelResponseStream:
"""Add mcp_tool_calls and mcp_call_results to the final chunk."""
from litellm.types.utils import (
StreamingChoices,
add_provider_specific_fields,
)
if hasattr(chunk, "choices") and chunk.choices:
for choice in chunk.choices:
if (
isinstance(choice, StreamingChoices)
and hasattr(choice, "delta")
and choice.delta
):
# Get existing provider_specific_fields or create new dict
# Access the attribute directly to handle Pydantic model attributes correctly
existing_fields = {}
if hasattr(choice.delta, "provider_specific_fields"):
attr_value = getattr(
choice.delta, "provider_specific_fields", None
)
if attr_value is not None:
# Create a copy to avoid mutating the original
existing_fields = (
dict(attr_value)
if isinstance(attr_value, dict)
else {}
)
provider_fields = existing_fields
# Add tool_calls and tool_results if available
if self.tool_calls:
provider_fields["mcp_tool_calls"] = self.tool_calls
if self.tool_results:
provider_fields["mcp_call_results"] = self.tool_results
# Use add_provider_specific_fields to ensure proper setting
# This function handles Pydantic model attribute setting correctly
add_provider_specific_fields(choice.delta, provider_fields)
return chunk
async def __anext__(self):
# Phase 1: Collect and yield initial stream chunks
if not self.stream_exhausted:
# Get the iterator from the stream wrapper
if not hasattr(self, "_stream_iterator"):
self._stream_iterator = self.stream_wrapper.__aiter__()
# Add mcp_list_tools to the first chunk (available from the start)
_add_mcp_metadata_to_response(
response=self.stream_wrapper,
openai_tools=self.openai_tools,
)
try:
chunk = await self._stream_iterator.__anext__()
self.collected_chunks.append(chunk)
# Add mcp_list_tools to the first chunk
if len(self.collected_chunks) == 1:
chunk = self._add_mcp_list_tools_to_chunk(chunk)
# Check if this is the final chunk (has finish_reason)
is_final = (
hasattr(chunk, "choices")
and chunk.choices
and hasattr(chunk.choices[0], "finish_reason")
and chunk.choices[0].finish_reason is not None
)
if is_final:
# This is the final chunk, mark stream as exhausted
self.stream_exhausted = True
# Process tool calls after we've collected all chunks
await self._process_tool_calls()
# Apply MCP metadata (tool_calls and tool_results) to final chunk
chunk = self._add_mcp_tool_metadata_to_final_chunk(chunk)
# If we have tool results, prepare follow-up call immediately
if self.tool_results and self.complete_response:
await self._prepare_follow_up_call()
return chunk
except StopAsyncIteration:
self.stream_exhausted = True
# Process tool calls after stream is exhausted
await self._process_tool_calls()
# If we have chunks, yield the final one with metadata
if self.collected_chunks:
final_chunk = self.collected_chunks[-1]
final_chunk = self._add_mcp_tool_metadata_to_final_chunk(
final_chunk
)
# If we have tool results, prepare follow-up call
if self.tool_results and self.complete_response:
await self._prepare_follow_up_call()
return final_chunk
# Phase 2: Yield follow-up stream chunks if available
if self.follow_up_stream and not self.follow_up_exhausted:
if not self.follow_up_iterator:
self.follow_up_iterator = self.follow_up_stream.__aiter__()
from litellm._logging import verbose_logger
verbose_logger.debug("Follow-up stream iterator created")
try:
chunk = await self.follow_up_iterator.__anext__()
from litellm._logging import verbose_logger
verbose_logger.debug(f"Follow-up chunk yielded: {chunk}")
return chunk
except StopAsyncIteration:
self.follow_up_exhausted = True
from litellm._logging import verbose_logger
verbose_logger.debug("Follow-up stream exhausted")
# After follow-up stream is exhausted, check if we need to raise StopAsyncIteration
raise StopAsyncIteration
# If we're here and follow_up_stream is None but we expected it, log a warning
if (
self.stream_exhausted
and self.tool_results
and self.complete_response
and self.follow_up_stream is None
):
from litellm._logging import verbose_logger
verbose_logger.warning(
"Follow-up stream was not created despite having tool results"
)
raise StopAsyncIteration
async def _process_tool_calls(self):
"""Process tool calls after streaming completes."""
if self.tool_execution_done:
return
self.tool_execution_done = True
if not self.collected_chunks:
return
# Build complete response from chunks
complete_response = stream_chunk_builder(
chunks=self.collected_chunks,
messages=self.messages,
)
if isinstance(complete_response, ModelResponse):
self.complete_response = complete_response
# Extract tool calls from complete response
self.tool_calls = LiteLLM_Proxy_MCP_Handler._extract_tool_calls_from_chat_response(
response=complete_response
)
if self.tool_calls:
# Execute tool calls
self.tool_results = (
await LiteLLM_Proxy_MCP_Handler._execute_tool_calls(
tool_server_map=self.tool_server_map,
tool_calls=self.tool_calls,
user_api_key_auth=self.user_api_key_auth,
mcp_auth_header=self.mcp_auth_header,
mcp_server_auth_headers=self.mcp_server_auth_headers,
oauth2_headers=self.oauth2_headers,
raw_headers=self.raw_headers,
litellm_call_id=self.litellm_call_id,
litellm_trace_id=self.litellm_trace_id,
)
)
async def _prepare_follow_up_call(self):
"""Prepare and initiate follow-up call with tool results."""
if self.follow_up_stream is not None:
return # Already prepared
if not self.tool_results or not self.complete_response:
return
# Create follow-up messages with tool results
follow_up_messages = (
LiteLLM_Proxy_MCP_Handler._create_follow_up_messages_for_chat(
original_messages=self.messages,
response=self.complete_response,
tool_results=self.tool_results,
)
)
# Make follow-up call with streaming
follow_up_call_args = dict(self.base_call_args)
follow_up_call_args["messages"] = follow_up_messages
follow_up_call_args["stream"] = True
# Ensure follow-up call doesn't trigger MCP handler again
follow_up_call_args["_skip_mcp_handler"] = True
# Import litellm here to ensure we get the patched version
# This ensures the patch works correctly in tests
import litellm
follow_up_response = await litellm.acompletion(**follow_up_call_args)
# Ensure follow-up response is a CustomStreamWrapper
if isinstance(follow_up_response, CustomStreamWrapper):
self.follow_up_stream = follow_up_response
from litellm._logging import verbose_logger
verbose_logger.debug("Follow-up stream created successfully")
else:
# Unexpected response type - log and set to None
from litellm._logging import verbose_logger
verbose_logger.warning(
f"Follow-up response is not a CustomStreamWrapper: {type(follow_up_response)}"
)
self.follow_up_stream = None
# Create the custom iterator
iterator = MCPStreamingIterator(
stream_wrapper=initial_stream,
messages=messages,
tool_server_map=tool_server_map,
user_api_key_auth=user_api_key_auth,
mcp_auth_header=mcp_auth_header,
mcp_server_auth_headers=mcp_server_auth_headers,
oauth2_headers=oauth2_headers,
raw_headers=raw_headers,
litellm_call_id=kwargs.get("litellm_call_id"),
litellm_trace_id=kwargs.get("litellm_trace_id"),
openai_tools=openai_tools,
base_call_args=base_call_args,
)
# Create a wrapper class that delegates to our custom iterator
# We'll use a simple approach: just replace the __aiter__ method
class MCPStreamWrapper(CustomStreamWrapper):
def __init__(self, original_wrapper, custom_iterator):
# Initialize with the same parameters as original wrapper
super().__init__(
completion_stream=None,
model=getattr(original_wrapper, "model", "unknown"),
logging_obj=getattr(original_wrapper, "logging_obj", None),
custom_llm_provider=getattr(
original_wrapper, "custom_llm_provider", None
),
stream_options=getattr(original_wrapper, "stream_options", None),
make_call=getattr(original_wrapper, "make_call", None),
_response_headers=getattr(
original_wrapper, "_response_headers", None
),
)
self._original_wrapper = original_wrapper
self._custom_iterator = custom_iterator
# Copy important attributes from original wrapper
if hasattr(original_wrapper, "_hidden_params"):
self._hidden_params = original_wrapper._hidden_params
# For synchronous iteration, we need to run the async iterator
self._sync_iterator = None
self._sync_loop = None
def __aiter__(self):
return self._custom_iterator
def __iter__(self):
# For synchronous iteration, create a sync wrapper
if self._sync_iterator is None:
import asyncio
try:
self._sync_loop = asyncio.get_event_loop()
except RuntimeError:
self._sync_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._sync_loop)
self._sync_iterator = _SyncIteratorWrapper(
self._custom_iterator, self._sync_loop
)
return self._sync_iterator
def __next__(self):
# Delegate to sync iterator
if self._sync_iterator is None:
self.__iter__()
return next(self._sync_iterator)
def __getattr__(self, name):
# Delegate all other attributes to original wrapper
return getattr(self._original_wrapper, name)
# Helper class to wrap async iterator for sync iteration
class _SyncIteratorWrapper:
def __init__(self, async_iterator, loop):
self._async_iterator = async_iterator
self._loop = loop
self._iterator = None
def __iter__(self):
return self
def __next__(self):
if self._iterator is None:
# __aiter__ might be async, so we need to await it
aiter_result = self._async_iterator.__aiter__()
if hasattr(aiter_result, "__await__"):
# It's a coroutine, await it
self._iterator = self._loop.run_until_complete(aiter_result)
else:
# It's already an iterator
self._iterator = aiter_result
try:
return self._loop.run_until_complete(self._iterator.__anext__())
except StopAsyncIteration:
raise StopIteration
return cast(CustomStreamWrapper, MCPStreamWrapper(initial_stream, iterator))
# Non-streaming mode: use existing logic
initial_call_args = dict(base_call_args)
initial_call_args["stream"] = False
if mock_tool_calls is not None:
initial_call_args["mock_tool_calls"] = mock_tool_calls
# Make initial call
initial_response = await litellm_acompletion(**initial_call_args)
if not isinstance(initial_response, ModelResponse):
return initial_response
# Extract tool calls from response
tool_calls = LiteLLM_Proxy_MCP_Handler._extract_tool_calls_from_chat_response(
response=initial_response
)
if not tool_calls:
_add_mcp_metadata_to_response(
response=initial_response,
openai_tools=openai_tools,
)
return initial_response
# Execute tool calls
tool_results = await LiteLLM_Proxy_MCP_Handler._execute_tool_calls(
tool_server_map=tool_server_map,
tool_calls=tool_calls,
user_api_key_auth=user_api_key_auth,
mcp_auth_header=mcp_auth_header,
mcp_server_auth_headers=mcp_server_auth_headers,
oauth2_headers=oauth2_headers,
raw_headers=raw_headers,
litellm_call_id=kwargs.get("litellm_call_id"),
litellm_trace_id=kwargs.get("litellm_trace_id"),
)
if not tool_results:
_add_mcp_metadata_to_response(
response=initial_response,
openai_tools=openai_tools,
tool_calls=tool_calls,
)
return initial_response
# Create follow-up messages with tool results
follow_up_messages = LiteLLM_Proxy_MCP_Handler._create_follow_up_messages_for_chat(
original_messages=messages,
response=initial_response,
tool_results=tool_results,
)
# Make follow-up call with original stream setting
follow_up_call_args = dict(base_call_args)
follow_up_call_args["messages"] = follow_up_messages
follow_up_call_args["stream"] = stream
response = await litellm_acompletion(**follow_up_call_args)
if isinstance(response, (ModelResponse, CustomStreamWrapper)):
_add_mcp_metadata_to_response(
response=response,
openai_tools=openai_tools,
tool_calls=tool_calls,
tool_results=tool_results,
)
return response