chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,500 @@
|
||||
"""
|
||||
A2A Protocol endpoints for LiteLLM Proxy.
|
||||
|
||||
Allows clients to invoke agents through LiteLLM using the A2A protocol.
|
||||
The A2A SDK can point to LiteLLM's URL and invoke agents registered with LiteLLM.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.agent_endpoints.utils import merge_agent_headers
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.types.utils import all_litellm_params
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _jsonrpc_error(
|
||||
request_id: Optional[str],
|
||||
code: int,
|
||||
message: str,
|
||||
status_code: int = 400,
|
||||
) -> JSONResponse:
|
||||
"""Create a JSON-RPC 2.0 error response."""
|
||||
return JSONResponse(
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {"code": code, "message": message},
|
||||
},
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
|
||||
def _get_agent(agent_id: str):
|
||||
"""Look up an agent by ID or name. Returns None if not found."""
|
||||
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
|
||||
|
||||
agent = global_agent_registry.get_agent_by_id(agent_id=agent_id)
|
||||
if agent is None:
|
||||
agent = global_agent_registry.get_agent_by_name(agent_name=agent_id)
|
||||
return agent
|
||||
|
||||
|
||||
def _enforce_inbound_trace_id(agent: Any, request: Request) -> None:
|
||||
"""Raise 400 if agent requires x-litellm-trace-id on inbound calls and it is missing."""
|
||||
agent_litellm_params = agent.litellm_params or {}
|
||||
if not agent_litellm_params.get("require_trace_id_on_calls_to_agent"):
|
||||
return
|
||||
|
||||
from litellm.proxy.litellm_pre_call_utils import get_chain_id_from_headers
|
||||
|
||||
headers_dict = dict(request.headers)
|
||||
trace_id = get_chain_id_from_headers(headers_dict)
|
||||
if not trace_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Agent '{agent.agent_id}' requires x-litellm-trace-id header "
|
||||
"on all inbound requests."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_stream_message(
|
||||
api_base: Optional[str],
|
||||
request_id: str,
|
||||
params: dict,
|
||||
litellm_params: Optional[dict] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
proxy_server_request: Optional[dict] = None,
|
||||
*,
|
||||
agent_extra_headers: Optional[Dict[str, str]] = None,
|
||||
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
|
||||
request_data: Optional[dict] = None,
|
||||
proxy_logging_obj: Optional[Any] = None,
|
||||
) -> StreamingResponse:
|
||||
"""Handle message/stream method via SDK functions.
|
||||
|
||||
When user_api_key_dict, request_data, and proxy_logging_obj are provided,
|
||||
uses common_request_processing.async_streaming_data_generator with NDJSON
|
||||
serializers so proxy hooks and cost injection apply.
|
||||
"""
|
||||
from litellm.a2a_protocol import asend_message_streaming
|
||||
from litellm.a2a_protocol.main import A2A_SDK_AVAILABLE
|
||||
|
||||
if not A2A_SDK_AVAILABLE:
|
||||
|
||||
async def _error_stream():
|
||||
yield json.dumps(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": "Server error: 'a2a' package not installed",
|
||||
},
|
||||
}
|
||||
) + "\n"
|
||||
|
||||
return StreamingResponse(_error_stream(), media_type="application/x-ndjson")
|
||||
|
||||
from a2a.types import MessageSendParams, SendStreamingMessageRequest
|
||||
|
||||
use_proxy_hooks = (
|
||||
user_api_key_dict is not None
|
||||
and request_data is not None
|
||||
and proxy_logging_obj is not None
|
||||
)
|
||||
|
||||
async def stream_response():
|
||||
try:
|
||||
a2a_request = SendStreamingMessageRequest(
|
||||
id=request_id,
|
||||
params=MessageSendParams(**params),
|
||||
)
|
||||
a2a_stream = asend_message_streaming(
|
||||
request=a2a_request,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
agent_id=agent_id,
|
||||
metadata=metadata,
|
||||
proxy_server_request=proxy_server_request,
|
||||
agent_extra_headers=agent_extra_headers,
|
||||
)
|
||||
|
||||
if (
|
||||
use_proxy_hooks
|
||||
and user_api_key_dict is not None
|
||||
and request_data is not None
|
||||
and proxy_logging_obj is not None
|
||||
):
|
||||
from litellm.proxy.common_request_processing import (
|
||||
ProxyBaseLLMRequestProcessing,
|
||||
)
|
||||
|
||||
def _ndjson_chunk(chunk: Any) -> str:
|
||||
if hasattr(chunk, "model_dump"):
|
||||
obj = chunk.model_dump(mode="json", exclude_none=True)
|
||||
else:
|
||||
obj = chunk
|
||||
return json.dumps(obj) + "\n"
|
||||
|
||||
def _ndjson_error(proxy_exc: Any) -> str:
|
||||
return (
|
||||
json.dumps(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": getattr(
|
||||
proxy_exc,
|
||||
"message",
|
||||
f"Streaming error: {proxy_exc!s}",
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
async for (
|
||||
line
|
||||
) in ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
|
||||
response=a2a_stream,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=request_data,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
serialize_chunk=_ndjson_chunk,
|
||||
serialize_error=_ndjson_error,
|
||||
):
|
||||
yield line
|
||||
else:
|
||||
async for chunk in a2a_stream:
|
||||
if hasattr(chunk, "model_dump"):
|
||||
yield json.dumps(
|
||||
chunk.model_dump(mode="json", exclude_none=True)
|
||||
) + "\n"
|
||||
else:
|
||||
yield json.dumps(chunk) + "\n"
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error streaming A2A response: {e}")
|
||||
if (
|
||||
use_proxy_hooks
|
||||
and proxy_logging_obj is not None
|
||||
and user_api_key_dict is not None
|
||||
and request_data is not None
|
||||
):
|
||||
transformed_exception = await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=request_data,
|
||||
)
|
||||
if transformed_exception is not None:
|
||||
e = transformed_exception
|
||||
if isinstance(e, HTTPException):
|
||||
raise
|
||||
yield json.dumps(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {"code": -32603, "message": f"Streaming error: {str(e)}"},
|
||||
}
|
||||
) + "\n"
|
||||
|
||||
return StreamingResponse(stream_response(), media_type="application/x-ndjson")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/a2a/{agent_id}/.well-known/agent-card.json",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@router.get(
|
||||
"/a2a/{agent_id}/.well-known/agent.json",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_agent_card(
|
||||
agent_id: str,
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get the agent card for an agent (A2A discovery endpoint).
|
||||
|
||||
Supports both standard paths:
|
||||
- /.well-known/agent-card.json
|
||||
- /.well-known/agent.json
|
||||
|
||||
The URL in the agent card is rewritten to point to the LiteLLM proxy,
|
||||
so all subsequent A2A calls go through LiteLLM for logging and cost tracking.
|
||||
"""
|
||||
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
|
||||
AgentRequestHandler,
|
||||
)
|
||||
|
||||
try:
|
||||
agent = _get_agent(agent_id)
|
||||
if agent is None:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||
|
||||
# Check agent permission (skip for admin users)
|
||||
is_allowed = await AgentRequestHandler.is_agent_allowed(
|
||||
agent_id=agent.agent_id,
|
||||
user_api_key_auth=user_api_key_dict,
|
||||
)
|
||||
if not is_allowed:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Agent '{agent_id}' is not allowed for your key/team. Contact proxy admin for access.",
|
||||
)
|
||||
|
||||
# Copy and rewrite URL to point to LiteLLM proxy
|
||||
agent_card = dict(agent.agent_card_params)
|
||||
agent_card["url"] = f"{str(request.base_url).rstrip('/')}/a2a/{agent_id}"
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Returning agent card for '{agent_id}' with proxy URL: {agent_card['url']}"
|
||||
)
|
||||
return JSONResponse(content=agent_card)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting agent card: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/a2a/{agent_id}",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@router.post(
|
||||
"/a2a/{agent_id}/message/send",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@router.post(
|
||||
"/v1/a2a/{agent_id}/message/send",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def invoke_agent_a2a( # noqa: PLR0915
|
||||
agent_id: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Invoke an agent using the A2A protocol (JSON-RPC 2.0).
|
||||
|
||||
Supported methods:
|
||||
- message/send: Send a message and get a response
|
||||
- message/stream: Send a message and stream the response
|
||||
"""
|
||||
from litellm.a2a_protocol import asend_message
|
||||
from litellm.a2a_protocol.main import A2A_SDK_AVAILABLE
|
||||
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
|
||||
AgentRequestHandler,
|
||||
)
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
)
|
||||
|
||||
body = {}
|
||||
try:
|
||||
body = await request.json()
|
||||
|
||||
verbose_proxy_logger.debug(f"A2A request for agent '{agent_id}': {body}")
|
||||
|
||||
# Validate JSON-RPC format
|
||||
if body.get("jsonrpc") != "2.0":
|
||||
return _jsonrpc_error(
|
||||
body.get("id"), -32600, "Invalid Request: jsonrpc must be '2.0'"
|
||||
)
|
||||
|
||||
request_id = body.get("id")
|
||||
method = body.get("method")
|
||||
params = body.get("params", {})
|
||||
|
||||
if params:
|
||||
# extract any litellm params from the params - eg. 'guardrails'
|
||||
params_to_remove = []
|
||||
for key, value in params.items():
|
||||
if key in all_litellm_params:
|
||||
params_to_remove.append(key)
|
||||
body[key] = value
|
||||
for key in params_to_remove:
|
||||
params.pop(key)
|
||||
|
||||
if not A2A_SDK_AVAILABLE:
|
||||
return _jsonrpc_error(
|
||||
request_id,
|
||||
-32603,
|
||||
"Server error: 'a2a' package not installed. Please install 'a2a-sdk'.",
|
||||
500,
|
||||
)
|
||||
|
||||
# Find the agent
|
||||
agent = _get_agent(agent_id)
|
||||
if agent is None:
|
||||
return _jsonrpc_error(
|
||||
request_id, -32000, f"Agent '{agent_id}' not found", 404
|
||||
)
|
||||
|
||||
is_allowed = await AgentRequestHandler.is_agent_allowed(
|
||||
agent_id=agent.agent_id,
|
||||
user_api_key_auth=user_api_key_dict,
|
||||
)
|
||||
if not is_allowed:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Agent '{agent_id}' is not allowed for your key/team. Contact proxy admin for access.",
|
||||
)
|
||||
|
||||
_enforce_inbound_trace_id(agent, request)
|
||||
|
||||
# Get backend URL and agent name
|
||||
agent_url = agent.agent_card_params.get("url")
|
||||
agent_name = agent.agent_card_params.get("name", agent_id)
|
||||
|
||||
# Get litellm_params (may include custom_llm_provider for completion bridge)
|
||||
litellm_params = agent.litellm_params or {}
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
|
||||
# URL is required unless using completion bridge with a provider that derives endpoint from model
|
||||
# (e.g., bedrock/agentcore derives endpoint from ARN in model string)
|
||||
if not agent_url and not custom_llm_provider:
|
||||
return _jsonrpc_error(
|
||||
request_id, -32000, f"Agent '{agent_id}' has no URL configured", 500
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Proxying A2A request to agent '{agent_id}' at {agent_url or 'completion-bridge'}"
|
||||
)
|
||||
|
||||
# Set up data dict for litellm processing
|
||||
if "metadata" not in body:
|
||||
body["metadata"] = {}
|
||||
body["metadata"]["agent_id"] = agent.agent_id
|
||||
|
||||
body.update(
|
||||
{
|
||||
"model": f"a2a_agent/{agent_name}",
|
||||
"custom_llm_provider": "a2a_agent",
|
||||
}
|
||||
)
|
||||
|
||||
# Add litellm data (user_api_key, user_id, team_id, etc.)
|
||||
from litellm.proxy.common_request_processing import (
|
||||
ProxyBaseLLMRequestProcessing,
|
||||
)
|
||||
|
||||
processor = ProxyBaseLLMRequestProcessing(data=body)
|
||||
data, logging_obj = await processor.common_processing_pre_call_logic(
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
proxy_config=proxy_config,
|
||||
route_type="asend_message",
|
||||
version=version,
|
||||
)
|
||||
|
||||
# Build merged headers for the backend agent
|
||||
static_headers: Dict[str, str] = dict(agent.static_headers or {})
|
||||
|
||||
raw_headers = dict(request.headers)
|
||||
normalized = {k.lower(): v for k, v in raw_headers.items()}
|
||||
|
||||
dynamic_headers: Dict[str, str] = {}
|
||||
|
||||
# 1. Admin-configured extra_headers: forward named headers from client request
|
||||
if agent.extra_headers:
|
||||
for header_name in agent.extra_headers:
|
||||
val = normalized.get(header_name.lower())
|
||||
if val is not None:
|
||||
dynamic_headers[header_name] = val
|
||||
|
||||
# 2. Convention-based forwarding: x-a2a-{agent_id_or_name}-{header_name}
|
||||
# Matches both agent_id (UUID) and agent_name (alias), case-insensitive.
|
||||
for alias in (agent.agent_id.lower(), agent.agent_name.lower()):
|
||||
prefix = f"x-a2a-{alias}-"
|
||||
for key, val in normalized.items():
|
||||
if key.startswith(prefix):
|
||||
header_name = key[len(prefix) :]
|
||||
if header_name:
|
||||
dynamic_headers[header_name] = val
|
||||
|
||||
agent_extra_headers = merge_agent_headers(
|
||||
dynamic_headers=dynamic_headers or None,
|
||||
static_headers=static_headers or None,
|
||||
)
|
||||
|
||||
# Route through SDK functions
|
||||
if method == "message/send":
|
||||
from a2a.types import MessageSendParams, SendMessageRequest
|
||||
|
||||
a2a_request = SendMessageRequest(
|
||||
id=request_id,
|
||||
params=MessageSendParams(**params),
|
||||
)
|
||||
response = await asend_message(
|
||||
request=a2a_request,
|
||||
api_base=agent_url,
|
||||
litellm_params=litellm_params,
|
||||
agent_id=agent.agent_id,
|
||||
metadata=data.get("metadata", {}),
|
||||
proxy_server_request=data.get("proxy_server_request"),
|
||||
litellm_logging_obj=logging_obj,
|
||||
agent_extra_headers=agent_extra_headers,
|
||||
)
|
||||
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
data=data,
|
||||
response=response,
|
||||
)
|
||||
return JSONResponse(
|
||||
content=(
|
||||
response.model_dump(mode="json", exclude_none=True) # type: ignore
|
||||
if hasattr(response, "model_dump")
|
||||
else response
|
||||
)
|
||||
)
|
||||
|
||||
elif method == "message/stream":
|
||||
return await _handle_stream_message(
|
||||
api_base=agent_url,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
agent_id=agent.agent_id,
|
||||
metadata=data.get("metadata", {}),
|
||||
proxy_server_request=data.get("proxy_server_request"),
|
||||
agent_extra_headers=agent_extra_headers,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=data,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
else:
|
||||
return _jsonrpc_error(request_id, -32601, f"Method '{method}' not found")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error invoking agent: {e}")
|
||||
return _jsonrpc_error(body.get("id"), -32603, f"Internal error: {str(e)}", 500)
|
||||
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
A2A Agent Routing
|
||||
|
||||
Handles routing for A2A agents (models with "a2a/<agent-name>" prefix).
|
||||
Looks up agents in the registry and injects their API base URL.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
|
||||
def route_a2a_agent_request(data: dict, route_type: str) -> Optional[Any]:
|
||||
"""
|
||||
Route A2A agent requests directly to litellm with injected API base.
|
||||
|
||||
Returns None if not an A2A request (allows normal routing to continue).
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
|
||||
from litellm.proxy.route_llm_request import (
|
||||
ROUTE_ENDPOINT_MAPPING,
|
||||
ProxyModelNotFoundError,
|
||||
)
|
||||
|
||||
model_name = data.get("model", "")
|
||||
|
||||
# Check if this is an A2A agent request
|
||||
if not isinstance(model_name, str) or not model_name.startswith("a2a/"):
|
||||
return None
|
||||
|
||||
# Extract agent name (e.g., "a2a/my-agent" -> "my-agent")
|
||||
agent_name = model_name[4:]
|
||||
|
||||
# Look up agent in registry
|
||||
agent = global_agent_registry.get_agent_by_name(agent_name)
|
||||
if agent is None:
|
||||
verbose_proxy_logger.error(f"[A2A] Agent '{agent_name}' not found in registry")
|
||||
route_name = ROUTE_ENDPOINT_MAPPING.get(route_type, route_type)
|
||||
raise ProxyModelNotFoundError(route=route_name, model_name=model_name)
|
||||
|
||||
# Get API base URL from agent config
|
||||
if not agent.agent_card_params or "url" not in agent.agent_card_params:
|
||||
verbose_proxy_logger.error(f"[A2A] Agent '{agent_name}' has no URL configured")
|
||||
route_name = ROUTE_ENDPOINT_MAPPING.get(route_type, route_type)
|
||||
raise ProxyModelNotFoundError(route=route_name, model_name=model_name)
|
||||
|
||||
# Inject API base and route to litellm
|
||||
data["api_base"] = agent.agent_card_params["url"]
|
||||
verbose_proxy_logger.debug(f"[A2A] Routing {model_name} to {data['api_base']}")
|
||||
|
||||
return getattr(litellm, f"{route_type}")(**data)
|
||||
@@ -0,0 +1,458 @@
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy.management_helpers.object_permission_utils import (
|
||||
handle_update_object_permission_common,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.types.agents import AgentConfig, AgentResponse, PatchAgentRequest
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
def __init__(self):
|
||||
self.agent_list: List[AgentResponse] = []
|
||||
|
||||
def reset_agent_list(self):
|
||||
self.agent_list = []
|
||||
|
||||
def register_agent(self, agent_config: AgentResponse):
|
||||
self.agent_list.append(agent_config)
|
||||
|
||||
def deregister_agent(self, agent_name: str):
|
||||
self.agent_list = [
|
||||
agent for agent in self.agent_list if agent.agent_name != agent_name
|
||||
]
|
||||
|
||||
def get_agent_list(self, agent_names: Optional[List[str]] = None):
|
||||
if agent_names is not None:
|
||||
return [
|
||||
agent for agent in self.agent_list if agent.agent_name in agent_names
|
||||
]
|
||||
return self.agent_list
|
||||
|
||||
def get_public_agent_list(self) -> List[AgentResponse]:
|
||||
public_agent_list: List[AgentResponse] = []
|
||||
if litellm.public_agent_groups is None:
|
||||
return public_agent_list
|
||||
for agent in self.agent_list:
|
||||
if agent.agent_id in litellm.public_agent_groups:
|
||||
public_agent_list.append(agent)
|
||||
return public_agent_list
|
||||
|
||||
def _create_agent_id(self, agent_config: AgentConfig) -> str:
|
||||
return hashlib.sha256(
|
||||
json.dumps(agent_config, sort_keys=True).encode()
|
||||
).hexdigest()
|
||||
|
||||
def load_agents_from_config(self, agent_config: Optional[List[AgentConfig]] = None):
|
||||
if agent_config is None:
|
||||
return None
|
||||
|
||||
for agent_config_item in agent_config:
|
||||
if not isinstance(agent_config_item, dict):
|
||||
raise ValueError("agent_config must be a list of dictionaries")
|
||||
|
||||
agent_name = agent_config_item.get("agent_name")
|
||||
agent_card_params = agent_config_item.get("agent_card_params")
|
||||
if not all([agent_name, agent_card_params]):
|
||||
continue
|
||||
|
||||
# create a stable hash id for config item
|
||||
config_hash = self._create_agent_id(agent_config_item)
|
||||
|
||||
self.register_agent(agent_config=AgentResponse(agent_id=config_hash, **agent_config_item)) # type: ignore
|
||||
|
||||
def load_agents_from_db_and_config(
|
||||
self,
|
||||
agent_config: Optional[List[AgentConfig]] = None,
|
||||
db_agents: Optional[List[Dict[str, Any]]] = None,
|
||||
):
|
||||
self.reset_agent_list()
|
||||
|
||||
if agent_config:
|
||||
for agent_config_item in agent_config:
|
||||
if not isinstance(agent_config_item, dict):
|
||||
raise ValueError("agent_config must be a list of dictionaries")
|
||||
|
||||
self.register_agent(agent_config=AgentResponse(agent_id=self._create_agent_id(agent_config_item), **agent_config_item)) # type: ignore
|
||||
|
||||
if db_agents:
|
||||
for db_agent in db_agents:
|
||||
if not isinstance(db_agent, dict):
|
||||
raise ValueError("db_agents must be a list of dictionaries")
|
||||
|
||||
self.register_agent(agent_config=AgentResponse(**db_agent)) # type: ignore
|
||||
return self.agent_list
|
||||
|
||||
###########################################################
|
||||
########### DB management helpers for agents ###########
|
||||
############################################################
|
||||
async def add_agent_to_db(
|
||||
self, agent: AgentConfig, prisma_client: PrismaClient, created_by: str
|
||||
) -> AgentResponse:
|
||||
"""
|
||||
Add an agent to the database
|
||||
"""
|
||||
try:
|
||||
agent_name = agent.get("agent_name")
|
||||
|
||||
# Serialize litellm_params
|
||||
litellm_params_obj: Any = agent.get("litellm_params", {})
|
||||
if hasattr(litellm_params_obj, "model_dump"):
|
||||
litellm_params_dict = litellm_params_obj.model_dump()
|
||||
else:
|
||||
litellm_params_dict = (
|
||||
dict(litellm_params_obj) if litellm_params_obj else {}
|
||||
)
|
||||
litellm_params: str = safe_dumps(litellm_params_dict)
|
||||
|
||||
# Serialize agent_card_params
|
||||
agent_card_params_obj: Any = agent.get("agent_card_params", {})
|
||||
if hasattr(agent_card_params_obj, "model_dump"):
|
||||
agent_card_params_dict = agent_card_params_obj.model_dump()
|
||||
else:
|
||||
agent_card_params_dict = (
|
||||
dict(agent_card_params_obj) if agent_card_params_obj else {}
|
||||
)
|
||||
agent_card_params: str = safe_dumps(agent_card_params_dict)
|
||||
|
||||
# Handle object_permission (MCP tool access for agent)
|
||||
object_permission_id: Optional[str] = None
|
||||
if agent.get("object_permission") is not None:
|
||||
agent_copy = dict(agent)
|
||||
object_permission_id = await handle_update_object_permission_common(
|
||||
agent_copy, None, prisma_client
|
||||
)
|
||||
|
||||
# Serialize static_headers
|
||||
static_headers_obj = agent.get("static_headers")
|
||||
static_headers_val: Optional[str] = (
|
||||
safe_dumps(dict(static_headers_obj)) if static_headers_obj else None
|
||||
)
|
||||
|
||||
extra_headers_val: Optional[List[str]] = agent.get("extra_headers")
|
||||
|
||||
create_data: Dict[str, Any] = {
|
||||
"agent_name": agent_name,
|
||||
"litellm_params": litellm_params,
|
||||
"agent_card_params": agent_card_params,
|
||||
"created_by": created_by,
|
||||
"updated_by": created_by,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
}
|
||||
if static_headers_val is not None:
|
||||
create_data["static_headers"] = static_headers_val
|
||||
if extra_headers_val is not None:
|
||||
create_data["extra_headers"] = extra_headers_val
|
||||
if object_permission_id is not None:
|
||||
create_data["object_permission_id"] = object_permission_id
|
||||
|
||||
for rate_field in (
|
||||
"tpm_limit",
|
||||
"rpm_limit",
|
||||
"session_tpm_limit",
|
||||
"session_rpm_limit",
|
||||
):
|
||||
_val = agent.get(rate_field)
|
||||
if _val is not None:
|
||||
create_data[rate_field] = _val
|
||||
|
||||
# Create agent in DB
|
||||
created_agent = await prisma_client.db.litellm_agentstable.create(
|
||||
data=create_data,
|
||||
include={"object_permission": True},
|
||||
)
|
||||
|
||||
created_agent_dict = created_agent.model_dump()
|
||||
if created_agent.object_permission is not None:
|
||||
try:
|
||||
created_agent_dict[
|
||||
"object_permission"
|
||||
] = created_agent.object_permission.model_dump()
|
||||
except Exception:
|
||||
created_agent_dict[
|
||||
"object_permission"
|
||||
] = created_agent.object_permission.dict()
|
||||
return AgentResponse(**created_agent_dict) # type: ignore
|
||||
except Exception as e:
|
||||
raise Exception(f"Error adding agent to DB: {str(e)}")
|
||||
|
||||
async def delete_agent_from_db(
|
||||
self, agent_id: str, prisma_client: PrismaClient
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Delete an agent from the database
|
||||
"""
|
||||
try:
|
||||
deleted_agent = await prisma_client.db.litellm_agentstable.delete(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
return dict(deleted_agent)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error deleting agent from DB: {str(e)}")
|
||||
|
||||
async def patch_agent_in_db(
|
||||
self,
|
||||
agent_id: str,
|
||||
agent: PatchAgentRequest,
|
||||
prisma_client: PrismaClient,
|
||||
updated_by: str,
|
||||
) -> AgentResponse:
|
||||
"""
|
||||
Patch an agent in the database.
|
||||
|
||||
Get the existing agent from the database and patch it with the new values.
|
||||
|
||||
Args:
|
||||
agent_id: The ID of the agent to patch
|
||||
agent: The new agent values to patch
|
||||
prisma_client: The Prisma client to use
|
||||
updated_by: The user ID of the user who is patching the agent
|
||||
|
||||
Returns:
|
||||
The patched agent
|
||||
"""
|
||||
try:
|
||||
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if existing_agent is not None:
|
||||
existing_agent = dict(existing_agent)
|
||||
|
||||
if existing_agent is None:
|
||||
raise Exception(f"Agent with ID {agent_id} not found")
|
||||
|
||||
augment_agent = {**existing_agent, **agent}
|
||||
update_data: Dict[str, Any] = {}
|
||||
if augment_agent.get("agent_name"):
|
||||
update_data["agent_name"] = augment_agent.get("agent_name")
|
||||
if augment_agent.get("litellm_params"):
|
||||
update_data["litellm_params"] = safe_dumps(
|
||||
augment_agent.get("litellm_params")
|
||||
)
|
||||
if augment_agent.get("agent_card_params"):
|
||||
update_data["agent_card_params"] = safe_dumps(
|
||||
augment_agent.get("agent_card_params")
|
||||
)
|
||||
|
||||
for rate_field in (
|
||||
"tpm_limit",
|
||||
"rpm_limit",
|
||||
"session_tpm_limit",
|
||||
"session_rpm_limit",
|
||||
):
|
||||
if rate_field in agent:
|
||||
update_data[rate_field] = agent.get(rate_field)
|
||||
if "static_headers" in agent:
|
||||
headers_value = agent.get("static_headers")
|
||||
update_data["static_headers"] = safe_dumps(
|
||||
dict(headers_value) if headers_value is not None else {}
|
||||
)
|
||||
if "extra_headers" in agent:
|
||||
extra_headers_value = agent.get("extra_headers")
|
||||
update_data["extra_headers"] = (
|
||||
extra_headers_value if extra_headers_value is not None else []
|
||||
)
|
||||
if agent.get("object_permission") is not None:
|
||||
agent_copy = dict(augment_agent)
|
||||
existing_object_permission_id = existing_agent.get(
|
||||
"object_permission_id"
|
||||
)
|
||||
object_permission_id = await handle_update_object_permission_common(
|
||||
agent_copy,
|
||||
existing_object_permission_id,
|
||||
prisma_client,
|
||||
)
|
||||
if object_permission_id is not None:
|
||||
update_data["object_permission_id"] = object_permission_id
|
||||
# Patch agent in DB
|
||||
patched_agent = await prisma_client.db.litellm_agentstable.update(
|
||||
where={"agent_id": agent_id},
|
||||
data={
|
||||
**update_data,
|
||||
"updated_by": updated_by,
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
},
|
||||
include={"object_permission": True},
|
||||
)
|
||||
patched_agent_dict = patched_agent.model_dump()
|
||||
if patched_agent.object_permission is not None:
|
||||
try:
|
||||
patched_agent_dict[
|
||||
"object_permission"
|
||||
] = patched_agent.object_permission.model_dump()
|
||||
except Exception:
|
||||
patched_agent_dict[
|
||||
"object_permission"
|
||||
] = patched_agent.object_permission.dict()
|
||||
return AgentResponse(**patched_agent_dict) # type: ignore
|
||||
except Exception as e:
|
||||
raise Exception(f"Error patching agent in DB: {str(e)}")
|
||||
|
||||
async def update_agent_in_db(
|
||||
self,
|
||||
agent_id: str,
|
||||
agent: AgentConfig,
|
||||
prisma_client: PrismaClient,
|
||||
updated_by: str,
|
||||
) -> AgentResponse:
|
||||
"""
|
||||
Update an agent in the database
|
||||
"""
|
||||
try:
|
||||
agent_name = agent.get("agent_name")
|
||||
|
||||
# Serialize litellm_params
|
||||
litellm_params_obj: Any = agent.get("litellm_params", {})
|
||||
if hasattr(litellm_params_obj, "model_dump"):
|
||||
litellm_params_dict = litellm_params_obj.model_dump()
|
||||
else:
|
||||
litellm_params_dict = (
|
||||
dict(litellm_params_obj) if litellm_params_obj else {}
|
||||
)
|
||||
litellm_params: str = safe_dumps(litellm_params_dict)
|
||||
|
||||
# Serialize agent_card_params
|
||||
agent_card_params_obj: Any = agent.get("agent_card_params", {})
|
||||
if hasattr(agent_card_params_obj, "model_dump"):
|
||||
agent_card_params_dict = agent_card_params_obj.model_dump()
|
||||
else:
|
||||
agent_card_params_dict = (
|
||||
dict(agent_card_params_obj) if agent_card_params_obj else {}
|
||||
)
|
||||
agent_card_params: str = safe_dumps(agent_card_params_dict)
|
||||
|
||||
# Serialize static_headers for update
|
||||
static_headers_obj_u = agent.get("static_headers")
|
||||
static_headers_val_u: str = (
|
||||
safe_dumps(dict(static_headers_obj_u))
|
||||
if static_headers_obj_u is not None
|
||||
else safe_dumps({})
|
||||
)
|
||||
extra_headers_val_u: List[str] = agent.get("extra_headers") or []
|
||||
|
||||
update_data: Dict[str, Any] = {
|
||||
"agent_name": agent_name,
|
||||
"litellm_params": litellm_params,
|
||||
"agent_card_params": agent_card_params,
|
||||
"static_headers": static_headers_val_u,
|
||||
"extra_headers": extra_headers_val_u,
|
||||
"updated_by": updated_by,
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
}
|
||||
|
||||
for rate_field in (
|
||||
"tpm_limit",
|
||||
"rpm_limit",
|
||||
"session_tpm_limit",
|
||||
"session_rpm_limit",
|
||||
):
|
||||
_val = agent.get(rate_field)
|
||||
if _val is not None:
|
||||
update_data[rate_field] = _val
|
||||
|
||||
if agent.get("object_permission") is not None:
|
||||
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
existing_object_permission_id = (
|
||||
existing_agent.object_permission_id
|
||||
if existing_agent is not None
|
||||
else None
|
||||
)
|
||||
agent_copy = dict(agent)
|
||||
object_permission_id = await handle_update_object_permission_common(
|
||||
agent_copy,
|
||||
existing_object_permission_id,
|
||||
prisma_client,
|
||||
)
|
||||
if object_permission_id is not None:
|
||||
update_data["object_permission_id"] = object_permission_id
|
||||
|
||||
# Update agent in DB
|
||||
updated_agent = await prisma_client.db.litellm_agentstable.update(
|
||||
where={"agent_id": agent_id},
|
||||
data=update_data,
|
||||
include={"object_permission": True},
|
||||
)
|
||||
|
||||
updated_agent_dict = updated_agent.model_dump()
|
||||
if updated_agent.object_permission is not None:
|
||||
try:
|
||||
updated_agent_dict[
|
||||
"object_permission"
|
||||
] = updated_agent.object_permission.model_dump()
|
||||
except Exception:
|
||||
updated_agent_dict[
|
||||
"object_permission"
|
||||
] = updated_agent.object_permission.dict()
|
||||
return AgentResponse(**updated_agent_dict) # type: ignore
|
||||
except Exception as e:
|
||||
raise Exception(f"Error updating agent in DB: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def get_all_agents_from_db(
|
||||
prisma_client: PrismaClient,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all agents from the database
|
||||
"""
|
||||
try:
|
||||
agents_from_db = await prisma_client.db.litellm_agentstable.find_many(
|
||||
order={"created_at": "desc"},
|
||||
include={"object_permission": True},
|
||||
)
|
||||
|
||||
agents: List[Dict[str, Any]] = []
|
||||
for agent in agents_from_db:
|
||||
agent_dict = dict(agent)
|
||||
# object_permission is eagerly loaded via include above
|
||||
if agent.object_permission is not None:
|
||||
try:
|
||||
agent_dict[
|
||||
"object_permission"
|
||||
] = agent.object_permission.model_dump()
|
||||
except Exception:
|
||||
agent_dict["object_permission"] = agent.object_permission.dict()
|
||||
agents.append(agent_dict)
|
||||
|
||||
return agents
|
||||
except Exception as e:
|
||||
raise Exception(f"Error getting agents from DB: {str(e)}")
|
||||
|
||||
def get_agent_by_id(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> Optional[AgentResponse]:
|
||||
"""
|
||||
Get an agent by its ID from the database
|
||||
"""
|
||||
try:
|
||||
for agent in self.agent_list:
|
||||
if agent.agent_id == agent_id:
|
||||
return agent
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
raise Exception(f"Error getting agent from DB: {str(e)}")
|
||||
|
||||
def get_agent_by_name(self, agent_name: str) -> Optional[AgentResponse]:
|
||||
"""
|
||||
Get an agent by its name from the database
|
||||
"""
|
||||
try:
|
||||
for agent in self.agent_list:
|
||||
if agent.agent_name == agent_name:
|
||||
return agent
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
raise Exception(f"Error getting agent from DB: {str(e)}")
|
||||
|
||||
|
||||
global_agent_registry = AgentRegistry()
|
||||
@@ -0,0 +1,451 @@
|
||||
"""
|
||||
Agent Permission Handler for LiteLLM Proxy.
|
||||
|
||||
Handles agent permission checking for keys and teams using object_permission_id.
|
||||
Follows the same pattern as MCP permission handling.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_ObjectPermissionTable,
|
||||
LiteLLM_TeamTable,
|
||||
UI_TEAM_ID,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
|
||||
|
||||
class AgentRequestHandler:
|
||||
"""
|
||||
Class to handle agent permission checking, including:
|
||||
1. Key-level agent permissions
|
||||
2. Team-level agent permissions
|
||||
3. Agent access group resolution
|
||||
|
||||
Follows the same inheritance logic as MCP:
|
||||
- If team has restrictions and key has restrictions: use intersection
|
||||
- If team has restrictions and key has none: inherit from team
|
||||
- If team has no restrictions: use key restrictions
|
||||
- If no restrictions: allow all agents
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def get_allowed_agents(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get list of allowed agent IDs for the given user/key based on permissions.
|
||||
|
||||
Returns:
|
||||
List[str]: List of allowed agent IDs. Empty list means no restrictions (allow all).
|
||||
"""
|
||||
try:
|
||||
allowed_agents: List[str] = []
|
||||
allowed_agents_for_key = (
|
||||
await AgentRequestHandler._get_allowed_agents_for_key(user_api_key_auth)
|
||||
)
|
||||
allowed_agents_for_team = (
|
||||
await AgentRequestHandler._get_allowed_agents_for_team(
|
||||
user_api_key_auth
|
||||
)
|
||||
)
|
||||
|
||||
# If team has agent restrictions, handle inheritance and intersection logic
|
||||
if len(allowed_agents_for_team) > 0:
|
||||
if len(allowed_agents_for_key) > 0:
|
||||
# Key has its own agent permissions - use intersection with team permissions
|
||||
for agent_id in allowed_agents_for_key:
|
||||
if agent_id in allowed_agents_for_team:
|
||||
allowed_agents.append(agent_id)
|
||||
else:
|
||||
# Key has no agent permissions - inherit from team
|
||||
allowed_agents = allowed_agents_for_team
|
||||
else:
|
||||
allowed_agents = allowed_agents_for_key
|
||||
|
||||
return list(set(allowed_agents))
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to get allowed agents: {str(e)}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def is_agent_allowed(
|
||||
agent_id: str,
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a specific agent is allowed for the given user/key.
|
||||
|
||||
Args:
|
||||
agent_id: The agent ID to check
|
||||
user_api_key_auth: User authentication info
|
||||
|
||||
Returns:
|
||||
bool: True if agent is allowed, False otherwise
|
||||
"""
|
||||
allowed_agents = await AgentRequestHandler.get_allowed_agents(user_api_key_auth)
|
||||
|
||||
# Empty list means no restrictions - allow all
|
||||
if len(allowed_agents) == 0:
|
||||
return True
|
||||
|
||||
return agent_id in allowed_agents
|
||||
|
||||
@staticmethod
|
||||
def _get_key_object_permission(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> Optional[LiteLLM_ObjectPermissionTable]:
|
||||
"""
|
||||
Get key object_permission - already loaded by get_key_object() in main auth flow.
|
||||
|
||||
Note: object_permission is automatically populated when the key is fetched via
|
||||
get_key_object() in litellm/proxy/auth/auth_checks.py
|
||||
"""
|
||||
if not user_api_key_auth:
|
||||
return None
|
||||
|
||||
return user_api_key_auth.object_permission
|
||||
|
||||
@staticmethod
|
||||
async def _get_team_object_permission(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> Optional[LiteLLM_ObjectPermissionTable]:
|
||||
"""
|
||||
Get team object_permission - automatically loaded by get_team_object() in main auth flow.
|
||||
|
||||
Note: object_permission is automatically populated when the team is fetched via
|
||||
get_team_object() in litellm/proxy/auth/auth_checks.py
|
||||
"""
|
||||
from litellm.proxy.auth.auth_checks import get_team_object
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if not user_api_key_auth or not user_api_key_auth.team_id or not prisma_client:
|
||||
return None
|
||||
|
||||
# Get the team object (which has object_permission already loaded)
|
||||
team_obj: Optional[LiteLLM_TeamTable] = await get_team_object(
|
||||
team_id=user_api_key_auth.team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=user_api_key_auth.parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
if not team_obj:
|
||||
return None
|
||||
|
||||
return team_obj.object_permission
|
||||
|
||||
@staticmethod
|
||||
async def _get_allowed_agents_for_key(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get allowed agents for a key.
|
||||
|
||||
1. First checks native key-level agent permissions (object_permission)
|
||||
2. Also includes agents from key's access_group_ids (unified access groups)
|
||||
|
||||
Note: object_permission is already loaded by get_key_object() in main auth flow.
|
||||
"""
|
||||
if user_api_key_auth is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
all_agents: List[str] = []
|
||||
|
||||
# 1. Get agents from object_permission (native permissions)
|
||||
key_object_permission = AgentRequestHandler._get_key_object_permission(
|
||||
user_api_key_auth
|
||||
)
|
||||
if key_object_permission is not None:
|
||||
# Get direct agents
|
||||
direct_agents = key_object_permission.agents or []
|
||||
|
||||
# Get agents from access groups
|
||||
access_group_agents = (
|
||||
await AgentRequestHandler._get_agents_from_access_groups(
|
||||
key_object_permission.agent_access_groups or []
|
||||
)
|
||||
)
|
||||
|
||||
all_agents = direct_agents + access_group_agents
|
||||
|
||||
# 2. Fallback: get agent IDs from key's access_group_ids (unified access groups)
|
||||
key_access_group_ids = user_api_key_auth.access_group_ids or []
|
||||
if key_access_group_ids:
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
_get_agent_ids_from_access_groups,
|
||||
)
|
||||
|
||||
unified_agents = await _get_agent_ids_from_access_groups(
|
||||
access_group_ids=key_access_group_ids,
|
||||
)
|
||||
all_agents.extend(unified_agents)
|
||||
|
||||
return list(set(all_agents))
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to get allowed agents for key: {str(e)}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def _get_allowed_agents_for_team(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get allowed agents for a team.
|
||||
|
||||
1. First checks native team-level agent permissions (object_permission)
|
||||
2. Also includes agents from team's access_group_ids (unified access groups)
|
||||
|
||||
Fetches the team object once and reuses it for both permission sources.
|
||||
"""
|
||||
if user_api_key_auth is None:
|
||||
return []
|
||||
|
||||
if user_api_key_auth.team_id is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
from litellm.proxy.auth.auth_checks import get_team_object
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if not prisma_client:
|
||||
return []
|
||||
|
||||
# Fetch the team object once for both permission sources
|
||||
team_obj = await get_team_object(
|
||||
team_id=user_api_key_auth.team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=user_api_key_auth.parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
if team_obj is None:
|
||||
return []
|
||||
|
||||
all_agents: List[str] = []
|
||||
|
||||
# 1. Get agents from object_permission (native permissions)
|
||||
object_permissions = team_obj.object_permission
|
||||
if object_permissions is not None:
|
||||
# Get direct agents
|
||||
direct_agents = object_permissions.agents or []
|
||||
|
||||
# Get agents from access groups
|
||||
access_group_agents = (
|
||||
await AgentRequestHandler._get_agents_from_access_groups(
|
||||
object_permissions.agent_access_groups or []
|
||||
)
|
||||
)
|
||||
|
||||
all_agents = direct_agents + access_group_agents
|
||||
|
||||
# 2. Also include agents from team's access_group_ids (unified access groups)
|
||||
team_access_group_ids = team_obj.access_group_ids or []
|
||||
if team_access_group_ids:
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
_get_agent_ids_from_access_groups,
|
||||
)
|
||||
|
||||
unified_agents = await _get_agent_ids_from_access_groups(
|
||||
access_group_ids=team_access_group_ids,
|
||||
)
|
||||
all_agents.extend(unified_agents)
|
||||
|
||||
return list(set(all_agents))
|
||||
except Exception as e:
|
||||
# litellm-dashboard is the default UI team and will never have agents;
|
||||
# skip noisy warnings for it.
|
||||
if user_api_key_auth.team_id != UI_TEAM_ID:
|
||||
verbose_logger.warning(
|
||||
f"Failed to get allowed agents for team: {str(e)}"
|
||||
)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _get_config_agent_ids_for_access_groups(
|
||||
config_agents: List, access_groups: List[str]
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Helper to get agent_ids from config-loaded agents that match any of the given access groups.
|
||||
"""
|
||||
server_ids: Set[str] = set()
|
||||
for agent in config_agents:
|
||||
agent_access_groups = getattr(agent, "agent_access_groups", None)
|
||||
if agent_access_groups:
|
||||
if any(group in agent_access_groups for group in access_groups):
|
||||
server_ids.add(agent.agent_id)
|
||||
return server_ids
|
||||
|
||||
@staticmethod
|
||||
async def _get_db_agent_ids_for_access_groups(
|
||||
prisma_client, access_groups: List[str]
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Helper to get agent_ids from DB agents that match any of the given access groups.
|
||||
"""
|
||||
agent_ids: Set[str] = set()
|
||||
if access_groups and prisma_client is not None:
|
||||
try:
|
||||
agents = await prisma_client.db.litellm_agentstable.find_many(
|
||||
where={"agent_access_groups": {"hasSome": access_groups}}
|
||||
)
|
||||
for agent in agents:
|
||||
agent_ids.add(agent.agent_id)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error getting agents from access groups: {e}")
|
||||
return agent_ids
|
||||
|
||||
@staticmethod
|
||||
async def _get_agents_from_access_groups(
|
||||
access_groups: List[str],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Resolve agent access groups to agent IDs by querying BOTH the agent table (DB) AND config-loaded agents.
|
||||
"""
|
||||
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
try:
|
||||
# Use the helper for config-loaded agents
|
||||
agent_ids = AgentRequestHandler._get_config_agent_ids_for_access_groups(
|
||||
global_agent_registry.agent_list, access_groups
|
||||
)
|
||||
|
||||
# Use the helper for DB agents
|
||||
db_agent_ids = (
|
||||
await AgentRequestHandler._get_db_agent_ids_for_access_groups(
|
||||
prisma_client, access_groups
|
||||
)
|
||||
)
|
||||
agent_ids.update(db_agent_ids)
|
||||
|
||||
return list(agent_ids)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to get agents from access groups: {str(e)}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def get_agent_access_groups(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get list of agent access groups for the given user/key based on permissions.
|
||||
"""
|
||||
access_groups: List[str] = []
|
||||
access_groups_for_key = (
|
||||
await AgentRequestHandler._get_agent_access_groups_for_key(
|
||||
user_api_key_auth
|
||||
)
|
||||
)
|
||||
access_groups_for_team = (
|
||||
await AgentRequestHandler._get_agent_access_groups_for_team(
|
||||
user_api_key_auth
|
||||
)
|
||||
)
|
||||
|
||||
# If team has access groups, then key must have a subset of the team's access groups
|
||||
if len(access_groups_for_team) > 0:
|
||||
for access_group in access_groups_for_key:
|
||||
if access_group in access_groups_for_team:
|
||||
access_groups.append(access_group)
|
||||
else:
|
||||
access_groups = access_groups_for_key
|
||||
|
||||
return list(set(access_groups))
|
||||
|
||||
@staticmethod
|
||||
async def _get_agent_access_groups_for_key(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> List[str]:
|
||||
"""Get agent access groups for the key."""
|
||||
from litellm.proxy.auth.auth_checks import get_object_permission
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if user_api_key_auth is None:
|
||||
return []
|
||||
|
||||
if user_api_key_auth.object_permission_id is None:
|
||||
return []
|
||||
|
||||
if prisma_client is None:
|
||||
verbose_logger.debug("prisma_client is None")
|
||||
return []
|
||||
|
||||
try:
|
||||
key_object_permission = await get_object_permission(
|
||||
object_permission_id=user_api_key_auth.object_permission_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=user_api_key_auth.parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
if key_object_permission is None:
|
||||
return []
|
||||
|
||||
return key_object_permission.agent_access_groups or []
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
f"Failed to get agent access groups for key: {str(e)}"
|
||||
)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def _get_agent_access_groups_for_team(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> List[str]:
|
||||
"""Get agent access groups for the team."""
|
||||
from litellm.proxy.auth.auth_checks import get_team_object
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if user_api_key_auth is None:
|
||||
return []
|
||||
|
||||
if user_api_key_auth.team_id is None:
|
||||
return []
|
||||
|
||||
if prisma_client is None:
|
||||
verbose_logger.debug("prisma_client is None")
|
||||
return []
|
||||
|
||||
try:
|
||||
team_obj: Optional[LiteLLM_TeamTable] = await get_team_object(
|
||||
team_id=user_api_key_auth.team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=user_api_key_auth.parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
if team_obj is None:
|
||||
verbose_logger.debug("team_obj is None")
|
||||
return []
|
||||
|
||||
object_permissions = team_obj.object_permission
|
||||
if object_permissions is None:
|
||||
return []
|
||||
|
||||
return object_permissions.agent_access_groups or []
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
f"Failed to get agent access groups for team: {str(e)}"
|
||||
)
|
||||
return []
|
||||
@@ -0,0 +1,944 @@
|
||||
"""
|
||||
Agent endpoints for registering + discovering agents via LiteLLM.
|
||||
|
||||
Follows the A2A Spec.
|
||||
|
||||
1. Register an agent via POST `/v1/agents`
|
||||
2. Discover agents via GET `/v1/agents`
|
||||
3. Get specific agent via GET `/v1/agents/{agent_id}`
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.rbac_utils import check_feature_access_for_user
|
||||
from litellm.proxy.management_endpoints.common_daily_activity import get_daily_activity
|
||||
from litellm.types.agents import (
|
||||
AgentConfig,
|
||||
AgentMakePublicResponse,
|
||||
AgentResponse,
|
||||
MakeAgentsPublicRequest,
|
||||
PatchAgentRequest,
|
||||
)
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
from litellm.types.proxy.management_endpoints.common_daily_activity import (
|
||||
SpendAnalyticsPaginatedResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _check_agent_management_permission(user_api_key_dict: UserAPIKeyAuth) -> None:
|
||||
"""
|
||||
Raises HTTP 403 if the caller does not have permission to create, update,
|
||||
or delete agents. Only PROXY_ADMIN users are allowed to perform these
|
||||
write operations.
|
||||
"""
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Only proxy admins can create, update, or delete agents. Your role={}".format(
|
||||
user_api_key_dict.user_role
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
AGENT_HEALTH_CHECK_TIMEOUT_SECONDS = float(
|
||||
os.environ.get("LITELLM_AGENT_HEALTH_CHECK_TIMEOUT", "5.0")
|
||||
)
|
||||
AGENT_HEALTH_CHECK_GATHER_TIMEOUT_SECONDS = float(
|
||||
os.environ.get("LITELLM_AGENT_HEALTH_CHECK_GATHER_TIMEOUT", "30.0")
|
||||
)
|
||||
|
||||
|
||||
async def _check_agent_url_health(
|
||||
agent: AgentResponse,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform a GET request against the agent's URL and return the health result.
|
||||
|
||||
Returns a dict with ``agent_id``, ``healthy`` (bool), and an optional
|
||||
``error`` message.
|
||||
"""
|
||||
url = (agent.agent_card_params or {}).get("url")
|
||||
if not url:
|
||||
return {"agent_id": agent.agent_id, "healthy": True}
|
||||
|
||||
try:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.AgentHealthCheck,
|
||||
params={"timeout": AGENT_HEALTH_CHECK_TIMEOUT_SECONDS},
|
||||
)
|
||||
response = await client.get(url)
|
||||
if response.status_code >= 500:
|
||||
return {
|
||||
"agent_id": agent.agent_id,
|
||||
"healthy": False,
|
||||
"error": f"HTTP {response.status_code}",
|
||||
}
|
||||
return {"agent_id": agent.agent_id, "healthy": True}
|
||||
except Exception as exc:
|
||||
return {
|
||||
"agent_id": agent.agent_id,
|
||||
"healthy": False,
|
||||
"error": str(exc),
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/agents",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=List[AgentResponse],
|
||||
)
|
||||
async def get_agents(
|
||||
request: Request,
|
||||
health_check: bool = Query(
|
||||
False,
|
||||
description="When true, performs a GET request to each agent's URL. Agents with reachable URLs (HTTP status < 500) and agents without a URL are returned; unreachable agents are filtered out.",
|
||||
),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), # Used for auth
|
||||
):
|
||||
"""
|
||||
Example usage:
|
||||
```
|
||||
curl -X GET "http://localhost:4000/v1/agents" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer your-key" \
|
||||
```
|
||||
|
||||
Pass `?health_check=true` to filter out agents whose URL is unreachable:
|
||||
```
|
||||
curl -X GET "http://localhost:4000/v1/agents?health_check=true" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer your-key" \
|
||||
```
|
||||
|
||||
Returns: List[AgentResponse]
|
||||
|
||||
"""
|
||||
await check_feature_access_for_user(user_api_key_dict, "agents")
|
||||
|
||||
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
|
||||
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
|
||||
AgentRequestHandler,
|
||||
)
|
||||
|
||||
try:
|
||||
returned_agents: List[AgentResponse] = []
|
||||
|
||||
# Admin users get all agents
|
||||
if (
|
||||
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||
):
|
||||
returned_agents = global_agent_registry.get_agent_list()
|
||||
else:
|
||||
# Get allowed agents from object_permission (key/team level)
|
||||
allowed_agent_ids = await AgentRequestHandler.get_allowed_agents(
|
||||
user_api_key_auth=user_api_key_dict
|
||||
)
|
||||
|
||||
# If no restrictions (empty list), return all agents
|
||||
if len(allowed_agent_ids) == 0:
|
||||
returned_agents = global_agent_registry.get_agent_list()
|
||||
else:
|
||||
# Filter agents by allowed IDs
|
||||
all_agents = global_agent_registry.get_agent_list()
|
||||
returned_agents = [
|
||||
agent for agent in all_agents if agent.agent_id in allowed_agent_ids
|
||||
]
|
||||
|
||||
# Fetch current spend from DB for all returned agents
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is not None:
|
||||
agent_ids = [agent.agent_id for agent in returned_agents]
|
||||
if agent_ids:
|
||||
db_agents = await prisma_client.db.litellm_agentstable.find_many(
|
||||
where={"agent_id": {"in": agent_ids}},
|
||||
)
|
||||
spend_map = {a.agent_id: a.spend for a in db_agents}
|
||||
for agent in returned_agents:
|
||||
if agent.agent_id in spend_map:
|
||||
agent.spend = spend_map[agent.agent_id]
|
||||
|
||||
# add is_public field to each agent - we do it this way, to allow setting config agents as public
|
||||
for agent in returned_agents:
|
||||
if agent.litellm_params is None:
|
||||
agent.litellm_params = {}
|
||||
agent.litellm_params[
|
||||
"is_public"
|
||||
] = litellm.public_agent_groups is not None and (
|
||||
agent.agent_id in litellm.public_agent_groups
|
||||
)
|
||||
|
||||
if health_check:
|
||||
agents_with_url = [
|
||||
agent
|
||||
for agent in returned_agents
|
||||
if (agent.agent_card_params or {}).get("url")
|
||||
]
|
||||
agents_without_url = [
|
||||
agent
|
||||
for agent in returned_agents
|
||||
if not (agent.agent_card_params or {}).get("url")
|
||||
]
|
||||
try:
|
||||
health_results = await asyncio.wait_for(
|
||||
asyncio.gather(
|
||||
*[_check_agent_url_health(agent) for agent in agents_with_url]
|
||||
),
|
||||
timeout=AGENT_HEALTH_CHECK_GATHER_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
verbose_proxy_logger.warning(
|
||||
"Agent health check gather timed out after %s seconds",
|
||||
AGENT_HEALTH_CHECK_GATHER_TIMEOUT_SECONDS,
|
||||
)
|
||||
health_results = [
|
||||
{
|
||||
"agent_id": agent.agent_id,
|
||||
"healthy": False,
|
||||
"error": "Health check timed out",
|
||||
}
|
||||
for agent in agents_with_url
|
||||
]
|
||||
healthy_ids = {
|
||||
result["agent_id"] for result in health_results if result["healthy"]
|
||||
}
|
||||
returned_agents = [
|
||||
agent for agent in agents_with_url if agent.agent_id in healthy_ids
|
||||
] + agents_without_url
|
||||
|
||||
return returned_agents
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.agent_endpoints.get_agents(): Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": f"Internal server error: {str(e)}"}
|
||||
)
|
||||
|
||||
|
||||
#### CRUD ENDPOINTS FOR AGENTS ####
|
||||
|
||||
from litellm.proxy.agent_endpoints.agent_registry import (
|
||||
global_agent_registry as AGENT_REGISTRY,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/agents",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=AgentResponse,
|
||||
)
|
||||
async def create_agent(
|
||||
request: AgentConfig,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create a new agent
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/agents" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"agent": {
|
||||
"agent_name": "my-custom-agent",
|
||||
"agent_card_params": {
|
||||
"protocolVersion": "1.0",
|
||||
"name": "Hello World Agent",
|
||||
"description": "Just a hello world agent",
|
||||
"url": "http://localhost:9999/",
|
||||
"version": "1.0.0",
|
||||
"defaultInputModes": ["text"],
|
||||
"defaultOutputModes": ["text"],
|
||||
"capabilities": {
|
||||
"streaming": true
|
||||
},
|
||||
"skills": [
|
||||
{
|
||||
"id": "hello_world",
|
||||
"name": "Returns hello world",
|
||||
"description": "just returns hello world",
|
||||
"tags": ["hello world"],
|
||||
"examples": ["hi", "hello world"]
|
||||
}
|
||||
]
|
||||
},
|
||||
"litellm_params": {
|
||||
"make_public": true
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
"""
|
||||
await check_feature_access_for_user(user_api_key_dict, "agents")
|
||||
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
_check_agent_management_permission(user_api_key_dict)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Prisma client not initialized")
|
||||
|
||||
try:
|
||||
# Get the user ID from the API key auth
|
||||
created_by = user_api_key_dict.user_id or "unknown"
|
||||
|
||||
# check for naming conflicts
|
||||
existing_agent = AGENT_REGISTRY.get_agent_by_name(
|
||||
agent_name=request.get("agent_name") # type: ignore
|
||||
)
|
||||
if existing_agent is not None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Agent with name {request.get('agent_name')} already exists",
|
||||
)
|
||||
|
||||
result = await AGENT_REGISTRY.add_agent_to_db(
|
||||
agent=request, prisma_client=prisma_client, created_by=created_by
|
||||
)
|
||||
|
||||
agent_name = result.agent_name
|
||||
agent_id = result.agent_id
|
||||
|
||||
# Also register in memory
|
||||
try:
|
||||
AGENT_REGISTRY.register_agent(agent_config=result)
|
||||
verbose_proxy_logger.info(
|
||||
f"Successfully registered agent '{agent_name}' (ID: {agent_id}) in memory"
|
||||
)
|
||||
except Exception as reg_error:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Failed to register agent '{agent_name}' (ID: {agent_id}) in memory: {reg_error}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error adding agent to db: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/agents/{agent_id}",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=AgentResponse,
|
||||
)
|
||||
async def get_agent_by_id(
|
||||
agent_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get a specific agent by ID
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X GET "http://localhost:4000/agents/123e4567-e89b-12d3-a456-426614174000" \\
|
||||
-H "Authorization: Bearer <your_api_key>"
|
||||
```
|
||||
"""
|
||||
await check_feature_access_for_user(user_api_key_dict, "agents")
|
||||
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Prisma client not initialized")
|
||||
|
||||
try:
|
||||
agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id)
|
||||
if agent is None:
|
||||
agent_row = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id},
|
||||
include={"object_permission": True},
|
||||
)
|
||||
if agent_row is not None:
|
||||
agent_dict = agent_row.model_dump()
|
||||
if agent_row.object_permission is not None:
|
||||
try:
|
||||
agent_dict[
|
||||
"object_permission"
|
||||
] = agent_row.object_permission.model_dump()
|
||||
except Exception:
|
||||
agent_dict[
|
||||
"object_permission"
|
||||
] = agent_row.object_permission.dict()
|
||||
agent = AgentResponse(**agent_dict) # type: ignore
|
||||
else:
|
||||
# Agent found in memory — refresh spend from DB
|
||||
db_row = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if db_row is not None:
|
||||
agent.spend = db_row.spend
|
||||
|
||||
if agent is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
|
||||
return agent
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting agent from db: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put(
|
||||
"/v1/agents/{agent_id}",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=AgentResponse,
|
||||
)
|
||||
async def update_agent(
|
||||
agent_id: str,
|
||||
request: AgentConfig,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update an existing agent
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X PUT "http://localhost:4000/agents/123e4567-e89b-12d3-a456-426614174000" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"agent": {
|
||||
"agent_name": "updated-agent",
|
||||
"agent_card_params": {
|
||||
"protocolVersion": "1.0",
|
||||
"name": "Updated Agent",
|
||||
"description": "Updated description",
|
||||
"url": "http://localhost:9999/",
|
||||
"version": "1.1.0",
|
||||
"defaultInputModes": ["text"],
|
||||
"defaultOutputModes": ["text"],
|
||||
"capabilities": {
|
||||
"streaming": true
|
||||
},
|
||||
"skills": []
|
||||
},
|
||||
"litellm_params": {
|
||||
"make_public": false
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
"""
|
||||
await check_feature_access_for_user(user_api_key_dict, "agents")
|
||||
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
_check_agent_management_permission(user_api_key_dict)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if agent exists
|
||||
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if existing_agent is not None:
|
||||
existing_agent = dict(existing_agent)
|
||||
|
||||
if existing_agent is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
|
||||
# Get the user ID from the API key auth
|
||||
updated_by = user_api_key_dict.user_id or "unknown"
|
||||
|
||||
result = await AGENT_REGISTRY.update_agent_in_db(
|
||||
agent_id=agent_id,
|
||||
agent=request,
|
||||
prisma_client=prisma_client,
|
||||
updated_by=updated_by,
|
||||
)
|
||||
|
||||
# deregister in memory
|
||||
AGENT_REGISTRY.deregister_agent(agent_name=existing_agent.get("agent_name")) # type: ignore
|
||||
# register in memory
|
||||
AGENT_REGISTRY.register_agent(agent_config=result)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Successfully updated agent '{existing_agent.get('agent_name')}' (ID: {agent_id}) in memory"
|
||||
)
|
||||
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error updating agent: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/v1/agents/{agent_id}",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=AgentResponse,
|
||||
)
|
||||
async def patch_agent(
|
||||
agent_id: str,
|
||||
request: PatchAgentRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update an existing agent
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X PUT "http://localhost:4000/agents/123e4567-e89b-12d3-a456-426614174000" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"agent": {
|
||||
"agent_name": "updated-agent",
|
||||
"agent_card_params": {
|
||||
"protocolVersion": "1.0",
|
||||
"name": "Updated Agent",
|
||||
"description": "Updated description",
|
||||
"url": "http://localhost:9999/",
|
||||
"version": "1.1.0",
|
||||
"defaultInputModes": ["text"],
|
||||
"defaultOutputModes": ["text"],
|
||||
"capabilities": {
|
||||
"streaming": true
|
||||
},
|
||||
"skills": []
|
||||
},
|
||||
"litellm_params": {
|
||||
"make_public": false
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
"""
|
||||
await check_feature_access_for_user(user_api_key_dict, "agents")
|
||||
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
_check_agent_management_permission(user_api_key_dict)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if agent exists
|
||||
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if existing_agent is not None:
|
||||
existing_agent = dict(existing_agent)
|
||||
|
||||
if existing_agent is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
|
||||
# Get the user ID from the API key auth
|
||||
updated_by = user_api_key_dict.user_id or "unknown"
|
||||
|
||||
result = await AGENT_REGISTRY.patch_agent_in_db(
|
||||
agent_id=agent_id,
|
||||
agent=request,
|
||||
prisma_client=prisma_client,
|
||||
updated_by=updated_by,
|
||||
)
|
||||
|
||||
# deregister in memory
|
||||
AGENT_REGISTRY.deregister_agent(agent_name=existing_agent.get("agent_name")) # type: ignore
|
||||
# register in memory
|
||||
AGENT_REGISTRY.register_agent(agent_config=result)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Successfully updated agent '{existing_agent.get('agent_name')}' (ID: {agent_id}) in memory"
|
||||
)
|
||||
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error updating agent: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/v1/agents/{agent_id}",
|
||||
tags=["Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_agent(
|
||||
agent_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete an agent
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X DELETE "http://localhost:4000/agents/123e4567-e89b-12d3-a456-426614174000" \\
|
||||
-H "Authorization: Bearer <your_api_key>"
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"message": "Agent 123e4567-e89b-12d3-a456-426614174000 deleted successfully"
|
||||
}
|
||||
```
|
||||
"""
|
||||
await check_feature_access_for_user(user_api_key_dict, "agents")
|
||||
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
_check_agent_management_permission(user_api_key_dict)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Prisma client not initialized")
|
||||
|
||||
try:
|
||||
# Check if agent exists
|
||||
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if existing_agent is not None:
|
||||
existing_agent = dict[Any, Any](existing_agent)
|
||||
|
||||
if existing_agent is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found in DB."
|
||||
)
|
||||
|
||||
await AGENT_REGISTRY.delete_agent_from_db(
|
||||
agent_id=agent_id, prisma_client=prisma_client
|
||||
)
|
||||
|
||||
AGENT_REGISTRY.deregister_agent(agent_name=existing_agent.get("agent_name")) # type: ignore
|
||||
|
||||
return {"message": f"Agent {agent_id} deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error deleting agent: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/agents/{agent_id}/make_public",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=AgentMakePublicResponse,
|
||||
)
|
||||
async def make_agent_public(
|
||||
agent_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Make an agent publicly discoverable
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/v1/agents/123e4567-e89b-12d3-a456-426614174000/make_public" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json"
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"agent_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"agent_name": "my-custom-agent",
|
||||
"litellm_params": {
|
||||
"make_public": true
|
||||
},
|
||||
"agent_card_params": {...},
|
||||
"created_at": "2025-11-15T10:30:00Z",
|
||||
"updated_at": "2025-11-15T10:35:00Z",
|
||||
"created_by": "user123",
|
||||
"updated_by": "user123"
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
try:
|
||||
# Update the public model groups
|
||||
import litellm
|
||||
from litellm.proxy.agent_endpoints.agent_registry import (
|
||||
global_agent_registry as AGENT_REGISTRY,
|
||||
)
|
||||
from litellm.proxy.proxy_server import proxy_config
|
||||
|
||||
# Check if user has admin permissions
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Only proxy admins can update public model groups. Your role={}".format(
|
||||
user_api_key_dict.user_role
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id)
|
||||
if agent is None:
|
||||
# check if agent exists in DB
|
||||
agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if agent is not None:
|
||||
agent = AgentResponse(**agent.model_dump()) # type: ignore
|
||||
|
||||
if agent is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
|
||||
if litellm.public_agent_groups is None:
|
||||
litellm.public_agent_groups = []
|
||||
# handle duplicates
|
||||
if agent.agent_id in litellm.public_agent_groups:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Agent with name {agent.agent_name} already in public agent groups",
|
||||
)
|
||||
litellm.public_agent_groups.append(agent.agent_id)
|
||||
|
||||
# Load existing config
|
||||
config = await proxy_config.get_config()
|
||||
|
||||
# Update config with new settings
|
||||
if "litellm_settings" not in config or config["litellm_settings"] is None:
|
||||
config["litellm_settings"] = {}
|
||||
|
||||
config["litellm_settings"]["public_agent_groups"] = litellm.public_agent_groups
|
||||
|
||||
# Save the updated config
|
||||
await proxy_config.save_config(new_config=config)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Updated public agent groups to: {litellm.public_agent_groups} by user: {user_api_key_dict.user_id}"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Successfully updated public agent groups",
|
||||
"public_agent_groups": litellm.public_agent_groups,
|
||||
"updated_by": user_api_key_dict.user_id,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error making agent public: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/agents/make_public",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=AgentMakePublicResponse,
|
||||
)
|
||||
async def make_agents_public(
|
||||
request: MakeAgentsPublicRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Make multiple agents publicly discoverable
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/v1/agents/make_public" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"agent_ids": ["123e4567-e89b-12d3-a456-426614174000", "123e4567-e89b-12d3-a456-426614174001"]
|
||||
}'
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"agent_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"agent_name": "my-custom-agent",
|
||||
"litellm_params": {
|
||||
"make_public": true
|
||||
},
|
||||
"agent_card_params": {...},
|
||||
"created_at": "2025-11-15T10:30:00Z",
|
||||
"updated_at": "2025-11-15T10:35:00Z",
|
||||
"created_by": "user123",
|
||||
"updated_by": "user123"
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
try:
|
||||
# Update the public model groups
|
||||
import litellm
|
||||
from litellm.proxy.agent_endpoints.agent_registry import (
|
||||
global_agent_registry as AGENT_REGISTRY,
|
||||
)
|
||||
from litellm.proxy.proxy_server import proxy_config
|
||||
|
||||
# Load existing config
|
||||
config = await proxy_config.get_config()
|
||||
# Check if user has admin permissions
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Only proxy admins can update public model groups. Your role={}".format(
|
||||
user_api_key_dict.user_role
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
if litellm.public_agent_groups is None:
|
||||
litellm.public_agent_groups = []
|
||||
|
||||
for agent_id in request.agent_ids:
|
||||
agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id)
|
||||
if agent is None:
|
||||
# check if agent exists in DB
|
||||
agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if agent is not None:
|
||||
agent = AgentResponse(**agent.model_dump()) # type: ignore
|
||||
|
||||
if agent is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
|
||||
litellm.public_agent_groups = request.agent_ids
|
||||
|
||||
# Update config with new settings
|
||||
if "litellm_settings" not in config or config["litellm_settings"] is None:
|
||||
config["litellm_settings"] = {}
|
||||
|
||||
config["litellm_settings"]["public_agent_groups"] = litellm.public_agent_groups
|
||||
|
||||
# Save the updated config
|
||||
await proxy_config.save_config(new_config=config)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Updated public agent groups to: {litellm.public_agent_groups} by user: {user_api_key_dict.user_id}"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Successfully updated public agent groups",
|
||||
"public_agent_groups": litellm.public_agent_groups,
|
||||
"updated_by": user_api_key_dict.user_id,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error making agent public: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agent/daily/activity",
|
||||
tags=["Agent Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=SpendAnalyticsPaginatedResponse,
|
||||
)
|
||||
async def get_agent_daily_activity(
|
||||
agent_ids: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
exclude_agent_ids: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get daily activity for specific agents or all accessible agents.
|
||||
"""
|
||||
await check_feature_access_for_user(user_api_key_dict, "agents")
|
||||
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
agent_ids_list = agent_ids.split(",") if agent_ids else None
|
||||
exclude_agent_ids_list: Optional[List[str]] = None
|
||||
if exclude_agent_ids:
|
||||
exclude_agent_ids_list = (
|
||||
exclude_agent_ids.split(",") if exclude_agent_ids else None
|
||||
)
|
||||
|
||||
where_condition = {}
|
||||
if agent_ids_list:
|
||||
where_condition["agent_id"] = {"in": list(agent_ids_list)}
|
||||
|
||||
agent_records = await prisma_client.db.litellm_agentstable.find_many(
|
||||
where=where_condition
|
||||
)
|
||||
agent_metadata = {
|
||||
agent.agent_id: {"agent_name": agent.agent_name} for agent in agent_records
|
||||
}
|
||||
|
||||
return await get_daily_activity(
|
||||
prisma_client=prisma_client,
|
||||
table_name="litellm_dailyagentspend",
|
||||
entity_id_field="agent_id",
|
||||
entity_id=agent_ids_list,
|
||||
entity_metadata_field=agent_metadata,
|
||||
exclude_entity_ids=exclude_agent_ids_list,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Helper functions for appending A2A agents to model lists.
|
||||
|
||||
Used by proxy model endpoints to make agents appear in UI alongside models.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.proxy.management_endpoints.model_management_endpoints import (
|
||||
ModelGroupInfoProxy,
|
||||
)
|
||||
|
||||
|
||||
async def append_agents_to_model_group(
|
||||
model_groups: List[ModelGroupInfoProxy],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> List[ModelGroupInfoProxy]:
|
||||
"""
|
||||
Append A2A agents to model groups list for UI display.
|
||||
|
||||
Converts agents to model format with "a2a/<agent-name>" naming
|
||||
so they appear in playground and work with LiteLLM routing.
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
|
||||
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
|
||||
AgentRequestHandler,
|
||||
)
|
||||
|
||||
allowed_agent_ids = await AgentRequestHandler.get_allowed_agents(
|
||||
user_api_key_auth=user_api_key_dict
|
||||
)
|
||||
|
||||
for agent_id in allowed_agent_ids:
|
||||
agent = global_agent_registry.get_agent_by_id(agent_id)
|
||||
if agent is not None:
|
||||
model_groups.append(
|
||||
ModelGroupInfoProxy(
|
||||
model_group=f"a2a/{agent.agent_name}",
|
||||
mode="chat",
|
||||
providers=["a2a"],
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error appending agents to model_group/info: {e}")
|
||||
|
||||
return model_groups
|
||||
|
||||
|
||||
async def append_agents_to_model_info(
|
||||
models: List[dict],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Append A2A agents to model info list for UI display.
|
||||
|
||||
Converts agents to model format with "a2a/<agent-name>" naming
|
||||
so they appear in models page and work with LiteLLM routing.
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
|
||||
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
|
||||
AgentRequestHandler,
|
||||
)
|
||||
|
||||
allowed_agent_ids = await AgentRequestHandler.get_allowed_agents(
|
||||
user_api_key_auth=user_api_key_dict
|
||||
)
|
||||
|
||||
for agent_id in allowed_agent_ids:
|
||||
agent = global_agent_registry.get_agent_by_id(agent_id)
|
||||
if agent is not None:
|
||||
models.append(
|
||||
{
|
||||
"model_name": f"a2a/{agent.agent_name}",
|
||||
"litellm_params": {
|
||||
"model": f"a2a/{agent.agent_name}",
|
||||
"custom_llm_provider": "a2a",
|
||||
},
|
||||
"model_info": {
|
||||
"id": agent.agent_id,
|
||||
"mode": "chat",
|
||||
"db_model": True,
|
||||
"created_by": agent.created_by,
|
||||
"created_at": agent.created_at,
|
||||
"updated_at": agent.updated_at,
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error appending agents to v2/model/info: {e}")
|
||||
|
||||
return models
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Utility helpers for A2A agent endpoints."""
|
||||
|
||||
from typing import Dict, Mapping, Optional
|
||||
|
||||
|
||||
def merge_agent_headers(
|
||||
*,
|
||||
dynamic_headers: Optional[Mapping[str, str]] = None,
|
||||
static_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> Optional[Dict[str, str]]:
|
||||
"""Merge outbound HTTP headers for A2A agent calls.
|
||||
|
||||
Merge rules:
|
||||
- Start with ``dynamic_headers`` (values extracted from the incoming client request).
|
||||
- Overlay ``static_headers`` (admin-configured per agent).
|
||||
|
||||
If both contain the same key, ``static_headers`` wins.
|
||||
"""
|
||||
merged: Dict[str, str] = {}
|
||||
|
||||
if dynamic_headers:
|
||||
merged.update({str(k): str(v) for k, v in dynamic_headers.items()})
|
||||
|
||||
if static_headers:
|
||||
merged.update({str(k): str(v) for k, v in static_headers.items()})
|
||||
|
||||
return merged or None
|
||||
Reference in New Issue
Block a user