chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
LiteLLM A2A - Wrapper for invoking A2A protocol agents.
|
||||
|
||||
This module provides a thin wrapper around the official `a2a` SDK that:
|
||||
- Handles httpx client creation and agent card resolution
|
||||
- Adds LiteLLM logging via @client decorator
|
||||
- Matches the A2A SDK interface (SendMessageRequest, SendMessageResponse, etc.)
|
||||
|
||||
Example usage (standalone functions with @client decorator):
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message
|
||||
from a2a.types import SendMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": "Hello!"}],
|
||||
"messageId": uuid4().hex,
|
||||
}
|
||||
)
|
||||
)
|
||||
response = await asend_message(
|
||||
base_url="http://localhost:10001",
|
||||
request=request,
|
||||
)
|
||||
print(response.model_dump(mode='json', exclude_none=True))
|
||||
```
|
||||
|
||||
Example usage (class-based):
|
||||
```python
|
||||
from litellm.a2a_protocol import A2AClient
|
||||
|
||||
client = A2AClient(base_url="http://localhost:10001")
|
||||
response = await client.send_message(request)
|
||||
```
|
||||
"""
|
||||
|
||||
from litellm.a2a_protocol.client import A2AClient
|
||||
from litellm.a2a_protocol.exceptions import (
|
||||
A2AAgentCardError,
|
||||
A2AConnectionError,
|
||||
A2AError,
|
||||
A2ALocalhostURLError,
|
||||
)
|
||||
from litellm.a2a_protocol.main import (
|
||||
aget_agent_card,
|
||||
asend_message,
|
||||
asend_message_streaming,
|
||||
create_a2a_client,
|
||||
send_message,
|
||||
)
|
||||
from litellm.types.agents import LiteLLMSendMessageResponse
|
||||
|
||||
__all__ = [
|
||||
# Client
|
||||
"A2AClient",
|
||||
# Functions
|
||||
"asend_message",
|
||||
"send_message",
|
||||
"asend_message_streaming",
|
||||
"aget_agent_card",
|
||||
"create_a2a_client",
|
||||
# Response types
|
||||
"LiteLLMSendMessageResponse",
|
||||
# Exceptions
|
||||
"A2AError",
|
||||
"A2AConnectionError",
|
||||
"A2AAgentCardError",
|
||||
"A2ALocalhostURLError",
|
||||
]
|
||||
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Custom A2A Card Resolver for LiteLLM.
|
||||
|
||||
Extends the A2A SDK's card resolver to support multiple well-known paths.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import LOCALHOST_URL_PATTERNS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import AgentCard
|
||||
|
||||
# Runtime imports with availability check
|
||||
_A2ACardResolver: Any = None
|
||||
AGENT_CARD_WELL_KNOWN_PATH: str = "/.well-known/agent-card.json"
|
||||
PREV_AGENT_CARD_WELL_KNOWN_PATH: str = "/.well-known/agent.json"
|
||||
|
||||
try:
|
||||
from a2a.client import A2ACardResolver as _A2ACardResolver # type: ignore[no-redef]
|
||||
from a2a.utils.constants import ( # type: ignore[no-redef]
|
||||
AGENT_CARD_WELL_KNOWN_PATH,
|
||||
PREV_AGENT_CARD_WELL_KNOWN_PATH,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def is_localhost_or_internal_url(url: Optional[str]) -> bool:
|
||||
"""
|
||||
Check if a URL is a localhost or internal URL.
|
||||
|
||||
This detects common development URLs that are accidentally left in
|
||||
agent cards when deploying to production.
|
||||
|
||||
Args:
|
||||
url: The URL to check
|
||||
|
||||
Returns:
|
||||
True if the URL is localhost/internal
|
||||
"""
|
||||
if not url:
|
||||
return False
|
||||
|
||||
url_lower = url.lower()
|
||||
|
||||
return any(pattern in url_lower for pattern in LOCALHOST_URL_PATTERNS)
|
||||
|
||||
|
||||
def fix_agent_card_url(agent_card: "AgentCard", base_url: str) -> "AgentCard":
|
||||
"""
|
||||
Fix the agent card URL if it contains a localhost/internal address.
|
||||
|
||||
Many A2A agents are deployed with agent cards that contain internal URLs
|
||||
like "http://0.0.0.0:8001/" or "http://localhost:8000/". This function
|
||||
replaces such URLs with the provided base_url.
|
||||
|
||||
Args:
|
||||
agent_card: The agent card to fix
|
||||
base_url: The base URL to use as replacement
|
||||
|
||||
Returns:
|
||||
The agent card with the URL fixed if necessary
|
||||
"""
|
||||
card_url = getattr(agent_card, "url", None)
|
||||
|
||||
if card_url and is_localhost_or_internal_url(card_url):
|
||||
# Normalize base_url to ensure it ends with /
|
||||
fixed_url = base_url.rstrip("/") + "/"
|
||||
agent_card.url = fixed_url
|
||||
|
||||
return agent_card
|
||||
|
||||
|
||||
class LiteLLMA2ACardResolver(_A2ACardResolver): # type: ignore[misc]
|
||||
"""
|
||||
Custom A2A card resolver that supports multiple well-known paths.
|
||||
|
||||
Extends the base A2ACardResolver to try both:
|
||||
- /.well-known/agent-card.json (standard)
|
||||
- /.well-known/agent.json (previous/alternative)
|
||||
"""
|
||||
|
||||
async def get_agent_card(
|
||||
self,
|
||||
relative_card_path: Optional[str] = None,
|
||||
http_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> "AgentCard":
|
||||
"""
|
||||
Fetch the agent card, trying multiple well-known paths.
|
||||
|
||||
First tries the standard path, then falls back to the previous path.
|
||||
|
||||
Args:
|
||||
relative_card_path: Optional path to the agent card endpoint.
|
||||
If None, tries both well-known paths.
|
||||
http_kwargs: Optional dictionary of keyword arguments to pass to httpx.get
|
||||
|
||||
Returns:
|
||||
AgentCard from the A2A agent
|
||||
|
||||
Raises:
|
||||
A2AClientHTTPError or A2AClientJSONError if both paths fail
|
||||
"""
|
||||
# If a specific path is provided, use the parent implementation
|
||||
if relative_card_path is not None:
|
||||
return await super().get_agent_card(
|
||||
relative_card_path=relative_card_path,
|
||||
http_kwargs=http_kwargs,
|
||||
)
|
||||
|
||||
# Try both well-known paths
|
||||
paths = [
|
||||
AGENT_CARD_WELL_KNOWN_PATH,
|
||||
PREV_AGENT_CARD_WELL_KNOWN_PATH,
|
||||
]
|
||||
|
||||
last_error = None
|
||||
for path in paths:
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
f"Attempting to fetch agent card from {self.base_url}{path}"
|
||||
)
|
||||
return await super().get_agent_card(
|
||||
relative_card_path=path,
|
||||
http_kwargs=http_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Failed to fetch agent card from {self.base_url}{path}: {e}"
|
||||
)
|
||||
last_error = e
|
||||
continue
|
||||
|
||||
# If we get here, all paths failed - re-raise the last error
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
|
||||
# This shouldn't happen, but just in case
|
||||
raise Exception(
|
||||
f"Failed to fetch agent card from {self.base_url}. "
|
||||
f"Tried paths: {', '.join(paths)}"
|
||||
)
|
||||
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
LiteLLM A2A Client class.
|
||||
|
||||
Provides a class-based interface for A2A agent invocation.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, AsyncIterator, Dict, Optional
|
||||
|
||||
from litellm.types.agents import LiteLLMSendMessageResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.client import A2AClient as A2AClientType
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
SendMessageRequest,
|
||||
SendStreamingMessageRequest,
|
||||
SendStreamingMessageResponse,
|
||||
)
|
||||
|
||||
|
||||
class A2AClient:
|
||||
"""
|
||||
LiteLLM wrapper for A2A agent invocation.
|
||||
|
||||
Creates the underlying A2A client once on first use and reuses it.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from litellm.a2a_protocol import A2AClient
|
||||
from a2a.types import SendMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
client = A2AClient(base_url="http://localhost:10001")
|
||||
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": "Hello!"}],
|
||||
"messageId": uuid4().hex,
|
||||
}
|
||||
)
|
||||
)
|
||||
response = await client.send_message(request)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
timeout: float = 60.0,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the A2A client wrapper.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the A2A agent (e.g., "http://localhost:10001")
|
||||
timeout: Request timeout in seconds (default: 60.0)
|
||||
extra_headers: Optional additional headers to include in requests
|
||||
"""
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
self.extra_headers = extra_headers
|
||||
self._a2a_client: Optional["A2AClientType"] = None
|
||||
|
||||
async def _get_client(self) -> "A2AClientType":
|
||||
"""Get or create the underlying A2A client."""
|
||||
if self._a2a_client is None:
|
||||
from litellm.a2a_protocol.main import create_a2a_client
|
||||
|
||||
self._a2a_client = await create_a2a_client(
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout,
|
||||
extra_headers=self.extra_headers,
|
||||
)
|
||||
return self._a2a_client
|
||||
|
||||
async def get_agent_card(self) -> "AgentCard":
|
||||
"""Fetch the agent card from the server."""
|
||||
from litellm.a2a_protocol.main import aget_agent_card
|
||||
|
||||
return await aget_agent_card(
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout,
|
||||
extra_headers=self.extra_headers,
|
||||
)
|
||||
|
||||
async def send_message(
|
||||
self, request: "SendMessageRequest"
|
||||
) -> LiteLLMSendMessageResponse:
|
||||
"""Send a message to the A2A agent."""
|
||||
from litellm.a2a_protocol.main import asend_message
|
||||
|
||||
a2a_client = await self._get_client()
|
||||
return await asend_message(a2a_client=a2a_client, request=request)
|
||||
|
||||
async def send_message_streaming(
|
||||
self, request: "SendStreamingMessageRequest"
|
||||
) -> AsyncIterator["SendStreamingMessageResponse"]:
|
||||
"""Send a streaming message to the A2A agent."""
|
||||
from litellm.a2a_protocol.main import asend_message_streaming
|
||||
|
||||
a2a_client = await self._get_client()
|
||||
async for chunk in asend_message_streaming(
|
||||
a2a_client=a2a_client, request=request
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Cost calculator for A2A (Agent-to-Agent) calls.
|
||||
|
||||
Supports dynamic cost parameters that allow platform owners
|
||||
to define custom costs per agent query or per token.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LitellmLoggingObject,
|
||||
)
|
||||
else:
|
||||
LitellmLoggingObject = Any
|
||||
|
||||
|
||||
class A2ACostCalculator:
|
||||
@staticmethod
|
||||
def calculate_a2a_cost(
|
||||
litellm_logging_obj: Optional[LitellmLoggingObject],
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the cost of an A2A send_message call.
|
||||
|
||||
Supports multiple cost parameters for platform owners:
|
||||
- cost_per_query: Fixed cost per query
|
||||
- input_cost_per_token + output_cost_per_token: Token-based pricing
|
||||
|
||||
Priority order:
|
||||
1. response_cost - if set directly (backward compatibility)
|
||||
2. cost_per_query - fixed cost per query
|
||||
3. input_cost_per_token + output_cost_per_token - token-based cost
|
||||
4. Default to 0.0
|
||||
|
||||
Args:
|
||||
litellm_logging_obj: The LiteLLM logging object containing call details
|
||||
|
||||
Returns:
|
||||
float: The cost of the A2A call
|
||||
"""
|
||||
if litellm_logging_obj is None:
|
||||
return 0.0
|
||||
|
||||
model_call_details = litellm_logging_obj.model_call_details
|
||||
|
||||
# Check if user set a custom response cost (backward compatibility)
|
||||
response_cost = model_call_details.get("response_cost", None)
|
||||
if response_cost is not None:
|
||||
return float(response_cost)
|
||||
|
||||
# Get litellm_params for cost parameters
|
||||
litellm_params = model_call_details.get("litellm_params", {}) or {}
|
||||
|
||||
# Check for cost_per_query (fixed cost per query)
|
||||
if litellm_params.get("cost_per_query") is not None:
|
||||
return float(litellm_params["cost_per_query"])
|
||||
|
||||
# Check for token-based pricing
|
||||
input_cost_per_token = litellm_params.get("input_cost_per_token")
|
||||
output_cost_per_token = litellm_params.get("output_cost_per_token")
|
||||
|
||||
if input_cost_per_token is not None or output_cost_per_token is not None:
|
||||
return A2ACostCalculator._calculate_token_based_cost(
|
||||
model_call_details=model_call_details,
|
||||
input_cost_per_token=input_cost_per_token,
|
||||
output_cost_per_token=output_cost_per_token,
|
||||
)
|
||||
|
||||
# Default to 0.0 for A2A calls
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def _calculate_token_based_cost(
|
||||
model_call_details: dict,
|
||||
input_cost_per_token: Optional[float],
|
||||
output_cost_per_token: Optional[float],
|
||||
) -> float:
|
||||
"""
|
||||
Calculate cost based on token usage and per-token pricing.
|
||||
|
||||
Args:
|
||||
model_call_details: The model call details containing usage
|
||||
input_cost_per_token: Cost per input token (can be None, defaults to 0)
|
||||
output_cost_per_token: Cost per output token (can be None, defaults to 0)
|
||||
|
||||
Returns:
|
||||
float: The calculated cost
|
||||
"""
|
||||
# Get usage from model_call_details
|
||||
usage = model_call_details.get("usage")
|
||||
if usage is None:
|
||||
return 0.0
|
||||
|
||||
# Get token counts
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
|
||||
# Calculate costs
|
||||
input_cost = prompt_tokens * (
|
||||
float(input_cost_per_token) if input_cost_per_token else 0.0
|
||||
)
|
||||
output_cost = completion_tokens * (
|
||||
float(output_cost_per_token) if output_cost_per_token else 0.0
|
||||
)
|
||||
|
||||
return input_cost + output_cost
|
||||
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
A2A Protocol Exception Mapping Utils.
|
||||
|
||||
Maps A2A SDK exceptions to LiteLLM A2A exception types.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.a2a_protocol.card_resolver import (
|
||||
fix_agent_card_url,
|
||||
is_localhost_or_internal_url,
|
||||
)
|
||||
from litellm.a2a_protocol.exceptions import (
|
||||
A2AAgentCardError,
|
||||
A2AConnectionError,
|
||||
A2AError,
|
||||
A2ALocalhostURLError,
|
||||
)
|
||||
from litellm.constants import CONNECTION_ERROR_PATTERNS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.client import A2AClient as A2AClientType
|
||||
|
||||
|
||||
# Runtime import
|
||||
A2A_SDK_AVAILABLE = False
|
||||
try:
|
||||
from a2a.client import A2AClient as _A2AClient # type: ignore[no-redef]
|
||||
|
||||
A2A_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
_A2AClient = None # type: ignore[assignment, misc]
|
||||
|
||||
|
||||
class A2AExceptionCheckers:
|
||||
"""
|
||||
Helper class for checking various A2A error conditions.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def is_connection_error(error_str: str) -> bool:
|
||||
"""
|
||||
Check if an error string indicates a connection error.
|
||||
|
||||
Args:
|
||||
error_str: The error string to check
|
||||
|
||||
Returns:
|
||||
True if the error indicates a connection issue
|
||||
"""
|
||||
if not isinstance(error_str, str):
|
||||
return False
|
||||
|
||||
error_str_lower = error_str.lower()
|
||||
return any(pattern in error_str_lower for pattern in CONNECTION_ERROR_PATTERNS)
|
||||
|
||||
@staticmethod
|
||||
def is_localhost_url(url: Optional[str]) -> bool:
|
||||
"""
|
||||
Check if a URL is a localhost/internal URL.
|
||||
|
||||
Args:
|
||||
url: The URL to check
|
||||
|
||||
Returns:
|
||||
True if the URL is localhost/internal
|
||||
"""
|
||||
return is_localhost_or_internal_url(url)
|
||||
|
||||
@staticmethod
|
||||
def is_agent_card_error(error_str: str) -> bool:
|
||||
"""
|
||||
Check if an error string indicates an agent card error.
|
||||
|
||||
Args:
|
||||
error_str: The error string to check
|
||||
|
||||
Returns:
|
||||
True if the error is related to agent card fetching/parsing
|
||||
"""
|
||||
if not isinstance(error_str, str):
|
||||
return False
|
||||
|
||||
error_str_lower = error_str.lower()
|
||||
agent_card_patterns = [
|
||||
"agent card",
|
||||
"agent-card",
|
||||
".well-known",
|
||||
"card not found",
|
||||
"invalid agent",
|
||||
]
|
||||
return any(pattern in error_str_lower for pattern in agent_card_patterns)
|
||||
|
||||
|
||||
def map_a2a_exception(
|
||||
original_exception: Exception,
|
||||
card_url: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> Exception:
|
||||
"""
|
||||
Map an A2A SDK exception to a LiteLLM A2A exception type.
|
||||
|
||||
Args:
|
||||
original_exception: The original exception from the A2A SDK
|
||||
card_url: The URL from the agent card (if available)
|
||||
api_base: The original API base URL
|
||||
model: The model/agent name
|
||||
|
||||
Returns:
|
||||
A mapped LiteLLM A2A exception
|
||||
|
||||
Raises:
|
||||
A2ALocalhostURLError: If the error is a connection error to a localhost URL
|
||||
A2AConnectionError: If the error is a general connection error
|
||||
A2AAgentCardError: If the error is related to agent card issues
|
||||
A2AError: For other A2A-related errors
|
||||
"""
|
||||
error_str = str(original_exception)
|
||||
|
||||
# Check for localhost URL connection error (special case - retryable)
|
||||
if (
|
||||
card_url
|
||||
and api_base
|
||||
and A2AExceptionCheckers.is_localhost_url(card_url)
|
||||
and A2AExceptionCheckers.is_connection_error(error_str)
|
||||
):
|
||||
raise A2ALocalhostURLError(
|
||||
localhost_url=card_url,
|
||||
base_url=api_base,
|
||||
original_error=original_exception,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Check for agent card errors
|
||||
if A2AExceptionCheckers.is_agent_card_error(error_str):
|
||||
raise A2AAgentCardError(
|
||||
message=error_str,
|
||||
url=api_base,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Check for general connection errors
|
||||
if A2AExceptionCheckers.is_connection_error(error_str):
|
||||
raise A2AConnectionError(
|
||||
message=error_str,
|
||||
url=card_url or api_base,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Default: wrap in generic A2AError
|
||||
raise A2AError(
|
||||
message=error_str,
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
def handle_a2a_localhost_retry(
|
||||
error: A2ALocalhostURLError,
|
||||
agent_card: Any,
|
||||
a2a_client: "A2AClientType",
|
||||
is_streaming: bool = False,
|
||||
) -> "A2AClientType":
|
||||
"""
|
||||
Handle A2ALocalhostURLError by fixing the URL and creating a new client.
|
||||
|
||||
This is called when we catch an A2ALocalhostURLError and want to retry
|
||||
with the corrected URL.
|
||||
|
||||
Args:
|
||||
error: The localhost URL error
|
||||
agent_card: The agent card object to fix
|
||||
a2a_client: The current A2A client
|
||||
is_streaming: Whether this is a streaming request (for logging)
|
||||
|
||||
Returns:
|
||||
A new A2A client with the fixed URL
|
||||
|
||||
Raises:
|
||||
ImportError: If the A2A SDK is not installed
|
||||
"""
|
||||
if not A2A_SDK_AVAILABLE or _A2AClient is None:
|
||||
raise ImportError(
|
||||
"A2A SDK is required for localhost retry handling. "
|
||||
"Install it with: pip install a2a"
|
||||
)
|
||||
|
||||
request_type = "streaming " if is_streaming else ""
|
||||
verbose_logger.warning(
|
||||
f"A2A {request_type}request to '{error.localhost_url}' failed: {error.original_error}. "
|
||||
f"Agent card contains localhost/internal URL. "
|
||||
f"Retrying with base_url '{error.base_url}'."
|
||||
)
|
||||
|
||||
# Fix the agent card URL
|
||||
fix_agent_card_url(agent_card, error.base_url)
|
||||
|
||||
# Create a new client with the fixed agent card (transport caches URL)
|
||||
return _A2AClient(
|
||||
httpx_client=a2a_client._transport.httpx_client, # type: ignore[union-attr]
|
||||
agent_card=agent_card,
|
||||
)
|
||||
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
A2A Protocol Exceptions.
|
||||
|
||||
Custom exception types for A2A protocol operations, following LiteLLM's exception pattern.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class A2AError(Exception):
|
||||
"""
|
||||
Base exception for A2A protocol errors.
|
||||
|
||||
Follows the same pattern as LiteLLM's main exceptions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
status_code: int = 500,
|
||||
llm_provider: str = "a2a_agent",
|
||||
model: Optional[str] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = f"litellm.A2AError: {message}"
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
self.response = response or httpx.Response(
|
||||
status_code=self.status_code,
|
||||
request=httpx.Request(method="POST", url="https://litellm.ai"),
|
||||
)
|
||||
super().__init__(self.message)
|
||||
|
||||
def __str__(self) -> str:
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class A2AConnectionError(A2AError):
|
||||
"""
|
||||
Raised when connection to an A2A agent fails.
|
||||
|
||||
This typically occurs when:
|
||||
- The agent is unreachable
|
||||
- The agent card contains a localhost/internal URL
|
||||
- Network issues prevent connection
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.url = url
|
||||
super().__init__(
|
||||
message=message,
|
||||
status_code=503,
|
||||
llm_provider="a2a_agent",
|
||||
model=model,
|
||||
response=response,
|
||||
litellm_debug_info=litellm_debug_info,
|
||||
max_retries=max_retries,
|
||||
num_retries=num_retries,
|
||||
)
|
||||
|
||||
|
||||
class A2AAgentCardError(A2AError):
|
||||
"""
|
||||
Raised when there's an issue with the agent card.
|
||||
|
||||
This includes:
|
||||
- Failed to fetch agent card
|
||||
- Invalid agent card format
|
||||
- Missing required fields
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.url = url
|
||||
super().__init__(
|
||||
message=message,
|
||||
status_code=404,
|
||||
llm_provider="a2a_agent",
|
||||
model=model,
|
||||
response=response,
|
||||
litellm_debug_info=litellm_debug_info,
|
||||
)
|
||||
|
||||
|
||||
class A2ALocalhostURLError(A2AConnectionError):
|
||||
"""
|
||||
Raised when an agent card contains a localhost/internal URL.
|
||||
|
||||
Many A2A agents are deployed with agent cards that contain internal URLs
|
||||
like "http://0.0.0.0:8001/" or "http://localhost:8000/". This error
|
||||
indicates that the URL needs to be corrected and the request should be retried.
|
||||
|
||||
Attributes:
|
||||
localhost_url: The localhost/internal URL found in the agent card
|
||||
base_url: The public base URL that should be used instead
|
||||
original_error: The original connection error that was raised
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
localhost_url: str,
|
||||
base_url: str,
|
||||
original_error: Optional[Exception] = None,
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
self.localhost_url = localhost_url
|
||||
self.base_url = base_url
|
||||
self.original_error = original_error
|
||||
|
||||
message = (
|
||||
f"Agent card contains localhost/internal URL '{localhost_url}'. "
|
||||
f"Retrying with base URL '{base_url}'."
|
||||
)
|
||||
super().__init__(
|
||||
message=message,
|
||||
url=localhost_url,
|
||||
model=model,
|
||||
)
|
||||
@@ -0,0 +1,74 @@
|
||||
# A2A to LiteLLM Completion Bridge
|
||||
|
||||
Routes A2A protocol requests through `litellm.acompletion`, enabling any LiteLLM-supported provider to be invoked via A2A.
|
||||
|
||||
## Flow
|
||||
|
||||
```
|
||||
A2A Request → Transform → litellm.acompletion → Transform → A2A Response
|
||||
```
|
||||
|
||||
## SDK Usage
|
||||
|
||||
Use the existing `asend_message` and `asend_message_streaming` functions with `litellm_params`:
|
||||
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message, asend_message_streaming
|
||||
from a2a.types import SendMessageRequest, SendStreamingMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
# Non-streaming
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
response = await asend_message(
|
||||
request=request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
)
|
||||
|
||||
# Streaming
|
||||
stream_request = SendStreamingMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
async for chunk in asend_message_streaming(
|
||||
request=stream_request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
):
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
## Proxy Usage
|
||||
|
||||
Configure an agent with `custom_llm_provider` in `litellm_params`:
|
||||
|
||||
```yaml
|
||||
agents:
|
||||
- agent_name: my-langgraph-agent
|
||||
agent_card_params:
|
||||
name: "LangGraph Agent"
|
||||
url: "http://localhost:2024" # Used as api_base
|
||||
litellm_params:
|
||||
custom_llm_provider: langgraph
|
||||
model: agent
|
||||
```
|
||||
|
||||
When an A2A request hits `/a2a/{agent_id}/message/send`, the bridge:
|
||||
|
||||
1. Detects `custom_llm_provider` in agent's `litellm_params`
|
||||
2. Transforms A2A message → OpenAI messages
|
||||
3. Calls `litellm.acompletion(model="langgraph/agent", api_base="http://localhost:2024")`
|
||||
4. Transforms response → A2A format
|
||||
|
||||
## Classes
|
||||
|
||||
- `A2ACompletionBridgeTransformation` - Static methods for message format conversion
|
||||
- `A2ACompletionBridgeHandler` - Static methods for handling requests (streaming/non-streaming)
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
A2A to LiteLLM Completion Bridge.
|
||||
|
||||
This module provides transformation between A2A protocol messages and
|
||||
LiteLLM completion API, enabling any LiteLLM-supported provider to be
|
||||
invoked via the A2A protocol.
|
||||
"""
|
||||
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
|
||||
A2ACompletionBridgeHandler,
|
||||
handle_a2a_completion,
|
||||
handle_a2a_completion_streaming,
|
||||
)
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.transformation import (
|
||||
A2ACompletionBridgeTransformation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"A2ACompletionBridgeTransformation",
|
||||
"A2ACompletionBridgeHandler",
|
||||
"handle_a2a_completion",
|
||||
"handle_a2a_completion_streaming",
|
||||
]
|
||||
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
Handler for A2A to LiteLLM completion bridge.
|
||||
|
||||
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
|
||||
|
||||
A2A Streaming Events (in order):
|
||||
1. Task event (kind: "task") - Initial task creation with status "submitted"
|
||||
2. Status update (kind: "status-update") - Status change to "working"
|
||||
3. Artifact update (kind: "artifact-update") - Content/artifact delivery
|
||||
4. Status update (kind: "status-update") - Final status "completed" with final=true
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.transformation import (
|
||||
A2ACompletionBridgeTransformation,
|
||||
A2AStreamingContext,
|
||||
)
|
||||
from litellm.a2a_protocol.providers.config_manager import A2AProviderConfigManager
|
||||
|
||||
|
||||
class A2ACompletionBridgeHandler:
|
||||
"""
|
||||
Static methods for handling A2A requests via LiteLLM completion.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def handle_non_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle non-streaming A2A request via litellm.acompletion.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
|
||||
api_base: API base URL from agent_card_params
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
# Get provider config for custom_llm_provider
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
a2a_provider_config = A2AProviderConfigManager.get_provider_config(
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
# If provider config exists, use it
|
||||
if a2a_provider_config is not None:
|
||||
if api_base is None:
|
||||
raise ValueError(f"api_base is required for {custom_llm_provider}")
|
||||
|
||||
verbose_logger.info(f"A2A: Using provider config for {custom_llm_provider}")
|
||||
|
||||
response_data = await a2a_provider_config.handle_non_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
# Extract message from params
|
||||
message = params.get("message", {})
|
||||
|
||||
# Transform A2A message to OpenAI format
|
||||
openai_messages = (
|
||||
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
|
||||
)
|
||||
|
||||
# Get completion params
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
model = litellm_params.get("model", "agent")
|
||||
|
||||
# Build full model string if provider specified
|
||||
# Skip prepending if model already starts with the provider prefix
|
||||
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
|
||||
full_model = f"{custom_llm_provider}/{model}"
|
||||
else:
|
||||
full_model = model
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge: model={full_model}, api_base={api_base}"
|
||||
)
|
||||
|
||||
# Build completion params dict
|
||||
completion_params = {
|
||||
"model": full_model,
|
||||
"messages": openai_messages,
|
||||
"api_base": api_base,
|
||||
"stream": False,
|
||||
}
|
||||
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
|
||||
litellm_params_to_add = {
|
||||
k: v
|
||||
for k, v in litellm_params.items()
|
||||
if k not in ("model", "custom_llm_provider")
|
||||
}
|
||||
completion_params.update(litellm_params_to_add)
|
||||
|
||||
# Call litellm.acompletion
|
||||
response = await litellm.acompletion(**completion_params)
|
||||
|
||||
# Transform response to A2A format
|
||||
a2a_response = (
|
||||
A2ACompletionBridgeTransformation.openai_response_to_a2a_response(
|
||||
response=response,
|
||||
request_id=request_id,
|
||||
)
|
||||
)
|
||||
|
||||
verbose_logger.info(f"A2A completion bridge completed: request_id={request_id}")
|
||||
|
||||
return a2a_response
|
||||
|
||||
@staticmethod
|
||||
async def handle_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Handle streaming A2A request via litellm.acompletion with stream=True.
|
||||
|
||||
Emits proper A2A streaming events:
|
||||
1. Task event (kind: "task") - Initial task with status "submitted"
|
||||
2. Status update (kind: "status-update") - Status "working"
|
||||
3. Artifact update (kind: "artifact-update") - Content delivery
|
||||
4. Status update (kind: "status-update") - Final "completed" status
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
|
||||
api_base: API base URL from agent_card_params
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
# Get provider config for custom_llm_provider
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
a2a_provider_config = A2AProviderConfigManager.get_provider_config(
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
# If provider config exists, use it
|
||||
if a2a_provider_config is not None:
|
||||
if api_base is None:
|
||||
raise ValueError(f"api_base is required for {custom_llm_provider}")
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A: Using provider config for {custom_llm_provider} (streaming)"
|
||||
)
|
||||
|
||||
async for chunk in a2a_provider_config.handle_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
api_base=api_base,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return
|
||||
|
||||
# Extract message from params
|
||||
message = params.get("message", {})
|
||||
|
||||
# Create streaming context
|
||||
ctx = A2AStreamingContext(
|
||||
request_id=request_id,
|
||||
input_message=message,
|
||||
)
|
||||
|
||||
# Transform A2A message to OpenAI format
|
||||
openai_messages = (
|
||||
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
|
||||
)
|
||||
|
||||
# Get completion params
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
model = litellm_params.get("model", "agent")
|
||||
|
||||
# Build full model string if provider specified
|
||||
# Skip prepending if model already starts with the provider prefix
|
||||
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
|
||||
full_model = f"{custom_llm_provider}/{model}"
|
||||
else:
|
||||
full_model = model
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge streaming: model={full_model}, api_base={api_base}"
|
||||
)
|
||||
|
||||
# Build completion params dict
|
||||
completion_params = {
|
||||
"model": full_model,
|
||||
"messages": openai_messages,
|
||||
"api_base": api_base,
|
||||
"stream": True,
|
||||
}
|
||||
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
|
||||
litellm_params_to_add = {
|
||||
k: v
|
||||
for k, v in litellm_params.items()
|
||||
if k not in ("model", "custom_llm_provider")
|
||||
}
|
||||
completion_params.update(litellm_params_to_add)
|
||||
|
||||
# 1. Emit initial task event (kind: "task", status: "submitted")
|
||||
task_event = A2ACompletionBridgeTransformation.create_task_event(ctx)
|
||||
yield task_event
|
||||
|
||||
# 2. Emit status update (kind: "status-update", status: "working")
|
||||
working_event = A2ACompletionBridgeTransformation.create_status_update_event(
|
||||
ctx=ctx,
|
||||
state="working",
|
||||
final=False,
|
||||
message_text="Processing request...",
|
||||
)
|
||||
yield working_event
|
||||
|
||||
# Call litellm.acompletion with streaming
|
||||
response = await litellm.acompletion(**completion_params)
|
||||
|
||||
# 3. Accumulate content and emit artifact update
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
async for chunk in response: # type: ignore[union-attr]
|
||||
chunk_count += 1
|
||||
|
||||
# Extract delta content
|
||||
content = ""
|
||||
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
content = choice.delta.content or ""
|
||||
|
||||
if content:
|
||||
accumulated_text += content
|
||||
|
||||
# Emit artifact update with accumulated content
|
||||
if accumulated_text:
|
||||
artifact_event = (
|
||||
A2ACompletionBridgeTransformation.create_artifact_update_event(
|
||||
ctx=ctx,
|
||||
text=accumulated_text,
|
||||
)
|
||||
)
|
||||
yield artifact_event
|
||||
|
||||
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
|
||||
completed_event = A2ACompletionBridgeTransformation.create_status_update_event(
|
||||
ctx=ctx,
|
||||
state="completed",
|
||||
final=True,
|
||||
)
|
||||
yield completed_event
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge streaming completed: request_id={request_id}, chunks={chunk_count}"
|
||||
)
|
||||
|
||||
|
||||
# Convenience functions that delegate to the class methods
|
||||
async def handle_a2a_completion(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Convenience function for non-streaming A2A completion."""
|
||||
return await A2ACompletionBridgeHandler.handle_non_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
|
||||
async def handle_a2a_completion_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Convenience function for streaming A2A completion."""
|
||||
async for chunk in A2ACompletionBridgeHandler.handle_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Transformation utilities for A2A <-> OpenAI message format conversion.
|
||||
|
||||
A2A Message Format:
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": "Hello!"}],
|
||||
"messageId": "abc123"
|
||||
}
|
||||
|
||||
OpenAI Message Format:
|
||||
{"role": "user", "content": "Hello!"}
|
||||
|
||||
A2A Streaming Events:
|
||||
- Task event (kind: "task") - Initial task creation with status "submitted"
|
||||
- Status update (kind: "status-update") - Status changes (working, completed)
|
||||
- Artifact update (kind: "artifact-update") - Content/artifact delivery
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
class A2AStreamingContext:
|
||||
"""
|
||||
Context holder for A2A streaming state.
|
||||
Tracks task_id, context_id, and message accumulation.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str, input_message: Dict[str, Any]):
|
||||
self.request_id = request_id
|
||||
self.task_id = str(uuid4())
|
||||
self.context_id = str(uuid4())
|
||||
self.input_message = input_message
|
||||
self.accumulated_text = ""
|
||||
self.has_emitted_task = False
|
||||
self.has_emitted_working = False
|
||||
|
||||
|
||||
class A2ACompletionBridgeTransformation:
|
||||
"""
|
||||
Static methods for transforming between A2A and OpenAI message formats.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def a2a_message_to_openai_messages(
|
||||
a2a_message: Dict[str, Any],
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Transform an A2A message to OpenAI message format.
|
||||
|
||||
Args:
|
||||
a2a_message: A2A message with role, parts, and messageId
|
||||
|
||||
Returns:
|
||||
List of OpenAI-format messages
|
||||
"""
|
||||
role = a2a_message.get("role", "user")
|
||||
parts = a2a_message.get("parts", [])
|
||||
|
||||
# Map A2A roles to OpenAI roles
|
||||
openai_role = role
|
||||
if role == "user":
|
||||
openai_role = "user"
|
||||
elif role == "assistant":
|
||||
openai_role = "assistant"
|
||||
elif role == "system":
|
||||
openai_role = "system"
|
||||
|
||||
# Extract text content from parts
|
||||
content_parts = []
|
||||
for part in parts:
|
||||
kind = part.get("kind", "")
|
||||
if kind == "text":
|
||||
text = part.get("text", "")
|
||||
content_parts.append(text)
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else ""
|
||||
|
||||
verbose_logger.debug(
|
||||
f"A2A -> OpenAI transform: role={role} -> {openai_role}, content_length={len(content)}"
|
||||
)
|
||||
|
||||
return [{"role": openai_role, "content": content}]
|
||||
|
||||
@staticmethod
|
||||
def openai_response_to_a2a_response(
|
||||
response: Any,
|
||||
request_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform a LiteLLM ModelResponse to A2A SendMessageResponse format.
|
||||
|
||||
Args:
|
||||
response: LiteLLM ModelResponse object
|
||||
request_id: Original A2A request ID
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
# Extract content from response
|
||||
content = ""
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "message") and choice.message:
|
||||
content = choice.message.content or ""
|
||||
|
||||
# Build A2A message
|
||||
a2a_message = {
|
||||
"role": "agent",
|
||||
"parts": [{"kind": "text", "text": content}],
|
||||
"messageId": uuid4().hex,
|
||||
}
|
||||
|
||||
# Build A2A response
|
||||
a2a_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"message": a2a_message,
|
||||
},
|
||||
}
|
||||
|
||||
verbose_logger.debug(f"OpenAI -> A2A transform: content_length={len(content)}")
|
||||
|
||||
return a2a_response
|
||||
|
||||
@staticmethod
|
||||
def _get_timestamp() -> str:
|
||||
"""Get current timestamp in ISO format with timezone."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
@staticmethod
|
||||
def create_task_event(
|
||||
ctx: A2AStreamingContext,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create the initial task event with status 'submitted'.
|
||||
|
||||
This is the first event emitted in an A2A streaming response.
|
||||
"""
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"contextId": ctx.context_id,
|
||||
"history": [
|
||||
{
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "message",
|
||||
"messageId": ctx.input_message.get("messageId", uuid4().hex),
|
||||
"parts": ctx.input_message.get("parts", []),
|
||||
"role": ctx.input_message.get("role", "user"),
|
||||
"taskId": ctx.task_id,
|
||||
}
|
||||
],
|
||||
"id": ctx.task_id,
|
||||
"kind": "task",
|
||||
"status": {
|
||||
"state": "submitted",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_status_update_event(
|
||||
ctx: A2AStreamingContext,
|
||||
state: str,
|
||||
final: bool = False,
|
||||
message_text: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a status update event.
|
||||
|
||||
Args:
|
||||
ctx: Streaming context
|
||||
state: Status state ('working', 'completed')
|
||||
final: Whether this is the final event
|
||||
message_text: Optional message text for 'working' status
|
||||
"""
|
||||
status: Dict[str, Any] = {
|
||||
"state": state,
|
||||
"timestamp": A2ACompletionBridgeTransformation._get_timestamp(),
|
||||
}
|
||||
|
||||
# Add message for 'working' status
|
||||
if state == "working" and message_text:
|
||||
status["message"] = {
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "message",
|
||||
"messageId": str(uuid4()),
|
||||
"parts": [{"kind": "text", "text": message_text}],
|
||||
"role": "agent",
|
||||
"taskId": ctx.task_id,
|
||||
}
|
||||
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"contextId": ctx.context_id,
|
||||
"final": final,
|
||||
"kind": "status-update",
|
||||
"status": status,
|
||||
"taskId": ctx.task_id,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_artifact_update_event(
|
||||
ctx: A2AStreamingContext,
|
||||
text: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create an artifact update event with content.
|
||||
|
||||
Args:
|
||||
ctx: Streaming context
|
||||
text: The text content for the artifact
|
||||
"""
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"artifact": {
|
||||
"artifactId": str(uuid4()),
|
||||
"name": "response",
|
||||
"parts": [{"kind": "text", "text": text}],
|
||||
},
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "artifact-update",
|
||||
"taskId": ctx.task_id,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def openai_chunk_to_a2a_chunk(
|
||||
chunk: Any,
|
||||
request_id: Optional[str] = None,
|
||||
is_final: bool = False,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Transform a LiteLLM streaming chunk to A2A streaming format.
|
||||
|
||||
NOTE: This method is deprecated for streaming. Use the event-based
|
||||
methods (create_task_event, create_status_update_event,
|
||||
create_artifact_update_event) instead for proper A2A streaming.
|
||||
|
||||
Args:
|
||||
chunk: LiteLLM ModelResponse chunk
|
||||
request_id: Original A2A request ID
|
||||
is_final: Whether this is the final chunk
|
||||
|
||||
Returns:
|
||||
A2A streaming chunk dict or None if no content
|
||||
"""
|
||||
# Extract delta content
|
||||
content = ""
|
||||
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
content = choice.delta.content or ""
|
||||
|
||||
if not content and not is_final:
|
||||
return None
|
||||
|
||||
# Build A2A streaming chunk (legacy format)
|
||||
a2a_chunk = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"message": {
|
||||
"role": "agent",
|
||||
"parts": [{"kind": "text", "text": content}],
|
||||
"messageId": uuid4().hex,
|
||||
},
|
||||
"final": is_final,
|
||||
},
|
||||
}
|
||||
|
||||
return a2a_chunk
|
||||
@@ -0,0 +1,744 @@
|
||||
"""
|
||||
LiteLLM A2A SDK functions.
|
||||
|
||||
Provides standalone functions with @client decorator for LiteLLM logging integration.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Coroutine, Dict, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm.a2a_protocol.streaming_iterator import A2AStreamingIterator
|
||||
from litellm.a2a_protocol.utils import A2ARequestUtils
|
||||
from litellm.constants import DEFAULT_A2A_AGENT_TIMEOUT
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.agents import LiteLLMSendMessageResponse
|
||||
from litellm.utils import client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.client import A2AClient as A2AClientType
|
||||
from a2a.types import AgentCard, SendMessageRequest, SendStreamingMessageRequest
|
||||
|
||||
# Runtime imports with availability check
|
||||
A2A_SDK_AVAILABLE = False
|
||||
A2ACardResolver: Any = None
|
||||
_A2AClient: Any = None
|
||||
|
||||
try:
|
||||
from a2a.client import A2AClient as _A2AClient # type: ignore[no-redef]
|
||||
|
||||
A2A_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Import our custom card resolver that supports multiple well-known paths
|
||||
from litellm.a2a_protocol.card_resolver import LiteLLMA2ACardResolver
|
||||
from litellm.a2a_protocol.exception_mapping_utils import (
|
||||
handle_a2a_localhost_retry,
|
||||
map_a2a_exception,
|
||||
)
|
||||
from litellm.a2a_protocol.exceptions import A2ALocalhostURLError
|
||||
|
||||
# Use our custom resolver instead of the default A2A SDK resolver
|
||||
A2ACardResolver = LiteLLMA2ACardResolver
|
||||
|
||||
|
||||
def _set_usage_on_logging_obj(
|
||||
kwargs: Dict[str, Any],
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Set usage on litellm_logging_obj for standard logging payload.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs dict containing litellm_logging_obj
|
||||
prompt_tokens: Number of input tokens
|
||||
completion_tokens: Number of output tokens
|
||||
"""
|
||||
litellm_logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if litellm_logging_obj is not None:
|
||||
usage = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
litellm_logging_obj.model_call_details["usage"] = usage
|
||||
|
||||
|
||||
def _set_agent_id_on_logging_obj(
|
||||
kwargs: Dict[str, Any],
|
||||
agent_id: Optional[str],
|
||||
) -> None:
|
||||
"""
|
||||
Set agent_id on litellm_logging_obj for SpendLogs tracking.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs dict containing litellm_logging_obj
|
||||
agent_id: The A2A agent ID
|
||||
"""
|
||||
if agent_id is None:
|
||||
return
|
||||
|
||||
litellm_logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if litellm_logging_obj is not None:
|
||||
# Set agent_id directly on model_call_details (same pattern as custom_llm_provider)
|
||||
litellm_logging_obj.model_call_details["agent_id"] = agent_id
|
||||
|
||||
|
||||
def _get_a2a_model_info(a2a_client: Any, kwargs: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract agent info and set model/custom_llm_provider for cost tracking.
|
||||
|
||||
Sets model info on the litellm_logging_obj if available.
|
||||
Returns the agent name for logging.
|
||||
"""
|
||||
agent_name = "unknown"
|
||||
|
||||
# Try to get agent card from our stored attribute first, then fallback to SDK attribute
|
||||
agent_card = getattr(a2a_client, "_litellm_agent_card", None)
|
||||
if agent_card is None:
|
||||
agent_card = getattr(a2a_client, "agent_card", None)
|
||||
|
||||
if agent_card is not None:
|
||||
agent_name = getattr(agent_card, "name", "unknown") or "unknown"
|
||||
|
||||
# Build model string
|
||||
model = f"a2a_agent/{agent_name}"
|
||||
custom_llm_provider = "a2a_agent"
|
||||
|
||||
# Set on litellm_logging_obj if available (for standard logging payload)
|
||||
litellm_logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if litellm_logging_obj is not None:
|
||||
litellm_logging_obj.model = model
|
||||
litellm_logging_obj.custom_llm_provider = custom_llm_provider
|
||||
litellm_logging_obj.model_call_details["model"] = model
|
||||
litellm_logging_obj.model_call_details[
|
||||
"custom_llm_provider"
|
||||
] = custom_llm_provider
|
||||
|
||||
return agent_name
|
||||
|
||||
|
||||
async def _send_message_via_completion_bridge(
|
||||
request: "SendMessageRequest",
|
||||
custom_llm_provider: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: Dict[str, Any],
|
||||
) -> LiteLLMSendMessageResponse:
|
||||
"""
|
||||
Route a send_message through the LiteLLM completion bridge (e.g. LangGraph, Bedrock AgentCore).
|
||||
|
||||
Requires request; api_base is optional for providers that derive endpoint from model.
|
||||
"""
|
||||
verbose_logger.info(
|
||||
f"A2A using completion bridge: provider={custom_llm_provider}, api_base={api_base}"
|
||||
)
|
||||
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
|
||||
A2ACompletionBridgeHandler,
|
||||
)
|
||||
|
||||
params = (
|
||||
request.params.model_dump(mode="json")
|
||||
if hasattr(request.params, "model_dump")
|
||||
else dict(request.params)
|
||||
)
|
||||
|
||||
response_dict = await A2ACompletionBridgeHandler.handle_non_streaming(
|
||||
request_id=str(request.id),
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
return LiteLLMSendMessageResponse.from_dict(response_dict)
|
||||
|
||||
|
||||
async def _execute_a2a_send_with_retry(
|
||||
a2a_client: Any,
|
||||
request: Any,
|
||||
agent_card: Any,
|
||||
card_url: Optional[str],
|
||||
api_base: Optional[str],
|
||||
agent_name: Optional[str],
|
||||
) -> Any:
|
||||
"""Send an A2A message with retry logic for localhost URL errors."""
|
||||
a2a_response = None
|
||||
for _ in range(2): # max 2 attempts: original + 1 retry
|
||||
try:
|
||||
a2a_response = await a2a_client.send_message(request)
|
||||
break # success, exit retry loop
|
||||
except A2ALocalhostURLError as e:
|
||||
a2a_client = handle_a2a_localhost_retry(
|
||||
error=e,
|
||||
agent_card=agent_card,
|
||||
a2a_client=a2a_client,
|
||||
is_streaming=False,
|
||||
)
|
||||
card_url = agent_card.url if agent_card else None
|
||||
except Exception as e:
|
||||
try:
|
||||
map_a2a_exception(e, card_url, api_base, model=agent_name)
|
||||
except A2ALocalhostURLError as localhost_err:
|
||||
a2a_client = handle_a2a_localhost_retry(
|
||||
error=localhost_err,
|
||||
agent_card=agent_card,
|
||||
a2a_client=a2a_client,
|
||||
is_streaming=False,
|
||||
)
|
||||
card_url = agent_card.url if agent_card else None
|
||||
continue
|
||||
except Exception:
|
||||
raise
|
||||
if a2a_response is None:
|
||||
raise RuntimeError(
|
||||
"A2A send_message failed: no response received after retry attempts."
|
||||
)
|
||||
return a2a_response
|
||||
|
||||
|
||||
@client
|
||||
async def asend_message(
|
||||
a2a_client: Optional["A2AClientType"] = None,
|
||||
request: Optional["SendMessageRequest"] = None,
|
||||
api_base: Optional[str] = None,
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
agent_extra_headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LiteLLMSendMessageResponse:
|
||||
"""
|
||||
Async: Send a message to an A2A agent.
|
||||
|
||||
Uses the @client decorator for LiteLLM logging and tracking.
|
||||
If litellm_params contains custom_llm_provider, routes through the completion bridge.
|
||||
|
||||
Args:
|
||||
a2a_client: An initialized a2a.client.A2AClient instance (optional if using completion bridge)
|
||||
request: SendMessageRequest from a2a.types (optional if using completion bridge with api_base)
|
||||
api_base: API base URL (required for completion bridge, optional for standard A2A)
|
||||
litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge
|
||||
agent_id: Optional agent ID for tracking in SpendLogs
|
||||
**kwargs: Additional arguments passed to the client decorator
|
||||
|
||||
Returns:
|
||||
LiteLLMSendMessageResponse (wraps a2a SendMessageResponse with _hidden_params)
|
||||
|
||||
Example (standard A2A):
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message, create_a2a_client
|
||||
from a2a.types import SendMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
a2a_client = await create_a2a_client(base_url="http://localhost:10001")
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
response = await asend_message(a2a_client=a2a_client, request=request)
|
||||
```
|
||||
|
||||
Example (completion bridge with LangGraph):
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message
|
||||
from a2a.types import SendMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
response = await asend_message(
|
||||
request=request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
)
|
||||
```
|
||||
"""
|
||||
litellm_params = litellm_params or {}
|
||||
logging_obj = kwargs.get("litellm_logging_obj")
|
||||
trace_id = getattr(logging_obj, "litellm_trace_id", None) if logging_obj else None
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
|
||||
# Route through completion bridge if custom_llm_provider is set
|
||||
if custom_llm_provider:
|
||||
if request is None:
|
||||
raise ValueError("request is required for completion bridge")
|
||||
return await _send_message_via_completion_bridge(
|
||||
request=request,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
# Standard A2A client flow
|
||||
if request is None:
|
||||
raise ValueError("request is required")
|
||||
|
||||
# Create A2A client if not provided but api_base is available
|
||||
if a2a_client is None:
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Either a2a_client or api_base is required for standard A2A flow"
|
||||
)
|
||||
trace_id = trace_id or str(uuid.uuid4())
|
||||
extra_headers: Dict[str, str] = {"X-LiteLLM-Trace-Id": trace_id}
|
||||
if agent_id:
|
||||
extra_headers["X-LiteLLM-Agent-Id"] = agent_id
|
||||
# Overlay agent-level headers (agent headers take precedence over LiteLLM internal ones)
|
||||
if agent_extra_headers:
|
||||
extra_headers.update(agent_extra_headers)
|
||||
a2a_client = await create_a2a_client(
|
||||
base_url=api_base, extra_headers=extra_headers
|
||||
)
|
||||
|
||||
# Type assertion: a2a_client is guaranteed to be non-None here
|
||||
assert a2a_client is not None
|
||||
|
||||
agent_name = _get_a2a_model_info(a2a_client, kwargs)
|
||||
|
||||
verbose_logger.info(f"A2A send_message request_id={request.id}, agent={agent_name}")
|
||||
|
||||
# Get agent card URL for localhost retry logic
|
||||
agent_card = getattr(a2a_client, "_litellm_agent_card", None) or getattr(
|
||||
a2a_client, "agent_card", None
|
||||
)
|
||||
card_url = getattr(agent_card, "url", None) if agent_card else None
|
||||
|
||||
context_id = trace_id or str(uuid.uuid4())
|
||||
message = request.params.message
|
||||
if isinstance(message, dict):
|
||||
if message.get("context_id") is None:
|
||||
message["context_id"] = context_id
|
||||
else:
|
||||
if getattr(message, "context_id", None) is None:
|
||||
message.context_id = context_id
|
||||
|
||||
a2a_response = await _execute_a2a_send_with_retry(
|
||||
a2a_client=a2a_client,
|
||||
request=request,
|
||||
agent_card=agent_card,
|
||||
card_url=card_url,
|
||||
api_base=api_base,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
verbose_logger.info(f"A2A send_message completed, request_id={request.id}")
|
||||
|
||||
# Wrap in LiteLLM response type for _hidden_params support
|
||||
response = LiteLLMSendMessageResponse.from_a2a_response(a2a_response)
|
||||
|
||||
# Calculate token usage from request and response
|
||||
response_dict = a2a_response.model_dump(mode="json", exclude_none=True)
|
||||
(
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
_,
|
||||
) = A2ARequestUtils.calculate_usage_from_request_response(
|
||||
request=request,
|
||||
response_dict=response_dict,
|
||||
)
|
||||
|
||||
# Set usage on logging obj for standard logging payload
|
||||
_set_usage_on_logging_obj(
|
||||
kwargs=kwargs,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
# Set agent_id on logging obj for SpendLogs tracking
|
||||
_set_agent_id_on_logging_obj(kwargs=kwargs, agent_id=agent_id)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@client
|
||||
def send_message(
|
||||
a2a_client: "A2AClientType",
|
||||
request: "SendMessageRequest",
|
||||
**kwargs: Any,
|
||||
) -> Union[LiteLLMSendMessageResponse, Coroutine[Any, Any, LiteLLMSendMessageResponse]]:
|
||||
"""
|
||||
Sync: Send a message to an A2A agent.
|
||||
|
||||
Uses the @client decorator for LiteLLM logging and tracking.
|
||||
|
||||
Args:
|
||||
a2a_client: An initialized a2a.client.A2AClient instance
|
||||
request: SendMessageRequest from a2a.types
|
||||
**kwargs: Additional arguments passed to the client decorator
|
||||
|
||||
Returns:
|
||||
LiteLLMSendMessageResponse (wraps a2a SendMessageResponse with _hidden_params)
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop is not None:
|
||||
return asend_message(a2a_client=a2a_client, request=request, **kwargs)
|
||||
else:
|
||||
return asyncio.run(
|
||||
asend_message(a2a_client=a2a_client, request=request, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
def _build_streaming_logging_obj(
|
||||
request: "SendStreamingMessageRequest",
|
||||
agent_name: str,
|
||||
agent_id: Optional[str],
|
||||
litellm_params: Optional[Dict[str, Any]],
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
proxy_server_request: Optional[Dict[str, Any]],
|
||||
) -> Logging:
|
||||
"""Build logging object for streaming A2A requests."""
|
||||
start_time = datetime.datetime.now()
|
||||
model = f"a2a_agent/{agent_name}"
|
||||
|
||||
logging_obj = Logging(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "streaming-request"}],
|
||||
stream=False,
|
||||
call_type="asend_message_streaming",
|
||||
start_time=start_time,
|
||||
litellm_call_id=str(request.id),
|
||||
function_id=str(request.id),
|
||||
)
|
||||
logging_obj.model = model
|
||||
logging_obj.custom_llm_provider = "a2a_agent"
|
||||
logging_obj.model_call_details["model"] = model
|
||||
logging_obj.model_call_details["custom_llm_provider"] = "a2a_agent"
|
||||
if agent_id:
|
||||
logging_obj.model_call_details["agent_id"] = agent_id
|
||||
|
||||
_litellm_params = litellm_params.copy() if litellm_params else {}
|
||||
if metadata:
|
||||
_litellm_params["metadata"] = metadata
|
||||
if proxy_server_request:
|
||||
_litellm_params["proxy_server_request"] = proxy_server_request
|
||||
|
||||
logging_obj.litellm_params = _litellm_params
|
||||
logging_obj.optional_params = _litellm_params
|
||||
logging_obj.model_call_details["litellm_params"] = _litellm_params
|
||||
logging_obj.model_call_details["metadata"] = metadata or {}
|
||||
|
||||
return logging_obj
|
||||
|
||||
|
||||
async def asend_message_streaming( # noqa: PLR0915
|
||||
a2a_client: Optional["A2AClientType"] = None,
|
||||
request: Optional["SendStreamingMessageRequest"] = None,
|
||||
api_base: Optional[str] = None,
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
proxy_server_request: Optional[Dict[str, Any]] = None,
|
||||
agent_extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> AsyncIterator[Any]:
|
||||
"""
|
||||
Async: Send a streaming message to an A2A agent.
|
||||
|
||||
If litellm_params contains custom_llm_provider, routes through the completion bridge.
|
||||
|
||||
Args:
|
||||
a2a_client: An initialized a2a.client.A2AClient instance (optional if using completion bridge)
|
||||
request: SendStreamingMessageRequest from a2a.types
|
||||
api_base: API base URL (required for completion bridge)
|
||||
litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge
|
||||
agent_id: Optional agent ID for tracking in SpendLogs
|
||||
metadata: Optional metadata dict (contains user_api_key, user_id, team_id, etc.)
|
||||
proxy_server_request: Optional proxy server request data
|
||||
|
||||
Yields:
|
||||
SendStreamingMessageResponse chunks from the agent
|
||||
|
||||
Example (completion bridge with LangGraph):
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message_streaming
|
||||
from a2a.types import SendStreamingMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
request = SendStreamingMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
async for chunk in asend_message_streaming(
|
||||
request=request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
):
|
||||
print(chunk)
|
||||
```
|
||||
"""
|
||||
litellm_params = litellm_params or {}
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
|
||||
# Route through completion bridge if custom_llm_provider is set
|
||||
if custom_llm_provider:
|
||||
if request is None:
|
||||
raise ValueError("request is required for completion bridge")
|
||||
# api_base is optional for providers that derive endpoint from model (e.g., bedrock/agentcore)
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A streaming using completion bridge: provider={custom_llm_provider}"
|
||||
)
|
||||
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
|
||||
A2ACompletionBridgeHandler,
|
||||
)
|
||||
|
||||
# Extract params from request
|
||||
params = (
|
||||
request.params.model_dump(mode="json")
|
||||
if hasattr(request.params, "model_dump")
|
||||
else dict(request.params)
|
||||
)
|
||||
|
||||
async for chunk in A2ACompletionBridgeHandler.handle_streaming(
|
||||
request_id=str(request.id),
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# Standard A2A client flow
|
||||
if request is None:
|
||||
raise ValueError("request is required")
|
||||
|
||||
# Create A2A client if not provided but api_base is available
|
||||
if a2a_client is None:
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Either a2a_client or api_base is required for standard A2A flow"
|
||||
)
|
||||
# Mirror the non-streaming path: always include trace and agent-id headers
|
||||
streaming_extra_headers: Dict[str, str] = {
|
||||
"X-LiteLLM-Trace-Id": str(request.id),
|
||||
}
|
||||
if agent_id:
|
||||
streaming_extra_headers["X-LiteLLM-Agent-Id"] = agent_id
|
||||
if agent_extra_headers:
|
||||
streaming_extra_headers.update(agent_extra_headers)
|
||||
a2a_client = await create_a2a_client(
|
||||
base_url=api_base, extra_headers=streaming_extra_headers
|
||||
)
|
||||
|
||||
# Type assertion: a2a_client is guaranteed to be non-None here
|
||||
assert a2a_client is not None
|
||||
|
||||
verbose_logger.info(f"A2A send_message_streaming request_id={request.id}")
|
||||
|
||||
# Build logging object for streaming completion callbacks
|
||||
agent_card = getattr(a2a_client, "_litellm_agent_card", None) or getattr(
|
||||
a2a_client, "agent_card", None
|
||||
)
|
||||
card_url = getattr(agent_card, "url", None) if agent_card else None
|
||||
agent_name = getattr(agent_card, "name", "unknown") if agent_card else "unknown"
|
||||
|
||||
logging_obj = _build_streaming_logging_obj(
|
||||
request=request,
|
||||
agent_name=agent_name,
|
||||
agent_id=agent_id,
|
||||
litellm_params=litellm_params,
|
||||
metadata=metadata,
|
||||
proxy_server_request=proxy_server_request,
|
||||
)
|
||||
|
||||
# Retry loop: if connection fails due to localhost URL in agent card, retry with fixed URL
|
||||
# Connection errors in streaming typically occur on first chunk iteration
|
||||
first_chunk = True
|
||||
for attempt in range(2): # max 2 attempts: original + 1 retry
|
||||
stream = a2a_client.send_message_streaming(request)
|
||||
iterator = A2AStreamingIterator(
|
||||
stream=stream,
|
||||
request=request,
|
||||
logging_obj=logging_obj,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
try:
|
||||
first_chunk = True
|
||||
async for chunk in iterator:
|
||||
if first_chunk:
|
||||
first_chunk = False # connection succeeded
|
||||
yield chunk
|
||||
return # stream completed successfully
|
||||
except A2ALocalhostURLError as e:
|
||||
# Only retry on first chunk, not mid-stream
|
||||
if first_chunk and attempt == 0:
|
||||
a2a_client = handle_a2a_localhost_retry(
|
||||
error=e,
|
||||
agent_card=agent_card,
|
||||
a2a_client=a2a_client,
|
||||
is_streaming=True,
|
||||
)
|
||||
card_url = agent_card.url if agent_card else None
|
||||
else:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Only map exception on first chunk
|
||||
if first_chunk and attempt == 0:
|
||||
try:
|
||||
map_a2a_exception(e, card_url, api_base, model=agent_name)
|
||||
except A2ALocalhostURLError as localhost_err:
|
||||
# Localhost URL error - fix and retry
|
||||
a2a_client = handle_a2a_localhost_retry(
|
||||
error=localhost_err,
|
||||
agent_card=agent_card,
|
||||
a2a_client=a2a_client,
|
||||
is_streaming=True,
|
||||
)
|
||||
card_url = agent_card.url if agent_card else None
|
||||
continue
|
||||
except Exception:
|
||||
# Re-raise the mapped exception
|
||||
raise
|
||||
raise
|
||||
|
||||
|
||||
async def create_a2a_client(
|
||||
base_url: str,
|
||||
timeout: float = 60.0,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> "A2AClientType":
|
||||
"""
|
||||
Create an A2A client for the given agent URL.
|
||||
|
||||
This resolves the agent card and returns a ready-to-use A2A client.
|
||||
The client can be reused for multiple requests.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the A2A agent (e.g., "http://localhost:10001")
|
||||
timeout: Request timeout in seconds (default: 60.0)
|
||||
extra_headers: Optional additional headers to include in requests
|
||||
|
||||
Returns:
|
||||
An initialized a2a.client.A2AClient instance
|
||||
|
||||
Example:
|
||||
```python
|
||||
from litellm.a2a_protocol import create_a2a_client, asend_message
|
||||
|
||||
# Create client once
|
||||
client = await create_a2a_client(base_url="http://localhost:10001")
|
||||
|
||||
# Reuse for multiple requests
|
||||
response1 = await asend_message(a2a_client=client, request=request1)
|
||||
response2 = await asend_message(a2a_client=client, request=request2)
|
||||
```
|
||||
"""
|
||||
if not A2A_SDK_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The 'a2a' package is required for A2A agent invocation. "
|
||||
"Install it with: pip install a2a-sdk"
|
||||
)
|
||||
|
||||
verbose_logger.info(f"Creating A2A client for {base_url}")
|
||||
|
||||
# Use get_async_httpx_client with per-agent params so that different agents
|
||||
# (with different extra_headers) get separate cached clients. The params
|
||||
# dict is hashed into the cache key, keeping agent auth isolated while
|
||||
# still reusing connections within the same agent.
|
||||
#
|
||||
# Only pass params that AsyncHTTPHandler.__init__ accepts (e.g. timeout).
|
||||
# Use "disable_aiohttp_transport" key for cache-key-only data (it's
|
||||
# filtered out before reaching the constructor).
|
||||
_client_params: dict = {"timeout": timeout}
|
||||
if extra_headers:
|
||||
# Encode headers into a cache-key-only param so each unique header
|
||||
# set produces a distinct cache key.
|
||||
_client_params["disable_aiohttp_transport"] = str(sorted(extra_headers.items()))
|
||||
_async_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.A2AProvider,
|
||||
params=_client_params,
|
||||
)
|
||||
httpx_client = _async_handler.client
|
||||
if extra_headers:
|
||||
httpx_client.headers.update(extra_headers)
|
||||
verbose_proxy_logger.debug(
|
||||
f"A2A client created with extra_headers={list(extra_headers.keys())}"
|
||||
)
|
||||
|
||||
# Resolve agent card
|
||||
resolver = A2ACardResolver(
|
||||
httpx_client=httpx_client,
|
||||
base_url=base_url,
|
||||
)
|
||||
agent_card = await resolver.get_agent_card()
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Resolved agent card: {agent_card.name if hasattr(agent_card, 'name') else 'unknown'}"
|
||||
)
|
||||
|
||||
# Create A2A client
|
||||
a2a_client = _A2AClient(
|
||||
httpx_client=httpx_client,
|
||||
agent_card=agent_card,
|
||||
)
|
||||
|
||||
# Store agent_card on client for later retrieval (SDK doesn't expose it)
|
||||
a2a_client._litellm_agent_card = agent_card # type: ignore[attr-defined]
|
||||
|
||||
verbose_logger.info(f"A2A client created for {base_url}")
|
||||
|
||||
return a2a_client
|
||||
|
||||
|
||||
async def aget_agent_card(
|
||||
base_url: str,
|
||||
timeout: float = DEFAULT_A2A_AGENT_TIMEOUT,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> "AgentCard":
|
||||
"""
|
||||
Fetch the agent card from an A2A agent.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the A2A agent (e.g., "http://localhost:10001")
|
||||
timeout: Request timeout in seconds (default: 60.0)
|
||||
extra_headers: Optional additional headers to include in requests
|
||||
|
||||
Returns:
|
||||
AgentCard from the A2A agent
|
||||
"""
|
||||
if not A2A_SDK_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The 'a2a' package is required for A2A agent invocation. "
|
||||
"Install it with: pip install a2a-sdk"
|
||||
)
|
||||
|
||||
verbose_logger.info(f"Fetching agent card from {base_url}")
|
||||
|
||||
# Use LiteLLM's cached httpx client
|
||||
http_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.A2A,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
httpx_client = http_handler.client
|
||||
|
||||
resolver = A2ACardResolver(
|
||||
httpx_client=httpx_client,
|
||||
base_url=base_url,
|
||||
)
|
||||
agent_card = await resolver.get_agent_card()
|
||||
|
||||
verbose_logger.info(
|
||||
f"Fetched agent card: {agent_card.name if hasattr(agent_card, 'name') else 'unknown'}"
|
||||
)
|
||||
return agent_card
|
||||
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
A2A Protocol Providers.
|
||||
|
||||
This module contains provider-specific implementations for the A2A protocol.
|
||||
"""
|
||||
|
||||
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
|
||||
from litellm.a2a_protocol.providers.config_manager import A2AProviderConfigManager
|
||||
|
||||
__all__ = ["BaseA2AProviderConfig", "A2AProviderConfigManager"]
|
||||
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Base configuration for A2A protocol providers.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncIterator, Dict
|
||||
|
||||
|
||||
class BaseA2AProviderConfig(ABC):
|
||||
"""
|
||||
Base configuration class for A2A protocol providers.
|
||||
|
||||
Each provider should implement this interface to define how to handle
|
||||
A2A requests for their specific agent type.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def handle_non_streaming(
|
||||
self,
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle non-streaming A2A request.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
api_base: Base URL of the agent
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def handle_streaming(
|
||||
self,
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
**kwargs,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Handle streaming A2A request.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
api_base: Base URL of the agent
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
# This is an abstract method - subclasses must implement
|
||||
# The yield is here to make this a generator function
|
||||
if False: # pragma: no cover
|
||||
yield {}
|
||||
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
A2A Provider Config Manager.
|
||||
|
||||
Manages provider-specific configurations for A2A protocol.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
|
||||
|
||||
|
||||
class A2AProviderConfigManager:
|
||||
"""
|
||||
Manager for A2A provider configurations.
|
||||
|
||||
Similar to ProviderConfigManager in litellm.utils but specifically for A2A providers.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_provider_config(
|
||||
custom_llm_provider: Optional[str],
|
||||
) -> Optional[BaseA2AProviderConfig]:
|
||||
"""
|
||||
Get the provider configuration for a given custom_llm_provider.
|
||||
|
||||
Args:
|
||||
custom_llm_provider: The provider identifier (e.g., "pydantic_ai_agents")
|
||||
|
||||
Returns:
|
||||
Provider configuration instance or None if not found
|
||||
"""
|
||||
if custom_llm_provider is None:
|
||||
return None
|
||||
|
||||
if custom_llm_provider == "pydantic_ai_agents":
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.config import (
|
||||
PydanticAIProviderConfig,
|
||||
)
|
||||
|
||||
return PydanticAIProviderConfig()
|
||||
|
||||
# Add more providers here as needed
|
||||
# elif custom_llm_provider == "another_provider":
|
||||
# from litellm.a2a_protocol.providers.another_provider.config import AnotherProviderConfig
|
||||
# return AnotherProviderConfig()
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,74 @@
|
||||
# A2A to LiteLLM Completion Bridge
|
||||
|
||||
Routes A2A protocol requests through `litellm.acompletion`, enabling any LiteLLM-supported provider to be invoked via A2A.
|
||||
|
||||
## Flow
|
||||
|
||||
```
|
||||
A2A Request → Transform → litellm.acompletion → Transform → A2A Response
|
||||
```
|
||||
|
||||
## SDK Usage
|
||||
|
||||
Use the existing `asend_message` and `asend_message_streaming` functions with `litellm_params`:
|
||||
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message, asend_message_streaming
|
||||
from a2a.types import SendMessageRequest, SendStreamingMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
# Non-streaming
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
response = await asend_message(
|
||||
request=request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
)
|
||||
|
||||
# Streaming
|
||||
stream_request = SendStreamingMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
async for chunk in asend_message_streaming(
|
||||
request=stream_request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
):
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
## Proxy Usage
|
||||
|
||||
Configure an agent with `custom_llm_provider` in `litellm_params`:
|
||||
|
||||
```yaml
|
||||
agents:
|
||||
- agent_name: my-langgraph-agent
|
||||
agent_card_params:
|
||||
name: "LangGraph Agent"
|
||||
url: "http://localhost:2024" # Used as api_base
|
||||
litellm_params:
|
||||
custom_llm_provider: langgraph
|
||||
model: agent
|
||||
```
|
||||
|
||||
When an A2A request hits `/a2a/{agent_id}/message/send`, the bridge:
|
||||
|
||||
1. Detects `custom_llm_provider` in agent's `litellm_params`
|
||||
2. Transforms A2A message → OpenAI messages
|
||||
3. Calls `litellm.acompletion(model="langgraph/agent", api_base="http://localhost:2024")`
|
||||
4. Transforms response → A2A format
|
||||
|
||||
## Classes
|
||||
|
||||
- `A2ACompletionBridgeTransformation` - Static methods for message format conversion
|
||||
- `A2ACompletionBridgeHandler` - Static methods for handling requests (streaming/non-streaming)
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
LiteLLM Completion bridge provider for A2A protocol.
|
||||
|
||||
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
|
||||
"""
|
||||
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
Handler for A2A to LiteLLM completion bridge.
|
||||
|
||||
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
|
||||
|
||||
A2A Streaming Events (in order):
|
||||
1. Task event (kind: "task") - Initial task creation with status "submitted"
|
||||
2. Status update (kind: "status-update") - Status change to "working"
|
||||
3. Artifact update (kind: "artifact-update") - Content/artifact delivery
|
||||
4. Status update (kind: "status-update") - Final status "completed" with final=true
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.pydantic_ai_transformation import (
|
||||
PydanticAITransformation,
|
||||
)
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.transformation import (
|
||||
A2ACompletionBridgeTransformation,
|
||||
A2AStreamingContext,
|
||||
)
|
||||
|
||||
|
||||
class A2ACompletionBridgeHandler:
|
||||
"""
|
||||
Static methods for handling A2A requests via LiteLLM completion.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def handle_non_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle non-streaming A2A request via litellm.acompletion.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
|
||||
api_base: API base URL from agent_card_params
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
# Check if this is a Pydantic AI agent request
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
if custom_llm_provider == "pydantic_ai_agents":
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for Pydantic AI agents")
|
||||
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Routing to Pydantic AI agent at {api_base}"
|
||||
)
|
||||
|
||||
# Send request directly to Pydantic AI agent
|
||||
response_data = await PydanticAITransformation.send_non_streaming_request(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
# Extract message from params
|
||||
message = params.get("message", {})
|
||||
|
||||
# Transform A2A message to OpenAI format
|
||||
openai_messages = (
|
||||
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
|
||||
)
|
||||
|
||||
# Get completion params
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
model = litellm_params.get("model", "agent")
|
||||
|
||||
# Build full model string if provider specified
|
||||
# Skip prepending if model already starts with the provider prefix
|
||||
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
|
||||
full_model = f"{custom_llm_provider}/{model}"
|
||||
else:
|
||||
full_model = model
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge: model={full_model}, api_base={api_base}"
|
||||
)
|
||||
|
||||
# Build completion params dict
|
||||
completion_params = {
|
||||
"model": full_model,
|
||||
"messages": openai_messages,
|
||||
"api_base": api_base,
|
||||
"stream": False,
|
||||
}
|
||||
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
|
||||
litellm_params_to_add = {
|
||||
k: v
|
||||
for k, v in litellm_params.items()
|
||||
if k not in ("model", "custom_llm_provider")
|
||||
}
|
||||
completion_params.update(litellm_params_to_add)
|
||||
|
||||
# Call litellm.acompletion
|
||||
response = await litellm.acompletion(**completion_params)
|
||||
|
||||
# Transform response to A2A format
|
||||
a2a_response = (
|
||||
A2ACompletionBridgeTransformation.openai_response_to_a2a_response(
|
||||
response=response,
|
||||
request_id=request_id,
|
||||
)
|
||||
)
|
||||
|
||||
verbose_logger.info(f"A2A completion bridge completed: request_id={request_id}")
|
||||
|
||||
return a2a_response
|
||||
|
||||
@staticmethod
|
||||
async def handle_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Handle streaming A2A request via litellm.acompletion with stream=True.
|
||||
|
||||
Emits proper A2A streaming events:
|
||||
1. Task event (kind: "task") - Initial task with status "submitted"
|
||||
2. Status update (kind: "status-update") - Status "working"
|
||||
3. Artifact update (kind: "artifact-update") - Content delivery
|
||||
4. Status update (kind: "status-update") - Final "completed" status
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
|
||||
api_base: API base URL from agent_card_params
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
# Check if this is a Pydantic AI agent request
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
if custom_llm_provider == "pydantic_ai_agents":
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for Pydantic AI agents")
|
||||
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Faking streaming for Pydantic AI agent at {api_base}"
|
||||
)
|
||||
|
||||
# Get non-streaming response first
|
||||
response_data = await PydanticAITransformation.send_non_streaming_request(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
)
|
||||
|
||||
# Convert to fake streaming
|
||||
async for chunk in PydanticAITransformation.fake_streaming_from_response(
|
||||
response_data=response_data,
|
||||
request_id=request_id,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return
|
||||
|
||||
# Extract message from params
|
||||
message = params.get("message", {})
|
||||
|
||||
# Create streaming context
|
||||
ctx = A2AStreamingContext(
|
||||
request_id=request_id,
|
||||
input_message=message,
|
||||
)
|
||||
|
||||
# Transform A2A message to OpenAI format
|
||||
openai_messages = (
|
||||
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
|
||||
)
|
||||
|
||||
# Get completion params
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
model = litellm_params.get("model", "agent")
|
||||
|
||||
# Build full model string if provider specified
|
||||
# Skip prepending if model already starts with the provider prefix
|
||||
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
|
||||
full_model = f"{custom_llm_provider}/{model}"
|
||||
else:
|
||||
full_model = model
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge streaming: model={full_model}, api_base={api_base}"
|
||||
)
|
||||
|
||||
# Build completion params dict
|
||||
completion_params = {
|
||||
"model": full_model,
|
||||
"messages": openai_messages,
|
||||
"api_base": api_base,
|
||||
"stream": True,
|
||||
}
|
||||
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
|
||||
litellm_params_to_add = {
|
||||
k: v
|
||||
for k, v in litellm_params.items()
|
||||
if k not in ("model", "custom_llm_provider")
|
||||
}
|
||||
completion_params.update(litellm_params_to_add)
|
||||
|
||||
# 1. Emit initial task event (kind: "task", status: "submitted")
|
||||
task_event = A2ACompletionBridgeTransformation.create_task_event(ctx)
|
||||
yield task_event
|
||||
|
||||
# 2. Emit status update (kind: "status-update", status: "working")
|
||||
working_event = A2ACompletionBridgeTransformation.create_status_update_event(
|
||||
ctx=ctx,
|
||||
state="working",
|
||||
final=False,
|
||||
message_text="Processing request...",
|
||||
)
|
||||
yield working_event
|
||||
|
||||
# Call litellm.acompletion with streaming
|
||||
response = await litellm.acompletion(**completion_params)
|
||||
|
||||
# 3. Accumulate content and emit artifact update
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
async for chunk in response: # type: ignore[union-attr]
|
||||
chunk_count += 1
|
||||
|
||||
# Extract delta content
|
||||
content = ""
|
||||
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
content = choice.delta.content or ""
|
||||
|
||||
if content:
|
||||
accumulated_text += content
|
||||
|
||||
# Emit artifact update with accumulated content
|
||||
if accumulated_text:
|
||||
artifact_event = (
|
||||
A2ACompletionBridgeTransformation.create_artifact_update_event(
|
||||
ctx=ctx,
|
||||
text=accumulated_text,
|
||||
)
|
||||
)
|
||||
yield artifact_event
|
||||
|
||||
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
|
||||
completed_event = A2ACompletionBridgeTransformation.create_status_update_event(
|
||||
ctx=ctx,
|
||||
state="completed",
|
||||
final=True,
|
||||
)
|
||||
yield completed_event
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge streaming completed: request_id={request_id}, chunks={chunk_count}"
|
||||
)
|
||||
|
||||
|
||||
# Convenience functions that delegate to the class methods
|
||||
async def handle_a2a_completion(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Convenience function for non-streaming A2A completion."""
|
||||
return await A2ACompletionBridgeHandler.handle_non_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
|
||||
async def handle_a2a_completion_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Convenience function for streaming A2A completion."""
|
||||
async for chunk in A2ACompletionBridgeHandler.handle_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Transformation utilities for A2A <-> OpenAI message format conversion.
|
||||
|
||||
A2A Message Format:
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": "Hello!"}],
|
||||
"messageId": "abc123"
|
||||
}
|
||||
|
||||
OpenAI Message Format:
|
||||
{"role": "user", "content": "Hello!"}
|
||||
|
||||
A2A Streaming Events:
|
||||
- Task event (kind: "task") - Initial task creation with status "submitted"
|
||||
- Status update (kind: "status-update") - Status changes (working, completed)
|
||||
- Artifact update (kind: "artifact-update") - Content/artifact delivery
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
class A2AStreamingContext:
|
||||
"""
|
||||
Context holder for A2A streaming state.
|
||||
Tracks task_id, context_id, and message accumulation.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str, input_message: Dict[str, Any]):
|
||||
self.request_id = request_id
|
||||
self.task_id = str(uuid4())
|
||||
self.context_id = str(uuid4())
|
||||
self.input_message = input_message
|
||||
self.accumulated_text = ""
|
||||
self.has_emitted_task = False
|
||||
self.has_emitted_working = False
|
||||
|
||||
|
||||
class A2ACompletionBridgeTransformation:
|
||||
"""
|
||||
Static methods for transforming between A2A and OpenAI message formats.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def a2a_message_to_openai_messages(
|
||||
a2a_message: Dict[str, Any],
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Transform an A2A message to OpenAI message format.
|
||||
|
||||
Args:
|
||||
a2a_message: A2A message with role, parts, and messageId
|
||||
|
||||
Returns:
|
||||
List of OpenAI-format messages
|
||||
"""
|
||||
role = a2a_message.get("role", "user")
|
||||
parts = a2a_message.get("parts", [])
|
||||
|
||||
# Map A2A roles to OpenAI roles
|
||||
openai_role = role
|
||||
if role == "user":
|
||||
openai_role = "user"
|
||||
elif role == "assistant":
|
||||
openai_role = "assistant"
|
||||
elif role == "system":
|
||||
openai_role = "system"
|
||||
|
||||
# Extract text content from parts
|
||||
content_parts = []
|
||||
for part in parts:
|
||||
kind = part.get("kind", "")
|
||||
if kind == "text":
|
||||
text = part.get("text", "")
|
||||
content_parts.append(text)
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else ""
|
||||
|
||||
verbose_logger.debug(
|
||||
f"A2A -> OpenAI transform: role={role} -> {openai_role}, content_length={len(content)}"
|
||||
)
|
||||
|
||||
return [{"role": openai_role, "content": content}]
|
||||
|
||||
@staticmethod
|
||||
def openai_response_to_a2a_response(
|
||||
response: Any,
|
||||
request_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform a LiteLLM ModelResponse to A2A SendMessageResponse format.
|
||||
|
||||
Args:
|
||||
response: LiteLLM ModelResponse object
|
||||
request_id: Original A2A request ID
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
# Extract content from response
|
||||
content = ""
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "message") and choice.message:
|
||||
content = choice.message.content or ""
|
||||
|
||||
# Build A2A message
|
||||
a2a_message = {
|
||||
"role": "agent",
|
||||
"parts": [{"kind": "text", "text": content}],
|
||||
"messageId": uuid4().hex,
|
||||
}
|
||||
|
||||
# Build A2A response
|
||||
a2a_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"message": a2a_message,
|
||||
},
|
||||
}
|
||||
|
||||
verbose_logger.debug(f"OpenAI -> A2A transform: content_length={len(content)}")
|
||||
|
||||
return a2a_response
|
||||
|
||||
@staticmethod
|
||||
def _get_timestamp() -> str:
|
||||
"""Get current timestamp in ISO format with timezone."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
@staticmethod
|
||||
def create_task_event(
|
||||
ctx: A2AStreamingContext,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create the initial task event with status 'submitted'.
|
||||
|
||||
This is the first event emitted in an A2A streaming response.
|
||||
"""
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"contextId": ctx.context_id,
|
||||
"history": [
|
||||
{
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "message",
|
||||
"messageId": ctx.input_message.get("messageId", uuid4().hex),
|
||||
"parts": ctx.input_message.get("parts", []),
|
||||
"role": ctx.input_message.get("role", "user"),
|
||||
"taskId": ctx.task_id,
|
||||
}
|
||||
],
|
||||
"id": ctx.task_id,
|
||||
"kind": "task",
|
||||
"status": {
|
||||
"state": "submitted",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_status_update_event(
|
||||
ctx: A2AStreamingContext,
|
||||
state: str,
|
||||
final: bool = False,
|
||||
message_text: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a status update event.
|
||||
|
||||
Args:
|
||||
ctx: Streaming context
|
||||
state: Status state ('working', 'completed')
|
||||
final: Whether this is the final event
|
||||
message_text: Optional message text for 'working' status
|
||||
"""
|
||||
status: Dict[str, Any] = {
|
||||
"state": state,
|
||||
"timestamp": A2ACompletionBridgeTransformation._get_timestamp(),
|
||||
}
|
||||
|
||||
# Add message for 'working' status
|
||||
if state == "working" and message_text:
|
||||
status["message"] = {
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "message",
|
||||
"messageId": str(uuid4()),
|
||||
"parts": [{"kind": "text", "text": message_text}],
|
||||
"role": "agent",
|
||||
"taskId": ctx.task_id,
|
||||
}
|
||||
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"contextId": ctx.context_id,
|
||||
"final": final,
|
||||
"kind": "status-update",
|
||||
"status": status,
|
||||
"taskId": ctx.task_id,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_artifact_update_event(
|
||||
ctx: A2AStreamingContext,
|
||||
text: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create an artifact update event with content.
|
||||
|
||||
Args:
|
||||
ctx: Streaming context
|
||||
text: The text content for the artifact
|
||||
"""
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"artifact": {
|
||||
"artifactId": str(uuid4()),
|
||||
"name": "response",
|
||||
"parts": [{"kind": "text", "text": text}],
|
||||
},
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "artifact-update",
|
||||
"taskId": ctx.task_id,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def openai_chunk_to_a2a_chunk(
|
||||
chunk: Any,
|
||||
request_id: Optional[str] = None,
|
||||
is_final: bool = False,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Transform a LiteLLM streaming chunk to A2A streaming format.
|
||||
|
||||
NOTE: This method is deprecated for streaming. Use the event-based
|
||||
methods (create_task_event, create_status_update_event,
|
||||
create_artifact_update_event) instead for proper A2A streaming.
|
||||
|
||||
Args:
|
||||
chunk: LiteLLM ModelResponse chunk
|
||||
request_id: Original A2A request ID
|
||||
is_final: Whether this is the final chunk
|
||||
|
||||
Returns:
|
||||
A2A streaming chunk dict or None if no content
|
||||
"""
|
||||
# Extract delta content
|
||||
content = ""
|
||||
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
content = choice.delta.content or ""
|
||||
|
||||
if not content and not is_final:
|
||||
return None
|
||||
|
||||
# Build A2A streaming chunk (legacy format)
|
||||
a2a_chunk = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"message": {
|
||||
"role": "agent",
|
||||
"parts": [{"kind": "text", "text": content}],
|
||||
"messageId": uuid4().hex,
|
||||
},
|
||||
"final": is_final,
|
||||
},
|
||||
}
|
||||
|
||||
return a2a_chunk
|
||||
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Pydantic AI agent provider for A2A protocol.
|
||||
|
||||
Pydantic AI agents follow A2A protocol but don't support streaming natively.
|
||||
This provider handles fake streaming by converting non-streaming responses into streaming chunks.
|
||||
"""
|
||||
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.config import (
|
||||
PydanticAIProviderConfig,
|
||||
)
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.handler import PydanticAIHandler
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.transformation import (
|
||||
PydanticAITransformation,
|
||||
)
|
||||
|
||||
__all__ = ["PydanticAIHandler", "PydanticAITransformation", "PydanticAIProviderConfig"]
|
||||
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Pydantic AI provider configuration.
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Dict
|
||||
|
||||
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.handler import PydanticAIHandler
|
||||
|
||||
|
||||
class PydanticAIProviderConfig(BaseA2AProviderConfig):
|
||||
"""
|
||||
Provider configuration for Pydantic AI agents.
|
||||
|
||||
Pydantic AI agents follow A2A protocol but don't support streaming natively.
|
||||
This config provides fake streaming by converting non-streaming responses into streaming chunks.
|
||||
"""
|
||||
|
||||
async def handle_non_streaming(
|
||||
self,
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Handle non-streaming request to Pydantic AI agent."""
|
||||
return await PydanticAIHandler.handle_non_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
api_base=api_base,
|
||||
timeout=kwargs.get("timeout", 60.0),
|
||||
)
|
||||
|
||||
async def handle_streaming(
|
||||
self,
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
**kwargs,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Handle streaming request with fake streaming."""
|
||||
async for chunk in PydanticAIHandler.handle_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
api_base=api_base,
|
||||
timeout=kwargs.get("timeout", 60.0),
|
||||
chunk_size=kwargs.get("chunk_size", 50),
|
||||
delay_ms=kwargs.get("delay_ms", 10),
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Handler for Pydantic AI agents.
|
||||
|
||||
Pydantic AI agents follow A2A protocol but don't support streaming natively.
|
||||
This handler provides fake streaming by converting non-streaming responses into streaming chunks.
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Dict
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.transformation import (
|
||||
PydanticAITransformation,
|
||||
)
|
||||
|
||||
|
||||
class PydanticAIHandler:
|
||||
"""
|
||||
Handler for Pydantic AI agent requests.
|
||||
|
||||
Provides:
|
||||
- Direct non-streaming requests to Pydantic AI agents
|
||||
- Fake streaming by converting non-streaming responses into streaming chunks
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def handle_non_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
timeout: float = 60.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle non-streaming request to Pydantic AI agent.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
api_base: Base URL of the Pydantic AI agent
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
verbose_logger.info(f"Pydantic AI: Routing to Pydantic AI agent at {api_base}")
|
||||
|
||||
# Send request directly to Pydantic AI agent
|
||||
response_data = await PydanticAITransformation.send_non_streaming_request(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
@staticmethod
|
||||
async def handle_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
timeout: float = 60.0,
|
||||
chunk_size: int = 50,
|
||||
delay_ms: int = 10,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Handle streaming request to Pydantic AI agent with fake streaming.
|
||||
|
||||
Since Pydantic AI agents don't support streaming natively, this method:
|
||||
1. Makes a non-streaming request
|
||||
2. Converts the response into streaming chunks
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
api_base: Base URL of the Pydantic AI agent
|
||||
timeout: Request timeout in seconds
|
||||
chunk_size: Number of characters per chunk
|
||||
delay_ms: Delay between chunks in milliseconds
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Faking streaming for Pydantic AI agent at {api_base}"
|
||||
)
|
||||
|
||||
# Get raw task response first (not the transformed A2A format)
|
||||
raw_response = await PydanticAITransformation.send_and_get_raw_response(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Convert raw task response to fake streaming chunks
|
||||
async for chunk in PydanticAITransformation.fake_streaming_from_response(
|
||||
response_data=raw_response,
|
||||
request_id=request_id,
|
||||
chunk_size=chunk_size,
|
||||
delay_ms=delay_ms,
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,530 @@
|
||||
"""
|
||||
Transformation layer for Pydantic AI agents.
|
||||
|
||||
Pydantic AI agents follow A2A protocol but don't support streaming.
|
||||
This module provides fake streaming by converting non-streaming responses into streaming chunks.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, AsyncIterator, Dict, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
|
||||
|
||||
class PydanticAITransformation:
|
||||
"""
|
||||
Transformation layer for Pydantic AI agents.
|
||||
|
||||
Handles:
|
||||
- Direct A2A requests to Pydantic AI endpoints
|
||||
- Polling for task completion (since Pydantic AI doesn't support streaming)
|
||||
- Fake streaming by chunking non-streaming responses
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _remove_none_values(obj: Any) -> Any:
|
||||
"""
|
||||
Recursively remove None values from a dict/list structure.
|
||||
|
||||
FastA2A/Pydantic AI servers don't accept None values for optional fields -
|
||||
they expect those fields to be omitted entirely.
|
||||
|
||||
Args:
|
||||
obj: Dict, list, or other value to clean
|
||||
|
||||
Returns:
|
||||
Cleaned object with None values removed
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
k: PydanticAITransformation._remove_none_values(v)
|
||||
for k, v in obj.items()
|
||||
if v is not None
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [
|
||||
PydanticAITransformation._remove_none_values(item)
|
||||
for item in obj
|
||||
if item is not None
|
||||
]
|
||||
else:
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
def _params_to_dict(params: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert params to a dict, handling Pydantic models.
|
||||
|
||||
Args:
|
||||
params: Dict or Pydantic model
|
||||
|
||||
Returns:
|
||||
Dict representation of params
|
||||
"""
|
||||
if hasattr(params, "model_dump"):
|
||||
# Pydantic v2 model
|
||||
return params.model_dump(mode="python", exclude_none=True)
|
||||
elif hasattr(params, "dict"):
|
||||
# Pydantic v1 model
|
||||
return params.dict(exclude_none=True)
|
||||
elif isinstance(params, dict):
|
||||
return params
|
||||
else:
|
||||
# Try to convert to dict
|
||||
return dict(params)
|
||||
|
||||
@staticmethod
|
||||
async def _poll_for_completion(
|
||||
client: AsyncHTTPHandler,
|
||||
endpoint: str,
|
||||
task_id: str,
|
||||
request_id: str,
|
||||
max_attempts: int = 30,
|
||||
poll_interval: float = 0.5,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Poll for task completion using tasks/get method.
|
||||
|
||||
Args:
|
||||
client: HTTPX async client
|
||||
endpoint: API endpoint URL
|
||||
task_id: Task ID to poll for
|
||||
request_id: JSON-RPC request ID
|
||||
max_attempts: Maximum polling attempts
|
||||
poll_interval: Seconds between poll attempts
|
||||
|
||||
Returns:
|
||||
Completed task response
|
||||
"""
|
||||
for attempt in range(max_attempts):
|
||||
poll_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": f"{request_id}-poll-{attempt}",
|
||||
"method": "tasks/get",
|
||||
"params": {"id": task_id},
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
endpoint,
|
||||
json=poll_request,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
poll_data = response.json()
|
||||
|
||||
result = poll_data.get("result", {})
|
||||
status = result.get("status", {})
|
||||
state = status.get("state", "")
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Pydantic AI: Poll attempt {attempt + 1}/{max_attempts}, state={state}"
|
||||
)
|
||||
|
||||
if state == "completed":
|
||||
return poll_data
|
||||
elif state in ("failed", "canceled"):
|
||||
raise Exception(f"Task {task_id} ended with state: {state}")
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
raise TimeoutError(
|
||||
f"Task {task_id} did not complete within {max_attempts * poll_interval} seconds"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _send_and_poll_raw(
|
||||
api_base: str,
|
||||
request_id: str,
|
||||
params: Any,
|
||||
timeout: float = 60.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Send a request to Pydantic AI agent and return the raw task response.
|
||||
|
||||
This is an internal method used by both non-streaming and streaming handlers.
|
||||
Returns the raw Pydantic AI task format with history/artifacts.
|
||||
|
||||
Args:
|
||||
api_base: Base URL of the Pydantic AI agent
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Raw Pydantic AI task response (with history/artifacts)
|
||||
"""
|
||||
# Convert params to dict if it's a Pydantic model
|
||||
params_dict = PydanticAITransformation._params_to_dict(params)
|
||||
|
||||
# Remove None values - FastA2A doesn't accept null for optional fields
|
||||
params_dict = PydanticAITransformation._remove_none_values(params_dict)
|
||||
|
||||
# Ensure the message has 'kind': 'message' as required by FastA2A/Pydantic AI
|
||||
if "message" in params_dict:
|
||||
params_dict["message"]["kind"] = "message"
|
||||
|
||||
# Build A2A JSON-RPC request using message/send method for FastA2A compatibility
|
||||
a2a_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": "message/send",
|
||||
"params": params_dict,
|
||||
}
|
||||
|
||||
# FastA2A uses root endpoint (/) not /messages
|
||||
endpoint = api_base.rstrip("/")
|
||||
|
||||
verbose_logger.info(f"Pydantic AI: Sending non-streaming request to {endpoint}")
|
||||
|
||||
# Send request to Pydantic AI agent using shared async HTTP client
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=cast(Any, "pydantic_ai_agent"),
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
response = await client.post(
|
||||
endpoint,
|
||||
json=a2a_request,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
|
||||
# Check if task is already completed
|
||||
result = response_data.get("result", {})
|
||||
status = result.get("status", {})
|
||||
state = status.get("state", "")
|
||||
|
||||
if state != "completed":
|
||||
# Need to poll for completion
|
||||
task_id = result.get("id")
|
||||
if task_id:
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Task {task_id} submitted, polling for completion..."
|
||||
)
|
||||
response_data = await PydanticAITransformation._poll_for_completion(
|
||||
client=client,
|
||||
endpoint=endpoint,
|
||||
task_id=task_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Received completed response for request_id={request_id}"
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
@staticmethod
|
||||
async def send_non_streaming_request(
|
||||
api_base: str,
|
||||
request_id: str,
|
||||
params: Any,
|
||||
timeout: float = 60.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Send a non-streaming A2A request to Pydantic AI agent and wait for completion.
|
||||
|
||||
Args:
|
||||
api_base: Base URL of the Pydantic AI agent (e.g., "http://localhost:9999")
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message (dict or Pydantic model)
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Standard A2A non-streaming response format with message
|
||||
"""
|
||||
# Get raw task response
|
||||
raw_response = await PydanticAITransformation._send_and_poll_raw(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Transform to standard A2A non-streaming format
|
||||
return PydanticAITransformation._transform_to_a2a_response(
|
||||
response_data=raw_response,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def send_and_get_raw_response(
|
||||
api_base: str,
|
||||
request_id: str,
|
||||
params: Any,
|
||||
timeout: float = 60.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Send a request to Pydantic AI agent and return the raw task response.
|
||||
|
||||
Used by streaming handler to get raw response for fake streaming.
|
||||
|
||||
Args:
|
||||
api_base: Base URL of the Pydantic AI agent
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Raw Pydantic AI task response (with history/artifacts)
|
||||
"""
|
||||
return await PydanticAITransformation._send_and_poll_raw(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _transform_to_a2a_response(
|
||||
response_data: Dict[str, Any],
|
||||
request_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform Pydantic AI task response to standard A2A non-streaming format.
|
||||
|
||||
Pydantic AI returns a task with history/artifacts, but the standard A2A
|
||||
non-streaming format expects:
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "...",
|
||||
"result": {
|
||||
"message": {
|
||||
"role": "agent",
|
||||
"parts": [{"kind": "text", "text": "..."}],
|
||||
"messageId": "..."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Args:
|
||||
response_data: Pydantic AI task response
|
||||
request_id: Original request ID
|
||||
|
||||
Returns:
|
||||
Standard A2A non-streaming response format
|
||||
"""
|
||||
# Extract the agent response text
|
||||
full_text, message_id, parts = PydanticAITransformation._extract_response_text(
|
||||
response_data
|
||||
)
|
||||
|
||||
# Build standard A2A message
|
||||
a2a_message = {
|
||||
"role": "agent",
|
||||
"parts": parts if parts else [{"kind": "text", "text": full_text}],
|
||||
"messageId": message_id,
|
||||
}
|
||||
|
||||
# Return standard A2A non-streaming format
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"message": a2a_message,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _extract_response_text(response_data: Dict[str, Any]) -> tuple[str, str, list]:
|
||||
"""
|
||||
Extract response text from completed task response.
|
||||
|
||||
Pydantic AI returns completed tasks with:
|
||||
- history: list of messages (user and agent)
|
||||
- artifacts: list of result artifacts
|
||||
|
||||
Args:
|
||||
response_data: Completed task response
|
||||
|
||||
Returns:
|
||||
Tuple of (full_text, message_id, parts)
|
||||
"""
|
||||
result = response_data.get("result", {})
|
||||
|
||||
# Try to extract from artifacts first (preferred for results)
|
||||
artifacts = result.get("artifacts", [])
|
||||
if artifacts:
|
||||
for artifact in artifacts:
|
||||
parts = artifact.get("parts", [])
|
||||
for part in parts:
|
||||
if part.get("kind") == "text":
|
||||
text = part.get("text", "")
|
||||
if text:
|
||||
return text, str(uuid4()), parts
|
||||
|
||||
# Fall back to history - get the last agent message
|
||||
history = result.get("history", [])
|
||||
for msg in reversed(history):
|
||||
if msg.get("role") == "agent":
|
||||
parts = msg.get("parts", [])
|
||||
message_id = msg.get("messageId", str(uuid4()))
|
||||
full_text = ""
|
||||
for part in parts:
|
||||
if part.get("kind") == "text":
|
||||
full_text += part.get("text", "")
|
||||
if full_text:
|
||||
return full_text, message_id, parts
|
||||
|
||||
# Fall back to message field (original format)
|
||||
message = result.get("message", {})
|
||||
if message:
|
||||
parts = message.get("parts", [])
|
||||
message_id = message.get("messageId", str(uuid4()))
|
||||
full_text = ""
|
||||
for part in parts:
|
||||
if part.get("kind") == "text":
|
||||
full_text += part.get("text", "")
|
||||
return full_text, message_id, parts
|
||||
|
||||
return "", str(uuid4()), []
|
||||
|
||||
@staticmethod
|
||||
async def fake_streaming_from_response(
|
||||
response_data: Dict[str, Any],
|
||||
request_id: str,
|
||||
chunk_size: int = 50,
|
||||
delay_ms: int = 10,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Convert a non-streaming A2A response into fake streaming chunks.
|
||||
|
||||
Emits proper A2A streaming events:
|
||||
1. Task event (kind: "task") - Initial task with status "submitted"
|
||||
2. Status update (kind: "status-update") - Status "working"
|
||||
3. Artifact update chunks (kind: "artifact-update") - Content delivery in chunks
|
||||
4. Status update (kind: "status-update") - Final "completed" status
|
||||
|
||||
Args:
|
||||
response_data: Non-streaming A2A response dict (completed task)
|
||||
request_id: A2A JSON-RPC request ID
|
||||
chunk_size: Number of characters per chunk (default: 50)
|
||||
delay_ms: Delay between chunks in milliseconds (default: 10)
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
# Extract the response text from completed task
|
||||
full_text, message_id, parts = PydanticAITransformation._extract_response_text(
|
||||
response_data
|
||||
)
|
||||
|
||||
# Extract input message from raw response for history
|
||||
result = response_data.get("result", {})
|
||||
history = result.get("history", [])
|
||||
input_message = {}
|
||||
for msg in history:
|
||||
if msg.get("role") == "user":
|
||||
input_message = msg
|
||||
break
|
||||
|
||||
# Generate IDs for streaming events
|
||||
task_id = str(uuid4())
|
||||
context_id = str(uuid4())
|
||||
artifact_id = str(uuid4())
|
||||
input_message_id = input_message.get("messageId", str(uuid4()))
|
||||
|
||||
# 1. Emit initial task event (kind: "task", status: "submitted")
|
||||
# Format matches A2ACompletionBridgeTransformation.create_task_event
|
||||
task_event = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"contextId": context_id,
|
||||
"history": [
|
||||
{
|
||||
"contextId": context_id,
|
||||
"kind": "message",
|
||||
"messageId": input_message_id,
|
||||
"parts": input_message.get(
|
||||
"parts", [{"kind": "text", "text": ""}]
|
||||
),
|
||||
"role": "user",
|
||||
"taskId": task_id,
|
||||
}
|
||||
],
|
||||
"id": task_id,
|
||||
"kind": "task",
|
||||
"status": {
|
||||
"state": "submitted",
|
||||
},
|
||||
},
|
||||
}
|
||||
yield task_event
|
||||
|
||||
# 2. Emit status update (kind: "status-update", status: "working")
|
||||
# Format matches A2ACompletionBridgeTransformation.create_status_update_event
|
||||
working_event = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"contextId": context_id,
|
||||
"final": False,
|
||||
"kind": "status-update",
|
||||
"status": {
|
||||
"state": "working",
|
||||
},
|
||||
"taskId": task_id,
|
||||
},
|
||||
}
|
||||
yield working_event
|
||||
|
||||
# Small delay to simulate processing
|
||||
await asyncio.sleep(delay_ms / 1000.0)
|
||||
|
||||
# 3. Emit artifact update chunks (kind: "artifact-update")
|
||||
# Format matches A2ACompletionBridgeTransformation.create_artifact_update_event
|
||||
if full_text:
|
||||
# Split text into chunks
|
||||
for i in range(0, len(full_text), chunk_size):
|
||||
chunk_text = full_text[i : i + chunk_size]
|
||||
is_last_chunk = (i + chunk_size) >= len(full_text)
|
||||
|
||||
artifact_event = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"contextId": context_id,
|
||||
"kind": "artifact-update",
|
||||
"taskId": task_id,
|
||||
"artifact": {
|
||||
"artifactId": artifact_id,
|
||||
"parts": [
|
||||
{
|
||||
"kind": "text",
|
||||
"text": chunk_text,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
yield artifact_event
|
||||
|
||||
# Add delay between chunks (except for last chunk)
|
||||
if not is_last_chunk:
|
||||
await asyncio.sleep(delay_ms / 1000.0)
|
||||
|
||||
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
|
||||
completed_event = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"contextId": context_id,
|
||||
"final": True,
|
||||
"kind": "status-update",
|
||||
"status": {
|
||||
"state": "completed",
|
||||
},
|
||||
"taskId": task_id,
|
||||
},
|
||||
}
|
||||
yield completed_event
|
||||
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Fake streaming completed for request_id={request_id}"
|
||||
)
|
||||
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
A2A Streaming Iterator with token tracking and logging support.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.a2a_protocol.cost_calculator import A2ACostCalculator
|
||||
from litellm.a2a_protocol.utils import A2ARequestUtils
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.thread_pool_executor import executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import SendStreamingMessageRequest, SendStreamingMessageResponse
|
||||
|
||||
|
||||
class A2AStreamingIterator:
|
||||
"""
|
||||
Async iterator for A2A streaming responses with token tracking.
|
||||
|
||||
Collects chunks, extracts text, and logs usage on completion.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: AsyncIterator["SendStreamingMessageResponse"],
|
||||
request: "SendStreamingMessageRequest",
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
agent_name: str = "unknown",
|
||||
):
|
||||
self.stream = stream
|
||||
self.request = request
|
||||
self.logging_obj = logging_obj
|
||||
self.agent_name = agent_name
|
||||
self.start_time = datetime.now()
|
||||
|
||||
# Collect chunks for token counting
|
||||
self.chunks: List[Any] = []
|
||||
self.collected_text_parts: List[str] = []
|
||||
self.final_chunk: Optional[Any] = None
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> "SendStreamingMessageResponse":
|
||||
try:
|
||||
chunk = await self.stream.__anext__()
|
||||
|
||||
# Store chunk
|
||||
self.chunks.append(chunk)
|
||||
|
||||
# Extract text from chunk for token counting
|
||||
self._collect_text_from_chunk(chunk)
|
||||
|
||||
# Check if this is the final chunk (completed status)
|
||||
if self._is_completed_chunk(chunk):
|
||||
self.final_chunk = chunk
|
||||
|
||||
return chunk
|
||||
|
||||
except StopAsyncIteration:
|
||||
# Stream ended - handle logging
|
||||
if self.final_chunk is None and self.chunks:
|
||||
self.final_chunk = self.chunks[-1]
|
||||
await self._handle_stream_complete()
|
||||
raise
|
||||
|
||||
def _collect_text_from_chunk(self, chunk: Any) -> None:
|
||||
"""Extract text from a streaming chunk and add to collected parts."""
|
||||
try:
|
||||
chunk_dict = (
|
||||
chunk.model_dump(mode="json", exclude_none=True)
|
||||
if hasattr(chunk, "model_dump")
|
||||
else {}
|
||||
)
|
||||
text = A2ARequestUtils.extract_text_from_response(chunk_dict)
|
||||
if text:
|
||||
self.collected_text_parts.append(text)
|
||||
except Exception:
|
||||
verbose_logger.debug("Failed to extract text from A2A streaming chunk")
|
||||
|
||||
def _is_completed_chunk(self, chunk: Any) -> bool:
|
||||
"""Check if chunk indicates stream completion."""
|
||||
try:
|
||||
chunk_dict = (
|
||||
chunk.model_dump(mode="json", exclude_none=True)
|
||||
if hasattr(chunk, "model_dump")
|
||||
else {}
|
||||
)
|
||||
result = chunk_dict.get("result", {})
|
||||
if isinstance(result, dict):
|
||||
status = result.get("status", {})
|
||||
if isinstance(status, dict):
|
||||
return status.get("state") == "completed"
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
async def _handle_stream_complete(self) -> None:
|
||||
"""Handle logging and token counting when stream completes."""
|
||||
try:
|
||||
end_time = datetime.now()
|
||||
|
||||
# Calculate tokens from collected text
|
||||
input_message = A2ARequestUtils.get_input_message_from_request(self.request)
|
||||
input_text = A2ARequestUtils.extract_text_from_message(input_message)
|
||||
prompt_tokens = A2ARequestUtils.count_tokens(input_text)
|
||||
|
||||
# Use the last (most complete) text from chunks
|
||||
output_text = (
|
||||
self.collected_text_parts[-1] if self.collected_text_parts else ""
|
||||
)
|
||||
completion_tokens = A2ARequestUtils.count_tokens(output_text)
|
||||
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
# Create usage object
|
||||
usage = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
|
||||
# Set usage on logging obj
|
||||
self.logging_obj.model_call_details["usage"] = usage
|
||||
# Mark stream flag for downstream callbacks
|
||||
self.logging_obj.model_call_details["stream"] = False
|
||||
|
||||
# Calculate cost using A2ACostCalculator
|
||||
response_cost = A2ACostCalculator.calculate_a2a_cost(self.logging_obj)
|
||||
self.logging_obj.model_call_details["response_cost"] = response_cost
|
||||
|
||||
# Build result for logging
|
||||
result = self._build_logging_result(usage)
|
||||
|
||||
# Call success handlers - they will build standard_logging_object
|
||||
asyncio.create_task(
|
||||
self.logging_obj.async_success_handler(
|
||||
result=result,
|
||||
start_time=self.start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=None,
|
||||
)
|
||||
)
|
||||
|
||||
executor.submit(
|
||||
self.logging_obj.success_handler,
|
||||
result=result,
|
||||
cache_hit=None,
|
||||
start_time=self.start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A streaming completed: prompt_tokens={prompt_tokens}, "
|
||||
f"completion_tokens={completion_tokens}, total_tokens={total_tokens}, "
|
||||
f"response_cost={response_cost}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error in A2A streaming completion handler: {e}")
|
||||
|
||||
def _build_logging_result(self, usage: litellm.Usage) -> Dict[str, Any]:
|
||||
"""Build a result dict for logging."""
|
||||
result: Dict[str, Any] = {
|
||||
"id": getattr(self.request, "id", "unknown"),
|
||||
"jsonrpc": "2.0",
|
||||
"usage": usage.model_dump()
|
||||
if hasattr(usage, "model_dump")
|
||||
else dict(usage),
|
||||
}
|
||||
|
||||
# Add final chunk result if available
|
||||
if self.final_chunk:
|
||||
try:
|
||||
chunk_dict = self.final_chunk.model_dump(mode="json", exclude_none=True)
|
||||
result["result"] = chunk_dict.get("result", {})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
Utility functions for A2A protocol.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import SendMessageRequest, SendStreamingMessageRequest
|
||||
|
||||
|
||||
class A2ARequestUtils:
|
||||
"""Utility class for A2A request/response processing."""
|
||||
|
||||
@staticmethod
|
||||
def extract_text_from_message(message: Any) -> str:
|
||||
"""
|
||||
Extract text content from A2A message parts.
|
||||
|
||||
Args:
|
||||
message: A2A message dict or object with 'parts' containing text parts
|
||||
|
||||
Returns:
|
||||
Concatenated text from all text parts
|
||||
"""
|
||||
if message is None:
|
||||
return ""
|
||||
|
||||
# Handle both dict and object access
|
||||
if isinstance(message, dict):
|
||||
parts = message.get("parts", [])
|
||||
else:
|
||||
parts = getattr(message, "parts", []) or []
|
||||
|
||||
text_parts: List[str] = []
|
||||
for part in parts:
|
||||
if isinstance(part, dict):
|
||||
if part.get("kind") == "text":
|
||||
text_parts.append(part.get("text", ""))
|
||||
else:
|
||||
if getattr(part, "kind", None) == "text":
|
||||
text_parts.append(getattr(part, "text", ""))
|
||||
|
||||
return " ".join(text_parts)
|
||||
|
||||
@staticmethod
|
||||
def extract_text_from_response(response_dict: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract text content from A2A response result.
|
||||
|
||||
Args:
|
||||
response_dict: A2A response dict with 'result' containing message
|
||||
|
||||
Returns:
|
||||
Text from response message parts
|
||||
"""
|
||||
result = response_dict.get("result", {})
|
||||
if not isinstance(result, dict):
|
||||
return ""
|
||||
|
||||
message = result.get("message", {})
|
||||
return A2ARequestUtils.extract_text_from_message(message)
|
||||
|
||||
@staticmethod
|
||||
def get_input_message_from_request(
|
||||
request: "Union[SendMessageRequest, SendStreamingMessageRequest]",
|
||||
) -> Any:
|
||||
"""
|
||||
Extract the input message from an A2A request.
|
||||
|
||||
Args:
|
||||
request: The A2A SendMessageRequest or SendStreamingMessageRequest
|
||||
|
||||
Returns:
|
||||
The message object/dict or None
|
||||
"""
|
||||
params = getattr(request, "params", None)
|
||||
if params is None:
|
||||
return None
|
||||
return getattr(params, "message", None)
|
||||
|
||||
@staticmethod
|
||||
def count_tokens(text: str) -> int:
|
||||
"""
|
||||
Count tokens in text using litellm.token_counter.
|
||||
|
||||
Args:
|
||||
text: Text to count tokens for
|
||||
|
||||
Returns:
|
||||
Token count, or 0 if counting fails
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
try:
|
||||
return litellm.token_counter(text=text)
|
||||
except Exception:
|
||||
verbose_logger.debug("Failed to count tokens")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def calculate_usage_from_request_response(
|
||||
request: "Union[SendMessageRequest, SendStreamingMessageRequest]",
|
||||
response_dict: Dict[str, Any],
|
||||
) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Calculate token usage from A2A request and response.
|
||||
|
||||
Args:
|
||||
request: The A2A SendMessageRequest or SendStreamingMessageRequest
|
||||
response_dict: The A2A response as a dict
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_tokens, completion_tokens, total_tokens)
|
||||
"""
|
||||
# Count input tokens
|
||||
input_message = A2ARequestUtils.get_input_message_from_request(request)
|
||||
input_text = A2ARequestUtils.extract_text_from_message(input_message)
|
||||
prompt_tokens = A2ARequestUtils.count_tokens(input_text)
|
||||
|
||||
# Count output tokens
|
||||
output_text = A2ARequestUtils.extract_text_from_response(response_dict)
|
||||
completion_tokens = A2ARequestUtils.count_tokens(output_text)
|
||||
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
return prompt_tokens, completion_tokens, total_tokens
|
||||
|
||||
|
||||
# Backwards compatibility aliases
|
||||
def extract_text_from_a2a_message(message: Any) -> str:
|
||||
return A2ARequestUtils.extract_text_from_message(message)
|
||||
|
||||
|
||||
def extract_text_from_a2a_response(response_dict: Dict[str, Any]) -> str:
|
||||
return A2ARequestUtils.extract_text_from_response(response_dict)
|
||||
Reference in New Issue
Block a user