chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,73 @@
"""
LiteLLM A2A - Wrapper for invoking A2A protocol agents.
This module provides a thin wrapper around the official `a2a` SDK that:
- Handles httpx client creation and agent card resolution
- Adds LiteLLM logging via @client decorator
- Matches the A2A SDK interface (SendMessageRequest, SendMessageResponse, etc.)
Example usage (standalone functions with @client decorator):
```python
from litellm.a2a_protocol import asend_message
from a2a.types import SendMessageRequest, MessageSendParams
from uuid import uuid4
request = SendMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={
"role": "user",
"parts": [{"kind": "text", "text": "Hello!"}],
"messageId": uuid4().hex,
}
)
)
response = await asend_message(
base_url="http://localhost:10001",
request=request,
)
print(response.model_dump(mode='json', exclude_none=True))
```
Example usage (class-based):
```python
from litellm.a2a_protocol import A2AClient
client = A2AClient(base_url="http://localhost:10001")
response = await client.send_message(request)
```
"""
from litellm.a2a_protocol.client import A2AClient
from litellm.a2a_protocol.exceptions import (
A2AAgentCardError,
A2AConnectionError,
A2AError,
A2ALocalhostURLError,
)
from litellm.a2a_protocol.main import (
aget_agent_card,
asend_message,
asend_message_streaming,
create_a2a_client,
send_message,
)
from litellm.types.agents import LiteLLMSendMessageResponse
__all__ = [
# Client
"A2AClient",
# Functions
"asend_message",
"send_message",
"asend_message_streaming",
"aget_agent_card",
"create_a2a_client",
# Response types
"LiteLLMSendMessageResponse",
# Exceptions
"A2AError",
"A2AConnectionError",
"A2AAgentCardError",
"A2ALocalhostURLError",
]

View File

@@ -0,0 +1,144 @@
"""
Custom A2A Card Resolver for LiteLLM.
Extends the A2A SDK's card resolver to support multiple well-known paths.
"""
from typing import TYPE_CHECKING, Any, Dict, Optional
from litellm._logging import verbose_logger
from litellm.constants import LOCALHOST_URL_PATTERNS
if TYPE_CHECKING:
from a2a.types import AgentCard
# Runtime imports with availability check
_A2ACardResolver: Any = None
AGENT_CARD_WELL_KNOWN_PATH: str = "/.well-known/agent-card.json"
PREV_AGENT_CARD_WELL_KNOWN_PATH: str = "/.well-known/agent.json"
try:
from a2a.client import A2ACardResolver as _A2ACardResolver # type: ignore[no-redef]
from a2a.utils.constants import ( # type: ignore[no-redef]
AGENT_CARD_WELL_KNOWN_PATH,
PREV_AGENT_CARD_WELL_KNOWN_PATH,
)
except ImportError:
pass
def is_localhost_or_internal_url(url: Optional[str]) -> bool:
"""
Check if a URL is a localhost or internal URL.
This detects common development URLs that are accidentally left in
agent cards when deploying to production.
Args:
url: The URL to check
Returns:
True if the URL is localhost/internal
"""
if not url:
return False
url_lower = url.lower()
return any(pattern in url_lower for pattern in LOCALHOST_URL_PATTERNS)
def fix_agent_card_url(agent_card: "AgentCard", base_url: str) -> "AgentCard":
"""
Fix the agent card URL if it contains a localhost/internal address.
Many A2A agents are deployed with agent cards that contain internal URLs
like "http://0.0.0.0:8001/" or "http://localhost:8000/". This function
replaces such URLs with the provided base_url.
Args:
agent_card: The agent card to fix
base_url: The base URL to use as replacement
Returns:
The agent card with the URL fixed if necessary
"""
card_url = getattr(agent_card, "url", None)
if card_url and is_localhost_or_internal_url(card_url):
# Normalize base_url to ensure it ends with /
fixed_url = base_url.rstrip("/") + "/"
agent_card.url = fixed_url
return agent_card
class LiteLLMA2ACardResolver(_A2ACardResolver): # type: ignore[misc]
"""
Custom A2A card resolver that supports multiple well-known paths.
Extends the base A2ACardResolver to try both:
- /.well-known/agent-card.json (standard)
- /.well-known/agent.json (previous/alternative)
"""
async def get_agent_card(
self,
relative_card_path: Optional[str] = None,
http_kwargs: Optional[Dict[str, Any]] = None,
) -> "AgentCard":
"""
Fetch the agent card, trying multiple well-known paths.
First tries the standard path, then falls back to the previous path.
Args:
relative_card_path: Optional path to the agent card endpoint.
If None, tries both well-known paths.
http_kwargs: Optional dictionary of keyword arguments to pass to httpx.get
Returns:
AgentCard from the A2A agent
Raises:
A2AClientHTTPError or A2AClientJSONError if both paths fail
"""
# If a specific path is provided, use the parent implementation
if relative_card_path is not None:
return await super().get_agent_card(
relative_card_path=relative_card_path,
http_kwargs=http_kwargs,
)
# Try both well-known paths
paths = [
AGENT_CARD_WELL_KNOWN_PATH,
PREV_AGENT_CARD_WELL_KNOWN_PATH,
]
last_error = None
for path in paths:
try:
verbose_logger.debug(
f"Attempting to fetch agent card from {self.base_url}{path}"
)
return await super().get_agent_card(
relative_card_path=path,
http_kwargs=http_kwargs,
)
except Exception as e:
verbose_logger.debug(
f"Failed to fetch agent card from {self.base_url}{path}: {e}"
)
last_error = e
continue
# If we get here, all paths failed - re-raise the last error
if last_error is not None:
raise last_error
# This shouldn't happen, but just in case
raise Exception(
f"Failed to fetch agent card from {self.base_url}. "
f"Tried paths: {', '.join(paths)}"
)

View File

@@ -0,0 +1,109 @@
"""
LiteLLM A2A Client class.
Provides a class-based interface for A2A agent invocation.
"""
from typing import TYPE_CHECKING, AsyncIterator, Dict, Optional
from litellm.types.agents import LiteLLMSendMessageResponse
if TYPE_CHECKING:
from a2a.client import A2AClient as A2AClientType
from a2a.types import (
AgentCard,
SendMessageRequest,
SendStreamingMessageRequest,
SendStreamingMessageResponse,
)
class A2AClient:
"""
LiteLLM wrapper for A2A agent invocation.
Creates the underlying A2A client once on first use and reuses it.
Example:
```python
from litellm.a2a_protocol import A2AClient
from a2a.types import SendMessageRequest, MessageSendParams
from uuid import uuid4
client = A2AClient(base_url="http://localhost:10001")
request = SendMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={
"role": "user",
"parts": [{"kind": "text", "text": "Hello!"}],
"messageId": uuid4().hex,
}
)
)
response = await client.send_message(request)
```
"""
def __init__(
self,
base_url: str,
timeout: float = 60.0,
extra_headers: Optional[Dict[str, str]] = None,
):
"""
Initialize the A2A client wrapper.
Args:
base_url: The base URL of the A2A agent (e.g., "http://localhost:10001")
timeout: Request timeout in seconds (default: 60.0)
extra_headers: Optional additional headers to include in requests
"""
self.base_url = base_url
self.timeout = timeout
self.extra_headers = extra_headers
self._a2a_client: Optional["A2AClientType"] = None
async def _get_client(self) -> "A2AClientType":
"""Get or create the underlying A2A client."""
if self._a2a_client is None:
from litellm.a2a_protocol.main import create_a2a_client
self._a2a_client = await create_a2a_client(
base_url=self.base_url,
timeout=self.timeout,
extra_headers=self.extra_headers,
)
return self._a2a_client
async def get_agent_card(self) -> "AgentCard":
"""Fetch the agent card from the server."""
from litellm.a2a_protocol.main import aget_agent_card
return await aget_agent_card(
base_url=self.base_url,
timeout=self.timeout,
extra_headers=self.extra_headers,
)
async def send_message(
self, request: "SendMessageRequest"
) -> LiteLLMSendMessageResponse:
"""Send a message to the A2A agent."""
from litellm.a2a_protocol.main import asend_message
a2a_client = await self._get_client()
return await asend_message(a2a_client=a2a_client, request=request)
async def send_message_streaming(
self, request: "SendStreamingMessageRequest"
) -> AsyncIterator["SendStreamingMessageResponse"]:
"""Send a streaming message to the A2A agent."""
from litellm.a2a_protocol.main import asend_message_streaming
a2a_client = await self._get_client()
async for chunk in asend_message_streaming(
a2a_client=a2a_client, request=request
):
yield chunk

View File

@@ -0,0 +1,107 @@
"""
Cost calculator for A2A (Agent-to-Agent) calls.
Supports dynamic cost parameters that allow platform owners
to define custom costs per agent query or per token.
"""
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import (
Logging as LitellmLoggingObject,
)
else:
LitellmLoggingObject = Any
class A2ACostCalculator:
@staticmethod
def calculate_a2a_cost(
litellm_logging_obj: Optional[LitellmLoggingObject],
) -> float:
"""
Calculate the cost of an A2A send_message call.
Supports multiple cost parameters for platform owners:
- cost_per_query: Fixed cost per query
- input_cost_per_token + output_cost_per_token: Token-based pricing
Priority order:
1. response_cost - if set directly (backward compatibility)
2. cost_per_query - fixed cost per query
3. input_cost_per_token + output_cost_per_token - token-based cost
4. Default to 0.0
Args:
litellm_logging_obj: The LiteLLM logging object containing call details
Returns:
float: The cost of the A2A call
"""
if litellm_logging_obj is None:
return 0.0
model_call_details = litellm_logging_obj.model_call_details
# Check if user set a custom response cost (backward compatibility)
response_cost = model_call_details.get("response_cost", None)
if response_cost is not None:
return float(response_cost)
# Get litellm_params for cost parameters
litellm_params = model_call_details.get("litellm_params", {}) or {}
# Check for cost_per_query (fixed cost per query)
if litellm_params.get("cost_per_query") is not None:
return float(litellm_params["cost_per_query"])
# Check for token-based pricing
input_cost_per_token = litellm_params.get("input_cost_per_token")
output_cost_per_token = litellm_params.get("output_cost_per_token")
if input_cost_per_token is not None or output_cost_per_token is not None:
return A2ACostCalculator._calculate_token_based_cost(
model_call_details=model_call_details,
input_cost_per_token=input_cost_per_token,
output_cost_per_token=output_cost_per_token,
)
# Default to 0.0 for A2A calls
return 0.0
@staticmethod
def _calculate_token_based_cost(
model_call_details: dict,
input_cost_per_token: Optional[float],
output_cost_per_token: Optional[float],
) -> float:
"""
Calculate cost based on token usage and per-token pricing.
Args:
model_call_details: The model call details containing usage
input_cost_per_token: Cost per input token (can be None, defaults to 0)
output_cost_per_token: Cost per output token (can be None, defaults to 0)
Returns:
float: The calculated cost
"""
# Get usage from model_call_details
usage = model_call_details.get("usage")
if usage is None:
return 0.0
# Get token counts
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
# Calculate costs
input_cost = prompt_tokens * (
float(input_cost_per_token) if input_cost_per_token else 0.0
)
output_cost = completion_tokens * (
float(output_cost_per_token) if output_cost_per_token else 0.0
)
return input_cost + output_cost

View File

