chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
# LiteLLM MCP Client
|
||||
|
||||
LiteLLM MCP Client is a client that allows you to use MCP tools with LiteLLM.
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .tools import call_openai_tool, load_mcp_tools
|
||||
|
||||
__all__ = ["load_mcp_tools", "call_openai_tool"]
|
||||
@@ -0,0 +1,697 @@
|
||||
"""
|
||||
LiteLLM Proxy uses this MCP Client to connnect to other MCP servers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from mcp import ClientSession, ReadResourceResult, Resource, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
streamable_http_client: Optional[Any] = None
|
||||
try:
|
||||
import mcp.client.streamable_http as streamable_http_module # type: ignore
|
||||
|
||||
streamable_http_client = getattr(
|
||||
streamable_http_module, "streamable_http_client", None
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
|
||||
from mcp.types import CallToolResult as MCPCallToolResult
|
||||
from mcp.types import (
|
||||
GetPromptRequestParams,
|
||||
GetPromptResult,
|
||||
Prompt,
|
||||
ResourceTemplate,
|
||||
TextContent,
|
||||
)
|
||||
from mcp.types import Tool as MCPTool
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import MCP_CLIENT_TIMEOUT
|
||||
from litellm.llms.custom_httpx.http_handler import get_ssl_configuration
|
||||
from litellm.types.llms.custom_http import VerifyTypes
|
||||
from litellm.types.mcp import (
|
||||
MCPAuth,
|
||||
MCPAuthType,
|
||||
MCPStdioConfig,
|
||||
MCPTransport,
|
||||
MCPTransportType,
|
||||
)
|
||||
|
||||
|
||||
def to_basic_auth(auth_value: str) -> str:
|
||||
"""Convert auth value to Basic Auth format."""
|
||||
return base64.b64encode(auth_value.encode("utf-8")).decode()
|
||||
|
||||
|
||||
TSessionResult = TypeVar("TSessionResult")
|
||||
|
||||
|
||||
class MCPSigV4Auth(httpx.Auth):
|
||||
"""
|
||||
httpx Auth class that signs each request with AWS SigV4.
|
||||
|
||||
This is used for MCP servers that require AWS SigV4 authentication,
|
||||
such as AWS Bedrock AgentCore MCP servers. httpx calls auth_flow()
|
||||
for every outgoing request, enabling per-request signature computation.
|
||||
"""
|
||||
|
||||
requires_request_body = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_service_name: Optional[str] = None,
|
||||
):
|
||||
try:
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Missing botocore to use AWS SigV4 authentication. "
|
||||
"Run 'pip install boto3'."
|
||||
)
|
||||
|
||||
self.service_name = aws_service_name or "bedrock-agentcore"
|
||||
self.region_name = aws_region_name or "us-east-1"
|
||||
|
||||
# Note: os.environ/ prefixed values are already resolved by
|
||||
# ProxyConfig._check_for_os_environ_vars() at config load time.
|
||||
# Values arrive here as plain strings.
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
self.credentials = Credentials(
|
||||
access_key=aws_access_key_id,
|
||||
secret_key=aws_secret_access_key,
|
||||
token=aws_session_token,
|
||||
)
|
||||
else:
|
||||
# Fall back to default boto3 credential chain
|
||||
import botocore.session
|
||||
|
||||
session = botocore.session.get_session()
|
||||
self.credentials = session.get_credentials()
|
||||
if self.credentials is None:
|
||||
raise ValueError(
|
||||
"No AWS credentials found. Provide aws_access_key_id and "
|
||||
"aws_secret_access_key, or configure default credentials "
|
||||
"(env vars, ~/.aws/credentials, instance profile)."
|
||||
)
|
||||
|
||||
def auth_flow(
|
||||
self, request: httpx.Request
|
||||
) -> Generator[httpx.Request, httpx.Response, None]:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
|
||||
# Build AWSRequest from the httpx Request.
|
||||
# Pass all request headers so the canonical SigV4 signature covers them.
|
||||
aws_request = AWSRequest(
|
||||
method=request.method,
|
||||
url=str(request.url),
|
||||
data=request.content,
|
||||
headers=dict(request.headers),
|
||||
)
|
||||
|
||||
# Sign the request — SigV4Auth.add_auth() adds Authorization,
|
||||
# X-Amz-Date, and X-Amz-Security-Token (if session token present).
|
||||
# Host header is derived automatically from the URL.
|
||||
sigv4 = SigV4Auth(self.credentials, self.service_name, self.region_name)
|
||||
sigv4.add_auth(aws_request)
|
||||
|
||||
# Copy SigV4 headers back to the httpx request
|
||||
for header_name, header_value in aws_request.headers.items():
|
||||
request.headers[header_name] = header_value
|
||||
|
||||
yield request
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""
|
||||
MCP Client supporting:
|
||||
SSE and HTTP transports
|
||||
Authentication via Bearer token, Basic Auth, or API Key
|
||||
Tool calling with error handling and result parsing
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str = "",
|
||||
transport_type: MCPTransportType = MCPTransport.http,
|
||||
auth_type: MCPAuthType = None,
|
||||
auth_value: Optional[Union[str, Dict[str, str]]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
stdio_config: Optional[MCPStdioConfig] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
ssl_verify: Optional[VerifyTypes] = None,
|
||||
aws_auth: Optional[httpx.Auth] = None,
|
||||
):
|
||||
self.server_url: str = server_url
|
||||
self.transport_type: MCPTransport = transport_type
|
||||
self.auth_type: MCPAuthType = auth_type
|
||||
self.timeout: float = timeout if timeout is not None else MCP_CLIENT_TIMEOUT
|
||||
self._mcp_auth_value: Optional[Union[str, Dict[str, str]]] = None
|
||||
self.stdio_config: Optional[MCPStdioConfig] = stdio_config
|
||||
self.extra_headers: Optional[Dict[str, str]] = extra_headers
|
||||
self.ssl_verify: Optional[VerifyTypes] = ssl_verify
|
||||
self._aws_auth: Optional[httpx.Auth] = aws_auth
|
||||
# handle the basic auth value if provided
|
||||
if auth_value:
|
||||
self.update_auth_value(auth_value)
|
||||
|
||||
def _create_transport_context(
|
||||
self,
|
||||
) -> Tuple[Any, Optional[httpx.AsyncClient]]:
|
||||
"""
|
||||
Create the appropriate transport context based on transport type.
|
||||
|
||||
Returns:
|
||||
Tuple of (transport_context, http_client).
|
||||
http_client is only set for HTTP transport and needs cleanup.
|
||||
"""
|
||||
http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
if self.transport_type == MCPTransport.stdio:
|
||||
if not self.stdio_config:
|
||||
raise ValueError("stdio_config is required for stdio transport")
|
||||
server_params = StdioServerParameters(
|
||||
command=self.stdio_config.get("command", ""),
|
||||
args=self.stdio_config.get("args", []),
|
||||
env=self.stdio_config.get("env", {}),
|
||||
)
|
||||
return stdio_client(server_params), None
|
||||
|
||||
if self.transport_type == MCPTransport.sse:
|
||||
headers = self._get_auth_headers()
|
||||
httpx_client_factory = self._create_httpx_client_factory()
|
||||
return (
|
||||
sse_client(
|
||||
url=self.server_url,
|
||||
timeout=self.timeout,
|
||||
headers=headers,
|
||||
httpx_client_factory=httpx_client_factory,
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# HTTP transport (default)
|
||||
if streamable_http_client is None:
|
||||
raise ImportError(
|
||||
"streamable_http_client is not available. "
|
||||
"Please install mcp with HTTP support."
|
||||
)
|
||||
|
||||
headers = self._get_auth_headers()
|
||||
httpx_client_factory = self._create_httpx_client_factory()
|
||||
verbose_logger.debug("litellm headers for streamable_http_client: %s", headers)
|
||||
http_client = httpx_client_factory(
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout),
|
||||
)
|
||||
transport_ctx = streamable_http_client(
|
||||
url=self.server_url,
|
||||
http_client=http_client,
|
||||
)
|
||||
return transport_ctx, http_client
|
||||
|
||||
async def _execute_session_operation(
|
||||
self,
|
||||
transport_ctx: Any,
|
||||
operation: Callable[[ClientSession], Awaitable[TSessionResult]],
|
||||
) -> TSessionResult:
|
||||
"""
|
||||
Execute an operation within a transport and session context.
|
||||
|
||||
Handles entering/exiting contexts and running the operation.
|
||||
"""
|
||||
transport = await transport_ctx.__aenter__()
|
||||
try:
|
||||
read_stream, write_stream = transport[0], transport[1]
|
||||
session_ctx = ClientSession(read_stream, write_stream)
|
||||
session = await session_ctx.__aenter__()
|
||||
try:
|
||||
await session.initialize()
|
||||
return await operation(session)
|
||||
finally:
|
||||
try:
|
||||
await session_ctx.__aexit__(None, None, None)
|
||||
except BaseException as e:
|
||||
verbose_logger.debug(f"Error during session context exit: {e}")
|
||||
finally:
|
||||
try:
|
||||
await transport_ctx.__aexit__(None, None, None)
|
||||
except BaseException as e:
|
||||
verbose_logger.debug(f"Error during transport context exit: {e}")
|
||||
|
||||
async def run_with_session(
|
||||
self, operation: Callable[[ClientSession], Awaitable[TSessionResult]]
|
||||
) -> TSessionResult:
|
||||
"""Open a session, run the provided coroutine, and clean up."""
|
||||
http_client: Optional[httpx.AsyncClient] = None
|
||||
try:
|
||||
transport_ctx, http_client = self._create_transport_context()
|
||||
return await self._execute_session_operation(transport_ctx, operation)
|
||||
except Exception:
|
||||
verbose_logger.warning(
|
||||
"MCP client run_with_session failed for %s", self.server_url or "stdio"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if http_client is not None:
|
||||
try:
|
||||
await http_client.aclose()
|
||||
except BaseException as e:
|
||||
verbose_logger.debug(f"Error during http_client cleanup: {e}")
|
||||
|
||||
def update_auth_value(self, mcp_auth_value: Union[str, Dict[str, str]]):
|
||||
"""
|
||||
Set the authentication header for the MCP client.
|
||||
"""
|
||||
if isinstance(mcp_auth_value, dict):
|
||||
self._mcp_auth_value = mcp_auth_value
|
||||
else:
|
||||
if self.auth_type == MCPAuth.basic:
|
||||
# Assuming mcp_auth_value is in format "username:password", convert it when updating
|
||||
mcp_auth_value = to_basic_auth(mcp_auth_value)
|
||||
self._mcp_auth_value = mcp_auth_value
|
||||
|
||||
def _get_auth_headers(self) -> dict:
|
||||
"""Generate authentication headers based on auth type."""
|
||||
headers = {}
|
||||
|
||||
if self._mcp_auth_value:
|
||||
if isinstance(self._mcp_auth_value, str):
|
||||
if self.auth_type == MCPAuth.bearer_token:
|
||||
headers["Authorization"] = f"Bearer {self._mcp_auth_value}"
|
||||
elif self.auth_type == MCPAuth.basic:
|
||||
headers["Authorization"] = f"Basic {self._mcp_auth_value}"
|
||||
elif self.auth_type == MCPAuth.api_key:
|
||||
headers["X-API-Key"] = self._mcp_auth_value
|
||||
elif self.auth_type == MCPAuth.authorization:
|
||||
headers["Authorization"] = self._mcp_auth_value
|
||||
elif self.auth_type == MCPAuth.oauth2:
|
||||
headers["Authorization"] = f"Bearer {self._mcp_auth_value}"
|
||||
elif self.auth_type == MCPAuth.token:
|
||||
headers["Authorization"] = f"token {self._mcp_auth_value}"
|
||||
elif isinstance(self._mcp_auth_value, dict):
|
||||
headers.update(self._mcp_auth_value)
|
||||
# Note: aws_sigv4 auth is not handled here — SigV4 requires per-request
|
||||
# signing (including the body hash), so it uses httpx.Auth flow instead
|
||||
# of static headers. See MCPSigV4Auth and _create_httpx_client_factory().
|
||||
|
||||
# update the headers with the extra headers
|
||||
if self.extra_headers:
|
||||
headers.update(self.extra_headers)
|
||||
|
||||
return headers
|
||||
|
||||
def _create_httpx_client_factory(self) -> Callable[..., httpx.AsyncClient]:
|
||||
"""
|
||||
Create a custom httpx client factory that uses LiteLLM's SSL configuration.
|
||||
|
||||
This factory follows the same CA bundle path logic as http_handler.py:
|
||||
1. Check ssl_verify parameter (can be SSLContext, bool, or path to CA bundle)
|
||||
2. Check SSL_VERIFY environment variable
|
||||
3. Check SSL_CERT_FILE environment variable
|
||||
4. Fall back to certifi CA bundle
|
||||
"""
|
||||
|
||||
def factory(
|
||||
*,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[httpx.Timeout] = None,
|
||||
auth: Optional[httpx.Auth] = None,
|
||||
) -> httpx.AsyncClient:
|
||||
"""Create an httpx.AsyncClient with LiteLLM's SSL configuration."""
|
||||
# Get unified SSL configuration using the same logic as http_handler.py
|
||||
ssl_config = get_ssl_configuration(self.ssl_verify)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"MCP client using SSL configuration: {type(ssl_config).__name__}"
|
||||
)
|
||||
|
||||
# Use SigV4 auth if configured and no explicit auth provided.
|
||||
# The MCP SDK's sse_client and streamable_http_client call this
|
||||
# factory without passing auth=, so self._aws_auth is used.
|
||||
# For non-SigV4 clients, self._aws_auth is None — no behavior change.
|
||||
effective_auth = auth if auth is not None else self._aws_auth
|
||||
|
||||
return httpx.AsyncClient(
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
auth=effective_auth,
|
||||
verify=ssl_config,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
return factory
|
||||
|
||||
async def list_tools(self) -> List[MCPTool]:
|
||||
"""List available tools from the server."""
|
||||
verbose_logger.debug(
|
||||
f"MCP client listing tools from {self.server_url or 'stdio'}"
|
||||
)
|
||||
|
||||
async def _list_tools_operation(session: ClientSession):
|
||||
return await session.list_tools()
|
||||
|
||||
try:
|
||||
result = await self.run_with_session(_list_tools_operation)
|
||||
tool_count = len(result.tools)
|
||||
tool_names = [tool.name for tool in result.tools]
|
||||
verbose_logger.info(
|
||||
f"MCP client listed {tool_count} tools from {self.server_url or 'stdio'}: {tool_names}"
|
||||
)
|
||||
return result.tools
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client list_tools was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.exception(
|
||||
f"MCP client list_tools failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream during list_tools - "
|
||||
"the MCP server may have crashed, disconnected, or timed out"
|
||||
)
|
||||
|
||||
# Return empty list instead of raising to allow graceful degradation
|
||||
return []
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
call_tool_request_params: MCPCallToolRequestParams,
|
||||
host_progress_callback: Optional[Callable] = None,
|
||||
) -> MCPCallToolResult:
|
||||
"""
|
||||
Call an MCP Tool.
|
||||
"""
|
||||
verbose_logger.info(
|
||||
f"MCP client calling tool '{call_tool_request_params.name}' with arguments: {call_tool_request_params.arguments}"
|
||||
)
|
||||
|
||||
async def on_progress(
|
||||
progress: float, total: float | None, message: str | None
|
||||
):
|
||||
percentage = (progress / total * 100) if total else 0
|
||||
verbose_logger.info(
|
||||
f"MCP Tool '{call_tool_request_params.name}' progress: "
|
||||
f"{progress}/{total} ({percentage:.0f}%) - {message or ''}"
|
||||
)
|
||||
|
||||
# Forward to Host if callback provided
|
||||
if host_progress_callback:
|
||||
try:
|
||||
await host_progress_callback(progress, total)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to forward to Host: {e}")
|
||||
|
||||
async def _call_tool_operation(session: ClientSession):
|
||||
verbose_logger.debug("MCP client sending tool call to session")
|
||||
return await session.call_tool(
|
||||
name=call_tool_request_params.name,
|
||||
arguments=call_tool_request_params.arguments,
|
||||
progress_callback=on_progress,
|
||||
)
|
||||
|
||||
try:
|
||||
tool_result = await self.run_with_session(_call_tool_operation)
|
||||
verbose_logger.info(
|
||||
f"MCP client tool call '{call_tool_request_params.name}' completed successfully"
|
||||
)
|
||||
return tool_result
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client tool call was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_trace = traceback.format_exc()
|
||||
verbose_logger.debug(f"MCP client tool call traceback:\n{error_trace}")
|
||||
|
||||
# Log detailed error information
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.error(
|
||||
f"MCP client call_tool failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Tool: {call_tool_request_params.name}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream - "
|
||||
"the MCP server may have crashed, disconnected, or timed out."
|
||||
)
|
||||
|
||||
# Return a default error result instead of raising
|
||||
return MCPCallToolResult(
|
||||
content=[
|
||||
TextContent(type="text", text=f"{error_type}: {str(e)}")
|
||||
], # Empty content for error case
|
||||
isError=True,
|
||||
)
|
||||
|
||||
async def list_prompts(self) -> List[Prompt]:
|
||||
"""List available prompts from the server."""
|
||||
verbose_logger.debug(
|
||||
f"MCP client listing tools from {self.server_url or 'stdio'}"
|
||||
)
|
||||
|
||||
async def _list_prompts_operation(session: ClientSession):
|
||||
return await session.list_prompts()
|
||||
|
||||
try:
|
||||
result = await self.run_with_session(_list_prompts_operation)
|
||||
prompt_count = len(result.prompts)
|
||||
prompt_names = [prompt.name for prompt in result.prompts]
|
||||
verbose_logger.info(
|
||||
f"MCP client listed {prompt_count} tools from {self.server_url or 'stdio'}: {prompt_names}"
|
||||
)
|
||||
return result.prompts
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client list_prompts was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.error(
|
||||
f"MCP client list_prompts failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream during list_tools - "
|
||||
"the MCP server may have crashed, disconnected, or timed out"
|
||||
)
|
||||
|
||||
# Return empty list instead of raising to allow graceful degradation
|
||||
return []
|
||||
|
||||
async def get_prompt(
|
||||
self, get_prompt_request_params: GetPromptRequestParams
|
||||
) -> GetPromptResult:
|
||||
"""Fetch a prompt definition from the MCP server."""
|
||||
verbose_logger.info(
|
||||
f"MCP client fetching prompt '{get_prompt_request_params.name}' with arguments: {get_prompt_request_params.arguments}"
|
||||
)
|
||||
|
||||
async def _get_prompt_operation(session: ClientSession):
|
||||
verbose_logger.debug("MCP client sending get_prompt request to session")
|
||||
return await session.get_prompt(
|
||||
name=get_prompt_request_params.name,
|
||||
arguments=get_prompt_request_params.arguments,
|
||||
)
|
||||
|
||||
try:
|
||||
get_prompt_result = await self.run_with_session(_get_prompt_operation)
|
||||
verbose_logger.info(
|
||||
f"MCP client get_prompt '{get_prompt_request_params.name}' completed successfully"
|
||||
)
|
||||
return get_prompt_result
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client get_prompt was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_trace = traceback.format_exc()
|
||||
verbose_logger.debug(f"MCP client get_prompt traceback:\n{error_trace}")
|
||||
|
||||
# Log detailed error information
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.error(
|
||||
f"MCP client get_prompt failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Prompt: {get_prompt_request_params.name}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream during get_prompt - "
|
||||
"the MCP server may have crashed, disconnected, or timed out."
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
async def list_resources(self) -> list[Resource]:
|
||||
"""List available resources from the server."""
|
||||
verbose_logger.debug(
|
||||
f"MCP client listing resources from {self.server_url or 'stdio'}"
|
||||
)
|
||||
|
||||
async def _list_resources_operation(session: ClientSession):
|
||||
return await session.list_resources()
|
||||
|
||||
try:
|
||||
result = await self.run_with_session(_list_resources_operation)
|
||||
resource_count = len(result.resources)
|
||||
resource_names = [resource.name for resource in result.resources]
|
||||
verbose_logger.info(
|
||||
f"MCP client listed {resource_count} resources from {self.server_url or 'stdio'}: {resource_names}"
|
||||
)
|
||||
return result.resources
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client list_resources was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.error(
|
||||
f"MCP client list_resources failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream during list_resources - "
|
||||
"the MCP server may have crashed, disconnected, or timed out"
|
||||
)
|
||||
|
||||
# Return empty list instead of raising to allow graceful degradation
|
||||
return []
|
||||
|
||||
async def list_resource_templates(self) -> list[ResourceTemplate]:
|
||||
"""List available resource templates from the server."""
|
||||
verbose_logger.debug(
|
||||
f"MCP client listing resource templates from {self.server_url or 'stdio'}"
|
||||
)
|
||||
|
||||
async def _list_resource_templates_operation(session: ClientSession):
|
||||
return await session.list_resource_templates()
|
||||
|
||||
try:
|
||||
result = await self.run_with_session(_list_resource_templates_operation)
|
||||
resource_template_count = len(result.resourceTemplates)
|
||||
resource_template_names = [
|
||||
resourceTemplate.name for resourceTemplate in result.resourceTemplates
|
||||
]
|
||||
verbose_logger.info(
|
||||
f"MCP client listed {resource_template_count} resource templates from {self.server_url or 'stdio'}: {resource_template_names}"
|
||||
)
|
||||
return result.resourceTemplates
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client list_resource_templates was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.error(
|
||||
f"MCP client list_resource_templates failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream during list_resource_templates - "
|
||||
"the MCP server may have crashed, disconnected, or timed out"
|
||||
)
|
||||
|
||||
# Return empty list instead of raising to allow graceful degradation
|
||||
return []
|
||||
|
||||
async def read_resource(self, url: AnyUrl) -> ReadResourceResult:
|
||||
"""Fetch resource contents from the MCP server."""
|
||||
verbose_logger.info(f"MCP client fetching resource '{url}'")
|
||||
|
||||
async def _read_resource_operation(session: ClientSession):
|
||||
verbose_logger.debug("MCP client sending read_resource request to session")
|
||||
return await session.read_resource(url)
|
||||
|
||||
try:
|
||||
read_resource_result = await self.run_with_session(_read_resource_operation)
|
||||
verbose_logger.info(
|
||||
f"MCP client read_resource '{url}' completed successfully"
|
||||
)
|
||||
return read_resource_result
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client read_resource was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_trace = traceback.format_exc()
|
||||
verbose_logger.debug(f"MCP client read_resource traceback:\n{error_trace}")
|
||||
|
||||
# Log detailed error information
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.error(
|
||||
f"MCP client read_resource failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Url: {url}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream during read_resource - "
|
||||
"the MCP server may have crashed, disconnected, or timed out."
|
||||
)
|
||||
|
||||
raise
|
||||
@@ -0,0 +1,159 @@
|
||||
import json
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
|
||||
from mcp.types import CallToolResult as MCPCallToolResult
|
||||
from mcp.types import Tool as MCPTool
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from openai.types.responses.function_tool_param import FunctionToolParam
|
||||
from openai.types.shared_params.function_definition import FunctionDefinition
|
||||
|
||||
from litellm.types.utils import ChatCompletionMessageToolCall
|
||||
|
||||
|
||||
########################################################
|
||||
# List MCP Tool functions
|
||||
########################################################
|
||||
def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolParam:
|
||||
"""Convert an MCP tool to an OpenAI tool."""
|
||||
normalized_parameters = _normalize_mcp_input_schema(mcp_tool.inputSchema)
|
||||
|
||||
return ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=FunctionDefinition(
|
||||
name=mcp_tool.name,
|
||||
description=mcp_tool.description or "",
|
||||
parameters=normalized_parameters,
|
||||
strict=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _normalize_mcp_input_schema(input_schema: dict) -> dict:
|
||||
"""
|
||||
Normalize MCP input schema to ensure it's valid for OpenAI function calling.
|
||||
|
||||
OpenAI requires that function parameters have:
|
||||
- type: 'object'
|
||||
- properties: dict (can be empty)
|
||||
- additionalProperties: false (recommended)
|
||||
"""
|
||||
if not input_schema:
|
||||
return {"type": "object", "properties": {}, "additionalProperties": False}
|
||||
|
||||
# Make a copy to avoid modifying the original
|
||||
normalized_schema = dict(input_schema)
|
||||
|
||||
# Ensure type is 'object'
|
||||
if "type" not in normalized_schema:
|
||||
normalized_schema["type"] = "object"
|
||||
|
||||
# Ensure properties exists (can be empty)
|
||||
if "properties" not in normalized_schema:
|
||||
normalized_schema["properties"] = {}
|
||||
|
||||
# Add additionalProperties if not present (recommended by OpenAI)
|
||||
if "additionalProperties" not in normalized_schema:
|
||||
normalized_schema["additionalProperties"] = False
|
||||
|
||||
return normalized_schema
|
||||
|
||||
|
||||
def transform_mcp_tool_to_openai_responses_api_tool(
|
||||
mcp_tool: MCPTool,
|
||||
) -> FunctionToolParam:
|
||||
"""Convert an MCP tool to an OpenAI Responses API tool."""
|
||||
normalized_parameters = _normalize_mcp_input_schema(mcp_tool.inputSchema)
|
||||
|
||||
return FunctionToolParam(
|
||||
name=mcp_tool.name,
|
||||
parameters=normalized_parameters,
|
||||
strict=False,
|
||||
type="function",
|
||||
description=mcp_tool.description or "",
|
||||
)
|
||||
|
||||
|
||||
async def load_mcp_tools(
|
||||
session: ClientSession, format: Literal["mcp", "openai"] = "mcp"
|
||||
) -> Union[List[MCPTool], List[ChatCompletionToolParam]]:
|
||||
"""
|
||||
Load all available MCP tools
|
||||
|
||||
Args:
|
||||
session: The MCP session to use
|
||||
format: The format to convert the tools to
|
||||
By default, the tools are returned in MCP format.
|
||||
|
||||
If format is set to "openai", the tools are converted to OpenAI API compatible tools.
|
||||
"""
|
||||
tools = await session.list_tools()
|
||||
if format == "openai":
|
||||
return [
|
||||
transform_mcp_tool_to_openai_tool(mcp_tool=tool) for tool in tools.tools
|
||||
]
|
||||
return tools.tools
|
||||
|
||||
|
||||
########################################################
|
||||
# Call MCP Tool functions
|
||||
########################################################
|
||||
|
||||
|
||||
async def call_mcp_tool(
|
||||
session: ClientSession,
|
||||
call_tool_request_params: MCPCallToolRequestParams,
|
||||
) -> MCPCallToolResult:
|
||||
"""Call an MCP tool."""
|
||||
tool_result = await session.call_tool(
|
||||
name=call_tool_request_params.name,
|
||||
arguments=call_tool_request_params.arguments,
|
||||
)
|
||||
return tool_result
|
||||
|
||||
|
||||
def _get_function_arguments(function: FunctionDefinition) -> dict:
|
||||
"""Helper to safely get and parse function arguments."""
|
||||
arguments = function.get("arguments", {})
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
return arguments if isinstance(arguments, dict) else {}
|
||||
|
||||
|
||||
def transform_openai_tool_call_request_to_mcp_tool_call_request(
|
||||
openai_tool: Union[ChatCompletionMessageToolCall, Dict],
|
||||
) -> MCPCallToolRequestParams:
|
||||
"""Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams."""
|
||||
function = openai_tool["function"]
|
||||
return MCPCallToolRequestParams(
|
||||
name=function["name"],
|
||||
arguments=_get_function_arguments(function),
|
||||
)
|
||||
|
||||
|
||||
async def call_openai_tool(
|
||||
session: ClientSession,
|
||||
openai_tool: ChatCompletionMessageToolCall,
|
||||
) -> MCPCallToolResult:
|
||||
"""
|
||||
Call an OpenAI tool using MCP client.
|
||||
|
||||
Args:
|
||||
session: The MCP session to use
|
||||
openai_tool: The OpenAI tool to call. You can get this from the `choices[0].message.tool_calls[0]` of the response from the OpenAI API.
|
||||
Returns:
|
||||
The result of the MCP tool call.
|
||||
"""
|
||||
mcp_tool_call_request_params = (
|
||||
transform_openai_tool_call_request_to_mcp_tool_call_request(
|
||||
openai_tool=openai_tool,
|
||||
)
|
||||
)
|
||||
return await call_mcp_tool(
|
||||
session=session,
|
||||
call_tool_request_params=mcp_tool_call_request_params,
|
||||
)
|
||||
Reference in New Issue
Block a user