"""Helpers for handling MCP-aware `/chat/completions` requests.""" from typing import ( Any, List, Optional, Union, cast, ) from litellm.responses.mcp.litellm_proxy_mcp_handler import ( LiteLLM_Proxy_MCP_Handler, ) from litellm.responses.utils import ResponsesAPIRequestUtils from litellm.types.utils import ModelResponse from litellm.utils import CustomStreamWrapper def _add_mcp_metadata_to_response( response: Union[ModelResponse, CustomStreamWrapper], openai_tools: Optional[List], tool_calls: Optional[List] = None, tool_results: Optional[List] = None, ) -> None: """ Add MCP metadata to response's provider_specific_fields. This function adds MCP-related information to the response so that clients can access which tools were available, which were called, and what results were returned. For ModelResponse: adds to choices[].message.provider_specific_fields For CustomStreamWrapper: stores in _hidden_params and automatically adds to final chunk's delta.provider_specific_fields via CustomStreamWrapper._add_mcp_metadata_to_final_chunk() """ if isinstance(response, CustomStreamWrapper): # For streaming, store MCP metadata in _hidden_params # CustomStreamWrapper._add_mcp_metadata_to_final_chunk() will automatically # add it to the final chunk's delta.provider_specific_fields if not hasattr(response, "_hidden_params"): response._hidden_params = {} mcp_metadata = {} if openai_tools: mcp_metadata["mcp_list_tools"] = openai_tools if tool_calls: mcp_metadata["mcp_tool_calls"] = tool_calls if tool_results: mcp_metadata["mcp_call_results"] = tool_results if mcp_metadata: response._hidden_params["mcp_metadata"] = mcp_metadata return if not isinstance(response, ModelResponse): return if not hasattr(response, "choices") or not response.choices: return # Add MCP metadata to all choices' messages for choice in response.choices: message = getattr(choice, "message", None) if message is not None: # Get existing provider_specific_fields or create new dict provider_fields = getattr(message, "provider_specific_fields", None) or {} # Add MCP metadata if openai_tools: provider_fields["mcp_list_tools"] = openai_tools if tool_calls: provider_fields["mcp_tool_calls"] = tool_calls if tool_results: provider_fields["mcp_call_results"] = tool_results # Set the provider_specific_fields setattr(message, "provider_specific_fields", provider_fields) async def acompletion_with_mcp( # noqa: PLR0915 model: str, messages: List, tools: Optional[List] = None, **kwargs: Any, ) -> Union[ModelResponse, CustomStreamWrapper]: """ Async completion with MCP integration. This function handles MCP tool integration following the same pattern as aresponses_api_with_mcp. It's designed to be called from the synchronous completion() function and return a coroutine. When MCP tools with server_url="litellm_proxy" are provided, this function will: 1. Get available tools from the MCP server manager 2. Transform them to OpenAI format 3. Call acompletion with the transformed tools 4. If require_approval="never" and tool calls are returned, automatically execute them 5. Make a follow-up call with the tool results """ from litellm import acompletion as litellm_acompletion # Parse MCP tools and separate from other tools ( mcp_tools_with_litellm_proxy, other_tools, ) = LiteLLM_Proxy_MCP_Handler._parse_mcp_tools(tools) if not mcp_tools_with_litellm_proxy: # No MCP tools, proceed with regular completion return await litellm_acompletion( model=model, messages=messages, tools=tools, **kwargs, ) # Extract user_api_key_auth from metadata or kwargs user_api_key_auth = kwargs.get("user_api_key_auth") or ( (kwargs.get("metadata", {}) or {}).get("user_api_key_auth") ) # Extract MCP auth headers before fetching tools (needed for dynamic auth) ( mcp_auth_header, mcp_server_auth_headers, oauth2_headers, raw_headers, ) = ResponsesAPIRequestUtils.extract_mcp_headers_from_request( secret_fields=kwargs.get("secret_fields"), tools=tools, ) # Process MCP tools (pass auth headers for dynamic auth) ( deduplicated_mcp_tools, tool_server_map, ) = await LiteLLM_Proxy_MCP_Handler._process_mcp_tools_without_openai_transform( user_api_key_auth=user_api_key_auth, mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy, litellm_trace_id=kwargs.get("litellm_trace_id"), mcp_auth_header=mcp_auth_header, mcp_server_auth_headers=mcp_server_auth_headers, ) openai_tools = LiteLLM_Proxy_MCP_Handler._transform_mcp_tools_to_openai( deduplicated_mcp_tools, target_format="chat", ) # Combine with other tools all_tools = openai_tools + other_tools if (openai_tools or other_tools) else None # Determine if we should auto-execute tools should_auto_execute = LiteLLM_Proxy_MCP_Handler._should_auto_execute_tools( mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy ) # Prepare call parameters # Remove keys that shouldn't be passed to acompletion clean_kwargs = {k: v for k, v in kwargs.items() if k not in ["acompletion"]} base_call_args = { "model": model, "messages": messages, "tools": all_tools, "_skip_mcp_handler": True, # Prevent recursion **clean_kwargs, } # If not auto-executing, just make the call with transformed tools if not should_auto_execute: response = await litellm_acompletion(**base_call_args) if isinstance(response, (ModelResponse, CustomStreamWrapper)): _add_mcp_metadata_to_response( response=response, openai_tools=openai_tools, ) return response # For auto-execute: handle streaming vs non-streaming differently stream = kwargs.get("stream", False) mock_tool_calls = base_call_args.pop("mock_tool_calls", None) if stream: # Streaming mode: make initial call with streaming, collect chunks, detect tool calls initial_call_args = dict(base_call_args) initial_call_args["stream"] = True if mock_tool_calls is not None: initial_call_args["mock_tool_calls"] = mock_tool_calls # Make initial streaming call initial_stream = await litellm_acompletion(**initial_call_args) if not isinstance(initial_stream, CustomStreamWrapper): # Not a stream, return as-is if isinstance(initial_stream, ModelResponse): _add_mcp_metadata_to_response( response=initial_stream, openai_tools=openai_tools, ) return initial_stream # Create a custom async generator that collects chunks and handles tool execution from litellm.main import stream_chunk_builder from litellm.types.utils import ModelResponseStream class MCPStreamingIterator: """Custom iterator that collects chunks, detects tool calls, and adds MCP metadata to final chunk.""" def __init__( self, stream_wrapper, messages, tool_server_map, user_api_key_auth, mcp_auth_header, mcp_server_auth_headers, oauth2_headers, raw_headers, litellm_call_id, litellm_trace_id, openai_tools, base_call_args, ): self.stream_wrapper = stream_wrapper self.messages = messages self.tool_server_map = tool_server_map self.user_api_key_auth = user_api_key_auth self.mcp_auth_header = mcp_auth_header self.mcp_server_auth_headers = mcp_server_auth_headers self.oauth2_headers = oauth2_headers self.raw_headers = raw_headers self.litellm_call_id = litellm_call_id self.litellm_trace_id = litellm_trace_id self.openai_tools = openai_tools self.base_call_args = base_call_args self.collected_chunks: List[ModelResponseStream] = [] self.tool_calls: Optional[List] = None self.tool_results: Optional[List] = None self.complete_response: Optional[ModelResponse] = None self.stream_exhausted = False self.tool_execution_done = False self.follow_up_stream = None self.follow_up_iterator = None self.follow_up_exhausted = False async def __aiter__(self): return self def _add_mcp_list_tools_to_chunk( self, chunk: ModelResponseStream ) -> ModelResponseStream: """Add mcp_list_tools to the first chunk.""" from litellm.types.utils import ( StreamingChoices, add_provider_specific_fields, ) if not self.openai_tools: return chunk if hasattr(chunk, "choices") and chunk.choices: for choice in chunk.choices: if ( isinstance(choice, StreamingChoices) and hasattr(choice, "delta") and choice.delta ): # Get existing provider_specific_fields or create new dict existing_fields = ( getattr(choice.delta, "provider_specific_fields", None) or {} ) provider_fields = dict( existing_fields ) # Create a copy to avoid mutating the original # Add only mcp_list_tools to first chunk provider_fields["mcp_list_tools"] = self.openai_tools # Use add_provider_specific_fields to ensure proper setting # This function handles Pydantic model attribute setting correctly add_provider_specific_fields(choice.delta, provider_fields) return chunk def _add_mcp_tool_metadata_to_final_chunk( self, chunk: ModelResponseStream ) -> ModelResponseStream: """Add mcp_tool_calls and mcp_call_results to the final chunk.""" from litellm.types.utils import ( StreamingChoices, add_provider_specific_fields, ) if hasattr(chunk, "choices") and chunk.choices: for choice in chunk.choices: if ( isinstance(choice, StreamingChoices) and hasattr(choice, "delta") and choice.delta ): # Get existing provider_specific_fields or create new dict # Access the attribute directly to handle Pydantic model attributes correctly existing_fields = {} if hasattr(choice.delta, "provider_specific_fields"): attr_value = getattr( choice.delta, "provider_specific_fields", None ) if attr_value is not None: # Create a copy to avoid mutating the original existing_fields = ( dict(attr_value) if isinstance(attr_value, dict) else {} ) provider_fields = existing_fields # Add tool_calls and tool_results if available if self.tool_calls: provider_fields["mcp_tool_calls"] = self.tool_calls if self.tool_results: provider_fields["mcp_call_results"] = self.tool_results # Use add_provider_specific_fields to ensure proper setting # This function handles Pydantic model attribute setting correctly add_provider_specific_fields(choice.delta, provider_fields) return chunk async def __anext__(self): # Phase 1: Collect and yield initial stream chunks if not self.stream_exhausted: # Get the iterator from the stream wrapper if not hasattr(self, "_stream_iterator"): self._stream_iterator = self.stream_wrapper.__aiter__() # Add mcp_list_tools to the first chunk (available from the start) _add_mcp_metadata_to_response( response=self.stream_wrapper, openai_tools=self.openai_tools, ) try: chunk = await self._stream_iterator.__anext__() self.collected_chunks.append(chunk) # Add mcp_list_tools to the first chunk if len(self.collected_chunks) == 1: chunk = self._add_mcp_list_tools_to_chunk(chunk) # Check if this is the final chunk (has finish_reason) is_final = ( hasattr(chunk, "choices") and chunk.choices and hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason is not None ) if is_final: # This is the final chunk, mark stream as exhausted self.stream_exhausted = True # Process tool calls after we've collected all chunks await self._process_tool_calls() # Apply MCP metadata (tool_calls and tool_results) to final chunk chunk = self._add_mcp_tool_metadata_to_final_chunk(chunk) # If we have tool results, prepare follow-up call immediately if self.tool_results and self.complete_response: await self._prepare_follow_up_call() return chunk except StopAsyncIteration: self.stream_exhausted = True # Process tool calls after stream is exhausted await self._process_tool_calls() # If we have chunks, yield the final one with metadata if self.collected_chunks: final_chunk = self.collected_chunks[-1] final_chunk = self._add_mcp_tool_metadata_to_final_chunk( final_chunk ) # If we have tool results, prepare follow-up call if self.tool_results and self.complete_response: await self._prepare_follow_up_call() return final_chunk # Phase 2: Yield follow-up stream chunks if available if self.follow_up_stream and not self.follow_up_exhausted: if not self.follow_up_iterator: self.follow_up_iterator = self.follow_up_stream.__aiter__() from litellm._logging import verbose_logger verbose_logger.debug("Follow-up stream iterator created") try: chunk = await self.follow_up_iterator.__anext__() from litellm._logging import verbose_logger verbose_logger.debug(f"Follow-up chunk yielded: {chunk}") return chunk except StopAsyncIteration: self.follow_up_exhausted = True from litellm._logging import verbose_logger verbose_logger.debug("Follow-up stream exhausted") # After follow-up stream is exhausted, check if we need to raise StopAsyncIteration raise StopAsyncIteration # If we're here and follow_up_stream is None but we expected it, log a warning if ( self.stream_exhausted and self.tool_results and self.complete_response and self.follow_up_stream is None ): from litellm._logging import verbose_logger verbose_logger.warning( "Follow-up stream was not created despite having tool results" ) raise StopAsyncIteration async def _process_tool_calls(self): """Process tool calls after streaming completes.""" if self.tool_execution_done: return self.tool_execution_done = True if not self.collected_chunks: return # Build complete response from chunks complete_response = stream_chunk_builder( chunks=self.collected_chunks, messages=self.messages, ) if isinstance(complete_response, ModelResponse): self.complete_response = complete_response # Extract tool calls from complete response self.tool_calls = LiteLLM_Proxy_MCP_Handler._extract_tool_calls_from_chat_response( response=complete_response ) if self.tool_calls: # Execute tool calls self.tool_results = ( await LiteLLM_Proxy_MCP_Handler._execute_tool_calls( tool_server_map=self.tool_server_map, tool_calls=self.tool_calls, user_api_key_auth=self.user_api_key_auth, mcp_auth_header=self.mcp_auth_header, mcp_server_auth_headers=self.mcp_server_auth_headers, oauth2_headers=self.oauth2_headers, raw_headers=self.raw_headers, litellm_call_id=self.litellm_call_id, litellm_trace_id=self.litellm_trace_id, ) ) async def _prepare_follow_up_call(self): """Prepare and initiate follow-up call with tool results.""" if self.follow_up_stream is not None: return # Already prepared if not self.tool_results or not self.complete_response: return # Create follow-up messages with tool results follow_up_messages = ( LiteLLM_Proxy_MCP_Handler._create_follow_up_messages_for_chat( original_messages=self.messages, response=self.complete_response, tool_results=self.tool_results, ) ) # Make follow-up call with streaming follow_up_call_args = dict(self.base_call_args) follow_up_call_args["messages"] = follow_up_messages follow_up_call_args["stream"] = True # Ensure follow-up call doesn't trigger MCP handler again follow_up_call_args["_skip_mcp_handler"] = True # Import litellm here to ensure we get the patched version # This ensures the patch works correctly in tests import litellm follow_up_response = await litellm.acompletion(**follow_up_call_args) # Ensure follow-up response is a CustomStreamWrapper if isinstance(follow_up_response, CustomStreamWrapper): self.follow_up_stream = follow_up_response from litellm._logging import verbose_logger verbose_logger.debug("Follow-up stream created successfully") else: # Unexpected response type - log and set to None from litellm._logging import verbose_logger verbose_logger.warning( f"Follow-up response is not a CustomStreamWrapper: {type(follow_up_response)}" ) self.follow_up_stream = None # Create the custom iterator iterator = MCPStreamingIterator( stream_wrapper=initial_stream, messages=messages, tool_server_map=tool_server_map, user_api_key_auth=user_api_key_auth, mcp_auth_header=mcp_auth_header, mcp_server_auth_headers=mcp_server_auth_headers, oauth2_headers=oauth2_headers, raw_headers=raw_headers, litellm_call_id=kwargs.get("litellm_call_id"), litellm_trace_id=kwargs.get("litellm_trace_id"), openai_tools=openai_tools, base_call_args=base_call_args, ) # Create a wrapper class that delegates to our custom iterator # We'll use a simple approach: just replace the __aiter__ method class MCPStreamWrapper(CustomStreamWrapper): def __init__(self, original_wrapper, custom_iterator): # Initialize with the same parameters as original wrapper super().__init__( completion_stream=None, model=getattr(original_wrapper, "model", "unknown"), logging_obj=getattr(original_wrapper, "logging_obj", None), custom_llm_provider=getattr( original_wrapper, "custom_llm_provider", None ), stream_options=getattr(original_wrapper, "stream_options", None), make_call=getattr(original_wrapper, "make_call", None), _response_headers=getattr( original_wrapper, "_response_headers", None ), ) self._original_wrapper = original_wrapper self._custom_iterator = custom_iterator # Copy important attributes from original wrapper if hasattr(original_wrapper, "_hidden_params"): self._hidden_params = original_wrapper._hidden_params # For synchronous iteration, we need to run the async iterator self._sync_iterator = None self._sync_loop = None def __aiter__(self): return self._custom_iterator def __iter__(self): # For synchronous iteration, create a sync wrapper if self._sync_iterator is None: import asyncio try: self._sync_loop = asyncio.get_event_loop() except RuntimeError: self._sync_loop = asyncio.new_event_loop() asyncio.set_event_loop(self._sync_loop) self._sync_iterator = _SyncIteratorWrapper( self._custom_iterator, self._sync_loop ) return self._sync_iterator def __next__(self): # Delegate to sync iterator if self._sync_iterator is None: self.__iter__() return next(self._sync_iterator) def __getattr__(self, name): # Delegate all other attributes to original wrapper return getattr(self._original_wrapper, name) # Helper class to wrap async iterator for sync iteration class _SyncIteratorWrapper: def __init__(self, async_iterator, loop): self._async_iterator = async_iterator self._loop = loop self._iterator = None def __iter__(self): return self def __next__(self): if self._iterator is None: # __aiter__ might be async, so we need to await it aiter_result = self._async_iterator.__aiter__() if hasattr(aiter_result, "__await__"): # It's a coroutine, await it self._iterator = self._loop.run_until_complete(aiter_result) else: # It's already an iterator self._iterator = aiter_result try: return self._loop.run_until_complete(self._iterator.__anext__()) except StopAsyncIteration: raise StopIteration return cast(CustomStreamWrapper, MCPStreamWrapper(initial_stream, iterator)) # Non-streaming mode: use existing logic initial_call_args = dict(base_call_args) initial_call_args["stream"] = False if mock_tool_calls is not None: initial_call_args["mock_tool_calls"] = mock_tool_calls # Make initial call initial_response = await litellm_acompletion(**initial_call_args) if not isinstance(initial_response, ModelResponse): return initial_response # Extract tool calls from response tool_calls = LiteLLM_Proxy_MCP_Handler._extract_tool_calls_from_chat_response( response=initial_response ) if not tool_calls: _add_mcp_metadata_to_response( response=initial_response, openai_tools=openai_tools, ) return initial_response # Execute tool calls tool_results = await LiteLLM_Proxy_MCP_Handler._execute_tool_calls( tool_server_map=tool_server_map, tool_calls=tool_calls, user_api_key_auth=user_api_key_auth, mcp_auth_header=mcp_auth_header, mcp_server_auth_headers=mcp_server_auth_headers, oauth2_headers=oauth2_headers, raw_headers=raw_headers, litellm_call_id=kwargs.get("litellm_call_id"), litellm_trace_id=kwargs.get("litellm_trace_id"), ) if not tool_results: _add_mcp_metadata_to_response( response=initial_response, openai_tools=openai_tools, tool_calls=tool_calls, ) return initial_response # Create follow-up messages with tool results follow_up_messages = LiteLLM_Proxy_MCP_Handler._create_follow_up_messages_for_chat( original_messages=messages, response=initial_response, tool_results=tool_results, ) # Make follow-up call with original stream setting follow_up_call_args = dict(base_call_args) follow_up_call_args["messages"] = follow_up_messages follow_up_call_args["stream"] = stream response = await litellm_acompletion(**follow_up_call_args) if isinstance(response, (ModelResponse, CustomStreamWrapper)): _add_mcp_metadata_to_response( response=response, openai_tools=openai_tools, tool_calls=tool_calls, tool_results=tool_results, ) return response