@@ -0,0 +1,203 @@
"""
A2A Protocol Exception Mapping Utils.
Maps A2A SDK exceptions to LiteLLM A2A exception types.
"""
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_logger
from litellm.a2a_protocol.card_resolver import (
fix_agent_card_url,
is_localhost_or_internal_url,
)
from litellm.a2a_protocol.exceptions import (
A2AAgentCardError,
A2AConnectionError,
A2AError,
A2ALocalhostURLError,
)
from litellm.constants import CONNECTION_ERROR_PATTERNS
if TYPE_CHECKING:
from a2a.client import A2AClient as A2AClientType
# Runtime import
A2A_SDK_AVAILABLE = False
try:
from a2a.client import A2AClient as _A2AClient # type: ignore[no-redef]
A2A_SDK_AVAILABLE = True
except ImportError:
_A2AClient = None # type: ignore[assignment, misc]
class A2AExceptionCheckers:
"""
Helper class for checking various A2A error conditions.
"""
@staticmethod
def is_connection_error(error_str: str) -> bool:
"""
Check if an error string indicates a connection error.
Args:
error_str: The error string to check
Returns:
True if the error indicates a connection issue
"""
if not isinstance(error_str, str):
return False
error_str_lower = error_str.lower()
return any(pattern in error_str_lower for pattern in CONNECTION_ERROR_PATTERNS)
@staticmethod
def is_localhost_url(url: Optional[str]) -> bool:
"""
Check if a URL is a localhost/internal URL.
Args:
url: The URL to check
Returns:
True if the URL is localhost/internal
"""
return is_localhost_or_internal_url(url)
@staticmethod
def is_agent_card_error(error_str: str) -> bool:
"""
Check if an error string indicates an agent card error.
Args:
error_str: The error string to check
Returns:
True if the error is related to agent card fetching/parsing
"""
if not isinstance(error_str, str):
return False
error_str_lower = error_str.lower()
agent_card_patterns = [
"agent card",
"agent-card",
".well-known",
"card not found",
"invalid agent",
]
return any(pattern in error_str_lower for pattern in agent_card_patterns)
def map_a2a_exception(
original_exception: Exception,
card_url: Optional[str] = None,
api_base: Optional[str] = None,
model: Optional[str] = None,
) -> Exception:
"""
Map an A2A SDK exception to a LiteLLM A2A exception type.
Args:
original_exception: The original exception from the A2A SDK
card_url: The URL from the agent card (if available)
api_base: The original API base URL
model: The model/agent name
Returns:
A mapped LiteLLM A2A exception
Raises:
A2ALocalhostURLError: If the error is a connection error to a localhost URL
A2AConnectionError: If the error is a general connection error
A2AAgentCardError: If the error is related to agent card issues
A2AError: For other A2A-related errors
"""
error_str = str(original_exception)
# Check for localhost URL connection error (special case - retryable)
if (
card_url
and api_base
and A2AExceptionCheckers.is_localhost_url(card_url)
and A2AExceptionCheckers.is_connection_error(error_str)
):
raise A2ALocalhostURLError(
localhost_url=card_url,
base_url=api_base,
original_error=original_exception,
model=model,
)
# Check for agent card errors
if A2AExceptionCheckers.is_agent_card_error(error_str):
raise A2AAgentCardError(
message=error_str,
url=api_base,
model=model,
)
# Check for general connection errors
if A2AExceptionCheckers.is_connection_error(error_str):
raise A2AConnectionError(
message=error_str,
url=card_url or api_base,
model=model,
)
# Default: wrap in generic A2AError
raise A2AError(
message=error_str,
model=model,
)
def handle_a2a_localhost_retry(
error: A2ALocalhostURLError,
agent_card: Any,
a2a_client: "A2AClientType",
is_streaming: bool = False,
) -> "A2AClientType":
"""
Handle A2ALocalhostURLError by fixing the URL and creating a new client.
This is called when we catch an A2ALocalhostURLError and want to retry
with the corrected URL.
Args:
error: The localhost URL error
agent_card: The agent card object to fix
a2a_client: The current A2A client
is_streaming: Whether this is a streaming request (for logging)
Returns:
A new A2A client with the fixed URL
Raises:
ImportError: If the A2A SDK is not installed
"""
if not A2A_SDK_AVAILABLE or _A2AClient is None:
raise ImportError(
"A2A SDK is required for localhost retry handling. "
"Install it with: pip install a2a"
)
request_type = "streaming " if is_streaming else ""
verbose_logger.warning(
f"A2A {request_type}request to '{error.localhost_url}' failed: {error.original_error}. "
f"Agent card contains localhost/internal URL. "
f"Retrying with base_url '{error.base_url}'."
)
# Fix the agent card URL
fix_agent_card_url(agent_card, error.base_url)
# Create a new client with the fixed agent card (transport caches URL)
return _A2AClient(
httpx_client=a2a_client._transport.httpx_client, # type: ignore[union-attr]
agent_card=agent_card,
)

View File

@@ -0,0 +1,150 @@
"""
A2A Protocol Exceptions.
Custom exception types for A2A protocol operations, following LiteLLM's exception pattern.
"""
from typing import Optional
import httpx
class A2AError(Exception):
"""
Base exception for A2A protocol errors.
Follows the same pattern as LiteLLM's main exceptions.
"""
def __init__(
self,
message: str,
status_code: int = 500,
llm_provider: str = "a2a_agent",
model: Optional[str] = None,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.status_code = status_code
self.message = f"litellm.A2AError: {message}"
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
self.response = response or httpx.Response(
status_code=self.status_code,
request=httpx.Request(method="POST", url="https://litellm.ai"),
)
super().__init__(self.message)
def __str__(self) -> str:
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self) -> str:
return self.__str__()
class A2AConnectionError(A2AError):
"""
Raised when connection to an A2A agent fails.
This typically occurs when:
- The agent is unreachable
- The agent card contains a localhost/internal URL
- Network issues prevent connection
"""
def __init__(
self,
message: str,
url: Optional[str] = None,
model: Optional[str] = None,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.url = url
super().__init__(
message=message,
status_code=503,
llm_provider="a2a_agent",
model=model,
response=response,
litellm_debug_info=litellm_debug_info,
max_retries=max_retries,
num_retries=num_retries,
)
class A2AAgentCardError(A2AError):
"""
Raised when there's an issue with the agent card.
This includes:
- Failed to fetch agent card
- Invalid agent card format
- Missing required fields
"""
def __init__(
self,
message: str,
url: Optional[str] = None,
model: Optional[str] = None,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
):
self.url = url
super().__init__(
message=message,
status_code=404,
llm_provider="a2a_agent",
model=model,
response=response,
litellm_debug_info=litellm_debug_info,
)
class A2ALocalhostURLError(A2AConnectionError):
"""
Raised when an agent card contains a localhost/internal URL.
Many A2A agents are deployed with agent cards that contain internal URLs
like "http://0.0.0.0:8001/" or "http://localhost:8000/". This error
indicates that the URL needs to be corrected and the request should be retried.
Attributes:
localhost_url: The localhost/internal URL found in the agent card
base_url: The public base URL that should be used instead
original_error: The original connection error that was raised
"""
def __init__(
self,
localhost_url: str,
base_url: str,
original_error: Optional[Exception] = None,
model: Optional[str] = None,
):
self.localhost_url = localhost_url
self.base_url = base_url
self.original_error = original_error
message = (
f"Agent card contains localhost/internal URL '{localhost_url}'. "
f"Retrying with base URL '{base_url}'."
)
super().__init__(
message=message,
url=localhost_url,
model=model,
)

View File

@@ -0,0 +1,74 @@
# A2A to LiteLLM Completion Bridge
Routes A2A protocol requests through `litellm.acompletion`, enabling any LiteLLM-supported provider to be invoked via A2A.
## Flow
```
A2A Request → Transform → litellm.acompletion → Transform → A2A Response
```
## SDK Usage
Use the existing `asend_message` and `asend_message_streaming` functions with `litellm_params`:
```python
from litellm.a2a_protocol import asend_message, asend_message_streaming
from a2a.types import SendMessageRequest, SendStreamingMessageRequest, MessageSendParams
from uuid import uuid4
# Non-streaming
request = SendMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
)
)
response = await asend_message(
request=request,
api_base="http://localhost:2024",
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
)
# Streaming
stream_request = SendStreamingMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
)
)
async for chunk in asend_message_streaming(
request=stream_request,
api_base="http://localhost:2024",
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
):
print(chunk)
```
## Proxy Usage
Configure an agent with `custom_llm_provider` in `litellm_params`:
```yaml
agents:
- agent_name: my-langgraph-agent
agent_card_params:
name: "LangGraph Agent"
url: "http://localhost:2024" # Used as api_base
litellm_params:
custom_llm_provider: langgraph
model: agent
```
When an A2A request hits `/a2a/{agent_id}/message/send`, the bridge:
1. Detects `custom_llm_provider` in agent's `litellm_params`
2. Transforms A2A message → OpenAI messages
3. Calls `litellm.acompletion(model="langgraph/agent", api_base="http://localhost:2024")`
4. Transforms response → A2A format
## Classes
- `A2ACompletionBridgeTransformation` - Static methods for message format conversion
- `A2ACompletionBridgeHandler` - Static methods for handling requests (streaming/non-streaming)

View File

@@ -0,0 +1,23 @@
"""
A2A to LiteLLM Completion Bridge.
This module provides transformation between A2A protocol messages and
LiteLLM completion API, enabling any LiteLLM-supported provider to be
invoked via the A2A protocol.
"""
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
A2ACompletionBridgeHandler,
handle_a2a_completion,
handle_a2a_completion_streaming,
)
from litellm.a2a_protocol.litellm_completion_bridge.transformation import (
A2ACompletionBridgeTransformation,
)
__all__ = [
"A2ACompletionBridgeTransformation",
"A2ACompletionBridgeHandler",
"handle_a2a_completion",
"handle_a2a_completion_streaming",
]

View File

