chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

Submodule llm-gateway-competitors/litellm added at cd37ee1459

Submodule llm-gateway-competitors/litellm-sparse added at 58e74a631c

View File

@@ -0,0 +1,26 @@
Portions of this software are licensed as follows:
* All content that resides under the "enterprise/" directory of this repository, if that directory exists, is licensed under the license defined in "enterprise/LICENSE".
* Content outside of the above mentioned directories or restrictions above is available under the MIT license as defined below.
---
MIT License
Copyright (c) 2023 Berri AI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,555 @@
Metadata-Version: 2.1
Name: litellm
Version: 1.82.2
Summary: Library to easily interface with LLM API providers
License: MIT
Author: BerriAI
Requires-Python: >=3.9,<4.0
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Provides-Extra: caching
Provides-Extra: extra-proxy
Provides-Extra: google
Provides-Extra: grpc
Provides-Extra: mlflow
Provides-Extra: proxy
Provides-Extra: semantic-router
Provides-Extra: utils
Requires-Dist: PyJWT (>=2.10.1,<3.0.0) ; (python_version >= "3.9") and (extra == "proxy")
Requires-Dist: a2a-sdk (>=0.3.22,<0.4.0) ; (python_version >= "3.10") and (extra == "extra-proxy")
Requires-Dist: aiohttp (>=3.10)
Requires-Dist: apscheduler (>=3.10.4,<4.0.0) ; extra == "proxy"
Requires-Dist: azure-identity (>=1.15.0,<2.0.0) ; (python_version >= "3.9") and (extra == "proxy" or extra == "extra-proxy")
Requires-Dist: azure-keyvault-secrets (>=4.8.0,<5.0.0) ; extra == "extra-proxy"
Requires-Dist: azure-storage-blob (>=12.25.1,<13.0.0) ; extra == "proxy"
Requires-Dist: backoff ; extra == "proxy"
Requires-Dist: boto3 (>=1.40.76,<2.0.0) ; extra == "proxy"
Requires-Dist: click
Requires-Dist: cryptography ; extra == "proxy"
Requires-Dist: diskcache (>=5.6.1,<6.0.0) ; extra == "caching"
Requires-Dist: fastapi (>=0.120.1) ; extra == "proxy"
Requires-Dist: fastapi-sso (>=0.16.0,<0.17.0) ; extra == "proxy"
Requires-Dist: fastuuid (>=0.13.0)
Requires-Dist: google-cloud-aiplatform (>=1.38.0) ; extra == "google"
Requires-Dist: google-cloud-iam (>=2.19.1,<3.0.0) ; extra == "extra-proxy"
Requires-Dist: google-cloud-kms (>=2.21.3,<3.0.0) ; extra == "extra-proxy"
Requires-Dist: grpcio (>=1.62.3,!=1.68.*,!=1.69.*,!=1.70.*,!=1.71.0,!=1.71.1,!=1.72.0,!=1.72.1,!=1.73.0) ; (python_version < "3.14") and (extra == "grpc")
Requires-Dist: grpcio (>=1.75.0) ; (python_version >= "3.14") and (extra == "grpc")
Requires-Dist: gunicorn (>=23.0.0,<24.0.0) ; extra == "proxy"
Requires-Dist: httpx (>=0.23.0)
Requires-Dist: importlib-metadata (>=6.8.0)
Requires-Dist: jinja2 (>=3.1.2,<4.0.0)
Requires-Dist: jsonschema (>=4.23.0,<5.0.0)
Requires-Dist: litellm-enterprise (>=0.1.33,<0.2.0) ; extra == "proxy"
Requires-Dist: litellm-proxy-extras (>=0.4.56,<0.5.0) ; extra == "proxy"
Requires-Dist: mcp (>=1.25.0,<2.0.0) ; (python_version >= "3.10") and (extra == "proxy")
Requires-Dist: mlflow (>3.1.4) ; (python_version >= "3.10") and (extra == "mlflow")
Requires-Dist: numpydoc ; extra == "utils"
Requires-Dist: openai (>=2.8.0)
Requires-Dist: orjson (>=3.9.7,<4.0.0) ; extra == "proxy"
Requires-Dist: polars (>=1.31.0,<2.0.0) ; (python_version >= "3.10") and (extra == "proxy")
Requires-Dist: prisma (>=0.11.0,<0.12.0) ; extra == "extra-proxy"
Requires-Dist: pydantic (>=2.5.0,<3.0.0)
Requires-Dist: pynacl (>=1.5.0,<2.0.0) ; extra == "proxy"
Requires-Dist: pyroscope-io (>=0.8,<0.9) ; (sys_platform != "win32") and (extra == "proxy")
Requires-Dist: python-dotenv (>=0.2.0)
Requires-Dist: python-multipart (>=0.0.20) ; extra == "proxy"
Requires-Dist: pyyaml (>=6.0.1,<7.0.0) ; extra == "proxy"
Requires-Dist: redisvl (>=0.4.1,<0.5.0) ; (python_version >= "3.9" and python_version < "3.14") and (extra == "extra-proxy")
Requires-Dist: resend (>=0.8.0) ; extra == "extra-proxy"
Requires-Dist: rich (>=13.7.1,<14.0.0) ; extra == "proxy"
Requires-Dist: rq ; extra == "proxy"
Requires-Dist: semantic-router (>=0.1.12) ; (python_version >= "3.9" and python_version < "3.14") and (extra == "semantic-router")
Requires-Dist: soundfile (>=0.12.1,<0.13.0) ; extra == "proxy"
Requires-Dist: tiktoken (>=0.7.0)
Requires-Dist: tokenizers
Requires-Dist: uvicorn (>=0.32.1,<1.0.0) ; extra == "proxy"
Requires-Dist: uvloop (>=0.21.0,<0.22.0) ; (sys_platform != "win32") and (extra == "proxy")
Requires-Dist: websockets (>=15.0.1,<16.0.0) ; extra == "proxy"
Project-URL: Documentation, https://docs.litellm.ai
Project-URL: Homepage, https://litellm.ai
Project-URL: Repository, https://github.com/BerriAI/litellm
Project-URL: documentation, https://docs.litellm.ai
Project-URL: homepage, https://litellm.ai
Project-URL: repository, https://github.com/BerriAI/litellm
Description-Content-Type: text/markdown
<h1 align="center">
🚅 LiteLLM
</h1>
<p align="center">
<p align="center">Call 100+ LLMs in OpenAI format. [Bedrock, Azure, OpenAI, VertexAI, Anthropic, Groq, etc.]
</p>
<p align="center">
<a href="https://render.com/deploy?repo=https://github.com/BerriAI/litellm" target="_blank" rel="nofollow"><img src="https://render.com/images/deploy-to-render-button.svg" alt="Deploy to Render"></a>
<a href="https://railway.app/template/HLP0Ub?referralCode=jch2ME">
<img src="https://railway.app/button.svg" alt="Deploy on Railway">
</a>
</p>
</p>
<h4 align="center"><a href="https://docs.litellm.ai/docs/simple_proxy" target="_blank">LiteLLM Proxy Server (AI Gateway)</a> | <a href="https://docs.litellm.ai/docs/enterprise#hosted-litellm-proxy" target="_blank"> Hosted Proxy</a> | <a href="https://docs.litellm.ai/docs/enterprise"target="_blank">Enterprise Tier</a></h4>
<h4 align="center">
<a href="https://pypi.org/project/litellm/" target="_blank">
<img src="https://img.shields.io/pypi/v/litellm.svg" alt="PyPI Version">
</a>
<a href="https://www.ycombinator.com/companies/berriai">
<img src="https://img.shields.io/badge/Y%20Combinator-W23-orange?style=flat-square" alt="Y Combinator W23">
</a>
<a href="https://wa.link/huol9n">
<img src="https://img.shields.io/static/v1?label=Chat%20on&message=WhatsApp&color=success&logo=WhatsApp&style=flat-square" alt="Whatsapp">
</a>
<a href="https://discord.gg/wuPM9dRgDw">
<img src="https://img.shields.io/static/v1?label=Chat%20on&message=Discord&color=blue&logo=Discord&style=flat-square" alt="Discord">
</a>
<a href="https://www.litellm.ai/support">
<img src="https://img.shields.io/static/v1?label=Chat%20on&message=Slack&color=black&logo=Slack&style=flat-square" alt="Slack">
</a>
</h4>
<img width="2688" height="1600" alt="Group 7154 (1)" src="https://github.com/user-attachments/assets/c5ee0412-6fb5-4fb6-ab5b-bafae4209ca6" />
## Use LiteLLM for
<details open>
<summary><b>LLMs</b> - Call 100+ LLMs (Python SDK + AI Gateway)</summary>
[**All Supported Endpoints**](https://docs.litellm.ai/docs/supported_endpoints) - `/chat/completions`, `/responses`, `/embeddings`, `/images`, `/audio`, `/batches`, `/rerank`, `/a2a`, `/messages` and more.
### Python SDK
```shell
pip install litellm
```
```python
from litellm import completion
import os
os.environ["OPENAI_API_KEY"] = "your-openai-key"
os.environ["ANTHROPIC_API_KEY"] = "your-anthropic-key"
# OpenAI
response = completion(model="openai/gpt-4o", messages=[{"role": "user", "content": "Hello!"}])
# Anthropic
response = completion(model="anthropic/claude-sonnet-4-20250514", messages=[{"role": "user", "content": "Hello!"}])
```
### AI Gateway (Proxy Server)
[**Getting Started - E2E Tutorial**](https://docs.litellm.ai/docs/proxy/docker_quick_start) - Setup virtual keys, make your first request
```shell
pip install 'litellm[proxy]'
litellm --model gpt-4o
```
```python
import openai
client = openai.OpenAI(api_key="anything", base_url="http://0.0.0.0:4000")
response = client.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": "Hello!"}]
)
```
[**Docs: LLM Providers**](https://docs.litellm.ai/docs/providers)
</details>
<details>
<summary><b>Agents</b> - Invoke A2A Agents (Python SDK + AI Gateway)</summary>
[**Supported Providers**](https://docs.litellm.ai/docs/a2a#add-a2a-agents) - LangGraph, Vertex AI Agent Engine, Azure AI Foundry, Bedrock AgentCore, Pydantic AI
### Python SDK - A2A Protocol
```python
from litellm.a2a_protocol import A2AClient
from a2a.types import SendMessageRequest, MessageSendParams
from uuid import uuid4
client = A2AClient(base_url="http://localhost:10001")
request = SendMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={
"role": "user",
"parts": [{"kind": "text", "text": "Hello!"}],
"messageId": uuid4().hex,
}
)
)
response = await client.send_message(request)
```
### AI Gateway (Proxy Server)
**Step 1.** [Add your Agent to the AI Gateway](https://docs.litellm.ai/docs/a2a#adding-your-agent)
**Step 2.** Call Agent via A2A SDK
```python
from a2a.client import A2ACardResolver, A2AClient
from a2a.types import MessageSendParams, SendMessageRequest
from uuid import uuid4
import httpx
base_url = "http://localhost:4000/a2a/my-agent" # LiteLLM proxy + agent name
headers = {"Authorization": "Bearer sk-1234"} # LiteLLM Virtual Key
async with httpx.AsyncClient(headers=headers) as httpx_client:
resolver = A2ACardResolver(httpx_client=httpx_client, base_url=base_url)
agent_card = await resolver.get_agent_card()
client = A2AClient(httpx_client=httpx_client, agent_card=agent_card)
request = SendMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={
"role": "user",
"parts": [{"kind": "text", "text": "Hello!"}],
"messageId": uuid4().hex,
}
)
)
response = await client.send_message(request)
```
[**Docs: A2A Agent Gateway**](https://docs.litellm.ai/docs/a2a)
</details>
<details>
<summary><b>MCP Tools</b> - Connect MCP servers to any LLM (Python SDK + AI Gateway)</summary>
### Python SDK - MCP Bridge
```python
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from litellm import experimental_mcp_client
import litellm
server_params = StdioServerParameters(command="python", args=["mcp_server.py"])
async with stdio_client(server_params) as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()
# Load MCP tools in OpenAI format
tools = await experimental_mcp_client.load_mcp_tools(session=session, format="openai")
# Use with any LiteLLM model
response = await litellm.acompletion(
model="gpt-4o",
messages=[{"role": "user", "content": "What's 3 + 5?"}],
tools=tools
)
```
### AI Gateway - MCP Gateway
**Step 1.** [Add your MCP Server to the AI Gateway](https://docs.litellm.ai/docs/mcp#adding-your-mcp)
**Step 2.** Call MCP tools via `/chat/completions`
```bash
curl -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"model": "gpt-4o",
"messages": [{"role": "user", "content": "Summarize the latest open PR"}],
"tools": [{
"type": "mcp",
"server_url": "litellm_proxy/mcp/github",
"server_label": "github_mcp",
"require_approval": "never"
}]
}'
```
### Use with Cursor IDE
```json
{
"mcpServers": {
"LiteLLM": {
"url": "http://localhost:4000/mcp/",
"headers": {
"x-litellm-api-key": "Bearer sk-1234"
}
}
}
}
```
[**Docs: MCP Gateway**](https://docs.litellm.ai/docs/mcp)
</details>
---
## How to use LiteLLM
You can use LiteLLM through either the Proxy Server or Python SDK. Both gives you a unified interface to access multiple LLMs (100+ LLMs). Choose the option that best fits your needs:
<table style={{width: '100%', tableLayout: 'fixed'}}>
<thead>
<tr>
<th style={{width: '14%'}}></th>
<th style={{width: '43%'}}><strong><a href="https://docs.litellm.ai/docs/simple_proxy">LiteLLM AI Gateway</a></strong></th>
<th style={{width: '43%'}}><strong><a href="https://docs.litellm.ai/docs/">LiteLLM Python SDK</a></strong></th>
</tr>
</thead>
<tbody>
<tr>
<td style={{width: '14%'}}><strong>Use Case</strong></td>
<td style={{width: '43%'}}>Central service (LLM Gateway) to access multiple LLMs</td>
<td style={{width: '43%'}}>Use LiteLLM directly in your Python code</td>
</tr>
<tr>
<td style={{width: '14%'}}><strong>Who Uses It?</strong></td>
<td style={{width: '43%'}}>Gen AI Enablement / ML Platform Teams</td>
<td style={{width: '43%'}}>Developers building LLM projects</td>
</tr>
<tr>
<td style={{width: '14%'}}><strong>Key Features</strong></td>
<td style={{width: '43%'}}>Centralized API gateway with authentication and authorization, multi-tenant cost tracking and spend management per project/user, per-project customization (logging, guardrails, caching), virtual keys for secure access control, admin dashboard UI for monitoring and management</td>
<td style={{width: '43%'}}>Direct Python library integration in your codebase, Router with retry/fallback logic across multiple deployments (e.g. Azure/OpenAI) - <a href="https://docs.litellm.ai/docs/routing">Router</a>, application-level load balancing and cost tracking, exception handling with OpenAI-compatible errors, observability callbacks (Lunary, MLflow, Langfuse, etc.)</td>
</tr>
</tbody>
</table>
LiteLLM Performance: **8ms P95 latency** at 1k RPS (See benchmarks [here](https://docs.litellm.ai/docs/benchmarks))
[**Jump to LiteLLM Proxy (LLM Gateway) Docs**](https://docs.litellm.ai/docs/simple_proxy) <br>
[**Jump to Supported LLM Providers**](https://docs.litellm.ai/docs/providers)
**Stable Release:** Use docker images with the `-stable` tag. These have undergone 12 hour load tests, before being published. [More information about the release cycle here](https://docs.litellm.ai/docs/proxy/release_cycle)
Support for more providers. Missing a provider or LLM Platform, raise a [feature request](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=enhancement&projects=&template=feature_request.yml&title=%5BFeature%5D%3A+).
## OSS Adopters
<table>
<tr>
<td><img height="60" alt="Stripe" src="https://github.com/user-attachments/assets/f7296d4f-9fbd-460d-9d05-e4df31697c4b" /></td>
<td><img height="60" alt="Google ADK" src="https://github.com/user-attachments/assets/caf270a2-5aee-45c4-8222-41a2070c4f19" /></td>
<td><img height="60" alt="Greptile" src="https://github.com/user-attachments/assets/0be4bd8a-7cfa-48d3-9090-f415fe948280" /></td>
<td><img height="60" alt="OpenHands" src="https://github.com/user-attachments/assets/a6150c4c-149e-4cae-888b-8b92be6e003f" /></td>
<td><h2>Netflix</h2></td>
<td><img height="60" alt="OpenAI Agents SDK" src="https://github.com/user-attachments/assets/c02f7be0-8c2e-4d27-aea7-7c024bfaebc0" /></td>
</tr>
</table>
## Supported Providers ([Website Supported Models](https://models.litellm.ai/) | [Docs](https://docs.litellm.ai/docs/providers))
| Provider | `/chat/completions` | `/messages` | `/responses` | `/embeddings` | `/image/generations` | `/audio/transcriptions` | `/audio/speech` | `/moderations` | `/batches` | `/rerank` |
|-------------------------------------------------------------------------------------|---------------------|-------------|--------------|---------------|----------------------|-------------------------|-----------------|----------------|-----------|-----------|
| [Abliteration (`abliteration`)](https://docs.litellm.ai/docs/providers/abliteration) | ✅ | | | | | | | | | |
| [AI/ML API (`aiml`)](https://docs.litellm.ai/docs/providers/aiml) | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
| [AI21 (`ai21`)](https://docs.litellm.ai/docs/providers/ai21) | ✅ | ✅ | ✅ | | | | | | | |
| [AI21 Chat (`ai21_chat`)](https://docs.litellm.ai/docs/providers/ai21) | ✅ | ✅ | ✅ | | | | | | | |
| [Aleph Alpha](https://docs.litellm.ai/docs/providers/aleph_alpha) | ✅ | ✅ | ✅ | | | | | | | |
| [Amazon Nova](https://docs.litellm.ai/docs/providers/amazon_nova) | ✅ | ✅ | ✅ | | | | | | | |
| [Anthropic (`anthropic`)](https://docs.litellm.ai/docs/providers/anthropic) | ✅ | ✅ | ✅ | | | | | | ✅ | |
| [Anthropic Text (`anthropic_text`)](https://docs.litellm.ai/docs/providers/anthropic) | ✅ | ✅ | ✅ | | | | | | ✅ | |
| [Anyscale](https://docs.litellm.ai/docs/providers/anyscale) | ✅ | ✅ | ✅ | | | | | | | |
| [AssemblyAI (`assemblyai`)](https://docs.litellm.ai/docs/pass_through/assembly_ai) | ✅ | ✅ | ✅ | | | ✅ | | | | |
| [Auto Router (`auto_router`)](https://docs.litellm.ai/docs/proxy/auto_routing) | ✅ | ✅ | ✅ | | | | | | | |
| [AWS - Bedrock (`bedrock`)](https://docs.litellm.ai/docs/providers/bedrock) | ✅ | ✅ | ✅ | ✅ | | | | | | ✅ |
| [AWS - Sagemaker (`sagemaker`)](https://docs.litellm.ai/docs/providers/aws_sagemaker) | ✅ | ✅ | ✅ | ✅ | | | | | | |
| [Azure (`azure`)](https://docs.litellm.ai/docs/providers/azure) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [Azure AI (`azure_ai`)](https://docs.litellm.ai/docs/providers/azure_ai) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [Azure Text (`azure_text`)](https://docs.litellm.ai/docs/providers/azure) | ✅ | ✅ | ✅ | | | ✅ | ✅ | ✅ | ✅ | |
| [Baseten (`baseten`)](https://docs.litellm.ai/docs/providers/baseten) | ✅ | ✅ | ✅ | | | | | | | |
| [Bytez (`bytez`)](https://docs.litellm.ai/docs/providers/bytez) | ✅ | ✅ | ✅ | | | | | | | |
| [Cerebras (`cerebras`)](https://docs.litellm.ai/docs/providers/cerebras) | ✅ | ✅ | ✅ | | | | | | | |
| [Clarifai (`clarifai`)](https://docs.litellm.ai/docs/providers/clarifai) | ✅ | ✅ | ✅ | | | | | | | |
| [Cloudflare AI Workers (`cloudflare`)](https://docs.litellm.ai/docs/providers/cloudflare_workers) | ✅ | ✅ | ✅ | | | | | | | |
| [Codestral (`codestral`)](https://docs.litellm.ai/docs/providers/codestral) | ✅ | ✅ | ✅ | | | | | | | |
| [Cohere (`cohere`)](https://docs.litellm.ai/docs/providers/cohere) | ✅ | ✅ | ✅ | ✅ | | | | | | ✅ |
| [Cohere Chat (`cohere_chat`)](https://docs.litellm.ai/docs/providers/cohere) | ✅ | ✅ | ✅ | | | | | | | |
| [CometAPI (`cometapi`)](https://docs.litellm.ai/docs/providers/cometapi) | ✅ | ✅ | ✅ | ✅ | | | | | | |
| [CompactifAI (`compactifai`)](https://docs.litellm.ai/docs/providers/compactifai) | ✅ | ✅ | ✅ | | | | | | | |
| [Custom (`custom`)](https://docs.litellm.ai/docs/providers/custom_llm_server) | ✅ | ✅ | ✅ | | | | | | | |
| [Custom OpenAI (`custom_openai`)](https://docs.litellm.ai/docs/providers/openai_compatible) | ✅ | ✅ | ✅ | | | ✅ | ✅ | ✅ | ✅ | |
| [Dashscope (`dashscope`)](https://docs.litellm.ai/docs/providers/dashscope) | ✅ | ✅ | ✅ | | | | | | | |
| [Databricks (`databricks`)](https://docs.litellm.ai/docs/providers/databricks) | ✅ | ✅ | ✅ | | | | | | | |
| [DataRobot (`datarobot`)](https://docs.litellm.ai/docs/providers/datarobot) | ✅ | ✅ | ✅ | | | | | | | |
| [Deepgram (`deepgram`)](https://docs.litellm.ai/docs/providers/deepgram) | ✅ | ✅ | ✅ | | | ✅ | | | | |
| [DeepInfra (`deepinfra`)](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | ✅ | | | | | | | |
| [Deepseek (`deepseek`)](https://docs.litellm.ai/docs/providers/deepseek) | ✅ | ✅ | ✅ | | | | | | | |
| [ElevenLabs (`elevenlabs`)](https://docs.litellm.ai/docs/providers/elevenlabs) | ✅ | ✅ | ✅ | | | ✅ | ✅ | | | |
| [Empower (`empower`)](https://docs.litellm.ai/docs/providers/empower) | ✅ | ✅ | ✅ | | | | | | | |
| [Fal AI (`fal_ai`)](https://docs.litellm.ai/docs/providers/fal_ai) | ✅ | ✅ | ✅ | | ✅ | | | | | |
| [Featherless AI (`featherless_ai`)](https://docs.litellm.ai/docs/providers/featherless_ai) | ✅ | ✅ | ✅ | | | | | | | |
| [Fireworks AI (`fireworks_ai`)](https://docs.litellm.ai/docs/providers/fireworks_ai) | ✅ | ✅ | ✅ | | | | | | | |
| [FriendliAI (`friendliai`)](https://docs.litellm.ai/docs/providers/friendliai) | ✅ | ✅ | ✅ | | | | | | | |
| [Galadriel (`galadriel`)](https://docs.litellm.ai/docs/providers/galadriel) | ✅ | ✅ | ✅ | | | | | | | |
| [GitHub Copilot (`github_copilot`)](https://docs.litellm.ai/docs/providers/github_copilot) | ✅ | ✅ | ✅ | ✅ | | | | | | |
| [GitHub Models (`github`)](https://docs.litellm.ai/docs/providers/github) | ✅ | ✅ | ✅ | | | | | | | |
| [Google - PaLM](https://docs.litellm.ai/docs/providers/palm) | ✅ | ✅ | ✅ | | | | | | | |
| [Google - Vertex AI (`vertex_ai`)](https://docs.litellm.ai/docs/providers/vertex) | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
| [Google AI Studio - Gemini (`gemini`)](https://docs.litellm.ai/docs/providers/gemini) | ✅ | ✅ | ✅ | | | | | | | |
| [GradientAI (`gradient_ai`)](https://docs.litellm.ai/docs/providers/gradient_ai) | ✅ | ✅ | ✅ | | | | | | | |
| [Groq AI (`groq`)](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | | | | | | | |
| [Heroku (`heroku`)](https://docs.litellm.ai/docs/providers/heroku) | ✅ | ✅ | ✅ | | | | | | | |
| [Hosted VLLM (`hosted_vllm`)](https://docs.litellm.ai/docs/providers/vllm) | ✅ | ✅ | ✅ | | | | | | | |
| [Huggingface (`huggingface`)](https://docs.litellm.ai/docs/providers/huggingface) | ✅ | ✅ | ✅ | ✅ | | | | | | ✅ |
| [Hyperbolic (`hyperbolic`)](https://docs.litellm.ai/docs/providers/hyperbolic) | ✅ | ✅ | ✅ | | | | | | | |
| [IBM - Watsonx.ai (`watsonx`)](https://docs.litellm.ai/docs/providers/watsonx) | ✅ | ✅ | ✅ | ✅ | | | | | | |
| [Infinity (`infinity`)](https://docs.litellm.ai/docs/providers/infinity) | | | | ✅ | | | | | | |
| [Jina AI (`jina_ai`)](https://docs.litellm.ai/docs/providers/jina_ai) | | | | ✅ | | | | | | |
| [Lambda AI (`lambda_ai`)](https://docs.litellm.ai/docs/providers/lambda_ai) | ✅ | ✅ | ✅ | | | | | | | |
| [Lemonade (`lemonade`)](https://docs.litellm.ai/docs/providers/lemonade) | ✅ | ✅ | ✅ | | | | | | | |
| [LiteLLM Proxy (`litellm_proxy`)](https://docs.litellm.ai/docs/providers/litellm_proxy) | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
| [Llamafile (`llamafile`)](https://docs.litellm.ai/docs/providers/llamafile) | ✅ | ✅ | ✅ | | | | | | | |
| [LM Studio (`lm_studio`)](https://docs.litellm.ai/docs/providers/lm_studio) | ✅ | ✅ | ✅ | | | | | | | |
| [Maritalk (`maritalk`)](https://docs.litellm.ai/docs/providers/maritalk) | ✅ | ✅ | ✅ | | | | | | | |
| [Meta - Llama API (`meta_llama`)](https://docs.litellm.ai/docs/providers/meta_llama) | ✅ | ✅ | ✅ | | | | | | | |
| [Mistral AI API (`mistral`)](https://docs.litellm.ai/docs/providers/mistral) | ✅ | ✅ | ✅ | ✅ | | | | | | |
| [Moonshot (`moonshot`)](https://docs.litellm.ai/docs/providers/moonshot) | ✅ | ✅ | ✅ | | | | | | | |
| [Morph (`morph`)](https://docs.litellm.ai/docs/providers/morph) | ✅ | ✅ | ✅ | | | | | | | |
| [Nebius AI Studio (`nebius`)](https://docs.litellm.ai/docs/providers/nebius) | ✅ | ✅ | ✅ | ✅ | | | | | | |
| [NLP Cloud (`nlp_cloud`)](https://docs.litellm.ai/docs/providers/nlp_cloud) | ✅ | ✅ | ✅ | | | | | | | |
| [Novita AI (`novita`)](https://novita.ai/models/llm?utm_source=github_litellm&utm_medium=github_readme&utm_campaign=github_link) | ✅ | ✅ | ✅ | | | | | | | |
| [Nscale (`nscale`)](https://docs.litellm.ai/docs/providers/nscale) | ✅ | ✅ | ✅ | | | | | | | |
| [Nvidia NIM (`nvidia_nim`)](https://docs.litellm.ai/docs/providers/nvidia_nim) | ✅ | ✅ | ✅ | | | | | | | |
| [OCI (`oci`)](https://docs.litellm.ai/docs/providers/oci) | ✅ | ✅ | ✅ | | | | | | | |
| [Ollama (`ollama`)](https://docs.litellm.ai/docs/providers/ollama) | ✅ | ✅ | ✅ | ✅ | | | | | | |
| [Ollama Chat (`ollama_chat`)](https://docs.litellm.ai/docs/providers/ollama) | ✅ | ✅ | ✅ | | | | | | | |
| [Oobabooga (`oobabooga`)](https://docs.litellm.ai/docs/providers/openai_compatible) | ✅ | ✅ | ✅ | | | ✅ | ✅ | ✅ | ✅ | |
| [OpenAI (`openai`)](https://docs.litellm.ai/docs/providers/openai) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [OpenAI-like (`openai_like`)](https://docs.litellm.ai/docs/providers/openai_compatible) | | | | ✅ | | | | | | |
| [OpenRouter (`openrouter`)](https://docs.litellm.ai/docs/providers/openrouter) | ✅ | ✅ | ✅ | | | | | | | |
| [OVHCloud AI Endpoints (`ovhcloud`)](https://docs.litellm.ai/docs/providers/ovhcloud) | ✅ | ✅ | ✅ | | | | | | | |
| [Perplexity AI (`perplexity`)](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | | | | | | | |
| [Petals (`petals`)](https://docs.litellm.ai/docs/providers/petals) | ✅ | ✅ | ✅ | | | | | | | |
| [Predibase (`predibase`)](https://docs.litellm.ai/docs/providers/predibase) | ✅ | ✅ | ✅ | | | | | | | |
| [Recraft (`recraft`)](https://docs.litellm.ai/docs/providers/recraft) | | | | | ✅ | | | | | |
| [Replicate (`replicate`)](https://docs.litellm.ai/docs/providers/replicate) | ✅ | ✅ | ✅ | | | | | | | |
| [Sagemaker Chat (`sagemaker_chat`)](https://docs.litellm.ai/docs/providers/aws_sagemaker) | ✅ | ✅ | ✅ | | | | | | | |
| [Sambanova (`sambanova`)](https://docs.litellm.ai/docs/providers/sambanova) | ✅ | ✅ | ✅ | | | | | | | |
| [Snowflake (`snowflake`)](https://docs.litellm.ai/docs/providers/snowflake) | ✅ | ✅ | ✅ | | | | | | | |
| [Text Completion Codestral (`text-completion-codestral`)](https://docs.litellm.ai/docs/providers/codestral) | ✅ | ✅ | ✅ | | | | | | | |
| [Text Completion OpenAI (`text-completion-openai`)](https://docs.litellm.ai/docs/providers/text_completion_openai) | ✅ | ✅ | ✅ | | | ✅ | ✅ | ✅ | ✅ | |
| [Together AI (`together_ai`)](https://docs.litellm.ai/docs/providers/togetherai) | ✅ | ✅ | ✅ | | | | | | | |
| [Topaz (`topaz`)](https://docs.litellm.ai/docs/providers/topaz) | ✅ | ✅ | ✅ | | | | | | | |
| [Triton (`triton`)](https://docs.litellm.ai/docs/providers/triton-inference-server) | ✅ | ✅ | ✅ | | | | | | | |
| [V0 (`v0`)](https://docs.litellm.ai/docs/providers/v0) | ✅ | ✅ | ✅ | | | | | | | |
| [Vercel AI Gateway (`vercel_ai_gateway`)](https://docs.litellm.ai/docs/providers/vercel_ai_gateway) | ✅ | ✅ | ✅ | | | | | | | |
| [VLLM (`vllm`)](https://docs.litellm.ai/docs/providers/vllm) | ✅ | ✅ | ✅ | | | | | | | |
| [Volcengine (`volcengine`)](https://docs.litellm.ai/docs/providers/volcano) | ✅ | ✅ | ✅ | | | | | | | |
| [Voyage AI (`voyage`)](https://docs.litellm.ai/docs/providers/voyage) | | | | ✅ | | | | | | |
| [WandB Inference (`wandb`)](https://docs.litellm.ai/docs/providers/wandb_inference) | ✅ | ✅ | ✅ | | | | | | | |
| [Watsonx Text (`watsonx_text`)](https://docs.litellm.ai/docs/providers/watsonx) | ✅ | ✅ | ✅ | | | | | | | |
| [xAI (`xai`)](https://docs.litellm.ai/docs/providers/xai) | ✅ | ✅ | ✅ | | | | | | | |
| [Xinference (`xinference`)](https://docs.litellm.ai/docs/providers/xinference) | | | | ✅ | | | | | | |
[**Read the Docs**](https://docs.litellm.ai/docs/)
## Run in Developer mode
### Services
1. Setup .env file in root
2. Run dependant services `docker-compose up db prometheus`
### Backend
1. (In root) create virtual environment `python -m venv .venv`
2. Activate virtual environment `source .venv/bin/activate`
3. Install dependencies `pip install -e ".[all]"`
4. `pip install prisma`
5. `prisma generate`
6. Start proxy backend `python litellm/proxy/proxy_cli.py`
### Frontend
1. Navigate to `ui/litellm-dashboard`
2. Install dependencies `npm install`
3. Run `npm run dev` to start the dashboard
# Enterprise
For companies that need better security, user management and professional support
[Talk to founders](https://calendly.com/d/cx9p-5yf-2nm/litellm-introductions)
This covers:
- ✅ **Features under the [LiteLLM Commercial License](https://docs.litellm.ai/docs/proxy/enterprise):**
- ✅ **Feature Prioritization**
- ✅ **Custom Integrations**
- ✅ **Professional Support - Dedicated discord + slack**
- ✅ **Custom SLAs**
- ✅ **Secure access with Single Sign-On**
# Contributing
We welcome contributions to LiteLLM! Whether you're fixing bugs, adding features, or improving documentation, we appreciate your help.
## Quick Start for Contributors
This requires poetry to be installed.
```bash
git clone https://github.com/BerriAI/litellm.git
cd litellm
make install-dev # Install development dependencies
make format # Format your code
make lint # Run all linting checks
make test-unit # Run unit tests
make format-check # Check formatting only
```
For detailed contributing guidelines, see [CONTRIBUTING.md](CONTRIBUTING.md).
## Code Quality / Linting
LiteLLM follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html).
Our automated checks include:
- **Black** for code formatting
- **Ruff** for linting and code quality
- **MyPy** for type checking
- **Circular import detection**
- **Import safety checks**
All these checks must pass before your PR can be merged.
# Support / talk with founders
- [Schedule Demo 👋](https://calendly.com/d/4mp-gd3-k5k/berriai-1-1-onboarding-litellm-hosted-version)
- [Community Discord 💭](https://discord.gg/wuPM9dRgDw)
- [Community Slack 💭](https://www.litellm.ai/support)
- Our numbers 📞 +1 (770) 8783-106 / +1 (412) 618-6238
- Our emails ✉️ ishaan@berri.ai / krrish@berri.ai
# Why did we build this
- **Need for simplicity**: Our code started to get extremely complicated managing & translating calls between Azure, OpenAI and Cohere.
# Contributors
<!-- ALL-CONTRIBUTORS-LIST:START - Do not remove or modify this section -->
<!-- prettier-ignore-start -->
<!-- markdownlint-disable -->
<!-- markdownlint-restore -->
<!-- prettier-ignore-end -->
<!-- ALL-CONTRIBUTORS-LIST:END -->
<a href="https://github.com/BerriAI/litellm/graphs/contributors">
<img src="https://contrib.rocks/image?repo=BerriAI/litellm" />
</a>

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,4 @@
Wheel-Version: 1.0
Generator: poetry-core 1.9.1
Root-Is-Purelib: true
Tag: py3-none-any

View File

@@ -0,0 +1,4 @@
[console_scripts]
litellm=litellm:run_server
litellm-proxy=litellm.proxy.client.cli:cli

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,449 @@
"""
Lazy Import System
This module implements lazy loading for LiteLLM attributes. Instead of importing
everything when the module loads, we only import things when they're actually used.
How it works:
1. When someone accesses `litellm.some_attribute`, Python calls __getattr__ in __init__.py
2. __getattr__ looks up the attribute name in a registry
3. The registry points to a handler function (like _lazy_import_utils)
4. The handler function imports the module and returns the attribute
5. The result is cached so we don't import it again
This makes importing litellm much faster because we don't load heavy dependencies
until they're actually needed.
"""
import importlib
import sys
from typing import Any, Optional, cast, Callable
# Import all the data structures that define what can be lazy-loaded
# These are just lists of names and maps of where to find them
from ._lazy_imports_registry import (
# Name tuples
COST_CALCULATOR_NAMES,
LITELLM_LOGGING_NAMES,
UTILS_NAMES,
TOKEN_COUNTER_NAMES,
LLM_CLIENT_CACHE_NAMES,
BEDROCK_TYPES_NAMES,
TYPES_UTILS_NAMES,
CACHING_NAMES,
HTTP_HANDLER_NAMES,
DOTPROMPT_NAMES,
LLM_CONFIG_NAMES,
TYPES_NAMES,
LLM_PROVIDER_LOGIC_NAMES,
UTILS_MODULE_NAMES,
# Import maps
_UTILS_IMPORT_MAP,
_COST_CALCULATOR_IMPORT_MAP,
_TYPES_UTILS_IMPORT_MAP,
_TOKEN_COUNTER_IMPORT_MAP,
_BEDROCK_TYPES_IMPORT_MAP,
_CACHING_IMPORT_MAP,
_LITELLM_LOGGING_IMPORT_MAP,
_DOTPROMPT_IMPORT_MAP,
_TYPES_IMPORT_MAP,
_LLM_CONFIGS_IMPORT_MAP,
_LLM_PROVIDER_LOGIC_IMPORT_MAP,
_UTILS_MODULE_IMPORT_MAP,
)
def _get_litellm_globals() -> dict:
"""
Get the globals dictionary of the litellm module.
This is where we cache imported attributes so we don't import them twice.
When you do `litellm.some_function`, it gets stored in this dictionary.
"""
return sys.modules["litellm"].__dict__
def _get_utils_globals() -> dict:
"""
Get the globals dictionary of the utils module.
This is where we cache imported attributes so we don't import them twice.
When you do `litellm.utils.some_function`, it gets stored in this dictionary.
"""
return sys.modules["litellm.utils"].__dict__
# These are special lazy loaders for things that are used internally
# They're separate from the main lazy import system because they have specific use cases
# Lazy loader for default encoding - avoids importing heavy tiktoken library at startup
_default_encoding: Optional[Any] = None
def _get_default_encoding() -> Any:
"""
Lazily load and cache the default OpenAI encoding.
This avoids importing `litellm.litellm_core_utils.default_encoding` (and thus tiktoken)
at `litellm` import time. The encoding is cached after the first import.
This is used internally by utils.py functions that need the encoding but shouldn't
trigger its import during module load.
"""
global _default_encoding
if _default_encoding is None:
from litellm.litellm_core_utils.default_encoding import encoding
_default_encoding = encoding
return _default_encoding
# Lazy loader for get_modified_max_tokens to avoid importing token_counter at module import time
_get_modified_max_tokens_func: Optional[Any] = None
def _get_modified_max_tokens() -> Any:
"""
Lazily load and cache the get_modified_max_tokens function.
This avoids importing `litellm.litellm_core_utils.token_counter` at `litellm` import time.
The function is cached after the first import.
This is used internally by utils.py functions that need the token counter but shouldn't
trigger its import during module load.
"""
global _get_modified_max_tokens_func
if _get_modified_max_tokens_func is None:
from litellm.litellm_core_utils.token_counter import (
get_modified_max_tokens as _get_modified_max_tokens_imported,
)
_get_modified_max_tokens_func = _get_modified_max_tokens_imported
return _get_modified_max_tokens_func
# Lazy loader for token_counter to avoid importing token_counter module at module import time
_token_counter_new_func: Optional[Any] = None
def _get_token_counter_new() -> Any:
"""
Lazily load and cache the token_counter function (aliased as token_counter_new).
This avoids importing `litellm.litellm_core_utils.token_counter` at `litellm` import time.
The function is cached after the first import.
This is used internally by utils.py functions that need the token counter but shouldn't
trigger its import during module load.
"""
global _token_counter_new_func
if _token_counter_new_func is None:
from litellm.litellm_core_utils.token_counter import (
token_counter as _token_counter_imported,
)
_token_counter_new_func = _token_counter_imported
return _token_counter_new_func
# ============================================================================
# MAIN LAZY IMPORT SYSTEM
# ============================================================================
# This registry maps attribute names (like "ModelResponse") to handler functions
# It's built once the first time someone accesses a lazy-loaded attribute
# Example: {"ModelResponse": _lazy_import_utils, "Cache": _lazy_import_caching, ...}
_LAZY_IMPORT_REGISTRY: Optional[dict[str, Callable[[str], Any]]] = None
def _get_lazy_import_registry() -> dict[str, Callable[[str], Any]]:
"""
Build the registry that maps attribute names to their handler functions.
This is called once, the first time someone accesses a lazy-loaded attribute.
After that, we just look up the handler function in this dictionary.
Returns:
Dictionary like {"ModelResponse": _lazy_import_utils, ...}
"""
global _LAZY_IMPORT_REGISTRY
if _LAZY_IMPORT_REGISTRY is None:
# Build the registry by going through each category and mapping
# all the names in that category to their handler function
_LAZY_IMPORT_REGISTRY = {}
# For each category, map all its names to the handler function
# Example: All names in UTILS_NAMES get mapped to _lazy_import_utils
for name in COST_CALCULATOR_NAMES:
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_cost_calculator
for name in LITELLM_LOGGING_NAMES:
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_litellm_logging
for name in UTILS_NAMES:
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_utils
for name in TOKEN_COUNTER_NAMES:
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_token_counter
for name in LLM_CLIENT_CACHE_NAMES:
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_llm_client_cache
for name in BEDROCK_TYPES_NAMES:
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_bedrock_types
for name in TYPES_UTILS_NAMES:
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_types_utils
for name in CACHING_NAMES:
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_caching
for name in HTTP_HANDLER_NAMES:
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_http_handlers
for name in DOTPROMPT_NAMES:
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_dotprompt
for name in LLM_CONFIG_NAMES:
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_llm_configs
for name in TYPES_NAMES:
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_types
for name in LLM_PROVIDER_LOGIC_NAMES:
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_llm_provider_logic
for name in UTILS_MODULE_NAMES:
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_utils_module
return _LAZY_IMPORT_REGISTRY
def _generic_lazy_import(
name: str, import_map: dict[str, tuple[str, str]], category: str
) -> Any:
"""
Generic function that handles lazy importing for most attributes.
This is the workhorse function - it does the actual importing and caching.
Most handler functions just call this with their specific import map.
Steps:
1. Check if the name exists in the import map (if not, raise error)
2. Check if we've already imported it (if yes, return cached value)
3. Look up where to find it (module_path and attr_name from the map)
4. Import the module (Python caches this automatically)
5. Get the attribute from the module
6. Cache it in _globals so we don't import again
7. Return it
Args:
name: The attribute name someone is trying to access (e.g., "ModelResponse")
import_map: Dictionary telling us where to find each attribute
Format: {"ModelResponse": (".utils", "ModelResponse")}
category: Just for error messages (e.g., "Utils", "Cost calculator")
"""
# Step 1: Make sure this attribute exists in our map
if name not in import_map:
raise AttributeError(f"{category} lazy import: unknown attribute {name!r}")
# Step 2: Get the cache (where we store imported things)
_globals = _get_litellm_globals()
# Step 3: If we've already imported it, just return the cached version
if name in _globals:
return _globals[name]
# Step 4: Look up where to find this attribute
# The map tells us: (module_path, attribute_name)
# Example: (".utils", "ModelResponse") means "look in .utils module, get ModelResponse"
module_path, attr_name = import_map[name]
# Step 5: Import the module
# Python automatically caches modules in sys.modules, so calling this twice is fast
# If module_path starts with ".", it's a relative import (needs package="litellm")
# Otherwise it's an absolute import (like "litellm.caching.caching")
if module_path.startswith("."):
module = importlib.import_module(module_path, package="litellm")
else:
module = importlib.import_module(module_path)
# Step 6: Get the actual attribute from the module
# Example: getattr(utils_module, "ModelResponse") returns the ModelResponse class
value = getattr(module, attr_name)
# Step 7: Cache it so we don't have to import again next time
_globals[name] = value
# Step 8: Return it
return value
# ============================================================================
# HANDLER FUNCTIONS
# ============================================================================
# These functions are called when someone accesses a lazy-loaded attribute.
# Most of them just call _generic_lazy_import with their specific import map.
# The registry (above) maps attribute names to these handler functions.
def _lazy_import_utils(name: str) -> Any:
"""Handler for utils module attributes (ModelResponse, token_counter, etc.)"""
return _generic_lazy_import(name, _UTILS_IMPORT_MAP, "Utils")
def _lazy_import_cost_calculator(name: str) -> Any:
"""Handler for cost calculator functions (completion_cost, cost_per_token, etc.)"""
return _generic_lazy_import(name, _COST_CALCULATOR_IMPORT_MAP, "Cost calculator")
def _lazy_import_token_counter(name: str) -> Any:
"""Handler for token counter utilities"""
return _generic_lazy_import(name, _TOKEN_COUNTER_IMPORT_MAP, "Token counter")
def _lazy_import_bedrock_types(name: str) -> Any:
"""Handler for Bedrock type aliases"""
return _generic_lazy_import(name, _BEDROCK_TYPES_IMPORT_MAP, "Bedrock types")
def _lazy_import_types_utils(name: str) -> Any:
"""Handler for types from litellm.types.utils (BudgetConfig, ImageObject, etc.)"""
return _generic_lazy_import(name, _TYPES_UTILS_IMPORT_MAP, "Types utils")
def _lazy_import_caching(name: str) -> Any:
"""Handler for caching classes (Cache, DualCache, RedisCache, etc.)"""
return _generic_lazy_import(name, _CACHING_IMPORT_MAP, "Caching")
def _lazy_import_dotprompt(name: str) -> Any:
"""Handler for dotprompt integration globals"""
return _generic_lazy_import(name, _DOTPROMPT_IMPORT_MAP, "Dotprompt")
def _lazy_import_types(name: str) -> Any:
"""Handler for type classes (GuardrailItem, etc.)"""
return _generic_lazy_import(name, _TYPES_IMPORT_MAP, "Types")
def _lazy_import_llm_configs(name: str) -> Any:
"""Handler for LLM config classes (AnthropicConfig, OpenAILikeChatConfig, etc.)"""
return _generic_lazy_import(name, _LLM_CONFIGS_IMPORT_MAP, "LLM config")
def _lazy_import_litellm_logging(name: str) -> Any:
"""Handler for litellm_logging module (Logging, modify_integration)"""
return _generic_lazy_import(name, _LITELLM_LOGGING_IMPORT_MAP, "Litellm logging")
def _lazy_import_llm_provider_logic(name: str) -> Any:
"""Handler for LLM provider logic functions (get_llm_provider, etc.)"""
return _generic_lazy_import(
name, _LLM_PROVIDER_LOGIC_IMPORT_MAP, "LLM provider logic"
)
def _lazy_import_utils_module(name: str) -> Any:
"""
Handler for utils module lazy imports.
This uses a custom implementation because utils module needs to use
_get_utils_globals() instead of _get_litellm_globals() for caching.
"""
# Check if this attribute exists in our map
if name not in _UTILS_MODULE_IMPORT_MAP:
raise AttributeError(f"Utils module lazy import: unknown attribute {name!r}")
# Get the cache (where we store imported things) - use utils globals
_globals = _get_utils_globals()
# If we've already imported it, just return the cached version
if name in _globals:
return _globals[name]
# Look up where to find this attribute
module_path, attr_name = _UTILS_MODULE_IMPORT_MAP[name]
# Import the module
if module_path.startswith("."):
module = importlib.import_module(module_path, package="litellm")
else:
module = importlib.import_module(module_path)
# Get the actual attribute from the module
value = getattr(module, attr_name)
# Cache it so we don't have to import again next time
_globals[name] = value
# Return it
return value
# ============================================================================
# SPECIAL HANDLERS
# ============================================================================
# These handlers have custom logic that doesn't fit the generic pattern
def _lazy_import_llm_client_cache(name: str) -> Any:
"""
Handler for LLM client cache - has special logic for singleton instance.
This one is different because:
- "LLMClientCache" is the class itself
- "in_memory_llm_clients_cache" is a singleton instance of that class
So we need custom logic to handle both cases.
"""
_globals = _get_litellm_globals()
# If already cached, return it
if name in _globals:
return _globals[name]
# Import the class
module = importlib.import_module("litellm.caching.llm_caching_handler")
LLMClientCache = getattr(module, "LLMClientCache")
# If they want the class itself, return it
if name == "LLMClientCache":
_globals["LLMClientCache"] = LLMClientCache
return LLMClientCache
# If they want the singleton instance, create it (only once)
if name == "in_memory_llm_clients_cache":
instance = LLMClientCache()
_globals["in_memory_llm_clients_cache"] = instance
return instance
raise AttributeError(f"LLM client cache lazy import: unknown attribute {name!r}")
def _lazy_import_http_handlers(name: str) -> Any:
"""
Handler for HTTP clients - has special logic for creating client instances.
This one is different because:
- These aren't just imports, they're actual client instances that need to be created
- They need configuration (timeout, etc.) from the module globals
- They use factory functions instead of direct instantiation
"""
_globals = _get_litellm_globals()
if name == "module_level_aclient":
# Create an async HTTP client using the factory function
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
# Get timeout from module config (if set)
timeout = _globals.get("request_timeout")
params = {"timeout": timeout, "client_alias": "module level aclient"}
# Create the client instance
provider_id = cast(Any, "litellm_module_level_client")
async_client = get_async_httpx_client(
llm_provider=provider_id,
params=params,
)
# Cache it so we don't create it again
_globals["module_level_aclient"] = async_client
return async_client
if name == "module_level_client":
# Create a sync HTTP client
from litellm.llms.custom_httpx.http_handler import HTTPHandler
timeout = _globals.get("request_timeout")
sync_client = HTTPHandler(timeout=timeout)
# Cache it
_globals["module_level_client"] = sync_client
return sync_client
raise AttributeError(f"HTTP handlers lazy import: unknown attribute {name!r}")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,352 @@
import ast
import logging
import os
import sys
from datetime import datetime
from logging import Formatter
from typing import Any, Dict, Optional
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
set_verbose = False
if set_verbose is True:
logging.warning(
"`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs."
)
json_logs = bool(os.getenv("JSON_LOGS", False))
# Create a handler for the logger (you may need to adapt this based on your needs)
log_level = os.getenv("LITELLM_LOG", "DEBUG")
numeric_level: str = getattr(logging, log_level.upper())
handler = logging.StreamHandler()
handler.setLevel(numeric_level)
def _try_parse_json_message(message: str) -> Optional[Dict[str, Any]]:
"""
Try to parse a log message as JSON. Returns parsed dict if valid, else None.
Handles messages that are entirely valid JSON (e.g. json.dumps output).
Uses shared safe_json_loads for consistent error handling.
"""
if not message or not isinstance(message, str):
return None
msg_stripped = message.strip()
if not (msg_stripped.startswith("{") or msg_stripped.startswith("[")):
return None
parsed = safe_json_loads(message, default=None)
if parsed is None or not isinstance(parsed, dict):
return None
return parsed
def _try_parse_embedded_python_dict(message: str) -> Optional[Dict[str, Any]]:
"""
Try to find and parse a Python dict repr (e.g. str(d) or repr(d)) embedded in
the message. Handles patterns like:
"get_available_deployment for model: X, Selected deployment: {'model_name': '...', ...} for model: X"
Uses ast.literal_eval for safe parsing. Returns the parsed dict or None.
"""
if not message or not isinstance(message, str) or "{" not in message:
return None
i = 0
while i < len(message):
start = message.find("{", i)
if start == -1:
break
depth = 0
for j in range(start, len(message)):
c = message[j]
if c == "{":
depth += 1
elif c == "}":
depth -= 1
if depth == 0:
substr = message[start : j + 1]
try:
result = ast.literal_eval(substr)
if isinstance(result, dict) and len(result) > 0:
return result
except (ValueError, SyntaxError, TypeError):
pass
break
i = start + 1
return None
# Standard LogRecord attribute names - used to identify 'extra' fields.
# Derived at runtime so we automatically include version-specific attrs (e.g. taskName).
def _get_standard_record_attrs() -> frozenset:
"""Standard LogRecord attribute names - excludes extra keys from logger.debug(..., extra={...})."""
return frozenset(logging.LogRecord("", 0, "", 0, "", (), None).__dict__.keys())
_STANDARD_RECORD_ATTRS = _get_standard_record_attrs()
class JsonFormatter(Formatter):
def __init__(self):
super(JsonFormatter, self).__init__()
def formatTime(self, record, datefmt=None):
# Use datetime to format the timestamp in ISO 8601 format
dt = datetime.fromtimestamp(record.created)
return dt.isoformat()
def format(self, record):
message_str = record.getMessage()
json_record: Dict[str, Any] = {
"message": message_str,
"level": record.levelname,
"timestamp": self.formatTime(record),
}
# Parse embedded JSON or Python dict repr in message so sub-fields become first-class properties
parsed = _try_parse_json_message(message_str)
if parsed is None:
parsed = _try_parse_embedded_python_dict(message_str)
if parsed is not None:
for key, value in parsed.items():
if key not in json_record:
json_record[key] = value
# Include extra attributes passed via logger.debug("msg", extra={...})
for key, value in record.__dict__.items():
if key not in _STANDARD_RECORD_ATTRS and key not in json_record:
json_record[key] = value
if record.exc_info:
json_record["stacktrace"] = self.formatException(record.exc_info)
return safe_dumps(json_record)
# Function to set up exception handlers for JSON logging
def _setup_json_exception_handlers(formatter):
# Create a handler with JSON formatting for exceptions
error_handler = logging.StreamHandler()
error_handler.setFormatter(formatter)
# Setup excepthook for uncaught exceptions
def json_excepthook(exc_type, exc_value, exc_traceback):
record = logging.LogRecord(
name="LiteLLM",
level=logging.ERROR,
pathname="",
lineno=0,
msg=str(exc_value),
args=(),
exc_info=(exc_type, exc_value, exc_traceback),
)
error_handler.handle(record)
sys.excepthook = json_excepthook
# Configure asyncio exception handler if possible
try:
import asyncio
def async_json_exception_handler(loop, context):
exception = context.get("exception")
if exception:
record = logging.LogRecord(
name="LiteLLM",
level=logging.ERROR,
pathname="",
lineno=0,
msg=str(exception),
args=(),
exc_info=None,
)
error_handler.handle(record)
else:
loop.default_exception_handler(context)
asyncio.get_event_loop().set_exception_handler(async_json_exception_handler)
except Exception:
pass
# Create a formatter and set it for the handler
if json_logs:
handler.setFormatter(JsonFormatter())
_setup_json_exception_handlers(JsonFormatter())
else:
formatter = logging.Formatter(
"\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s",
datefmt="%H:%M:%S",
)
handler.setFormatter(formatter)
verbose_proxy_logger = logging.getLogger("LiteLLM Proxy")
verbose_router_logger = logging.getLogger("LiteLLM Router")
verbose_logger = logging.getLogger("LiteLLM")
# Add the handler to the logger
verbose_router_logger.addHandler(handler)
verbose_proxy_logger.addHandler(handler)
verbose_logger.addHandler(handler)
def _suppress_loggers():
"""Suppress noisy loggers at INFO level"""
# Suppress httpx request logging at INFO level
httpx_logger = logging.getLogger("httpx")
httpx_logger.setLevel(logging.WARNING)
# Suppress APScheduler logging at INFO level
apscheduler_executors_logger = logging.getLogger("apscheduler.executors.default")
apscheduler_executors_logger.setLevel(logging.WARNING)
apscheduler_scheduler_logger = logging.getLogger("apscheduler.scheduler")
apscheduler_scheduler_logger.setLevel(logging.WARNING)
# Call the suppression function
_suppress_loggers()
ALL_LOGGERS = [
logging.getLogger(),
verbose_logger,
verbose_router_logger,
verbose_proxy_logger,
]
def _get_loggers_to_initialize():
"""
Get all loggers that should be initialized with the JSON handler.
Includes third-party integration loggers (like langfuse) if they are
configured as callbacks.
"""
import litellm
loggers = list(ALL_LOGGERS)
# Add langfuse logger if langfuse is being used as a callback
langfuse_callbacks = {"langfuse", "langfuse_otel"}
all_callbacks = set(litellm.success_callback + litellm.failure_callback)
if langfuse_callbacks & all_callbacks:
loggers.append(logging.getLogger("langfuse"))
return loggers
def _initialize_loggers_with_handler(handler: logging.Handler):
"""
Initialize all loggers with a handler
- Adds a handler to each logger
- Prevents bubbling to parent/root (critical to prevent duplicate JSON logs)
"""
for lg in _get_loggers_to_initialize():
lg.handlers.clear() # remove any existing handlers
lg.addHandler(handler) # add JSON formatter handler
lg.propagate = False # prevent bubbling to parent/root
def _get_uvicorn_json_log_config():
"""
Generate a uvicorn log_config dictionary that applies JSON formatting to all loggers.
This ensures that uvicorn's access logs, error logs, and all application logs
are formatted as JSON when json_logs is enabled.
"""
json_formatter_class = "litellm._logging.JsonFormatter"
# Use the module-level log_level variable for consistency
uvicorn_log_level = log_level.upper()
log_config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"json": {
"()": json_formatter_class,
},
"default": {
"()": json_formatter_class,
},
"access": {
"()": json_formatter_class,
},
},
"handlers": {
"default": {
"formatter": "json",
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
},
"access": {
"formatter": "access",
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
},
},
"loggers": {
"uvicorn": {
"handlers": ["default"],
"level": uvicorn_log_level,
"propagate": False,
},
"uvicorn.error": {
"handlers": ["default"],
"level": uvicorn_log_level,
"propagate": False,
},
"uvicorn.access": {
"handlers": ["access"],
"level": uvicorn_log_level,
"propagate": False,
},
},
}
return log_config
def _turn_on_json():
"""
Turn on JSON logging
- Adds a JSON formatter to all loggers
"""
handler = logging.StreamHandler()
handler.setFormatter(JsonFormatter())
_initialize_loggers_with_handler(handler)
# Set up exception handlers
_setup_json_exception_handlers(JsonFormatter())
def _turn_on_debug():
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
verbose_router_logger.setLevel(level=logging.DEBUG) # set router logs to debug
verbose_proxy_logger.setLevel(level=logging.DEBUG) # set proxy logs to debug
def _disable_debugging():
verbose_logger.disabled = True
verbose_router_logger.disabled = True
verbose_proxy_logger.disabled = True
def _enable_debugging():
verbose_logger.disabled = False
verbose_router_logger.disabled = False
verbose_proxy_logger.disabled = False
def print_verbose(print_statement):
try:
if set_verbose:
print(print_statement) # noqa
except Exception:
pass
def _is_debugging_on() -> bool:
"""
Returns True if debugging is on
"""
return verbose_logger.isEnabledFor(logging.DEBUG) or set_verbose is True

View File

@@ -0,0 +1,598 @@
# +-----------------------------------------------+
# | |
# | Give Feedback / Get Help |
# | https://github.com/BerriAI/litellm/issues/new |
# | |
# +-----------------------------------------------+
#
# Thank you users! We ❤️ you! - Krrish & Ishaan
import inspect
import json
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
import os
from typing import Callable, List, Optional, Union
import redis # type: ignore
import redis.asyncio as async_redis # type: ignore
from litellm import get_secret, get_secret_str
from litellm.constants import REDIS_CONNECTION_POOL_TIMEOUT, REDIS_SOCKET_TIMEOUT
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
from ._logging import verbose_logger
def _get_redis_kwargs():
arg_spec = inspect.getfullargspec(redis.Redis)
# Only allow primitive arguments
exclude_args = {
"self",
"connection_pool",
"retry",
}
include_args = [
"url",
"redis_connect_func",
"gcp_service_account",
"gcp_ssl_ca_certs",
]
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
return available_args
def _get_redis_url_kwargs(client=None):
if client is None:
client = redis.Redis.from_url
arg_spec = inspect.getfullargspec(redis.Redis.from_url)
# Only allow primitive arguments
exclude_args = {
"self",
"connection_pool",
"retry",
}
include_args = ["url"]
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
return available_args
def _get_redis_cluster_kwargs(client=None):
if client is None:
client = redis.Redis.from_url
arg_spec = inspect.getfullargspec(redis.RedisCluster)
# Only allow primitive arguments
exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"}
available_args = [x for x in arg_spec.args if x not in exclude_args]
available_args.append("password")
available_args.append("username")
available_args.append("ssl")
available_args.append("ssl_cert_reqs")
available_args.append("ssl_check_hostname")
available_args.append("ssl_ca_certs")
available_args.append(
"redis_connect_func"
) # Needed for sync clusters and IAM detection
available_args.append("gcp_service_account")
available_args.append("gcp_ssl_ca_certs")
available_args.append("max_connections")
return available_args
def _get_redis_env_kwarg_mapping():
PREFIX = "REDIS_"
return {f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()}
def _redis_kwargs_from_environment():
mapping = _get_redis_env_kwarg_mapping()
return_dict = {}
for k, v in mapping.items():
value = get_secret(k, default_value=None) # type: ignore
if value is not None:
return_dict[v] = value
return return_dict
def _generate_gcp_iam_access_token(service_account: str) -> str:
"""
Generate GCP IAM access token for Redis authentication.
Args:
service_account: GCP service account in format 'projects/-/serviceAccounts/name@project.iam.gserviceaccount.com'
Returns:
Access token string for GCP IAM authentication
"""
try:
from google.cloud import iam_credentials_v1
except ImportError:
raise ImportError(
"google-cloud-iam is required for GCP IAM Redis authentication. "
"Install it with: pip install google-cloud-iam"
)
client = iam_credentials_v1.IAMCredentialsClient()
request = iam_credentials_v1.GenerateAccessTokenRequest(
name=service_account,
scope=["https://www.googleapis.com/auth/cloud-platform"],
)
response = client.generate_access_token(request=request)
return str(response.access_token)
def create_gcp_iam_redis_connect_func(
service_account: str,
ssl_ca_certs: Optional[str] = None,
) -> Callable:
"""
Creates a custom Redis connection function for GCP IAM authentication.
Args:
service_account: GCP service account in format 'projects/-/serviceAccounts/name@project.iam.gserviceaccount.com'
ssl_ca_certs: Path to SSL CA certificate file for secure connections
Returns:
A connection function that can be used with Redis clients
"""
def iam_connect(self):
"""Initialize the connection and authenticate using GCP IAM"""
from redis.exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
)
from redis.utils import str_if_bytes
self._parser.on_connect(self)
auth_args = (_generate_gcp_iam_access_token(service_account),)
self.send_command("AUTH", *auth_args, check_health=False)
try:
auth_response = self.read_response()
except AuthenticationWrongNumberOfArgsError:
# Fallback to password auth if IAM fails
if hasattr(self, "password") and self.password:
self.send_command("AUTH", self.password, check_health=False)
auth_response = self.read_response()
else:
raise
if str_if_bytes(auth_response) != "OK":
raise AuthenticationError("GCP IAM authentication failed")
return iam_connect
def get_redis_url_from_environment():
if "REDIS_URL" in os.environ:
return os.environ["REDIS_URL"]
if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ:
raise ValueError(
"Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis."
)
if "REDIS_SSL" in os.environ and os.environ["REDIS_SSL"].lower() == "true":
redis_protocol = "rediss"
else:
redis_protocol = "redis"
# Build authentication part of URL
auth_part = ""
if "REDIS_USERNAME" in os.environ and "REDIS_PASSWORD" in os.environ:
auth_part = f"{os.environ['REDIS_USERNAME']}:{os.environ['REDIS_PASSWORD']}@"
elif "REDIS_PASSWORD" in os.environ:
auth_part = f"{os.environ['REDIS_PASSWORD']}@"
return f"{redis_protocol}://{auth_part}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
def _get_redis_client_logic(**env_overrides):
"""
Common functionality across sync + async redis client implementations
"""
### check if "os.environ/<key-name>" passed in
for k, v in env_overrides.items():
if isinstance(v, str) and v.startswith("os.environ/"):
v = v.replace("os.environ/", "")
value = get_secret(v) # type: ignore
env_overrides[k] = value
redis_kwargs = {
**_redis_kwargs_from_environment(),
**env_overrides,
}
_startup_nodes: Optional[Union[str, list]] = redis_kwargs.get("startup_nodes", None) or get_secret( # type: ignore
"REDIS_CLUSTER_NODES"
)
if _startup_nodes is not None and isinstance(_startup_nodes, str):
redis_kwargs["startup_nodes"] = json.loads(_startup_nodes)
_sentinel_nodes: Optional[Union[str, list]] = redis_kwargs.get("sentinel_nodes", None) or get_secret( # type: ignore
"REDIS_SENTINEL_NODES"
)
if _sentinel_nodes is not None and isinstance(_sentinel_nodes, str):
redis_kwargs["sentinel_nodes"] = json.loads(_sentinel_nodes)
_sentinel_password: Optional[str] = redis_kwargs.get(
"sentinel_password", None
) or get_secret_str("REDIS_SENTINEL_PASSWORD")
if _sentinel_password is not None:
redis_kwargs["sentinel_password"] = _sentinel_password
_service_name: Optional[str] = redis_kwargs.get("service_name", None) or get_secret( # type: ignore
"REDIS_SERVICE_NAME"
)
if _service_name is not None:
redis_kwargs["service_name"] = _service_name
# Handle GCP IAM authentication
_gcp_service_account = redis_kwargs.get("gcp_service_account") or get_secret_str(
"REDIS_GCP_SERVICE_ACCOUNT"
)
_gcp_ssl_ca_certs = redis_kwargs.get("gcp_ssl_ca_certs") or get_secret_str(
"REDIS_GCP_SSL_CA_CERTS"
)
if _gcp_service_account is not None:
verbose_logger.debug(
"Setting up GCP IAM authentication for Redis with service account."
)
redis_kwargs["redis_connect_func"] = create_gcp_iam_redis_connect_func(
service_account=_gcp_service_account, ssl_ca_certs=_gcp_ssl_ca_certs
)
# Store GCP service account in redis_connect_func for async cluster access
redis_kwargs["redis_connect_func"]._gcp_service_account = _gcp_service_account
# Remove GCP-specific kwargs that shouldn't be passed to Redis client
redis_kwargs.pop("gcp_service_account", None)
redis_kwargs.pop("gcp_ssl_ca_certs", None)
# Only enable SSL if explicitly requested AND SSL CA certs are provided
if _gcp_ssl_ca_certs and redis_kwargs.get("ssl", False):
redis_kwargs["ssl_ca_certs"] = _gcp_ssl_ca_certs
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
redis_kwargs.pop("host", None)
redis_kwargs.pop("port", None)
redis_kwargs.pop("db", None)
redis_kwargs.pop("password", None)
elif "startup_nodes" in redis_kwargs and redis_kwargs["startup_nodes"] is not None:
pass
elif (
"sentinel_nodes" in redis_kwargs and redis_kwargs["sentinel_nodes"] is not None
):
pass
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.")
# litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis_kwargs
def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
_redis_cluster_nodes_in_env: Optional[str] = get_secret("REDIS_CLUSTER_NODES") # type: ignore
if _redis_cluster_nodes_in_env is not None:
try:
redis_kwargs["startup_nodes"] = json.loads(_redis_cluster_nodes_in_env)
except json.JSONDecodeError:
raise ValueError(
"REDIS_CLUSTER_NODES environment variable is not valid JSON. Please ensure it's properly formatted."
)
verbose_logger.debug("init_redis_cluster: startup nodes are being initialized.")
from redis.cluster import ClusterNode
args = _get_redis_cluster_kwargs()
cluster_kwargs = {}
for arg in redis_kwargs:
if arg in args:
cluster_kwargs[arg] = redis_kwargs[arg]
new_startup_nodes: List[ClusterNode] = []
for item in redis_kwargs["startup_nodes"]:
new_startup_nodes.append(ClusterNode(**item))
cluster_kwargs.pop("startup_nodes", None)
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) # type: ignore
def _init_redis_sentinel(redis_kwargs) -> redis.Redis:
sentinel_nodes = redis_kwargs.get("sentinel_nodes")
sentinel_password = redis_kwargs.get("sentinel_password")
service_name = redis_kwargs.get("service_name")
if not sentinel_nodes or not service_name:
raise ValueError(
"Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
)
verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
# Set up the Sentinel client
sentinel = redis.Sentinel(
sentinel_nodes,
socket_timeout=REDIS_SOCKET_TIMEOUT,
password=sentinel_password,
)
# Return the master instance for the given service
return sentinel.master_for(service_name)
def _init_async_redis_sentinel(redis_kwargs) -> async_redis.Redis:
sentinel_nodes = redis_kwargs.get("sentinel_nodes")
sentinel_password = redis_kwargs.get("sentinel_password")
service_name = redis_kwargs.get("service_name")
if not sentinel_nodes or not service_name:
raise ValueError(
"Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
)
verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
# Set up the Sentinel client
sentinel = async_redis.Sentinel(
sentinel_nodes,
socket_timeout=REDIS_SOCKET_TIMEOUT,
password=sentinel_password,
)
# Return the master instance for the given service
return sentinel.master_for(service_name)
def get_redis_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
args = _get_redis_url_kwargs()
url_kwargs = {}
for arg in redis_kwargs:
if arg in args:
url_kwargs[arg] = redis_kwargs[arg]
return redis.Redis.from_url(**url_kwargs)
if "startup_nodes" in redis_kwargs or get_secret("REDIS_CLUSTER_NODES") is not None: # type: ignore
return init_redis_cluster(redis_kwargs)
# Check for Redis Sentinel
if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
return _init_redis_sentinel(redis_kwargs)
return redis.Redis(**redis_kwargs)
def get_redis_async_client(
connection_pool: Optional[async_redis.BlockingConnectionPool] = None,
**env_overrides,
) -> Union[async_redis.Redis, async_redis.RedisCluster]:
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
if connection_pool is not None:
return async_redis.Redis(connection_pool=connection_pool)
args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)
url_kwargs = {}
for arg in redis_kwargs:
if arg in args:
url_kwargs[arg] = redis_kwargs[arg]
else:
verbose_logger.debug(
"REDIS: ignoring argument: {}. Not an allowed async_redis.Redis.from_url arg.".format(
arg
)
)
return async_redis.Redis.from_url(**url_kwargs)
if "startup_nodes" in redis_kwargs:
from redis.cluster import ClusterNode
args = _get_redis_cluster_kwargs()
cluster_kwargs = {}
for arg in redis_kwargs:
if arg in args:
cluster_kwargs[arg] = redis_kwargs[arg]
# Handle GCP IAM authentication for async clusters
redis_connect_func = cluster_kwargs.pop("redis_connect_func", None)
from litellm import get_secret_str
# Get GCP service account - first try from redis_connect_func, then from environment
gcp_service_account = None
if redis_connect_func and hasattr(redis_connect_func, "_gcp_service_account"):
gcp_service_account = redis_connect_func._gcp_service_account
else:
gcp_service_account = redis_kwargs.get(
"gcp_service_account"
) or get_secret_str("REDIS_GCP_SERVICE_ACCOUNT")
verbose_logger.debug(
f"DEBUG: Redis cluster kwargs: redis_connect_func={redis_connect_func is not None}, gcp_service_account_provided={gcp_service_account is not None}"
)
# If GCP IAM is configured (indicated by redis_connect_func), generate access token and use as password
if redis_connect_func and gcp_service_account:
verbose_logger.debug(
"DEBUG: Generating IAM token for service account (value not logged for security reasons)"
)
try:
# Generate IAM access token using the helper function
access_token = _generate_gcp_iam_access_token(gcp_service_account)
cluster_kwargs["password"] = access_token
verbose_logger.debug(
"DEBUG: Successfully generated GCP IAM access token for async Redis cluster"
)
except Exception as e:
verbose_logger.error(f"Failed to generate GCP IAM access token: {e}")
from redis.exceptions import AuthenticationError
raise AuthenticationError("Failed to generate GCP IAM access token")
else:
verbose_logger.debug(
f"DEBUG: Not using GCP IAM auth - redis_connect_func={redis_connect_func is not None}, gcp_service_account_provided={gcp_service_account is not None}"
)
new_startup_nodes: List[ClusterNode] = []
for item in redis_kwargs["startup_nodes"]:
new_startup_nodes.append(ClusterNode(**item))
cluster_kwargs.pop("startup_nodes", None)
# Create async RedisCluster with IAM token as password if available
cluster_client = async_redis.RedisCluster(
startup_nodes=new_startup_nodes, **cluster_kwargs # type: ignore
)
return cluster_client
# Check for Redis Sentinel
if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
return _init_async_redis_sentinel(redis_kwargs)
_pretty_print_redis_config(redis_kwargs=redis_kwargs)
if connection_pool is not None:
redis_kwargs["connection_pool"] = connection_pool
return async_redis.Redis(
**redis_kwargs,
)
def get_redis_connection_pool(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
verbose_logger.debug("get_redis_connection_pool: redis_kwargs", redis_kwargs)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
pool_kwargs = {
"timeout": REDIS_CONNECTION_POOL_TIMEOUT,
"url": redis_kwargs["url"],
}
if "max_connections" in redis_kwargs:
try:
pool_kwargs["max_connections"] = int(redis_kwargs["max_connections"])
except (TypeError, ValueError):
verbose_logger.warning(
"REDIS: invalid max_connections value %r, ignoring",
redis_kwargs["max_connections"],
)
return async_redis.BlockingConnectionPool.from_url(**pool_kwargs)
connection_class = async_redis.Connection
if "ssl" in redis_kwargs:
connection_class = async_redis.SSLConnection
redis_kwargs.pop("ssl", None)
redis_kwargs["connection_class"] = connection_class
redis_kwargs.pop("startup_nodes", None)
return async_redis.BlockingConnectionPool(
timeout=REDIS_CONNECTION_POOL_TIMEOUT, **redis_kwargs
)
def _pretty_print_redis_config(redis_kwargs: dict) -> None:
"""Pretty print the Redis configuration using rich with sensitive data masking"""
try:
import logging
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
if not verbose_logger.isEnabledFor(logging.DEBUG):
return
console = Console()
# Initialize the sensitive data masker
masker = SensitiveDataMasker()
# Mask sensitive data in redis_kwargs
masked_redis_kwargs = masker.mask_dict(redis_kwargs)
# Create main panel title
title = Text("Redis Configuration", style="bold blue")
# Create configuration table
config_table = Table(
title="🔧 Redis Connection Parameters",
show_header=True,
header_style="bold magenta",
title_justify="left",
)
config_table.add_column("Parameter", style="cyan", no_wrap=True)
config_table.add_column("Value", style="yellow")
# Add rows for each configuration parameter
for key, value in masked_redis_kwargs.items():
if value is not None:
# Special handling for complex objects
if isinstance(value, list):
if key == "startup_nodes" and value:
# Special handling for cluster nodes
value_str = f"[{len(value)} cluster nodes]"
elif key == "sentinel_nodes" and value:
# Special handling for sentinel nodes
value_str = f"[{len(value)} sentinel nodes]"
else:
value_str = str(value)
else:
value_str = str(value)
config_table.add_row(key, value_str)
# Determine connection type
connection_type = "Standard Redis"
if masked_redis_kwargs.get("startup_nodes"):
connection_type = "Redis Cluster"
elif masked_redis_kwargs.get("sentinel_nodes"):
connection_type = "Redis Sentinel"
elif masked_redis_kwargs.get("url"):
connection_type = "Redis (URL-based)"
# Create connection type info
info_table = Table(
title="📊 Connection Info",
show_header=True,
header_style="bold green",
title_justify="left",
)
info_table.add_column("Property", style="cyan", no_wrap=True)
info_table.add_column("Value", style="yellow")
info_table.add_row("Connection Type", connection_type)
# Print everything in a nice panel
console.print("\n")
console.print(Panel(title, border_style="blue"))
console.print(info_table)
console.print(config_table)
console.print("\n")
except ImportError:
# Fallback to simple logging if rich is not available
masker = SensitiveDataMasker()
masked_redis_kwargs = masker.mask_dict(redis_kwargs)
verbose_logger.info(f"Redis configuration: {masked_redis_kwargs}")
except Exception as e:
verbose_logger.error(f"Error pretty printing Redis configuration: {e}")

View File

@@ -0,0 +1,323 @@
import asyncio
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Optional, Union
import litellm
from litellm._logging import verbose_logger
from .integrations.custom_logger import CustomLogger
from .integrations.datadog.datadog import DataDogLogger
from .integrations.opentelemetry import OpenTelemetry
from .integrations.prometheus_services import PrometheusServicesLogger
from .types.services import ServiceLoggerPayload, ServiceTypes
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.proxy._types import UserAPIKeyAuth
Span = Union[_Span, Any]
OTELClass = OpenTelemetry
else:
Span = Any
OTELClass = Any
UserAPIKeyAuth = Any
class ServiceLogging(CustomLogger):
"""
Separate class used for monitoring health of litellm-adjacent services (redis/postgres).
"""
def __init__(self, mock_testing: bool = False) -> None:
self.mock_testing = mock_testing
self.mock_testing_sync_success_hook = 0
self.mock_testing_async_success_hook = 0
self.mock_testing_sync_failure_hook = 0
self.mock_testing_async_failure_hook = 0
if "prometheus_system" in litellm.service_callback:
self.prometheusServicesLogger = PrometheusServicesLogger()
def service_success_hook(
self,
service: ServiceTypes,
duration: float,
call_type: str,
parent_otel_span: Optional[Span] = None,
start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[Union[float, datetime]] = None,
):
"""
Handles both sync and async monitoring by checking for existing event loop.
"""
if self.mock_testing:
self.mock_testing_sync_success_hook += 1
try:
# Try to get the current event loop
loop = asyncio.get_event_loop()
# Check if the loop is running
if loop.is_running():
# If we're in a running loop, create a task
loop.create_task(
self.async_service_success_hook(
service=service,
duration=duration,
call_type=call_type,
parent_otel_span=parent_otel_span,
start_time=start_time,
end_time=end_time,
)
)
else:
# Loop exists but not running, we can use run_until_complete
loop.run_until_complete(
self.async_service_success_hook(
service=service,
duration=duration,
call_type=call_type,
parent_otel_span=parent_otel_span,
start_time=start_time,
end_time=end_time,
)
)
except RuntimeError:
# No event loop exists, create a new one and run
asyncio.run(
self.async_service_success_hook(
service=service,
duration=duration,
call_type=call_type,
parent_otel_span=parent_otel_span,
start_time=start_time,
end_time=end_time,
)
)
def service_failure_hook(
self, service: ServiceTypes, duration: float, error: Exception, call_type: str
):
"""
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
"""
if self.mock_testing:
self.mock_testing_sync_failure_hook += 1
async def async_service_success_hook(
self,
service: ServiceTypes,
call_type: str,
duration: float,
parent_otel_span: Optional[Span] = None,
start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[Union[datetime, float]] = None,
event_metadata: Optional[dict] = None,
):
"""
- For counting if the redis, postgres call is successful
"""
if self.mock_testing:
self.mock_testing_async_success_hook += 1
payload = ServiceLoggerPayload(
is_error=False,
error=None,
service=service,
duration=duration,
call_type=call_type,
event_metadata=event_metadata,
)
for callback in litellm.service_callback:
if callback == "prometheus_system":
await self.init_prometheus_services_logger_if_none()
await self.prometheusServicesLogger.async_service_success_hook(
payload=payload
)
elif callback == "datadog" or isinstance(callback, DataDogLogger):
await self.init_datadog_logger_if_none()
await self.dd_logger.async_service_success_hook(
payload=payload,
parent_otel_span=parent_otel_span,
start_time=start_time,
end_time=end_time,
event_metadata=event_metadata,
)
elif callback == "otel" or isinstance(callback, OpenTelemetry):
_otel_logger_to_use: Optional[OpenTelemetry] = None
if isinstance(callback, OpenTelemetry):
_otel_logger_to_use = callback
else:
from litellm.proxy.proxy_server import open_telemetry_logger
if open_telemetry_logger is not None and isinstance(
open_telemetry_logger, OpenTelemetry
):
_otel_logger_to_use = open_telemetry_logger
if _otel_logger_to_use is not None and parent_otel_span is not None:
await _otel_logger_to_use.async_service_success_hook(
payload=payload,
parent_otel_span=parent_otel_span,
start_time=start_time,
end_time=end_time,
event_metadata=event_metadata,
)
async def init_prometheus_services_logger_if_none(self):
"""
initializes prometheusServicesLogger if it is None or no attribute exists on ServiceLogging Object
"""
if not hasattr(self, "prometheusServicesLogger"):
self.prometheusServicesLogger = PrometheusServicesLogger()
elif self.prometheusServicesLogger is None:
self.prometheusServicesLogger = self.prometheusServicesLogger()
return
async def init_datadog_logger_if_none(self):
"""
initializes dd_logger if it is None or no attribute exists on ServiceLogging Object
"""
from litellm.integrations.datadog.datadog import DataDogLogger
if not hasattr(self, "dd_logger"):
self.dd_logger: DataDogLogger = DataDogLogger()
return
async def init_otel_logger_if_none(self):
"""
initializes otel_logger if it is None or no attribute exists on ServiceLogging Object
"""
from litellm.proxy.proxy_server import open_telemetry_logger
if not hasattr(self, "otel_logger"):
if open_telemetry_logger is not None and isinstance(
open_telemetry_logger, OpenTelemetry
):
self.otel_logger: OpenTelemetry = open_telemetry_logger
else:
verbose_logger.warning(
"ServiceLogger: open_telemetry_logger is None or not an instance of OpenTelemetry"
)
return
async def async_service_failure_hook(
self,
service: ServiceTypes,
duration: float,
error: Union[str, Exception],
call_type: str,
parent_otel_span: Optional[Span] = None,
start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[Union[float, datetime]] = None,
event_metadata: Optional[dict] = None,
):
"""
- For counting if the redis, postgres call is unsuccessful
"""
if self.mock_testing:
self.mock_testing_async_failure_hook += 1
error_message = ""
if isinstance(error, Exception):
error_message = str(error)
elif isinstance(error, str):
error_message = error
payload = ServiceLoggerPayload(
is_error=True,
error=error_message,
service=service,
duration=duration,
call_type=call_type,
event_metadata=event_metadata,
)
for callback in litellm.service_callback:
if callback == "prometheus_system":
await self.init_prometheus_services_logger_if_none()
await self.prometheusServicesLogger.async_service_failure_hook(
payload=payload,
error=error,
)
elif callback == "datadog" or isinstance(callback, DataDogLogger):
await self.init_datadog_logger_if_none()
await self.dd_logger.async_service_failure_hook(
payload=payload,
error=error_message,
parent_otel_span=parent_otel_span,
start_time=start_time,
end_time=end_time,
event_metadata=event_metadata,
)
elif callback == "otel" or isinstance(callback, OpenTelemetry):
_otel_logger_to_use: Optional[OpenTelemetry] = None
if isinstance(callback, OpenTelemetry):
_otel_logger_to_use = callback
else:
from litellm.proxy.proxy_server import open_telemetry_logger
if open_telemetry_logger is not None and isinstance(
open_telemetry_logger, OpenTelemetry
):
_otel_logger_to_use = open_telemetry_logger
if not isinstance(error, str):
error = str(error)
if _otel_logger_to_use is not None and parent_otel_span is not None:
await _otel_logger_to_use.async_service_failure_hook(
payload=payload,
error=error,
parent_otel_span=parent_otel_span,
start_time=start_time,
end_time=end_time,
event_metadata=event_metadata,
)
async def async_post_call_failure_hook(
self,
request_data: dict,
original_exception: Exception,
user_api_key_dict: UserAPIKeyAuth,
traceback_str: Optional[str] = None,
):
"""
Hook to track failed litellm-service calls
"""
return await super().async_post_call_failure_hook(
request_data,
original_exception,
user_api_key_dict,
)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Hook to track latency for litellm proxy llm api calls
"""
try:
_duration = end_time - start_time
if isinstance(_duration, timedelta):
_duration = _duration.total_seconds()
elif isinstance(_duration, float):
pass
else:
raise Exception(
"Duration={} is not a float or timedelta object. type={}".format(
_duration, type(_duration)
)
) # invalid _duration value
# Batch polling callbacks (check_batch_cost) don't include call_type in kwargs.
# Use .get() to avoid KeyError.
await self.async_service_success_hook(
service=ServiceTypes.LITELLM,
duration=_duration,
call_type=kwargs.get("call_type", "unknown"),
)
except Exception as e:
raise e

View File

@@ -0,0 +1,16 @@
"""
Internal unified UUID helper.
Always uses fastuuid for performance.
"""
import fastuuid as _uuid # type: ignore
# Expose a module-like alias so callers can use: uuid.uuid4()
uuid = _uuid
def uuid4():
"""Return a UUID4 using the selected backend."""
return uuid.uuid4()

View File

@@ -0,0 +1,6 @@
import importlib_metadata
try:
version = importlib_metadata.version("litellm")
except Exception:
version = "unknown"

View File

@@ -0,0 +1,73 @@
"""
LiteLLM A2A - Wrapper for invoking A2A protocol agents.
This module provides a thin wrapper around the official `a2a` SDK that:
- Handles httpx client creation and agent card resolution
- Adds LiteLLM logging via @client decorator
- Matches the A2A SDK interface (SendMessageRequest, SendMessageResponse, etc.)
Example usage (standalone functions with @client decorator):
```python
from litellm.a2a_protocol import asend_message
from a2a.types import SendMessageRequest, MessageSendParams
from uuid import uuid4
request = SendMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={
"role": "user",
"parts": [{"kind": "text", "text": "Hello!"}],
"messageId": uuid4().hex,
}
)
)
response = await asend_message(
base_url="http://localhost:10001",
request=request,
)
print(response.model_dump(mode='json', exclude_none=True))
```
Example usage (class-based):
```python
from litellm.a2a_protocol import A2AClient
client = A2AClient(base_url="http://localhost:10001")
response = await client.send_message(request)
```
"""
from litellm.a2a_protocol.client import A2AClient
from litellm.a2a_protocol.exceptions import (
A2AAgentCardError,
A2AConnectionError,
A2AError,
A2ALocalhostURLError,
)
from litellm.a2a_protocol.main import (
aget_agent_card,
asend_message,
asend_message_streaming,
create_a2a_client,
send_message,
)
from litellm.types.agents import LiteLLMSendMessageResponse
__all__ = [
# Client
"A2AClient",
# Functions
"asend_message",
"send_message",
"asend_message_streaming",
"aget_agent_card",
"create_a2a_client",
# Response types
"LiteLLMSendMessageResponse",
# Exceptions
"A2AError",
"A2AConnectionError",
"A2AAgentCardError",
"A2ALocalhostURLError",
]

View File

@@ -0,0 +1,144 @@
"""
Custom A2A Card Resolver for LiteLLM.
Extends the A2A SDK's card resolver to support multiple well-known paths.
"""
from typing import TYPE_CHECKING, Any, Dict, Optional
from litellm._logging import verbose_logger
from litellm.constants import LOCALHOST_URL_PATTERNS
if TYPE_CHECKING:
from a2a.types import AgentCard
# Runtime imports with availability check
_A2ACardResolver: Any = None
AGENT_CARD_WELL_KNOWN_PATH: str = "/.well-known/agent-card.json"
PREV_AGENT_CARD_WELL_KNOWN_PATH: str = "/.well-known/agent.json"
try:
from a2a.client import A2ACardResolver as _A2ACardResolver # type: ignore[no-redef]
from a2a.utils.constants import ( # type: ignore[no-redef]
AGENT_CARD_WELL_KNOWN_PATH,
PREV_AGENT_CARD_WELL_KNOWN_PATH,
)
except ImportError:
pass
def is_localhost_or_internal_url(url: Optional[str]) -> bool:
"""
Check if a URL is a localhost or internal URL.
This detects common development URLs that are accidentally left in
agent cards when deploying to production.
Args:
url: The URL to check
Returns:
True if the URL is localhost/internal
"""
if not url:
return False
url_lower = url.lower()
return any(pattern in url_lower for pattern in LOCALHOST_URL_PATTERNS)
def fix_agent_card_url(agent_card: "AgentCard", base_url: str) -> "AgentCard":
"""
Fix the agent card URL if it contains a localhost/internal address.
Many A2A agents are deployed with agent cards that contain internal URLs
like "http://0.0.0.0:8001/" or "http://localhost:8000/". This function
replaces such URLs with the provided base_url.
Args:
agent_card: The agent card to fix
base_url: The base URL to use as replacement
Returns:
The agent card with the URL fixed if necessary
"""
card_url = getattr(agent_card, "url", None)
if card_url and is_localhost_or_internal_url(card_url):
# Normalize base_url to ensure it ends with /
fixed_url = base_url.rstrip("/") + "/"
agent_card.url = fixed_url
return agent_card
class LiteLLMA2ACardResolver(_A2ACardResolver): # type: ignore[misc]
"""
Custom A2A card resolver that supports multiple well-known paths.
Extends the base A2ACardResolver to try both:
- /.well-known/agent-card.json (standard)
- /.well-known/agent.json (previous/alternative)
"""
async def get_agent_card(
self,
relative_card_path: Optional[str] = None,
http_kwargs: Optional[Dict[str, Any]] = None,
) -> "AgentCard":
"""
Fetch the agent card, trying multiple well-known paths.
First tries the standard path, then falls back to the previous path.
Args:
relative_card_path: Optional path to the agent card endpoint.
If None, tries both well-known paths.
http_kwargs: Optional dictionary of keyword arguments to pass to httpx.get
Returns:
AgentCard from the A2A agent
Raises:
A2AClientHTTPError or A2AClientJSONError if both paths fail
"""
# If a specific path is provided, use the parent implementation
if relative_card_path is not None:
return await super().get_agent_card(
relative_card_path=relative_card_path,
http_kwargs=http_kwargs,
)
# Try both well-known paths
paths = [
AGENT_CARD_WELL_KNOWN_PATH,
PREV_AGENT_CARD_WELL_KNOWN_PATH,
]
last_error = None
for path in paths:
try:
verbose_logger.debug(
f"Attempting to fetch agent card from {self.base_url}{path}"
)
return await super().get_agent_card(
relative_card_path=path,
http_kwargs=http_kwargs,
)
except Exception as e:
verbose_logger.debug(
f"Failed to fetch agent card from {self.base_url}{path}: {e}"
)
last_error = e
continue
# If we get here, all paths failed - re-raise the last error
if last_error is not None:
raise last_error
# This shouldn't happen, but just in case
raise Exception(
f"Failed to fetch agent card from {self.base_url}. "
f"Tried paths: {', '.join(paths)}"
)

View File

@@ -0,0 +1,109 @@
"""
LiteLLM A2A Client class.
Provides a class-based interface for A2A agent invocation.
"""
from typing import TYPE_CHECKING, AsyncIterator, Dict, Optional
from litellm.types.agents import LiteLLMSendMessageResponse
if TYPE_CHECKING:
from a2a.client import A2AClient as A2AClientType
from a2a.types import (
AgentCard,
SendMessageRequest,
SendStreamingMessageRequest,
SendStreamingMessageResponse,
)
class A2AClient:
"""
LiteLLM wrapper for A2A agent invocation.
Creates the underlying A2A client once on first use and reuses it.
Example:
```python
from litellm.a2a_protocol import A2AClient
from a2a.types import SendMessageRequest, MessageSendParams
from uuid import uuid4
client = A2AClient(base_url="http://localhost:10001")
request = SendMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={
"role": "user",
"parts": [{"kind": "text", "text": "Hello!"}],
"messageId": uuid4().hex,
}
)
)
response = await client.send_message(request)
```
"""
def __init__(
self,
base_url: str,
timeout: float = 60.0,
extra_headers: Optional[Dict[str, str]] = None,
):
"""
Initialize the A2A client wrapper.
Args:
base_url: The base URL of the A2A agent (e.g., "http://localhost:10001")
timeout: Request timeout in seconds (default: 60.0)
extra_headers: Optional additional headers to include in requests
"""
self.base_url = base_url
self.timeout = timeout
self.extra_headers = extra_headers
self._a2a_client: Optional["A2AClientType"] = None
async def _get_client(self) -> "A2AClientType":
"""Get or create the underlying A2A client."""
if self._a2a_client is None:
from litellm.a2a_protocol.main import create_a2a_client
self._a2a_client = await create_a2a_client(
base_url=self.base_url,
timeout=self.timeout,
extra_headers=self.extra_headers,
)
return self._a2a_client
async def get_agent_card(self) -> "AgentCard":
"""Fetch the agent card from the server."""
from litellm.a2a_protocol.main import aget_agent_card
return await aget_agent_card(
base_url=self.base_url,
timeout=self.timeout,
extra_headers=self.extra_headers,
)
async def send_message(
self, request: "SendMessageRequest"
) -> LiteLLMSendMessageResponse:
"""Send a message to the A2A agent."""
from litellm.a2a_protocol.main import asend_message
a2a_client = await self._get_client()
return await asend_message(a2a_client=a2a_client, request=request)
async def send_message_streaming(
self, request: "SendStreamingMessageRequest"
) -> AsyncIterator["SendStreamingMessageResponse"]:
"""Send a streaming message to the A2A agent."""
from litellm.a2a_protocol.main import asend_message_streaming
a2a_client = await self._get_client()
async for chunk in asend_message_streaming(
a2a_client=a2a_client, request=request
):
yield chunk

View File

@@ -0,0 +1,107 @@
"""
Cost calculator for A2A (Agent-to-Agent) calls.
Supports dynamic cost parameters that allow platform owners
to define custom costs per agent query or per token.
"""
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import (
Logging as LitellmLoggingObject,
)
else:
LitellmLoggingObject = Any
class A2ACostCalculator:
@staticmethod
def calculate_a2a_cost(
litellm_logging_obj: Optional[LitellmLoggingObject],
) -> float:
"""
Calculate the cost of an A2A send_message call.
Supports multiple cost parameters for platform owners:
- cost_per_query: Fixed cost per query
- input_cost_per_token + output_cost_per_token: Token-based pricing
Priority order:
1. response_cost - if set directly (backward compatibility)
2. cost_per_query - fixed cost per query
3. input_cost_per_token + output_cost_per_token - token-based cost
4. Default to 0.0
Args:
litellm_logging_obj: The LiteLLM logging object containing call details
Returns:
float: The cost of the A2A call
"""
if litellm_logging_obj is None:
return 0.0
model_call_details = litellm_logging_obj.model_call_details
# Check if user set a custom response cost (backward compatibility)
response_cost = model_call_details.get("response_cost", None)
if response_cost is not None:
return float(response_cost)
# Get litellm_params for cost parameters
litellm_params = model_call_details.get("litellm_params", {}) or {}
# Check for cost_per_query (fixed cost per query)
if litellm_params.get("cost_per_query") is not None:
return float(litellm_params["cost_per_query"])
# Check for token-based pricing
input_cost_per_token = litellm_params.get("input_cost_per_token")
output_cost_per_token = litellm_params.get("output_cost_per_token")
if input_cost_per_token is not None or output_cost_per_token is not None:
return A2ACostCalculator._calculate_token_based_cost(
model_call_details=model_call_details,
input_cost_per_token=input_cost_per_token,
output_cost_per_token=output_cost_per_token,
)
# Default to 0.0 for A2A calls
return 0.0
@staticmethod
def _calculate_token_based_cost(
model_call_details: dict,
input_cost_per_token: Optional[float],
output_cost_per_token: Optional[float],
) -> float:
"""
Calculate cost based on token usage and per-token pricing.
Args:
model_call_details: The model call details containing usage
input_cost_per_token: Cost per input token (can be None, defaults to 0)
output_cost_per_token: Cost per output token (can be None, defaults to 0)
Returns:
float: The calculated cost
"""
# Get usage from model_call_details
usage = model_call_details.get("usage")
if usage is None:
return 0.0
# Get token counts
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
# Calculate costs
input_cost = prompt_tokens * (
float(input_cost_per_token) if input_cost_per_token else 0.0
)
output_cost = completion_tokens * (
float(output_cost_per_token) if output_cost_per_token else 0.0
)
return input_cost + output_cost

View File

@@ -0,0 +1,203 @@
"""
A2A Protocol Exception Mapping Utils.
Maps A2A SDK exceptions to LiteLLM A2A exception types.
"""
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_logger
from litellm.a2a_protocol.card_resolver import (
fix_agent_card_url,
is_localhost_or_internal_url,
)
from litellm.a2a_protocol.exceptions import (
A2AAgentCardError,
A2AConnectionError,
A2AError,
A2ALocalhostURLError,
)
from litellm.constants import CONNECTION_ERROR_PATTERNS
if TYPE_CHECKING:
from a2a.client import A2AClient as A2AClientType
# Runtime import
A2A_SDK_AVAILABLE = False
try:
from a2a.client import A2AClient as _A2AClient # type: ignore[no-redef]
A2A_SDK_AVAILABLE = True
except ImportError:
_A2AClient = None # type: ignore[assignment, misc]
class A2AExceptionCheckers:
"""
Helper class for checking various A2A error conditions.
"""
@staticmethod
def is_connection_error(error_str: str) -> bool:
"""
Check if an error string indicates a connection error.
Args:
error_str: The error string to check
Returns:
True if the error indicates a connection issue
"""
if not isinstance(error_str, str):
return False
error_str_lower = error_str.lower()
return any(pattern in error_str_lower for pattern in CONNECTION_ERROR_PATTERNS)
@staticmethod
def is_localhost_url(url: Optional[str]) -> bool:
"""
Check if a URL is a localhost/internal URL.
Args:
url: The URL to check
Returns:
True if the URL is localhost/internal
"""
return is_localhost_or_internal_url(url)
@staticmethod
def is_agent_card_error(error_str: str) -> bool:
"""
Check if an error string indicates an agent card error.
Args:
error_str: The error string to check
Returns:
True if the error is related to agent card fetching/parsing
"""
if not isinstance(error_str, str):
return False
error_str_lower = error_str.lower()
agent_card_patterns = [
"agent card",
"agent-card",
".well-known",
"card not found",
"invalid agent",
]
return any(pattern in error_str_lower for pattern in agent_card_patterns)
def map_a2a_exception(
original_exception: Exception,
card_url: Optional[str] = None,
api_base: Optional[str] = None,
model: Optional[str] = None,
) -> Exception:
"""
Map an A2A SDK exception to a LiteLLM A2A exception type.
Args:
original_exception: The original exception from the A2A SDK
card_url: The URL from the agent card (if available)
api_base: The original API base URL
model: The model/agent name
Returns:
A mapped LiteLLM A2A exception
Raises:
A2ALocalhostURLError: If the error is a connection error to a localhost URL
A2AConnectionError: If the error is a general connection error
A2AAgentCardError: If the error is related to agent card issues
A2AError: For other A2A-related errors
"""
error_str = str(original_exception)
# Check for localhost URL connection error (special case - retryable)
if (
card_url
and api_base
and A2AExceptionCheckers.is_localhost_url(card_url)
and A2AExceptionCheckers.is_connection_error(error_str)
):
raise A2ALocalhostURLError(
localhost_url=card_url,
base_url=api_base,
original_error=original_exception,
model=model,
)
# Check for agent card errors
if A2AExceptionCheckers.is_agent_card_error(error_str):
raise A2AAgentCardError(
message=error_str,
url=api_base,
model=model,
)
# Check for general connection errors
if A2AExceptionCheckers.is_connection_error(error_str):
raise A2AConnectionError(
message=error_str,
url=card_url or api_base,
model=model,
)
# Default: wrap in generic A2AError
raise A2AError(
message=error_str,
model=model,
)
def handle_a2a_localhost_retry(
error: A2ALocalhostURLError,
agent_card: Any,
a2a_client: "A2AClientType",
is_streaming: bool = False,
) -> "A2AClientType":
"""
Handle A2ALocalhostURLError by fixing the URL and creating a new client.
This is called when we catch an A2ALocalhostURLError and want to retry
with the corrected URL.
Args:
error: The localhost URL error
agent_card: The agent card object to fix
a2a_client: The current A2A client
is_streaming: Whether this is a streaming request (for logging)
Returns:
A new A2A client with the fixed URL
Raises:
ImportError: If the A2A SDK is not installed
"""
if not A2A_SDK_AVAILABLE or _A2AClient is None:
raise ImportError(
"A2A SDK is required for localhost retry handling. "
"Install it with: pip install a2a"
)
request_type = "streaming " if is_streaming else ""
verbose_logger.warning(
f"A2A {request_type}request to '{error.localhost_url}' failed: {error.original_error}. "
f"Agent card contains localhost/internal URL. "
f"Retrying with base_url '{error.base_url}'."
)
# Fix the agent card URL
fix_agent_card_url(agent_card, error.base_url)
# Create a new client with the fixed agent card (transport caches URL)
return _A2AClient(
httpx_client=a2a_client._transport.httpx_client, # type: ignore[union-attr]
agent_card=agent_card,
)

View File

@@ -0,0 +1,150 @@
"""
A2A Protocol Exceptions.
Custom exception types for A2A protocol operations, following LiteLLM's exception pattern.
"""
from typing import Optional
import httpx
class A2AError(Exception):
"""
Base exception for A2A protocol errors.
Follows the same pattern as LiteLLM's main exceptions.
"""
def __init__(
self,
message: str,
status_code: int = 500,
llm_provider: str = "a2a_agent",
model: Optional[str] = None,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.status_code = status_code
self.message = f"litellm.A2AError: {message}"
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
self.response = response or httpx.Response(
status_code=self.status_code,
request=httpx.Request(method="POST", url="https://litellm.ai"),
)
super().__init__(self.message)
def __str__(self) -> str:
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self) -> str:
return self.__str__()
class A2AConnectionError(A2AError):
"""
Raised when connection to an A2A agent fails.
This typically occurs when:
- The agent is unreachable
- The agent card contains a localhost/internal URL
- Network issues prevent connection
"""
def __init__(
self,
message: str,
url: Optional[str] = None,
model: Optional[str] = None,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.url = url
super().__init__(
message=message,
status_code=503,
llm_provider="a2a_agent",
model=model,
response=response,
litellm_debug_info=litellm_debug_info,
max_retries=max_retries,
num_retries=num_retries,
)
class A2AAgentCardError(A2AError):
"""
Raised when there's an issue with the agent card.
This includes:
- Failed to fetch agent card
- Invalid agent card format
- Missing required fields
"""
def __init__(
self,
message: str,
url: Optional[str] = None,
model: Optional[str] = None,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
):
self.url = url
super().__init__(
message=message,
status_code=404,
llm_provider="a2a_agent",
model=model,
response=response,
litellm_debug_info=litellm_debug_info,
)
class A2ALocalhostURLError(A2AConnectionError):
"""
Raised when an agent card contains a localhost/internal URL.
Many A2A agents are deployed with agent cards that contain internal URLs
like "http://0.0.0.0:8001/" or "http://localhost:8000/". This error
indicates that the URL needs to be corrected and the request should be retried.
Attributes:
localhost_url: The localhost/internal URL found in the agent card
base_url: The public base URL that should be used instead
original_error: The original connection error that was raised
"""
def __init__(
self,
localhost_url: str,
base_url: str,
original_error: Optional[Exception] = None,
model: Optional[str] = None,
):
self.localhost_url = localhost_url
self.base_url = base_url
self.original_error = original_error
message = (
f"Agent card contains localhost/internal URL '{localhost_url}'. "
f"Retrying with base URL '{base_url}'."
)
super().__init__(
message=message,
url=localhost_url,
model=model,
)

View File

@@ -0,0 +1,74 @@
# A2A to LiteLLM Completion Bridge
Routes A2A protocol requests through `litellm.acompletion`, enabling any LiteLLM-supported provider to be invoked via A2A.
## Flow
```
A2A Request → Transform → litellm.acompletion → Transform → A2A Response
```
## SDK Usage
Use the existing `asend_message` and `asend_message_streaming` functions with `litellm_params`:
```python
from litellm.a2a_protocol import asend_message, asend_message_streaming
from a2a.types import SendMessageRequest, SendStreamingMessageRequest, MessageSendParams
from uuid import uuid4
# Non-streaming
request = SendMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
)
)
response = await asend_message(
request=request,
api_base="http://localhost:2024",
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
)
# Streaming
stream_request = SendStreamingMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
)
)
async for chunk in asend_message_streaming(
request=stream_request,
api_base="http://localhost:2024",
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
):
print(chunk)
```
## Proxy Usage
Configure an agent with `custom_llm_provider` in `litellm_params`:
```yaml
agents:
- agent_name: my-langgraph-agent
agent_card_params:
name: "LangGraph Agent"
url: "http://localhost:2024" # Used as api_base
litellm_params:
custom_llm_provider: langgraph
model: agent
```
When an A2A request hits `/a2a/{agent_id}/message/send`, the bridge:
1. Detects `custom_llm_provider` in agent's `litellm_params`
2. Transforms A2A message → OpenAI messages
3. Calls `litellm.acompletion(model="langgraph/agent", api_base="http://localhost:2024")`
4. Transforms response → A2A format
## Classes
- `A2ACompletionBridgeTransformation` - Static methods for message format conversion
- `A2ACompletionBridgeHandler` - Static methods for handling requests (streaming/non-streaming)

View File

@@ -0,0 +1,23 @@
"""
A2A to LiteLLM Completion Bridge.
This module provides transformation between A2A protocol messages and
LiteLLM completion API, enabling any LiteLLM-supported provider to be
invoked via the A2A protocol.
"""
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
A2ACompletionBridgeHandler,
handle_a2a_completion,
handle_a2a_completion_streaming,
)
from litellm.a2a_protocol.litellm_completion_bridge.transformation import (
A2ACompletionBridgeTransformation,
)
__all__ = [
"A2ACompletionBridgeTransformation",
"A2ACompletionBridgeHandler",
"handle_a2a_completion",
"handle_a2a_completion_streaming",
]

View File

@@ -0,0 +1,299 @@
"""
Handler for A2A to LiteLLM completion bridge.
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
A2A Streaming Events (in order):
1. Task event (kind: "task") - Initial task creation with status "submitted"
2. Status update (kind: "status-update") - Status change to "working"
3. Artifact update (kind: "artifact-update") - Content/artifact delivery
4. Status update (kind: "status-update") - Final status "completed" with final=true
"""
from typing import Any, AsyncIterator, Dict, Optional
import litellm
from litellm._logging import verbose_logger
from litellm.a2a_protocol.litellm_completion_bridge.transformation import (
A2ACompletionBridgeTransformation,
A2AStreamingContext,
)
from litellm.a2a_protocol.providers.config_manager import A2AProviderConfigManager
class A2ACompletionBridgeHandler:
"""
Static methods for handling A2A requests via LiteLLM completion.
"""
@staticmethod
async def handle_non_streaming(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> Dict[str, Any]:
"""
Handle non-streaming A2A request via litellm.acompletion.
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
api_base: API base URL from agent_card_params
Returns:
A2A SendMessageResponse dict
"""
# Get provider config for custom_llm_provider
custom_llm_provider = litellm_params.get("custom_llm_provider")
a2a_provider_config = A2AProviderConfigManager.get_provider_config(
custom_llm_provider=custom_llm_provider
)
# If provider config exists, use it
if a2a_provider_config is not None:
if api_base is None:
raise ValueError(f"api_base is required for {custom_llm_provider}")
verbose_logger.info(f"A2A: Using provider config for {custom_llm_provider}")
response_data = await a2a_provider_config.handle_non_streaming(
request_id=request_id,
params=params,
api_base=api_base,
)
return response_data
# Extract message from params
message = params.get("message", {})
# Transform A2A message to OpenAI format
openai_messages = (
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
)
# Get completion params
custom_llm_provider = litellm_params.get("custom_llm_provider")
model = litellm_params.get("model", "agent")
# Build full model string if provider specified
# Skip prepending if model already starts with the provider prefix
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
full_model = f"{custom_llm_provider}/{model}"
else:
full_model = model
verbose_logger.info(
f"A2A completion bridge: model={full_model}, api_base={api_base}"
)
# Build completion params dict
completion_params = {
"model": full_model,
"messages": openai_messages,
"api_base": api_base,
"stream": False,
}
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
litellm_params_to_add = {
k: v
for k, v in litellm_params.items()
if k not in ("model", "custom_llm_provider")
}
completion_params.update(litellm_params_to_add)
# Call litellm.acompletion
response = await litellm.acompletion(**completion_params)
# Transform response to A2A format
a2a_response = (
A2ACompletionBridgeTransformation.openai_response_to_a2a_response(
response=response,
request_id=request_id,
)
)
verbose_logger.info(f"A2A completion bridge completed: request_id={request_id}")
return a2a_response
@staticmethod
async def handle_streaming(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""
Handle streaming A2A request via litellm.acompletion with stream=True.
Emits proper A2A streaming events:
1. Task event (kind: "task") - Initial task with status "submitted"
2. Status update (kind: "status-update") - Status "working"
3. Artifact update (kind: "artifact-update") - Content delivery
4. Status update (kind: "status-update") - Final "completed" status
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
api_base: API base URL from agent_card_params
Yields:
A2A streaming response events
"""
# Get provider config for custom_llm_provider
custom_llm_provider = litellm_params.get("custom_llm_provider")
a2a_provider_config = A2AProviderConfigManager.get_provider_config(
custom_llm_provider=custom_llm_provider
)
# If provider config exists, use it
if a2a_provider_config is not None:
if api_base is None:
raise ValueError(f"api_base is required for {custom_llm_provider}")
verbose_logger.info(
f"A2A: Using provider config for {custom_llm_provider} (streaming)"
)
async for chunk in a2a_provider_config.handle_streaming(
request_id=request_id,
params=params,
api_base=api_base,
):
yield chunk
return
# Extract message from params
message = params.get("message", {})
# Create streaming context
ctx = A2AStreamingContext(
request_id=request_id,
input_message=message,
)
# Transform A2A message to OpenAI format
openai_messages = (
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
)
# Get completion params
custom_llm_provider = litellm_params.get("custom_llm_provider")
model = litellm_params.get("model", "agent")
# Build full model string if provider specified
# Skip prepending if model already starts with the provider prefix
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
full_model = f"{custom_llm_provider}/{model}"
else:
full_model = model
verbose_logger.info(
f"A2A completion bridge streaming: model={full_model}, api_base={api_base}"
)
# Build completion params dict
completion_params = {
"model": full_model,
"messages": openai_messages,
"api_base": api_base,
"stream": True,
}
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
litellm_params_to_add = {
k: v
for k, v in litellm_params.items()
if k not in ("model", "custom_llm_provider")
}
completion_params.update(litellm_params_to_add)
# 1. Emit initial task event (kind: "task", status: "submitted")
task_event = A2ACompletionBridgeTransformation.create_task_event(ctx)
yield task_event
# 2. Emit status update (kind: "status-update", status: "working")
working_event = A2ACompletionBridgeTransformation.create_status_update_event(
ctx=ctx,
state="working",
final=False,
message_text="Processing request...",
)
yield working_event
# Call litellm.acompletion with streaming
response = await litellm.acompletion(**completion_params)
# 3. Accumulate content and emit artifact update
accumulated_text = ""
chunk_count = 0
async for chunk in response: # type: ignore[union-attr]
chunk_count += 1
# Extract delta content
content = ""
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, "delta") and choice.delta:
content = choice.delta.content or ""
if content:
accumulated_text += content
# Emit artifact update with accumulated content
if accumulated_text:
artifact_event = (
A2ACompletionBridgeTransformation.create_artifact_update_event(
ctx=ctx,
text=accumulated_text,
)
)
yield artifact_event
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
completed_event = A2ACompletionBridgeTransformation.create_status_update_event(
ctx=ctx,
state="completed",
final=True,
)
yield completed_event
verbose_logger.info(
f"A2A completion bridge streaming completed: request_id={request_id}, chunks={chunk_count}"
)
# Convenience functions that delegate to the class methods
async def handle_a2a_completion(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> Dict[str, Any]:
"""Convenience function for non-streaming A2A completion."""
return await A2ACompletionBridgeHandler.handle_non_streaming(
request_id=request_id,
params=params,
litellm_params=litellm_params,
api_base=api_base,
)
async def handle_a2a_completion_streaming(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""Convenience function for streaming A2A completion."""
async for chunk in A2ACompletionBridgeHandler.handle_streaming(
request_id=request_id,
params=params,
litellm_params=litellm_params,
api_base=api_base,
):
yield chunk

View File

@@ -0,0 +1,284 @@
"""
Transformation utilities for A2A <-> OpenAI message format conversion.
A2A Message Format:
{
"role": "user",
"parts": [{"kind": "text", "text": "Hello!"}],
"messageId": "abc123"
}
OpenAI Message Format:
{"role": "user", "content": "Hello!"}
A2A Streaming Events:
- Task event (kind: "task") - Initial task creation with status "submitted"
- Status update (kind: "status-update") - Status changes (working, completed)
- Artifact update (kind: "artifact-update") - Content/artifact delivery
"""
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from uuid import uuid4
from litellm._logging import verbose_logger
class A2AStreamingContext:
"""
Context holder for A2A streaming state.
Tracks task_id, context_id, and message accumulation.
"""
def __init__(self, request_id: str, input_message: Dict[str, Any]):
self.request_id = request_id
self.task_id = str(uuid4())
self.context_id = str(uuid4())
self.input_message = input_message
self.accumulated_text = ""
self.has_emitted_task = False
self.has_emitted_working = False
class A2ACompletionBridgeTransformation:
"""
Static methods for transforming between A2A and OpenAI message formats.
"""
@staticmethod
def a2a_message_to_openai_messages(
a2a_message: Dict[str, Any],
) -> List[Dict[str, str]]:
"""
Transform an A2A message to OpenAI message format.
Args:
a2a_message: A2A message with role, parts, and messageId
Returns:
List of OpenAI-format messages
"""
role = a2a_message.get("role", "user")
parts = a2a_message.get("parts", [])
# Map A2A roles to OpenAI roles
openai_role = role
if role == "user":
openai_role = "user"
elif role == "assistant":
openai_role = "assistant"
elif role == "system":
openai_role = "system"
# Extract text content from parts
content_parts = []
for part in parts:
kind = part.get("kind", "")
if kind == "text":
text = part.get("text", "")
content_parts.append(text)
content = "\n".join(content_parts) if content_parts else ""
verbose_logger.debug(
f"A2A -> OpenAI transform: role={role} -> {openai_role}, content_length={len(content)}"
)
return [{"role": openai_role, "content": content}]
@staticmethod
def openai_response_to_a2a_response(
response: Any,
request_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Transform a LiteLLM ModelResponse to A2A SendMessageResponse format.
Args:
response: LiteLLM ModelResponse object
request_id: Original A2A request ID
Returns:
A2A SendMessageResponse dict
"""
# Extract content from response
content = ""
if hasattr(response, "choices") and response.choices:
choice = response.choices[0]
if hasattr(choice, "message") and choice.message:
content = choice.message.content or ""
# Build A2A message
a2a_message = {
"role": "agent",
"parts": [{"kind": "text", "text": content}],
"messageId": uuid4().hex,
}
# Build A2A response
a2a_response = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"message": a2a_message,
},
}
verbose_logger.debug(f"OpenAI -> A2A transform: content_length={len(content)}")
return a2a_response
@staticmethod
def _get_timestamp() -> str:
"""Get current timestamp in ISO format with timezone."""
return datetime.now(timezone.utc).isoformat()
@staticmethod
def create_task_event(
ctx: A2AStreamingContext,
) -> Dict[str, Any]:
"""
Create the initial task event with status 'submitted'.
This is the first event emitted in an A2A streaming response.
"""
return {
"id": ctx.request_id,
"jsonrpc": "2.0",
"result": {
"contextId": ctx.context_id,
"history": [
{
"contextId": ctx.context_id,
"kind": "message",
"messageId": ctx.input_message.get("messageId", uuid4().hex),
"parts": ctx.input_message.get("parts", []),
"role": ctx.input_message.get("role", "user"),
"taskId": ctx.task_id,
}
],
"id": ctx.task_id,
"kind": "task",
"status": {
"state": "submitted",
},
},
}
@staticmethod
def create_status_update_event(
ctx: A2AStreamingContext,
state: str,
final: bool = False,
message_text: Optional[str] = None,
) -> Dict[str, Any]:
"""
Create a status update event.
Args:
ctx: Streaming context
state: Status state ('working', 'completed')
final: Whether this is the final event
message_text: Optional message text for 'working' status
"""
status: Dict[str, Any] = {
"state": state,
"timestamp": A2ACompletionBridgeTransformation._get_timestamp(),
}
# Add message for 'working' status
if state == "working" and message_text:
status["message"] = {
"contextId": ctx.context_id,
"kind": "message",
"messageId": str(uuid4()),
"parts": [{"kind": "text", "text": message_text}],
"role": "agent",
"taskId": ctx.task_id,
}
return {
"id": ctx.request_id,
"jsonrpc": "2.0",
"result": {
"contextId": ctx.context_id,
"final": final,
"kind": "status-update",
"status": status,
"taskId": ctx.task_id,
},
}
@staticmethod
def create_artifact_update_event(
ctx: A2AStreamingContext,
text: str,
) -> Dict[str, Any]:
"""
Create an artifact update event with content.
Args:
ctx: Streaming context
text: The text content for the artifact
"""
return {
"id": ctx.request_id,
"jsonrpc": "2.0",
"result": {
"artifact": {
"artifactId": str(uuid4()),
"name": "response",
"parts": [{"kind": "text", "text": text}],
},
"contextId": ctx.context_id,
"kind": "artifact-update",
"taskId": ctx.task_id,
},
}
@staticmethod
def openai_chunk_to_a2a_chunk(
chunk: Any,
request_id: Optional[str] = None,
is_final: bool = False,
) -> Optional[Dict[str, Any]]:
"""
Transform a LiteLLM streaming chunk to A2A streaming format.
NOTE: This method is deprecated for streaming. Use the event-based
methods (create_task_event, create_status_update_event,
create_artifact_update_event) instead for proper A2A streaming.
Args:
chunk: LiteLLM ModelResponse chunk
request_id: Original A2A request ID
is_final: Whether this is the final chunk
Returns:
A2A streaming chunk dict or None if no content
"""
# Extract delta content
content = ""
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, "delta") and choice.delta:
content = choice.delta.content or ""
if not content and not is_final:
return None
# Build A2A streaming chunk (legacy format)
a2a_chunk = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"message": {
"role": "agent",
"parts": [{"kind": "text", "text": content}],
"messageId": uuid4().hex,
},
"final": is_final,
},
}
return a2a_chunk

View File

@@ -0,0 +1,744 @@
"""
LiteLLM A2A SDK functions.
Provides standalone functions with @client decorator for LiteLLM logging integration.
"""
import asyncio
import datetime
import uuid
from typing import TYPE_CHECKING, Any, AsyncIterator, Coroutine, Dict, Optional, Union
import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.a2a_protocol.streaming_iterator import A2AStreamingIterator
from litellm.a2a_protocol.utils import A2ARequestUtils
from litellm.constants import DEFAULT_A2A_AGENT_TIMEOUT
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.agents import LiteLLMSendMessageResponse
from litellm.utils import client
if TYPE_CHECKING:
from a2a.client import A2AClient as A2AClientType
from a2a.types import AgentCard, SendMessageRequest, SendStreamingMessageRequest
# Runtime imports with availability check
A2A_SDK_AVAILABLE = False
A2ACardResolver: Any = None
_A2AClient: Any = None
try:
from a2a.client import A2AClient as _A2AClient # type: ignore[no-redef]
A2A_SDK_AVAILABLE = True
except ImportError:
pass
# Import our custom card resolver that supports multiple well-known paths
from litellm.a2a_protocol.card_resolver import LiteLLMA2ACardResolver
from litellm.a2a_protocol.exception_mapping_utils import (
handle_a2a_localhost_retry,
map_a2a_exception,
)
from litellm.a2a_protocol.exceptions import A2ALocalhostURLError
# Use our custom resolver instead of the default A2A SDK resolver
A2ACardResolver = LiteLLMA2ACardResolver
def _set_usage_on_logging_obj(
kwargs: Dict[str, Any],
prompt_tokens: int,
completion_tokens: int,
) -> None:
"""
Set usage on litellm_logging_obj for standard logging payload.
Args:
kwargs: The kwargs dict containing litellm_logging_obj
prompt_tokens: Number of input tokens
completion_tokens: Number of output tokens
"""
litellm_logging_obj = kwargs.get("litellm_logging_obj")
if litellm_logging_obj is not None:
usage = litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
litellm_logging_obj.model_call_details["usage"] = usage
def _set_agent_id_on_logging_obj(
kwargs: Dict[str, Any],
agent_id: Optional[str],
) -> None:
"""
Set agent_id on litellm_logging_obj for SpendLogs tracking.
Args:
kwargs: The kwargs dict containing litellm_logging_obj
agent_id: The A2A agent ID
"""
if agent_id is None:
return
litellm_logging_obj = kwargs.get("litellm_logging_obj")
if litellm_logging_obj is not None:
# Set agent_id directly on model_call_details (same pattern as custom_llm_provider)
litellm_logging_obj.model_call_details["agent_id"] = agent_id
def _get_a2a_model_info(a2a_client: Any, kwargs: Dict[str, Any]) -> str:
"""
Extract agent info and set model/custom_llm_provider for cost tracking.
Sets model info on the litellm_logging_obj if available.
Returns the agent name for logging.
"""
agent_name = "unknown"
# Try to get agent card from our stored attribute first, then fallback to SDK attribute
agent_card = getattr(a2a_client, "_litellm_agent_card", None)
if agent_card is None:
agent_card = getattr(a2a_client, "agent_card", None)
if agent_card is not None:
agent_name = getattr(agent_card, "name", "unknown") or "unknown"
# Build model string
model = f"a2a_agent/{agent_name}"
custom_llm_provider = "a2a_agent"
# Set on litellm_logging_obj if available (for standard logging payload)
litellm_logging_obj = kwargs.get("litellm_logging_obj")
if litellm_logging_obj is not None:
litellm_logging_obj.model = model
litellm_logging_obj.custom_llm_provider = custom_llm_provider
litellm_logging_obj.model_call_details["model"] = model
litellm_logging_obj.model_call_details[
"custom_llm_provider"
] = custom_llm_provider
return agent_name
async def _send_message_via_completion_bridge(
request: "SendMessageRequest",
custom_llm_provider: str,
api_base: Optional[str],
litellm_params: Dict[str, Any],
) -> LiteLLMSendMessageResponse:
"""
Route a send_message through the LiteLLM completion bridge (e.g. LangGraph, Bedrock AgentCore).
Requires request; api_base is optional for providers that derive endpoint from model.
"""
verbose_logger.info(
f"A2A using completion bridge: provider={custom_llm_provider}, api_base={api_base}"
)
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
A2ACompletionBridgeHandler,
)
params = (
request.params.model_dump(mode="json")
if hasattr(request.params, "model_dump")
else dict(request.params)
)
response_dict = await A2ACompletionBridgeHandler.handle_non_streaming(
request_id=str(request.id),
params=params,
litellm_params=litellm_params,
api_base=api_base,
)
return LiteLLMSendMessageResponse.from_dict(response_dict)
async def _execute_a2a_send_with_retry(
a2a_client: Any,
request: Any,
agent_card: Any,
card_url: Optional[str],
api_base: Optional[str],
agent_name: Optional[str],
) -> Any:
"""Send an A2A message with retry logic for localhost URL errors."""
a2a_response = None
for _ in range(2): # max 2 attempts: original + 1 retry
try:
a2a_response = await a2a_client.send_message(request)
break # success, exit retry loop
except A2ALocalhostURLError as e:
a2a_client = handle_a2a_localhost_retry(
error=e,
agent_card=agent_card,
a2a_client=a2a_client,
is_streaming=False,
)
card_url = agent_card.url if agent_card else None
except Exception as e:
try:
map_a2a_exception(e, card_url, api_base, model=agent_name)
except A2ALocalhostURLError as localhost_err:
a2a_client = handle_a2a_localhost_retry(
error=localhost_err,
agent_card=agent_card,
a2a_client=a2a_client,
is_streaming=False,
)
card_url = agent_card.url if agent_card else None
continue
except Exception:
raise
if a2a_response is None:
raise RuntimeError(
"A2A send_message failed: no response received after retry attempts."
)
return a2a_response
@client
async def asend_message(
a2a_client: Optional["A2AClientType"] = None,
request: Optional["SendMessageRequest"] = None,
api_base: Optional[str] = None,
litellm_params: Optional[Dict[str, Any]] = None,
agent_id: Optional[str] = None,
agent_extra_headers: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> LiteLLMSendMessageResponse:
"""
Async: Send a message to an A2A agent.
Uses the @client decorator for LiteLLM logging and tracking.
If litellm_params contains custom_llm_provider, routes through the completion bridge.
Args:
a2a_client: An initialized a2a.client.A2AClient instance (optional if using completion bridge)
request: SendMessageRequest from a2a.types (optional if using completion bridge with api_base)
api_base: API base URL (required for completion bridge, optional for standard A2A)
litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge
agent_id: Optional agent ID for tracking in SpendLogs
**kwargs: Additional arguments passed to the client decorator
Returns:
LiteLLMSendMessageResponse (wraps a2a SendMessageResponse with _hidden_params)
Example (standard A2A):
```python
from litellm.a2a_protocol import asend_message, create_a2a_client
from a2a.types import SendMessageRequest, MessageSendParams
from uuid import uuid4
a2a_client = await create_a2a_client(base_url="http://localhost:10001")
request = SendMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
)
)
response = await asend_message(a2a_client=a2a_client, request=request)
```
Example (completion bridge with LangGraph):
```python
from litellm.a2a_protocol import asend_message
from a2a.types import SendMessageRequest, MessageSendParams
from uuid import uuid4
request = SendMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
)
)
response = await asend_message(
request=request,
api_base="http://localhost:2024",
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
)
```
"""
litellm_params = litellm_params or {}
logging_obj = kwargs.get("litellm_logging_obj")
trace_id = getattr(logging_obj, "litellm_trace_id", None) if logging_obj else None
custom_llm_provider = litellm_params.get("custom_llm_provider")
# Route through completion bridge if custom_llm_provider is set
if custom_llm_provider:
if request is None:
raise ValueError("request is required for completion bridge")
return await _send_message_via_completion_bridge(
request=request,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
litellm_params=litellm_params,
)
# Standard A2A client flow
if request is None:
raise ValueError("request is required")
# Create A2A client if not provided but api_base is available
if a2a_client is None:
if api_base is None:
raise ValueError(
"Either a2a_client or api_base is required for standard A2A flow"
)
trace_id = trace_id or str(uuid.uuid4())
extra_headers: Dict[str, str] = {"X-LiteLLM-Trace-Id": trace_id}
if agent_id:
extra_headers["X-LiteLLM-Agent-Id"] = agent_id
# Overlay agent-level headers (agent headers take precedence over LiteLLM internal ones)
if agent_extra_headers:
extra_headers.update(agent_extra_headers)
a2a_client = await create_a2a_client(
base_url=api_base, extra_headers=extra_headers
)
# Type assertion: a2a_client is guaranteed to be non-None here
assert a2a_client is not None
agent_name = _get_a2a_model_info(a2a_client, kwargs)
verbose_logger.info(f"A2A send_message request_id={request.id}, agent={agent_name}")
# Get agent card URL for localhost retry logic
agent_card = getattr(a2a_client, "_litellm_agent_card", None) or getattr(
a2a_client, "agent_card", None
)
card_url = getattr(agent_card, "url", None) if agent_card else None
context_id = trace_id or str(uuid.uuid4())
message = request.params.message
if isinstance(message, dict):
if message.get("context_id") is None:
message["context_id"] = context_id
else:
if getattr(message, "context_id", None) is None:
message.context_id = context_id
a2a_response = await _execute_a2a_send_with_retry(
a2a_client=a2a_client,
request=request,
agent_card=agent_card,
card_url=card_url,
api_base=api_base,
agent_name=agent_name,
)
verbose_logger.info(f"A2A send_message completed, request_id={request.id}")
# Wrap in LiteLLM response type for _hidden_params support
response = LiteLLMSendMessageResponse.from_a2a_response(a2a_response)
# Calculate token usage from request and response
response_dict = a2a_response.model_dump(mode="json", exclude_none=True)
(
prompt_tokens,
completion_tokens,
_,
) = A2ARequestUtils.calculate_usage_from_request_response(
request=request,
response_dict=response_dict,
)
# Set usage on logging obj for standard logging payload
_set_usage_on_logging_obj(
kwargs=kwargs,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
# Set agent_id on logging obj for SpendLogs tracking
_set_agent_id_on_logging_obj(kwargs=kwargs, agent_id=agent_id)
return response
@client
def send_message(
a2a_client: "A2AClientType",
request: "SendMessageRequest",
**kwargs: Any,
) -> Union[LiteLLMSendMessageResponse, Coroutine[Any, Any, LiteLLMSendMessageResponse]]:
"""
Sync: Send a message to an A2A agent.
Uses the @client decorator for LiteLLM logging and tracking.
Args:
a2a_client: An initialized a2a.client.A2AClient instance
request: SendMessageRequest from a2a.types
**kwargs: Additional arguments passed to the client decorator
Returns:
LiteLLMSendMessageResponse (wraps a2a SendMessageResponse with _hidden_params)
"""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None:
return asend_message(a2a_client=a2a_client, request=request, **kwargs)
else:
return asyncio.run(
asend_message(a2a_client=a2a_client, request=request, **kwargs)
)
def _build_streaming_logging_obj(
request: "SendStreamingMessageRequest",
agent_name: str,
agent_id: Optional[str],
litellm_params: Optional[Dict[str, Any]],
metadata: Optional[Dict[str, Any]],
proxy_server_request: Optional[Dict[str, Any]],
) -> Logging:
"""Build logging object for streaming A2A requests."""
start_time = datetime.datetime.now()
model = f"a2a_agent/{agent_name}"
logging_obj = Logging(
model=model,
messages=[{"role": "user", "content": "streaming-request"}],
stream=False,
call_type="asend_message_streaming",
start_time=start_time,
litellm_call_id=str(request.id),
function_id=str(request.id),
)
logging_obj.model = model
logging_obj.custom_llm_provider = "a2a_agent"
logging_obj.model_call_details["model"] = model
logging_obj.model_call_details["custom_llm_provider"] = "a2a_agent"
if agent_id:
logging_obj.model_call_details["agent_id"] = agent_id
_litellm_params = litellm_params.copy() if litellm_params else {}
if metadata:
_litellm_params["metadata"] = metadata
if proxy_server_request:
_litellm_params["proxy_server_request"] = proxy_server_request
logging_obj.litellm_params = _litellm_params
logging_obj.optional_params = _litellm_params
logging_obj.model_call_details["litellm_params"] = _litellm_params
logging_obj.model_call_details["metadata"] = metadata or {}
return logging_obj
async def asend_message_streaming( # noqa: PLR0915
a2a_client: Optional["A2AClientType"] = None,
request: Optional["SendStreamingMessageRequest"] = None,
api_base: Optional[str] = None,
litellm_params: Optional[Dict[str, Any]] = None,
agent_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
proxy_server_request: Optional[Dict[str, Any]] = None,
agent_extra_headers: Optional[Dict[str, str]] = None,
) -> AsyncIterator[Any]:
"""
Async: Send a streaming message to an A2A agent.
If litellm_params contains custom_llm_provider, routes through the completion bridge.
Args:
a2a_client: An initialized a2a.client.A2AClient instance (optional if using completion bridge)
request: SendStreamingMessageRequest from a2a.types
api_base: API base URL (required for completion bridge)
litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge
agent_id: Optional agent ID for tracking in SpendLogs
metadata: Optional metadata dict (contains user_api_key, user_id, team_id, etc.)
proxy_server_request: Optional proxy server request data
Yields:
SendStreamingMessageResponse chunks from the agent
Example (completion bridge with LangGraph):
```python
from litellm.a2a_protocol import asend_message_streaming
from a2a.types import SendStreamingMessageRequest, MessageSendParams
from uuid import uuid4
request = SendStreamingMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
)
)
async for chunk in asend_message_streaming(
request=request,
api_base="http://localhost:2024",
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
):
print(chunk)
```
"""
litellm_params = litellm_params or {}
custom_llm_provider = litellm_params.get("custom_llm_provider")
# Route through completion bridge if custom_llm_provider is set
if custom_llm_provider:
if request is None:
raise ValueError("request is required for completion bridge")
# api_base is optional for providers that derive endpoint from model (e.g., bedrock/agentcore)
verbose_logger.info(
f"A2A streaming using completion bridge: provider={custom_llm_provider}"
)
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
A2ACompletionBridgeHandler,
)
# Extract params from request
params = (
request.params.model_dump(mode="json")
if hasattr(request.params, "model_dump")
else dict(request.params)
)
async for chunk in A2ACompletionBridgeHandler.handle_streaming(
request_id=str(request.id),
params=params,
litellm_params=litellm_params,
api_base=api_base,
):
yield chunk
return
# Standard A2A client flow
if request is None:
raise ValueError("request is required")
# Create A2A client if not provided but api_base is available
if a2a_client is None:
if api_base is None:
raise ValueError(
"Either a2a_client or api_base is required for standard A2A flow"
)
# Mirror the non-streaming path: always include trace and agent-id headers
streaming_extra_headers: Dict[str, str] = {
"X-LiteLLM-Trace-Id": str(request.id),
}
if agent_id:
streaming_extra_headers["X-LiteLLM-Agent-Id"] = agent_id
if agent_extra_headers:
streaming_extra_headers.update(agent_extra_headers)
a2a_client = await create_a2a_client(
base_url=api_base, extra_headers=streaming_extra_headers
)
# Type assertion: a2a_client is guaranteed to be non-None here
assert a2a_client is not None
verbose_logger.info(f"A2A send_message_streaming request_id={request.id}")
# Build logging object for streaming completion callbacks
agent_card = getattr(a2a_client, "_litellm_agent_card", None) or getattr(
a2a_client, "agent_card", None
)
card_url = getattr(agent_card, "url", None) if agent_card else None
agent_name = getattr(agent_card, "name", "unknown") if agent_card else "unknown"
logging_obj = _build_streaming_logging_obj(
request=request,
agent_name=agent_name,
agent_id=agent_id,
litellm_params=litellm_params,
metadata=metadata,
proxy_server_request=proxy_server_request,
)
# Retry loop: if connection fails due to localhost URL in agent card, retry with fixed URL
# Connection errors in streaming typically occur on first chunk iteration
first_chunk = True
for attempt in range(2): # max 2 attempts: original + 1 retry
stream = a2a_client.send_message_streaming(request)
iterator = A2AStreamingIterator(
stream=stream,
request=request,
logging_obj=logging_obj,
agent_name=agent_name,
)
try:
first_chunk = True
async for chunk in iterator:
if first_chunk:
first_chunk = False # connection succeeded
yield chunk
return # stream completed successfully
except A2ALocalhostURLError as e:
# Only retry on first chunk, not mid-stream
if first_chunk and attempt == 0:
a2a_client = handle_a2a_localhost_retry(
error=e,
agent_card=agent_card,
a2a_client=a2a_client,
is_streaming=True,
)
card_url = agent_card.url if agent_card else None
else:
raise
except Exception as e:
# Only map exception on first chunk
if first_chunk and attempt == 0:
try:
map_a2a_exception(e, card_url, api_base, model=agent_name)
except A2ALocalhostURLError as localhost_err:
# Localhost URL error - fix and retry
a2a_client = handle_a2a_localhost_retry(
error=localhost_err,
agent_card=agent_card,
a2a_client=a2a_client,
is_streaming=True,
)
card_url = agent_card.url if agent_card else None
continue
except Exception:
# Re-raise the mapped exception
raise
raise
async def create_a2a_client(
base_url: str,
timeout: float = 60.0,
extra_headers: Optional[Dict[str, str]] = None,
) -> "A2AClientType":
"""
Create an A2A client for the given agent URL.
This resolves the agent card and returns a ready-to-use A2A client.
The client can be reused for multiple requests.
Args:
base_url: The base URL of the A2A agent (e.g., "http://localhost:10001")
timeout: Request timeout in seconds (default: 60.0)
extra_headers: Optional additional headers to include in requests
Returns:
An initialized a2a.client.A2AClient instance
Example:
```python
from litellm.a2a_protocol import create_a2a_client, asend_message
# Create client once
client = await create_a2a_client(base_url="http://localhost:10001")
# Reuse for multiple requests
response1 = await asend_message(a2a_client=client, request=request1)
response2 = await asend_message(a2a_client=client, request=request2)
```
"""
if not A2A_SDK_AVAILABLE:
raise ImportError(
"The 'a2a' package is required for A2A agent invocation. "
"Install it with: pip install a2a-sdk"
)
verbose_logger.info(f"Creating A2A client for {base_url}")
# Use get_async_httpx_client with per-agent params so that different agents
# (with different extra_headers) get separate cached clients. The params
# dict is hashed into the cache key, keeping agent auth isolated while
# still reusing connections within the same agent.
#
# Only pass params that AsyncHTTPHandler.__init__ accepts (e.g. timeout).
# Use "disable_aiohttp_transport" key for cache-key-only data (it's
# filtered out before reaching the constructor).
_client_params: dict = {"timeout": timeout}
if extra_headers:
# Encode headers into a cache-key-only param so each unique header
# set produces a distinct cache key.
_client_params["disable_aiohttp_transport"] = str(sorted(extra_headers.items()))
_async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.A2AProvider,
params=_client_params,
)
httpx_client = _async_handler.client
if extra_headers:
httpx_client.headers.update(extra_headers)
verbose_proxy_logger.debug(
f"A2A client created with extra_headers={list(extra_headers.keys())}"
)
# Resolve agent card
resolver = A2ACardResolver(
httpx_client=httpx_client,
base_url=base_url,
)
agent_card = await resolver.get_agent_card()
verbose_logger.debug(
f"Resolved agent card: {agent_card.name if hasattr(agent_card, 'name') else 'unknown'}"
)
# Create A2A client
a2a_client = _A2AClient(
httpx_client=httpx_client,
agent_card=agent_card,
)
# Store agent_card on client for later retrieval (SDK doesn't expose it)
a2a_client._litellm_agent_card = agent_card # type: ignore[attr-defined]
verbose_logger.info(f"A2A client created for {base_url}")
return a2a_client
async def aget_agent_card(
base_url: str,
timeout: float = DEFAULT_A2A_AGENT_TIMEOUT,
extra_headers: Optional[Dict[str, str]] = None,
) -> "AgentCard":
"""
Fetch the agent card from an A2A agent.
Args:
base_url: The base URL of the A2A agent (e.g., "http://localhost:10001")
timeout: Request timeout in seconds (default: 60.0)
extra_headers: Optional additional headers to include in requests
Returns:
AgentCard from the A2A agent
"""
if not A2A_SDK_AVAILABLE:
raise ImportError(
"The 'a2a' package is required for A2A agent invocation. "
"Install it with: pip install a2a-sdk"
)
verbose_logger.info(f"Fetching agent card from {base_url}")
# Use LiteLLM's cached httpx client
http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.A2A,
params={"timeout": timeout},
)
httpx_client = http_handler.client
resolver = A2ACardResolver(
httpx_client=httpx_client,
base_url=base_url,
)
agent_card = await resolver.get_agent_card()
verbose_logger.info(
f"Fetched agent card: {agent_card.name if hasattr(agent_card, 'name') else 'unknown'}"
)
return agent_card

View File

@@ -0,0 +1,10 @@
"""
A2A Protocol Providers.
This module contains provider-specific implementations for the A2A protocol.
"""
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
from litellm.a2a_protocol.providers.config_manager import A2AProviderConfigManager
__all__ = ["BaseA2AProviderConfig", "A2AProviderConfigManager"]

View File

@@ -0,0 +1,62 @@
"""
Base configuration for A2A protocol providers.
"""
from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, Dict
class BaseA2AProviderConfig(ABC):
"""
Base configuration class for A2A protocol providers.
Each provider should implement this interface to define how to handle
A2A requests for their specific agent type.
"""
@abstractmethod
async def handle_non_streaming(
self,
request_id: str,
params: Dict[str, Any],
api_base: str,
**kwargs,
) -> Dict[str, Any]:
"""
Handle non-streaming A2A request.
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
api_base: Base URL of the agent
**kwargs: Additional provider-specific parameters
Returns:
A2A SendMessageResponse dict
"""
pass
@abstractmethod
async def handle_streaming(
self,
request_id: str,
params: Dict[str, Any],
api_base: str,
**kwargs,
) -> AsyncIterator[Dict[str, Any]]:
"""
Handle streaming A2A request.
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
api_base: Base URL of the agent
**kwargs: Additional provider-specific parameters
Yields:
A2A streaming response events
"""
# This is an abstract method - subclasses must implement
# The yield is here to make this a generator function
if False: # pragma: no cover
yield {}

View File

@@ -0,0 +1,47 @@
"""
A2A Provider Config Manager.
Manages provider-specific configurations for A2A protocol.
"""
from typing import Optional
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
class A2AProviderConfigManager:
"""
Manager for A2A provider configurations.
Similar to ProviderConfigManager in litellm.utils but specifically for A2A providers.
"""
@staticmethod
def get_provider_config(
custom_llm_provider: Optional[str],
) -> Optional[BaseA2AProviderConfig]:
"""
Get the provider configuration for a given custom_llm_provider.
Args:
custom_llm_provider: The provider identifier (e.g., "pydantic_ai_agents")
Returns:
Provider configuration instance or None if not found
"""
if custom_llm_provider is None:
return None
if custom_llm_provider == "pydantic_ai_agents":
from litellm.a2a_protocol.providers.pydantic_ai_agents.config import (
PydanticAIProviderConfig,
)
return PydanticAIProviderConfig()
# Add more providers here as needed
# elif custom_llm_provider == "another_provider":
# from litellm.a2a_protocol.providers.another_provider.config import AnotherProviderConfig
# return AnotherProviderConfig()
return None

View File

@@ -0,0 +1,74 @@
# A2A to LiteLLM Completion Bridge
Routes A2A protocol requests through `litellm.acompletion`, enabling any LiteLLM-supported provider to be invoked via A2A.
## Flow
```
A2A Request → Transform → litellm.acompletion → Transform → A2A Response
```
## SDK Usage
Use the existing `asend_message` and `asend_message_streaming` functions with `litellm_params`:
```python
from litellm.a2a_protocol import asend_message, asend_message_streaming
from a2a.types import SendMessageRequest, SendStreamingMessageRequest, MessageSendParams
from uuid import uuid4
# Non-streaming
request = SendMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
)
)
response = await asend_message(
request=request,
api_base="http://localhost:2024",
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
)
# Streaming
stream_request = SendStreamingMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
)
)
async for chunk in asend_message_streaming(
request=stream_request,
api_base="http://localhost:2024",
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
):
print(chunk)
```
## Proxy Usage
Configure an agent with `custom_llm_provider` in `litellm_params`:
```yaml
agents:
- agent_name: my-langgraph-agent
agent_card_params:
name: "LangGraph Agent"
url: "http://localhost:2024" # Used as api_base
litellm_params:
custom_llm_provider: langgraph
model: agent
```
When an A2A request hits `/a2a/{agent_id}/message/send`, the bridge:
1. Detects `custom_llm_provider` in agent's `litellm_params`
2. Transforms A2A message → OpenAI messages
3. Calls `litellm.acompletion(model="langgraph/agent", api_base="http://localhost:2024")`
4. Transforms response → A2A format
## Classes
- `A2ACompletionBridgeTransformation` - Static methods for message format conversion
- `A2ACompletionBridgeHandler` - Static methods for handling requests (streaming/non-streaming)

View File

@@ -0,0 +1,5 @@
"""
LiteLLM Completion bridge provider for A2A protocol.
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
"""

View File

@@ -0,0 +1,301 @@
"""
Handler for A2A to LiteLLM completion bridge.
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
A2A Streaming Events (in order):
1. Task event (kind: "task") - Initial task creation with status "submitted"
2. Status update (kind: "status-update") - Status change to "working"
3. Artifact update (kind: "artifact-update") - Content/artifact delivery
4. Status update (kind: "status-update") - Final status "completed" with final=true
"""
from typing import Any, AsyncIterator, Dict, Optional
import litellm
from litellm._logging import verbose_logger
from litellm.a2a_protocol.litellm_completion_bridge.pydantic_ai_transformation import (
PydanticAITransformation,
)
from litellm.a2a_protocol.litellm_completion_bridge.transformation import (
A2ACompletionBridgeTransformation,
A2AStreamingContext,
)
class A2ACompletionBridgeHandler:
"""
Static methods for handling A2A requests via LiteLLM completion.
"""
@staticmethod
async def handle_non_streaming(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> Dict[str, Any]:
"""
Handle non-streaming A2A request via litellm.acompletion.
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
api_base: API base URL from agent_card_params
Returns:
A2A SendMessageResponse dict
"""
# Check if this is a Pydantic AI agent request
custom_llm_provider = litellm_params.get("custom_llm_provider")
if custom_llm_provider == "pydantic_ai_agents":
if api_base is None:
raise ValueError("api_base is required for Pydantic AI agents")
verbose_logger.info(
f"Pydantic AI: Routing to Pydantic AI agent at {api_base}"
)
# Send request directly to Pydantic AI agent
response_data = await PydanticAITransformation.send_non_streaming_request(
api_base=api_base,
request_id=request_id,
params=params,
)
return response_data
# Extract message from params
message = params.get("message", {})
# Transform A2A message to OpenAI format
openai_messages = (
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
)
# Get completion params
custom_llm_provider = litellm_params.get("custom_llm_provider")
model = litellm_params.get("model", "agent")
# Build full model string if provider specified
# Skip prepending if model already starts with the provider prefix
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
full_model = f"{custom_llm_provider}/{model}"
else:
full_model = model
verbose_logger.info(
f"A2A completion bridge: model={full_model}, api_base={api_base}"
)
# Build completion params dict
completion_params = {
"model": full_model,
"messages": openai_messages,
"api_base": api_base,
"stream": False,
}
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
litellm_params_to_add = {
k: v
for k, v in litellm_params.items()
if k not in ("model", "custom_llm_provider")
}
completion_params.update(litellm_params_to_add)
# Call litellm.acompletion
response = await litellm.acompletion(**completion_params)
# Transform response to A2A format
a2a_response = (
A2ACompletionBridgeTransformation.openai_response_to_a2a_response(
response=response,
request_id=request_id,
)
)
verbose_logger.info(f"A2A completion bridge completed: request_id={request_id}")
return a2a_response
@staticmethod
async def handle_streaming(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""
Handle streaming A2A request via litellm.acompletion with stream=True.
Emits proper A2A streaming events:
1. Task event (kind: "task") - Initial task with status "submitted"
2. Status update (kind: "status-update") - Status "working"
3. Artifact update (kind: "artifact-update") - Content delivery
4. Status update (kind: "status-update") - Final "completed" status
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
api_base: API base URL from agent_card_params
Yields:
A2A streaming response events
"""
# Check if this is a Pydantic AI agent request
custom_llm_provider = litellm_params.get("custom_llm_provider")
if custom_llm_provider == "pydantic_ai_agents":
if api_base is None:
raise ValueError("api_base is required for Pydantic AI agents")
verbose_logger.info(
f"Pydantic AI: Faking streaming for Pydantic AI agent at {api_base}"
)
# Get non-streaming response first
response_data = await PydanticAITransformation.send_non_streaming_request(
api_base=api_base,
request_id=request_id,
params=params,
)
# Convert to fake streaming
async for chunk in PydanticAITransformation.fake_streaming_from_response(
response_data=response_data,
request_id=request_id,
):
yield chunk
return
# Extract message from params
message = params.get("message", {})
# Create streaming context
ctx = A2AStreamingContext(
request_id=request_id,
input_message=message,
)
# Transform A2A message to OpenAI format
openai_messages = (
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
)
# Get completion params
custom_llm_provider = litellm_params.get("custom_llm_provider")
model = litellm_params.get("model", "agent")
# Build full model string if provider specified
# Skip prepending if model already starts with the provider prefix
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
full_model = f"{custom_llm_provider}/{model}"
else:
full_model = model
verbose_logger.info(
f"A2A completion bridge streaming: model={full_model}, api_base={api_base}"
)
# Build completion params dict
completion_params = {
"model": full_model,
"messages": openai_messages,
"api_base": api_base,
"stream": True,
}
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
litellm_params_to_add = {
k: v
for k, v in litellm_params.items()
if k not in ("model", "custom_llm_provider")
}
completion_params.update(litellm_params_to_add)
# 1. Emit initial task event (kind: "task", status: "submitted")
task_event = A2ACompletionBridgeTransformation.create_task_event(ctx)
yield task_event
# 2. Emit status update (kind: "status-update", status: "working")
working_event = A2ACompletionBridgeTransformation.create_status_update_event(
ctx=ctx,
state="working",
final=False,
message_text="Processing request...",
)
yield working_event
# Call litellm.acompletion with streaming
response = await litellm.acompletion(**completion_params)
# 3. Accumulate content and emit artifact update
accumulated_text = ""
chunk_count = 0
async for chunk in response: # type: ignore[union-attr]
chunk_count += 1
# Extract delta content
content = ""
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, "delta") and choice.delta:
content = choice.delta.content or ""
if content:
accumulated_text += content
# Emit artifact update with accumulated content
if accumulated_text:
artifact_event = (
A2ACompletionBridgeTransformation.create_artifact_update_event(
ctx=ctx,
text=accumulated_text,
)
)
yield artifact_event
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
completed_event = A2ACompletionBridgeTransformation.create_status_update_event(
ctx=ctx,
state="completed",
final=True,
)
yield completed_event
verbose_logger.info(
f"A2A completion bridge streaming completed: request_id={request_id}, chunks={chunk_count}"
)
# Convenience functions that delegate to the class methods
async def handle_a2a_completion(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> Dict[str, Any]:
"""Convenience function for non-streaming A2A completion."""
return await A2ACompletionBridgeHandler.handle_non_streaming(
request_id=request_id,
params=params,
litellm_params=litellm_params,
api_base=api_base,
)
async def handle_a2a_completion_streaming(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""Convenience function for streaming A2A completion."""
async for chunk in A2ACompletionBridgeHandler.handle_streaming(
request_id=request_id,
params=params,
litellm_params=litellm_params,
api_base=api_base,
):
yield chunk

View File

@@ -0,0 +1,284 @@
"""
Transformation utilities for A2A <-> OpenAI message format conversion.
A2A Message Format:
{
"role": "user",
"parts": [{"kind": "text", "text": "Hello!"}],
"messageId": "abc123"
}
OpenAI Message Format:
{"role": "user", "content": "Hello!"}
A2A Streaming Events:
- Task event (kind: "task") - Initial task creation with status "submitted"
- Status update (kind: "status-update") - Status changes (working, completed)
- Artifact update (kind: "artifact-update") - Content/artifact delivery
"""
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from uuid import uuid4
from litellm._logging import verbose_logger
class A2AStreamingContext:
"""
Context holder for A2A streaming state.
Tracks task_id, context_id, and message accumulation.
"""
def __init__(self, request_id: str, input_message: Dict[str, Any]):
self.request_id = request_id
self.task_id = str(uuid4())
self.context_id = str(uuid4())
self.input_message = input_message
self.accumulated_text = ""
self.has_emitted_task = False
self.has_emitted_working = False
class A2ACompletionBridgeTransformation:
"""
Static methods for transforming between A2A and OpenAI message formats.
"""
@staticmethod
def a2a_message_to_openai_messages(
a2a_message: Dict[str, Any],
) -> List[Dict[str, str]]:
"""
Transform an A2A message to OpenAI message format.
Args:
a2a_message: A2A message with role, parts, and messageId
Returns:
List of OpenAI-format messages
"""
role = a2a_message.get("role", "user")
parts = a2a_message.get("parts", [])
# Map A2A roles to OpenAI roles
openai_role = role
if role == "user":
openai_role = "user"
elif role == "assistant":
openai_role = "assistant"
elif role == "system":
openai_role = "system"
# Extract text content from parts
content_parts = []
for part in parts:
kind = part.get("kind", "")
if kind == "text":
text = part.get("text", "")
content_parts.append(text)
content = "\n".join(content_parts) if content_parts else ""
verbose_logger.debug(
f"A2A -> OpenAI transform: role={role} -> {openai_role}, content_length={len(content)}"
)
return [{"role": openai_role, "content": content}]
@staticmethod
def openai_response_to_a2a_response(
response: Any,
request_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Transform a LiteLLM ModelResponse to A2A SendMessageResponse format.
Args:
response: LiteLLM ModelResponse object
request_id: Original A2A request ID
Returns:
A2A SendMessageResponse dict
"""
# Extract content from response
content = ""
if hasattr(response, "choices") and response.choices:
choice = response.choices[0]
if hasattr(choice, "message") and choice.message:
content = choice.message.content or ""
# Build A2A message
a2a_message = {
"role": "agent",
"parts": [{"kind": "text", "text": content}],
"messageId": uuid4().hex,
}
# Build A2A response
a2a_response = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"message": a2a_message,
},
}
verbose_logger.debug(f"OpenAI -> A2A transform: content_length={len(content)}")
return a2a_response
@staticmethod
def _get_timestamp() -> str:
"""Get current timestamp in ISO format with timezone."""
return datetime.now(timezone.utc).isoformat()
@staticmethod
def create_task_event(
ctx: A2AStreamingContext,
) -> Dict[str, Any]:
"""
Create the initial task event with status 'submitted'.
This is the first event emitted in an A2A streaming response.
"""
return {
"id": ctx.request_id,
"jsonrpc": "2.0",
"result": {
"contextId": ctx.context_id,
"history": [
{
"contextId": ctx.context_id,
"kind": "message",
"messageId": ctx.input_message.get("messageId", uuid4().hex),
"parts": ctx.input_message.get("parts", []),
"role": ctx.input_message.get("role", "user"),
"taskId": ctx.task_id,
}
],
"id": ctx.task_id,
"kind": "task",
"status": {
"state": "submitted",
},
},
}
@staticmethod
def create_status_update_event(
ctx: A2AStreamingContext,
state: str,
final: bool = False,
message_text: Optional[str] = None,
) -> Dict[str, Any]:
"""
Create a status update event.
Args:
ctx: Streaming context
state: Status state ('working', 'completed')
final: Whether this is the final event
message_text: Optional message text for 'working' status
"""
status: Dict[str, Any] = {
"state": state,
"timestamp": A2ACompletionBridgeTransformation._get_timestamp(),
}
# Add message for 'working' status
if state == "working" and message_text:
status["message"] = {
"contextId": ctx.context_id,
"kind": "message",
"messageId": str(uuid4()),
"parts": [{"kind": "text", "text": message_text}],
"role": "agent",
"taskId": ctx.task_id,
}
return {
"id": ctx.request_id,
"jsonrpc": "2.0",
"result": {
"contextId": ctx.context_id,
"final": final,
"kind": "status-update",
"status": status,
"taskId": ctx.task_id,
},
}
@staticmethod
def create_artifact_update_event(
ctx: A2AStreamingContext,
text: str,
) -> Dict[str, Any]:
"""
Create an artifact update event with content.
Args:
ctx: Streaming context
text: The text content for the artifact
"""
return {
"id": ctx.request_id,
"jsonrpc": "2.0",
"result": {
"artifact": {
"artifactId": str(uuid4()),
"name": "response",
"parts": [{"kind": "text", "text": text}],
},
"contextId": ctx.context_id,
"kind": "artifact-update",
"taskId": ctx.task_id,
},
}
@staticmethod
def openai_chunk_to_a2a_chunk(
chunk: Any,
request_id: Optional[str] = None,
is_final: bool = False,
) -> Optional[Dict[str, Any]]:
"""
Transform a LiteLLM streaming chunk to A2A streaming format.
NOTE: This method is deprecated for streaming. Use the event-based
methods (create_task_event, create_status_update_event,
create_artifact_update_event) instead for proper A2A streaming.
Args:
chunk: LiteLLM ModelResponse chunk
request_id: Original A2A request ID
is_final: Whether this is the final chunk
Returns:
A2A streaming chunk dict or None if no content
"""
# Extract delta content
content = ""
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, "delta") and choice.delta:
content = choice.delta.content or ""
if not content and not is_final:
return None
# Build A2A streaming chunk (legacy format)
a2a_chunk = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"message": {
"role": "agent",
"parts": [{"kind": "text", "text": content}],
"messageId": uuid4().hex,
},
"final": is_final,
},
}
return a2a_chunk

View File

@@ -0,0 +1,16 @@
"""
Pydantic AI agent provider for A2A protocol.
Pydantic AI agents follow A2A protocol but don't support streaming natively.
This provider handles fake streaming by converting non-streaming responses into streaming chunks.
"""
from litellm.a2a_protocol.providers.pydantic_ai_agents.config import (
PydanticAIProviderConfig,
)
from litellm.a2a_protocol.providers.pydantic_ai_agents.handler import PydanticAIHandler
from litellm.a2a_protocol.providers.pydantic_ai_agents.transformation import (
PydanticAITransformation,
)
__all__ = ["PydanticAIHandler", "PydanticAITransformation", "PydanticAIProviderConfig"]

View File

@@ -0,0 +1,50 @@
"""
Pydantic AI provider configuration.
"""
from typing import Any, AsyncIterator, Dict
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
from litellm.a2a_protocol.providers.pydantic_ai_agents.handler import PydanticAIHandler
class PydanticAIProviderConfig(BaseA2AProviderConfig):
"""
Provider configuration for Pydantic AI agents.
Pydantic AI agents follow A2A protocol but don't support streaming natively.
This config provides fake streaming by converting non-streaming responses into streaming chunks.
"""
async def handle_non_streaming(
self,
request_id: str,
params: Dict[str, Any],
api_base: str,
**kwargs,
) -> Dict[str, Any]:
"""Handle non-streaming request to Pydantic AI agent."""
return await PydanticAIHandler.handle_non_streaming(
request_id=request_id,
params=params,
api_base=api_base,
timeout=kwargs.get("timeout", 60.0),
)
async def handle_streaming(
self,
request_id: str,
params: Dict[str, Any],
api_base: str,
**kwargs,
) -> AsyncIterator[Dict[str, Any]]:
"""Handle streaming request with fake streaming."""
async for chunk in PydanticAIHandler.handle_streaming(
request_id=request_id,
params=params,
api_base=api_base,
timeout=kwargs.get("timeout", 60.0),
chunk_size=kwargs.get("chunk_size", 50),
delay_ms=kwargs.get("delay_ms", 10),
):
yield chunk

View File

@@ -0,0 +1,102 @@
"""
Handler for Pydantic AI agents.
Pydantic AI agents follow A2A protocol but don't support streaming natively.
This handler provides fake streaming by converting non-streaming responses into streaming chunks.
"""
from typing import Any, AsyncIterator, Dict
from litellm._logging import verbose_logger
from litellm.a2a_protocol.providers.pydantic_ai_agents.transformation import (
PydanticAITransformation,
)
class PydanticAIHandler:
"""
Handler for Pydantic AI agent requests.
Provides:
- Direct non-streaming requests to Pydantic AI agents
- Fake streaming by converting non-streaming responses into streaming chunks
"""
@staticmethod
async def handle_non_streaming(
request_id: str,
params: Dict[str, Any],
api_base: str,
timeout: float = 60.0,
) -> Dict[str, Any]:
"""
Handle non-streaming request to Pydantic AI agent.
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
api_base: Base URL of the Pydantic AI agent
timeout: Request timeout in seconds
Returns:
A2A SendMessageResponse dict
"""
verbose_logger.info(f"Pydantic AI: Routing to Pydantic AI agent at {api_base}")
# Send request directly to Pydantic AI agent
response_data = await PydanticAITransformation.send_non_streaming_request(
api_base=api_base,
request_id=request_id,
params=params,
timeout=timeout,
)
return response_data
@staticmethod
async def handle_streaming(
request_id: str,
params: Dict[str, Any],
api_base: str,
timeout: float = 60.0,
chunk_size: int = 50,
delay_ms: int = 10,
) -> AsyncIterator[Dict[str, Any]]:
"""
Handle streaming request to Pydantic AI agent with fake streaming.
Since Pydantic AI agents don't support streaming natively, this method:
1. Makes a non-streaming request
2. Converts the response into streaming chunks
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
api_base: Base URL of the Pydantic AI agent
timeout: Request timeout in seconds
chunk_size: Number of characters per chunk
delay_ms: Delay between chunks in milliseconds
Yields:
A2A streaming response events
"""
verbose_logger.info(
f"Pydantic AI: Faking streaming for Pydantic AI agent at {api_base}"
)
# Get raw task response first (not the transformed A2A format)
raw_response = await PydanticAITransformation.send_and_get_raw_response(
api_base=api_base,
request_id=request_id,
params=params,
timeout=timeout,
)
# Convert raw task response to fake streaming chunks
async for chunk in PydanticAITransformation.fake_streaming_from_response(
response_data=raw_response,
request_id=request_id,
chunk_size=chunk_size,
delay_ms=delay_ms,
):
yield chunk

View File

@@ -0,0 +1,530 @@
"""
Transformation layer for Pydantic AI agents.
Pydantic AI agents follow A2A protocol but don't support streaming.
This module provides fake streaming by converting non-streaming responses into streaming chunks.
"""
import asyncio
from typing import Any, AsyncIterator, Dict, cast
from uuid import uuid4
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
class PydanticAITransformation:
"""
Transformation layer for Pydantic AI agents.
Handles:
- Direct A2A requests to Pydantic AI endpoints
- Polling for task completion (since Pydantic AI doesn't support streaming)
- Fake streaming by chunking non-streaming responses
"""
@staticmethod
def _remove_none_values(obj: Any) -> Any:
"""
Recursively remove None values from a dict/list structure.
FastA2A/Pydantic AI servers don't accept None values for optional fields -
they expect those fields to be omitted entirely.
Args:
obj: Dict, list, or other value to clean
Returns:
Cleaned object with None values removed
"""
if isinstance(obj, dict):
return {
k: PydanticAITransformation._remove_none_values(v)
for k, v in obj.items()
if v is not None
}
elif isinstance(obj, list):
return [
PydanticAITransformation._remove_none_values(item)
for item in obj
if item is not None
]
else:
return obj
@staticmethod
def _params_to_dict(params: Any) -> Dict[str, Any]:
"""
Convert params to a dict, handling Pydantic models.
Args:
params: Dict or Pydantic model
Returns:
Dict representation of params
"""
if hasattr(params, "model_dump"):
# Pydantic v2 model
return params.model_dump(mode="python", exclude_none=True)
elif hasattr(params, "dict"):
# Pydantic v1 model
return params.dict(exclude_none=True)
elif isinstance(params, dict):
return params
else:
# Try to convert to dict
return dict(params)
@staticmethod
async def _poll_for_completion(
client: AsyncHTTPHandler,
endpoint: str,
task_id: str,
request_id: str,
max_attempts: int = 30,
poll_interval: float = 0.5,
) -> Dict[str, Any]:
"""
Poll for task completion using tasks/get method.
Args:
client: HTTPX async client
endpoint: API endpoint URL
task_id: Task ID to poll for
request_id: JSON-RPC request ID
max_attempts: Maximum polling attempts
poll_interval: Seconds between poll attempts
Returns:
Completed task response
"""
for attempt in range(max_attempts):
poll_request = {
"jsonrpc": "2.0",
"id": f"{request_id}-poll-{attempt}",
"method": "tasks/get",
"params": {"id": task_id},
}
response = await client.post(
endpoint,
json=poll_request,
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
poll_data = response.json()
result = poll_data.get("result", {})
status = result.get("status", {})
state = status.get("state", "")
verbose_logger.debug(
f"Pydantic AI: Poll attempt {attempt + 1}/{max_attempts}, state={state}"
)
if state == "completed":
return poll_data
elif state in ("failed", "canceled"):
raise Exception(f"Task {task_id} ended with state: {state}")
await asyncio.sleep(poll_interval)
raise TimeoutError(
f"Task {task_id} did not complete within {max_attempts * poll_interval} seconds"
)
@staticmethod
async def _send_and_poll_raw(
api_base: str,
request_id: str,
params: Any,
timeout: float = 60.0,
) -> Dict[str, Any]:
"""
Send a request to Pydantic AI agent and return the raw task response.
This is an internal method used by both non-streaming and streaming handlers.
Returns the raw Pydantic AI task format with history/artifacts.
Args:
api_base: Base URL of the Pydantic AI agent
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
timeout: Request timeout in seconds
Returns:
Raw Pydantic AI task response (with history/artifacts)
"""
# Convert params to dict if it's a Pydantic model
params_dict = PydanticAITransformation._params_to_dict(params)
# Remove None values - FastA2A doesn't accept null for optional fields
params_dict = PydanticAITransformation._remove_none_values(params_dict)
# Ensure the message has 'kind': 'message' as required by FastA2A/Pydantic AI
if "message" in params_dict:
params_dict["message"]["kind"] = "message"
# Build A2A JSON-RPC request using message/send method for FastA2A compatibility
a2a_request = {
"jsonrpc": "2.0",
"id": request_id,
"method": "message/send",
"params": params_dict,
}
# FastA2A uses root endpoint (/) not /messages
endpoint = api_base.rstrip("/")
verbose_logger.info(f"Pydantic AI: Sending non-streaming request to {endpoint}")
# Send request to Pydantic AI agent using shared async HTTP client
client = get_async_httpx_client(
llm_provider=cast(Any, "pydantic_ai_agent"),
params={"timeout": timeout},
)
response = await client.post(
endpoint,
json=a2a_request,
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
response_data = response.json()
# Check if task is already completed
result = response_data.get("result", {})
status = result.get("status", {})
state = status.get("state", "")
if state != "completed":
# Need to poll for completion
task_id = result.get("id")
if task_id:
verbose_logger.info(
f"Pydantic AI: Task {task_id} submitted, polling for completion..."
)
response_data = await PydanticAITransformation._poll_for_completion(
client=client,
endpoint=endpoint,
task_id=task_id,
request_id=request_id,
)
verbose_logger.info(
f"Pydantic AI: Received completed response for request_id={request_id}"
)
return response_data
@staticmethod
async def send_non_streaming_request(
api_base: str,
request_id: str,
params: Any,
timeout: float = 60.0,
) -> Dict[str, Any]:
"""
Send a non-streaming A2A request to Pydantic AI agent and wait for completion.
Args:
api_base: Base URL of the Pydantic AI agent (e.g., "http://localhost:9999")
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message (dict or Pydantic model)
timeout: Request timeout in seconds
Returns:
Standard A2A non-streaming response format with message
"""
# Get raw task response
raw_response = await PydanticAITransformation._send_and_poll_raw(
api_base=api_base,
request_id=request_id,
params=params,
timeout=timeout,
)
# Transform to standard A2A non-streaming format
return PydanticAITransformation._transform_to_a2a_response(
response_data=raw_response,
request_id=request_id,
)
@staticmethod
async def send_and_get_raw_response(
api_base: str,
request_id: str,
params: Any,
timeout: float = 60.0,
) -> Dict[str, Any]:
"""
Send a request to Pydantic AI agent and return the raw task response.
Used by streaming handler to get raw response for fake streaming.
Args:
api_base: Base URL of the Pydantic AI agent
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
timeout: Request timeout in seconds
Returns:
Raw Pydantic AI task response (with history/artifacts)
"""
return await PydanticAITransformation._send_and_poll_raw(
api_base=api_base,
request_id=request_id,
params=params,
timeout=timeout,
)
@staticmethod
def _transform_to_a2a_response(
response_data: Dict[str, Any],
request_id: str,
) -> Dict[str, Any]:
"""
Transform Pydantic AI task response to standard A2A non-streaming format.
Pydantic AI returns a task with history/artifacts, but the standard A2A
non-streaming format expects:
{
"jsonrpc": "2.0",
"id": "...",
"result": {
"message": {
"role": "agent",
"parts": [{"kind": "text", "text": "..."}],
"messageId": "..."
}
}
}
Args:
response_data: Pydantic AI task response
request_id: Original request ID
Returns:
Standard A2A non-streaming response format
"""
# Extract the agent response text
full_text, message_id, parts = PydanticAITransformation._extract_response_text(
response_data
)
# Build standard A2A message
a2a_message = {
"role": "agent",
"parts": parts if parts else [{"kind": "text", "text": full_text}],
"messageId": message_id,
}
# Return standard A2A non-streaming format
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"message": a2a_message,
},
}
@staticmethod
def _extract_response_text(response_data: Dict[str, Any]) -> tuple[str, str, list]:
"""
Extract response text from completed task response.
Pydantic AI returns completed tasks with:
- history: list of messages (user and agent)
- artifacts: list of result artifacts
Args:
response_data: Completed task response
Returns:
Tuple of (full_text, message_id, parts)
"""
result = response_data.get("result", {})
# Try to extract from artifacts first (preferred for results)
artifacts = result.get("artifacts", [])
if artifacts:
for artifact in artifacts:
parts = artifact.get("parts", [])
for part in parts:
if part.get("kind") == "text":
text = part.get("text", "")
if text:
return text, str(uuid4()), parts
# Fall back to history - get the last agent message
history = result.get("history", [])
for msg in reversed(history):
if msg.get("role") == "agent":
parts = msg.get("parts", [])
message_id = msg.get("messageId", str(uuid4()))
full_text = ""
for part in parts:
if part.get("kind") == "text":
full_text += part.get("text", "")
if full_text:
return full_text, message_id, parts
# Fall back to message field (original format)
message = result.get("message", {})
if message:
parts = message.get("parts", [])
message_id = message.get("messageId", str(uuid4()))
full_text = ""
for part in parts:
if part.get("kind") == "text":
full_text += part.get("text", "")
return full_text, message_id, parts
return "", str(uuid4()), []
@staticmethod
async def fake_streaming_from_response(
response_data: Dict[str, Any],
request_id: str,
chunk_size: int = 50,
delay_ms: int = 10,
) -> AsyncIterator[Dict[str, Any]]:
"""
Convert a non-streaming A2A response into fake streaming chunks.
Emits proper A2A streaming events:
1. Task event (kind: "task") - Initial task with status "submitted"
2. Status update (kind: "status-update") - Status "working"
3. Artifact update chunks (kind: "artifact-update") - Content delivery in chunks
4. Status update (kind: "status-update") - Final "completed" status
Args:
response_data: Non-streaming A2A response dict (completed task)
request_id: A2A JSON-RPC request ID
chunk_size: Number of characters per chunk (default: 50)
delay_ms: Delay between chunks in milliseconds (default: 10)
Yields:
A2A streaming response events
"""
# Extract the response text from completed task
full_text, message_id, parts = PydanticAITransformation._extract_response_text(
response_data
)
# Extract input message from raw response for history
result = response_data.get("result", {})
history = result.get("history", [])
input_message = {}
for msg in history:
if msg.get("role") == "user":
input_message = msg
break
# Generate IDs for streaming events
task_id = str(uuid4())
context_id = str(uuid4())
artifact_id = str(uuid4())
input_message_id = input_message.get("messageId", str(uuid4()))
# 1. Emit initial task event (kind: "task", status: "submitted")
# Format matches A2ACompletionBridgeTransformation.create_task_event
task_event = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"contextId": context_id,
"history": [
{
"contextId": context_id,
"kind": "message",
"messageId": input_message_id,
"parts": input_message.get(
"parts", [{"kind": "text", "text": ""}]
),
"role": "user",
"taskId": task_id,
}
],
"id": task_id,
"kind": "task",
"status": {
"state": "submitted",
},
},
}
yield task_event
# 2. Emit status update (kind: "status-update", status: "working")
# Format matches A2ACompletionBridgeTransformation.create_status_update_event
working_event = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"contextId": context_id,
"final": False,
"kind": "status-update",
"status": {
"state": "working",
},
"taskId": task_id,
},
}
yield working_event
# Small delay to simulate processing
await asyncio.sleep(delay_ms / 1000.0)
# 3. Emit artifact update chunks (kind: "artifact-update")
# Format matches A2ACompletionBridgeTransformation.create_artifact_update_event
if full_text:
# Split text into chunks
for i in range(0, len(full_text), chunk_size):
chunk_text = full_text[i : i + chunk_size]
is_last_chunk = (i + chunk_size) >= len(full_text)
artifact_event = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"contextId": context_id,
"kind": "artifact-update",
"taskId": task_id,
"artifact": {
"artifactId": artifact_id,
"parts": [
{
"kind": "text",
"text": chunk_text,
}
],
},
},
}
yield artifact_event
# Add delay between chunks (except for last chunk)
if not is_last_chunk:
await asyncio.sleep(delay_ms / 1000.0)
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
completed_event = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"contextId": context_id,
"final": True,
"kind": "status-update",
"status": {
"state": "completed",
},
"taskId": task_id,
},
}
yield completed_event
verbose_logger.info(
f"Pydantic AI: Fake streaming completed for request_id={request_id}"
)

View File

@@ -0,0 +1,184 @@
"""
A2A Streaming Iterator with token tracking and logging support.
"""
import asyncio
from datetime import datetime
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional
import litellm
from litellm._logging import verbose_logger
from litellm.a2a_protocol.cost_calculator import A2ACostCalculator
from litellm.a2a_protocol.utils import A2ARequestUtils
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.thread_pool_executor import executor
if TYPE_CHECKING:
from a2a.types import SendStreamingMessageRequest, SendStreamingMessageResponse
class A2AStreamingIterator:
"""
Async iterator for A2A streaming responses with token tracking.
Collects chunks, extracts text, and logs usage on completion.
"""
def __init__(
self,
stream: AsyncIterator["SendStreamingMessageResponse"],
request: "SendStreamingMessageRequest",
logging_obj: LiteLLMLoggingObj,
agent_name: str = "unknown",
):
self.stream = stream
self.request = request
self.logging_obj = logging_obj
self.agent_name = agent_name
self.start_time = datetime.now()
# Collect chunks for token counting
self.chunks: List[Any] = []
self.collected_text_parts: List[str] = []
self.final_chunk: Optional[Any] = None
def __aiter__(self):
return self
async def __anext__(self) -> "SendStreamingMessageResponse":
try:
chunk = await self.stream.__anext__()
# Store chunk
self.chunks.append(chunk)
# Extract text from chunk for token counting
self._collect_text_from_chunk(chunk)
# Check if this is the final chunk (completed status)
if self._is_completed_chunk(chunk):
self.final_chunk = chunk
return chunk
except StopAsyncIteration:
# Stream ended - handle logging
if self.final_chunk is None and self.chunks:
self.final_chunk = self.chunks[-1]
await self._handle_stream_complete()
raise
def _collect_text_from_chunk(self, chunk: Any) -> None:
"""Extract text from a streaming chunk and add to collected parts."""
try:
chunk_dict = (
chunk.model_dump(mode="json", exclude_none=True)
if hasattr(chunk, "model_dump")
else {}
)
text = A2ARequestUtils.extract_text_from_response(chunk_dict)
if text:
self.collected_text_parts.append(text)
except Exception:
verbose_logger.debug("Failed to extract text from A2A streaming chunk")
def _is_completed_chunk(self, chunk: Any) -> bool:
"""Check if chunk indicates stream completion."""
try:
chunk_dict = (
chunk.model_dump(mode="json", exclude_none=True)
if hasattr(chunk, "model_dump")
else {}
)
result = chunk_dict.get("result", {})
if isinstance(result, dict):
status = result.get("status", {})
if isinstance(status, dict):
return status.get("state") == "completed"
except Exception:
pass
return False
async def _handle_stream_complete(self) -> None:
"""Handle logging and token counting when stream completes."""
try:
end_time = datetime.now()
# Calculate tokens from collected text
input_message = A2ARequestUtils.get_input_message_from_request(self.request)
input_text = A2ARequestUtils.extract_text_from_message(input_message)
prompt_tokens = A2ARequestUtils.count_tokens(input_text)
# Use the last (most complete) text from chunks
output_text = (
self.collected_text_parts[-1] if self.collected_text_parts else ""
)
completion_tokens = A2ARequestUtils.count_tokens(output_text)
total_tokens = prompt_tokens + completion_tokens
# Create usage object
usage = litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
# Set usage on logging obj
self.logging_obj.model_call_details["usage"] = usage
# Mark stream flag for downstream callbacks
self.logging_obj.model_call_details["stream"] = False
# Calculate cost using A2ACostCalculator
response_cost = A2ACostCalculator.calculate_a2a_cost(self.logging_obj)
self.logging_obj.model_call_details["response_cost"] = response_cost
# Build result for logging
result = self._build_logging_result(usage)
# Call success handlers - they will build standard_logging_object
asyncio.create_task(
self.logging_obj.async_success_handler(
result=result,
start_time=self.start_time,
end_time=end_time,
cache_hit=None,
)
)
executor.submit(
self.logging_obj.success_handler,
result=result,
cache_hit=None,
start_time=self.start_time,
end_time=end_time,
)
verbose_logger.info(
f"A2A streaming completed: prompt_tokens={prompt_tokens}, "
f"completion_tokens={completion_tokens}, total_tokens={total_tokens}, "
f"response_cost={response_cost}"
)
except Exception as e:
verbose_logger.debug(f"Error in A2A streaming completion handler: {e}")
def _build_logging_result(self, usage: litellm.Usage) -> Dict[str, Any]:
"""Build a result dict for logging."""
result: Dict[str, Any] = {
"id": getattr(self.request, "id", "unknown"),
"jsonrpc": "2.0",
"usage": usage.model_dump()
if hasattr(usage, "model_dump")
else dict(usage),
}
# Add final chunk result if available
if self.final_chunk:
try:
chunk_dict = self.final_chunk.model_dump(mode="json", exclude_none=True)
result["result"] = chunk_dict.get("result", {})
except Exception:
pass
return result

View File

@@ -0,0 +1,138 @@
"""
Utility functions for A2A protocol.
"""
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
import litellm
from litellm._logging import verbose_logger
if TYPE_CHECKING:
from a2a.types import SendMessageRequest, SendStreamingMessageRequest
class A2ARequestUtils:
"""Utility class for A2A request/response processing."""
@staticmethod
def extract_text_from_message(message: Any) -> str:
"""
Extract text content from A2A message parts.
Args:
message: A2A message dict or object with 'parts' containing text parts
Returns:
Concatenated text from all text parts
"""
if message is None:
return ""
# Handle both dict and object access
if isinstance(message, dict):
parts = message.get("parts", [])
else:
parts = getattr(message, "parts", []) or []
text_parts: List[str] = []
for part in parts:
if isinstance(part, dict):
if part.get("kind") == "text":
text_parts.append(part.get("text", ""))
else:
if getattr(part, "kind", None) == "text":
text_parts.append(getattr(part, "text", ""))
return " ".join(text_parts)
@staticmethod
def extract_text_from_response(response_dict: Dict[str, Any]) -> str:
"""
Extract text content from A2A response result.
Args:
response_dict: A2A response dict with 'result' containing message
Returns:
Text from response message parts
"""
result = response_dict.get("result", {})
if not isinstance(result, dict):
return ""
message = result.get("message", {})
return A2ARequestUtils.extract_text_from_message(message)
@staticmethod
def get_input_message_from_request(
request: "Union[SendMessageRequest, SendStreamingMessageRequest]",
) -> Any:
"""
Extract the input message from an A2A request.
Args:
request: The A2A SendMessageRequest or SendStreamingMessageRequest
Returns:
The message object/dict or None
"""
params = getattr(request, "params", None)
if params is None:
return None
return getattr(params, "message", None)
@staticmethod
def count_tokens(text: str) -> int:
"""
Count tokens in text using litellm.token_counter.
Args:
text: Text to count tokens for
Returns:
Token count, or 0 if counting fails
"""
if not text:
return 0
try:
return litellm.token_counter(text=text)
except Exception:
verbose_logger.debug("Failed to count tokens")
return 0
@staticmethod
def calculate_usage_from_request_response(
request: "Union[SendMessageRequest, SendStreamingMessageRequest]",
response_dict: Dict[str, Any],
) -> Tuple[int, int, int]:
"""
Calculate token usage from A2A request and response.
Args:
request: The A2A SendMessageRequest or SendStreamingMessageRequest
response_dict: The A2A response as a dict
Returns:
Tuple of (prompt_tokens, completion_tokens, total_tokens)
"""
# Count input tokens
input_message = A2ARequestUtils.get_input_message_from_request(request)
input_text = A2ARequestUtils.extract_text_from_message(input_message)
prompt_tokens = A2ARequestUtils.count_tokens(input_text)
# Count output tokens
output_text = A2ARequestUtils.extract_text_from_response(response_dict)
completion_tokens = A2ARequestUtils.count_tokens(output_text)
total_tokens = prompt_tokens + completion_tokens
return prompt_tokens, completion_tokens, total_tokens
# Backwards compatibility aliases
def extract_text_from_a2a_message(message: Any) -> str:
return A2ARequestUtils.extract_text_from_message(message)
def extract_text_from_a2a_response(response_dict: Dict[str, Any]) -> str:
return A2ARequestUtils.extract_text_from_response(response_dict)

View File

@@ -0,0 +1,182 @@
{
"description": "Mapping of Anthropic beta headers for each provider. Keys are input header names, values are provider-specific header names (or null if unsupported). Only headers present in mapping keys with non-null values can be forwarded.",
"anthropic": {
"advanced-tool-use-2025-11-20": "advanced-tool-use-2025-11-20",
"bash_20241022": null,
"bash_20250124": null,
"code-execution-2025-08-25": "code-execution-2025-08-25",
"compact-2026-01-12": "compact-2026-01-12",
"computer-use-2025-01-24": "computer-use-2025-01-24",
"computer-use-2025-11-24": "computer-use-2025-11-24",
"context-1m-2025-08-07": "context-1m-2025-08-07",
"context-management-2025-06-27": "context-management-2025-06-27",
"effort-2025-11-24": "effort-2025-11-24",
"fast-mode-2026-02-01": "fast-mode-2026-02-01",
"files-api-2025-04-14": "files-api-2025-04-14",
"structured-output-2024-03-01": null,
"fine-grained-tool-streaming-2025-05-14": "fine-grained-tool-streaming-2025-05-14",
"interleaved-thinking-2025-05-14": "interleaved-thinking-2025-05-14",
"mcp-client-2025-11-20": "mcp-client-2025-11-20",
"mcp-client-2025-04-04": "mcp-client-2025-04-04",
"mcp-servers-2025-12-04": null,
"oauth-2025-04-20": "oauth-2025-04-20",
"output-128k-2025-02-19": "output-128k-2025-02-19",
"prompt-caching-scope-2026-01-05": "prompt-caching-scope-2026-01-05",
"skills-2025-10-02": "skills-2025-10-02",
"structured-outputs-2025-11-13": "structured-outputs-2025-11-13",
"text_editor_20241022": null,
"text_editor_20250124": null,
"token-efficient-tools-2025-02-19": "token-efficient-tools-2025-02-19",
"web-fetch-2025-09-10": "web-fetch-2025-09-10",
"web-search-2025-03-05": "web-search-2025-03-05"
},
"azure_ai": {
"advanced-tool-use-2025-11-20": "advanced-tool-use-2025-11-20",
"bash_20241022": null,
"bash_20250124": null,
"code-execution-2025-08-25": "code-execution-2025-08-25",
"compact-2026-01-12": null,
"computer-use-2025-01-24": "computer-use-2025-01-24",
"computer-use-2025-11-24": "computer-use-2025-11-24",
"context-1m-2025-08-07": "context-1m-2025-08-07",
"context-management-2025-06-27": "context-management-2025-06-27",
"effort-2025-11-24": "effort-2025-11-24",
"fast-mode-2026-02-01": null,
"files-api-2025-04-14": "files-api-2025-04-14",
"fine-grained-tool-streaming-2025-05-14": null,
"interleaved-thinking-2025-05-14": "interleaved-thinking-2025-05-14",
"mcp-client-2025-11-20": "mcp-client-2025-11-20",
"mcp-client-2025-04-04": "mcp-client-2025-04-04",
"mcp-servers-2025-12-04": null,
"output-128k-2025-02-19": null,
"structured-output-2024-03-01": null,
"prompt-caching-scope-2026-01-05": "prompt-caching-scope-2026-01-05",
"skills-2025-10-02": "skills-2025-10-02",
"structured-outputs-2025-11-13": "structured-outputs-2025-11-13",
"text_editor_20241022": null,
"text_editor_20250124": null,
"token-efficient-tools-2025-02-19": null,
"web-fetch-2025-09-10": "web-fetch-2025-09-10",
"web-search-2025-03-05": "web-search-2025-03-05"
},
"bedrock_converse": {
"advanced-tool-use-2025-11-20": null,
"bash_20241022": null,
"bash_20250124": null,
"code-execution-2025-08-25": null,
"compact-2026-01-12": null,
"computer-use-2025-01-24": "computer-use-2025-01-24",
"computer-use-2025-11-24": "computer-use-2025-11-24",
"context-1m-2025-08-07": "context-1m-2025-08-07",
"context-management-2025-06-27": "context-management-2025-06-27",
"effort-2025-11-24": null,
"fast-mode-2026-02-01": null,
"files-api-2025-04-14": null,
"fine-grained-tool-streaming-2025-05-14": null,
"interleaved-thinking-2025-05-14": "interleaved-thinking-2025-05-14",
"mcp-client-2025-11-20": null,
"mcp-client-2025-04-04": null,
"mcp-servers-2025-12-04": null,
"output-128k-2025-02-19": null,
"structured-output-2024-03-01": null,
"prompt-caching-scope-2026-01-05": null,
"skills-2025-10-02": null,
"structured-outputs-2025-11-13": "structured-outputs-2025-11-13",
"text_editor_20241022": null,
"text_editor_20250124": null,
"token-efficient-tools-2025-02-19": null,
"tool-search-tool-2025-10-19": null,
"web-fetch-2025-09-10": null,
"web-search-2025-03-05": null
},
"bedrock": {
"advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19",
"bash_20241022": null,
"bash_20250124": null,
"code-execution-2025-08-25": null,
"compact-2026-01-12": "compact-2026-01-12",
"computer-use-2025-01-24": "computer-use-2025-01-24",
"computer-use-2025-11-24": "computer-use-2025-11-24",
"context-1m-2025-08-07": "context-1m-2025-08-07",
"context-management-2025-06-27": "context-management-2025-06-27",
"effort-2025-11-24": null,
"fast-mode-2026-02-01": null,
"files-api-2025-04-14": null,
"fine-grained-tool-streaming-2025-05-14": null,
"interleaved-thinking-2025-05-14": "interleaved-thinking-2025-05-14",
"mcp-client-2025-11-20": null,
"mcp-client-2025-04-04": null,
"mcp-servers-2025-12-04": null,
"output-128k-2025-02-19": null,
"structured-output-2024-03-01": null,
"prompt-caching-scope-2026-01-05": null,
"skills-2025-10-02": null,
"structured-outputs-2025-11-13": null,
"text_editor_20241022": null,
"text_editor_20250124": null,
"token-efficient-tools-2025-02-19": null,
"tool-search-tool-2025-10-19": "tool-search-tool-2025-10-19",
"web-fetch-2025-09-10": null,
"web-search-2025-03-05": null
},
"vertex_ai": {
"advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19",
"bash_20241022": null,
"bash_20250124": null,
"code-execution-2025-08-25": null,
"compact-2026-01-12": null,
"computer-use-2025-01-24": "computer-use-2025-01-24",
"computer-use-2025-11-24": "computer-use-2025-11-24",
"context-1m-2025-08-07": "context-1m-2025-08-07",
"context-management-2025-06-27": "context-management-2025-06-27",
"effort-2025-11-24": null,
"fast-mode-2026-02-01": null,
"files-api-2025-04-14": null,
"fine-grained-tool-streaming-2025-05-14": null,
"interleaved-thinking-2025-05-14": "interleaved-thinking-2025-05-14",
"mcp-client-2025-11-20": null,
"mcp-client-2025-04-04": null,
"mcp-servers-2025-12-04": null,
"output-128k-2025-02-19": null,
"structured-output-2024-03-01": null,
"prompt-caching-scope-2026-01-05": null,
"skills-2025-10-02": null,
"structured-outputs-2025-11-13": null,
"text_editor_20241022": null,
"text_editor_20250124": null,
"token-efficient-tools-2025-02-19": null,
"tool-search-tool-2025-10-19": "tool-search-tool-2025-10-19",
"web-fetch-2025-09-10": null,
"web-search-2025-03-05": "web-search-2025-03-05"
},
"databricks": {
"advanced-tool-use-2025-11-20": "advanced-tool-use-2025-11-20",
"bash_20241022": null,
"bash_20250124": null,
"code-execution-2025-08-25": "code-execution-2025-08-25",
"compact-2026-01-12": "compact-2026-01-12",
"computer-use-2025-01-24": "computer-use-2025-01-24",
"computer-use-2025-11-24": "computer-use-2025-11-24",
"context-1m-2025-08-07": "context-1m-2025-08-07",
"context-management-2025-06-27": "context-management-2025-06-27",
"effort-2025-11-24": "effort-2025-11-24",
"fast-mode-2026-02-01": "fast-mode-2026-02-01",
"files-api-2025-04-14": "files-api-2025-04-14",
"structured-output-2024-03-01": null,
"fine-grained-tool-streaming-2025-05-14": "fine-grained-tool-streaming-2025-05-14",
"interleaved-thinking-2025-05-14": "interleaved-thinking-2025-05-14",
"mcp-client-2025-11-20": "mcp-client-2025-11-20",
"mcp-client-2025-04-04": "mcp-client-2025-04-04",
"mcp-servers-2025-12-04": null,
"oauth-2025-04-20": "oauth-2025-04-20",
"output-128k-2025-02-19": "output-128k-2025-02-19",
"prompt-caching-scope-2026-01-05": "prompt-caching-scope-2026-01-05",
"skills-2025-10-02": "skills-2025-10-02",
"structured-outputs-2025-11-13": "structured-outputs-2025-11-13",
"text_editor_20241022": null,
"text_editor_20250124": null,
"token-efficient-tools-2025-02-19": "token-efficient-tools-2025-02-19",
"web-fetch-2025-09-10": "web-fetch-2025-09-10",
"web-search-2025-03-05": "web-search-2025-03-05"
}
}

View File

@@ -0,0 +1,385 @@
"""
Centralized manager for Anthropic beta headers across different providers.
This module provides utilities to:
1. Load beta header configuration from JSON (mapping of supported headers per provider)
2. Filter and map beta headers based on provider support
3. Handle provider-specific header name mappings (e.g., advanced-tool-use -> tool-search-tool)
4. Support remote fetching and caching similar to model cost map
Design:
- JSON config contains mapping of beta headers for each provider
- Keys are input header names, values are provider-specific header names (or null if unsupported)
- Only headers present in mapping keys with non-null values can be forwarded
- This enforces stricter validation than the previous unsupported list approach
Configuration can be loaded from:
- Remote URL (default): Fetches from GitHub repository
- Local file: Set LITELLM_LOCAL_ANTHROPIC_BETA_HEADERS=True to use bundled config only
Environment Variables:
- LITELLM_LOCAL_ANTHROPIC_BETA_HEADERS: Set to "True" to disable remote fetching
- LITELLM_ANTHROPIC_BETA_HEADERS_URL: Custom URL for remote config (optional)
"""
import json
import os
from importlib.resources import files
from typing import Dict, List, Optional, Set
import httpx
from litellm.litellm_core_utils.litellm_logging import verbose_logger
# Cache for the loaded configuration
_BETA_HEADERS_CONFIG: Optional[Dict] = None
class GetAnthropicBetaHeadersConfig:
"""
Handles fetching, validating, and loading the Anthropic beta headers configuration.
Similar to GetModelCostMap, this class manages the lifecycle of the beta headers
configuration with support for remote fetching and local fallback.
"""
@staticmethod
def load_local_beta_headers_config() -> Dict:
"""Load the local backup beta headers config bundled with the package."""
try:
content = json.loads(
files("litellm")
.joinpath("anthropic_beta_headers_config.json")
.read_text(encoding="utf-8")
)
return content
except Exception as e:
verbose_logger.error(f"Failed to load local beta headers config: {e}")
# Return empty config as fallback
return {
"anthropic": {},
"azure_ai": {},
"bedrock": {},
"bedrock_converse": {},
"vertex_ai": {},
"provider_aliases": {},
}
@staticmethod
def _check_is_valid_dict(fetched_config: dict) -> bool:
"""Check if fetched config is a non-empty dict with expected structure."""
if not isinstance(fetched_config, dict):
verbose_logger.warning(
"LiteLLM: Fetched beta headers config is not a dict (type=%s). "
"Falling back to local backup.",
type(fetched_config).__name__,
)
return False
if len(fetched_config) == 0:
verbose_logger.warning(
"LiteLLM: Fetched beta headers config is empty. "
"Falling back to local backup.",
)
return False
# Check for at least one provider key
provider_keys = [
"anthropic",
"azure_ai",
"bedrock",
"bedrock_converse",
"vertex_ai",
]
has_provider = any(key in fetched_config for key in provider_keys)
if not has_provider:
verbose_logger.warning(
"LiteLLM: Fetched beta headers config missing provider keys. "
"Falling back to local backup.",
)
return False
return True
@classmethod
def validate_beta_headers_config(cls, fetched_config: dict) -> bool:
"""
Validate the integrity of a fetched beta headers config.
Returns True if all checks pass, False otherwise.
"""
return cls._check_is_valid_dict(fetched_config)
@staticmethod
def fetch_remote_beta_headers_config(url: str, timeout: int = 5) -> dict:
"""
Fetch the beta headers config from a remote URL.
Returns the parsed JSON dict. Raises on network/parse errors
(caller is expected to handle).
"""
response = httpx.get(url, timeout=timeout)
response.raise_for_status()
return response.json()
def get_beta_headers_config(url: str) -> dict:
"""
Public entry point — returns the beta headers config dict.
1. If ``LITELLM_LOCAL_ANTHROPIC_BETA_HEADERS`` is set, uses the local backup only.
2. Otherwise fetches from ``url``, validates integrity, and falls back
to the local backup on any failure.
Args:
url: URL to fetch the remote beta headers configuration from
Returns:
Dict containing the beta headers configuration
"""
# Check if local-only mode is enabled
if os.getenv("LITELLM_LOCAL_ANTHROPIC_BETA_HEADERS", "").lower() == "true":
# verbose_logger.debug("Using local Anthropic beta headers config (LITELLM_LOCAL_ANTHROPIC_BETA_HEADERS=True)")
return GetAnthropicBetaHeadersConfig.load_local_beta_headers_config()
try:
content = GetAnthropicBetaHeadersConfig.fetch_remote_beta_headers_config(url)
except Exception as e:
verbose_logger.warning(
"LiteLLM: Failed to fetch remote beta headers config from %s: %s. "
"Falling back to local backup.",
url,
str(e),
)
return GetAnthropicBetaHeadersConfig.load_local_beta_headers_config()
# Validate the fetched config
if not GetAnthropicBetaHeadersConfig.validate_beta_headers_config(
fetched_config=content
):
verbose_logger.warning(
"LiteLLM: Fetched beta headers config failed integrity check. "
"Using local backup instead. url=%s",
url,
)
return GetAnthropicBetaHeadersConfig.load_local_beta_headers_config()
return content
def _load_beta_headers_config() -> Dict:
"""
Load the beta headers configuration.
Uses caching to avoid repeated fetches/file reads.
This function is called by all public API functions and manages the global cache.
Returns:
Dict containing the beta headers configuration
"""
global _BETA_HEADERS_CONFIG
if _BETA_HEADERS_CONFIG is not None:
return _BETA_HEADERS_CONFIG
# Get the URL from environment or use default
from litellm import anthropic_beta_headers_url
_BETA_HEADERS_CONFIG = get_beta_headers_config(url=anthropic_beta_headers_url)
verbose_logger.debug("Loaded and cached beta headers config")
return _BETA_HEADERS_CONFIG
def reload_beta_headers_config() -> Dict:
"""
Force reload the beta headers configuration from source (remote or local).
Clears the cache and fetches fresh configuration.
Returns:
Dict containing the newly loaded beta headers configuration
"""
global _BETA_HEADERS_CONFIG
_BETA_HEADERS_CONFIG = None
verbose_logger.info("Reloading beta headers config (cache cleared)")
return _load_beta_headers_config()
def get_provider_name(provider: str) -> str:
"""
Resolve provider aliases to canonical provider names.
Args:
provider: Provider name (may be an alias)
Returns:
Canonical provider name
"""
config = _load_beta_headers_config()
aliases = config.get("provider_aliases", {})
return aliases.get(provider, provider)
def filter_and_transform_beta_headers(
beta_headers: List[str],
provider: str,
) -> List[str]:
"""
Filter and transform beta headers based on provider's mapping configuration.
This function:
1. Only allows headers that are present in the provider's mapping keys
2. Filters out headers with null values (unsupported)
3. Maps headers to provider-specific names (e.g., advanced-tool-use -> tool-search-tool)
Args:
beta_headers: List of Anthropic beta header values
provider: Provider name (e.g., "anthropic", "bedrock", "vertex_ai")
Returns:
List of filtered and transformed beta headers for the provider
"""
if not beta_headers:
return []
config = _load_beta_headers_config()
provider = get_provider_name(provider)
# Get the header mapping for this provider
provider_mapping = config.get(provider, {})
filtered_headers: Set[str] = set()
for header in beta_headers:
header = header.strip()
# Check if header is in the mapping
if header not in provider_mapping:
verbose_logger.debug(
f"Dropping unknown beta header '{header}' for provider '{provider}' (not in mapping)"
)
continue
# Get the mapped header value
mapped_header = provider_mapping[header]
# Skip if header is unsupported (null value)
if mapped_header is None:
verbose_logger.debug(
f"Dropping unsupported beta header '{header}' for provider '{provider}'"
)
continue
# Add the mapped header
filtered_headers.add(mapped_header)
return sorted(list(filtered_headers))
def is_beta_header_supported(
beta_header: str,
provider: str,
) -> bool:
"""
Check if a specific beta header is supported by a provider.
Args:
beta_header: The Anthropic beta header value
provider: Provider name
Returns:
True if the header is in the mapping with a non-null value, False otherwise
"""
config = _load_beta_headers_config()
provider = get_provider_name(provider)
provider_mapping = config.get(provider, {})
# Header is supported if it's in the mapping and has a non-null value
return beta_header in provider_mapping and provider_mapping[beta_header] is not None
def get_provider_beta_header(
anthropic_beta_header: str,
provider: str,
) -> Optional[str]:
"""
Get the provider-specific beta header name for a given Anthropic beta header.
This function handles header transformations/mappings (e.g., advanced-tool-use -> tool-search-tool).
Args:
anthropic_beta_header: The Anthropic beta header value
provider: Provider name
Returns:
The provider-specific header name if supported, or None if unsupported/unknown
"""
config = _load_beta_headers_config()
provider = get_provider_name(provider)
# Get the header mapping for this provider
provider_mapping = config.get(provider, {})
# Check if header is in the mapping
if anthropic_beta_header not in provider_mapping:
return None
# Return the mapped value (could be None if unsupported)
return provider_mapping[anthropic_beta_header]
def update_headers_with_filtered_beta(
headers: dict,
provider: str,
) -> dict:
"""
Update headers dict by filtering and transforming anthropic-beta header values.
Modifies the headers dict in place and returns it.
Args:
headers: Request headers dict (will be modified in place)
provider: Provider name
Returns:
Updated headers dict
"""
existing_beta = headers.get("anthropic-beta")
if not existing_beta:
return headers
# Parse existing beta headers
beta_values = [b.strip() for b in existing_beta.split(",") if b.strip()]
# Filter and transform based on provider
filtered_beta_values = filter_and_transform_beta_headers(
beta_headers=beta_values,
provider=provider,
)
# Update or remove the header
if filtered_beta_values:
headers["anthropic-beta"] = ",".join(filtered_beta_values)
else:
# Remove the header if no values remain
headers.pop("anthropic-beta", None)
return headers
def get_unsupported_headers(provider: str) -> List[str]:
"""
Get all beta headers that are unsupported by a provider (have null values in mapping).
Args:
provider: Provider name
Returns:
List of unsupported Anthropic beta header names
"""
config = _load_beta_headers_config()
provider = get_provider_name(provider)
provider_mapping = config.get(provider, {})
# Return headers with null values
return [header for header, value in provider_mapping.items() if value is None]

View File

@@ -0,0 +1,6 @@
"""
Anthropic module for LiteLLM
"""
from .messages import acreate, create
__all__ = ["acreate", "create"]

View File

@@ -0,0 +1,19 @@
"""Anthropic error format utilities."""
from .exception_mapping_utils import (
ANTHROPIC_ERROR_TYPE_MAP,
AnthropicExceptionMapping,
)
from .exceptions import (
AnthropicErrorDetail,
AnthropicErrorResponse,
AnthropicErrorType,
)
__all__ = [
"AnthropicErrorType",
"AnthropicErrorDetail",
"AnthropicErrorResponse",
"ANTHROPIC_ERROR_TYPE_MAP",
"AnthropicExceptionMapping",
]

View File

@@ -0,0 +1,172 @@
"""
Utilities for mapping exceptions to Anthropic error format.
Similar to litellm/litellm_core_utils/exception_mapping_utils.py but for Anthropic response format.
"""
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
from typing import Dict, Optional
from .exceptions import AnthropicErrorResponse, AnthropicErrorType
# HTTP status code -> Anthropic error type
# Source: https://docs.anthropic.com/en/api/errors
ANTHROPIC_ERROR_TYPE_MAP: Dict[int, AnthropicErrorType] = {
400: "invalid_request_error",
401: "authentication_error",
403: "permission_error",
404: "not_found_error",
413: "request_too_large",
429: "rate_limit_error",
500: "api_error",
529: "overloaded_error",
}
class AnthropicExceptionMapping:
"""
Helper class for mapping exceptions to Anthropic error format.
Similar pattern to ExceptionCheckers in litellm_core_utils/exception_mapping_utils.py
"""
@staticmethod
def get_error_type(status_code: int) -> AnthropicErrorType:
"""Map HTTP status code to Anthropic error type."""
return ANTHROPIC_ERROR_TYPE_MAP.get(status_code, "api_error")
@staticmethod
def create_error_response(
status_code: int,
message: str,
request_id: Optional[str] = None,
) -> AnthropicErrorResponse:
"""
Create an Anthropic-formatted error response dict.
Anthropic error format:
{
"type": "error",
"error": {"type": "...", "message": "..."},
"request_id": "req_..."
}
"""
error_type = AnthropicExceptionMapping.get_error_type(status_code)
response: AnthropicErrorResponse = {
"type": "error",
"error": {
"type": error_type,
"message": message,
},
}
if request_id:
response["request_id"] = request_id
return response
@staticmethod
def extract_error_message(raw_message: str) -> str:
"""
Extract error message from various provider response formats.
Handles:
- Bedrock: {"detail": {"message": "..."}}
- AWS: {"Message": "..."}
- Generic: {"message": "..."}
- Plain strings
"""
parsed = safe_json_loads(raw_message)
if isinstance(parsed, dict):
# Bedrock format
if "detail" in parsed and isinstance(parsed["detail"], dict):
return parsed["detail"].get("message", raw_message)
# AWS/generic format
return parsed.get("Message") or parsed.get("message") or raw_message
return raw_message
@staticmethod
def _is_anthropic_error_dict(parsed: dict) -> bool:
"""
Check if a parsed dict is in Anthropic error format.
Anthropic error format:
{
"type": "error",
"error": {"type": "...", "message": "..."}
}
"""
return (
parsed.get("type") == "error"
and isinstance(parsed.get("error"), dict)
and "type" in parsed["error"]
and "message" in parsed["error"]
)
@staticmethod
def _extract_message_from_dict(parsed: dict, raw_message: str) -> str:
"""
Extract error message from a parsed provider-specific dict.
Handles:
- Bedrock: {"detail": {"message": "..."}}
- AWS: {"Message": "..."}
- Generic: {"message": "..."}
"""
# Bedrock format
if "detail" in parsed and isinstance(parsed["detail"], dict):
return parsed["detail"].get("message", raw_message)
# AWS/generic format
return parsed.get("Message") or parsed.get("message") or raw_message
@staticmethod
def transform_to_anthropic_error(
status_code: int,
raw_message: str,
request_id: Optional[str] = None,
) -> AnthropicErrorResponse:
"""
Transform an error message to Anthropic format.
- If already in Anthropic format: passthrough unchanged
- Otherwise: extract message and create Anthropic error
Parses JSON only once for efficiency.
Args:
status_code: HTTP status code
raw_message: Raw error message (may be JSON string or plain text)
request_id: Optional request ID to include
Returns:
AnthropicErrorResponse dict
"""
# Try to parse as JSON once
parsed: Optional[dict] = safe_json_loads(raw_message)
if not isinstance(parsed, dict):
parsed = None
# If parsed and already in Anthropic format - passthrough
if parsed is not None and AnthropicExceptionMapping._is_anthropic_error_dict(
parsed
):
# Optionally add request_id if provided and not present
if request_id and "request_id" not in parsed:
parsed["request_id"] = request_id
return parsed # type: ignore
# Extract message - use parsed dict if available, otherwise raw string
if parsed is not None:
message = AnthropicExceptionMapping._extract_message_from_dict(
parsed, raw_message
)
else:
message = raw_message
return AnthropicExceptionMapping.create_error_response(
status_code=status_code,
message=message,
request_id=request_id,
)

View File

@@ -0,0 +1,41 @@
"""Anthropic error format type definitions."""
from typing_extensions import Literal, Required, TypedDict
# Known Anthropic error types
# Source: https://docs.anthropic.com/en/api/errors
AnthropicErrorType = Literal[
"invalid_request_error",
"authentication_error",
"permission_error",
"not_found_error",
"request_too_large",
"rate_limit_error",
"api_error",
"overloaded_error",
]
class AnthropicErrorDetail(TypedDict):
"""Inner error detail in Anthropic format."""
type: AnthropicErrorType
message: str
class AnthropicErrorResponse(TypedDict, total=False):
"""
Anthropic-formatted error response.
Format:
{
"type": "error",
"error": {"type": "...", "message": "..."},
"request_id": "req_..." # optional
}
"""
type: Required[Literal["error"]]
error: Required[AnthropicErrorDetail]
request_id: str

View File

@@ -0,0 +1,144 @@
"""
Interface for Anthropic's messages API
Use this to call LLMs in Anthropic /messages Request/Response format
This is an __init__.py file to allow the following interface
- litellm.messages.acreate
- litellm.messages.create
"""
from typing import Any, AsyncIterator, Coroutine, Dict, List, Optional, Union
from litellm.llms.anthropic.experimental_pass_through.messages.handler import (
anthropic_messages as _async_anthropic_messages,
)
from litellm.llms.anthropic.experimental_pass_through.messages.handler import (
anthropic_messages_handler as _sync_anthropic_messages,
)
from litellm.types.llms.anthropic_messages.anthropic_response import (
AnthropicMessagesResponse,
)
async def acreate(
max_tokens: int,
messages: List[Dict],
model: str,
metadata: Optional[Dict] = None,
stop_sequences: Optional[List[str]] = None,
stream: Optional[bool] = False,
system: Optional[str] = None,
temperature: Optional[float] = None,
thinking: Optional[Dict] = None,
tool_choice: Optional[Dict] = None,
tools: Optional[List[Dict]] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
container: Optional[Dict] = None,
**kwargs
) -> Union[AnthropicMessagesResponse, AsyncIterator]:
"""
Async wrapper for Anthropic's messages API
Args:
max_tokens (int): Maximum tokens to generate (required)
messages (List[Dict]): List of message objects with role and content (required)
model (str): Model name to use (required)
metadata (Dict, optional): Request metadata
stop_sequences (List[str], optional): Custom stop sequences
stream (bool, optional): Whether to stream the response
system (str, optional): System prompt
temperature (float, optional): Sampling temperature (0.0 to 1.0)
thinking (Dict, optional): Extended thinking configuration
tool_choice (Dict, optional): Tool choice configuration
tools (List[Dict], optional): List of tool definitions
top_k (int, optional): Top K sampling parameter
top_p (float, optional): Nucleus sampling parameter
container (Dict, optional): Container config with skills for code execution
**kwargs: Additional arguments
Returns:
Dict: Response from the API
"""
return await _async_anthropic_messages(
max_tokens=max_tokens,
messages=messages,
model=model,
metadata=metadata,
stop_sequences=stop_sequences,
stream=stream,
system=system,
temperature=temperature,
thinking=thinking,
tool_choice=tool_choice,
tools=tools,
top_k=top_k,
top_p=top_p,
container=container,
**kwargs,
)
def create(
max_tokens: int,
messages: List[Dict],
model: str,
metadata: Optional[Dict] = None,
stop_sequences: Optional[List[str]] = None,
stream: Optional[bool] = False,
system: Optional[str] = None,
temperature: Optional[float] = None,
thinking: Optional[Dict] = None,
tool_choice: Optional[Dict] = None,
tools: Optional[List[Dict]] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
container: Optional[Dict] = None,
**kwargs
) -> Union[
AnthropicMessagesResponse,
AsyncIterator[Any],
Coroutine[Any, Any, Union[AnthropicMessagesResponse, AsyncIterator[Any]]],
]:
"""
Async wrapper for Anthropic's messages API
Args:
max_tokens (int): Maximum tokens to generate (required)
messages (List[Dict]): List of message objects with role and content (required)
model (str): Model name to use (required)
metadata (Dict, optional): Request metadata
stop_sequences (List[str], optional): Custom stop sequences
stream (bool, optional): Whether to stream the response
system (str, optional): System prompt
temperature (float, optional): Sampling temperature (0.0 to 1.0)
thinking (Dict, optional): Extended thinking configuration
tool_choice (Dict, optional): Tool choice configuration
tools (List[Dict], optional): List of tool definitions
top_k (int, optional): Top K sampling parameter
top_p (float, optional): Nucleus sampling parameter
**kwargs: Additional arguments
Returns:
Dict: Response from the API
"""
return _sync_anthropic_messages(
max_tokens=max_tokens,
messages=messages,
model=model,
metadata=metadata,
stop_sequences=stop_sequences,
stream=stream,
system=system,
temperature=temperature,
thinking=thinking,
tool_choice=tool_choice,
tools=tools,
top_k=top_k,
top_p=top_p,
container=container,
**kwargs,
)

View File

@@ -0,0 +1,116 @@
## Use LLM API endpoints in Anthropic Interface
Note: This is called `anthropic_interface` because `anthropic` is a known python package and was failing mypy type checking.
## Usage
---
### LiteLLM Python SDK
#### Non-streaming example
```python showLineNumbers title="Example using LiteLLM Python SDK"
import litellm
response = await litellm.anthropic.messages.acreate(
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
api_key=api_key,
model="anthropic/claude-3-haiku-20240307",
max_tokens=100,
)
```
Example response:
```json
{
"content": [
{
"text": "Hi! this is a very short joke",
"type": "text"
}
],
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
"model": "claude-3-7-sonnet-20250219",
"role": "assistant",
"stop_reason": "end_turn",
"stop_sequence": null,
"type": "message",
"usage": {
"input_tokens": 2095,
"output_tokens": 503,
"cache_creation_input_tokens": 2095,
"cache_read_input_tokens": 0
}
}
```
#### Streaming example
```python showLineNumbers title="Example using LiteLLM Python SDK"
import litellm
response = await litellm.anthropic.messages.acreate(
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
api_key=api_key,
model="anthropic/claude-3-haiku-20240307",
max_tokens=100,
stream=True,
)
async for chunk in response:
print(chunk)
```
### LiteLLM Proxy Server
1. Setup config.yaml
```yaml
model_list:
- model_name: anthropic-claude
litellm_params:
model: claude-3-7-sonnet-latest
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
<Tabs>
<TabItem label="Anthropic Python SDK" value="python">
```python showLineNumbers title="Example using LiteLLM Proxy Server"
import anthropic
# point anthropic sdk to litellm proxy
client = anthropic.Anthropic(
base_url="http://0.0.0.0:4000",
api_key="sk-1234",
)
response = client.messages.create(
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
model="anthropic/claude-3-haiku-20240307",
max_tokens=100,
)
```
</TabItem>
<TabItem label="curl" value="curl">
```bash showLineNumbers title="Example using LiteLLM Proxy Server"
curl -L -X POST 'http://0.0.0.0:4000/v1/messages' \
-H 'content-type: application/json' \
-H 'x-api-key: $LITELLM_API_KEY' \
-H 'anthropic-version: 2023-06-01' \
-d '{
"model": "anthropic-claude",
"messages": [
{
"role": "user",
"content": "Hello, can you tell me a short joke?"
}
],
"max_tokens": 100
}'
```

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,161 @@
from typing import Optional, Union
import litellm
from ..exceptions import UnsupportedParamsError
from ..types.llms.openai import *
def get_optional_params_add_message(
role: Optional[str],
content: Optional[
Union[
str,
List[
Union[
MessageContentTextObject,
MessageContentImageFileObject,
MessageContentImageURLObject,
]
],
]
],
attachments: Optional[List[Attachment]],
metadata: Optional[dict],
custom_llm_provider: str,
**kwargs,
):
"""
Azure doesn't support 'attachments' for creating a message
Reference - https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
"""
passed_params = locals()
custom_llm_provider = passed_params.pop("custom_llm_provider")
special_params = passed_params.pop("kwargs")
for k, v in special_params.items():
passed_params[k] = v
default_params = {
"role": None,
"content": None,
"attachments": None,
"metadata": None,
}
non_default_params = {
k: v
for k, v in passed_params.items()
if (k in default_params and v != default_params[k])
}
optional_params = {}
## raise exception if non-default value passed for non-openai/azure embedding calls
def _check_valid_arg(supported_params):
if len(non_default_params.keys()) > 0:
keys = list(non_default_params.keys())
for k in keys:
if (
litellm.drop_params is True and k not in supported_params
): # drop the unsupported non-default values
non_default_params.pop(k, None)
elif k not in supported_params:
raise litellm.utils.UnsupportedParamsError(
status_code=500,
message="k={}, not supported by {}. Supported params={}. To drop it from the call, set `litellm.drop_params = True`.".format(
k, custom_llm_provider, supported_params
),
)
return non_default_params
if custom_llm_provider == "openai":
optional_params = non_default_params
elif custom_llm_provider == "azure":
supported_params = (
litellm.AzureOpenAIAssistantsAPIConfig().get_supported_openai_create_message_params()
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.AzureOpenAIAssistantsAPIConfig().map_openai_params_create_message_params(
non_default_params=non_default_params, optional_params=optional_params
)
for k in passed_params.keys():
if k not in default_params.keys():
optional_params[k] = passed_params[k]
return optional_params
def get_optional_params_image_gen(
n: Optional[int] = None,
quality: Optional[str] = None,
response_format: Optional[str] = None,
size: Optional[str] = None,
style: Optional[str] = None,
user: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
**kwargs,
):
# retrieve all parameters passed to the function
passed_params = locals()
custom_llm_provider = passed_params.pop("custom_llm_provider")
special_params = passed_params.pop("kwargs")
for k, v in special_params.items():
passed_params[k] = v
default_params = {
"n": None,
"quality": None,
"response_format": None,
"size": None,
"style": None,
"user": None,
}
non_default_params = {
k: v
for k, v in passed_params.items()
if (k in default_params and v != default_params[k])
}
optional_params = {}
## raise exception if non-default value passed for non-openai/azure embedding calls
def _check_valid_arg(supported_params):
if len(non_default_params.keys()) > 0:
keys = list(non_default_params.keys())
for k in keys:
if (
litellm.drop_params is True and k not in supported_params
): # drop the unsupported non-default values
non_default_params.pop(k, None)
elif k not in supported_params:
raise UnsupportedParamsError(
status_code=500,
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
)
return non_default_params
if (
custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or custom_llm_provider in litellm.openai_compatible_providers
):
optional_params = non_default_params
elif custom_llm_provider == "bedrock":
supported_params = ["size"]
_check_valid_arg(supported_params=supported_params)
if size is not None:
width, height = size.split("x")
optional_params["width"] = int(width)
optional_params["height"] = int(height)
elif custom_llm_provider == "vertex_ai":
supported_params = ["n"]
"""
All params here: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
"""
_check_valid_arg(supported_params=supported_params)
if n is not None:
optional_params["sampleCount"] = int(n)
for k in passed_params.keys():
if k not in default_params.keys():
optional_params[k] = passed_params[k]
return optional_params

View File

@@ -0,0 +1,11 @@
# Implementation of `litellm.batch_completion`, `litellm.batch_completion_models`, `litellm.batch_completion_models_all_responses`
Doc: https://docs.litellm.ai/docs/completion/batching
LiteLLM Python SDK allows you to:
1. `litellm.batch_completion` Batch litellm.completion function for a given model.
2. `litellm.batch_completion_models` Send a request to multiple language models concurrently and return the response
as soon as one of the models responds.
3. `litellm.batch_completion_models_all_responses` Send a request to multiple language models concurrently and return a list of responses
from all models that respond.

View File

@@ -0,0 +1,273 @@
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from typing import List, Optional
import litellm
from litellm._logging import print_verbose
from litellm.utils import get_optional_params
from ..llms.vllm.completion import handler as vllm_handler
def batch_completion(
model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [],
functions: Optional[List] = None,
function_call: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stream: Optional[bool] = None,
stop=None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[dict] = None,
user: Optional[str] = None,
deployment_id=None,
request_timeout: Optional[int] = None,
timeout: Optional[int] = 600,
max_workers: Optional[int] = 100,
# Optional liteLLM function params
**kwargs,
):
"""
Batch litellm.completion function for a given model.
Args:
model (str): The model to use for generating completions.
messages (List, optional): List of messages to use as input for generating completions. Defaults to [].
functions (List, optional): List of functions to use as input for generating completions. Defaults to [].
function_call (str, optional): The function call to use as input for generating completions. Defaults to "".
temperature (float, optional): The temperature parameter for generating completions. Defaults to None.
top_p (float, optional): The top-p parameter for generating completions. Defaults to None.
n (int, optional): The number of completions to generate. Defaults to None.
stream (bool, optional): Whether to stream completions or not. Defaults to None.
stop (optional): The stop parameter for generating completions. Defaults to None.
max_tokens (float, optional): The maximum number of tokens to generate. Defaults to None.
presence_penalty (float, optional): The presence penalty for generating completions. Defaults to None.
frequency_penalty (float, optional): The frequency penalty for generating completions. Defaults to None.
logit_bias (dict, optional): The logit bias for generating completions. Defaults to {}.
user (str, optional): The user string for generating completions. Defaults to "".
deployment_id (optional): The deployment ID for generating completions. Defaults to None.
request_timeout (int, optional): The request timeout for generating completions. Defaults to None.
max_workers (int,optional): The maximum number of threads to use for parallel processing.
Returns:
list: A list of completion results.
"""
args = locals()
batch_messages = messages
completions = []
model = model
custom_llm_provider = None
if model.split("/", 1)[0] in litellm.provider_list:
custom_llm_provider = model.split("/", 1)[0]
model = model.split("/", 1)[1]
if custom_llm_provider == "vllm":
optional_params = get_optional_params(
functions=functions,
function_call=function_call,
temperature=temperature,
top_p=top_p,
n=n,
stream=stream or False,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
# params to identify the model
model=model,
custom_llm_provider=custom_llm_provider,
)
results = vllm_handler.batch_completions(
model=model,
messages=batch_messages,
custom_prompt_dict=litellm.custom_prompt_dict,
optional_params=optional_params,
)
# all non VLLM models for batch completion models
else:
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for sub_batch in chunks(batch_messages, 100):
for message_list in sub_batch:
kwargs_modified = args.copy()
kwargs_modified.pop("max_workers")
kwargs_modified["messages"] = message_list
original_kwargs = {}
if "kwargs" in kwargs_modified:
original_kwargs = kwargs_modified.pop("kwargs")
future = executor.submit(
litellm.completion, **kwargs_modified, **original_kwargs
)
completions.append(future)
# Retrieve the results from the futures
# results = [future.result() for future in completions]
# return exceptions if any
results = []
for future in completions:
try:
results.append(future.result())
except Exception as exc:
results.append(exc)
return results
# send one request to multiple models
# return as soon as one of the llms responds
def batch_completion_models(*args, **kwargs):
"""
Send a request to multiple language models concurrently and return the response
as soon as one of the models responds.
Args:
*args: Variable-length positional arguments passed to the completion function.
**kwargs: Additional keyword arguments:
- models (str or list of str): The language models to send requests to.
- Other keyword arguments to be passed to the completion function.
Returns:
str or None: The response from one of the language models, or None if no response is received.
Note:
This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models.
It sends requests concurrently and returns the response from the first model that responds.
"""
if "model" in kwargs:
kwargs.pop("model")
if "models" in kwargs:
models = kwargs["models"]
kwargs.pop("models")
futures = {}
with ThreadPoolExecutor(max_workers=len(models)) as executor:
for model in models:
futures[model] = executor.submit(
litellm.completion, *args, model=model, **kwargs
)
for model, future in sorted(
futures.items(), key=lambda x: models.index(x[0])
):
if future.result() is not None:
return future.result()
elif "deployments" in kwargs:
deployments = kwargs["deployments"]
kwargs.pop("deployments")
kwargs.pop("model_list")
nested_kwargs = kwargs.pop("kwargs", {})
futures = {}
with ThreadPoolExecutor(max_workers=len(deployments)) as executor:
for deployment in deployments:
for key in kwargs.keys():
if (
key not in deployment
): # don't override deployment values e.g. model name, api base, etc.
deployment[key] = kwargs[key]
kwargs = {**deployment, **nested_kwargs}
futures[deployment["model"]] = executor.submit(
litellm.completion, **kwargs
)
while futures:
# wait for the first returned future
print_verbose("\n\n waiting for next result\n\n")
done, _ = wait(futures.values(), return_when=FIRST_COMPLETED)
print_verbose(f"done list\n{done}")
for future in done:
try:
result = future.result()
return result
except Exception:
# if model 1 fails, continue with response from model 2, model3
print_verbose(
"\n\ngot an exception, ignoring, removing from futures"
)
print_verbose(futures)
new_futures = {}
for key, value in futures.items():
if future == value:
print_verbose(f"removing key{key}")
continue
else:
new_futures[key] = value
futures = new_futures
print_verbose(f"new futures{futures}")
continue
print_verbose("\n\ndone looping through futures\n\n")
print_verbose(futures)
return None # If no response is received from any model
def batch_completion_models_all_responses(*args, **kwargs):
"""
Send a request to multiple language models concurrently and return a list of responses
from all models that respond.
Args:
*args: Variable-length positional arguments passed to the completion function.
**kwargs: Additional keyword arguments:
- models (str or list of str): The language models to send requests to.
- Other keyword arguments to be passed to the completion function.
Returns:
list: A list of responses from the language models that responded.
Note:
This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models.
It sends requests concurrently and collects responses from all models that respond.
"""
import concurrent.futures
# ANSI escape codes for colored output
if "model" in kwargs:
kwargs.pop("model")
if "models" in kwargs:
models = kwargs.pop("models")
else:
raise Exception("'models' param not in kwargs")
if isinstance(models, str):
models = [models]
elif isinstance(models, (list, tuple)):
models = list(models)
else:
raise TypeError("'models' must be a string or list of strings")
if len(models) == 0:
return []
responses = []
with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor:
futures = [
executor.submit(litellm.completion, *args, model=model, **kwargs)
for model in models
]
for future in futures:
try:
result = future.result()
if result is not None:
responses.append(result)
except Exception as e:
print_verbose(
f"batch_completion_models_all_responses: model request failed: {str(e)}"
)
continue
return responses

View File

@@ -0,0 +1,442 @@
import json
from typing import Any, List, Literal, Optional, Tuple
import litellm
from litellm._logging import verbose_logger
from litellm.types.llms.openai import Batch
from litellm.types.utils import CallTypes, ModelInfo, Usage
from litellm.utils import token_counter
async def calculate_batch_cost_and_usage(
file_content_dictionary: List[dict],
custom_llm_provider: Literal[
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
],
model_name: Optional[str] = None,
model_info: Optional[ModelInfo] = None,
) -> Tuple[float, Usage, List[str]]:
"""
Calculate the cost and usage of a batch.
Args:
model_info: Optional deployment-level model info with custom batch
pricing. Threaded through to batch_cost_calculator so that
deployment-specific pricing (e.g. input_cost_per_token_batches)
is used instead of the global cost map.
"""
batch_cost = _batch_cost_calculator(
custom_llm_provider=custom_llm_provider,
file_content_dictionary=file_content_dictionary,
model_name=model_name,
model_info=model_info,
)
batch_usage = _get_batch_job_total_usage_from_file_content(
file_content_dictionary=file_content_dictionary,
custom_llm_provider=custom_llm_provider,
model_name=model_name,
)
batch_models = _get_batch_models_from_file_content(
file_content_dictionary, model_name
)
return batch_cost, batch_usage, batch_models
async def _handle_completed_batch(
batch: Batch,
custom_llm_provider: Literal[
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
],
model_name: Optional[str] = None,
litellm_params: Optional[dict] = None,
) -> Tuple[float, Usage, List[str]]:
"""Helper function to process a completed batch and handle logging
Args:
batch: The batch object
custom_llm_provider: The LLM provider
model_name: Optional model name
litellm_params: Optional litellm parameters containing credentials (api_key, api_base, etc.)
"""
# Get batch results
file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
batch, custom_llm_provider, litellm_params=litellm_params
)
# Calculate costs and usage
batch_cost = _batch_cost_calculator(
custom_llm_provider=custom_llm_provider,
file_content_dictionary=file_content_dictionary,
model_name=model_name,
)
batch_usage = _get_batch_job_total_usage_from_file_content(
file_content_dictionary=file_content_dictionary,
custom_llm_provider=custom_llm_provider,
model_name=model_name,
)
batch_models = _get_batch_models_from_file_content(
file_content_dictionary, model_name
)
return batch_cost, batch_usage, batch_models
def _get_batch_models_from_file_content(
file_content_dictionary: List[dict],
model_name: Optional[str] = None,
) -> List[str]:
"""
Get the models from the file content
"""
if model_name:
return [model_name]
batch_models = []
for _item in file_content_dictionary:
if _batch_response_was_successful(_item):
_response_body = _get_response_from_batch_job_output_file(_item)
_model = _response_body.get("model")
if _model:
batch_models.append(_model)
return batch_models
def _batch_cost_calculator(
file_content_dictionary: List[dict],
custom_llm_provider: Literal[
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
] = "openai",
model_name: Optional[str] = None,
model_info: Optional[ModelInfo] = None,
) -> float:
"""
Calculate the cost of a batch based on the output file id
"""
# Handle Vertex AI with specialized method
if custom_llm_provider == "vertex_ai" and model_name:
batch_cost, _ = calculate_vertex_ai_batch_cost_and_usage(
file_content_dictionary, model_name
)
verbose_logger.debug("vertex_ai_total_cost=%s", batch_cost)
return batch_cost
# For other providers, use the existing logic
total_cost = _get_batch_job_cost_from_file_content(
file_content_dictionary=file_content_dictionary,
custom_llm_provider=custom_llm_provider,
model_info=model_info,
)
verbose_logger.debug("total_cost=%s", total_cost)
return total_cost
def calculate_vertex_ai_batch_cost_and_usage(
vertex_ai_batch_responses: List[dict],
model_name: Optional[str] = None,
) -> Tuple[float, Usage]:
"""
Calculate both cost and usage from Vertex AI batch responses.
Vertex AI batch output lines have format:
{"request": ..., "status": "", "response": {"candidates": [...], "usageMetadata": {...}}}
usageMetadata contains promptTokenCount, candidatesTokenCount, totalTokenCount.
"""
from litellm.cost_calculator import batch_cost_calculator
total_cost = 0.0
total_tokens = 0
prompt_tokens = 0
completion_tokens = 0
actual_model_name = model_name or "gemini-2.0-flash-001"
for response in vertex_ai_batch_responses:
response_body = response.get("response")
if response_body is None:
continue
usage_metadata = response_body.get("usageMetadata", {})
_prompt = usage_metadata.get("promptTokenCount", 0) or 0
_completion = usage_metadata.get("candidatesTokenCount", 0) or 0
_total = usage_metadata.get("totalTokenCount", 0) or (_prompt + _completion)
line_usage = Usage(
prompt_tokens=_prompt,
completion_tokens=_completion,
total_tokens=_total,
)
try:
p_cost, c_cost = batch_cost_calculator(
usage=line_usage,
model=actual_model_name,
custom_llm_provider="vertex_ai",
)
total_cost += p_cost + c_cost
except Exception as e:
verbose_logger.debug(
"vertex_ai batch cost calculation error for line: %s", str(e)
)
prompt_tokens += _prompt
completion_tokens += _completion
total_tokens += _total
verbose_logger.info(
"vertex_ai batch cost: cost=%s, prompt=%d, completion=%d, total=%d",
total_cost,
prompt_tokens,
completion_tokens,
total_tokens,
)
return total_cost, Usage(
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
async def _get_batch_output_file_content_as_dictionary(
batch: Batch,
custom_llm_provider: Literal[
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
] = "openai",
litellm_params: Optional[dict] = None,
) -> List[dict]:
"""
Get the batch output file content as a list of dictionaries
Args:
batch: The batch object
custom_llm_provider: The LLM provider
litellm_params: Optional litellm parameters containing credentials (api_key, api_base, etc.)
Required for Azure and other providers that need authentication
"""
from litellm.files.main import afile_content
from litellm.proxy.openai_files_endpoints.common_utils import (
_is_base64_encoded_unified_file_id,
)
if custom_llm_provider == "vertex_ai":
raise ValueError("Vertex AI does not support file content retrieval")
if batch.output_file_id is None:
raise ValueError("Output file id is None cannot retrieve file content")
file_id = batch.output_file_id
is_base64_unified_file_id = _is_base64_encoded_unified_file_id(file_id)
if is_base64_unified_file_id:
try:
file_id = is_base64_unified_file_id.split("llm_output_file_id,")[1].split(
";"
)[0]
verbose_logger.debug(
f"Extracted LLM output file ID from unified file ID: {file_id}"
)
except (IndexError, AttributeError) as e:
verbose_logger.error(
f"Failed to extract LLM output file ID from unified file ID: {batch.output_file_id}, error: {e}"
)
# Build kwargs for afile_content with credentials from litellm_params
file_content_kwargs = {
"file_id": file_id,
"custom_llm_provider": custom_llm_provider,
}
# Extract and add credentials for file access
credentials = _extract_file_access_credentials(litellm_params)
file_content_kwargs.update(credentials)
_file_content = await afile_content(**file_content_kwargs) # type: ignore[reportArgumentType]
return _get_file_content_as_dictionary(_file_content.content)
def _extract_file_access_credentials(litellm_params: Optional[dict]) -> dict:
"""
Extract credentials from litellm_params for file access operations.
This method extracts relevant authentication and configuration parameters
needed for accessing files across different providers (Azure, Vertex AI, etc.).
Args:
litellm_params: Dictionary containing litellm parameters with credentials
Returns:
Dictionary containing only the credentials needed for file access
"""
credentials = {}
if litellm_params:
# List of credential keys that should be passed to file operations
credential_keys = [
"api_key",
"api_base",
"api_version",
"organization",
"azure_ad_token",
"azure_ad_token_provider",
"vertex_project",
"vertex_location",
"vertex_credentials",
"timeout",
"max_retries",
]
for key in credential_keys:
if key in litellm_params:
credentials[key] = litellm_params[key]
return credentials
def _get_file_content_as_dictionary(file_content: bytes) -> List[dict]:
"""
Get the file content as a list of dictionaries from JSON Lines format
"""
try:
_file_content_str = file_content.decode("utf-8")
# Split by newlines and parse each line as a separate JSON object
json_objects = []
for line in _file_content_str.strip().split("\n"):
if line: # Skip empty lines
json_objects.append(json.loads(line))
verbose_logger.debug("json_objects=%s", json.dumps(json_objects, indent=4))
return json_objects
except Exception as e:
raise e
def _get_batch_job_cost_from_file_content(
file_content_dictionary: List[dict],
custom_llm_provider: Literal[
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
] = "openai",
model_info: Optional[ModelInfo] = None,
) -> float:
"""
Get the cost of a batch job from the file content
"""
from litellm.cost_calculator import batch_cost_calculator
try:
total_cost: float = 0.0
# parse the file content as json
verbose_logger.debug(
"file_content_dictionary=%s", json.dumps(file_content_dictionary, indent=4)
)
for _item in file_content_dictionary:
if _batch_response_was_successful(_item):
_response_body = _get_response_from_batch_job_output_file(_item)
if model_info is not None:
usage = _get_batch_job_usage_from_response_body(_response_body)
model = _response_body.get("model", "")
prompt_cost, completion_cost = batch_cost_calculator(
usage=usage,
model=model,
custom_llm_provider=custom_llm_provider,
model_info=model_info,
)
total_cost += prompt_cost + completion_cost
else:
total_cost += litellm.completion_cost(
completion_response=_response_body,
custom_llm_provider=custom_llm_provider,
call_type=CallTypes.aretrieve_batch.value,
)
verbose_logger.debug("total_cost=%s", total_cost)
return total_cost
except Exception as e:
verbose_logger.error("error in _get_batch_job_cost_from_file_content", e)
raise e
def _get_batch_job_total_usage_from_file_content(
file_content_dictionary: List[dict],
custom_llm_provider: Literal[
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
] = "openai",
model_name: Optional[str] = None,
) -> Usage:
"""
Get the tokens of a batch job from the file content
"""
# Handle Vertex AI with specialized method
if custom_llm_provider == "vertex_ai" and model_name:
_, batch_usage = calculate_vertex_ai_batch_cost_and_usage(
file_content_dictionary, model_name
)
return batch_usage
# For other providers, use the existing logic
total_tokens: int = 0
prompt_tokens: int = 0
completion_tokens: int = 0
for _item in file_content_dictionary:
if _batch_response_was_successful(_item):
_response_body = _get_response_from_batch_job_output_file(_item)
usage: Usage = _get_batch_job_usage_from_response_body(_response_body)
total_tokens += usage.total_tokens
prompt_tokens += usage.prompt_tokens
completion_tokens += usage.completion_tokens
return Usage(
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
def _get_batch_job_input_file_usage(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
model_name: Optional[str] = None,
) -> Usage:
"""
Count the number of tokens in the input file
Used for batch rate limiting to count the number of tokens in the input file
"""
prompt_tokens: int = 0
completion_tokens: int = 0
for _item in file_content_dictionary:
body = _item.get("body", {})
model = body.get("model", model_name or "")
messages = body.get("messages", [])
if messages:
item_tokens = token_counter(model=model, messages=messages)
prompt_tokens += item_tokens
return Usage(
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
def _get_batch_job_usage_from_response_body(response_body: dict) -> Usage:
"""
Get the tokens of a batch job from the response body
"""
_usage_dict = response_body.get("usage", None) or {}
usage: Usage = Usage(**_usage_dict)
return usage
def _get_response_from_batch_job_output_file(batch_job_output_file: dict) -> Any:
"""
Get the response from the batch job output file
"""
_response: dict = batch_job_output_file.get("response", None) or {}
_response_body = _response.get("body", None) or {}
return _response_body
def _batch_response_was_successful(batch_job_output_file: dict) -> bool:
"""
Check if the batch job response status == 200
"""
_response: dict = batch_job_output_file.get("response", None) or {}
return _response.get("status_code", None) == 200

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,10 @@
{
"posts": [
{
"title": "Incident Report: SERVER_ROOT_PATH regression broke UI routing",
"description": "How a single line removal caused UI 404s for all deployments using SERVER_ROOT_PATH, and the tests we added to prevent it from happening again.",
"date": "2026-02-21",
"url": "https://docs.litellm.ai/blog/server-root-path-incident"
}
]
}

View File

@@ -0,0 +1,230 @@
# +-----------------------------------------------+
# | |
# | NOT PROXY BUDGET MANAGER |
# | proxy budget manager is in proxy_server.py |
# | |
# +-----------------------------------------------+
#
# Thank you users! We ❤️ you! - Krrish & Ishaan
import json
import os
import threading
import time
from typing import Literal, Optional
import litellm
from litellm.constants import (
DAYS_IN_A_MONTH,
DAYS_IN_A_WEEK,
DAYS_IN_A_YEAR,
HOURS_IN_A_DAY,
)
from litellm.utils import ModelResponse
class BudgetManager:
def __init__(
self,
project_name: str,
client_type: str = "local",
api_base: Optional[str] = None,
headers: Optional[dict] = None,
):
self.client_type = client_type
self.project_name = project_name
self.api_base = api_base or "https://api.litellm.ai"
self.headers = headers or {"Content-Type": "application/json"}
## load the data or init the initial dictionaries
self.load_data()
def print_verbose(self, print_statement):
try:
if litellm.set_verbose:
import logging
logging.info(print_statement)
except Exception:
pass
def load_data(self):
if self.client_type == "local":
# Check if user dict file exists
if os.path.isfile("user_cost.json"):
# Load the user dict
with open("user_cost.json", "r") as json_file:
self.user_dict = json.load(json_file)
else:
self.print_verbose("User Dictionary not found!")
self.user_dict = {}
self.print_verbose(f"user dict from local: {self.user_dict}")
elif self.client_type == "hosted":
# Load the user_dict from hosted db
url = self.api_base + "/get_budget"
data = {"project_name": self.project_name}
response = litellm.module_level_client.post(
url, headers=self.headers, json=data
)
response = response.json()
if response["status"] == "error":
self.user_dict = (
{}
) # assume this means the user dict hasn't been stored yet
else:
self.user_dict = response["data"]
def create_budget(
self,
total_budget: float,
user: str,
duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None,
created_at: float = time.time(),
):
self.user_dict[user] = {"total_budget": total_budget}
if duration is None:
return self.user_dict[user]
if duration == "daily":
duration_in_days = 1
elif duration == "weekly":
duration_in_days = DAYS_IN_A_WEEK
elif duration == "monthly":
duration_in_days = DAYS_IN_A_MONTH
elif duration == "yearly":
duration_in_days = DAYS_IN_A_YEAR
else:
raise ValueError(
"""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]"""
)
self.user_dict[user] = {
"total_budget": total_budget,
"duration": duration_in_days,
"created_at": created_at,
"last_updated_at": created_at,
}
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
return self.user_dict[user]
def projected_cost(self, model: str, messages: list, user: str):
text = "".join(message["content"] for message in messages)
prompt_tokens = litellm.token_counter(model=model, text=text)
prompt_cost, _ = litellm.cost_per_token(
model=model, prompt_tokens=prompt_tokens, completion_tokens=0
)
current_cost = self.user_dict[user].get("current_cost", 0)
projected_cost = prompt_cost + current_cost
return projected_cost
def get_total_budget(self, user: str):
return self.user_dict[user]["total_budget"]
def update_cost(
self,
user: str,
completion_obj: Optional[ModelResponse] = None,
model: Optional[str] = None,
input_text: Optional[str] = None,
output_text: Optional[str] = None,
):
if model and input_text and output_text:
prompt_tokens = litellm.token_counter(
model=model, messages=[{"role": "user", "content": input_text}]
)
completion_tokens = litellm.token_counter(
model=model, messages=[{"role": "user", "content": output_text}]
)
(
prompt_tokens_cost_usd_dollar,
completion_tokens_cost_usd_dollar,
) = litellm.cost_per_token(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
elif completion_obj:
cost = litellm.completion_cost(completion_response=completion_obj)
model = completion_obj[
"model"
] # if this throws an error try, model = completion_obj['model']
else:
raise ValueError(
"Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager"
)
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get(
"current_cost", 0
)
if "model_cost" in self.user_dict[user]:
self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][
"model_cost"
].get(model, 0)
else:
self.user_dict[user]["model_cost"] = {model: cost}
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
return {"user": self.user_dict[user]}
def get_current_cost(self, user):
return self.user_dict[user].get("current_cost", 0)
def get_model_cost(self, user):
return self.user_dict[user].get("model_cost", 0)
def is_valid_user(self, user: str) -> bool:
return user in self.user_dict
def get_users(self):
return list(self.user_dict.keys())
def reset_cost(self, user):
self.user_dict[user]["current_cost"] = 0
self.user_dict[user]["model_cost"] = {}
return {"user": self.user_dict[user]}
def reset_on_duration(self, user: str):
# Get current and creation time
last_updated_at = self.user_dict[user]["last_updated_at"]
current_time = time.time()
# Convert duration from days to seconds
duration_in_seconds = (
self.user_dict[user]["duration"] * HOURS_IN_A_DAY * 60 * 60
)
# Check if duration has elapsed
if current_time - last_updated_at >= duration_in_seconds:
# Reset cost if duration has elapsed and update the creation time
self.reset_cost(user)
self.user_dict[user]["last_updated_at"] = current_time
self._save_data_thread() # Save the data
def update_budget_all_users(self):
for user in self.get_users():
if "duration" in self.user_dict[user]:
self.reset_on_duration(user)
def _save_data_thread(self):
thread = threading.Thread(
target=self.save_data
) # [Non-Blocking]: saves data without blocking execution
thread.start()
def save_data(self):
if self.client_type == "local":
import json
# save the user dict
with open("user_cost.json", "w") as json_file:
json.dump(
self.user_dict, json_file, indent=4
) # Indent for pretty formatting
return {"status": "success"}
elif self.client_type == "hosted":
url = self.api_base + "/set_budget"
data = {"project_name": self.project_name, "user_dict": self.user_dict}
response = litellm.module_level_client.post(
url, headers=self.headers, json=data
)
response = response.json()
return response

View File

@@ -0,0 +1,41 @@
# Caching on LiteLLM
LiteLLM supports multiple caching mechanisms. This allows users to choose the most suitable caching solution for their use case.
The following caching mechanisms are supported:
1. **RedisCache**
2. **RedisSemanticCache**
3. **QdrantSemanticCache**
4. **InMemoryCache**
5. **DiskCache**
6. **S3Cache**
7. **AzureBlobCache**
8. **DualCache** (updates both Redis and an in-memory cache simultaneously)
## Folder Structure
```
litellm/caching/
├── base_cache.py
├── caching.py
├── caching_handler.py
├── disk_cache.py
├── dual_cache.py
├── in_memory_cache.py
├── qdrant_semantic_cache.py
├── redis_cache.py
├── redis_semantic_cache.py
├── s3_cache.py
```
## Documentation
- [Caching on LiteLLM Gateway](https://docs.litellm.ai/docs/proxy/caching)
- [Caching on LiteLLM Python](https://docs.litellm.ai/docs/caching/all_caches)

View File

@@ -0,0 +1,11 @@
from .azure_blob_cache import AzureBlobCache
from .caching import Cache, LiteLLMCacheType
from .disk_cache import DiskCache
from .dual_cache import DualCache
from .in_memory_cache import InMemoryCache
from .qdrant_semantic_cache import QdrantSemanticCache
from .redis_cache import RedisCache
from .redis_cluster_cache import RedisClusterCache
from .redis_semantic_cache import RedisSemanticCache
from .s3_cache import S3Cache
from .gcs_cache import GCSCache

View File

@@ -0,0 +1,30 @@
from functools import lru_cache
from typing import Callable, Optional, TypeVar
T = TypeVar("T")
def lru_cache_wrapper(
maxsize: Optional[int] = None,
) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""
Wrapper for lru_cache that caches success and exceptions
"""
def decorator(f: Callable[..., T]) -> Callable[..., T]:
@lru_cache(maxsize=maxsize)
def wrapper(*args, **kwargs):
try:
return ("success", f(*args, **kwargs))
except Exception as e:
return ("error", e)
def wrapped(*args, **kwargs):
result = wrapper(*args, **kwargs)
if result[0] == "error":
raise result[1]
return result[1]
return wrapped
return decorator

View File

@@ -0,0 +1,107 @@
"""
Azure Blob Cache implementation
Has 4 methods:
- set_cache
- get_cache
- async_set_cache
- async_get_cache
"""
import asyncio
import json
from contextlib import suppress
from litellm._logging import print_verbose, verbose_logger
from .base_cache import BaseCache
class AzureBlobCache(BaseCache):
def __init__(self, account_url, container) -> None:
from azure.storage.blob import BlobServiceClient
from azure.core.exceptions import ResourceExistsError
from azure.identity import DefaultAzureCredential
from azure.identity.aio import (
DefaultAzureCredential as AsyncDefaultAzureCredential,
)
from azure.storage.blob.aio import BlobServiceClient as AsyncBlobServiceClient
self.container_client = BlobServiceClient(
account_url=account_url,
credential=DefaultAzureCredential(),
).get_container_client(container)
self.async_container_client = AsyncBlobServiceClient(
account_url=account_url,
credential=AsyncDefaultAzureCredential(),
).get_container_client(container)
with suppress(ResourceExistsError):
self.container_client.create_container()
def set_cache(self, key, value, **kwargs) -> None:
print_verbose(f"LiteLLM SET Cache - Azure Blob. Key={key}. Value={value}")
serialized_value = json.dumps(value)
try:
self.container_client.upload_blob(key, serialized_value)
except Exception as e:
# NON blocking - notify users Azure Blob is throwing an exception
print_verbose(f"LiteLLM set_cache() - Got exception from Azure Blob: {e}")
async def async_set_cache(self, key, value, **kwargs) -> None:
print_verbose(f"LiteLLM SET Cache - Azure Blob. Key={key}. Value={value}")
serialized_value = json.dumps(value)
try:
await self.async_container_client.upload_blob(
key, serialized_value, overwrite=True
)
except Exception as e:
# NON blocking - notify users Azure Blob is throwing an exception
print_verbose(f"LiteLLM set_cache() - Got exception from Azure Blob: {e}")
def get_cache(self, key, **kwargs):
from azure.core.exceptions import ResourceNotFoundError
try:
print_verbose(f"Get Azure Blob Cache: key: {key}")
as_bytes = self.container_client.download_blob(key).readall()
as_str = as_bytes.decode("utf-8")
cached_response = json.loads(as_str)
verbose_logger.debug(
f"Got Azure Blob Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
)
return cached_response
except ResourceNotFoundError:
return None
async def async_get_cache(self, key, **kwargs):
from azure.core.exceptions import ResourceNotFoundError
try:
print_verbose(f"Get Azure Blob Cache: key: {key}")
blob = await self.async_container_client.download_blob(key)
as_bytes = await blob.readall()
as_str = as_bytes.decode("utf-8")
cached_response = json.loads(as_str)
verbose_logger.debug(
f"Got Azure Blob Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
)
return cached_response
except ResourceNotFoundError:
return None
def flush_cache(self) -> None:
for blob in self.container_client.walk_blobs():
self.container_client.delete_blob(blob.name)
async def disconnect(self) -> None:
self.container_client.close()
await self.async_container_client.close()
async def async_set_cache_pipeline(self, cache_list, **kwargs) -> None:
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)

View File

@@ -0,0 +1,64 @@
"""
Base Cache implementation. All cache implementations should inherit from this class.
Has 4 methods:
- set_cache
- get_cache
- async_set_cache
- async_get_cache
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional, Union
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
class BaseCache(ABC):
def __init__(self, default_ttl: int = 60):
self.default_ttl = default_ttl
def get_ttl(self, **kwargs) -> Optional[int]:
kwargs_ttl: Optional[int] = kwargs.get("ttl")
if kwargs_ttl is not None:
try:
return int(kwargs_ttl)
except ValueError:
return self.default_ttl
return self.default_ttl
def set_cache(self, key, value, **kwargs):
raise NotImplementedError
async def async_set_cache(self, key, value, **kwargs):
raise NotImplementedError
@abstractmethod
async def async_set_cache_pipeline(self, cache_list, **kwargs):
pass
def get_cache(self, key, **kwargs):
raise NotImplementedError
async def async_get_cache(self, key, **kwargs):
raise NotImplementedError
async def batch_cache_write(self, key, value, **kwargs):
raise NotImplementedError
async def disconnect(self):
raise NotImplementedError
async def test_connection(self) -> dict:
"""
Test the cache connection.
Returns:
dict: {"status": "success" | "failed", "message": str, "error": Optional[str]}
"""
raise NotImplementedError

View File

@@ -0,0 +1,926 @@
# +-----------------------------------------------+
# | |
# | Give Feedback / Get Help |
# | https://github.com/BerriAI/litellm/issues/new |
# | |
# +-----------------------------------------------+
#
# Thank you users! We ❤️ you! - Krrish & Ishaan
import ast
import hashlib
import json
import time
import traceback
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.constants import CACHED_STREAMING_CHUNK_DELAY
from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
from litellm.types.caching import *
from litellm.types.utils import EmbeddingResponse, all_litellm_params
from .azure_blob_cache import AzureBlobCache
from .base_cache import BaseCache
from .disk_cache import DiskCache
from .dual_cache import DualCache # noqa
from .gcs_cache import GCSCache
from .in_memory_cache import InMemoryCache
from .qdrant_semantic_cache import QdrantSemanticCache
from .redis_cache import RedisCache
from .redis_cluster_cache import RedisClusterCache
from .redis_semantic_cache import RedisSemanticCache
from .s3_cache import S3Cache
def print_verbose(print_statement):
try:
verbose_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except Exception:
pass
class CacheMode(str, Enum):
default_on = "default_on"
default_off = "default_off"
#### LiteLLM.Completion / Embedding Cache ####
class Cache:
def __init__(
self,
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
mode: Optional[
CacheMode
] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
namespace: Optional[str] = None,
ttl: Optional[float] = None,
default_in_memory_ttl: Optional[float] = None,
default_in_redis_ttl: Optional[float] = None,
similarity_threshold: Optional[float] = None,
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
"atext_completion",
"text_completion",
"arerank",
"rerank",
"responses",
"aresponses",
],
# s3 Bucket, boto3 configuration
azure_account_url: Optional[str] = None,
azure_blob_container: Optional[str] = None,
s3_bucket_name: Optional[str] = None,
s3_region_name: Optional[str] = None,
s3_api_version: Optional[str] = None,
s3_use_ssl: Optional[bool] = True,
s3_verify: Optional[Union[bool, str]] = None,
s3_endpoint_url: Optional[str] = None,
s3_aws_access_key_id: Optional[str] = None,
s3_aws_secret_access_key: Optional[str] = None,
s3_aws_session_token: Optional[str] = None,
s3_config: Optional[Any] = None,
s3_path: Optional[str] = None,
gcs_bucket_name: Optional[str] = None,
gcs_path_service_account: Optional[str] = None,
gcs_path: Optional[str] = None,
redis_semantic_cache_embedding_model: str = "text-embedding-ada-002",
redis_semantic_cache_index_name: Optional[str] = None,
redis_flush_size: Optional[int] = None,
redis_startup_nodes: Optional[List] = None,
disk_cache_dir: Optional[str] = None,
qdrant_api_base: Optional[str] = None,
qdrant_api_key: Optional[str] = None,
qdrant_collection_name: Optional[str] = None,
qdrant_quantization_config: Optional[str] = None,
qdrant_semantic_cache_embedding_model: str = "text-embedding-ada-002",
qdrant_semantic_cache_vector_size: Optional[int] = None,
# GCP IAM authentication parameters
gcp_service_account: Optional[str] = None,
gcp_ssl_ca_certs: Optional[str] = None,
**kwargs,
):
"""
Initializes the cache based on the given type.
Args:
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
# Redis Cache Args
host (str, optional): The host address for the Redis cache. Required if type is "redis".
port (int, optional): The port number for the Redis cache. Required if type is "redis".
password (str, optional): The password for the Redis cache. Required if type is "redis".
namespace (str, optional): The namespace for the Redis cache. Required if type is "redis".
ttl (float, optional): The ttl for the Redis cache
redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used.
redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None.
# Qdrant Cache Args
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
# Disk Cache Args
disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None.
# S3 Cache Args
s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None.
s3_region_name (str, optional): The region name for the s3 cache. Defaults to None.
s3_api_version (str, optional): The api version for the s3 cache. Defaults to None.
s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True.
s3_verify (bool, optional): The verify for the s3 cache. Defaults to None.
s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None.
s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None.
s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None.
s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None.
s3_config (dict, optional): The config for the s3 cache. Defaults to None.
# GCS Cache Args
gcs_bucket_name (str, optional): The bucket name for the gcs cache. Defaults to None.
gcs_path_service_account (str, optional): Path to the service account json.
gcs_path (str, optional): Folder path inside the bucket to store cache files.
# Common Cache Args
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
**kwargs: Additional keyword arguments for redis.Redis() cache
Raises:
ValueError: If an invalid cache type is provided.
Returns:
None. Cache is set as a litellm param
"""
if type == LiteLLMCacheType.REDIS:
# Check REDIS_CLUSTER_NODES env var if no explicit startup nodes
if not redis_startup_nodes:
_env_cluster_nodes = litellm.get_secret("REDIS_CLUSTER_NODES")
if _env_cluster_nodes is not None and isinstance(
_env_cluster_nodes, str
):
redis_startup_nodes = json.loads(_env_cluster_nodes)
if redis_startup_nodes:
# Only pass GCP parameters if they are provided
cluster_kwargs = {
"host": host,
"port": port,
"password": password,
"redis_flush_size": redis_flush_size,
"startup_nodes": redis_startup_nodes,
**kwargs,
}
if gcp_service_account is not None:
cluster_kwargs["gcp_service_account"] = gcp_service_account
if gcp_ssl_ca_certs is not None:
cluster_kwargs["gcp_ssl_ca_certs"] = gcp_ssl_ca_certs
self.cache: BaseCache = RedisClusterCache(**cluster_kwargs)
else:
self.cache = RedisCache(
host=host,
port=port,
password=password,
redis_flush_size=redis_flush_size,
**kwargs,
)
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
self.cache = RedisSemanticCache(
host=host,
port=port,
password=password,
similarity_threshold=similarity_threshold,
embedding_model=redis_semantic_cache_embedding_model,
index_name=redis_semantic_cache_index_name,
**kwargs,
)
elif type == LiteLLMCacheType.QDRANT_SEMANTIC:
self.cache = QdrantSemanticCache(
qdrant_api_base=qdrant_api_base,
qdrant_api_key=qdrant_api_key,
collection_name=qdrant_collection_name,
similarity_threshold=similarity_threshold,
quantization_config=qdrant_quantization_config,
embedding_model=qdrant_semantic_cache_embedding_model,
vector_size=qdrant_semantic_cache_vector_size,
)
elif type == LiteLLMCacheType.LOCAL:
self.cache = InMemoryCache()
elif type == LiteLLMCacheType.S3:
self.cache = S3Cache(
s3_bucket_name=s3_bucket_name,
s3_region_name=s3_region_name,
s3_api_version=s3_api_version,
s3_use_ssl=s3_use_ssl,
s3_verify=s3_verify,
s3_endpoint_url=s3_endpoint_url,
s3_aws_access_key_id=s3_aws_access_key_id,
s3_aws_secret_access_key=s3_aws_secret_access_key,
s3_aws_session_token=s3_aws_session_token,
s3_config=s3_config,
s3_path=s3_path,
**kwargs,
)
elif type == LiteLLMCacheType.GCS:
self.cache = GCSCache(
bucket_name=gcs_bucket_name,
path_service_account=gcs_path_service_account,
gcs_path=gcs_path,
)
elif type == LiteLLMCacheType.AZURE_BLOB:
self.cache = AzureBlobCache(
account_url=azure_account_url,
container=azure_blob_container,
)
elif type == LiteLLMCacheType.DISK:
self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
if "cache" not in litellm.input_callback:
litellm.input_callback.append("cache")
if "cache" not in litellm.success_callback:
litellm.logging_callback_manager.add_litellm_success_callback("cache")
if "cache" not in litellm._async_success_callback:
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
self.type = type
self.namespace = namespace
self.redis_flush_size = redis_flush_size
self.ttl = ttl
self.mode: CacheMode = mode or CacheMode.default_on
if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None:
self.ttl = default_in_memory_ttl
if (
self.type == LiteLLMCacheType.REDIS
or self.type == LiteLLMCacheType.REDIS_SEMANTIC
) and default_in_redis_ttl is not None:
self.ttl = default_in_redis_ttl
if self.namespace is not None and isinstance(self.cache, RedisCache):
self.cache.namespace = self.namespace
def get_cache_key(self, **kwargs) -> str:
"""
Get the cache key for the given arguments.
Args:
**kwargs: kwargs to litellm.completion() or embedding()
Returns:
str: The cache key generated from the arguments, or None if no cache key could be generated.
"""
cache_key = ""
# verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
if preset_cache_key is not None:
verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key)
return preset_cache_key
combined_kwargs = ModelParamHelper._get_all_llm_api_params()
litellm_param_kwargs = all_litellm_params
for param in kwargs:
if param in combined_kwargs:
param_value: Optional[str] = self._get_param_value(param, kwargs)
if param_value is not None:
cache_key += f"{str(param)}: {str(param_value)}"
elif (
param not in litellm_param_kwargs
): # check if user passed in optional param - e.g. top_k
if (
litellm.enable_caching_on_provider_specific_optional_params is True
): # feature flagged for now
if kwargs[param] is None:
continue # ignore None params
param_value = kwargs[param]
cache_key += f"{str(param)}: {str(param_value)}"
verbose_logger.debug("\nCreated cache key: %s", cache_key)
hashed_cache_key = Cache._get_hashed_cache_key(cache_key)
hashed_cache_key = self._add_namespace_to_cache_key(hashed_cache_key, **kwargs)
self._set_preset_cache_key_in_kwargs(
preset_cache_key=hashed_cache_key, **kwargs
)
return hashed_cache_key
def _get_param_value(
self,
param: str,
kwargs: dict,
) -> Optional[str]:
"""
Get the value for the given param from kwargs
"""
if param == "model":
return self._get_model_param_value(kwargs)
elif param == "file":
return self._get_file_param_value(kwargs)
return kwargs[param]
def _get_model_param_value(self, kwargs: dict) -> str:
"""
Handles getting the value for the 'model' param from kwargs
1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups
2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router()
3. Else use the `model` passed in kwargs
"""
metadata: Dict = kwargs.get("metadata", {}) or {}
litellm_params: Dict = kwargs.get("litellm_params", {}) or {}
metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {}
model_group: Optional[str] = metadata.get(
"model_group"
) or metadata_in_litellm_params.get("model_group")
caching_group = self._get_caching_group(metadata, model_group)
return caching_group or model_group or kwargs["model"]
def _get_caching_group(
self, metadata: dict, model_group: Optional[str]
) -> Optional[str]:
caching_groups: Optional[List] = metadata.get("caching_groups", [])
if caching_groups:
for group in caching_groups:
if model_group in group:
return str(group)
return None
def _get_file_param_value(self, kwargs: dict) -> str:
"""
Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests
"""
file = kwargs.get("file")
metadata = kwargs.get("metadata", {})
litellm_params = kwargs.get("litellm_params", {})
return (
metadata.get("file_checksum")
or getattr(file, "name", None)
or metadata.get("file_name")
or litellm_params.get("file_name")
)
def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
"""
Get the preset cache key from kwargs["litellm_params"]
We use _get_preset_cache_keys for two reasons
1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
2. avoid doing duplicate / repeated work
"""
if kwargs:
if "litellm_params" in kwargs:
return kwargs["litellm_params"].get("preset_cache_key", None)
return None
def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None:
"""
Set the calculated cache key in kwargs
This is used to avoid doing duplicate / repeated work
Placed in kwargs["litellm_params"]
"""
if kwargs:
if "litellm_params" in kwargs:
kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key
@staticmethod
def _get_hashed_cache_key(cache_key: str) -> str:
"""
Get the hashed cache key for the given cache key.
Use hashlib to create a sha256 hash of the cache key
Args:
cache_key (str): The cache key to hash.
Returns:
str: The hashed cache key.
"""
hash_object = hashlib.sha256(cache_key.encode())
# Hexadecimal representation of the hash
hash_hex = hash_object.hexdigest()
verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex)
return hash_hex
def _add_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str:
"""
If a redis namespace is provided, add it to the cache key
Args:
hash_hex (str): The hashed cache key.
**kwargs: Additional keyword arguments.
Returns:
str: The final hashed cache key with the redis namespace.
"""
dynamic_cache_control: DynamicCacheControl = kwargs.get("cache", {})
namespace = (
dynamic_cache_control.get("namespace")
or kwargs.get("metadata", {}).get("redis_namespace")
or self.namespace
)
if namespace:
hash_hex = f"{namespace}:{hash_hex}"
verbose_logger.debug("Final hashed key: %s", hash_hex)
return hash_hex
def generate_streaming_content(self, content):
chunk_size = 5 # Adjust the chunk size as needed
for i in range(0, len(content), chunk_size):
yield {
"choices": [
{
"delta": {
"role": "assistant",
"content": content[i : i + chunk_size],
}
}
]
}
time.sleep(CACHED_STREAMING_CHUNK_DELAY)
def _get_cache_logic(
self,
cached_result: Optional[Any],
max_age: Optional[float],
):
"""
Common get cache logic across sync + async implementations
"""
# Check if a timestamp was stored with the cached response
if (
cached_result is not None
and isinstance(cached_result, dict)
and "timestamp" in cached_result
):
timestamp = cached_result["timestamp"]
current_time = time.time()
# Calculate age of the cached response
response_age = current_time - timestamp
# Check if the cached response is older than the max-age
if max_age is not None and response_age > max_age:
return None # Cached response is too old
# If the response is fresh, or there's no max-age requirement, return the cached response
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_result.get("response")
try:
if isinstance(cached_response, dict):
pass
else:
cached_response = json.loads(
cached_response # type: ignore
) # Convert string to dictionary
except Exception:
cached_response = ast.literal_eval(cached_response) # type: ignore
return cached_response
return cached_result
def get_cache(self, dynamic_cache_object: Optional[BaseCache] = None, **kwargs):
"""
Retrieves the cached result for the given arguments.
Args:
*args: args to litellm.completion() or embedding()
**kwargs: kwargs to litellm.completion() or embedding()
Returns:
The cached result if it exists, otherwise None.
"""
try: # never block execution
if self.should_use_cache(**kwargs) is not True:
return
messages = kwargs.get("messages", [])
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(**kwargs)
if cache_key is not None:
cache_control_args: DynamicCacheControl = kwargs.get("cache", {})
max_age = (
cache_control_args.get("s-maxage")
or cache_control_args.get("s-max-age")
or float("inf")
)
if dynamic_cache_object is not None:
cached_result = dynamic_cache_object.get_cache(
cache_key, messages=messages
)
else:
cached_result = self.cache.get_cache(cache_key, messages=messages)
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None
async def async_get_cache(
self, dynamic_cache_object: Optional[BaseCache] = None, **kwargs
):
"""
Async get cache implementation.
Used for embedding calls in async wrapper
"""
try: # never block execution
if self.should_use_cache(**kwargs) is not True:
return
kwargs.get("messages", [])
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(**kwargs)
if cache_key is not None:
cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
if dynamic_cache_object is not None:
cached_result = await dynamic_cache_object.async_get_cache(
cache_key, **kwargs
)
else:
cached_result = await self.cache.async_get_cache(
cache_key, **kwargs
)
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None
def _add_cache_logic(self, result, **kwargs):
"""
Common implementation across sync + async add_cache functions
"""
try:
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(**kwargs)
if cache_key is not None:
if isinstance(result, BaseModel):
result = result.model_dump_json()
## DEFAULT TTL ##
if self.ttl is not None:
kwargs["ttl"] = self.ttl
## Get Cache-Controls ##
_cache_kwargs = kwargs.get("cache", None)
if isinstance(_cache_kwargs, dict):
for k, v in _cache_kwargs.items():
if k == "ttl":
kwargs["ttl"] = v
cached_data = {"timestamp": time.time(), "response": result}
return cache_key, cached_data, kwargs
else:
raise Exception("cache key is None")
except Exception as e:
raise e
def add_cache(self, result, **kwargs):
"""
Adds a result to the cache.
Args:
*args: args to litellm.completion() or embedding()
**kwargs: kwargs to litellm.completion() or embedding()
Returns:
None
"""
try:
if self.should_use_cache(**kwargs) is not True:
return
cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, **kwargs
)
self.cache.set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
async def async_add_cache(
self, result, dynamic_cache_object: Optional[BaseCache] = None, **kwargs
):
"""
Async implementation of add_cache
"""
try:
if self.should_use_cache(**kwargs) is not True:
return
if self.type == "redis" and self.redis_flush_size is not None:
# high traffic - fill in results in memory and then flush
await self.batch_cache_write(result, **kwargs)
else:
cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, **kwargs
)
if dynamic_cache_object is not None:
await dynamic_cache_object.async_set_cache(
cache_key, cached_data, **kwargs
)
else:
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
def _convert_to_cached_embedding(
self, embedding_response: Any, model: Optional[str]
) -> CachedEmbedding:
"""
Convert any embedding response into the standardized CachedEmbedding TypedDict format.
"""
try:
if isinstance(embedding_response, dict):
return {
"embedding": embedding_response.get("embedding"),
"index": embedding_response.get("index"),
"object": embedding_response.get("object"),
"model": model,
}
elif hasattr(embedding_response, "model_dump"):
data = embedding_response.model_dump()
return {
"embedding": data.get("embedding"),
"index": data.get("index"),
"object": data.get("object"),
"model": model,
}
else:
data = vars(embedding_response)
return {
"embedding": data.get("embedding"),
"index": data.get("index"),
"object": data.get("object"),
"model": model,
}
except KeyError as e:
raise ValueError(f"Missing expected key in embedding response: {e}")
def add_embedding_response_to_cache(
self,
result: EmbeddingResponse,
input: str,
kwargs: dict,
idx_in_result_data: int = 0,
) -> Tuple[str, dict, dict]:
preset_cache_key = self.get_cache_key(**{**kwargs, "input": input})
kwargs["cache_key"] = preset_cache_key
embedding_response = result.data[idx_in_result_data]
# Always convert to properly typed CachedEmbedding
model_name = result.model
embedding_dict: CachedEmbedding = self._convert_to_cached_embedding(
embedding_response, model_name
)
cache_key, cached_data, kwargs = self._add_cache_logic(
result=embedding_dict,
**kwargs,
)
return cache_key, cached_data, kwargs
async def async_add_cache_pipeline(
self, result, dynamic_cache_object: Optional[BaseCache] = None, **kwargs
):
"""
Async implementation of add_cache for Embedding calls
Does a bulk write, to prevent using too many clients
"""
try:
if self.should_use_cache(**kwargs) is not True:
return
# set default ttl if not set
if self.ttl is not None:
kwargs["ttl"] = self.ttl
cache_list = []
if isinstance(kwargs["input"], list):
for idx, i in enumerate(kwargs["input"]):
(
cache_key,
cached_data,
kwargs,
) = self.add_embedding_response_to_cache(result, i, kwargs, idx)
cache_list.append((cache_key, cached_data))
elif isinstance(kwargs["input"], str):
cache_key, cached_data, kwargs = self.add_embedding_response_to_cache(
result, kwargs["input"], kwargs
)
cache_list.append((cache_key, cached_data))
if dynamic_cache_object is not None:
await dynamic_cache_object.async_set_cache_pipeline(
cache_list=cache_list, **kwargs
)
else:
await self.cache.async_set_cache_pipeline(
cache_list=cache_list, **kwargs
)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
def should_use_cache(self, **kwargs):
"""
Returns true if we should use the cache for LLM API calls
If cache is default_on then this is True
If cache is default_off then this is only true when user has opted in to use cache
"""
if self.mode == CacheMode.default_on:
return True
# when mode == default_off -> Cache is opt in only
_cache = kwargs.get("cache", None)
verbose_logger.debug("should_use_cache: kwargs: %s; _cache: %s", kwargs, _cache)
if _cache and isinstance(_cache, dict):
if _cache.get("use-cache", False) is True:
return True
return False
async def batch_cache_write(self, result, **kwargs):
cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
async def ping(self):
cache_ping = getattr(self.cache, "ping")
if cache_ping:
return await cache_ping()
return None
async def delete_cache_keys(self, keys):
cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys")
if cache_delete_cache_keys:
return await cache_delete_cache_keys(keys)
return None
async def disconnect(self):
if hasattr(self.cache, "disconnect"):
await self.cache.disconnect()
def _supports_async(self) -> bool:
"""
Internal method to check if the cache type supports async get/set operations
All cache types now support async operations
"""
return True
def enable_cache(
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
"atext_completion",
"text_completion",
"arerank",
"rerank",
"responses",
"aresponses",
],
**kwargs,
):
"""
Enable cache with the specified configuration.
Args:
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
host (Optional[str]): The host address of the cache server. Defaults to None.
port (Optional[str]): The port number of the cache server. Defaults to None.
password (Optional[str]): The password for the cache server. Defaults to None.
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
**kwargs: Additional keyword arguments.
Returns:
None
Raises:
None
"""
print_verbose("LiteLLM: Enabling Cache")
if "cache" not in litellm.input_callback:
litellm.input_callback.append("cache")
if "cache" not in litellm.success_callback:
litellm.logging_callback_manager.add_litellm_success_callback("cache")
if "cache" not in litellm._async_success_callback:
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
if litellm.cache is None:
litellm.cache = Cache(
type=type,
host=host,
port=port,
password=password,
supported_call_types=supported_call_types,
**kwargs,
)
print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}")
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
def update_cache(
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
"atext_completion",
"text_completion",
"arerank",
"rerank",
"responses",
"aresponses",
],
**kwargs,
):
"""
Update the cache for LiteLLM.
Args:
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
host (Optional[str]): The host of the cache. Defaults to None.
port (Optional[str]): The port of the cache. Defaults to None.
password (Optional[str]): The password for the cache. Defaults to None.
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
**kwargs: Additional keyword arguments for the cache.
Returns:
None
"""
print_verbose("LiteLLM: Updating Cache")
litellm.cache = Cache(
type=type,
host=host,
port=port,
password=password,
supported_call_types=supported_call_types,
**kwargs,
)
print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}")
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
def disable_cache():
"""
Disable the cache used by LiteLLM.
This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None.
Parameters:
None
Returns:
None
"""
from contextlib import suppress
print_verbose("LiteLLM: Disabling Cache")
with suppress(ValueError):
litellm.input_callback.remove("cache")
litellm.success_callback.remove("cache")
litellm._async_success_callback.remove("cache")
litellm.cache = None
print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,93 @@
import json
from typing import TYPE_CHECKING, Any, Optional, Union
from .base_cache import BaseCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
class DiskCache(BaseCache):
def __init__(self, disk_cache_dir: Optional[str] = None):
try:
import diskcache as dc
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"Please install litellm with `litellm[caching]` to use disk caching."
) from e
# if users don't provider one, use the default litellm cache
if disk_cache_dir is None:
self.disk_cache = dc.Cache(".litellm_cache")
else:
self.disk_cache = dc.Cache(disk_cache_dir)
def set_cache(self, key, value, **kwargs):
if "ttl" in kwargs:
self.disk_cache.set(key, value, expire=kwargs["ttl"])
else:
self.disk_cache.set(key, value)
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
async def async_set_cache_pipeline(self, cache_list, **kwargs):
for cache_key, cache_value in cache_list:
if "ttl" in kwargs:
self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"])
else:
self.set_cache(key=cache_key, value=cache_value)
def get_cache(self, key, **kwargs):
original_cached_response = self.disk_cache.get(key)
if original_cached_response:
try:
cached_response = json.loads(original_cached_response) # type: ignore
except Exception:
cached_response = original_cached_response
return cached_response
return None
def batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
def increment_cache(self, key, value: int, **kwargs) -> int:
# get the value
init_value = self.get_cache(key=key) or 0
value = init_value + value # type: ignore
self.set_cache(key, value, **kwargs)
return value
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
async def async_batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_increment(self, key, value: int, **kwargs) -> int:
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value # type: ignore
await self.async_set_cache(key, value, **kwargs)
return value
def flush_cache(self):
self.disk_cache.clear()
async def disconnect(self):
pass
def delete_cache(self, key):
self.disk_cache.pop(key)

View File

@@ -0,0 +1,506 @@
"""
Dual Cache implementation - Class to update both Redis and an in-memory cache simultaneously.
Has 4 primary methods:
- set_cache
- get_cache
- async_set_cache
- async_get_cache
"""
import asyncio
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from threading import Lock
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
if TYPE_CHECKING:
from litellm.types.caching import RedisPipelineIncrementOperation
import litellm
from litellm._logging import print_verbose, verbose_logger
from litellm.constants import DEFAULT_MAX_REDIS_BATCH_CACHE_SIZE
from .base_cache import BaseCache
from .in_memory_cache import InMemoryCache
from .redis_cache import RedisCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
from collections import OrderedDict
class LimitedSizeOrderedDict(OrderedDict):
def __init__(self, *args, max_size=100, **kwargs):
super().__init__(*args, **kwargs)
self.max_size = max_size
def __setitem__(self, key, value):
# If inserting a new key exceeds max size, remove the oldest item
if len(self) >= self.max_size:
self.popitem(last=False)
super().__setitem__(key, value)
class DualCache(BaseCache):
"""
DualCache is a cache implementation that updates both Redis and an in-memory cache simultaneously.
When data is updated or inserted, it is written to both the in-memory cache + Redis.
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
"""
def __init__(
self,
in_memory_cache: Optional[InMemoryCache] = None,
redis_cache: Optional[RedisCache] = None,
default_in_memory_ttl: Optional[float] = None,
default_redis_ttl: Optional[float] = None,
default_redis_batch_cache_expiry: Optional[float] = None,
default_max_redis_batch_cache_size: int = DEFAULT_MAX_REDIS_BATCH_CACHE_SIZE,
) -> None:
super().__init__()
# If in_memory_cache is not provided, use the default InMemoryCache
self.in_memory_cache = in_memory_cache or InMemoryCache()
# If redis_cache is not provided, use the default RedisCache
self.redis_cache = redis_cache
self.last_redis_batch_access_time = LimitedSizeOrderedDict(
max_size=default_max_redis_batch_cache_size
)
self._last_redis_batch_access_time_lock = Lock()
self.redis_batch_cache_expiry = (
default_redis_batch_cache_expiry
or litellm.default_redis_batch_cache_expiry
or 10
)
self.default_in_memory_ttl = (
default_in_memory_ttl or litellm.default_in_memory_ttl
)
self.default_redis_ttl = default_redis_ttl or litellm.default_redis_ttl
def update_cache_ttl(
self, default_in_memory_ttl: Optional[float], default_redis_ttl: Optional[float]
):
if default_in_memory_ttl is not None:
self.default_in_memory_ttl = default_in_memory_ttl
if default_redis_ttl is not None:
self.default_redis_ttl = default_redis_ttl
def set_cache(self, key, value, local_only: bool = False, **kwargs):
# Update both Redis and in-memory cache
try:
if self.in_memory_cache is not None:
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
kwargs["ttl"] = self.default_in_memory_ttl
self.in_memory_cache.set_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only is False:
self.redis_cache.set_cache(key, value, **kwargs)
except Exception as e:
print_verbose(e)
def increment_cache(
self, key, value: int, local_only: bool = False, **kwargs
) -> int:
"""
Key - the key in cache
Value - int - the value you want to increment by
Returns - int - the incremented value
"""
try:
result: int = value
if self.in_memory_cache is not None:
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only is False:
result = self.redis_cache.increment_cache(key, value, **kwargs)
return result
except Exception as e:
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
raise e
def get_cache(
self,
key,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
):
# Try to fetch from in-memory cache first
try:
result = None
if self.in_memory_cache is not None:
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
if in_memory_result is not None:
result = in_memory_result
if result is None and self.redis_cache is not None and local_only is False:
# If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.get_cache(
key, parent_otel_span=parent_otel_span
)
if redis_result is not None:
# Update in-memory cache with the value from Redis
self.in_memory_cache.set_cache(key, redis_result, **kwargs)
result = redis_result
print_verbose(f"get cache: cache result: {result}")
return result
except Exception:
verbose_logger.error(traceback.format_exc())
def batch_get_cache(
self,
keys: list,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
):
received_args = locals()
received_args.pop("self")
def run_in_new_loop():
"""Run the coroutine in a new event loop within this thread."""
new_loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(new_loop)
return new_loop.run_until_complete(
self.async_batch_get_cache(**received_args)
)
finally:
new_loop.close()
asyncio.set_event_loop(None)
try:
# First, try to get the current event loop
_ = asyncio.get_running_loop()
# If we're already in an event loop, run in a separate thread
# to avoid nested event loop issues
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
return future.result()
except RuntimeError:
# No running event loop, we can safely run in this thread
return run_in_new_loop()
async def async_get_cache(
self,
key,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
):
# Try to fetch from in-memory cache first
try:
print_verbose(
f"async get cache: cache key: {key}; local_only: {local_only}"
)
result = None
if self.in_memory_cache is not None:
in_memory_result = await self.in_memory_cache.async_get_cache(
key, **kwargs
)
print_verbose(f"in_memory_result: {in_memory_result}")
if in_memory_result is not None:
result = in_memory_result
if result is None and self.redis_cache is not None and local_only is False:
# If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_get_cache(
key, parent_otel_span=parent_otel_span
)
if redis_result is not None:
# Update in-memory cache with the value from Redis
await self.in_memory_cache.async_set_cache(
key, redis_result, **kwargs
)
result = redis_result
print_verbose(f"get cache: cache result: {result}")
return result
except Exception:
verbose_logger.error(traceback.format_exc())
def _reserve_redis_batch_keys(
self,
current_time: float,
keys: List[str],
result: List[Any],
) -> Tuple[List[str], Dict[str, Optional[float]]]:
"""
Atomically choose keys to fetch from Redis and reserve their access time.
This prevents check-then-act races under concurrent async callers.
"""
sublist_keys: List[str] = []
previous_access_times: Dict[str, Optional[float]] = {}
with self._last_redis_batch_access_time_lock:
for key, value in zip(keys, result):
if value is not None:
continue
if (
key not in self.last_redis_batch_access_time
or current_time - self.last_redis_batch_access_time[key]
>= self.redis_batch_cache_expiry
):
sublist_keys.append(key)
previous_access_times[key] = self.last_redis_batch_access_time.get(
key
)
self.last_redis_batch_access_time[key] = current_time
return sublist_keys, previous_access_times
def _rollback_redis_batch_key_reservations(
self, previous_access_times: Dict[str, Optional[float]]
) -> None:
with self._last_redis_batch_access_time_lock:
for key, previous_time in previous_access_times.items():
if previous_time is None:
self.last_redis_batch_access_time.pop(key, None)
else:
self.last_redis_batch_access_time[key] = previous_time
async def async_batch_get_cache(
self,
keys: list,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
):
try:
result = [None] * len(keys)
if self.in_memory_cache is not None:
in_memory_result = await self.in_memory_cache.async_batch_get_cache(
keys, **kwargs
)
if in_memory_result is not None:
result = in_memory_result
if None in result and self.redis_cache is not None and local_only is False:
"""
- for the none values in the result
- check the redis cache
"""
current_time = time.time()
sublist_keys, previous_access_times = self._reserve_redis_batch_keys(
current_time, keys, result
)
# Only hit Redis if enough time has passed since last access.
if len(sublist_keys) > 0:
try:
# If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_batch_get_cache(
sublist_keys, parent_otel_span=parent_otel_span
)
except Exception:
# Do not throttle subsequent callers if the Redis read fails.
self._rollback_redis_batch_key_reservations(
previous_access_times
)
raise
# Short-circuit if redis_result is None or contains only None values
if redis_result is None or all(
v is None for v in redis_result.values()
):
return result
# Pre-compute key-to-index mapping for O(1) lookup
key_to_index = {key: i for i, key in enumerate(keys)}
# Update both result and in-memory cache in a single loop
for key, value in redis_result.items():
result[key_to_index[key]] = value
if value is not None and self.in_memory_cache is not None:
await self.in_memory_cache.async_set_cache(
key, value, **kwargs
)
return result
except Exception:
verbose_logger.error(traceback.format_exc())
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
print_verbose(
f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}"
)
try:
if self.in_memory_cache is not None:
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
kwargs["ttl"] = self.default_in_memory_ttl
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only is False:
await self.redis_cache.async_set_cache(key, value, **kwargs)
except Exception as e:
verbose_logger.exception(
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
)
# async_batch_set_cache
async def async_set_cache_pipeline(
self, cache_list: list, local_only: bool = False, **kwargs
):
"""
Batch write values to the cache
"""
print_verbose(
f"async batch set cache: cache keys: {cache_list}; local_only: {local_only}"
)
try:
if self.in_memory_cache is not None:
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
kwargs["ttl"] = self.default_in_memory_ttl
await self.in_memory_cache.async_set_cache_pipeline(
cache_list=cache_list, **kwargs
)
if self.redis_cache is not None and local_only is False:
await self.redis_cache.async_set_cache_pipeline(
cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs
)
except Exception as e:
verbose_logger.exception(
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
)
async def async_increment_cache(
self,
key,
value: float,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
) -> float:
"""
Key - the key in cache
Value - float - the value you want to increment by
Returns - float - the incremented value
"""
try:
result: float = value
if self.in_memory_cache is not None:
result = await self.in_memory_cache.async_increment(
key, value, **kwargs
)
if self.redis_cache is not None and local_only is False:
result = await self.redis_cache.async_increment(
key,
value,
parent_otel_span=parent_otel_span,
ttl=kwargs.get("ttl", None),
)
return result
except Exception as e:
raise e # don't log if exception is raised
async def async_increment_cache_pipeline(
self,
increment_list: List["RedisPipelineIncrementOperation"],
local_only: bool = False,
parent_otel_span: Optional[Span] = None,
**kwargs,
) -> Optional[List[float]]:
try:
result: Optional[List[float]] = None
if self.in_memory_cache is not None:
result = await self.in_memory_cache.async_increment_pipeline(
increment_list=increment_list,
parent_otel_span=parent_otel_span,
)
if self.redis_cache is not None and local_only is False:
result = await self.redis_cache.async_increment_pipeline(
increment_list=increment_list,
parent_otel_span=parent_otel_span,
)
return result
except Exception as e:
raise e # don't log if exception is raised
async def async_set_cache_sadd(
self, key, value: List, local_only: bool = False, **kwargs
) -> None:
"""
Add value to a set
Key - the key in cache
Value - str - the value you want to add to the set
Returns - None
"""
try:
if self.in_memory_cache is not None:
_ = await self.in_memory_cache.async_set_cache_sadd(
key, value, ttl=kwargs.get("ttl", None)
)
if self.redis_cache is not None and local_only is False:
_ = await self.redis_cache.async_set_cache_sadd(
key, value, ttl=kwargs.get("ttl", None)
)
return None
except Exception as e:
raise e # don't log, if exception is raised
def flush_cache(self):
if self.in_memory_cache is not None:
self.in_memory_cache.flush_cache()
if self.redis_cache is not None:
self.redis_cache.flush_cache()
def delete_cache(self, key):
"""
Delete a key from the cache
"""
if self.in_memory_cache is not None:
self.in_memory_cache.delete_cache(key)
if self.redis_cache is not None:
self.redis_cache.delete_cache(key)
async def async_delete_cache(self, key: str):
"""
Delete a key from the cache
"""
if self.in_memory_cache is not None:
self.in_memory_cache.delete_cache(key)
if self.redis_cache is not None:
await self.redis_cache.async_delete_cache(key)
async def async_get_ttl(self, key: str) -> Optional[int]:
"""
Get the remaining TTL of a key in in-memory cache or redis
"""
ttl = await self.in_memory_cache.async_get_ttl(key)
if ttl is None and self.redis_cache is not None:
ttl = await self.redis_cache.async_get_ttl(key)
return ttl

View File

@@ -0,0 +1,113 @@
"""GCS Cache implementation
Supports syncing responses to Google Cloud Storage Buckets using HTTP requests.
"""
import json
import asyncio
from typing import Optional
from litellm._logging import print_verbose, verbose_logger
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
_get_httpx_client,
httpxSpecialProvider,
)
from .base_cache import BaseCache
class GCSCache(BaseCache):
def __init__(
self,
bucket_name: Optional[str] = None,
path_service_account: Optional[str] = None,
gcs_path: Optional[str] = None,
) -> None:
super().__init__()
self.bucket_name = bucket_name or GCSBucketBase(bucket_name=None).BUCKET_NAME
self.path_service_account = (
path_service_account
or GCSBucketBase(bucket_name=None).path_service_account_json
)
self.key_prefix = gcs_path.rstrip("/") + "/" if gcs_path else ""
# create httpx clients
self.async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
self.sync_client = _get_httpx_client()
def _construct_headers(self) -> dict:
base = GCSBucketBase(bucket_name=self.bucket_name)
base.path_service_account_json = self.path_service_account
base.BUCKET_NAME = self.bucket_name
return base.sync_construct_request_headers()
def set_cache(self, key, value, **kwargs):
try:
print_verbose(f"LiteLLM SET Cache - GCS. Key={key}. Value={value}")
headers = self._construct_headers()
object_name = self.key_prefix + key
bucket_name = self.bucket_name
url = f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}"
data = json.dumps(value)
self.sync_client.post(url=url, data=data, headers=headers)
except Exception as e:
print_verbose(f"GCS Caching: set_cache() - Got exception from GCS: {e}")
async def async_set_cache(self, key, value, **kwargs):
try:
headers = self._construct_headers()
object_name = self.key_prefix + key
bucket_name = self.bucket_name
url = f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}"
data = json.dumps(value)
await self.async_client.post(url=url, data=data, headers=headers)
except Exception as e:
print_verbose(
f"GCS Caching: async_set_cache() - Got exception from GCS: {e}"
)
def get_cache(self, key, **kwargs):
try:
headers = self._construct_headers()
object_name = self.key_prefix + key
bucket_name = self.bucket_name
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
response = self.sync_client.get(url=url, headers=headers)
if response.status_code == 200:
cached_response = json.loads(response.text)
verbose_logger.debug(
f"Got GCS Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
)
return cached_response
return None
except Exception as e:
verbose_logger.error(
f"GCS Caching: get_cache() - Got exception from GCS: {e}"
)
async def async_get_cache(self, key, **kwargs):
try:
headers = self._construct_headers()
object_name = self.key_prefix + key
bucket_name = self.bucket_name
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
response = await self.async_client.get(url=url, headers=headers)
if response.status_code == 200:
return json.loads(response.text)
return None
except Exception as e:
verbose_logger.error(
f"GCS Caching: async_get_cache() - Got exception from GCS: {e}"
)
def flush_cache(self):
pass
async def disconnect(self):
pass
async def async_set_cache_pipeline(self, cache_list, **kwargs):
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)

View File

@@ -0,0 +1,288 @@
"""
In-Memory Cache implementation
Has 4 methods:
- set_cache
- get_cache
- async_set_cache
- async_get_cache
"""
import json
import sys
import time
import heapq
from typing import TYPE_CHECKING, Any, List, Optional
if TYPE_CHECKING:
from litellm.types.caching import RedisPipelineIncrementOperation
from pydantic import BaseModel
from litellm.constants import MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
from .base_cache import BaseCache
class InMemoryCache(BaseCache):
def __init__(
self,
max_size_in_memory: Optional[int] = 200,
default_ttl: Optional[
int
] = 600, # default ttl is 10 minutes. At maximum litellm rate limiting logic requires objects to be in memory for 1 minute
max_size_per_item: Optional[int] = 1024, # 1MB = 1024KB
):
"""
max_size_in_memory [int]: Maximum number of items in cache. done to prevent memory leaks. Use 200 items as a default
"""
self.max_size_in_memory = (
max_size_in_memory if max_size_in_memory is not None else 200
) # set an upper bound of 200 items in-memory
self.default_ttl = default_ttl or 600
self.max_size_per_item = (
max_size_per_item or MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
) # 1MB = 1024KB
# in-memory cache
self.cache_dict: dict = {}
self.ttl_dict: dict = {}
self.expiration_heap: list[tuple[float, str]] = []
def check_value_size(self, value: Any):
"""
Check if value size exceeds max_size_per_item (1MB)
Returns True if value size is acceptable, False otherwise
"""
try:
# Fast path for common primitive types that are typically small
if (
isinstance(value, (bool, int, float, str))
and len(str(value))
< self.max_size_per_item * MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
): # Conservative estimate
return True
# Direct size check for bytes objects
if isinstance(value, bytes):
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
# Handle special types without full conversion when possible
if hasattr(value, "__sizeof__"): # Use __sizeof__ if available
size = value.__sizeof__() / 1024
return size <= self.max_size_per_item
# Fallback for complex types
if isinstance(value, BaseModel) and hasattr(
value, "model_dump"
): # Pydantic v2
value = value.model_dump()
elif hasattr(value, "isoformat"): # datetime objects
return True # datetime strings are always small
# Only convert to JSON if absolutely necessary
if not isinstance(value, (str, bytes)):
value = json.dumps(value, default=str)
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
except Exception:
return False
def _is_key_expired(self, key: str) -> bool:
"""
Check if a specific key is expired
"""
return key in self.ttl_dict and time.time() > self.ttl_dict[key]
def _remove_key(self, key: str) -> None:
"""
Remove a key from both cache_dict and ttl_dict
"""
self.cache_dict.pop(key, None)
self.ttl_dict.pop(key, None)
def evict_cache(self):
"""
Eviction policy:
1. First, remove expired items from ttl_dict and cache_dict
2. If cache is still at or above max_size_in_memory, evict items with earliest expiration times
This guarantees the following:
- 1. When item ttl not set: At minimum each item will remain in memory for the default ttl
- 2. When ttl is set: the item will remain in memory for at least that amount of time, unless cache size requires eviction
- 3. the size of in-memory cache is bounded
"""
current_time = time.time()
# Step 1: Remove expired or outdated items
while self.expiration_heap:
expiration_time, key = self.expiration_heap[0]
# Case 1: Heap entry is outdated
if expiration_time != self.ttl_dict.get(key):
heapq.heappop(self.expiration_heap)
# Case 2: Entry is valid but expired
elif expiration_time <= current_time:
heapq.heappop(self.expiration_heap)
self._remove_key(key)
else:
# Case 3: Entry is valid and not expired
break
# Step 2: Evict if cache is still full
while len(self.cache_dict) >= self.max_size_in_memory:
expiration_time, key = heapq.heappop(self.expiration_heap)
# Skip if key was removed or updated
if self.ttl_dict.get(key) == expiration_time:
self._remove_key(key)
# de-reference the removed item
# https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/
# One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used.
# This can occur when an object is referenced by another object, but the reference is never removed.
def allow_ttl_override(self, key: str) -> bool:
"""
Check if ttl is set for a key
"""
ttl_time = self.ttl_dict.get(key)
if ttl_time is None: # if ttl is not set, allow override
return True
elif float(ttl_time) < time.time(): # if ttl is expired, allow override
return True
else:
return False
def set_cache(self, key, value, **kwargs):
# Handle the edge case where max_size_in_memory is 0
if self.max_size_in_memory == 0:
return # Don't cache anything if max size is 0
if len(self.cache_dict) >= self.max_size_in_memory:
# only evict when cache is full
self.evict_cache()
if not self.check_value_size(value):
return
self.cache_dict[key] = value
if self.allow_ttl_override(key): # if ttl is not set, set it to default ttl
if "ttl" in kwargs and kwargs["ttl"] is not None:
self.ttl_dict[key] = time.time() + float(kwargs["ttl"])
heapq.heappush(self.expiration_heap, (self.ttl_dict[key], key))
else:
self.ttl_dict[key] = time.time() + self.default_ttl
heapq.heappush(self.expiration_heap, (self.ttl_dict[key], key))
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
for cache_key, cache_value in cache_list:
if ttl is not None:
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
else:
self.set_cache(key=cache_key, value=cache_value)
async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]):
"""
Add value to set
"""
# get the value
init_value = self.get_cache(key=key) or set()
for val in value:
init_value.add(val)
self.set_cache(key, init_value, ttl=ttl)
return value
def evict_element_if_expired(self, key: str) -> bool:
"""
Returns True if the element is expired and removed from the cache
Returns False if the element is not expired
"""
if self._is_key_expired(key):
self._remove_key(key)
return True
return False
def get_cache(self, key, **kwargs):
if key in self.cache_dict:
if self.evict_element_if_expired(key):
return None
original_cached_response = self.cache_dict[key]
try:
cached_response = json.loads(original_cached_response)
except Exception:
cached_response = original_cached_response
return cached_response
return None
def batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
def increment_cache(self, key, value: int, **kwargs) -> int:
# get the value
init_value = self.get_cache(key=key) or 0
value = init_value + value
self.set_cache(key, value, **kwargs)
return value
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
async def async_batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_increment(self, key, value: float, **kwargs) -> float:
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value
await self.async_set_cache(key, value, **kwargs)
return value
async def async_increment_pipeline(
self, increment_list: List["RedisPipelineIncrementOperation"], **kwargs
) -> Optional[List[float]]:
results = []
for increment in increment_list:
result = await self.async_increment(
increment["key"], increment["increment_value"], **kwargs
)
results.append(result)
return results
def flush_cache(self):
self.cache_dict.clear()
self.ttl_dict.clear()
self.expiration_heap.clear()
async def disconnect(self):
pass
def delete_cache(self, key):
self._remove_key(key)
async def async_get_ttl(self, key: str) -> Optional[int]:
"""
Get the remaining TTL of a key in in-memory cache
"""
return self.ttl_dict.get(key, None)
async def async_get_oldest_n_keys(self, n: int) -> List[str]:
"""
Get the oldest n keys in the cache
"""
# sorted ttl dict by ttl
sorted_ttl_dict = sorted(self.ttl_dict.items(), key=lambda x: x[1])
return [key for key, _ in sorted_ttl_dict[:n]]

View File

@@ -0,0 +1,50 @@
"""
Add the event loop to the cache key, to prevent event loop closed errors.
"""
import asyncio
from .in_memory_cache import InMemoryCache
class LLMClientCache(InMemoryCache):
"""Cache for LLM HTTP clients (OpenAI, Azure, httpx, etc.).
IMPORTANT: This cache intentionally does NOT close clients on eviction.
Evicted clients may still be in use by in-flight requests. Closing them
eagerly causes ``RuntimeError: Cannot send a request, as the client has
been closed.`` errors in production after the TTL (1 hour) expires.
Clients that are no longer referenced will be garbage-collected normally.
For explicit shutdown cleanup, use ``close_litellm_async_clients()``.
"""
def update_cache_key_with_event_loop(self, key):
"""
Add the event loop to the cache key, to prevent event loop closed errors.
If none, use the key as is.
"""
try:
event_loop = asyncio.get_running_loop()
stringified_event_loop = str(id(event_loop))
return f"{key}-{stringified_event_loop}"
except RuntimeError: # handle no current running event loop
return key
def set_cache(self, key, value, **kwargs):
key = self.update_cache_key_with_event_loop(key)
return super().set_cache(key, value, **kwargs)
async def async_set_cache(self, key, value, **kwargs):
key = self.update_cache_key_with_event_loop(key)
return await super().async_set_cache(key, value, **kwargs)
def get_cache(self, key, **kwargs):
key = self.update_cache_key_with_event_loop(key)
return super().get_cache(key, **kwargs)
async def async_get_cache(self, key, **kwargs):
key = self.update_cache_key_with_event_loop(key)
return await super().async_get_cache(key, **kwargs)

View File

@@ -0,0 +1,446 @@
"""
Qdrant Semantic Cache implementation
Has 4 methods:
- set_cache
- get_cache
- async_set_cache
- async_get_cache
"""
import ast
import asyncio
import json
from typing import Any, cast
import litellm
from litellm._logging import print_verbose
from litellm.constants import QDRANT_SCALAR_QUANTILE, QDRANT_VECTOR_SIZE
from litellm.types.utils import EmbeddingResponse
from .base_cache import BaseCache
class QdrantSemanticCache(BaseCache):
def __init__( # noqa: PLR0915
self,
qdrant_api_base=None,
qdrant_api_key=None,
collection_name=None,
similarity_threshold=None,
quantization_config=None,
embedding_model="text-embedding-ada-002",
host_type=None,
vector_size=None,
):
import os
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.secret_managers.main import get_secret_str
if collection_name is None:
raise Exception("collection_name must be provided, passed None")
self.collection_name = collection_name
print_verbose(
f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}"
)
if similarity_threshold is None:
raise Exception("similarity_threshold must be provided, passed None")
self.similarity_threshold = similarity_threshold
self.embedding_model = embedding_model
self.vector_size = (
vector_size if vector_size is not None else QDRANT_VECTOR_SIZE
)
headers = {}
# check if defined as os.environ/ variable
if qdrant_api_base:
if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith(
"os.environ/"
):
qdrant_api_base = get_secret_str(qdrant_api_base)
if qdrant_api_key:
if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith(
"os.environ/"
):
qdrant_api_key = get_secret_str(qdrant_api_key)
qdrant_api_base = (
qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
)
qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
headers = {"Content-Type": "application/json"}
if qdrant_api_key:
headers["api-key"] = qdrant_api_key
if qdrant_api_base is None:
raise ValueError("Qdrant url must be provided")
self.qdrant_api_base = qdrant_api_base
self.qdrant_api_key = qdrant_api_key
print_verbose(f"qdrant semantic-cache qdrant_api_base: {self.qdrant_api_base}")
self.headers = headers
self.sync_client = _get_httpx_client()
self.async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.Caching
)
if quantization_config is None:
print_verbose(
"Quantization config is not provided. Default binary quantization will be used."
)
collection_exists = self.sync_client.get(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/exists",
headers=self.headers,
)
if collection_exists.status_code != 200:
raise ValueError(
f"Error from qdrant checking if /collections exist {collection_exists.text}"
)
if collection_exists.json()["result"]["exists"]:
collection_details = self.sync_client.get(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
headers=self.headers,
)
self.collection_info = collection_details.json()
print_verbose(
f"Collection already exists.\nCollection details:{self.collection_info}"
)
else:
if quantization_config is None or quantization_config == "binary":
quantization_params = {
"binary": {
"always_ram": False,
}
}
elif quantization_config == "scalar":
quantization_params = {
"scalar": {
"type": "int8",
"quantile": QDRANT_SCALAR_QUANTILE,
"always_ram": False,
}
}
elif quantization_config == "product":
quantization_params = {
"product": {"compression": "x16", "always_ram": False}
}
else:
raise Exception(
"Quantization config must be one of 'scalar', 'binary' or 'product'"
)
new_collection_status = self.sync_client.put(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
json={
"vectors": {"size": self.vector_size, "distance": "Cosine"},
"quantization_config": quantization_params,
},
headers=self.headers,
)
if new_collection_status.json()["result"]:
collection_details = self.sync_client.get(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
headers=self.headers,
)
self.collection_info = collection_details.json()
print_verbose(
f"New collection created.\nCollection details:{self.collection_info}"
)
else:
raise Exception("Error while creating new collection")
def _get_cache_logic(self, cached_response: Any):
if cached_response is None:
return cached_response
try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except Exception:
cached_response = ast.literal_eval(cached_response)
return cached_response
def set_cache(self, key, value, **kwargs):
print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}")
from litellm._uuid import uuid
# get the prompt
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
# create an embedding for prompt
embedding_response = cast(
EmbeddingResponse,
litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
),
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
value = str(value)
assert isinstance(value, str)
data = {
"points": [
{
"id": str(uuid.uuid4()),
"vector": embedding,
"payload": {
"text": prompt,
"response": value,
},
},
]
}
self.sync_client.put(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
headers=self.headers,
json=data,
)
return
def get_cache(self, key, **kwargs):
print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}")
# get the messages
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
# convert to embedding
embedding_response = cast(
EmbeddingResponse,
litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
),
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
data = {
"vector": embedding,
"params": {
"quantization": {
"ignore": False,
"rescore": True,
"oversampling": 3.0,
}
},
"limit": 1,
"with_payload": True,
}
search_response = self.sync_client.post(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
headers=self.headers,
json=data,
)
results = search_response.json()["result"]
if results is None:
return None
if isinstance(results, list):
if len(results) == 0:
return None
similarity = results[0]["score"]
cached_prompt = results[0]["payload"]["text"]
# check similarity, if more than self.similarity_threshold, return results
print_verbose(
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
)
if similarity >= self.similarity_threshold:
# cache hit !
cached_value = results[0]["payload"]["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_value)
else:
# cache miss !
return None
pass
async def async_set_cache(self, key, value, **kwargs):
from litellm._uuid import uuid
from litellm.proxy.proxy_server import llm_model_list, llm_router
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
# get the prompt
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
# create an embedding for prompt
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if llm_router is not None and self.embedding_model in router_model_names:
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
embedding_response = await llm_router.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
metadata={
"user_api_key": user_api_key,
"semantic-cache-embedding": True,
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
},
)
else:
# convert to embedding
embedding_response = await litellm.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
value = str(value)
assert isinstance(value, str)
data = {
"points": [
{
"id": str(uuid.uuid4()),
"vector": embedding,
"payload": {
"text": prompt,
"response": value,
},
},
]
}
await self.async_client.put(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
headers=self.headers,
json=data,
)
return
async def async_get_cache(self, key, **kwargs):
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
from litellm.proxy.proxy_server import llm_model_list, llm_router
# get the messages
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if llm_router is not None and self.embedding_model in router_model_names:
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
embedding_response = await llm_router.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
metadata={
"user_api_key": user_api_key,
"semantic-cache-embedding": True,
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
},
)
else:
# convert to embedding
embedding_response = await litellm.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
data = {
"vector": embedding,
"params": {
"quantization": {
"ignore": False,
"rescore": True,
"oversampling": 3.0,
}
},
"limit": 1,
"with_payload": True,
}
search_response = await self.async_client.post(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
headers=self.headers,
json=data,
)
results = search_response.json()["result"]
if results is None:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
if isinstance(results, list):
if len(results) == 0:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
similarity = results[0]["score"]
cached_prompt = results[0]["payload"]["text"]
# check similarity, if more than self.similarity_threshold, return results
print_verbose(
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
)
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
if similarity >= self.similarity_threshold:
# cache hit !
cached_value = results[0]["payload"]["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_value)
else:
# cache miss !
return None
pass
async def _collection_info(self):
return self.collection_info
async def async_set_cache_pipeline(self, cache_list, **kwargs):
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,109 @@
"""
Redis Cluster Cache implementation
Key differences:
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
"""
from typing import TYPE_CHECKING, Any, List, Optional, Union
from litellm.caching.redis_cache import RedisCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from redis.asyncio import Redis, RedisCluster
from redis.asyncio.client import Pipeline
pipeline = Pipeline
async_redis_client = Redis
Span = Union[_Span, Any]
else:
pipeline = Any
async_redis_client = Any
Span = Any
class RedisClusterCache(RedisCache):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.redis_async_redis_cluster_client: Optional[RedisCluster] = None
self.redis_sync_redis_cluster_client: Optional[RedisCluster] = None
def init_async_client(self):
from redis.asyncio import RedisCluster
from .._redis import get_redis_async_client
if self.redis_async_redis_cluster_client:
return self.redis_async_redis_cluster_client
_redis_client = get_redis_async_client(
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
)
if isinstance(_redis_client, RedisCluster):
self.redis_async_redis_cluster_client = _redis_client
return _redis_client
def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
"""
Overrides `_run_redis_mget_operation` in redis_cache.py
"""
return self.redis_client.mget_nonatomic(keys=keys) # type: ignore
async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
"""
Overrides `_async_run_redis_mget_operation` in redis_cache.py
"""
async_redis_cluster_client = self.init_async_client()
return await async_redis_cluster_client.mget_nonatomic(keys=keys) # type: ignore
async def test_connection(self) -> dict:
"""
Test the Redis Cluster connection.
Returns:
dict: {"status": "success" | "failed", "message": str, "error": Optional[str]}
"""
try:
import redis.asyncio as redis_async
from redis.cluster import ClusterNode
# Create ClusterNode objects from startup_nodes
cluster_kwargs = self.redis_kwargs.copy()
startup_nodes = cluster_kwargs.pop("startup_nodes", [])
new_startup_nodes: List[ClusterNode] = []
for item in startup_nodes:
new_startup_nodes.append(ClusterNode(**item))
# Create a fresh Redis Cluster client with current settings
redis_client = redis_async.RedisCluster(
startup_nodes=new_startup_nodes, **cluster_kwargs # type: ignore
)
# Test the connection
ping_result = await redis_client.ping() # type: ignore[attr-defined, misc]
# Close the connection
await redis_client.aclose() # type: ignore[attr-defined]
if ping_result:
return {
"status": "success",
"message": "Redis Cluster connection test successful",
}
else:
return {
"status": "failed",
"message": "Redis Cluster ping returned False",
}
except Exception as e:
from litellm._logging import verbose_logger
verbose_logger.error(f"Redis Cluster connection test failed: {str(e)}")
return {
"status": "failed",
"message": f"Redis Cluster connection failed: {str(e)}",
"error": str(e),
}

View File

@@ -0,0 +1,450 @@
"""
Redis Semantic Cache implementation for LiteLLM
The RedisSemanticCache provides semantic caching functionality using Redis as a backend.
This cache stores responses based on the semantic similarity of prompts rather than
exact matching, allowing for more flexible caching of LLM responses.
This implementation uses RedisVL's SemanticCache to find semantically similar prompts
and their cached responses.
"""
import ast
import asyncio
import json
import os
from typing import Any, Dict, List, Optional, Tuple, cast
import litellm
from litellm._logging import print_verbose
from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_str_from_messages,
)
from litellm.types.utils import EmbeddingResponse
from .base_cache import BaseCache
class RedisSemanticCache(BaseCache):
"""
Redis-backed semantic cache for LLM responses.
This cache uses vector similarity to find semantically similar prompts that have been
previously sent to the LLM, allowing for cache hits even when prompts are not identical
but carry similar meaning.
"""
DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index"
def __init__(
self,
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
redis_url: Optional[str] = None,
similarity_threshold: Optional[float] = None,
embedding_model: str = "text-embedding-ada-002",
index_name: Optional[str] = None,
**kwargs,
):
"""
Initialize the Redis Semantic Cache.
Args:
host: Redis host address
port: Redis port
password: Redis password
redis_url: Full Redis URL (alternative to separate host/port/password)
similarity_threshold: Threshold for semantic similarity (0.0 to 1.0)
where 1.0 requires exact matches and 0.0 accepts any match
embedding_model: Model to use for generating embeddings
index_name: Name for the Redis index
ttl: Default time-to-live for cache entries in seconds
**kwargs: Additional arguments passed to the Redis client
Raises:
Exception: If similarity_threshold is not provided or required Redis
connection information is missing
"""
from redisvl.extensions.llmcache import SemanticCache
from redisvl.utils.vectorize import CustomTextVectorizer
if index_name is None:
index_name = self.DEFAULT_REDIS_INDEX_NAME
print_verbose(f"Redis semantic-cache initializing index - {index_name}")
# Validate similarity threshold
if similarity_threshold is None:
raise ValueError("similarity_threshold must be provided, passed None")
# Store configuration
self.similarity_threshold = similarity_threshold
# Convert similarity threshold [0,1] to distance threshold [0,2]
# For cosine distance: 0 = most similar, 2 = least similar
# While similarity: 1 = most similar, 0 = least similar
self.distance_threshold = 1 - similarity_threshold
self.embedding_model = embedding_model
# Set up Redis connection
if redis_url is None:
try:
# Attempt to use provided parameters or fallback to environment variables
host = host or os.environ["REDIS_HOST"]
port = port or os.environ["REDIS_PORT"]
password = password or os.environ["REDIS_PASSWORD"]
except KeyError as e:
# Raise a more informative exception if any of the required keys are missing
missing_var = e.args[0]
raise ValueError(
f"Missing required Redis configuration: {missing_var}. "
f"Provide {missing_var} or redis_url."
) from e
redis_url = f"redis://:{password}@{host}:{port}"
print_verbose(f"Redis semantic-cache redis_url: {redis_url}")
# Initialize the Redis vectorizer and cache
cache_vectorizer = CustomTextVectorizer(self._get_embedding)
self.llmcache = SemanticCache(
name=index_name,
redis_url=redis_url,
vectorizer=cache_vectorizer,
distance_threshold=self.distance_threshold,
overwrite=False,
)
def _get_ttl(self, **kwargs) -> Optional[int]:
"""
Get the TTL (time-to-live) value for cache entries.
Args:
**kwargs: Keyword arguments that may contain a custom TTL
Returns:
Optional[int]: The TTL value in seconds, or None if no TTL should be applied
"""
ttl = kwargs.get("ttl")
if ttl is not None:
ttl = int(ttl)
return ttl
def _get_embedding(self, prompt: str) -> List[float]:
"""
Generate an embedding vector for the given prompt using the configured embedding model.
Args:
prompt: The text to generate an embedding for
Returns:
List[float]: The embedding vector
"""
# Create an embedding from prompt
embedding_response = cast(
EmbeddingResponse,
litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
),
)
embedding = embedding_response["data"][0]["embedding"]
return embedding
def _get_cache_logic(self, cached_response: Any) -> Any:
"""
Process the cached response to prepare it for use.
Args:
cached_response: The raw cached response
Returns:
The processed cache response, or None if input was None
"""
if cached_response is None:
return cached_response
# Convert bytes to string if needed
if isinstance(cached_response, bytes):
cached_response = cached_response.decode("utf-8")
# Convert string representation to Python object
try:
cached_response = json.loads(cached_response)
except json.JSONDecodeError:
try:
cached_response = ast.literal_eval(cached_response)
except (ValueError, SyntaxError) as e:
print_verbose(f"Error parsing cached response: {str(e)}")
return None
return cached_response
def set_cache(self, key: str, value: Any, **kwargs) -> None:
"""
Store a value in the semantic cache.
Args:
key: The cache key (not directly used in semantic caching)
value: The response value to cache
**kwargs: Additional arguments including 'messages' for the prompt
and optional 'ttl' for time-to-live
"""
print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}")
value_str: Optional[str] = None
try:
# Extract the prompt from messages
messages = kwargs.get("messages", [])
if not messages:
print_verbose("No messages provided for semantic caching")
return
prompt = get_str_from_messages(messages)
value_str = str(value)
# Get TTL and store in Redis semantic cache
ttl = self._get_ttl(**kwargs)
if ttl is not None:
self.llmcache.store(prompt, value_str, ttl=int(ttl))
else:
self.llmcache.store(prompt, value_str)
except Exception as e:
print_verbose(
f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}"
)
def get_cache(self, key: str, **kwargs) -> Any:
"""
Retrieve a semantically similar cached response.
Args:
key: The cache key (not directly used in semantic caching)
**kwargs: Additional arguments including 'messages' for the prompt
Returns:
The cached response if a semantically similar prompt is found, else None
"""
print_verbose(f"Redis semantic-cache get_cache, kwargs: {kwargs}")
try:
# Extract the prompt from messages
messages = kwargs.get("messages", [])
if not messages:
print_verbose("No messages provided for semantic cache lookup")
return None
prompt = get_str_from_messages(messages)
# Check the cache for semantically similar prompts
results = self.llmcache.check(prompt=prompt)
# Return None if no similar prompts found
if not results:
return None
# Process the best matching result
cache_hit = results[0]
vector_distance = float(cache_hit["vector_distance"])
# Convert vector distance back to similarity score
# For cosine distance: 0 = most similar, 2 = least similar
# While similarity: 1 = most similar, 0 = least similar
similarity = 1 - vector_distance
cached_prompt = cache_hit["prompt"]
cached_response = cache_hit["response"]
print_verbose(
f"Cache hit: similarity threshold: {self.similarity_threshold}, "
f"actual similarity: {similarity}, "
f"current prompt: {prompt}, "
f"cached prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_response)
except Exception as e:
print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}")
async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]:
"""
Asynchronously generate an embedding for the given prompt.
Args:
prompt: The text to generate an embedding for
**kwargs: Additional arguments that may contain metadata
Returns:
List[float]: The embedding vector
"""
from litellm.proxy.proxy_server import llm_model_list, llm_router
# Route the embedding request through the proxy if appropriate
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
try:
if llm_router is not None and self.embedding_model in router_model_names:
# Use the router for embedding generation
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
embedding_response = await llm_router.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
metadata={
"user_api_key": user_api_key,
"semantic-cache-embedding": True,
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
},
)
else:
# Generate embedding directly
embedding_response = await litellm.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# Extract and return the embedding vector
return embedding_response["data"][0]["embedding"]
except Exception as e:
print_verbose(f"Error generating async embedding: {str(e)}")
raise ValueError(f"Failed to generate embedding: {str(e)}") from e
async def async_set_cache(self, key: str, value: Any, **kwargs) -> None:
"""
Asynchronously store a value in the semantic cache.
Args:
key: The cache key (not directly used in semantic caching)
value: The response value to cache
**kwargs: Additional arguments including 'messages' for the prompt
and optional 'ttl' for time-to-live
"""
print_verbose(f"Async Redis semantic-cache set_cache, kwargs: {kwargs}")
try:
# Extract the prompt from messages
messages = kwargs.get("messages", [])
if not messages:
print_verbose("No messages provided for semantic caching")
return
prompt = get_str_from_messages(messages)
value_str = str(value)
# Generate embedding for the value (response) to cache
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
# Get TTL and store in Redis semantic cache
ttl = self._get_ttl(**kwargs)
if ttl is not None:
await self.llmcache.astore(
prompt,
value_str,
vector=prompt_embedding, # Pass through custom embedding
ttl=ttl,
)
else:
await self.llmcache.astore(
prompt,
value_str,
vector=prompt_embedding, # Pass through custom embedding
)
except Exception as e:
print_verbose(f"Error in async_set_cache: {str(e)}")
async def async_get_cache(self, key: str, **kwargs) -> Any:
"""
Asynchronously retrieve a semantically similar cached response.
Args:
key: The cache key (not directly used in semantic caching)
**kwargs: Additional arguments including 'messages' for the prompt
Returns:
The cached response if a semantically similar prompt is found, else None
"""
print_verbose(f"Async Redis semantic-cache get_cache, kwargs: {kwargs}")
try:
# Extract the prompt from messages
messages = kwargs.get("messages", [])
if not messages:
print_verbose("No messages provided for semantic cache lookup")
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
prompt = get_str_from_messages(messages)
# Generate embedding for the prompt
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
# Check the cache for semantically similar prompts
results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding)
# handle results / cache hit
if not results:
kwargs.setdefault("metadata", {})[
"semantic-similarity"
] = 0.0 # TODO why here but not above??
return None
cache_hit = results[0]
vector_distance = float(cache_hit["vector_distance"])
# Convert vector distance back to similarity
# For cosine distance: 0 = most similar, 2 = least similar
# While similarity: 1 = most similar, 0 = least similar
similarity = 1 - vector_distance
cached_prompt = cache_hit["prompt"]
cached_response = cache_hit["response"]
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
print_verbose(
f"Cache hit: similarity threshold: {self.similarity_threshold}, "
f"actual similarity: {similarity}, "
f"current prompt: {prompt}, "
f"cached prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_response)
except Exception as e:
print_verbose(f"Error in async_get_cache: {str(e)}")
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
async def _index_info(self) -> Dict[str, Any]:
"""
Get information about the Redis index.
Returns:
Dict[str, Any]: Information about the Redis index
"""
aindex = await self.llmcache._get_async_index()
return await aindex.info()
async def async_set_cache_pipeline(
self, cache_list: List[Tuple[str, Any]], **kwargs
) -> None:
"""
Asynchronously store multiple values in the semantic cache.
Args:
cache_list: List of (key, value) tuples to cache
**kwargs: Additional arguments
"""
try:
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)
except Exception as e:
print_verbose(f"Error in async_set_cache_pipeline: {str(e)}")

View File

@@ -0,0 +1,193 @@
"""
S3 Cache implementation
Has 4 methods:
- set_cache
- get_cache
- async_set_cache (uses run_in_executor)
- async_get_cache (uses run_in_executor)
"""
import ast
import asyncio
import json
from functools import partial
from typing import Optional
from datetime import datetime, timezone, timedelta
from litellm._logging import print_verbose, verbose_logger
from .base_cache import BaseCache
class S3Cache(BaseCache):
def __init__(
self,
s3_bucket_name,
s3_region_name=None,
s3_api_version=None,
s3_use_ssl: Optional[bool] = True,
s3_verify=None,
s3_endpoint_url=None,
s3_aws_access_key_id=None,
s3_aws_secret_access_key=None,
s3_aws_session_token=None,
s3_config=None,
s3_path=None,
**kwargs,
):
import boto3
self.bucket_name = s3_bucket_name
self.key_prefix = s3_path.rstrip("/") + "/" if s3_path else ""
# Create an S3 client with custom endpoint URL
self.s3_client = boto3.client(
"s3",
region_name=s3_region_name,
endpoint_url=s3_endpoint_url,
api_version=s3_api_version,
use_ssl=s3_use_ssl,
verify=s3_verify,
aws_access_key_id=s3_aws_access_key_id,
aws_secret_access_key=s3_aws_secret_access_key,
aws_session_token=s3_aws_session_token,
config=s3_config,
**kwargs,
)
def _to_s3_key(self, key: str) -> str:
"""Convert cache key to S3 key"""
return self.key_prefix + key.replace(":", "/")
def set_cache(self, key, value, **kwargs):
try:
print_verbose(f"LiteLLM SET Cache - S3. Key={key}. Value={value}")
ttl = kwargs.get("ttl", None)
# Convert value to JSON before storing in S3
serialized_value = json.dumps(value)
key = self._to_s3_key(key)
if ttl is not None:
cache_control = f"immutable, max-age={ttl}, s-maxage={ttl}"
# Calculate expiration time
expiration_time = datetime.now(timezone.utc) + timedelta(seconds=ttl)
# Upload the data to S3 with the calculated expiration time
self.s3_client.put_object(
Bucket=self.bucket_name,
Key=key,
Body=serialized_value,
Expires=expiration_time,
CacheControl=cache_control,
ContentType="application/json",
ContentLanguage="en",
ContentDisposition=f'inline; filename="{key}.json"',
)
else:
cache_control = "immutable, max-age=31536000, s-maxage=31536000"
# Upload the data to S3 without specifying Expires
self.s3_client.put_object(
Bucket=self.bucket_name,
Key=key,
Body=serialized_value,
CacheControl=cache_control,
ContentType="application/json",
ContentLanguage="en",
ContentDisposition=f'inline; filename="{key}.json"',
)
except Exception as e:
print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
async def async_set_cache(self, key, value, **kwargs):
"""
Asynchronously set cache using run_in_executor to avoid blocking the event loop.
Compatible with Python 3.8+.
"""
try:
verbose_logger.debug(f"Set ASYNC S3 Cache: Key={key}. Value={value}")
loop = asyncio.get_event_loop()
func = partial(self.set_cache, key, value, **kwargs)
await loop.run_in_executor(None, func)
except Exception as e:
verbose_logger.error(
f"S3 Caching: async_set_cache() - Got exception from S3: {e}"
)
def get_cache(self, key, **kwargs):
import botocore
try:
key = self._to_s3_key(key)
print_verbose(f"Get S3 Cache: key: {key}")
# Download the data from S3
cached_response = self.s3_client.get_object(
Bucket=self.bucket_name, Key=key
)
if cached_response is not None:
if "Expires" in cached_response:
expires_time = cached_response["Expires"]
current_time = datetime.now(expires_time.tzinfo)
if current_time > expires_time:
return None
# cached_response is in `b{} convert it to ModelResponse
cached_response = (
cached_response["Body"].read().decode("utf-8")
) # Convert bytes to string
try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except Exception:
cached_response = ast.literal_eval(cached_response)
if not isinstance(cached_response, dict):
cached_response = dict(cached_response)
verbose_logger.debug(
f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
)
return cached_response
except botocore.exceptions.ClientError as e: # type: ignore
if e.response["Error"]["Code"] == "NoSuchKey":
verbose_logger.debug(
f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
)
return None
except Exception as e:
verbose_logger.error(
f"S3 Caching: get_cache() - Got exception from S3: {e}"
)
async def async_get_cache(self, key, **kwargs):
"""
Asynchronously get cache using run_in_executor to avoid blocking the event loop.
Compatible with Python 3.8+.
"""
try:
verbose_logger.debug(f"Get ASYNC S3 Cache: key: {key}")
loop = asyncio.get_event_loop()
func = partial(self.get_cache, key, **kwargs)
result = await loop.run_in_executor(None, func)
return result
except Exception as e:
verbose_logger.error(
f"S3 Caching: async_get_cache() - Got exception from S3: {e}"
)
return None
def flush_cache(self):
pass
async def disconnect(self):
pass
async def async_set_cache_pipeline(self, cache_list, **kwargs):
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)

View File

@@ -0,0 +1,4 @@
Logic specific for `litellm.completion`.
Includes:
- Bridge for transforming completion requests to responses api requests

View File

@@ -0,0 +1,3 @@
from .litellm_responses_transformation import responses_api_bridge
__all__ = ["responses_api_bridge"]

View File

@@ -0,0 +1,3 @@
from .handler import responses_api_bridge
__all__ = ["responses_api_bridge"]

View File

@@ -0,0 +1,331 @@
"""
Handler for transforming /chat/completions api requests to litellm.responses requests
"""
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Union
from typing_extensions import TypedDict
from litellm.types.llms.openai import ResponsesAPIResponse
if TYPE_CHECKING:
from litellm import CustomStreamWrapper, LiteLLMLoggingObj, ModelResponse
class ResponsesToCompletionBridgeHandlerInputKwargs(TypedDict):
model: str
messages: list
optional_params: dict
litellm_params: dict
headers: dict
model_response: "ModelResponse"
logging_obj: "LiteLLMLoggingObj"
custom_llm_provider: str
class ResponsesToCompletionBridgeHandler:
def __init__(self):
from .transformation import LiteLLMResponsesTransformationHandler
super().__init__()
self.transformation_handler = LiteLLMResponsesTransformationHandler()
@staticmethod
def _resolve_stream_flag(optional_params: dict, litellm_params: dict) -> bool:
stream = optional_params.get("stream")
if stream is None:
stream = litellm_params.get("stream", False)
return bool(stream)
@staticmethod
def _coerce_response_object(
response_obj: Any,
hidden_params: Optional[dict],
) -> "ResponsesAPIResponse":
if isinstance(response_obj, ResponsesAPIResponse):
response = response_obj
elif isinstance(response_obj, dict):
try:
response = ResponsesAPIResponse(**response_obj)
except Exception:
response = ResponsesAPIResponse.model_construct(**response_obj)
else:
raise ValueError("Unexpected responses stream payload")
if hidden_params:
existing = getattr(response, "_hidden_params", None)
if not isinstance(existing, dict) or not existing:
setattr(response, "_hidden_params", dict(hidden_params))
else:
for key, value in hidden_params.items():
existing.setdefault(key, value)
return response
def _collect_response_from_stream(self, stream_iter: Any) -> "ResponsesAPIResponse":
for _ in stream_iter:
pass
completed = getattr(stream_iter, "completed_response", None)
response_obj = getattr(completed, "response", None) if completed else None
if response_obj is None:
raise ValueError("Stream ended without a completed response")
hidden_params = getattr(stream_iter, "_hidden_params", None)
response = self._coerce_response_object(response_obj, hidden_params)
if not isinstance(response, ResponsesAPIResponse):
raise ValueError("Stream completed response is invalid")
return response
async def _collect_response_from_stream_async(
self, stream_iter: Any
) -> "ResponsesAPIResponse":
async for _ in stream_iter:
pass
completed = getattr(stream_iter, "completed_response", None)
response_obj = getattr(completed, "response", None) if completed else None
if response_obj is None:
raise ValueError("Stream ended without a completed response")
hidden_params = getattr(stream_iter, "_hidden_params", None)
response = self._coerce_response_object(response_obj, hidden_params)
if not isinstance(response, ResponsesAPIResponse):
raise ValueError("Stream completed response is invalid")
return response
def validate_input_kwargs(
self, kwargs: dict
) -> ResponsesToCompletionBridgeHandlerInputKwargs:
from litellm import LiteLLMLoggingObj
from litellm.types.utils import ModelResponse
model = kwargs.get("model")
if model is None or not isinstance(model, str):
raise ValueError("model is required")
custom_llm_provider = kwargs.get("custom_llm_provider")
if custom_llm_provider is None or not isinstance(custom_llm_provider, str):
raise ValueError("custom_llm_provider is required")
messages = kwargs.get("messages")
if messages is None or not isinstance(messages, list):
raise ValueError("messages is required")
optional_params = kwargs.get("optional_params")
if optional_params is None or not isinstance(optional_params, dict):
raise ValueError("optional_params is required")
litellm_params = kwargs.get("litellm_params")
if litellm_params is None or not isinstance(litellm_params, dict):
raise ValueError("litellm_params is required")
headers = kwargs.get("headers")
if headers is None or not isinstance(headers, dict):
raise ValueError("headers is required")
model_response = kwargs.get("model_response")
if model_response is None or not isinstance(model_response, ModelResponse):
raise ValueError("model_response is required")
logging_obj = kwargs.get("logging_obj")
if logging_obj is None or not isinstance(logging_obj, LiteLLMLoggingObj):
raise ValueError("logging_obj is required")
return ResponsesToCompletionBridgeHandlerInputKwargs(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
model_response=model_response,
logging_obj=logging_obj,
custom_llm_provider=custom_llm_provider,
)
def completion(
self, *args, **kwargs
) -> Union[
Coroutine[Any, Any, Union["ModelResponse", "CustomStreamWrapper"]],
"ModelResponse",
"CustomStreamWrapper",
]:
if kwargs.get("acompletion") is True:
return self.acompletion(**kwargs)
from litellm import responses
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
validated_kwargs = self.validate_input_kwargs(kwargs)
model = validated_kwargs["model"]
messages = validated_kwargs["messages"]
optional_params = validated_kwargs["optional_params"]
litellm_params = validated_kwargs["litellm_params"]
headers = validated_kwargs["headers"]
model_response = validated_kwargs["model_response"]
logging_obj = validated_kwargs["logging_obj"]
custom_llm_provider = validated_kwargs["custom_llm_provider"]
request_data = self.transformation_handler.transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
litellm_logging_obj=logging_obj,
client=kwargs.get("client"),
)
result = responses(
**request_data,
)
stream = self._resolve_stream_flag(optional_params, litellm_params)
if isinstance(result, ResponsesAPIResponse):
return self.transformation_handler.transform_response(
model=model,
raw_response=result,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=kwargs.get("encoding"),
api_key=kwargs.get("api_key"),
json_mode=kwargs.get("json_mode"),
)
elif not stream:
responses_api_response = self._collect_response_from_stream(result)
return self.transformation_handler.transform_response(
model=model,
raw_response=responses_api_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=kwargs.get("encoding"),
api_key=kwargs.get("api_key"),
json_mode=kwargs.get("json_mode"),
)
else:
completion_stream = self.transformation_handler.get_model_response_iterator(
streaming_response=result, # type: ignore
sync_stream=True,
json_mode=kwargs.get("json_mode"),
)
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
return self._apply_post_stream_processing(
streamwrapper, model, custom_llm_provider
)
async def acompletion(
self, *args, **kwargs
) -> Union["ModelResponse", "CustomStreamWrapper"]:
from litellm import aresponses
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
validated_kwargs = self.validate_input_kwargs(kwargs)
model = validated_kwargs["model"]
messages = validated_kwargs["messages"]
optional_params = validated_kwargs["optional_params"]
litellm_params = validated_kwargs["litellm_params"]
headers = validated_kwargs["headers"]
model_response = validated_kwargs["model_response"]
logging_obj = validated_kwargs["logging_obj"]
custom_llm_provider = validated_kwargs["custom_llm_provider"]
try:
request_data = self.transformation_handler.transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
litellm_logging_obj=logging_obj,
)
except Exception as e:
raise e
result = await aresponses(
**request_data,
aresponses=True,
)
stream = self._resolve_stream_flag(optional_params, litellm_params)
if isinstance(result, ResponsesAPIResponse):
return self.transformation_handler.transform_response(
model=model,
raw_response=result,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=kwargs.get("encoding"),
api_key=kwargs.get("api_key"),
json_mode=kwargs.get("json_mode"),
)
elif not stream:
responses_api_response = await self._collect_response_from_stream_async(
result
)
return self.transformation_handler.transform_response(
model=model,
raw_response=responses_api_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=kwargs.get("encoding"),
api_key=kwargs.get("api_key"),
json_mode=kwargs.get("json_mode"),
)
else:
completion_stream = self.transformation_handler.get_model_response_iterator(
streaming_response=result, # type: ignore
sync_stream=False,
json_mode=kwargs.get("json_mode"),
)
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
return self._apply_post_stream_processing(
streamwrapper, model, custom_llm_provider
)
@staticmethod
def _apply_post_stream_processing(
stream: "CustomStreamWrapper",
model: str,
custom_llm_provider: str,
) -> Any:
"""Apply provider-specific post-stream processing if available."""
from litellm.types.utils import LlmProviders
from litellm.utils import ProviderConfigManager
try:
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders(custom_llm_provider)
)
except (ValueError, KeyError):
return stream
if provider_config is not None:
return provider_config.post_stream_processing(stream)
return stream
responses_api_bridge = ResponsesToCompletionBridgeHandler()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,241 @@
# Container Files API
This module provides a unified interface for container file operations across multiple LLM providers (OpenAI, Azure OpenAI, etc.).
## Architecture
```
endpoints.json # Declarative endpoint definitions
endpoint_factory.py # Auto-generates SDK functions
container_handler.py # Generic HTTP handler
BaseContainerConfig # Provider-specific transformations
├── OpenAIContainerConfig
└── AzureContainerConfig (example)
```
## Files Overview
| File | Purpose |
|------|---------|
| `endpoints.json` | **Single source of truth** - Defines all container file endpoints |
| `endpoint_factory.py` | Auto-generates SDK functions (`list_container_files`, etc.) |
| `main.py` | Core container operations (create, list, retrieve, delete containers) |
| `utils.py` | Request parameter utilities |
## Adding a New Endpoint
To add a new container file endpoint (e.g., `get_container_file_content`):
### Step 1: Add to `endpoints.json`
```json
{
"name": "get_container_file_content",
"async_name": "aget_container_file_content",
"path": "/containers/{container_id}/files/{file_id}/content",
"method": "GET",
"path_params": ["container_id", "file_id"],
"query_params": [],
"response_type": "ContainerFileContentResponse"
}
```
### Step 2: Add Response Type (if new)
In `litellm/types/containers/main.py`:
```python
class ContainerFileContentResponse(BaseModel):
"""Response for file content download."""
content: bytes
# ... other fields
```
### Step 3: Register Response Type
In `litellm/llms/custom_httpx/container_handler.py`, add to `RESPONSE_TYPES`:
```python
RESPONSE_TYPES = {
# ... existing types
"ContainerFileContentResponse": ContainerFileContentResponse,
}
```
### Step 4: Update Router (one-time setup)
In `litellm/router.py`, add the call_type to the factory_function Literal and `_init_containers_api_endpoints` condition.
In `litellm/proxy/route_llm_request.py`, add to the route mappings and skip-model-routing lists.
### Step 5: Update Proxy Handler Factory (if new path params)
If your endpoint has a new combination of path parameters, add a handler in `litellm/proxy/container_endpoints/handler_factory.py`:
```python
elif path_params == ["container_id", "file_id", "new_param"]:
async def handler(...):
# handler implementation
```
---
## Adding a New Provider (e.g., Azure OpenAI)
### Step 1: Create Provider Config
Create `litellm/llms/azure/containers/transformation.py`:
```python
from typing import Dict, Optional, Tuple, Any
import httpx
from litellm.llms.base_llm.containers.transformation import BaseContainerConfig
from litellm.types.containers.main import (
ContainerFileListResponse,
ContainerFileObject,
DeleteContainerFileResponse,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.secret_managers.main import get_secret_str
class AzureContainerConfig(BaseContainerConfig):
"""Configuration class for Azure OpenAI container API."""
def get_supported_openai_params(self) -> list:
return ["name", "expires_after", "file_ids", "extra_headers"]
def map_openai_params(
self,
container_create_optional_params,
drop_params: bool,
) -> Dict:
return dict(container_create_optional_params)
def validate_environment(
self,
headers: dict,
api_key: Optional[str] = None,
) -> dict:
"""Azure uses api-key header instead of Bearer token."""
import litellm
api_key = (
api_key
or litellm.azure_key
or get_secret_str("AZURE_API_KEY")
)
headers["api-key"] = api_key
return headers
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Azure format:
https://{resource}.openai.azure.com/openai/containers?api-version=2024-xx
"""
if api_base is None:
raise ValueError("api_base is required for Azure")
api_version = litellm_params.get("api_version", "2024-02-15-preview")
return f"{api_base.rstrip('/')}/openai/containers?api-version={api_version}"
# Implement remaining abstract methods from BaseContainerConfig:
# - transform_container_create_request
# - transform_container_create_response
# - transform_container_list_request
# - transform_container_list_response
# - transform_container_retrieve_request
# - transform_container_retrieve_response
# - transform_container_delete_request
# - transform_container_delete_response
# - transform_container_file_list_request
# - transform_container_file_list_response
```
### Step 2: Register Provider Config
In `litellm/utils.py`, find `ProviderConfigManager.get_provider_container_config()` and add:
```python
@staticmethod
def get_provider_container_config(
provider: LlmProviders,
) -> Optional[BaseContainerConfig]:
if provider == LlmProviders.OPENAI:
from litellm.llms.openai.containers.transformation import OpenAIContainerConfig
return OpenAIContainerConfig()
elif provider == LlmProviders.AZURE:
from litellm.llms.azure.containers.transformation import AzureContainerConfig
return AzureContainerConfig()
return None
```
### Step 3: Test the New Provider
```bash
# Create container via Azure
curl -X POST "http://localhost:4000/v1/containers" \
-H "Authorization: Bearer sk-1234" \
-H "custom-llm-provider: azure" \
-H "Content-Type: application/json" \
-d '{"name": "My Azure Container"}'
# List container files via Azure
curl -X GET "http://localhost:4000/v1/containers/cntr_123/files" \
-H "Authorization: Bearer sk-1234" \
-H "custom-llm-provider: azure"
```
---
## How Provider Selection Works
1. **Proxy receives request** with `custom-llm-provider` header/query/body
2. **Router calls** `ProviderConfigManager.get_provider_container_config(provider)`
3. **Generic handler** uses the provider config for:
- URL construction (`get_complete_url`)
- Authentication (`validate_environment`)
- Request/response transformation
---
## Testing
Run the container API tests:
```bash
cd /Users/ishaanjaffer/github/litellm
python -m pytest tests/test_litellm/containers/ -v
```
Test via proxy:
```bash
# Start proxy
cd litellm/proxy && python proxy_cli.py --config proxy_config.yaml --port 4000
# Test endpoints
curl -X GET "http://localhost:4000/v1/containers/cntr_123/files" \
-H "Authorization: Bearer sk-1234"
```
---
## Endpoint Reference
| Endpoint | Method | Path |
|----------|--------|------|
| List container files | GET | `/v1/containers/{container_id}/files` |
| Retrieve container file | GET | `/v1/containers/{container_id}/files/{file_id}` |
| Delete container file | DELETE | `/v1/containers/{container_id}/files/{file_id}` |
See `endpoints.json` for the complete list.

View File

@@ -0,0 +1,44 @@
"""Container management functions for LiteLLM."""
# Auto-generated container file functions from endpoints.json
from .endpoint_factory import (
adelete_container_file,
alist_container_files,
aretrieve_container_file,
aretrieve_container_file_content,
delete_container_file,
list_container_files,
retrieve_container_file,
retrieve_container_file_content,
)
from .main import (
acreate_container,
adelete_container,
alist_containers,
aretrieve_container,
create_container,
delete_container,
list_containers,
retrieve_container,
)
__all__ = [
# Core container operations
"acreate_container",
"adelete_container",
"alist_containers",
"aretrieve_container",
"create_container",
"delete_container",
"list_containers",
"retrieve_container",
# Container file operations (auto-generated from endpoints.json)
"adelete_container_file",
"alist_container_files",
"aretrieve_container_file",
"aretrieve_container_file_content",
"delete_container_file",
"list_container_files",
"retrieve_container_file",
"retrieve_container_file_content",
]

View File

@@ -0,0 +1,232 @@
"""
Factory for generating container SDK functions from JSON config.
This module reads endpoints.json and dynamically generates SDK functions
that use the generic container handler.
"""
import asyncio
import contextvars
import json
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Type
import litellm
from litellm.constants import request_timeout as DEFAULT_REQUEST_TIMEOUT
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.containers.transformation import BaseContainerConfig
from litellm.llms.custom_httpx.container_handler import generic_container_handler
from litellm.types.containers.main import (
ContainerFileListResponse,
ContainerFileObject,
DeleteContainerFileResponse,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import ProviderConfigManager, client
# Response type mapping
RESPONSE_TYPES: Dict[str, Type] = {
"ContainerFileListResponse": ContainerFileListResponse,
"ContainerFileObject": ContainerFileObject,
"DeleteContainerFileResponse": DeleteContainerFileResponse,
}
def _load_endpoints_config() -> Dict:
"""Load the endpoints configuration from JSON file."""
config_path = Path(__file__).parent / "endpoints.json"
with open(config_path) as f:
return json.load(f)
def create_sync_endpoint_function(endpoint_config: Dict) -> Callable:
"""
Create a sync SDK function from endpoint config.
Uses the generic container handler instead of individual handler methods.
"""
endpoint_name = endpoint_config["name"]
response_type = RESPONSE_TYPES.get(endpoint_config["response_type"])
path_params = endpoint_config.get("path_params", [])
@client
def endpoint_func(
timeout: int = 600,
custom_llm_provider: Literal["openai"] = "openai",
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
**kwargs,
):
local_vars = locals()
try:
litellm_logging_obj: LiteLLMLoggingObj = kwargs.pop("litellm_logging_obj")
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id")
_is_async = kwargs.pop("async_call", False) is True
# Check for mock response
mock_response = kwargs.get("mock_response")
if mock_response is not None:
if isinstance(mock_response, str):
mock_response = json.loads(mock_response)
if response_type:
return response_type(**mock_response)
return mock_response
# Get provider config
litellm_params = GenericLiteLLMParams(**kwargs)
container_provider_config: Optional[
BaseContainerConfig
] = ProviderConfigManager.get_provider_container_config(
provider=litellm.LlmProviders(custom_llm_provider),
)
if container_provider_config is None:
raise ValueError(
f"Container provider config not found for: {custom_llm_provider}"
)
# Build optional params for logging
optional_params = {k: kwargs.get(k) for k in path_params if k in kwargs}
# Pre-call logging
litellm_logging_obj.update_environment_variables(
model="",
optional_params=optional_params,
litellm_params={"litellm_call_id": litellm_call_id},
custom_llm_provider=custom_llm_provider,
)
# Use generic handler
return generic_container_handler.handle(
endpoint_name=endpoint_name,
container_provider_config=container_provider_config,
litellm_params=litellm_params,
logging_obj=litellm_logging_obj,
extra_headers=extra_headers,
extra_query=extra_query,
timeout=timeout or DEFAULT_REQUEST_TIMEOUT,
_is_async=_is_async,
**kwargs,
)
except Exception as e:
raise litellm.exception_type(
model="",
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
return endpoint_func
def create_async_endpoint_function(
sync_func: Callable,
endpoint_config: Dict,
) -> Callable:
"""Create an async SDK function that wraps the sync function."""
@client
async def async_endpoint_func(
timeout: int = 600,
custom_llm_provider: Literal["openai"] = "openai",
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
**kwargs,
):
local_vars = locals()
try:
loop = asyncio.get_event_loop()
kwargs["async_call"] = True
func = partial(
sync_func,
timeout=timeout,
custom_llm_provider=custom_llm_provider,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
**kwargs,
)
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response
except Exception as e:
raise litellm.exception_type(
model="",
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
return async_endpoint_func
def generate_container_endpoints() -> Dict[str, Callable]:
"""
Generate all container endpoint functions from the JSON config.
Returns a dict mapping function names to their implementations.
"""
config = _load_endpoints_config()
endpoints = {}
for endpoint_config in config["endpoints"]:
# Create sync function
sync_func = create_sync_endpoint_function(endpoint_config)
endpoints[endpoint_config["name"]] = sync_func
# Create async function
async_func = create_async_endpoint_function(sync_func, endpoint_config)
endpoints[endpoint_config["async_name"]] = async_func
return endpoints
def get_all_endpoint_names() -> List[str]:
"""Get all endpoint names (sync and async) from config."""
config = _load_endpoints_config()
names = []
for endpoint in config["endpoints"]:
names.append(endpoint["name"])
names.append(endpoint["async_name"])
return names
def get_async_endpoint_names() -> List[str]:
"""Get all async endpoint names for router registration."""
config = _load_endpoints_config()
return [endpoint["async_name"] for endpoint in config["endpoints"]]
# Generate endpoints on module load
_generated_endpoints = generate_container_endpoints()
# Export generated functions dynamically
list_container_files = _generated_endpoints.get("list_container_files")
alist_container_files = _generated_endpoints.get("alist_container_files")
upload_container_file = _generated_endpoints.get("upload_container_file")
aupload_container_file = _generated_endpoints.get("aupload_container_file")
retrieve_container_file = _generated_endpoints.get("retrieve_container_file")
aretrieve_container_file = _generated_endpoints.get("aretrieve_container_file")
delete_container_file = _generated_endpoints.get("delete_container_file")
adelete_container_file = _generated_endpoints.get("adelete_container_file")
retrieve_container_file_content = _generated_endpoints.get(
"retrieve_container_file_content"
)
aretrieve_container_file_content = _generated_endpoints.get(
"aretrieve_container_file_content"
)

View File

@@ -0,0 +1,51 @@
{
"endpoints": [
{
"name": "list_container_files",
"async_name": "alist_container_files",
"path": "/containers/{container_id}/files",
"method": "GET",
"path_params": ["container_id"],
"query_params": ["after", "limit", "order"],
"response_type": "ContainerFileListResponse"
},
{
"name": "upload_container_file",
"async_name": "aupload_container_file",
"path": "/containers/{container_id}/files",
"method": "POST",
"path_params": ["container_id"],
"query_params": [],
"response_type": "ContainerFileObject",
"is_multipart": true
},
{
"name": "retrieve_container_file",
"async_name": "aretrieve_container_file",
"path": "/containers/{container_id}/files/{file_id}",
"method": "GET",
"path_params": ["container_id", "file_id"],
"query_params": [],
"response_type": "ContainerFileObject"
},
{
"name": "delete_container_file",
"async_name": "adelete_container_file",
"path": "/containers/{container_id}/files/{file_id}",
"method": "DELETE",
"path_params": ["container_id", "file_id"],
"query_params": [],
"response_type": "DeleteContainerFileResponse"
},
{
"name": "retrieve_container_file_content",
"async_name": "aretrieve_container_file_content",
"path": "/containers/{container_id}/files/{file_id}/content",
"method": "GET",
"path_params": ["container_id", "file_id"],
"query_params": [],
"response_type": "raw",
"returns_binary": true
}
]
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,70 @@
from typing import Dict
from litellm.llms.base_llm.containers.transformation import BaseContainerConfig
from litellm.types.containers.main import (
ContainerCreateOptionalRequestParams,
ContainerListOptionalRequestParams,
)
class ContainerRequestUtils:
@staticmethod
def get_requested_container_create_optional_param(
passed_params: dict,
) -> ContainerCreateOptionalRequestParams:
"""Extract only valid container creation parameters from the passed parameters."""
container_create_optional_params = ContainerCreateOptionalRequestParams()
valid_params = [
"expires_after",
"file_ids",
"extra_headers",
"extra_body",
]
for param in valid_params:
if param in passed_params and passed_params[param] is not None:
container_create_optional_params[param] = passed_params[param] # type: ignore
return container_create_optional_params
@staticmethod
def get_optional_params_container_create(
container_provider_config: BaseContainerConfig,
container_create_optional_params: ContainerCreateOptionalRequestParams,
) -> Dict:
"""Get the optional parameters for container creation."""
supported_params = container_provider_config.get_supported_openai_params()
# Filter out unsupported parameters
filtered_params = {
k: v
for k, v in container_create_optional_params.items()
if k in supported_params
}
return container_provider_config.map_openai_params(
container_create_optional_params=filtered_params, # type: ignore
drop_params=False,
)
@staticmethod
def get_requested_container_list_optional_param(
passed_params: dict,
) -> ContainerListOptionalRequestParams:
"""Extract only valid container list parameters from the passed parameters."""
container_list_optional_params = ContainerListOptionalRequestParams()
valid_params = [
"after",
"limit",
"order",
"extra_headers",
"extra_query",
]
for param in valid_params:
if param in passed_params and passed_params[param] is not None:
container_list_optional_params[param] = passed_params[param] # type: ignore
return container_list_optional_params

View File

@@ -0,0 +1,5 @@
{
"gpt-3.5-turbo-0613": 0.00015000000000000001,
"claude-2": 0.00016454,
"gpt-4-0613": 0.015408
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,128 @@
"""
Handler for transforming /chat/completions api requests to litellm.responses requests
"""
from typing import TYPE_CHECKING, Optional, Union
from typing_extensions import TypedDict
if TYPE_CHECKING:
from litellm import LiteLLMLoggingObj
from litellm.types.llms.openai import HttpxBinaryResponseContent
class SpeechToCompletionBridgeHandlerInputKwargs(TypedDict):
model: str
input: str
voice: Optional[Union[str, dict]]
optional_params: dict
litellm_params: dict
logging_obj: "LiteLLMLoggingObj"
headers: dict
custom_llm_provider: str
class SpeechToCompletionBridgeHandler:
def __init__(self):
from .transformation import SpeechToCompletionBridgeTransformationHandler
super().__init__()
self.transformation_handler = SpeechToCompletionBridgeTransformationHandler()
def validate_input_kwargs(
self, kwargs: dict
) -> SpeechToCompletionBridgeHandlerInputKwargs:
from litellm import LiteLLMLoggingObj
model = kwargs.get("model")
if model is None or not isinstance(model, str):
raise ValueError("model is required")
custom_llm_provider = kwargs.get("custom_llm_provider")
if custom_llm_provider is None or not isinstance(custom_llm_provider, str):
raise ValueError("custom_llm_provider is required")
input = kwargs.get("input")
if input is None or not isinstance(input, str):
raise ValueError("input is required")
optional_params = kwargs.get("optional_params")
if optional_params is None or not isinstance(optional_params, dict):
raise ValueError("optional_params is required")
litellm_params = kwargs.get("litellm_params")
if litellm_params is None or not isinstance(litellm_params, dict):
raise ValueError("litellm_params is required")
headers = kwargs.get("headers")
if headers is None or not isinstance(headers, dict):
raise ValueError("headers is required")
headers = kwargs.get("headers")
if headers is None or not isinstance(headers, dict):
raise ValueError("headers is required")
logging_obj = kwargs.get("logging_obj")
if logging_obj is None or not isinstance(logging_obj, LiteLLMLoggingObj):
raise ValueError("logging_obj is required")
return SpeechToCompletionBridgeHandlerInputKwargs(
model=model,
input=input,
voice=kwargs.get("voice"),
optional_params=optional_params,
litellm_params=litellm_params,
logging_obj=logging_obj,
custom_llm_provider=custom_llm_provider,
headers=headers,
)
def speech(
self,
model: str,
input: str,
voice: Optional[Union[str, dict]],
optional_params: dict,
litellm_params: dict,
headers: dict,
logging_obj: "LiteLLMLoggingObj",
custom_llm_provider: str,
) -> "HttpxBinaryResponseContent":
received_args = locals()
from litellm import completion
from litellm.types.utils import ModelResponse
validated_kwargs = self.validate_input_kwargs(received_args)
model = validated_kwargs["model"]
input = validated_kwargs["input"]
optional_params = validated_kwargs["optional_params"]
litellm_params = validated_kwargs["litellm_params"]
headers = validated_kwargs["headers"]
logging_obj = validated_kwargs["logging_obj"]
custom_llm_provider = validated_kwargs["custom_llm_provider"]
voice = validated_kwargs["voice"]
request_data = self.transformation_handler.transform_request(
model=model,
input=input,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
litellm_logging_obj=logging_obj,
custom_llm_provider=custom_llm_provider,
voice=voice,
)
result = completion(
**request_data,
)
if isinstance(result, ModelResponse):
return self.transformation_handler.transform_response(
model_response=result,
)
else:
raise Exception("Unmapped response type. Got type: {}".format(type(result)))
speech_to_completion_bridge_handler = SpeechToCompletionBridgeHandler()

View File

@@ -0,0 +1,134 @@
from typing import TYPE_CHECKING, Optional, Union, cast
from litellm.constants import OPENAI_CHAT_COMPLETION_PARAMS
if TYPE_CHECKING:
from litellm import Logging as LiteLLMLoggingObj
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.utils import ModelResponse
class SpeechToCompletionBridgeTransformationHandler:
def transform_request(
self,
model: str,
input: str,
voice: Optional[Union[str, dict]],
optional_params: dict,
litellm_params: dict,
headers: dict,
litellm_logging_obj: "LiteLLMLoggingObj",
custom_llm_provider: str,
) -> dict:
passed_optional_params = {}
for op in optional_params:
if op in OPENAI_CHAT_COMPLETION_PARAMS:
passed_optional_params[op] = optional_params[op]
if voice is not None:
if isinstance(voice, str):
passed_optional_params["audio"] = {"voice": voice}
if "response_format" in optional_params:
passed_optional_params["audio"]["format"] = optional_params[
"response_format"
]
return_kwargs = {
"model": model,
"messages": [
{
"role": "user",
"content": input,
}
],
"modalities": ["audio"],
**passed_optional_params,
**litellm_params,
"headers": headers,
"litellm_logging_obj": litellm_logging_obj,
"custom_llm_provider": custom_llm_provider,
}
# filter out None values
return_kwargs = {k: v for k, v in return_kwargs.items() if v is not None}
return return_kwargs
def _convert_pcm16_to_wav(
self, pcm_data: bytes, sample_rate: int = 24000, channels: int = 1
) -> bytes:
"""
Convert raw PCM16 data to WAV format.
Args:
pcm_data: Raw PCM16 audio data
sample_rate: Sample rate in Hz (Gemini TTS typically uses 24000)
channels: Number of audio channels (1 for mono)
Returns:
bytes: WAV formatted audio data
"""
import struct
# WAV header parameters
byte_rate = sample_rate * channels * 2 # 2 bytes per sample (16-bit)
block_align = channels * 2
data_size = len(pcm_data)
file_size = 36 + data_size
# Create WAV header
wav_header = struct.pack(
"<4sI4s4sIHHIIHH4sI",
b"RIFF", # Chunk ID
file_size, # Chunk Size
b"WAVE", # Format
b"fmt ", # Subchunk1 ID
16, # Subchunk1 Size (PCM)
1, # Audio Format (PCM)
channels, # Number of Channels
sample_rate, # Sample Rate
byte_rate, # Byte Rate
block_align, # Block Align
16, # Bits per Sample
b"data", # Subchunk2 ID
data_size, # Subchunk2 Size
)
return wav_header + pcm_data
def _is_gemini_tts_model(self, model: str) -> bool:
"""Check if the model is a Gemini TTS model that returns PCM16 data."""
return "gemini" in model.lower() and (
"tts" in model.lower() or "preview-tts" in model.lower()
)
def transform_response(
self, model_response: "ModelResponse"
) -> "HttpxBinaryResponseContent":
import base64
import httpx
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.utils import Choices
audio_part = cast(Choices, model_response.choices[0]).message.audio
if audio_part is None:
raise ValueError("No audio part found in the response")
audio_content = audio_part.data
# Decode base64 to get binary content
binary_data = base64.b64decode(audio_content)
# Check if this is a Gemini TTS model that returns raw PCM16 data
model = getattr(model_response, "model", "")
headers = {}
if self._is_gemini_tts_model(model):
# Convert PCM16 to WAV format for proper audio file playback
binary_data = self._convert_pcm16_to_wav(binary_data)
headers["Content-Type"] = "audio/wav"
else:
headers["Content-Type"] = "audio/mpeg"
# Create an httpx.Response object
response = httpx.Response(status_code=200, content=binary_data, headers=headers)
return HttpxBinaryResponseContent(response)

View File

@@ -0,0 +1,33 @@
"""
Evals API operations
"""
from .main import (
acancel_eval,
acreate_eval,
adelete_eval,
aget_eval,
alist_evals,
aupdate_eval,
cancel_eval,
create_eval,
delete_eval,
get_eval,
list_evals,
update_eval,
)
__all__ = [
"acreate_eval",
"alist_evals",
"aget_eval",
"aupdate_eval",
"adelete_eval",
"acancel_eval",
"create_eval",
"list_evals",
"get_eval",
"update_eval",
"delete_eval",
"cancel_eval",
]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
# LiteLLM MCP Client
LiteLLM MCP Client is a client that allows you to use MCP tools with LiteLLM.

View File

@@ -0,0 +1,3 @@
from .tools import call_openai_tool, load_mcp_tools
__all__ = ["load_mcp_tools", "call_openai_tool"]

View File

@@ -0,0 +1,697 @@
"""
LiteLLM Proxy uses this MCP Client to connnect to other MCP servers.
"""
import asyncio
import base64
from typing import (
Any,
Awaitable,
Callable,
Dict,
Generator,
List,
Optional,
Tuple,
TypeVar,
Union,
)
import httpx
from mcp import ClientSession, ReadResourceResult, Resource, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
streamable_http_client: Optional[Any] = None
try:
import mcp.client.streamable_http as streamable_http_module # type: ignore
streamable_http_client = getattr(
streamable_http_module, "streamable_http_client", None
)
except ImportError:
pass
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
from mcp.types import CallToolResult as MCPCallToolResult
from mcp.types import (
GetPromptRequestParams,
GetPromptResult,
Prompt,
ResourceTemplate,
TextContent,
)
from mcp.types import Tool as MCPTool
from pydantic import AnyUrl
from litellm._logging import verbose_logger
from litellm.constants import MCP_CLIENT_TIMEOUT
from litellm.llms.custom_httpx.http_handler import get_ssl_configuration
from litellm.types.llms.custom_http import VerifyTypes
from litellm.types.mcp import (
MCPAuth,
MCPAuthType,
MCPStdioConfig,
MCPTransport,
MCPTransportType,
)
def to_basic_auth(auth_value: str) -> str:
"""Convert auth value to Basic Auth format."""
return base64.b64encode(auth_value.encode("utf-8")).decode()
TSessionResult = TypeVar("TSessionResult")
class MCPSigV4Auth(httpx.Auth):
"""
httpx Auth class that signs each request with AWS SigV4.
This is used for MCP servers that require AWS SigV4 authentication,
such as AWS Bedrock AgentCore MCP servers. httpx calls auth_flow()
for every outgoing request, enabling per-request signature computation.
"""
requires_request_body = True
def __init__(
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_service_name: Optional[str] = None,
):
try:
from botocore.credentials import Credentials
except ImportError:
raise ImportError(
"Missing botocore to use AWS SigV4 authentication. "
"Run 'pip install boto3'."
)
self.service_name = aws_service_name or "bedrock-agentcore"
self.region_name = aws_region_name or "us-east-1"
# Note: os.environ/ prefixed values are already resolved by
# ProxyConfig._check_for_os_environ_vars() at config load time.
# Values arrive here as plain strings.
if aws_access_key_id and aws_secret_access_key:
self.credentials = Credentials(
access_key=aws_access_key_id,
secret_key=aws_secret_access_key,
token=aws_session_token,
)
else:
# Fall back to default boto3 credential chain
import botocore.session
session = botocore.session.get_session()
self.credentials = session.get_credentials()
if self.credentials is None:
raise ValueError(
"No AWS credentials found. Provide aws_access_key_id and "
"aws_secret_access_key, or configure default credentials "
"(env vars, ~/.aws/credentials, instance profile)."
)
def auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
# Build AWSRequest from the httpx Request.
# Pass all request headers so the canonical SigV4 signature covers them.
aws_request = AWSRequest(
method=request.method,
url=str(request.url),
data=request.content,
headers=dict(request.headers),
)
# Sign the request — SigV4Auth.add_auth() adds Authorization,
# X-Amz-Date, and X-Amz-Security-Token (if session token present).
# Host header is derived automatically from the URL.
sigv4 = SigV4Auth(self.credentials, self.service_name, self.region_name)
sigv4.add_auth(aws_request)
# Copy SigV4 headers back to the httpx request
for header_name, header_value in aws_request.headers.items():
request.headers[header_name] = header_value
yield request
class MCPClient:
"""
MCP Client supporting:
SSE and HTTP transports
Authentication via Bearer token, Basic Auth, or API Key
Tool calling with error handling and result parsing
"""
def __init__(
self,
server_url: str = "",
transport_type: MCPTransportType = MCPTransport.http,
auth_type: MCPAuthType = None,
auth_value: Optional[Union[str, Dict[str, str]]] = None,
timeout: Optional[float] = None,
stdio_config: Optional[MCPStdioConfig] = None,
extra_headers: Optional[Dict[str, str]] = None,
ssl_verify: Optional[VerifyTypes] = None,
aws_auth: Optional[httpx.Auth] = None,
):
self.server_url: str = server_url
self.transport_type: MCPTransport = transport_type
self.auth_type: MCPAuthType = auth_type
self.timeout: float = timeout if timeout is not None else MCP_CLIENT_TIMEOUT
self._mcp_auth_value: Optional[Union[str, Dict[str, str]]] = None
self.stdio_config: Optional[MCPStdioConfig] = stdio_config
self.extra_headers: Optional[Dict[str, str]] = extra_headers
self.ssl_verify: Optional[VerifyTypes] = ssl_verify
self._aws_auth: Optional[httpx.Auth] = aws_auth
# handle the basic auth value if provided
if auth_value:
self.update_auth_value(auth_value)
def _create_transport_context(
self,
) -> Tuple[Any, Optional[httpx.AsyncClient]]:
"""
Create the appropriate transport context based on transport type.
Returns:
Tuple of (transport_context, http_client).
http_client is only set for HTTP transport and needs cleanup.
"""
http_client: Optional[httpx.AsyncClient] = None
if self.transport_type == MCPTransport.stdio:
if not self.stdio_config:
raise ValueError("stdio_config is required for stdio transport")
server_params = StdioServerParameters(
command=self.stdio_config.get("command", ""),
args=self.stdio_config.get("args", []),
env=self.stdio_config.get("env", {}),
)
return stdio_client(server_params), None
if self.transport_type == MCPTransport.sse:
headers = self._get_auth_headers()
httpx_client_factory = self._create_httpx_client_factory()
return (
sse_client(
url=self.server_url,
timeout=self.timeout,
headers=headers,
httpx_client_factory=httpx_client_factory,
),
None,
)
# HTTP transport (default)
if streamable_http_client is None:
raise ImportError(
"streamable_http_client is not available. "
"Please install mcp with HTTP support."
)
headers = self._get_auth_headers()
httpx_client_factory = self._create_httpx_client_factory()
verbose_logger.debug("litellm headers for streamable_http_client: %s", headers)
http_client = httpx_client_factory(
headers=headers,
timeout=httpx.Timeout(self.timeout),
)
transport_ctx = streamable_http_client(
url=self.server_url,
http_client=http_client,
)
return transport_ctx, http_client
async def _execute_session_operation(
self,
transport_ctx: Any,
operation: Callable[[ClientSession], Awaitable[TSessionResult]],
) -> TSessionResult:
"""
Execute an operation within a transport and session context.
Handles entering/exiting contexts and running the operation.
"""
transport = await transport_ctx.__aenter__()
try:
read_stream, write_stream = transport[0], transport[1]
session_ctx = ClientSession(read_stream, write_stream)
session = await session_ctx.__aenter__()
try:
await session.initialize()
return await operation(session)
finally:
try:
await session_ctx.__aexit__(None, None, None)
except BaseException as e:
verbose_logger.debug(f"Error during session context exit: {e}")
finally:
try:
await transport_ctx.__aexit__(None, None, None)
except BaseException as e:
verbose_logger.debug(f"Error during transport context exit: {e}")
async def run_with_session(
self, operation: Callable[[ClientSession], Awaitable[TSessionResult]]
) -> TSessionResult:
"""Open a session, run the provided coroutine, and clean up."""
http_client: Optional[httpx.AsyncClient] = None
try:
transport_ctx, http_client = self._create_transport_context()
return await self._execute_session_operation(transport_ctx, operation)
except Exception:
verbose_logger.warning(
"MCP client run_with_session failed for %s", self.server_url or "stdio"
)
raise
finally:
if http_client is not None:
try:
await http_client.aclose()
except BaseException as e:
verbose_logger.debug(f"Error during http_client cleanup: {e}")
def update_auth_value(self, mcp_auth_value: Union[str, Dict[str, str]]):
"""
Set the authentication header for the MCP client.
"""
if isinstance(mcp_auth_value, dict):
self._mcp_auth_value = mcp_auth_value
else:
if self.auth_type == MCPAuth.basic:
# Assuming mcp_auth_value is in format "username:password", convert it when updating
mcp_auth_value = to_basic_auth(mcp_auth_value)
self._mcp_auth_value = mcp_auth_value
def _get_auth_headers(self) -> dict:
"""Generate authentication headers based on auth type."""
headers = {}
if self._mcp_auth_value:
if isinstance(self._mcp_auth_value, str):
if self.auth_type == MCPAuth.bearer_token:
headers["Authorization"] = f"Bearer {self._mcp_auth_value}"
elif self.auth_type == MCPAuth.basic:
headers["Authorization"] = f"Basic {self._mcp_auth_value}"
elif self.auth_type == MCPAuth.api_key:
headers["X-API-Key"] = self._mcp_auth_value
elif self.auth_type == MCPAuth.authorization:
headers["Authorization"] = self._mcp_auth_value
elif self.auth_type == MCPAuth.oauth2:
headers["Authorization"] = f"Bearer {self._mcp_auth_value}"
elif self.auth_type == MCPAuth.token:
headers["Authorization"] = f"token {self._mcp_auth_value}"
elif isinstance(self._mcp_auth_value, dict):
headers.update(self._mcp_auth_value)
# Note: aws_sigv4 auth is not handled here — SigV4 requires per-request
# signing (including the body hash), so it uses httpx.Auth flow instead
# of static headers. See MCPSigV4Auth and _create_httpx_client_factory().
# update the headers with the extra headers
if self.extra_headers:
headers.update(self.extra_headers)
return headers
def _create_httpx_client_factory(self) -> Callable[..., httpx.AsyncClient]:
"""
Create a custom httpx client factory that uses LiteLLM's SSL configuration.
This factory follows the same CA bundle path logic as http_handler.py:
1. Check ssl_verify parameter (can be SSLContext, bool, or path to CA bundle)
2. Check SSL_VERIFY environment variable
3. Check SSL_CERT_FILE environment variable
4. Fall back to certifi CA bundle
"""
def factory(
*,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[httpx.Timeout] = None,
auth: Optional[httpx.Auth] = None,
) -> httpx.AsyncClient:
"""Create an httpx.AsyncClient with LiteLLM's SSL configuration."""
# Get unified SSL configuration using the same logic as http_handler.py
ssl_config = get_ssl_configuration(self.ssl_verify)
verbose_logger.debug(
f"MCP client using SSL configuration: {type(ssl_config).__name__}"
)
# Use SigV4 auth if configured and no explicit auth provided.
# The MCP SDK's sse_client and streamable_http_client call this
# factory without passing auth=, so self._aws_auth is used.
# For non-SigV4 clients, self._aws_auth is None — no behavior change.
effective_auth = auth if auth is not None else self._aws_auth
return httpx.AsyncClient(
headers=headers,
timeout=timeout,
auth=effective_auth,
verify=ssl_config,
follow_redirects=True,
)
return factory
async def list_tools(self) -> List[MCPTool]:
"""List available tools from the server."""
verbose_logger.debug(
f"MCP client listing tools from {self.server_url or 'stdio'}"
)
async def _list_tools_operation(session: ClientSession):
return await session.list_tools()
try:
result = await self.run_with_session(_list_tools_operation)
tool_count = len(result.tools)
tool_names = [tool.name for tool in result.tools]
verbose_logger.info(
f"MCP client listed {tool_count} tools from {self.server_url or 'stdio'}: {tool_names}"
)
return result.tools
except asyncio.CancelledError:
verbose_logger.warning("MCP client list_tools was cancelled")
raise
except Exception as e:
error_type = type(e).__name__
verbose_logger.exception(
f"MCP client list_tools failed - "
f"Error Type: {error_type}, "
f"Error: {str(e)}, "
f"Server: {self.server_url or 'stdio'}, "
f"Transport: {self.transport_type}"
)
# Check if it's a stream/connection error
if "BrokenResourceError" in error_type or "Broken" in error_type:
verbose_logger.error(
"MCP client detected broken connection/stream during list_tools - "
"the MCP server may have crashed, disconnected, or timed out"
)
# Return empty list instead of raising to allow graceful degradation
return []
async def call_tool(
self,
call_tool_request_params: MCPCallToolRequestParams,
host_progress_callback: Optional[Callable] = None,
) -> MCPCallToolResult:
"""
Call an MCP Tool.
"""
verbose_logger.info(
f"MCP client calling tool '{call_tool_request_params.name}' with arguments: {call_tool_request_params.arguments}"
)
async def on_progress(
progress: float, total: float | None, message: str | None
):
percentage = (progress / total * 100) if total else 0
verbose_logger.info(
f"MCP Tool '{call_tool_request_params.name}' progress: "
f"{progress}/{total} ({percentage:.0f}%) - {message or ''}"
)
# Forward to Host if callback provided
if host_progress_callback:
try:
await host_progress_callback(progress, total)
except Exception as e:
verbose_logger.warning(f"Failed to forward to Host: {e}")
async def _call_tool_operation(session: ClientSession):
verbose_logger.debug("MCP client sending tool call to session")
return await session.call_tool(
name=call_tool_request_params.name,
arguments=call_tool_request_params.arguments,
progress_callback=on_progress,
)
try:
tool_result = await self.run_with_session(_call_tool_operation)
verbose_logger.info(
f"MCP client tool call '{call_tool_request_params.name}' completed successfully"
)
return tool_result
except asyncio.CancelledError:
verbose_logger.warning("MCP client tool call was cancelled")
raise
except Exception as e:
import traceback
error_trace = traceback.format_exc()
verbose_logger.debug(f"MCP client tool call traceback:\n{error_trace}")
# Log detailed error information
error_type = type(e).__name__
verbose_logger.error(
f"MCP client call_tool failed - "
f"Error Type: {error_type}, "
f"Error: {str(e)}, "
f"Tool: {call_tool_request_params.name}, "
f"Server: {self.server_url or 'stdio'}, "
f"Transport: {self.transport_type}"
)
# Check if it's a stream/connection error
if "BrokenResourceError" in error_type or "Broken" in error_type:
verbose_logger.error(
"MCP client detected broken connection/stream - "
"the MCP server may have crashed, disconnected, or timed out."
)
# Return a default error result instead of raising
return MCPCallToolResult(
content=[
TextContent(type="text", text=f"{error_type}: {str(e)}")
], # Empty content for error case
isError=True,
)
async def list_prompts(self) -> List[Prompt]:
"""List available prompts from the server."""
verbose_logger.debug(
f"MCP client listing tools from {self.server_url or 'stdio'}"
)
async def _list_prompts_operation(session: ClientSession):
return await session.list_prompts()
try:
result = await self.run_with_session(_list_prompts_operation)
prompt_count = len(result.prompts)
prompt_names = [prompt.name for prompt in result.prompts]
verbose_logger.info(
f"MCP client listed {prompt_count} tools from {self.server_url or 'stdio'}: {prompt_names}"
)
return result.prompts
except asyncio.CancelledError:
verbose_logger.warning("MCP client list_prompts was cancelled")
raise
except Exception as e:
error_type = type(e).__name__
verbose_logger.error(
f"MCP client list_prompts failed - "
f"Error Type: {error_type}, "
f"Error: {str(e)}, "
f"Server: {self.server_url or 'stdio'}, "
f"Transport: {self.transport_type}"
)
# Check if it's a stream/connection error
if "BrokenResourceError" in error_type or "Broken" in error_type:
verbose_logger.error(
"MCP client detected broken connection/stream during list_tools - "
"the MCP server may have crashed, disconnected, or timed out"
)
# Return empty list instead of raising to allow graceful degradation
return []
async def get_prompt(
self, get_prompt_request_params: GetPromptRequestParams
) -> GetPromptResult:
"""Fetch a prompt definition from the MCP server."""
verbose_logger.info(
f"MCP client fetching prompt '{get_prompt_request_params.name}' with arguments: {get_prompt_request_params.arguments}"
)
async def _get_prompt_operation(session: ClientSession):
verbose_logger.debug("MCP client sending get_prompt request to session")
return await session.get_prompt(
name=get_prompt_request_params.name,
arguments=get_prompt_request_params.arguments,
)
try:
get_prompt_result = await self.run_with_session(_get_prompt_operation)
verbose_logger.info(
f"MCP client get_prompt '{get_prompt_request_params.name}' completed successfully"
)
return get_prompt_result
except asyncio.CancelledError:
verbose_logger.warning("MCP client get_prompt was cancelled")
raise
except Exception as e:
import traceback
error_trace = traceback.format_exc()
verbose_logger.debug(f"MCP client get_prompt traceback:\n{error_trace}")
# Log detailed error information
error_type = type(e).__name__
verbose_logger.error(
f"MCP client get_prompt failed - "
f"Error Type: {error_type}, "
f"Error: {str(e)}, "
f"Prompt: {get_prompt_request_params.name}, "
f"Server: {self.server_url or 'stdio'}, "
f"Transport: {self.transport_type}"
)
# Check if it's a stream/connection error
if "BrokenResourceError" in error_type or "Broken" in error_type:
verbose_logger.error(
"MCP client detected broken connection/stream during get_prompt - "
"the MCP server may have crashed, disconnected, or timed out."
)
raise
async def list_resources(self) -> list[Resource]:
"""List available resources from the server."""
verbose_logger.debug(
f"MCP client listing resources from {self.server_url or 'stdio'}"
)
async def _list_resources_operation(session: ClientSession):
return await session.list_resources()
try:
result = await self.run_with_session(_list_resources_operation)
resource_count = len(result.resources)
resource_names = [resource.name for resource in result.resources]
verbose_logger.info(
f"MCP client listed {resource_count} resources from {self.server_url or 'stdio'}: {resource_names}"
)
return result.resources
except asyncio.CancelledError:
verbose_logger.warning("MCP client list_resources was cancelled")
raise
except Exception as e:
error_type = type(e).__name__
verbose_logger.error(
f"MCP client list_resources failed - "
f"Error Type: {error_type}, "
f"Error: {str(e)}, "
f"Server: {self.server_url or 'stdio'}, "
f"Transport: {self.transport_type}"
)
# Check if it's a stream/connection error
if "BrokenResourceError" in error_type or "Broken" in error_type:
verbose_logger.error(
"MCP client detected broken connection/stream during list_resources - "
"the MCP server may have crashed, disconnected, or timed out"
)
# Return empty list instead of raising to allow graceful degradation
return []
async def list_resource_templates(self) -> list[ResourceTemplate]:
"""List available resource templates from the server."""
verbose_logger.debug(
f"MCP client listing resource templates from {self.server_url or 'stdio'}"
)
async def _list_resource_templates_operation(session: ClientSession):
return await session.list_resource_templates()
try:
result = await self.run_with_session(_list_resource_templates_operation)
resource_template_count = len(result.resourceTemplates)
resource_template_names = [
resourceTemplate.name for resourceTemplate in result.resourceTemplates
]
verbose_logger.info(
f"MCP client listed {resource_template_count} resource templates from {self.server_url or 'stdio'}: {resource_template_names}"
)
return result.resourceTemplates
except asyncio.CancelledError:
verbose_logger.warning("MCP client list_resource_templates was cancelled")
raise
except Exception as e:
error_type = type(e).__name__
verbose_logger.error(
f"MCP client list_resource_templates failed - "
f"Error Type: {error_type}, "
f"Error: {str(e)}, "
f"Server: {self.server_url or 'stdio'}, "
f"Transport: {self.transport_type}"
)
# Check if it's a stream/connection error
if "BrokenResourceError" in error_type or "Broken" in error_type:
verbose_logger.error(
"MCP client detected broken connection/stream during list_resource_templates - "
"the MCP server may have crashed, disconnected, or timed out"
)
# Return empty list instead of raising to allow graceful degradation
return []
async def read_resource(self, url: AnyUrl) -> ReadResourceResult:
"""Fetch resource contents from the MCP server."""
verbose_logger.info(f"MCP client fetching resource '{url}'")
async def _read_resource_operation(session: ClientSession):
verbose_logger.debug("MCP client sending read_resource request to session")
return await session.read_resource(url)
try:
read_resource_result = await self.run_with_session(_read_resource_operation)
verbose_logger.info(
f"MCP client read_resource '{url}' completed successfully"
)
return read_resource_result
except asyncio.CancelledError:
verbose_logger.warning("MCP client read_resource was cancelled")
raise
except Exception as e:
import traceback
error_trace = traceback.format_exc()
verbose_logger.debug(f"MCP client read_resource traceback:\n{error_trace}")
# Log detailed error information
error_type = type(e).__name__
verbose_logger.error(
f"MCP client read_resource failed - "
f"Error Type: {error_type}, "
f"Error: {str(e)}, "
f"Url: {url}, "
f"Server: {self.server_url or 'stdio'}, "
f"Transport: {self.transport_type}"
)
# Check if it's a stream/connection error
if "BrokenResourceError" in error_type or "Broken" in error_type:
verbose_logger.error(
"MCP client detected broken connection/stream during read_resource - "
"the MCP server may have crashed, disconnected, or timed out."
)
raise

View File

@@ -0,0 +1,159 @@
import json
from typing import Dict, List, Literal, Union
from mcp import ClientSession
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
from mcp.types import CallToolResult as MCPCallToolResult
from mcp.types import Tool as MCPTool
from openai.types.chat import ChatCompletionToolParam
from openai.types.responses.function_tool_param import FunctionToolParam
from openai.types.shared_params.function_definition import FunctionDefinition
from litellm.types.utils import ChatCompletionMessageToolCall
########################################################
# List MCP Tool functions
########################################################
def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolParam:
"""Convert an MCP tool to an OpenAI tool."""
normalized_parameters = _normalize_mcp_input_schema(mcp_tool.inputSchema)
return ChatCompletionToolParam(
type="function",
function=FunctionDefinition(
name=mcp_tool.name,
description=mcp_tool.description or "",
parameters=normalized_parameters,
strict=False,
),
)
def _normalize_mcp_input_schema(input_schema: dict) -> dict:
"""
Normalize MCP input schema to ensure it's valid for OpenAI function calling.
OpenAI requires that function parameters have:
- type: 'object'
- properties: dict (can be empty)
- additionalProperties: false (recommended)
"""
if not input_schema:
return {"type": "object", "properties": {}, "additionalProperties": False}
# Make a copy to avoid modifying the original
normalized_schema = dict(input_schema)
# Ensure type is 'object'
if "type" not in normalized_schema:
normalized_schema["type"] = "object"
# Ensure properties exists (can be empty)
if "properties" not in normalized_schema:
normalized_schema["properties"] = {}
# Add additionalProperties if not present (recommended by OpenAI)
if "additionalProperties" not in normalized_schema:
normalized_schema["additionalProperties"] = False
return normalized_schema
def transform_mcp_tool_to_openai_responses_api_tool(
mcp_tool: MCPTool,
) -> FunctionToolParam:
"""Convert an MCP tool to an OpenAI Responses API tool."""
normalized_parameters = _normalize_mcp_input_schema(mcp_tool.inputSchema)
return FunctionToolParam(
name=mcp_tool.name,
parameters=normalized_parameters,
strict=False,
type="function",
description=mcp_tool.description or "",
)
async def load_mcp_tools(
session: ClientSession, format: Literal["mcp", "openai"] = "mcp"
) -> Union[List[MCPTool], List[ChatCompletionToolParam]]:
"""
Load all available MCP tools
Args:
session: The MCP session to use
format: The format to convert the tools to
By default, the tools are returned in MCP format.
If format is set to "openai", the tools are converted to OpenAI API compatible tools.
"""
tools = await session.list_tools()
if format == "openai":
return [
transform_mcp_tool_to_openai_tool(mcp_tool=tool) for tool in tools.tools
]
return tools.tools
########################################################
# Call MCP Tool functions
########################################################
async def call_mcp_tool(
session: ClientSession,
call_tool_request_params: MCPCallToolRequestParams,
) -> MCPCallToolResult:
"""Call an MCP tool."""
tool_result = await session.call_tool(
name=call_tool_request_params.name,
arguments=call_tool_request_params.arguments,
)
return tool_result
def _get_function_arguments(function: FunctionDefinition) -> dict:
"""Helper to safely get and parse function arguments."""
arguments = function.get("arguments", {})
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
return arguments if isinstance(arguments, dict) else {}
def transform_openai_tool_call_request_to_mcp_tool_call_request(
openai_tool: Union[ChatCompletionMessageToolCall, Dict],
) -> MCPCallToolRequestParams:
"""Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams."""
function = openai_tool["function"]
return MCPCallToolRequestParams(
name=function["name"],
arguments=_get_function_arguments(function),
)
async def call_openai_tool(
session: ClientSession,
openai_tool: ChatCompletionMessageToolCall,
) -> MCPCallToolResult:
"""
Call an OpenAI tool using MCP client.
Args:
session: The MCP session to use
openai_tool: The OpenAI tool to call. You can get this from the `choices[0].message.tool_calls[0]` of the response from the OpenAI API.
Returns:
The result of the MCP tool call.
"""
mcp_tool_call_request_params = (
transform_openai_tool_call_request_to_mcp_tool_call_request(
openai_tool=openai_tool,
)
)
return await call_mcp_tool(
session=session,
call_tool_request_params=mcp_tool_call_request_params,
)

View File

@@ -0,0 +1,984 @@
"""
Main File for Files API implementation
https://platform.openai.com/docs/api-reference/files
"""
import asyncio
import contextvars
import time
import uuid as uuid_module
from functools import partial
from typing import Any, Coroutine, Dict, Literal, Optional, Union, cast
import httpx
# Type aliases for provider parameters
FileCreateProvider = Literal[
"openai",
"azure",
"gemini",
"vertex_ai",
"bedrock",
"hosted_vllm",
"manus",
"anthropic",
]
FileRetrieveProvider = Literal[
"openai", "azure", "gemini", "vertex_ai", "hosted_vllm", "manus", "anthropic"
]
FileDeleteProvider = Literal["openai", "azure", "gemini", "manus", "anthropic"]
FileListProvider = Literal["openai", "azure", "manus", "anthropic"]
FileContentProvider = Literal[
"openai", "azure", "vertex_ai", "bedrock", "hosted_vllm", "anthropic", "manus"
]
import litellm
from litellm import get_secret_str
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.azure.common_utils import get_azure_credentials
from litellm.llms.azure.files.handler import AzureOpenAIFilesAPI
from litellm.llms.bedrock.files.handler import BedrockFilesHandler
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.llms.openai.common_utils import get_openai_credentials
from litellm.llms.openai.openai import FileDeleted, FileObject, OpenAIFilesAPI
from litellm.llms.vertex_ai.files.handler import VertexAIFilesHandler
from litellm.types.llms.openai import (
CreateFileRequest,
FileContentRequest,
FileExpiresAfter,
FileTypes,
HttpxBinaryResponseContent,
OpenAIFileObject,
)
from litellm.types.router import *
from litellm.types.utils import (
OPENAI_COMPATIBLE_BATCH_AND_FILES_PROVIDERS,
LlmProviders,
)
from litellm.utils import (
ProviderConfigManager,
client,
get_litellm_params,
supports_httpx_timeout,
)
base_llm_http_handler = BaseLLMHTTPHandler()
####### ENVIRONMENT VARIABLES ###################
openai_files_instance = OpenAIFilesAPI()
azure_files_instance = AzureOpenAIFilesAPI()
vertex_ai_files_instance = VertexAIFilesHandler()
bedrock_files_instance = BedrockFilesHandler()
#################################################
@client
async def acreate_file(
file: FileTypes,
purpose: Literal["assistants", "batch", "fine-tune", "messages"],
expires_after: Optional[FileExpiresAfter] = None,
custom_llm_provider: FileCreateProvider = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> OpenAIFileObject:
"""
Async: Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
"""
try:
loop = asyncio.get_event_loop()
kwargs["acreate_file"] = True
call_args = {
"file": file,
"purpose": purpose,
"expires_after": expires_after,
"custom_llm_provider": custom_llm_provider,
"extra_headers": extra_headers,
"extra_body": extra_body,
**kwargs,
}
# Use a partial function to pass your keyword arguments
func = partial(create_file, **call_args)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
@client
def create_file(
file: FileTypes,
purpose: Literal["assistants", "batch", "fine-tune", "messages"],
expires_after: Optional[FileExpiresAfter] = None,
custom_llm_provider: Optional[FileCreateProvider] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[OpenAIFileObject, Coroutine[Any, Any, OpenAIFileObject]]:
"""
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
Specify either provider_list or custom_llm_provider.
"""
try:
_is_async = kwargs.pop("acreate_file", False) is True
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = dict(**kwargs)
logging_obj = cast(
Optional[LiteLLMLoggingObj], kwargs.get("litellm_logging_obj")
)
if logging_obj is None:
raise ValueError("logging_obj is required")
client = kwargs.get("client")
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(cast(str, custom_llm_provider)) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
if expires_after is not None:
_create_file_request = CreateFileRequest(
file=file,
purpose=purpose,
expires_after=expires_after,
extra_headers=extra_headers,
extra_body=extra_body,
)
else:
_create_file_request = CreateFileRequest(
file=file,
purpose=purpose,
extra_headers=extra_headers,
extra_body=extra_body,
)
provider_config = ProviderConfigManager.get_provider_files_config(
model="",
provider=LlmProviders(custom_llm_provider),
)
if provider_config is not None:
response = base_llm_http_handler.create_file(
provider_config=provider_config,
litellm_params=litellm_params_dict,
create_file_data=_create_file_request,
headers=extra_headers or {},
api_base=optional_params.api_base,
api_key=optional_params.api_key,
logging_obj=logging_obj,
_is_async=_is_async,
client=(
client
if client is not None
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
else None
),
timeout=timeout,
)
elif custom_llm_provider in OPENAI_COMPATIBLE_BATCH_AND_FILES_PROVIDERS:
openai_creds = get_openai_credentials(
api_base=optional_params.api_base,
api_key=optional_params.api_key,
organization=optional_params.organization,
)
response = openai_files_instance.create_file(
_is_async=_is_async,
api_base=openai_creds.api_base,
api_key=openai_creds.api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=openai_creds.organization,
create_file_data=_create_file_request,
)
elif custom_llm_provider == "azure":
azure_creds = get_azure_credentials(
api_base=optional_params.api_base,
api_key=optional_params.api_key,
api_version=optional_params.api_version,
)
response = azure_files_instance.create_file(
_is_async=_is_async,
api_base=azure_creds.api_base,
api_key=azure_creds.api_key,
api_version=azure_creds.api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
create_file_data=_create_file_request,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_file'. Only ['openai', 'azure', 'vertex_ai', 'manus', 'anthropic'] are supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_file", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e
@client
async def afile_retrieve(
file_id: str,
custom_llm_provider: FileRetrieveProvider = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> OpenAIFileObject:
"""
Async: Get file contents
LiteLLM Equivalent of GET https://api.openai.com/v1/files
"""
try:
loop = asyncio.get_event_loop()
kwargs["is_async"] = True
# Use a partial function to pass your keyword arguments
func = partial(
file_retrieve,
file_id,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return OpenAIFileObject(**response.model_dump())
except Exception as e:
raise e
@client
def file_retrieve(
file_id: str,
custom_llm_provider: FileRetrieveProvider = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> FileObject:
"""
Returns the contents of the specified file.
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("is_async", False) is True
if custom_llm_provider in OPENAI_COMPATIBLE_BATCH_AND_FILES_PROVIDERS:
openai_creds = get_openai_credentials(
api_base=optional_params.api_base,
api_key=optional_params.api_key,
organization=optional_params.organization,
)
response = openai_files_instance.retrieve_file(
file_id=file_id,
_is_async=_is_async,
api_base=openai_creds.api_base,
api_key=openai_creds.api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=openai_creds.organization,
)
elif custom_llm_provider == "azure":
azure_creds = get_azure_credentials(
api_base=optional_params.api_base,
api_key=optional_params.api_key,
api_version=optional_params.api_version,
)
response = azure_files_instance.retrieve_file(
_is_async=_is_async,
api_base=azure_creds.api_base,
api_key=azure_creds.api_key,
api_version=azure_creds.api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
file_id=file_id,
)
else:
# Try using provider config pattern (for Manus, Bedrock, etc.)
provider_config = ProviderConfigManager.get_provider_files_config(
model="",
provider=LlmProviders(custom_llm_provider),
)
if provider_config is not None:
litellm_params_dict = get_litellm_params(**kwargs)
litellm_params_dict["api_key"] = optional_params.api_key
litellm_params_dict["api_base"] = optional_params.api_base
logging_obj = kwargs.get("litellm_logging_obj")
if logging_obj is None:
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObj,
)
logging_obj = LiteLLMLoggingObj(
model="",
messages=[],
stream=False,
call_type="afile_retrieve" if _is_async else "file_retrieve",
start_time=time.time(),
litellm_call_id=kwargs.get(
"litellm_call_id", str(uuid_module.uuid4())
),
function_id=str(kwargs.get("id") or ""),
)
client = kwargs.get("client")
response = base_llm_http_handler.retrieve_file(
file_id=file_id,
provider_config=provider_config,
litellm_params=litellm_params_dict,
headers=extra_headers or {},
logging_obj=logging_obj,
_is_async=_is_async,
client=(
client
if client is not None
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
else None
),
timeout=timeout,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'file_retrieve'. Only 'openai', 'azure', 'manus', and 'anthropic' are supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return cast(FileObject, response)
except Exception as e:
raise e
# Delete file
@client
async def afile_delete(
file_id: str,
custom_llm_provider: FileDeleteProvider = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Coroutine[Any, Any, FileObject]:
"""
Async: Delete file
LiteLLM Equivalent of DELETE https://api.openai.com/v1/files
"""
try:
loop = asyncio.get_event_loop()
model = kwargs.pop("model", None)
kwargs["is_async"] = True
# Use a partial function to pass your keyword arguments
func = partial(
file_delete,
file_id,
model,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return cast(FileDeleted, response) # type: ignore
except Exception as e:
raise e
@client
def file_delete(
file_id: str,
model: Optional[str] = None,
custom_llm_provider: Union[FileDeleteProvider, str] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> FileDeleted:
"""
Delete file
LiteLLM Equivalent of DELETE https://api.openai.com/v1/files
"""
try:
try:
if model is not None:
_, custom_llm_provider, _, _ = get_llm_provider(
model, custom_llm_provider
)
except Exception:
pass
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
client = kwargs.get("client")
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("is_async", False) is True
if custom_llm_provider in OPENAI_COMPATIBLE_BATCH_AND_FILES_PROVIDERS:
openai_creds = get_openai_credentials(
api_base=optional_params.api_base,
api_key=optional_params.api_key,
organization=optional_params.organization,
)
response = openai_files_instance.delete_file(
file_id=file_id,
_is_async=_is_async,
api_base=openai_creds.api_base,
api_key=openai_creds.api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=openai_creds.organization,
)
elif custom_llm_provider == "azure":
azure_creds = get_azure_credentials(
api_base=optional_params.api_base,
api_key=optional_params.api_key,
api_version=optional_params.api_version,
)
response = azure_files_instance.delete_file(
_is_async=_is_async,
api_base=azure_creds.api_base,
api_key=azure_creds.api_key,
api_version=azure_creds.api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
file_id=file_id,
client=client,
litellm_params=litellm_params_dict,
)
else:
# Try using provider config pattern (for Manus, Bedrock, etc.)
provider_config = ProviderConfigManager.get_provider_files_config(
model="",
provider=LlmProviders(custom_llm_provider),
)
if provider_config is not None:
litellm_params_dict["api_key"] = optional_params.api_key
litellm_params_dict["api_base"] = optional_params.api_base
logging_obj = kwargs.get("litellm_logging_obj")
if logging_obj is None:
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObj,
)
logging_obj = LiteLLMLoggingObj(
model="",
messages=[],
stream=False,
call_type="afile_delete" if _is_async else "file_delete",
start_time=time.time(),
litellm_call_id=kwargs.get(
"litellm_call_id", str(uuid_module.uuid4())
),
function_id=str(kwargs.get("id") or ""),
)
response = base_llm_http_handler.delete_file(
file_id=file_id,
provider_config=provider_config,
litellm_params=litellm_params_dict,
headers=extra_headers or {},
logging_obj=logging_obj,
_is_async=_is_async,
client=(
client
if client is not None
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
else None
),
timeout=timeout,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'file_delete'. Only 'openai', 'azure', 'gemini', 'manus', and 'anthropic' are supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return cast(FileDeleted, response)
except Exception as e:
raise e
# List files
@client
async def afile_list(
custom_llm_provider: FileListProvider = "openai",
purpose: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
):
"""
Async: List files
LiteLLM Equivalent of GET https://api.openai.com/v1/files
"""
try:
loop = asyncio.get_event_loop()
kwargs["is_async"] = True
# Use a partial function to pass your keyword arguments
func = partial(
file_list,
custom_llm_provider,
purpose,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
@client
def file_list(
custom_llm_provider: FileListProvider = "openai",
purpose: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
):
"""
List files
LiteLLM Equivalent of GET https://api.openai.com/v1/files
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("is_async", False) is True
# Check if provider has a custom files config (e.g., Manus, Bedrock, Vertex AI)
provider_config = ProviderConfigManager.get_provider_files_config(
model="",
provider=LlmProviders(custom_llm_provider),
)
if provider_config is not None:
litellm_params_dict = get_litellm_params(**kwargs)
litellm_params_dict["api_key"] = optional_params.api_key
litellm_params_dict["api_base"] = optional_params.api_base
logging_obj = kwargs.get("litellm_logging_obj")
if logging_obj is None:
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObj,
)
logging_obj = LiteLLMLoggingObj(
model="",
messages=[],
stream=False,
call_type="afile_list" if _is_async else "file_list",
start_time=time.time(),
litellm_call_id=kwargs.get(
"litellm_call_id", str(uuid_module.uuid4())
),
function_id=str(kwargs.get("id", "")),
)
client = kwargs.get("client")
response = base_llm_http_handler.list_files(
purpose=purpose,
provider_config=provider_config,
litellm_params=litellm_params_dict,
headers=extra_headers or {},
logging_obj=logging_obj,
_is_async=_is_async,
client=(
client
if client is not None
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
else None
),
timeout=timeout,
)
return response
elif custom_llm_provider in OPENAI_COMPATIBLE_BATCH_AND_FILES_PROVIDERS:
openai_creds = get_openai_credentials(
api_base=optional_params.api_base,
api_key=optional_params.api_key,
organization=optional_params.organization,
)
response = openai_files_instance.list_files(
purpose=purpose,
_is_async=_is_async,
api_base=openai_creds.api_base,
api_key=openai_creds.api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=openai_creds.organization,
)
elif custom_llm_provider == "azure":
azure_creds = get_azure_credentials(
api_base=optional_params.api_base,
api_key=optional_params.api_key,
api_version=optional_params.api_version,
)
response = azure_files_instance.list_files(
_is_async=_is_async,
api_base=azure_creds.api_base,
api_key=azure_creds.api_key,
api_version=azure_creds.api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
purpose=purpose,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'file_list'. Only 'openai', 'azure', 'manus', and 'anthropic' are supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="file_list", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e
@client
async def afile_content(
file_id: str,
custom_llm_provider: FileContentProvider = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> HttpxBinaryResponseContent:
"""
Async: Get file contents
LiteLLM Equivalent of GET https://api.openai.com/v1/files
"""
try:
loop = asyncio.get_event_loop()
kwargs["afile_content"] = True
model = kwargs.pop("model", None)
# Use a partial function to pass your keyword arguments
func = partial(
file_content,
file_id,
model,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
@client
def file_content(
file_id: str,
model: Optional[str] = None,
custom_llm_provider: Optional[Union[FileContentProvider, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]]:
"""
Returns the contents of the specified file.
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
client = kwargs.get("client")
# set timeout for 10 minutes by default
try:
if model is not None:
_, custom_llm_provider, _, _ = get_llm_provider(
model, custom_llm_provider
)
except Exception:
pass
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(cast(str, custom_llm_provider)) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_file_content_request = FileContentRequest(
file_id=file_id,
extra_headers=extra_headers,
extra_body=extra_body,
)
_is_async = kwargs.pop("afile_content", False) is True
# Check if provider has a custom files config (e.g., Anthropic, Manus)
provider_config = ProviderConfigManager.get_provider_files_config(
model="",
provider=LlmProviders(custom_llm_provider),
)
if provider_config is not None:
litellm_params_dict["api_key"] = optional_params.api_key
litellm_params_dict["api_base"] = optional_params.api_base
logging_obj = kwargs.get("litellm_logging_obj")
if logging_obj is None:
logging_obj = LiteLLMLoggingObj(
model="",
messages=[],
stream=False,
call_type="afile_content" if _is_async else "file_content",
start_time=time.time(),
litellm_call_id=kwargs.get(
"litellm_call_id", str(uuid_module.uuid4())
),
function_id=str(kwargs.get("id") or ""),
)
response = base_llm_http_handler.retrieve_file_content(
file_content_request=_file_content_request,
provider_config=provider_config,
litellm_params=litellm_params_dict,
headers=extra_headers or {},
logging_obj=logging_obj,
_is_async=_is_async,
client=(
client
if client is not None
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
else None
),
timeout=timeout,
)
return response
if custom_llm_provider in OPENAI_COMPATIBLE_BATCH_AND_FILES_PROVIDERS:
openai_creds = get_openai_credentials(
api_base=optional_params.api_base,
api_key=optional_params.api_key,
organization=optional_params.organization,
)
response = openai_files_instance.file_content(
_is_async=_is_async,
file_content_request=_file_content_request,
api_base=openai_creds.api_base,
api_key=openai_creds.api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=openai_creds.organization,
)
elif custom_llm_provider == "azure":
azure_creds = get_azure_credentials(
api_base=optional_params.api_base,
api_key=optional_params.api_key,
api_version=optional_params.api_version,
)
response = azure_files_instance.file_content(
_is_async=_is_async,
api_base=azure_creds.api_base,
api_key=azure_creds.api_key,
api_version=azure_creds.api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
file_content_request=_file_content_request,
client=client,
litellm_params=litellm_params_dict,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""
vertex_ai_project = (
optional_params.vertex_project
or litellm.vertex_project
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.vertex_location
or litellm.vertex_location
or get_secret_str("VERTEXAI_LOCATION")
)
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
"VERTEXAI_CREDENTIALS"
)
response = vertex_ai_files_instance.file_content(
_is_async=_is_async,
file_content_request=_file_content_request,
api_base=api_base,
vertex_credentials=vertex_credentials,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
timeout=timeout,
max_retries=optional_params.max_retries,
)
elif custom_llm_provider == "bedrock":
response = bedrock_files_instance.file_content(
_is_async=_is_async,
file_content_request=_file_content_request,
api_base=optional_params.api_base,
optional_params=litellm_params_dict,
timeout=timeout,
max_retries=optional_params.max_retries,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'file_content'. Supported providers are 'openai', 'azure', 'vertex_ai', 'bedrock', 'manus', 'anthropic'.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e

View File

@@ -0,0 +1,32 @@
from typing import Optional
from litellm.types.llms.openai import CreateFileRequest
from litellm.types.utils import ExtractedFileData
class FilesAPIUtils:
"""
Utils for files API interface on litellm
"""
@staticmethod
def is_batch_jsonl_file(
create_file_data: CreateFileRequest, extracted_file_data: ExtractedFileData
) -> bool:
"""
Check if the file is a batch jsonl file
"""
return (
create_file_data.get("purpose") == "batch"
and FilesAPIUtils.valid_content_type(
extracted_file_data.get("content_type")
)
and extracted_file_data.get("content") is not None
)
@staticmethod
def valid_content_type(content_type: Optional[str]) -> bool:
"""
Check if the content type is valid
"""
return content_type in set(["application/jsonl", "application/octet-stream"])

View File

@@ -0,0 +1,826 @@
"""
Main File for Fine Tuning API implementation
https://platform.openai.com/docs/api-reference/fine-tuning
- fine_tuning.jobs.create()
- fine_tuning.jobs.list()
- client.fine_tuning.jobs.list_events()
"""
import asyncio
import contextvars
import os
from functools import partial
from typing import Any, Coroutine, Dict, Literal, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.llms.azure.fine_tuning.handler import AzureOpenAIFineTuningAPI
from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
from litellm.llms.vertex_ai.fine_tuning.handler import VertexFineTuningAPI
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import FineTuningJobCreate, Hyperparameters
from litellm.types.router import *
from litellm.types.utils import LiteLLMFineTuningJob
from litellm.utils import client, supports_httpx_timeout
####### ENVIRONMENT VARIABLES ###################
openai_fine_tuning_apis_instance = OpenAIFineTuningAPI()
azure_fine_tuning_apis_instance = AzureOpenAIFineTuningAPI()
vertex_fine_tuning_apis_instance = VertexFineTuningAPI()
#################################################
def _prepare_azure_extra_body(
extra_body: Optional[Dict[str, Any]],
kwargs: Dict[str, Any],
azure_specific_hyperparams: Dict[str, Any],
) -> Dict[str, Any]:
"""
Prepare extra_body for Azure fine-tuning API by combining Azure-specific parameters.
Azure fine-tuning API accepts additional parameters beyond the standard OpenAI spec:
- trainingType: Type of training (e.g., 1 for supervised fine-tuning)
- prompt_loss_weight: Weight for prompt loss in training
These parameters must be passed in the extra_body field when calling the Azure OpenAI SDK.
Args:
extra_body: Optional existing extra_body dict
kwargs: Request kwargs that may contain Azure-specific parameters
azure_specific_hyperparams: Dict of Azure-specific hyperparameters already extracted
Returns:
Dict containing all Azure-specific parameters to be passed in extra_body
"""
if extra_body is None:
extra_body = {}
# Azure-specific root-level parameters
azure_specific_params = ["trainingType"]
for param in azure_specific_params:
if param in kwargs:
extra_body[param] = kwargs[param]
# Add Azure-specific hyperparameters
if azure_specific_hyperparams:
extra_body.update(azure_specific_hyperparams)
return extra_body
@client
async def acreate_fine_tuning_job(
model: str,
training_file: str,
hyperparameters: Optional[dict] = {},
suffix: Optional[str] = None,
validation_file: Optional[str] = None,
integrations: Optional[List[str]] = None,
seed: Optional[int] = None,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> LiteLLMFineTuningJob:
"""
Async: Creates and executes a batch from an uploaded file of request
"""
verbose_logger.debug(
"inside acreate_fine_tuning_job model=%s and kwargs=%s", model, kwargs
)
try:
loop = asyncio.get_event_loop()
kwargs["acreate_fine_tuning_job"] = True
# Use a partial function to pass your keyword arguments
func = partial(
create_fine_tuning_job,
model,
training_file,
hyperparameters,
suffix,
validation_file,
integrations,
seed,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
def _build_fine_tuning_job_data(
model, training_file, hyperparameters, suffix, validation_file, integrations, seed
):
return FineTuningJobCreate(
model=model,
training_file=training_file,
hyperparameters=hyperparameters,
suffix=suffix,
validation_file=validation_file,
integrations=integrations,
seed=seed,
)
def _resolve_fine_tuning_timeout(
timeout: Any,
custom_llm_provider: str,
) -> Union[float, httpx.Timeout]:
"""Normalise a raw timeout value to a float (seconds) or httpx.Timeout for fine-tuning calls."""
timeout = timeout or 600.0
if isinstance(timeout, httpx.Timeout):
if not supports_httpx_timeout(custom_llm_provider):
return float(timeout.read or 600)
return timeout
return float(timeout)
@client
def create_fine_tuning_job(
model: str,
training_file: str,
hyperparameters: Optional[dict] = {},
suffix: Optional[str] = None,
validation_file: Optional[str] = None,
integrations: Optional[List[str]] = None,
seed: Optional[int] = None,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[LiteLLMFineTuningJob, Coroutine[Any, Any, LiteLLMFineTuningJob]]:
"""
Creates a fine-tuning job which begins the process of creating a new model from a given dataset.
Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete
"""
try:
_is_async = kwargs.pop("acreate_fine_tuning_job", False) is True
optional_params = GenericLiteLLMParams(**kwargs)
# handle hyperparameters
hyperparameters = hyperparameters or {} # original hyperparameters
# For Azure, extract Azure-specific hyperparameters before creating OpenAI-spec hyperparameters
azure_specific_hyperparams = {}
if custom_llm_provider == "azure":
azure_hyperparameter_keys = ["prompt_loss_weight"]
for key in azure_hyperparameter_keys:
if key in hyperparameters:
azure_specific_hyperparams[key] = hyperparameters.pop(key)
_oai_hyperparameters: Hyperparameters = Hyperparameters(
**hyperparameters
) # Typed Hyperparameters for OpenAI Spec
timeout = _resolve_fine_tuning_timeout(
optional_params.timeout or kwargs.get("request_timeout", 600),
custom_llm_provider,
)
# OpenAI
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_BASE_URL")
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
create_fine_tuning_job_data_dict = _build_fine_tuning_job_data(
model,
training_file,
_oai_hyperparameters,
suffix,
validation_file,
integrations,
seed,
).model_dump(exclude_none=True)
response = openai_fine_tuning_apis_instance.create_fine_tuning_job(
api_base=api_base,
api_key=api_key,
api_version=optional_params.api_version,
organization=organization,
create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
client=kwargs.get(
"client", None
), # note, when we add this to `GenericLiteLLMParams` it impacts a lot of other tests + linting
)
# Azure OpenAI
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
# Prepare Azure-specific parameters for extra_body
extra_body = _prepare_azure_extra_body(
extra_body, kwargs, azure_specific_hyperparams
)
create_fine_tuning_job_data_dict = _build_fine_tuning_job_data(
model,
training_file,
_oai_hyperparameters,
suffix,
validation_file,
integrations,
seed,
).model_dump(exclude_none=True)
# Add extra_body if it has Azure-specific parameters
if extra_body:
create_fine_tuning_job_data_dict["extra_body"] = extra_body
response = azure_fine_tuning_apis_instance.create_fine_tuning_job(
api_base=api_base,
api_key=api_key,
api_version=api_version,
create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
organization=optional_params.organization,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""
vertex_ai_project = (
optional_params.vertex_project
or litellm.vertex_project
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.vertex_location
or litellm.vertex_location
or get_secret_str("VERTEXAI_LOCATION")
)
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
"VERTEXAI_CREDENTIALS"
)
response = vertex_fine_tuning_apis_instance.create_fine_tuning_job(
_is_async=_is_async,
create_fine_tuning_job_data=_build_fine_tuning_job_data(
model,
training_file,
_oai_hyperparameters,
suffix,
validation_file,
integrations,
seed,
),
vertex_credentials=vertex_credentials,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
timeout=timeout,
api_base=api_base,
kwargs=kwargs,
original_hyperparameters=hyperparameters,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
verbose_logger.error("got exception in create_fine_tuning_job=%s", str(e))
raise e
@client
async def acancel_fine_tuning_job(
fine_tuning_job_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> LiteLLMFineTuningJob:
"""
Async: Immediately cancel a fine-tune job.
"""
try:
loop = asyncio.get_event_loop()
kwargs["acancel_fine_tuning_job"] = True
# Use a partial function to pass your keyword arguments
func = partial(
cancel_fine_tuning_job,
fine_tuning_job_id,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
@client
def cancel_fine_tuning_job(
fine_tuning_job_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[LiteLLMFineTuningJob, Coroutine[Any, Any, LiteLLMFineTuningJob]]:
"""
Immediately cancel a fine-tune job.
Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("acancel_fine_tuning_job", False) is True
# OpenAI
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_BASE_URL")
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_fine_tuning_apis_instance.cancel_fine_tuning_job(
api_base=api_base,
api_key=api_key,
api_version=optional_params.api_version,
organization=organization,
fine_tuning_job_id=fine_tuning_job_id,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
client=kwargs.get("client", None),
)
# Azure OpenAI
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_fine_tuning_apis_instance.cancel_fine_tuning_job(
api_base=api_base,
api_key=api_key,
api_version=api_version,
fine_tuning_job_id=fine_tuning_job_id,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
organization=optional_params.organization,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e
async def alist_fine_tuning_jobs(
after: Optional[str] = None,
limit: Optional[int] = None,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
):
"""
Async: List your organization's fine-tuning jobs
"""
try:
loop = asyncio.get_event_loop()
kwargs["alist_fine_tuning_jobs"] = True
# Use a partial function to pass your keyword arguments
func = partial(
list_fine_tuning_jobs,
after,
limit,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
def list_fine_tuning_jobs(
after: Optional[str] = None,
limit: Optional[int] = None,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
):
"""
List your organization's fine-tuning jobs
Params:
- after: Optional[str] = None, Identifier for the last job from the previous pagination request.
- limit: Optional[int] = None, Number of fine-tuning jobs to retrieve. Defaults to 20
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("alist_fine_tuning_jobs", False) is True
# OpenAI
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_BASE_URL")
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_fine_tuning_apis_instance.list_fine_tuning_jobs(
api_base=api_base,
api_key=api_key,
api_version=optional_params.api_version,
organization=organization,
after=after,
limit=limit,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
client=kwargs.get("client", None),
)
# Azure OpenAI
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret("AZURE_AD_TOKEN") # type: ignore
response = azure_fine_tuning_apis_instance.list_fine_tuning_jobs(
api_base=api_base,
api_key=api_key,
api_version=api_version,
after=after,
limit=limit,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
organization=optional_params.organization,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e
@client
async def aretrieve_fine_tuning_job(
fine_tuning_job_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> LiteLLMFineTuningJob:
"""
Async: Get info about a fine-tuning job.
"""
try:
loop = asyncio.get_event_loop()
kwargs["aretrieve_fine_tuning_job"] = True
# Use a partial function to pass your keyword arguments
func = partial(
retrieve_fine_tuning_job,
fine_tuning_job_id,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
@client
def retrieve_fine_tuning_job(
fine_tuning_job_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[LiteLLMFineTuningJob, Coroutine[Any, Any, LiteLLMFineTuningJob]]:
"""
Get info about a fine-tuning job.
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("aretrieve_fine_tuning_job", False) is True
# OpenAI
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_BASE_URL")
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None
)
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_fine_tuning_apis_instance.retrieve_fine_tuning_job(
api_base=api_base,
api_key=api_key,
api_version=optional_params.api_version,
organization=organization,
fine_tuning_job_id=fine_tuning_job_id,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
client=kwargs.get("client", None),
)
# Azure OpenAI
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_fine_tuning_apis_instance.retrieve_fine_tuning_job(
api_base=api_base,
api_key=api_key,
api_version=api_version,
fine_tuning_job_id=fine_tuning_job_id,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
organization=optional_params.organization,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'retrieve_fine_tuning_job'. Only 'openai' and 'azure' are supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="retrieve_fine_tuning_job", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e

View File

@@ -0,0 +1,123 @@
# LiteLLM Google GenAI Interface
Interface to interact with Google GenAI Functions in the native Google interface format.
## Overview
This module provides a native interface to Google's Generative AI API, allowing you to use Google's content generation capabilities with both streaming and non-streaming modes, in both synchronous and asynchronous contexts.
## Available Functions
### Non-Streaming Functions
- `generate_content()` - Synchronous content generation
- `agenerate_content()` - Asynchronous content generation
### Streaming Functions
- `generate_content_stream()` - Synchronous streaming content generation
- `agenerate_content_stream()` - Asynchronous streaming content generation
## Usage Examples
### Basic Non-Streaming Usage
```python
from litellm.google_genai import generate_content, agenerate_content
from google.genai.types import ContentDict, PartDict
# Synchronous usage
contents = ContentDict(
parts=[
PartDict(text="Hello, can you tell me a short joke?")
],
)
response = generate_content(
contents=contents,
model="gemini-pro", # or your preferred model
# Add other model-specific parameters as needed
)
print(response)
```
### Async Non-Streaming Usage
```python
import asyncio
from litellm.google_genai import agenerate_content
from google.genai.types import ContentDict, PartDict
async def main():
contents = ContentDict(
parts=[
PartDict(text="Hello, can you tell me a short joke?")
],
)
response = await agenerate_content(
contents=contents,
model="gemini-pro",
# Add other model-specific parameters as needed
)
print(response)
# Run the async function
asyncio.run(main())
```
### Streaming Usage
```python
from litellm.google_genai import generate_content_stream
from google.genai.types import ContentDict, PartDict
# Synchronous streaming
contents = ContentDict(
parts=[
PartDict(text="Tell me a story about space exploration")
],
)
for chunk in generate_content_stream(
contents=contents,
model="gemini-pro",
):
print(f"Chunk: {chunk}")
```
### Async Streaming Usage
```python
import asyncio
from litellm.google_genai import agenerate_content_stream
from google.genai.types import ContentDict, PartDict
async def main():
contents = ContentDict(
parts=[
PartDict(text="Tell me a story about space exploration")
],
)
async for chunk in agenerate_content_stream(
contents=contents,
model="gemini-pro",
):
print(f"Async chunk: {chunk}")
asyncio.run(main())
```
## Testing
This module includes comprehensive tests covering:
- Sync and async non-streaming requests
- Sync and async streaming requests
- Response validation
- Error handling scenarios
See `tests/unified_google_tests/base_google_test.py` for test implementation examples.

View File

@@ -0,0 +1,19 @@
"""
This allows using Google GenAI model in their native interface.
This module provides generate_content functionality for Google GenAI models.
"""
from .main import (
agenerate_content,
agenerate_content_stream,
generate_content,
generate_content_stream,
)
__all__ = [
"generate_content",
"agenerate_content",
"generate_content_stream",
"agenerate_content_stream",
]

Some files were not shown because too many files have changed in this diff Show More