Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/proxy/agent_endpoints/a2a_endpoints.py

501 lines
18 KiB
Python
Raw Normal View History

"""
A2A Protocol endpoints for LiteLLM Proxy.
Allows clients to invoke agents through LiteLLM using the A2A protocol.
The A2A SDK can point to LiteLLM's URL and invoke agents registered with LiteLLM.
"""
import json
from typing import Any, Dict, Optional
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.agent_endpoints.utils import merge_agent_headers
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.utils import all_litellm_params
router = APIRouter()
def _jsonrpc_error(
request_id: Optional[str],
code: int,
message: str,
status_code: int = 400,
) -> JSONResponse:
"""Create a JSON-RPC 2.0 error response."""
return JSONResponse(
content={
"jsonrpc": "2.0",
"id": request_id,
"error": {"code": code, "message": message},
},
status_code=status_code,
)
def _get_agent(agent_id: str):
"""Look up an agent by ID or name. Returns None if not found."""
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
agent = global_agent_registry.get_agent_by_id(agent_id=agent_id)
if agent is None:
agent = global_agent_registry.get_agent_by_name(agent_name=agent_id)
return agent
def _enforce_inbound_trace_id(agent: Any, request: Request) -> None:
"""Raise 400 if agent requires x-litellm-trace-id on inbound calls and it is missing."""
agent_litellm_params = agent.litellm_params or {}
if not agent_litellm_params.get("require_trace_id_on_calls_to_agent"):
return
from litellm.proxy.litellm_pre_call_utils import get_chain_id_from_headers
headers_dict = dict(request.headers)
trace_id = get_chain_id_from_headers(headers_dict)
if not trace_id:
raise HTTPException(
status_code=400,
detail=(
f"Agent '{agent.agent_id}' requires x-litellm-trace-id header "
"on all inbound requests."
),
)
async def _handle_stream_message(
api_base: Optional[str],
request_id: str,
params: dict,
litellm_params: Optional[dict] = None,
agent_id: Optional[str] = None,
metadata: Optional[dict] = None,
proxy_server_request: Optional[dict] = None,
*,
agent_extra_headers: Optional[Dict[str, str]] = None,
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
request_data: Optional[dict] = None,
proxy_logging_obj: Optional[Any] = None,
) -> StreamingResponse:
"""Handle message/stream method via SDK functions.
When user_api_key_dict, request_data, and proxy_logging_obj are provided,
uses common_request_processing.async_streaming_data_generator with NDJSON
serializers so proxy hooks and cost injection apply.
"""
from litellm.a2a_protocol import asend_message_streaming
from litellm.a2a_protocol.main import A2A_SDK_AVAILABLE
if not A2A_SDK_AVAILABLE:
async def _error_stream():
yield json.dumps(
{
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32603,
"message": "Server error: 'a2a' package not installed",
},
}
) + "\n"
return StreamingResponse(_error_stream(), media_type="application/x-ndjson")
from a2a.types import MessageSendParams, SendStreamingMessageRequest
use_proxy_hooks = (
user_api_key_dict is not None
and request_data is not None
and proxy_logging_obj is not None
)
async def stream_response():
try:
a2a_request = SendStreamingMessageRequest(
id=request_id,
params=MessageSendParams(**params),
)
a2a_stream = asend_message_streaming(
request=a2a_request,
api_base=api_base,
litellm_params=litellm_params,
agent_id=agent_id,
metadata=metadata,
proxy_server_request=proxy_server_request,
agent_extra_headers=agent_extra_headers,
)
if (
use_proxy_hooks
and user_api_key_dict is not None
and request_data is not None
and proxy_logging_obj is not None
):
from litellm.proxy.common_request_processing import (
ProxyBaseLLMRequestProcessing,
)
def _ndjson_chunk(chunk: Any) -> str:
if hasattr(chunk, "model_dump"):
obj = chunk.model_dump(mode="json", exclude_none=True)
else:
obj = chunk
return json.dumps(obj) + "\n"
def _ndjson_error(proxy_exc: Any) -> str:
return (
json.dumps(
{
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32603,
"message": getattr(
proxy_exc,
"message",
f"Streaming error: {proxy_exc!s}",
),
},
}
)
+ "\n"
)
async for (
line
) in ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
response=a2a_stream,
user_api_key_dict=user_api_key_dict,
request_data=request_data,
proxy_logging_obj=proxy_logging_obj,
serialize_chunk=_ndjson_chunk,
serialize_error=_ndjson_error,
):
yield line
else:
async for chunk in a2a_stream:
if hasattr(chunk, "model_dump"):
yield json.dumps(
chunk.model_dump(mode="json", exclude_none=True)
) + "\n"
else:
yield json.dumps(chunk) + "\n"
except Exception as e:
verbose_proxy_logger.exception(f"Error streaming A2A response: {e}")
if (
use_proxy_hooks
and proxy_logging_obj is not None
and user_api_key_dict is not None
and request_data is not None
):
transformed_exception = await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=request_data,
)
if transformed_exception is not None:
e = transformed_exception
if isinstance(e, HTTPException):
raise
yield json.dumps(
{
"jsonrpc": "2.0",
"id": request_id,
"error": {"code": -32603, "message": f"Streaming error: {str(e)}"},
}
) + "\n"
return StreamingResponse(stream_response(), media_type="application/x-ndjson")
@router.get(
"/a2a/{agent_id}/.well-known/agent-card.json",
tags=["[beta] A2A Agents"],
dependencies=[Depends(user_api_key_auth)],
)
@router.get(
"/a2a/{agent_id}/.well-known/agent.json",
tags=["[beta] A2A Agents"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_agent_card(
agent_id: str,
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get the agent card for an agent (A2A discovery endpoint).
Supports both standard paths:
- /.well-known/agent-card.json
- /.well-known/agent.json
The URL in the agent card is rewritten to point to the LiteLLM proxy,
so all subsequent A2A calls go through LiteLLM for logging and cost tracking.
"""
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
AgentRequestHandler,
)
try:
agent = _get_agent(agent_id)
if agent is None:
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
# Check agent permission (skip for admin users)
is_allowed = await AgentRequestHandler.is_agent_allowed(
agent_id=agent.agent_id,
user_api_key_auth=user_api_key_dict,
)
if not is_allowed:
raise HTTPException(
status_code=403,
detail=f"Agent '{agent_id}' is not allowed for your key/team. Contact proxy admin for access.",
)
# Copy and rewrite URL to point to LiteLLM proxy
agent_card = dict(agent.agent_card_params)
agent_card["url"] = f"{str(request.base_url).rstrip('/')}/a2a/{agent_id}"
verbose_proxy_logger.debug(
f"Returning agent card for '{agent_id}' with proxy URL: {agent_card['url']}"
)
return JSONResponse(content=agent_card)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error getting agent card: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/a2a/{agent_id}",
tags=["[beta] A2A Agents"],
dependencies=[Depends(user_api_key_auth)],
)
@router.post(
"/a2a/{agent_id}/message/send",
tags=["[beta] A2A Agents"],
dependencies=[Depends(user_api_key_auth)],
)
@router.post(
"/v1/a2a/{agent_id}/message/send",
tags=["[beta] A2A Agents"],
dependencies=[Depends(user_api_key_auth)],
)
async def invoke_agent_a2a( # noqa: PLR0915
agent_id: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Invoke an agent using the A2A protocol (JSON-RPC 2.0).
Supported methods:
- message/send: Send a message and get a response
- message/stream: Send a message and stream the response
"""
from litellm.a2a_protocol import asend_message
from litellm.a2a_protocol.main import A2A_SDK_AVAILABLE
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
AgentRequestHandler,
)
from litellm.proxy.proxy_server import (
general_settings,
proxy_config,
proxy_logging_obj,
version,
)
body = {}
try:
body = await request.json()
verbose_proxy_logger.debug(f"A2A request for agent '{agent_id}': {body}")
# Validate JSON-RPC format
if body.get("jsonrpc") != "2.0":
return _jsonrpc_error(
body.get("id"), -32600, "Invalid Request: jsonrpc must be '2.0'"
)
request_id = body.get("id")
method = body.get("method")
params = body.get("params", {})
if params:
# extract any litellm params from the params - eg. 'guardrails'
params_to_remove = []
for key, value in params.items():
if key in all_litellm_params:
params_to_remove.append(key)
body[key] = value
for key in params_to_remove:
params.pop(key)
if not A2A_SDK_AVAILABLE:
return _jsonrpc_error(
request_id,
-32603,
"Server error: 'a2a' package not installed. Please install 'a2a-sdk'.",
500,
)
# Find the agent
agent = _get_agent(agent_id)
if agent is None:
return _jsonrpc_error(
request_id, -32000, f"Agent '{agent_id}' not found", 404
)
is_allowed = await AgentRequestHandler.is_agent_allowed(
agent_id=agent.agent_id,
user_api_key_auth=user_api_key_dict,
)
if not is_allowed:
raise HTTPException(
status_code=403,
detail=f"Agent '{agent_id}' is not allowed for your key/team. Contact proxy admin for access.",
)
_enforce_inbound_trace_id(agent, request)
# Get backend URL and agent name
agent_url = agent.agent_card_params.get("url")
agent_name = agent.agent_card_params.get("name", agent_id)
# Get litellm_params (may include custom_llm_provider for completion bridge)
litellm_params = agent.litellm_params or {}
custom_llm_provider = litellm_params.get("custom_llm_provider")
# URL is required unless using completion bridge with a provider that derives endpoint from model
# (e.g., bedrock/agentcore derives endpoint from ARN in model string)
if not agent_url and not custom_llm_provider:
return _jsonrpc_error(
request_id, -32000, f"Agent '{agent_id}' has no URL configured", 500
)
verbose_proxy_logger.info(
f"Proxying A2A request to agent '{agent_id}' at {agent_url or 'completion-bridge'}"
)
# Set up data dict for litellm processing
if "metadata" not in body:
body["metadata"] = {}
body["metadata"]["agent_id"] = agent.agent_id
body.update(
{
"model": f"a2a_agent/{agent_name}",
"custom_llm_provider": "a2a_agent",
}
)
# Add litellm data (user_api_key, user_id, team_id, etc.)
from litellm.proxy.common_request_processing import (
ProxyBaseLLMRequestProcessing,
)
processor = ProxyBaseLLMRequestProcessing(data=body)
data, logging_obj = await processor.common_processing_pre_call_logic(
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
proxy_config=proxy_config,
route_type="asend_message",
version=version,
)
# Build merged headers for the backend agent
static_headers: Dict[str, str] = dict(agent.static_headers or {})
raw_headers = dict(request.headers)
normalized = {k.lower(): v for k, v in raw_headers.items()}
dynamic_headers: Dict[str, str] = {}
# 1. Admin-configured extra_headers: forward named headers from client request
if agent.extra_headers:
for header_name in agent.extra_headers:
val = normalized.get(header_name.lower())
if val is not None:
dynamic_headers[header_name] = val
# 2. Convention-based forwarding: x-a2a-{agent_id_or_name}-{header_name}
# Matches both agent_id (UUID) and agent_name (alias), case-insensitive.
for alias in (agent.agent_id.lower(), agent.agent_name.lower()):
prefix = f"x-a2a-{alias}-"
for key, val in normalized.items():
if key.startswith(prefix):
header_name = key[len(prefix) :]
if header_name:
dynamic_headers[header_name] = val
agent_extra_headers = merge_agent_headers(
dynamic_headers=dynamic_headers or None,
static_headers=static_headers or None,
)
# Route through SDK functions
if method == "message/send":
from a2a.types import MessageSendParams, SendMessageRequest
a2a_request = SendMessageRequest(
id=request_id,
params=MessageSendParams(**params),
)
response = await asend_message(
request=a2a_request,
api_base=agent_url,
litellm_params=litellm_params,
agent_id=agent.agent_id,
metadata=data.get("metadata", {}),
proxy_server_request=data.get("proxy_server_request"),
litellm_logging_obj=logging_obj,
agent_extra_headers=agent_extra_headers,
)
response = await proxy_logging_obj.post_call_success_hook(
user_api_key_dict=user_api_key_dict,
data=data,
response=response,
)
return JSONResponse(
content=(
response.model_dump(mode="json", exclude_none=True) # type: ignore
if hasattr(response, "model_dump")
else response
)
)
elif method == "message/stream":
return await _handle_stream_message(
api_base=agent_url,
request_id=request_id,
params=params,
litellm_params=litellm_params,
agent_id=agent.agent_id,
metadata=data.get("metadata", {}),
proxy_server_request=data.get("proxy_server_request"),
agent_extra_headers=agent_extra_headers,
user_api_key_dict=user_api_key_dict,
request_data=data,
proxy_logging_obj=proxy_logging_obj,
)
else:
return _jsonrpc_error(request_id, -32601, f"Method '{method}' not found")
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error invoking agent: {e}")
return _jsonrpc_error(body.get("id"), -32603, f"Internal error: {str(e)}", 500)