@@ -0,0 +1,299 @@
"""
Handler for A2A to LiteLLM completion bridge.
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
A2A Streaming Events (in order):
1. Task event (kind: "task") - Initial task creation with status "submitted"
2. Status update (kind: "status-update") - Status change to "working"
3. Artifact update (kind: "artifact-update") - Content/artifact delivery
4. Status update (kind: "status-update") - Final status "completed" with final=true
"""
from typing import Any, AsyncIterator, Dict, Optional
import litellm
from litellm._logging import verbose_logger
from litellm.a2a_protocol.litellm_completion_bridge.transformation import (
A2ACompletionBridgeTransformation,
A2AStreamingContext,
)
from litellm.a2a_protocol.providers.config_manager import A2AProviderConfigManager
class A2ACompletionBridgeHandler:
"""
Static methods for handling A2A requests via LiteLLM completion.
"""
@staticmethod
async def handle_non_streaming(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> Dict[str, Any]:
"""
Handle non-streaming A2A request via litellm.acompletion.
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
api_base: API base URL from agent_card_params
Returns:
A2A SendMessageResponse dict
"""
# Get provider config for custom_llm_provider
custom_llm_provider = litellm_params.get("custom_llm_provider")
a2a_provider_config = A2AProviderConfigManager.get_provider_config(
custom_llm_provider=custom_llm_provider
)
# If provider config exists, use it
if a2a_provider_config is not None:
if api_base is None:
raise ValueError(f"api_base is required for {custom_llm_provider}")
verbose_logger.info(f"A2A: Using provider config for {custom_llm_provider}")
response_data = await a2a_provider_config.handle_non_streaming(
request_id=request_id,
params=params,
api_base=api_base,
)
return response_data
# Extract message from params
message = params.get("message", {})
# Transform A2A message to OpenAI format
openai_messages = (
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
)
# Get completion params
custom_llm_provider = litellm_params.get("custom_llm_provider")
model = litellm_params.get("model", "agent")
# Build full model string if provider specified
# Skip prepending if model already starts with the provider prefix
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
full_model = f"{custom_llm_provider}/{model}"
else:
full_model = model
verbose_logger.info(
f"A2A completion bridge: model={full_model}, api_base={api_base}"
)
# Build completion params dict
completion_params = {
"model": full_model,
"messages": openai_messages,
"api_base": api_base,
"stream": False,
}
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
litellm_params_to_add = {
k: v
for k, v in litellm_params.items()
if k not in ("model", "custom_llm_provider")
}
completion_params.update(litellm_params_to_add)
# Call litellm.acompletion
response = await litellm.acompletion(**completion_params)
# Transform response to A2A format
a2a_response = (
A2ACompletionBridgeTransformation.openai_response_to_a2a_response(
response=response,
request_id=request_id,
)
)
verbose_logger.info(f"A2A completion bridge completed: request_id={request_id}")
return a2a_response
@staticmethod
async def handle_streaming(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""
Handle streaming A2A request via litellm.acompletion with stream=True.
Emits proper A2A streaming events:
1. Task event (kind: "task") - Initial task with status "submitted"
2. Status update (kind: "status-update") - Status "working"
3. Artifact update (kind: "artifact-update") - Content delivery
4. Status update (kind: "status-update") - Final "completed" status
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
api_base: API base URL from agent_card_params
Yields:
A2A streaming response events
"""
# Get provider config for custom_llm_provider
custom_llm_provider = litellm_params.get("custom_llm_provider")
a2a_provider_config = A2AProviderConfigManager.get_provider_config(
custom_llm_provider=custom_llm_provider
)
# If provider config exists, use it
if a2a_provider_config is not None:
if api_base is None:
raise ValueError(f"api_base is required for {custom_llm_provider}")
verbose_logger.info(
f"A2A: Using provider config for {custom_llm_provider} (streaming)"
)
async for chunk in a2a_provider_config.handle_streaming(
request_id=request_id,
params=params,
api_base=api_base,
):
yield chunk
return
# Extract message from params
message = params.get("message", {})
# Create streaming context
ctx = A2AStreamingContext(
request_id=request_id,
input_message=message,
)
# Transform A2A message to OpenAI format
openai_messages = (
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
)
# Get completion params
custom_llm_provider = litellm_params.get("custom_llm_provider")
model = litellm_params.get("model", "agent")
# Build full model string if provider specified
# Skip prepending if model already starts with the provider prefix
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
full_model = f"{custom_llm_provider}/{model}"
else:
full_model = model
verbose_logger.info(
f"A2A completion bridge streaming: model={full_model}, api_base={api_base}"
)
# Build completion params dict
completion_params = {
"model": full_model,
"messages": openai_messages,
"api_base": api_base,
"stream": True,
}
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
litellm_params_to_add = {
k: v
for k, v in litellm_params.items()
if k not in ("model", "custom_llm_provider")
}
completion_params.update(litellm_params_to_add)
# 1. Emit initial task event (kind: "task", status: "submitted")
task_event = A2ACompletionBridgeTransformation.create_task_event(ctx)
yield task_event
# 2. Emit status update (kind: "status-update", status: "working")
working_event = A2ACompletionBridgeTransformation.create_status_update_event(
ctx=ctx,
state="working",
final=False,
message_text="Processing request...",
)
yield working_event
# Call litellm.acompletion with streaming
response = await litellm.acompletion(**completion_params)
# 3. Accumulate content and emit artifact update
accumulated_text = ""
chunk_count = 0
async for chunk in response: # type: ignore[union-attr]
chunk_count += 1
# Extract delta content
content = ""
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, "delta") and choice.delta:
content = choice.delta.content or ""
if content:
accumulated_text += content
# Emit artifact update with accumulated content
if accumulated_text:
artifact_event = (
A2ACompletionBridgeTransformation.create_artifact_update_event(
ctx=ctx,
text=accumulated_text,
)
)
yield artifact_event
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
completed_event = A2ACompletionBridgeTransformation.create_status_update_event(
ctx=ctx,
state="completed",
final=True,
)
yield completed_event
verbose_logger.info(
f"A2A completion bridge streaming completed: request_id={request_id}, chunks={chunk_count}"
)
# Convenience functions that delegate to the class methods
async def handle_a2a_completion(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> Dict[str, Any]:
"""Convenience function for non-streaming A2A completion."""
return await A2ACompletionBridgeHandler.handle_non_streaming(
request_id=request_id,
params=params,
litellm_params=litellm_params,
api_base=api_base,
)
async def handle_a2a_completion_streaming(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""Convenience function for streaming A2A completion."""
async for chunk in A2ACompletionBridgeHandler.handle_streaming(
request_id=request_id,
params=params,
litellm_params=litellm_params,
api_base=api_base,
):
yield chunk

View File

@@ -0,0 +1,284 @@
"""
Transformation utilities for A2A <-> OpenAI message format conversion.
A2A Message Format:
{
"role": "user",
"parts": [{"kind": "text", "text": "Hello!"}],
"messageId": "abc123"
}
OpenAI Message Format:
{"role": "user", "content": "Hello!"}
A2A Streaming Events:
- Task event (kind: "task") - Initial task creation with status "submitted"
- Status update (kind: "status-update") - Status changes (working, completed)
- Artifact update (kind: "artifact-update") - Content/artifact delivery
"""
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from uuid import uuid4
from litellm._logging import verbose_logger
class A2AStreamingContext:
"""
Context holder for A2A streaming state.
Tracks task_id, context_id, and message accumulation.
"""
def __init__(self, request_id: str, input_message: Dict[str, Any]):
self.request_id = request_id
self.task_id = str(uuid4())
self.context_id = str(uuid4())
self.input_message = input_message
self.accumulated_text = ""
self.has_emitted_task = False
self.has_emitted_working = False
class A2ACompletionBridgeTransformation:
"""
Static methods for transforming between A2A and OpenAI message formats.
"""
@staticmethod
def a2a_message_to_openai_messages(
a2a_message: Dict[str, Any],
) -> List[Dict[str, str]]:
"""
Transform an A2A message to OpenAI message format.
Args:
a2a_message: A2A message with role, parts, and messageId
Returns:
List of OpenAI-format messages
"""
role = a2a_message.get("role", "user")
parts = a2a_message.get("parts", [])
# Map A2A roles to OpenAI roles
openai_role = role
if role == "user":
openai_role = "user"
elif role == "assistant":
openai_role = "assistant"
elif role == "system":
openai_role = "system"
# Extract text content from parts
content_parts = []
for part in parts:
kind = part.get("kind", "")
if kind == "text":
text = part.get("text", "")
content_parts.append(text)
content = "\n".join(content_parts) if content_parts else ""
verbose_logger.debug(
f"A2A -> OpenAI transform: role={role} -> {openai_role}, content_length={len(content)}"
)
return [{"role": openai_role, "content": content}]
@staticmethod
def openai_response_to_a2a_response(
response: Any,
request_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Transform a LiteLLM ModelResponse to A2A SendMessageResponse format.
Args:
response: LiteLLM ModelResponse object
request_id: Original A2A request ID
Returns:
A2A SendMessageResponse dict
"""
# Extract content from response
content = ""
if hasattr(response, "choices") and response.choices:
choice = response.choices[0]
if hasattr(choice, "message") and choice.message:
content = choice.message.content or ""
# Build A2A message
a2a_message = {
"role": "agent",
"parts": [{"kind": "text", "text": content}],
"messageId": uuid4().hex,
}
# Build A2A response
a2a_response = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"message": a2a_message,
},
}
verbose_logger.debug(f"OpenAI -> A2A transform: content_length={len(content)}")
return a2a_response
@staticmethod
def _get_timestamp() -> str:
"""Get current timestamp in ISO format with timezone."""
return datetime.now(timezone.utc).isoformat()
@staticmethod
def create_task_event(
ctx: A2AStreamingContext,
) -> Dict[str, Any]:
"""
Create the initial task event with status 'submitted'.
This is the first event emitted in an A2A streaming response.
"""
return {
"id": ctx.request_id,
"jsonrpc": "2.0",
"result": {
"contextId": ctx.context_id,
"history": [
{
"contextId": ctx.context_id,
"kind": "message",
"messageId": ctx.input_message.get("messageId", uuid4().hex),
"parts": ctx.input_message.get("parts", []),
"role": ctx.input_message.get("role", "user"),
"taskId": ctx.task_id,
}
],
"id": ctx.task_id,
"kind": "task",
"status": {
"state": "submitted",
},
},
}
@staticmethod
def create_status_update_event(
ctx: A2AStreamingContext,
state: str,
final: bool = False,
message_text: Optional[str] = None,
) -> Dict[str, Any]:
"""
Create a status update event.
Args:
ctx: Streaming context
state: Status state ('working', 'completed')
final: Whether this is the final event
message_text: Optional message text for 'working' status
"""
status: Dict[str, Any] = {
"state": state,
"timestamp": A2ACompletionBridgeTransformation._get_timestamp(),
}
# Add message for 'working' status
if state == "working" and message_text:
status["message"] = {
"contextId": ctx.context_id,
"kind": "message",
"messageId": str(uuid4()),
"parts": [{"kind": "text", "text": message_text}],
"role": "agent",
"taskId": ctx.task_id,
}
return {
"id": ctx.request_id,
"jsonrpc": "2.0",
"result": {
"contextId": ctx.context_id,
"final": final,
"kind": "status-update",
"status": status,
"taskId": ctx.task_id,
},
}
@staticmethod
def create_artifact_update_event(
ctx: A2AStreamingContext,
text: str,
) -> Dict[str, Any]:
"""
Create an artifact update event with content.
Args:
ctx: Streaming context
text: The text content for the artifact
"""
return {
"id": ctx.request_id,
"jsonrpc": "2.0",
"result": {
"artifact": {
"artifactId": str(uuid4()),
"name": "response",
"parts": [{"kind": "text", "text": text}],
},
"contextId": ctx.context_id,
"kind": "artifact-update",
"taskId": ctx.task_id,
},
}
@staticmethod
def openai_chunk_to_a2a_chunk(
chunk: Any,
request_id: Optional[str] = None,
is_final: bool = False,
) -> Optional[Dict[str, Any]]:
"""
Transform a LiteLLM streaming chunk to A2A streaming format.
NOTE: This method is deprecated for streaming. Use the event-based
methods (create_task_event, create_status_update_event,
create_artifact_update_event) instead for proper A2A streaming.
Args:
chunk: LiteLLM ModelResponse chunk
request_id: Original A2A request ID
is_final: Whether this is the final chunk
Returns:
A2A streaming chunk dict or None if no content
"""
# Extract delta content
content = ""
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, "delta") and choice.delta:
content = choice.delta.content or ""
if not content and not is_final:
return None
# Build A2A streaming chunk (legacy format)
a2a_chunk = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"message": {
"role": "agent",
"parts": [{"kind": "text", "text": content}],
"messageId": uuid4().hex,
},
"final": is_final,
},
}
return a2a_chunk

View File

@@ -0,0 +1,744 @@
"""
LiteLLM A2A SDK functions.
Provides standalone functions with @client decorator for LiteLLM logging integration.
"""
import asyncio
import datetime
import uuid
from typing import TYPE_CHECKING, Any, AsyncIterator, Coroutine, Dict, Optional, Union
import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.a2a_protocol.streaming_iterator import A2AStreamingIterator
from litellm.a2a_protocol.utils import A2ARequestUtils
from litellm.constants import DEFAULT_A2A_AGENT_TIMEOUT
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.agents import LiteLLMSendMessageResponse
from litellm.utils import client
if TYPE_CHECKING:
from a2a.client import A2AClient as A2AClientType
from a2a.types import AgentCard, SendMessageRequest, SendStreamingMessageRequest
# Runtime imports with availability check
A2A_SDK_AVAILABLE = False
A2ACardResolver: Any = None
_A2AClient: Any = None
try:
from a2a.client import A2AClient as _A2AClient # type: ignore[no-redef]
A2A_SDK_AVAILABLE = True
except ImportError:
pass
# Import our custom card resolver that supports multiple well-known paths
from litellm.a2a_protocol.card_resolver import LiteLLMA2ACardResolver
from litellm.a2a_protocol.exception_mapping_utils import (
handle_a2a_localhost_retry,
map_a2a_exception,
)
from litellm.a2a_protocol.exceptions import A2ALocalhostURLError
# Use our custom resolver instead of the default A2A SDK resolver
A2ACardResolver = LiteLLMA2ACardResolver
def _set_usage_on_logging_obj(
kwargs: Dict[str, Any],
prompt_tokens: int,
completion_tokens: int,
) -> None:
"""
Set usage on litellm_logging_obj for standard logging payload.
Args:
kwargs: The kwargs dict containing litellm_logging_obj
prompt_tokens: Number of input tokens
completion_tokens: Number of output tokens
"""
litellm_logging_obj = kwargs.get("litellm_logging_obj")
if litellm_logging_obj is not None:
usage = litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
litellm_logging_obj.model_call_details["usage"] = usage
def _set_agent_id_on_logging_obj(
kwargs: Dict[str, Any],
agent_id: Optional[str],
) -> None:
"""
Set agent_id on litellm_logging_obj for SpendLogs tracking.
Args:
kwargs: The kwargs dict containing litellm_logging_obj
agent_id: The A2A agent ID
"""
if agent_id is None:
return
litellm_logging_obj = kwargs.get("litellm_logging_obj")
if litellm_logging_obj is not None:
# Set agent_id directly on model_call_details (same pattern as custom_llm_provider)
litellm_logging_obj.model_call_details["agent_id"] = agent_id
def _get_a2a_model_info(a2a_client: Any, kwargs: Dict[str, Any]) -> str:
"""
Extract agent info and set model/custom_llm_provider for cost tracking.
Sets model info on the litellm_logging_obj if available.
Returns the agent name for logging.
"""
agent_name = "unknown"
# Try to get agent card from our stored attribute first, then fallback to SDK attribute
agent_card = getattr(a2a_client, "_litellm_agent_card", None)
if agent_card is None:
agent_card = getattr(a2a_client, "agent_card", None)
if agent_card is not None:
agent_name = getattr(agent_card, "name", "unknown") or "unknown"
# Build model string
model = f"a2a_agent/{agent_name}"
custom_llm_provider = "a2a_agent"
# Set on litellm_logging_obj if available (for standard logging payload)
litellm_logging_obj = kwargs.get("litellm_logging_obj")
if litellm_logging_obj is not None:
litellm_logging_obj.model = model
litellm_logging_obj.custom_llm_provider = custom_llm_provider
litellm_logging_obj.model_call_details["model"] = model
litellm_logging_obj.model_call_details[
"custom_llm_provider"
] = custom_llm_provider
return agent_name
async def _send_message_via_completion_bridge(
request: "SendMessageRequest",
custom_llm_provider: str,
api_base: Optional[str],
litellm_params: Dict[str, Any],
) -> LiteLLMSendMessageResponse:
"""
Route a send_message through the LiteLLM completion bridge (e.g. LangGraph, Bedrock AgentCore).
Requires request; api_base is optional for providers that derive endpoint from model.
"""
verbose_logger.info(
f"A2A using completion bridge: provider={custom_llm_provider}, api_base={api_base}"
)
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
A2ACompletionBridgeHandler,
)
params = (
request.params.model_dump(mode="json")
if hasattr(request.params, "model_dump")
else dict(request.params)
)
response_dict = await A2ACompletionBridgeHandler.handle_non_streaming(
request_id=str(request.id),
params=params,
litellm_params=litellm_params,
api_base=api_base,
)
return LiteLLMSendMessageResponse.from_dict(response_dict)
async def _execute_a2a_send_with_retry(
a2a_client: Any,
request: Any,
agent_card: Any,
card_url: Optional[str],
api_base: Optional[str],
agent_name: Optional[str],
) -> Any:
"""Send an A2A message with retry logic for localhost URL errors."""
a2a_response = None
for _ in range(2): # max 2 attempts: original + 1 retry
try:
a2a_response = await a2a_client.send_message(request)
break # success, exit retry loop
except A2ALocalhostURLError as e:
a2a_client = handle_a2a_localhost_retry(
error=e,
agent_card=agent_card,
a2a_client=a2a_client,
is_streaming=False,
)
card_url = agent_card.url if agent_card else None
except Exception as e:
try:
map_a2a_exception(e, card_url, api_base, model=agent_name)
except A2ALocalhostURLError as localhost_err:
a2a_client = handle_a2a_localhost_retry(
error=localhost_err,
agent_card=agent_card,
a2a_client=a2a_client,
is_streaming=False,
)
card_url = agent_card.url if agent_card else None
continue
except Exception:
raise
if a2a_response is None:
raise RuntimeError(
"A2A send_message failed: no response received after retry attempts."
)
return a2a_response
@client
async def asend_message(
a2a_client: Optional["A2AClientType"] = None,
request: Optional["SendMessageRequest"] = None,
api_base: Optional[str] = None,
litellm_params: Optional[Dict[str, Any]] = None,
agent_id: Optional[str] = None,
agent_extra_headers: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> LiteLLMSendMessageResponse:
"""
Async: Send a message to an A2A agent.
Uses the @client decorator for LiteLLM logging and tracking.
If litellm_params contains custom_llm_provider, routes through the completion bridge.
Args:
a2a_client: An initialized a2a.client.A2AClient instance (optional if using completion bridge)
request: SendMessageRequest from a2a.types (optional if using completion bridge with api_base)
api_base: API base URL (required for completion bridge, optional for standard A2A)
litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge
agent_id: Optional agent ID for tracking in SpendLogs
**kwargs: Additional arguments passed to the client decorator
Returns:
LiteLLMSendMessageResponse (wraps a2a SendMessageResponse with _hidden_params)
Example (standard A2A):
```python
from litellm.a2a_protocol import asend_message, create_a2a_client
from a2a.types import SendMessageRequest, MessageSendParams
from uuid import uuid4
a2a_client = await create_a2a_client(base_url="http://localhost:10001")
request = SendMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
)
)
response = await asend_message(a2a_client=a2a_client, request=request)
```
Example (completion bridge with LangGraph):
```python
from litellm.a2a_protocol import asend_message
from a2a.types import SendMessageRequest, MessageSendParams
from uuid import uuid4
request = SendMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
)
)
response = await asend_message(
request=request,
api_base="http://localhost:2024",
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
)
```
"""
litellm_params = litellm_params or {}
logging_obj = kwargs.get("litellm_logging_obj")
trace_id = getattr(logging_obj, "litellm_trace_id", None) if logging_obj else None
custom_llm_provider = litellm_params.get("custom_llm_provider")
# Route through completion bridge if custom_llm_provider is set
if custom_llm_provider:
if request is None:
raise ValueError("request is required for completion bridge")
return await _send_message_via_completion_bridge(
request=request,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
litellm_params=litellm_params,
)
# Standard A2A client flow
if request is None:
raise ValueError("request is required")
# Create A2A client if not provided but api_base is available
if a2a_client is None:
if api_base is None:
raise ValueError(
"Either a2a_client or api_base is required for standard A2A flow"
)
trace_id = trace_id or str(uuid.uuid4())
extra_headers: Dict[str, str] = {"X-LiteLLM-Trace-Id": trace_id}
if agent_id:
extra_headers["X-LiteLLM-Agent-Id"] = agent_id
# Overlay agent-level headers (agent headers take precedence over LiteLLM internal ones)
if agent_extra_headers:
extra_headers.update(agent_extra_headers)
a2a_client = await create_a2a_client(
base_url=api_base, extra_headers=extra_headers
)
# Type assertion: a2a_client is guaranteed to be non-None here
assert a2a_client is not None
agent_name = _get_a2a_model_info(a2a_client, kwargs)
verbose_logger.info(f"A2A send_message request_id={request.id}, agent={agent_name}")
# Get agent card URL for localhost retry logic
agent_card = getattr(a2a_client, "_litellm_agent_card", None) or getattr(
a2a_client, "agent_card", None
)
card_url = getattr(agent_card, "url", None) if agent_card else None
context_id = trace_id or str(uuid.uuid4())
message = request.params.message
if isinstance(message, dict):
if message.get("context_id") is None:
message["context_id"] = context_id
else:
if getattr(message, "context_id", None) is None:
message.context_id = context_id
a2a_response = await _execute_a2a_send_with_retry(
a2a_client=a2a_client,
request=request,
agent_card=agent_card,
card_url=card_url,
api_base=api_base,
agent_name=agent_name,
)
verbose_logger.info(f"A2A send_message completed, request_id={request.id}")
# Wrap in LiteLLM response type for _hidden_params support
response = LiteLLMSendMessageResponse.from_a2a_response(a2a_response)
# Calculate token usage from request and response
response_dict = a2a_response.model_dump(mode="json", exclude_none=True)
(
prompt_tokens,
completion_tokens,
_,
) = A2ARequestUtils.calculate_usage_from_request_response(
request=request,
response_dict=response_dict,
)
# Set usage on logging obj for standard logging payload
_set_usage_on_logging_obj(
kwargs=kwargs,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
# Set agent_id on logging obj for SpendLogs tracking
_set_agent_id_on_logging_obj(kwargs=kwargs, agent_id=agent_id)
return response
@client
def send_message(
a2a_client: "A2AClientType",
request: "SendMessageRequest",
**kwargs: Any,
) -> Union[LiteLLMSendMessageResponse, Coroutine[Any, Any, LiteLLMSendMessageResponse]]:
"""
Sync: Send a message to an A2A agent.
Uses the @client decorator for LiteLLM logging and tracking.
Args:
a2a_client: An initialized a2a.client.A2AClient instance
request: SendMessageRequest from a2a.types
**kwargs: Additional arguments passed to the client decorator
Returns:
LiteLLMSendMessageResponse (wraps a2a SendMessageResponse with _hidden_params)
"""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None:
return asend_message(a2a_client=a2a_client, request=request, **kwargs)
else:
return asyncio.run(
asend_message(a2a_client=a2a_client, request=request, **kwargs)
)
def _build_streaming_logging_obj(
request: "SendStreamingMessageRequest",
agent_name: str,
agent_id: Optional[str],
litellm_params: Optional[Dict[str, Any]],
metadata: Optional[Dict[str, Any]],
proxy_server_request: Optional[Dict[str, Any]],
) -> Logging:
"""Build logging object for streaming A2A requests."""
start_time = datetime.datetime.now()
model = f"a2a_agent/{agent_name}"
logging_obj = Logging(
model=model,
messages=[{"role": "user", "content": "streaming-request"}],
stream=False,
call_type="asend_message_streaming",
start_time=start_time,
litellm_call_id=str(request.id),
function_id=str(request.id),
)
logging_obj.model = model
logging_obj.custom_llm_provider = "a2a_agent"
logging_obj.model_call_details["model"] = model
logging_obj.model_call_details["custom_llm_provider"] = "a2a_agent"
if agent_id:
logging_obj.model_call_details["agent_id"] = agent_id
_litellm_params = litellm_params.copy() if litellm_params else {}
if metadata:
_litellm_params["metadata"] = metadata
if proxy_server_request:
_litellm_params["proxy_server_request"] = proxy_server_request
logging_obj.litellm_params = _litellm_params
logging_obj.optional_params = _litellm_params
logging_obj.model_call_details["litellm_params"] = _litellm_params
logging_obj.model_call_details["metadata"] = metadata or {}
return logging_obj
async def asend_message_streaming( # noqa: PLR0915
a2a_client: Optional["A2AClientType"] = None,
request: Optional["SendStreamingMessageRequest"] = None,
api_base: Optional[str] = None,
litellm_params: Optional[Dict[str, Any]] = None,
agent_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
proxy_server_request: Optional[Dict[str, Any]] = None,
agent_extra_headers: Optional[Dict[str, str]] = None,
) -> AsyncIterator[Any]:
"""
Async: Send a streaming message to an A2A agent.
If litellm_params contains custom_llm_provider, routes through the completion bridge.
Args:
a2a_client: An initialized a2a.client.A2AClient instance (optional if using completion bridge)
request: SendStreamingMessageRequest from a2a.types
api_base: API base URL (required for completion bridge)
litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge
agent_id: Optional agent ID for tracking in SpendLogs
metadata: Optional metadata dict (contains user_api_key, user_id, team_id, etc.)
proxy_server_request: Optional proxy server request data
Yields:
SendStreamingMessageResponse chunks from the agent
Example (completion bridge with LangGraph):
```python
from litellm.a2a_protocol import asend_message_streaming
from a2a.types import SendStreamingMessageRequest, MessageSendParams
from uuid import uuid4
request = SendStreamingMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
)
)
async for chunk in asend_message_streaming(
request=request,
api_base="http://localhost:2024",
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
):
print(chunk)
```
"""
litellm_params = litellm_params or {}
custom_llm_provider = litellm_params.get("custom_llm_provider")
# Route through completion bridge if custom_llm_provider is set
if custom_llm_provider:
if request is None:
raise ValueError("request is required for completion bridge")
# api_base is optional for providers that derive endpoint from model (e.g., bedrock/agentcore)
verbose_logger.info(
f"A2A streaming using completion bridge: provider={custom_llm_provider}"
)
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
A2ACompletionBridgeHandler,
)
# Extract params from request
params = (
request.params.model_dump(mode="json")
if hasattr(request.params, "model_dump")
else dict(request.params)
)
async for chunk in A2ACompletionBridgeHandler.handle_streaming(
request_id=str(request.id),
params=params,
litellm_params=litellm_params,
api_base=api_base,
):
yield chunk
return
# Standard A2A client flow
if request is None:
raise ValueError("request is required")
# Create A2A client if not provided but api_base is available
if a2a_client is None:
if api_base is None:
raise ValueError(
"Either a2a_client or api_base is required for standard A2A flow"
)
# Mirror the non-streaming path: always include trace and agent-id headers
streaming_extra_headers: Dict[str, str] = {
"X-LiteLLM-Trace-Id": str(request.id),
}
if agent_id:
streaming_extra_headers["X-LiteLLM-Agent-Id"] = agent_id
if agent_extra_headers:
streaming_extra_headers.update(agent_extra_headers)
a2a_client = await create_a2a_client(
base_url=api_base, extra_headers=streaming_extra_headers
)
# Type assertion: a2a_client is guaranteed to be non-None here
assert a2a_client is not None
verbose_logger.info(f"A2A send_message_streaming request_id={request.id}")
# Build logging object for streaming completion callbacks
agent_card = getattr(a2a_client, "_litellm_agent_card", None) or getattr(
a2a_client, "agent_card", None
)
card_url = getattr(agent_card, "url", None) if agent_card else None
agent_name = getattr(agent_card, "name", "unknown") if agent_card else "unknown"
logging_obj = _build_streaming_logging_obj(
request=request,
agent_name=agent_name,
agent_id=agent_id,
litellm_params=litellm_params,
metadata=metadata,
proxy_server_request=proxy_server_request,
)
# Retry loop: if connection fails due to localhost URL in agent card, retry with fixed URL
# Connection errors in streaming typically occur on first chunk iteration
first_chunk = True
for attempt in range(2): # max 2 attempts: original + 1 retry
stream = a2a_client.send_message_streaming(request)
iterator = A2AStreamingIterator(
stream=stream,
request=request,
logging_obj=logging_obj,
agent_name=agent_name,
)
try:
first_chunk = True
async for chunk in iterator:
if first_chunk:
first_chunk = False # connection succeeded
yield chunk
return # stream completed successfully
except A2ALocalhostURLError as e:
# Only retry on first chunk, not mid-stream
if first_chunk and attempt == 0:
a2a_client = handle_a2a_localhost_retry(
error=e,
agent_card=agent_card,
a2a_client=a2a_client,
is_streaming=True,
)
card_url = agent_card.url if agent_card else None
else:
raise
except Exception as e:
# Only map exception on first chunk
if first_chunk and attempt == 0:
try:
map_a2a_exception(e, card_url, api_base, model=agent_name)
except A2ALocalhostURLError as localhost_err:
# Localhost URL error - fix and retry
a2a_client = handle_a2a_localhost_retry(
error=localhost_err,
agent_card=agent_card,
a2a_client=a2a_client,
is_streaming=True,
)
card_url = agent_card.url if agent_card else None
continue
except Exception:
# Re-raise the mapped exception
raise
raise
async def create_a2a_client(
base_url: str,
timeout: float = 60.0,
extra_headers: Optional[Dict[str, str]] = None,
) -> "A2AClientType":
"""
Create an A2A client for the given agent URL.
This resolves the agent card and returns a ready-to-use A2A client.
The client can be reused for multiple requests.
Args:
base_url: The base URL of the A2A agent (e.g., "http://localhost:10001")
timeout: Request timeout in seconds (default: 60.0)
extra_headers: Optional additional headers to include in requests
Returns:
An initialized a2a.client.A2AClient instance
Example:
```python
from litellm.a2a_protocol import create_a2a_client, asend_message
# Create client once
client = await create_a2a_client(base_url="http://localhost:10001")
# Reuse for multiple requests
response1 = await asend_message(a2a_client=client, request=request1)
response2 = await asend_message(a2a_client=client, request=request2)
```
"""
if not A2A_SDK_AVAILABLE:
raise ImportError(
"The 'a2a' package is required for A2A agent invocation. "
"Install it with: pip install a2a-sdk"
)
verbose_logger.info(f"Creating A2A client for {base_url}")
# Use get_async_httpx_client with per-agent params so that different agents
# (with different extra_headers) get separate cached clients. The params
# dict is hashed into the cache key, keeping agent auth isolated while
# still reusing connections within the same agent.
#
# Only pass params that AsyncHTTPHandler.__init__ accepts (e.g. timeout).
# Use "disable_aiohttp_transport" key for cache-key-only data (it's
# filtered out before reaching the constructor).
_client_params: dict = {"timeout": timeout}
if extra_headers:
# Encode headers into a cache-key-only param so each unique header
# set produces a distinct cache key.
_client_params["disable_aiohttp_transport"] = str(sorted(extra_headers.items()))
_async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.A2AProvider,
params=_client_params,
)
httpx_client = _async_handler.client
if extra_headers:
httpx_client.headers.update(extra_headers)
verbose_proxy_logger.debug(
f"A2A client created with extra_headers={list(extra_headers.keys())}"
)
# Resolve agent card
resolver = A2ACardResolver(
httpx_client=httpx_client,
base_url=base_url,
)
agent_card = await resolver.get_agent_card()
verbose_logger.debug(
f"Resolved agent card: {agent_card.name if hasattr(agent_card, 'name') else 'unknown'}"
)
# Create A2A client
a2a_client = _A2AClient(
httpx_client=httpx_client,
agent_card=agent_card,
)
# Store agent_card on client for later retrieval (SDK doesn't expose it)
a2a_client._litellm_agent_card = agent_card # type: ignore[attr-defined]
verbose_logger.info(f"A2A client created for {base_url}")
return a2a_client
async def aget_agent_card(
base_url: str,
timeout: float = DEFAULT_A2A_AGENT_TIMEOUT,
extra_headers: Optional[Dict[str, str]] = None,
) -> "AgentCard":
"""
Fetch the agent card from an A2A agent.
Args:
base_url: The base URL of the A2A agent (e.g., "http://localhost:10001")
timeout: Request timeout in seconds (default: 60.0)
extra_headers: Optional additional headers to include in requests
Returns:
AgentCard from the A2A agent
"""
if not A2A_SDK_AVAILABLE:
raise ImportError(
"The 'a2a' package is required for A2A agent invocation. "
"Install it with: pip install a2a-sdk"
)
verbose_logger.info(f"Fetching agent card from {base_url}")
# Use LiteLLM's cached httpx client
http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.A2A,
params={"timeout": timeout},
)
httpx_client = http_handler.client
resolver = A2ACardResolver(
httpx_client=httpx_client,
base_url=base_url,
)
agent_card = await resolver.get_agent_card()
verbose_logger.info(
f"Fetched agent card: {agent_card.name if hasattr(agent_card, 'name') else 'unknown'}"
)
return agent_card

View File

@@ -0,0 +1,10 @@
"""
A2A Protocol Providers.
This module contains provider-specific implementations for the A2A protocol.
"""
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
from litellm.a2a_protocol.providers.config_manager import A2AProviderConfigManager
__all__ = ["BaseA2AProviderConfig", "A2AProviderConfigManager"]

View File

@@ -0,0 +1,62 @@
"""
Base configuration for A2A protocol providers.
"""
from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, Dict
class BaseA2AProviderConfig(ABC):
"""
Base configuration class for A2A protocol providers.
Each provider should implement this interface to define how to handle
A2A requests for their specific agent type.
"""
@abstractmethod
async def handle_non_streaming(
self,
request_id: str,
params: Dict[str, Any],
api_base: str,
**kwargs,
) -> Dict[str, Any]:
"""
Handle non-streaming A2A request.
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
api_base: Base URL of the agent
**kwargs: Additional provider-specific parameters
Returns:
A2A SendMessageResponse dict
"""
pass
@abstractmethod
async def handle_streaming(
self,
request_id: str,
params: Dict[str, Any],
api_base: str,
**kwargs,
) -> AsyncIterator[Dict[str, Any]]:
"""
Handle streaming A2A request.
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
api_base: Base URL of the agent
**kwargs: Additional provider-specific parameters
Yields:
A2A streaming response events
"""
# This is an abstract method - subclasses must implement
# The yield is here to make this a generator function
if False: # pragma: no cover
yield {}

View File

@@ -0,0 +1,47 @@
"""
A2A Provider Config Manager.
Manages provider-specific configurations for A2A protocol.
"""
from typing import Optional
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
class A2AProviderConfigManager:
"""
Manager for A2A provider configurations.
Similar to ProviderConfigManager in litellm.utils but specifically for A2A providers.
"""
@staticmethod
def get_provider_config(
custom_llm_provider: Optional[str],
) -> Optional[BaseA2AProviderConfig]:
"""
Get the provider configuration for a given custom_llm_provider.
Args:
custom_llm_provider: The provider identifier (e.g., "pydantic_ai_agents")
Returns:
Provider configuration instance or None if not found
"""
if custom_llm_provider is None:
return None
if custom_llm_provider == "pydantic_ai_agents":
from litellm.a2a_protocol.providers.pydantic_ai_agents.config import (
PydanticAIProviderConfig,
)
return PydanticAIProviderConfig()
# Add more providers here as needed
# elif custom_llm_provider == "another_provider":
# from litellm.a2a_protocol.providers.another_provider.config import AnotherProviderConfig
# return AnotherProviderConfig()
return None

View File

@@ -0,0 +1,74 @@
# A2A to LiteLLM Completion Bridge
Routes A2A protocol requests through `litellm.acompletion`, enabling any LiteLLM-supported provider to be invoked via A2A.
## Flow
```
A2A Request → Transform → litellm.acompletion → Transform → A2A Response
```
## SDK Usage
Use the existing `asend_message` and `asend_message_streaming` functions with `litellm_params`:
```python
from litellm.a2a_protocol import asend_message, asend_message_streaming
from a2a.types import SendMessageRequest, SendStreamingMessageRequest, MessageSendParams
from uuid import uuid4
# Non-streaming
request = SendMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
)
)
response = await asend_message(
request=request,
api_base="http://localhost:2024",
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
)
# Streaming
stream_request = SendStreamingMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
)
)
async for chunk in asend_message_streaming(
request=stream_request,
api_base="http://localhost:2024",
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
):
print(chunk)
```
## Proxy Usage
Configure an agent with `custom_llm_provider` in `litellm_params`:
```yaml
agents:
- agent_name: my-langgraph-agent
agent_card_params:
name: "LangGraph Agent"
url: "http://localhost:2024" # Used as api_base
litellm_params:
custom_llm_provider: langgraph
model: agent
```
When an A2A request hits `/a2a/{agent_id}/message/send`, the bridge:
1. Detects `custom_llm_provider` in agent's `litellm_params`
2. Transforms A2A message → OpenAI messages
3. Calls `litellm.acompletion(model="langgraph/agent", api_base="http://localhost:2024")`
4. Transforms response → A2A format
## Classes
- `A2ACompletionBridgeTransformation` - Static methods for message format conversion
- `A2ACompletionBridgeHandler` - Static methods for handling requests (streaming/non-streaming)

View File

@@ -0,0 +1,5 @@
"""
LiteLLM Completion bridge provider for A2A protocol.
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
"""

View File

@@ -0,0 +1,301 @@
"""
Handler for A2A to LiteLLM completion bridge.
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
A2A Streaming Events (in order):
1. Task event (kind: "task") - Initial task creation with status "submitted"
2. Status update (kind: "status-update") - Status change to "working"
3. Artifact update (kind: "artifact-update") - Content/artifact delivery
4. Status update (kind: "status-update") - Final status "completed" with final=true
"""
from typing import Any, AsyncIterator, Dict, Optional
import litellm
from litellm._logging import verbose_logger
from litellm.a2a_protocol.litellm_completion_bridge.pydantic_ai_transformation import (
PydanticAITransformation,
)
from litellm.a2a_protocol.litellm_completion_bridge.transformation import (
A2ACompletionBridgeTransformation,
A2AStreamingContext,
)
class A2ACompletionBridgeHandler:
"""
Static methods for handling A2A requests via LiteLLM completion.
"""
@staticmethod
async def handle_non_streaming(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> Dict[str, Any]:
"""
Handle non-streaming A2A request via litellm.acompletion.
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
api_base: API base URL from agent_card_params
Returns:
A2A SendMessageResponse dict
"""
# Check if this is a Pydantic AI agent request
custom_llm_provider = litellm_params.get("custom_llm_provider")
if custom_llm_provider == "pydantic_ai_agents":
if api_base is None:
raise ValueError("api_base is required for Pydantic AI agents")
verbose_logger.info(
f"Pydantic AI: Routing to Pydantic AI agent at {api_base}"
)
# Send request directly to Pydantic AI agent
response_data = await PydanticAITransformation.send_non_streaming_request(
api_base=api_base,
request_id=request_id,
params=params,
)
return response_data
# Extract message from params
message = params.get("message", {})
# Transform A2A message to OpenAI format
openai_messages = (
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
)
# Get completion params
custom_llm_provider = litellm_params.get("custom_llm_provider")
model = litellm_params.get("model", "agent")
# Build full model string if provider specified
# Skip prepending if model already starts with the provider prefix
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
full_model = f"{custom_llm_provider}/{model}"
else:
full_model = model
verbose_logger.info(
f"A2A completion bridge: model={full_model}, api_base={api_base}"
)
# Build completion params dict
completion_params = {
"model": full_model,
"messages": openai_messages,
"api_base": api_base,
"stream": False,
}
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
litellm_params_to_add = {
k: v
for k, v in litellm_params.items()
if k not in ("model", "custom_llm_provider")
}
completion_params.update(litellm_params_to_add)
# Call litellm.acompletion
response = await litellm.acompletion(**completion_params)
# Transform response to A2A format
a2a_response = (
A2ACompletionBridgeTransformation.openai_response_to_a2a_response(
response=response,
request_id=request_id,
)
)
verbose_logger.info(f"A2A completion bridge completed: request_id={request_id}")
return a2a_response
@staticmethod
async def handle_streaming(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""
Handle streaming A2A request via litellm.acompletion with stream=True.
Emits proper A2A streaming events:
1. Task event (kind: "task") - Initial task with status "submitted"
2. Status update (kind: "status-update") - Status "working"
3. Artifact update (kind: "artifact-update") - Content delivery
4. Status update (kind: "status-update") - Final "completed" status
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
api_base: API base URL from agent_card_params
Yields:
A2A streaming response events
"""
# Check if this is a Pydantic AI agent request
custom_llm_provider = litellm_params.get("custom_llm_provider")
if custom_llm_provider == "pydantic_ai_agents":
if api_base is None:
raise ValueError("api_base is required for Pydantic AI agents")
verbose_logger.info(
f"Pydantic AI: Faking streaming for Pydantic AI agent at {api_base}"
)
# Get non-streaming response first
response_data = await PydanticAITransformation.send_non_streaming_request(
api_base=api_base,
request_id=request_id,
params=params,
)
# Convert to fake streaming
async for chunk in PydanticAITransformation.fake_streaming_from_response(
response_data=response_data,
request_id=request_id,
):
yield chunk
return
# Extract message from params
message = params.get("message", {})
# Create streaming context
ctx = A2AStreamingContext(
request_id=request_id,
input_message=message,
)
# Transform A2A message to OpenAI format
openai_messages = (
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
)
# Get completion params
custom_llm_provider = litellm_params.get("custom_llm_provider")
model = litellm_params.get("model", "agent")
# Build full model string if provider specified
# Skip prepending if model already starts with the provider prefix
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
full_model = f"{custom_llm_provider}/{model}"
else:
full_model = model
verbose_logger.info(
f"A2A completion bridge streaming: model={full_model}, api_base={api_base}"
)
# Build completion params dict
completion_params = {
"model": full_model,
"messages": openai_messages,
"api_base": api_base,
"stream": True,
}
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
litellm_params_to_add = {
k: v
for k, v in litellm_params.items()
if k not in ("model", "custom_llm_provider")
}
completion_params.update(litellm_params_to_add)
# 1. Emit initial task event (kind: "task", status: "submitted")
task_event = A2ACompletionBridgeTransformation.create_task_event(ctx)
yield task_event
# 2. Emit status update (kind: "status-update", status: "working")
working_event = A2ACompletionBridgeTransformation.create_status_update_event(
ctx=ctx,
state="working",
final=False,
message_text="Processing request...",
)
yield working_event
# Call litellm.acompletion with streaming
response = await litellm.acompletion(**completion_params)
# 3. Accumulate content and emit artifact update
accumulated_text = ""
chunk_count = 0
async for chunk in response: # type: ignore[union-attr]
chunk_count += 1
# Extract delta content
content = ""
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, "delta") and choice.delta:
content = choice.delta.content or ""
if content:
accumulated_text += content
# Emit artifact update with accumulated content
if accumulated_text:
artifact_event = (
A2ACompletionBridgeTransformation.create_artifact_update_event(
ctx=ctx,
text=accumulated_text,
)
)
yield artifact_event
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
completed_event = A2ACompletionBridgeTransformation.create_status_update_event(
ctx=ctx,
state="completed",
final=True,
)
yield completed_event
verbose_logger.info(
f"A2A completion bridge streaming completed: request_id={request_id}, chunks={chunk_count}"
)
# Convenience functions that delegate to the class methods
async def handle_a2a_completion(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> Dict[str, Any]:
"""Convenience function for non-streaming A2A completion."""
return await A2ACompletionBridgeHandler.handle_non_streaming(
request_id=request_id,
params=params,
litellm_params=litellm_params,
api_base=api_base,
)
async def handle_a2a_completion_streaming(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""Convenience function for streaming A2A completion."""
async for chunk in A2ACompletionBridgeHandler.handle_streaming(
request_id=request_id,
params=params,
litellm_params=litellm_params,
api_base=api_base,
):
yield chunk

View File

@@ -0,0 +1,284 @@
"""
Transformation utilities for A2A <-> OpenAI message format conversion.
A2A Message Format:
{
"role": "user",
"parts": [{"kind": "text", "text": "Hello!"}],
"messageId": "abc123"
}
OpenAI Message Format:
{"role": "user", "content": "Hello!"}
A2A Streaming Events:
- Task event (kind: "task") - Initial task creation with status "submitted"
- Status update (kind: "status-update") - Status changes (working, completed)
- Artifact update (kind: "artifact-update") - Content/artifact delivery
"""
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from uuid import uuid4
from litellm._logging import verbose_logger
class A2AStreamingContext:
"""
Context holder for A2A streaming state.
Tracks task_id, context_id, and message accumulation.
"""
def __init__(self, request_id: str, input_message: Dict[str, Any]):
self.request_id = request_id
self.task_id = str(uuid4())
self.context_id = str(uuid4())
self.input_message = input_message
self.accumulated_text = ""
self.has_emitted_task = False
self.has_emitted_working = False
class A2ACompletionBridgeTransformation:
"""
Static methods for transforming between A2A and OpenAI message formats.
"""
@staticmethod
def a2a_message_to_openai_messages(
a2a_message: Dict[str, Any],
) -> List[Dict[str, str]]:
"""
Transform an A2A message to OpenAI message format.
Args:
a2a_message: A2A message with role, parts, and messageId
Returns:
List of OpenAI-format messages
"""
role = a2a_message.get("role", "user")
parts = a2a_message.get("parts", [])
# Map A2A roles to OpenAI roles
openai_role = role
if role == "user":
openai_role = "user"
elif role == "assistant":
openai_role = "assistant"
elif role == "system":
openai_role = "system"
# Extract text content from parts
content_parts = []
for part in parts:
kind = part.get("kind", "")
if kind == "text":
text = part.get("text", "")
content_parts.append(text)
content = "\n".join(content_parts) if content_parts else ""
verbose_logger.debug(
f"A2A -> OpenAI transform: role={role} -> {openai_role}, content_length={len(content)}"
)
return [{"role": openai_role, "content": content}]
@staticmethod
def openai_response_to_a2a_response(
response: Any,
request_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Transform a LiteLLM ModelResponse to A2A SendMessageResponse format.
Args:
response: LiteLLM ModelResponse object
request_id: Original A2A request ID
Returns:
A2A SendMessageResponse dict
"""
# Extract content from response
content = ""
if hasattr(response, "choices") and response.choices:
choice = response.choices[0]
if hasattr(choice, "message") and choice.message:
content = choice.message.content or ""
# Build A2A message
a2a_message = {
"role": "agent",
"parts": [{"kind": "text", "text": content}],
"messageId": uuid4().hex,
}
# Build A2A response
a2a_response = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"message": a2a_message,
},
}
verbose_logger.debug(f"OpenAI -> A2A transform: content_length={len(content)}")
return a2a_response
@staticmethod
def _get_timestamp() -> str:
"""Get current timestamp in ISO format with timezone."""
return datetime.now(timezone.utc).isoformat()
@staticmethod
def create_task_event(
ctx: A2AStreamingContext,
) -> Dict[str, Any]:
"""
Create the initial task event with status 'submitted'.
This is the first event emitted in an A2A streaming response.
"""
return {
"id": ctx.request_id,
"jsonrpc": "2.0",
"result": {
"contextId": ctx.context_id,
"history": [
{
"contextId": ctx.context_id,
"kind": "message",
"messageId": ctx.input_message.get("messageId", uuid4().hex),
"parts": ctx.input_message.get("parts", []),
"role": ctx.input_message.get("role", "user"),
"taskId": ctx.task_id,
}
],
"id": ctx.task_id,
"kind": "task",
"status": {
"state": "submitted",
},
},
}
@staticmethod
def create_status_update_event(
ctx: A2AStreamingContext,
state: str,
final: bool = False,
message_text: Optional[str] = None,
) -> Dict[str, Any]:
"""
Create a status update event.
Args:
ctx: Streaming context
state: Status state ('working', 'completed')
final: Whether this is the final event
message_text: Optional message text for 'working' status
"""
status: Dict[str, Any] = {
"state": state,
"timestamp": A2ACompletionBridgeTransformation._get_timestamp(),
}
# Add message for 'working' status
if state == "working" and message_text:
status["message"] = {
"contextId": ctx.context_id,
"kind": "message",
"messageId": str(uuid4()),
"parts": [{"kind": "text", "text": message_text}],
"role": "agent",
"taskId": ctx.task_id,
}
return {
"id": ctx.request_id,
"jsonrpc": "2.0",
"result": {
"contextId": ctx.context_id,
"final": final,
"kind": "status-update",
"status": status,
"taskId": ctx.task_id,
},
}
@staticmethod
def create_artifact_update_event(
ctx: A2AStreamingContext,
text: str,
) -> Dict[str, Any]:
"""
Create an artifact update event with content.
Args:
ctx: Streaming context
text: The text content for the artifact
"""
return {
"id": ctx.request_id,
"jsonrpc": "2.0",
"result": {
"artifact": {
"artifactId": str(uuid4()),
"name": "response",
"parts": [{"kind": "text", "text": text}],
},
"contextId": ctx.context_id,
"kind": "artifact-update",
"taskId": ctx.task_id,
},
}
@staticmethod
def openai_chunk_to_a2a_chunk(
chunk: Any,
request_id: Optional[str] = None,
is_final: bool = False,
) -> Optional[Dict[str, Any]]:
"""
Transform a LiteLLM streaming chunk to A2A streaming format.
NOTE: This method is deprecated for streaming. Use the event-based
methods (create_task_event, create_status_update_event,
create_artifact_update_event) instead for proper A2A streaming.
Args:
chunk: LiteLLM ModelResponse chunk
request_id: Original A2A request ID
is_final: Whether this is the final chunk
Returns:
A2A streaming chunk dict or None if no content
"""
# Extract delta content
content = ""
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, "delta") and choice.delta:
content = choice.delta.content or ""
if not content and not is_final:
return None
# Build A2A streaming chunk (legacy format)
a2a_chunk = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"message": {
"role": "agent",
"parts": [{"kind": "text", "text": content}],
"messageId": uuid4().hex,
},
"final": is_final,
},
}
return a2a_chunk

View File

@@ -0,0 +1,16 @@
"""
Pydantic AI agent provider for A2A protocol.
Pydantic AI agents follow A2A protocol but don't support streaming natively.
This provider handles fake streaming by converting non-streaming responses into streaming chunks.
"""
from litellm.a2a_protocol.providers.pydantic_ai_agents.config import (
PydanticAIProviderConfig,
)
from litellm.a2a_protocol.providers.pydantic_ai_agents.handler import PydanticAIHandler
from litellm.a2a_protocol.providers.pydantic_ai_agents.transformation import (
PydanticAITransformation,
)
__all__ = ["PydanticAIHandler", "PydanticAITransformation", "PydanticAIProviderConfig"]

View File

@@ -0,0 +1,50 @@
"""
Pydantic AI provider configuration.
"""
from typing import Any, AsyncIterator, Dict
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
from litellm.a2a_protocol.providers.pydantic_ai_agents.handler import PydanticAIHandler
class PydanticAIProviderConfig(BaseA2AProviderConfig):
"""
Provider configuration for Pydantic AI agents.
Pydantic AI agents follow A2A protocol but don't support streaming natively.
This config provides fake streaming by converting non-streaming responses into streaming chunks.
"""
async def handle_non_streaming(
self,
request_id: str,
params: Dict[str, Any],
api_base: str,
**kwargs,
) -> Dict[str, Any]:
"""Handle non-streaming request to Pydantic AI agent."""
return await PydanticAIHandler.handle_non_streaming(
request_id=request_id,
params=params,
api_base=api_base,
timeout=kwargs.get("timeout", 60.0),
)
async def handle_streaming(
self,
request_id: str,
params: Dict[str, Any],
api_base: str,
**kwargs,
) -> AsyncIterator[Dict[str, Any]]:
"""Handle streaming request with fake streaming."""
async for chunk in PydanticAIHandler.handle_streaming(
request_id=request_id,
params=params,
api_base=api_base,
timeout=kwargs.get("timeout", 60.0),
chunk_size=kwargs.get("chunk_size", 50),
delay_ms=kwargs.get("delay_ms", 10),
):
yield chunk

View File

@@ -0,0 +1,102 @@
"""
Handler for Pydantic AI agents.
Pydantic AI agents follow A2A protocol but don't support streaming natively.
This handler provides fake streaming by converting non-streaming responses into streaming chunks.
"""
from typing import Any, AsyncIterator, Dict
from litellm._logging import verbose_logger
from litellm.a2a_protocol.providers.pydantic_ai_agents.transformation import (
PydanticAITransformation,
)
class PydanticAIHandler:
"""
Handler for Pydantic AI agent requests.
Provides:
- Direct non-streaming requests to Pydantic AI agents
- Fake streaming by converting non-streaming responses into streaming chunks
"""
@staticmethod
async def handle_non_streaming(
request_id: str,
params: Dict[str, Any],
api_base: str,
timeout: float = 60.0,
) -> Dict[str, Any]:
"""
Handle non-streaming request to Pydantic AI agent.
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
api_base: Base URL of the Pydantic AI agent
timeout: Request timeout in seconds
Returns:
A2A SendMessageResponse dict
"""
verbose_logger.info(f"Pydantic AI: Routing to Pydantic AI agent at {api_base}")
# Send request directly to Pydantic AI agent
response_data = await PydanticAITransformation.send_non_streaming_request(
api_base=api_base,
request_id=request_id,
params=params,
timeout=timeout,
)
return response_data
@staticmethod
async def handle_streaming(
request_id: str,
params: Dict[str, Any],
api_base: str,
timeout: float = 60.0,
chunk_size: int = 50,
delay_ms: int = 10,
) -> AsyncIterator[Dict[str, Any]]:
"""
Handle streaming request to Pydantic AI agent with fake streaming.
Since Pydantic AI agents don't support streaming natively, this method:
1. Makes a non-streaming request
2. Converts the response into streaming chunks
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
api_base: Base URL of the Pydantic AI agent
timeout: Request timeout in seconds
chunk_size: Number of characters per chunk
delay_ms: Delay between chunks in milliseconds
Yields:
A2A streaming response events
"""
verbose_logger.info(
f"Pydantic AI: Faking streaming for Pydantic AI agent at {api_base}"
)
# Get raw task response first (not the transformed A2A format)
raw_response = await PydanticAITransformation.send_and_get_raw_response(
api_base=api_base,
request_id=request_id,
params=params,
timeout=timeout,
)
# Convert raw task response to fake streaming chunks
async for chunk in PydanticAITransformation.fake_streaming_from_response(
response_data=raw_response,
request_id=request_id,
chunk_size=chunk_size,
delay_ms=delay_ms,
):
yield chunk

View File

@@ -0,0 +1,530 @@
"""
Transformation layer for Pydantic AI agents.
Pydantic AI agents follow A2A protocol but don't support streaming.
This module provides fake streaming by converting non-streaming responses into streaming chunks.
"""
import asyncio
from typing import Any, AsyncIterator, Dict, cast
from uuid import uuid4
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
class PydanticAITransformation:
"""
Transformation layer for Pydantic AI agents.
Handles:
- Direct A2A requests to Pydantic AI endpoints
- Polling for task completion (since Pydantic AI doesn't support streaming)
- Fake streaming by chunking non-streaming responses
"""
@staticmethod
def _remove_none_values(obj: Any) -> Any:
"""
Recursively remove None values from a dict/list structure.
FastA2A/Pydantic AI servers don't accept None values for optional fields -
they expect those fields to be omitted entirely.
Args:
obj: Dict, list, or other value to clean
Returns:
Cleaned object with None values removed
"""
if isinstance(obj, dict):
return {
k: PydanticAITransformation._remove_none_values(v)
for k, v in obj.items()
if v is not None
}
elif isinstance(obj, list):
return [
PydanticAITransformation._remove_none_values(item)
for item in obj
if item is not None
]
else:
return obj
@staticmethod
def _params_to_dict(params: Any) -> Dict[str, Any]:
"""
Convert params to a dict, handling Pydantic models.
Args:
params: Dict or Pydantic model
Returns:
Dict representation of params
"""
if hasattr(params, "model_dump"):
# Pydantic v2 model
return params.model_dump(mode="python", exclude_none=True)
elif hasattr(params, "dict"):
# Pydantic v1 model
return params.dict(exclude_none=True)
elif isinstance(params, dict):
return params
else:
# Try to convert to dict
return dict(params)
@staticmethod
async def _poll_for_completion(
client: AsyncHTTPHandler,
endpoint: str,
task_id: str,
request_id: str,
max_attempts: int = 30,
poll_interval: float = 0.5,
) -> Dict[str, Any]:
"""
Poll for task completion using tasks/get method.
Args:
client: HTTPX async client
endpoint: API endpoint URL
task_id: Task ID to poll for
request_id: JSON-RPC request ID
max_attempts: Maximum polling attempts
poll_interval: Seconds between poll attempts
Returns:
Completed task response
"""
for attempt in range(max_attempts):
poll_request = {
"jsonrpc": "2.0",
"id": f"{request_id}-poll-{attempt}",
"method": "tasks/get",
"params": {"id": task_id},
}
response = await client.post(
endpoint,
json=poll_request,
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
poll_data = response.json()
result = poll_data.get("result", {})
status = result.get("status", {})
state = status.get("state", "")
verbose_logger.debug(
f"Pydantic AI: Poll attempt {attempt + 1}/{max_attempts}, state={state}"
)
if state == "completed":
return poll_data
elif state in ("failed", "canceled"):
raise Exception(f"Task {task_id} ended with state: {state}")
await asyncio.sleep(poll_interval)
raise TimeoutError(
f"Task {task_id} did not complete within {max_attempts * poll_interval} seconds"
)
@staticmethod
async def _send_and_poll_raw(
api_base: str,
request_id: str,
params: Any,
timeout: float = 60.0,
) -> Dict[str, Any]:
"""
Send a request to Pydantic AI agent and return the raw task response.
This is an internal method used by both non-streaming and streaming handlers.
Returns the raw Pydantic AI task format with history/artifacts.
Args:
api_base: Base URL of the Pydantic AI agent
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
timeout: Request timeout in seconds
Returns:
Raw Pydantic AI task response (with history/artifacts)
"""
# Convert params to dict if it's a Pydantic model
params_dict = PydanticAITransformation._params_to_dict(params)
# Remove None values - FastA2A doesn't accept null for optional fields
params_dict = PydanticAITransformation._remove_none_values(params_dict)
# Ensure the message has 'kind': 'message' as required by FastA2A/Pydantic AI
if "message" in params_dict:
params_dict["message"]["kind"] = "message"
# Build A2A JSON-RPC request using message/send method for FastA2A compatibility
a2a_request = {
"jsonrpc": "2.0",
"id": request_id,
"method": "message/send",
"params": params_dict,
}
# FastA2A uses root endpoint (/) not /messages
endpoint = api_base.rstrip("/")
verbose_logger.info(f"Pydantic AI: Sending non-streaming request to {endpoint}")
# Send request to Pydantic AI agent using shared async HTTP client
client = get_async_httpx_client(
llm_provider=cast(Any, "pydantic_ai_agent"),
params={"timeout": timeout},
)
response = await client.post(
endpoint,
json=a2a_request,
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
response_data = response.json()
# Check if task is already completed
result = response_data.get("result", {})
status = result.get("status", {})
state = status.get("state", "")
if state != "completed":
# Need to poll for completion
task_id = result.get("id")
if task_id:
verbose_logger.info(
f"Pydantic AI: Task {task_id} submitted, polling for completion..."
)
response_data = await PydanticAITransformation._poll_for_completion(
client=client,
endpoint=endpoint,
task_id=task_id,
request_id=request_id,
)
verbose_logger.info(
f"Pydantic AI: Received completed response for request_id={request_id}"
)
return response_data
@staticmethod
async def send_non_streaming_request(
api_base: str,
request_id: str,
params: Any,
timeout: float = 60.0,
) -> Dict[str, Any]:
"""
Send a non-streaming A2A request to Pydantic AI agent and wait for completion.
Args:
api_base: Base URL of the Pydantic AI agent (e.g., "http://localhost:9999")
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message (dict or Pydantic model)
timeout: Request timeout in seconds
Returns:
Standard A2A non-streaming response format with message
"""
# Get raw task response
raw_response = await PydanticAITransformation._send_and_poll_raw(
api_base=api_base,
request_id=request_id,
params=params,
timeout=timeout,
)
# Transform to standard A2A non-streaming format
return PydanticAITransformation._transform_to_a2a_response(
response_data=raw_response,
request_id=request_id,
)
@staticmethod
async def send_and_get_raw_response(
api_base: str,
request_id: str,
params: Any,
timeout: float = 60.0,
) -> Dict[str, Any]:
"""
Send a request to Pydantic AI agent and return the raw task response.
Used by streaming handler to get raw response for fake streaming.
Args:
api_base: Base URL of the Pydantic AI agent
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
timeout: Request timeout in seconds
Returns:
Raw Pydantic AI task response (with history/artifacts)
"""
return await PydanticAITransformation._send_and_poll_raw(
api_base=api_base,
request_id=request_id,
params=params,
timeout=timeout,
)
@staticmethod
def _transform_to_a2a_response(
response_data: Dict[str, Any],
request_id: str,
) -> Dict[str, Any]:
"""
Transform Pydantic AI task response to standard A2A non-streaming format.
Pydantic AI returns a task with history/artifacts, but the standard A2A
non-streaming format expects:
{
"jsonrpc": "2.0",
"id": "...",
"result": {
"message": {
"role": "agent",
"parts": [{"kind": "text", "text": "..."}],
"messageId": "..."
}
}
}
Args:
response_data: Pydantic AI task response
request_id: Original request ID
Returns:
Standard A2A non-streaming response format
"""
# Extract the agent response text
full_text, message_id, parts = PydanticAITransformation._extract_response_text(
response_data
)
# Build standard A2A message
a2a_message = {
"role": "agent",
"parts": parts if parts else [{"kind": "text", "text": full_text}],
"messageId": message_id,
}
# Return standard A2A non-streaming format
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"message": a2a_message,
},
}
@staticmethod
def _extract_response_text(response_data: Dict[str, Any]) -> tuple[str, str, list]:
"""
Extract response text from completed task response.
Pydantic AI returns completed tasks with:
- history: list of messages (user and agent)
- artifacts: list of result artifacts
Args:
response_data: Completed task response
Returns:
Tuple of (full_text, message_id, parts)
"""
result = response_data.get("result", {})
# Try to extract from artifacts first (preferred for results)
artifacts = result.get("artifacts", [])
if artifacts:
for artifact in artifacts:
parts = artifact.get("parts", [])
for part in parts:
if part.get("kind") == "text":
text = part.get("text", "")
if text:
return text, str(uuid4()), parts
# Fall back to history - get the last agent message
history = result.get("history", [])
for msg in reversed(history):
if msg.get("role") == "agent":
parts = msg.get("parts", [])
message_id = msg.get("messageId", str(uuid4()))
full_text = ""
for part in parts:
if part.get("kind") == "text":
full_text += part.get("text", "")
if full_text:
return full_text, message_id, parts
# Fall back to message field (original format)
message = result.get("message", {})
if message:
parts = message.get("parts", [])
message_id = message.get("messageId", str(uuid4()))
full_text = ""
for part in parts:
if part.get("kind") == "text":
full_text += part.get("text", "")
return full_text, message_id, parts
return "", str(uuid4()), []
@staticmethod
async def fake_streaming_from_response(
response_data: Dict[str, Any],
request_id: str,
chunk_size: int = 50,
delay_ms: int = 10,
) -> AsyncIterator[Dict[str, Any]]:
"""
Convert a non-streaming A2A response into fake streaming chunks.
Emits proper A2A streaming events:
1. Task event (kind: "task") - Initial task with status "submitted"
2. Status update (kind: "status-update") - Status "working"
3. Artifact update chunks (kind: "artifact-update") - Content delivery in chunks
4. Status update (kind: "status-update") - Final "completed" status
Args:
response_data: Non-streaming A2A response dict (completed task)
request_id: A2A JSON-RPC request ID
chunk_size: Number of characters per chunk (default: 50)
delay_ms: Delay between chunks in milliseconds (default: 10)
Yields:
A2A streaming response events
"""
# Extract the response text from completed task
full_text, message_id, parts = PydanticAITransformation._extract_response_text(
response_data
)
# Extract input message from raw response for history
result = response_data.get("result", {})
history = result.get("history", [])
input_message = {}
for msg in history:
if msg.get("role") == "user":
input_message = msg
break
# Generate IDs for streaming events
task_id = str(uuid4())
context_id = str(uuid4())
artifact_id = str(uuid4())
input_message_id = input_message.get("messageId", str(uuid4()))
# 1. Emit initial task event (kind: "task", status: "submitted")
# Format matches A2ACompletionBridgeTransformation.create_task_event
task_event = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"contextId": context_id,
"history": [
{
"contextId": context_id,
"kind": "message",
"messageId": input_message_id,
"parts": input_message.get(
"parts", [{"kind": "text", "text": ""}]
),
"role": "user",
"taskId": task_id,
}
],
"id": task_id,
"kind": "task",
"status": {
"state": "submitted",
},
},
}
yield task_event
# 2. Emit status update (kind: "status-update", status: "working")
# Format matches A2ACompletionBridgeTransformation.create_status_update_event
working_event = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"contextId": context_id,
"final": False,
"kind": "status-update",
"status": {
"state": "working",
},
"taskId": task_id,
},
}
yield working_event
# Small delay to simulate processing
await asyncio.sleep(delay_ms / 1000.0)
# 3. Emit artifact update chunks (kind: "artifact-update")
# Format matches A2ACompletionBridgeTransformation.create_artifact_update_event
if full_text:
# Split text into chunks
for i in range(0, len(full_text), chunk_size):
chunk_text = full_text[i : i + chunk_size]
is_last_chunk = (i + chunk_size) >= len(full_text)
artifact_event = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"contextId": context_id,
"kind": "artifact-update",
"taskId": task_id,
"artifact": {
"artifactId": artifact_id,
"parts": [
{
"kind": "text",
"text": chunk_text,
}
],
},
},
}
yield artifact_event
# Add delay between chunks (except for last chunk)
if not is_last_chunk:
await asyncio.sleep(delay_ms / 1000.0)
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
completed_event = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"contextId": context_id,
"final": True,
"kind": "status-update",
"status": {
"state": "completed",
},
"taskId": task_id,
},
}
yield completed_event
verbose_logger.info(
f"Pydantic AI: Fake streaming completed for request_id={request_id}"
)

View File

@@ -0,0 +1,184 @@
"""
A2A Streaming Iterator with token tracking and logging support.
"""
import asyncio
from datetime import datetime
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional
import litellm
from litellm._logging import verbose_logger
from litellm.a2a_protocol.cost_calculator import A2ACostCalculator
from litellm.a2a_protocol.utils import A2ARequestUtils
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.thread_pool_executor import executor
if TYPE_CHECKING:
from a2a.types import SendStreamingMessageRequest, SendStreamingMessageResponse
class A2AStreamingIterator:
"""
Async iterator for A2A streaming responses with token tracking.
Collects chunks, extracts text, and logs usage on completion.
"""
def __init__(
self,
stream: AsyncIterator["SendStreamingMessageResponse"],
request: "SendStreamingMessageRequest",
logging_obj: LiteLLMLoggingObj,
agent_name: str = "unknown",
):
self.stream = stream
self.request = request
self.logging_obj = logging_obj
self.agent_name = agent_name
self.start_time = datetime.now()
# Collect chunks for token counting
self.chunks: List[Any] = []
self.collected_text_parts: List[str] = []
self.final_chunk: Optional[Any] = None
def __aiter__(self):
return self
async def __anext__(self) -> "SendStreamingMessageResponse":
try:
chunk = await self.stream.__anext__()
# Store chunk
self.chunks.append(chunk)
# Extract text from chunk for token counting
self._collect_text_from_chunk(chunk)
# Check if this is the final chunk (completed status)
if self._is_completed_chunk(chunk):
self.final_chunk = chunk
return chunk
except StopAsyncIteration:
# Stream ended - handle logging
if self.final_chunk is None and self.chunks:
self.final_chunk = self.chunks[-1]
await self._handle_stream_complete()
raise
def _collect_text_from_chunk(self, chunk: Any) -> None:
"""Extract text from a streaming chunk and add to collected parts."""
try:
chunk_dict = (
chunk.model_dump(mode="json", exclude_none=True)
if hasattr(chunk, "model_dump")
else {}
)
text = A2ARequestUtils.extract_text_from_response(chunk_dict)
if text:
self.collected_text_parts.append(text)
except Exception:
verbose_logger.debug("Failed to extract text from A2A streaming chunk")
def _is_completed_chunk(self, chunk: Any) -> bool:
"""Check if chunk indicates stream completion."""
try:
chunk_dict = (
chunk.model_dump(mode="json", exclude_none=True)
if hasattr(chunk, "model_dump")
else {}
)
result = chunk_dict.get("result", {})
if isinstance(result, dict):
status = result.get("status", {})
if isinstance(status, dict):
return status.get("state") == "completed"
except Exception:
pass
return False
async def _handle_stream_complete(self) -> None:
"""Handle logging and token counting when stream completes."""
try:
end_time = datetime.now()
# Calculate tokens from collected text
input_message = A2ARequestUtils.get_input_message_from_request(self.request)
input_text = A2ARequestUtils.extract_text_from_message(input_message)
prompt_tokens = A2ARequestUtils.count_tokens(input_text)
# Use the last (most complete) text from chunks
output_text = (
self.collected_text_parts[-1] if self.collected_text_parts else ""
)
completion_tokens = A2ARequestUtils.count_tokens(output_text)
total_tokens = prompt_tokens + completion_tokens
# Create usage object
usage = litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
# Set usage on logging obj
self.logging_obj.model_call_details["usage"] = usage
# Mark stream flag for downstream callbacks
self.logging_obj.model_call_details["stream"] = False
# Calculate cost using A2ACostCalculator
response_cost = A2ACostCalculator.calculate_a2a_cost(self.logging_obj)
self.logging_obj.model_call_details["response_cost"] = response_cost
# Build result for logging
result = self._build_logging_result(usage)
# Call success handlers - they will build standard_logging_object
asyncio.create_task(
self.logging_obj.async_success_handler(
result=result,
start_time=self.start_time,
end_time=end_time,
cache_hit=None,
)
)
executor.submit(
self.logging_obj.success_handler,
result=result,
cache_hit=None,
start_time=self.start_time,
end_time=end_time,
)
verbose_logger.info(
f"A2A streaming completed: prompt_tokens={prompt_tokens}, "
f"completion_tokens={completion_tokens}, total_tokens={total_tokens}, "
f"response_cost={response_cost}"
)
except Exception as e:
verbose_logger.debug(f"Error in A2A streaming completion handler: {e}")
def _build_logging_result(self, usage: litellm.Usage) -> Dict[str, Any]:
"""Build a result dict for logging."""
result: Dict[str, Any] = {
"id": getattr(self.request, "id", "unknown"),
"jsonrpc": "2.0",
"usage": usage.model_dump()
if hasattr(usage, "model_dump")
else dict(usage),
}
# Add final chunk result if available
if self.final_chunk:
try:
chunk_dict = self.final_chunk.model_dump(mode="json", exclude_none=True)
result["result"] = chunk_dict.get("result", {})
except Exception:
pass
return result

View File

@@ -0,0 +1,138 @@
"""
Utility functions for A2A protocol.
"""
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
import litellm
from litellm._logging import verbose_logger
if TYPE_CHECKING:
from a2a.types import SendMessageRequest, SendStreamingMessageRequest
class A2ARequestUtils:
"""Utility class for A2A request/response processing."""
@staticmethod
def extract_text_from_message(message: Any) -> str:
"""
Extract text content from A2A message parts.
Args:
message: A2A message dict or object with 'parts' containing text parts
Returns:
Concatenated text from all text parts
"""
if message is None:
return ""
# Handle both dict and object access
if isinstance(message, dict):
parts = message.get("parts", [])
else:
parts = getattr(message, "parts", []) or []
text_parts: List[str] = []
for part in parts:
if isinstance(part, dict):
if part.get("kind") == "text":
text_parts.append(part.get("text", ""))
else:
if getattr(part, "kind", None) == "text":
text_parts.append(getattr(part, "text", ""))
return " ".join(text_parts)
@staticmethod
def extract_text_from_response(response_dict: Dict[str, Any]) -> str:
"""
Extract text content from A2A response result.
Args:
response_dict: A2A response dict with 'result' containing message
Returns:
Text from response message parts
"""
result = response_dict.get("result", {})
if not isinstance(result, dict):
return ""
message = result.get("message", {})
return A2ARequestUtils.extract_text_from_message(message)
@staticmethod
def get_input_message_from_request(
request: "Union[SendMessageRequest, SendStreamingMessageRequest]",
) -> Any:
"""
Extract the input message from an A2A request.
Args:
request: The A2A SendMessageRequest or SendStreamingMessageRequest
Returns:
The message object/dict or None
"""
params = getattr(request, "params", None)
if params is None:
return None
return getattr(params, "message", None)
@staticmethod
def count_tokens(text: str) -> int:
"""
Count tokens in text using litellm.token_counter.
Args:
text: Text to count tokens for
Returns:
Token count, or 0 if counting fails
"""
if not text:
return 0
try:
return litellm.token_counter(text=text)
except Exception:
verbose_logger.debug("Failed to count tokens")
return 0
@staticmethod
def calculate_usage_from_request_response(
request: "Union[SendMessageRequest, SendStreamingMessageRequest]",
response_dict: Dict[str, Any],
) -> Tuple[int, int, int]:
"""
Calculate token usage from A2A request and response.
Args:
request: The A2A SendMessageRequest or SendStreamingMessageRequest
response_dict: The A2A response as a dict
Returns:
Tuple of (prompt_tokens, completion_tokens, total_tokens)
"""
# Count input tokens
input_message = A2ARequestUtils.get_input_message_from_request(request)
input_text = A2ARequestUtils.extract_text_from_message(input_message)
prompt_tokens = A2ARequestUtils.count_tokens(input_text)
# Count output tokens
output_text = A2ARequestUtils.extract_text_from_response(response_dict)
completion_tokens = A2ARequestUtils.count_tokens(output_text)
total_tokens = prompt_tokens + completion_tokens
return prompt_tokens, completion_tokens, total_tokens
# Backwards compatibility aliases
def extract_text_from_a2a_message(message: Any) -> str:
return A2ARequestUtils.extract_text_from_message(message)
def extract_text_from_a2a_response(response_dict: Dict[str, Any]) -> str:
return A2ARequestUtils.extract_text_from_response(response_dict)