chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
.env
secrets.toml

View File

@@ -0,0 +1,44 @@
# litellm-proxy
A local, fast, and lightweight **OpenAI-compatible server** to call 100+ LLM APIs.
## usage
```shell
$ pip install litellm
```
```shell
$ litellm --model ollama/codellama
#INFO: Ollama running on http://0.0.0.0:8000
```
## replace openai base
```python
import openai # openai v1.0.0+
client = openai.OpenAI(api_key="anything",base_url="http://0.0.0.0:8000") # set proxy to base_url
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
])
print(response)
```
[**See how to call Huggingface,Bedrock,TogetherAI,Anthropic, etc.**](https://docs.litellm.ai/docs/simple_proxy)
---
### Folder Structure
**Routes**
- `proxy_server.py` - all openai-compatible routes - `/v1/chat/completion`, `/v1/embedding` + model info routes - `/v1/models`, `/v1/model/info`, `/v1/model_group_info` routes.
- `health_endpoints/` - `/health`, `/health/liveliness`, `/health/readiness`
- `management_endpoints/key_management_endpoints.py` - all `/key/*` routes
- `management_endpoints/team_endpoints.py` - all `/team/*` routes
- `management_endpoints/internal_user_endpoints.py` - all `/user/*` routes
- `management_endpoints/ui_sso.py` - all `/sso/*` routes

View File

@@ -0,0 +1 @@
from . import *

View File

@@ -0,0 +1,39 @@
from typing import Dict, List, Optional
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
from litellm.proxy._types import UserAPIKeyAuth
class MCPAuthenticatedUser(AuthenticatedUser):
"""
Wrapper class to make LiteLLM's authentication and configuration compatible with MCP's AuthenticatedUser.
This class handles:
1. User API key authentication information
2. MCP authentication header (deprecated)
3. MCP server configuration (can include access groups)
4. Server-specific authentication headers
5. OAuth2 headers
6. Raw headers - allows forwarding specific headers to the MCP server, specified by the admin.
"""
def __init__(
self,
user_api_key_auth: UserAPIKeyAuth,
mcp_auth_header: Optional[str] = None,
mcp_servers: Optional[List[str]] = None,
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None,
oauth2_headers: Optional[Dict[str, str]] = None,
mcp_protocol_version: Optional[str] = None,
raw_headers: Optional[Dict[str, str]] = None,
client_ip: Optional[str] = None,
):
self.user_api_key_auth = user_api_key_auth
self.mcp_auth_header = mcp_auth_header
self.mcp_servers = mcp_servers
self.mcp_server_auth_headers = mcp_server_auth_headers or {}
self.mcp_protocol_version = mcp_protocol_version
self.oauth2_headers = oauth2_headers
self.raw_headers = raw_headers
self.client_ip = client_ip

View File

@@ -0,0 +1,789 @@
"""
BYOK (Bring Your Own Key) OAuth 2.1 Authorization Server endpoints for MCP servers.
When an MCP client connects to a BYOK-enabled server and no stored credential exists,
LiteLLM runs a minimal OAuth 2.1 authorization code flow. The "authorization page" is
just a form that asks the user for their API key — not a full identity-provider OAuth.
Endpoints implemented here:
GET /.well-known/oauth-authorization-server — OAuth authorization server metadata
GET /.well-known/oauth-protected-resource — OAuth protected resource metadata
GET /v1/mcp/oauth/authorize — Shows HTML form to collect the API key
POST /v1/mcp/oauth/authorize — Stores temp auth code and redirects
POST /v1/mcp/oauth/token — Exchanges code for a bearer JWT token
"""
import base64
import hashlib
import html as _html_module
import time
import uuid
from typing import Dict, Optional, cast
from urllib.parse import urlencode, urlparse
import jwt
from fastapi import APIRouter, Form, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from litellm._logging import verbose_proxy_logger
from litellm.proxy._experimental.mcp_server.db import store_user_credential
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
get_request_base_url,
)
# ---------------------------------------------------------------------------
# In-memory store for pending authorization codes.
# Each entry: {code: {api_key, server_id, code_challenge, redirect_uri, user_id, expires_at}}
# ---------------------------------------------------------------------------
_byok_auth_codes: Dict[str, dict] = {}
# Authorization codes expire after 5 minutes.
_AUTH_CODE_TTL_SECONDS = 300
# Hard cap to prevent memory exhaustion from incomplete OAuth flows.
_AUTH_CODES_MAX_SIZE = 1000
router = APIRouter(tags=["mcp"])
# ---------------------------------------------------------------------------
# PKCE helper
# ---------------------------------------------------------------------------
def _verify_pkce(code_verifier: str, code_challenge: str) -> bool:
"""Return True iff SHA-256(code_verifier) == code_challenge (base64url, no padding)."""
digest = hashlib.sha256(code_verifier.encode()).digest()
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
return computed == code_challenge
# ---------------------------------------------------------------------------
# Cleanup of expired auth codes (called lazily on each request)
# ---------------------------------------------------------------------------
def _purge_expired_codes() -> None:
now = time.time()
expired = [k for k, v in _byok_auth_codes.items() if v["expires_at"] < now]
for k in expired:
del _byok_auth_codes[k]
def _build_authorize_html(
server_name: str,
server_initial: str,
client_id: str,
redirect_uri: str,
code_challenge: str,
code_challenge_method: str,
state: str,
server_id: str,
access_items: list,
help_url: str,
) -> str:
"""Build the 2-step BYOK OAuth authorization page HTML."""
# Escape all user-supplied / externally-derived values before interpolation
e = _html_module.escape
server_name = e(server_name)
server_initial = e(server_initial)
client_id = e(client_id)
redirect_uri = e(redirect_uri)
code_challenge = e(code_challenge)
code_challenge_method = e(code_challenge_method)
state = e(state)
server_id = e(server_id)
# Build access checklist rows
access_rows = "".join(
f'<div class="access-item"><span class="check">&#10003;</span>{e(item)}</div>'
for item in access_items
)
access_section = ""
if access_rows:
access_section = f"""
<div class="access-box">
<div class="access-header">
<span class="shield">&#9646;</span>
<span>Requested Access</span>
</div>
{access_rows}
</div>"""
# Help link for step 2
help_link_html = ""
if help_url:
help_link_html = f'<a class="help-link" href="{e(help_url)}" target="_blank">Where do I find my API key? &#8599;</a>'
return f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Connect {server_name} &mdash; LiteLLM</title>
<style>
*, *::before, *::after {{ box-sizing: border-box; margin: 0; padding: 0; }}
body {{
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: #0f172a;
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
padding: 24px;
}}
.modal {{
background: #ffffff;
border-radius: 20px;
padding: 36px 32px 32px;
width: 440px;
max-width: 100%;
position: relative;
box-shadow: 0 25px 60px rgba(0,0,0,0.35);
}}
/* Progress dots */
.dots {{
display: flex;
justify-content: center;
gap: 7px;
margin-bottom: 28px;
}}
.dot {{
width: 8px; height: 8px;
border-radius: 50%;
background: #e2e8f0;
}}
.dot.active {{ background: #38bdf8; }}
/* Close button */
.close-btn {{
position: absolute;
top: 16px; right: 16px;
background: none; border: none;
font-size: 16px; color: #94a3b8;
cursor: pointer; line-height: 1;
width: 28px; height: 28px;
border-radius: 6px;
display: flex; align-items: center; justify-content: center;
}}
.close-btn:hover {{ background: #f1f5f9; color: #475569; }}
/* Logo pair */
.logos {{
display: flex; align-items: center; justify-content: center;
gap: 12px; margin-bottom: 20px;
}}
.logo {{
width: 52px; height: 52px;
border-radius: 14px;
display: flex; align-items: center; justify-content: center;
font-size: 22px; font-weight: 800; color: white;
}}
.logo-img {{
width: 52px; height: 52px;
border-radius: 14px;
object-fit: cover;
border: 1.5px solid #e2e8f0;
}}
.logo-s {{ background: linear-gradient(135deg, #818cf8 0%, #4f46e5 100%); }}
.logo-arrow {{ color: #cbd5e1; font-size: 20px; font-weight: 300; }}
/* Headings */
.step-title {{
text-align: center;
font-size: 21px; font-weight: 700;
color: #0f172a; margin-bottom: 8px;
}}
.step-subtitle {{
text-align: center;
font-size: 14px; color: #64748b;
line-height: 1.55; margin-bottom: 22px;
}}
/* Info box */
.info-box {{
background: #f8fafc;
border-radius: 12px;
padding: 14px 16px;
display: flex; gap: 12px;
margin-bottom: 14px;
}}
.info-icon {{ font-size: 17px; flex-shrink: 0; margin-top: 1px; color: #38bdf8; }}
.info-box h4 {{ font-size: 13px; font-weight: 600; color: #1e293b; margin-bottom: 4px; }}
.info-box p {{ font-size: 13px; color: #64748b; line-height: 1.5; }}
/* Access checklist */
.access-box {{
background: #f8fafc;
border-radius: 12px;
padding: 14px 16px;
margin-bottom: 22px;
}}
.access-header {{
display: flex; align-items: center; gap: 8px;
margin-bottom: 10px;
}}
.shield {{ color: #22c55e; font-size: 15px; }}
.access-header > span:last-child {{
font-size: 11px; font-weight: 700;
letter-spacing: 0.07em;
text-transform: uppercase;
color: #475569;
}}
.access-item {{
display: flex; align-items: center; gap: 9px;
font-size: 13.5px; color: #374151;
padding: 3px 0;
}}
.check {{ color: #22c55e; font-weight: 700; font-size: 13px; }}
/* Primary CTA */
.btn-primary {{
width: 100%; padding: 15px;
background: #0f172a; color: white;
border: none; border-radius: 12px;
font-size: 15px; font-weight: 600;
cursor: pointer; margin-bottom: 10px;
}}
.btn-primary:hover {{ background: #1e293b; }}
.btn-cancel {{
width: 100%; padding: 8px;
background: none; border: none;
font-size: 13.5px; color: #94a3b8;
cursor: pointer;
}}
.btn-cancel:hover {{ color: #64748b; }}
/* Step 2 nav */
.step2-nav {{
display: flex; align-items: center;
justify-content: space-between;
margin-bottom: 24px;
}}
.back-btn {{
background: none; border: none;
font-size: 13.5px; color: #64748b;
cursor: pointer; display: flex; align-items: center; gap: 4px;
}}
.back-btn:hover {{ color: #374151; }}
/* Key icon */
.key-icon-wrap {{
width: 46px; height: 46px;
background: #e0f2fe;
border-radius: 12px;
display: flex; align-items: center; justify-content: center;
margin-bottom: 14px;
}}
.key-icon-wrap svg {{ width: 22px; height: 22px; color: #0284c7; }}
/* Form elements */
.field-label {{
font-size: 13.5px; font-weight: 600;
color: #1e293b; display: block;
margin-bottom: 7px;
}}
.key-input {{
width: 100%; padding: 11px 13px;
border: 1.5px solid #e2e8f0;
border-radius: 10px;
font-size: 14px; color: #0f172a;
outline: none; transition: border-color 0.15s, box-shadow 0.15s;
}}
.key-input:focus {{
border-color: #38bdf8;
box-shadow: 0 0 0 3px rgba(56,189,248,0.12);
}}
.help-link {{
display: inline-flex; align-items: center; gap: 4px;
color: #0ea5e9; font-size: 13px;
text-decoration: none; margin: 8px 0 16px;
}}
.help-link:hover {{ text-decoration: underline; }}
/* Save toggle card */
.save-card {{
border: 1.5px solid #e2e8f0;
border-radius: 12px;
padding: 13px 15px;
margin-bottom: 6px;
}}
.save-row {{
display: flex; align-items: center; gap: 10px;
}}
.save-icon {{ font-size: 16px; }}
.save-label {{
flex: 1;
font-size: 14px; font-weight: 500; color: #1e293b;
}}
/* Toggle switch */
.toggle {{ position: relative; width: 44px; height: 24px; flex-shrink: 0; }}
.toggle input {{ opacity: 0; width: 0; height: 0; }}
.slider {{
position: absolute; inset: 0;
background: #e2e8f0;
border-radius: 24px; cursor: pointer;
transition: background 0.18s;
}}
.slider::before {{
content: '';
position: absolute;
width: 18px; height: 18px;
left: 3px; bottom: 3px;
background: white;
border-radius: 50%;
transition: transform 0.18s;
box-shadow: 0 1px 3px rgba(0,0,0,0.18);
}}
input:checked + .slider {{ background: #38bdf8; }}
input:checked + .slider::before {{ transform: translateX(20px); }}
/* Duration pills */
.duration-section {{ margin-top: 14px; }}
.duration-label {{
font-size: 12px; font-weight: 600;
color: #64748b; margin-bottom: 8px;
text-transform: uppercase; letter-spacing: 0.05em;
}}
.pills {{ display: flex; flex-wrap: wrap; gap: 7px; }}
.pill {{
padding: 6px 13px;
border: 1.5px solid #e2e8f0;
border-radius: 20px;
font-size: 13px; color: #475569;
cursor: pointer; background: white;
transition: all 0.13s;
user-select: none;
}}
.pill:hover {{ border-color: #94a3b8; }}
.pill.sel {{
border-color: #38bdf8;
color: #0284c7;
background: #e0f2fe;
}}
/* Security note */
.sec-note {{
background: #f8fafc;
border-radius: 10px;
padding: 11px 14px;
display: flex; gap: 9px; align-items: flex-start;
margin: 16px 0;
}}
.sec-icon {{ font-size: 13px; color: #94a3b8; margin-top: 1px; flex-shrink: 0; }}
.sec-note p {{ font-size: 12.5px; color: #64748b; line-height: 1.5; }}
/* Connect button */
.btn-connect {{
width: 100%; padding: 15px;
border: none; border-radius: 12px;
font-size: 15px; font-weight: 600;
cursor: pointer;
background: #bae6fd; color: #0369a1;
transition: background 0.15s, color 0.15s;
}}
.btn-connect.ready {{
background: #0ea5e9; color: white;
}}
.btn-connect.ready:hover {{ background: #0284c7; }}
/* Step visibility */
.step {{ display: none; }}
.step.show {{ display: block; }}
</style>
</head>
<body>
<div class="modal">
<!-- ── STEP 1: Connect ─────────────────────────────────────── -->
<div id="s1" class="step show">
<div class="dots">
<div class="dot active"></div>
<div class="dot"></div>
</div>
<button class="close-btn" type="button" onclick="doCancel()" title="Close">&times;</button>
<div class="logos">
<img src="/ui/assets/logos/litellm_logo.jpg" class="logo-img" alt="LiteLLM">
<span class="logo-arrow">&#8594;</span>
<div class="logo logo-s">{server_initial}</div>
</div>
<h2 class="step-title">Connect {server_name} MCP</h2>
<p class="step-subtitle">LiteLLM needs access to {server_name} to complete your request.</p>
<div class="info-box">
<span class="info-icon">
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"/><line x1="12" y1="8" x2="12" y2="12"/><line x1="12" y1="16" x2="12.01" y2="16"/></svg>
</span>
<div>
<h4>How it works</h4>
<p>LiteLLM acts as a secure bridge. Your requests are routed through our MCP client directly to {server_name}&rsquo;s API.</p>
</div>
</div>
{access_section}
<button class="btn-primary" type="button" onclick="goStep2()">
Continue to Authentication &rarr;
</button>
<button class="btn-cancel" type="button" onclick="doCancel()">Cancel</button>
</div>
<!-- ── STEP 2: Provide API Key ──────────────────────────────── -->
<div id="s2" class="step">
<div class="step2-nav">
<button class="back-btn" type="button" onclick="goStep1()">&#8592; Back</button>
<div class="dots">
<div class="dot active"></div>
<div class="dot active"></div>
</div>
<button class="close-btn" style="position:static;" type="button" onclick="doCancel()" title="Close">&times;</button>
</div>
<div class="key-icon-wrap">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="#0284c7" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M21 2l-2 2m-7.61 7.61a5.5 5.5 0 1 1-7.778 7.778 5.5 5.5 0 0 1 7.777-7.777zm0 0L15.5 7.5m0 0l3 3L22 7l-3-3m-3.5 3.5L19 4"/></svg>
</div>
<h2 class="step-title" style="text-align:left;">Provide API Key</h2>
<p class="step-subtitle" style="text-align:left;">Enter your {server_name} API key to authorize this connection.</p>
<form method="POST" id="authForm" onsubmit="prepareSubmit()">
<input type="hidden" name="client_id" value="{client_id}">
<input type="hidden" name="redirect_uri" value="{redirect_uri}">
<input type="hidden" name="code_challenge" value="{code_challenge}">
<input type="hidden" name="code_challenge_method" value="{code_challenge_method}">
<input type="hidden" name="state" value="{state}">
<input type="hidden" name="server_id" value="{server_id}">
<input type="hidden" name="duration" id="durInput" value="until_revoked">
<label class="field-label">{server_name} API Key</label>
<input
type="password"
name="api_key"
id="apiKey"
class="key-input"
placeholder="Enter your API key"
required
autofocus
oninput="syncBtn()"
>
{help_link_html}
<div class="save-card">
<div class="save-row">
<span class="save-label">Save key for future use</span>
<label class="toggle">
<input type="checkbox" id="saveToggle" onchange="toggleDur()">
<span class="slider"></span>
</label>
</div>
<div id="durSection" class="duration-section" style="display:none;">
<div class="duration-label">Duration</div>
<div class="pills">
<div class="pill" onclick="selDur('1h',this)">1 hour</div>
<div class="pill sel" onclick="selDur('24h',this)">24 hours</div>
<div class="pill" onclick="selDur('7d',this)">7 days</div>
<div class="pill" onclick="selDur('30d',this)">30 days</div>
<div class="pill" onclick="selDur('until_revoked',this)">Until I revoke</div>
</div>
</div>
</div>
<div class="sec-note">
<span class="sec-icon">
<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="3" y="11" width="18" height="11" rx="2" ry="2"/><path d="M7 11V7a5 5 0 0 1 10 0v4"/></svg>
</span>
<p>Your key is stored securely and transmitted over HTTPS. It is never shared with third parties.</p>
</div>
<button type="submit" class="btn-connect" id="connectBtn">
Connect &amp; Authorize
</button>
</form>
</div>
</div>
<script>
function goStep2() {{
document.getElementById('s1').classList.remove('show');
document.getElementById('s2').classList.add('show');
}}
function goStep1() {{
document.getElementById('s2').classList.remove('show');
document.getElementById('s1').classList.add('show');
}}
function doCancel() {{
if (window.opener) window.close();
else window.history.back();
}}
function toggleDur() {{
const on = document.getElementById('saveToggle').checked;
document.getElementById('durSection').style.display = on ? 'block' : 'none';
}}
function selDur(val, el) {{
document.querySelectorAll('.pill').forEach(p => p.classList.remove('sel'));
el.classList.add('sel');
document.getElementById('durInput').value = val;
}}
function syncBtn() {{
const btn = document.getElementById('connectBtn');
if (document.getElementById('apiKey').value.length > 0) {{
btn.classList.add('ready');
}} else {{
btn.classList.remove('ready');
}}
}}
function prepareSubmit() {{
// nothing extra needed — duration is already in the hidden input
}}
</script>
</body>
</html>"""
# ---------------------------------------------------------------------------
# OAuth metadata discovery endpoints
# ---------------------------------------------------------------------------
@router.get("/.well-known/oauth-authorization-server", include_in_schema=False)
async def oauth_authorization_server_metadata(request: Request) -> JSONResponse:
"""RFC 8414 Authorization Server Metadata for the BYOK OAuth flow."""
base_url = get_request_base_url(request)
return JSONResponse(
{
"issuer": base_url,
"authorization_endpoint": f"{base_url}/v1/mcp/oauth/authorize",
"token_endpoint": f"{base_url}/v1/mcp/oauth/token",
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code"],
"code_challenge_methods_supported": ["S256"],
}
)
@router.get("/.well-known/oauth-protected-resource", include_in_schema=False)
async def oauth_protected_resource_metadata(request: Request) -> JSONResponse:
"""RFC 9728 Protected Resource Metadata pointing back at this server."""
base_url = get_request_base_url(request)
return JSONResponse(
{
"resource": base_url,
"authorization_servers": [base_url],
}
)
# ---------------------------------------------------------------------------
# Authorization endpoint — GET (show form) and POST (process form)
# ---------------------------------------------------------------------------
@router.get("/v1/mcp/oauth/authorize", include_in_schema=False)
async def byok_authorize_get(
request: Request,
client_id: Optional[str] = None,
redirect_uri: Optional[str] = None,
response_type: Optional[str] = None,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
state: Optional[str] = None,
server_id: Optional[str] = None,
) -> HTMLResponse:
"""
Show the BYOK API-key entry form.
The MCP client navigates the user here; the user types their API key and
clicks "Connect & Authorize", which POSTs back to this same path.
"""
if response_type != "code":
raise HTTPException(status_code=400, detail="response_type must be 'code'")
if not redirect_uri:
raise HTTPException(status_code=400, detail="redirect_uri is required")
if not code_challenge:
raise HTTPException(status_code=400, detail="code_challenge is required")
# Resolve server metadata (name, description items, help URL).
server_name = "MCP Server"
access_items: list = []
help_url = ""
if server_id:
try:
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
registry = global_mcp_server_manager.get_registry()
if server_id in registry:
srv = registry[server_id]
server_name = srv.server_name or srv.name
access_items = list(srv.byok_description or [])
help_url = srv.byok_api_key_help_url or ""
except Exception:
pass
server_initial = (server_name[0].upper()) if server_name else "S"
html = _build_authorize_html(
server_name=server_name,
server_initial=server_initial,
client_id=client_id or "",
redirect_uri=redirect_uri,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method or "S256",
state=state or "",
server_id=server_id or "",
access_items=access_items,
help_url=help_url,
)
return HTMLResponse(content=html)
@router.post("/v1/mcp/oauth/authorize", include_in_schema=False)
async def byok_authorize_post(
request: Request,
client_id: str = Form(default=""),
redirect_uri: str = Form(...),
code_challenge: str = Form(...),
code_challenge_method: str = Form(default="S256"),
state: str = Form(default=""),
server_id: str = Form(default=""),
api_key: str = Form(...),
) -> RedirectResponse:
"""
Process the BYOK API-key form submission.
Stores a short-lived authorization code and redirects the client back to
redirect_uri with ?code=...&state=... query parameters.
"""
_purge_expired_codes()
# Validate redirect_uri scheme to prevent open redirect
parsed_uri = urlparse(redirect_uri)
if parsed_uri.scheme not in ("http", "https"):
raise HTTPException(status_code=400, detail="Invalid redirect_uri scheme")
# Reject new codes if the store is at capacity (prevents memory exhaustion
# from a burst of abandoned OAuth flows).
if len(_byok_auth_codes) >= _AUTH_CODES_MAX_SIZE:
raise HTTPException(
status_code=503, detail="Too many pending authorization flows"
)
if code_challenge_method != "S256":
raise HTTPException(
status_code=400, detail="Only S256 code_challenge_method is supported"
)
auth_code = str(uuid.uuid4())
_byok_auth_codes[auth_code] = {
"api_key": api_key,
"server_id": server_id,
"code_challenge": code_challenge,
"redirect_uri": redirect_uri,
"user_id": client_id, # external client passes LiteLLM user-id as client_id
"expires_at": time.time() + _AUTH_CODE_TTL_SECONDS,
}
params = urlencode({"code": auth_code, "state": state})
separator = "&" if "?" in redirect_uri else "?"
location = f"{redirect_uri}{separator}{params}"
return RedirectResponse(url=location, status_code=302)
# ---------------------------------------------------------------------------
# Token endpoint
# ---------------------------------------------------------------------------
@router.post("/v1/mcp/oauth/token", include_in_schema=False)
async def byok_token(
request: Request,
grant_type: str = Form(...),
code: str = Form(...),
redirect_uri: str = Form(default=""),
code_verifier: str = Form(...),
client_id: str = Form(default=""),
) -> JSONResponse:
"""
Exchange an authorization code for a short-lived BYOK session JWT.
1. Validates the authorization code and PKCE challenge.
2. Stores the API key via store_user_credential().
3. Issues a signed JWT with type="byok_session".
"""
from litellm.proxy.proxy_server import master_key, prisma_client
_purge_expired_codes()
if grant_type != "authorization_code":
raise HTTPException(status_code=400, detail="unsupported_grant_type")
record = _byok_auth_codes.get(code)
if record is None:
raise HTTPException(status_code=400, detail="invalid_grant")
if time.time() > record["expires_at"]:
del _byok_auth_codes[code]
raise HTTPException(status_code=400, detail="invalid_grant")
# PKCE verification
if not _verify_pkce(code_verifier, record["code_challenge"]):
raise HTTPException(status_code=400, detail="invalid_grant")
# Consume the code (one-time use)
del _byok_auth_codes[code]
server_id: str = record["server_id"]
api_key_value: str = record["api_key"]
# Prefer the user_id that was stored when the code was issued; fall back to
# whatever client_id the token request supplies (they should match).
user_id: str = record.get("user_id") or client_id
if not user_id:
raise HTTPException(
status_code=400,
detail="Cannot determine user_id; pass LiteLLM user id as client_id",
)
# Persist the BYOK credential
if prisma_client is not None:
try:
await store_user_credential(
prisma_client=prisma_client,
user_id=user_id,
server_id=server_id,
credential=api_key_value,
)
# Invalidate any cached negative result so the user isn't blocked
# for up to the TTL period after completing the OAuth flow.
from litellm.proxy._experimental.mcp_server.server import (
_invalidate_byok_cred_cache,
)
_invalidate_byok_cred_cache(user_id, server_id)
except Exception as exc:
verbose_proxy_logger.error(
"byok_token: failed to store user credential for user=%s server=%s: %s",
user_id,
server_id,
exc,
)
raise HTTPException(status_code=500, detail="Failed to store credential")
else:
verbose_proxy_logger.warning(
"byok_token: prisma_client is None — credential not persisted"
)
if master_key is None:
raise HTTPException(
status_code=500, detail="Master key not configured; cannot issue token"
)
now = int(time.time())
payload = {
"user_id": user_id,
"server_id": server_id,
# "type" distinguishes this from regular proxy auth tokens.
# The proxy's SSO JWT path uses asymmetric keys (RS256/ES256), so an
# HS256 token signed with master_key cannot be accepted there.
"type": "byok_session",
"iat": now,
"exp": now + 3600,
}
access_token = jwt.encode(payload, cast(str, master_key), algorithm="HS256")
return JSONResponse(
{
"access_token": access_token,
"token_type": "bearer",
"expires_in": 3600,
}
)

View File

@@ -0,0 +1,77 @@
"""
Cost calculator for MCP tools.
"""
from typing import TYPE_CHECKING, Any, Optional, cast
from litellm.types.mcp import MCPServerCostInfo
from litellm.types.utils import StandardLoggingMCPToolCall
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import (
Logging as LitellmLoggingObject,
)
else:
LitellmLoggingObject = Any
class MCPCostCalculator:
@staticmethod
def calculate_mcp_tool_call_cost(
litellm_logging_obj: Optional[LitellmLoggingObject],
) -> float:
"""
Calculate the cost of an MCP tool call.
Default is 0.0, unless user specifies a custom cost per request for MCP tools.
"""
if litellm_logging_obj is None:
return 0.0
#########################################################
# Get the response cost from logging object model_call_details
# This is set when a user modifies the response in a post_mcp_tool_call_hook
#########################################################
response_cost = litellm_logging_obj.model_call_details.get(
"response_cost", None
)
if response_cost is not None:
return response_cost
#########################################################
# Unpack the mcp_tool_call_metadata
#########################################################
mcp_tool_call_metadata: StandardLoggingMCPToolCall = (
cast(
StandardLoggingMCPToolCall,
litellm_logging_obj.model_call_details.get(
"mcp_tool_call_metadata", {}
),
)
or {}
)
mcp_server_cost_info: MCPServerCostInfo = (
mcp_tool_call_metadata.get("mcp_server_cost_info") or MCPServerCostInfo()
)
#########################################################
# User defined cost per query
#########################################################
default_cost_per_query = mcp_server_cost_info.get(
"default_cost_per_query", None
)
tool_name_to_cost_per_query: dict = (
mcp_server_cost_info.get("tool_name_to_cost_per_query", {}) or {}
)
tool_name = mcp_tool_call_metadata.get("name", "")
#########################################################
# 1. If tool_name is in tool_name_to_cost_per_query, use the cost per query
# 2. If tool_name is not in tool_name_to_cost_per_query, use the default cost per query
# 3. Default to 0.0 if no cost per query is found
#########################################################
cost_per_query: float = 0.0
if tool_name in tool_name_to_cost_per_query:
cost_per_query = tool_name_to_cost_per_query[tool_name]
elif default_cost_per_query is not None:
cost_per_query = default_cost_per_query
return cost_per_query

View File

@@ -0,0 +1,767 @@
import base64
import json
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
from litellm.proxy._types import (
LiteLLM_MCPServerTable,
LiteLLM_ObjectPermissionTable,
LiteLLM_TeamTable,
MCPApprovalStatus,
MCPSubmissionsSummary,
NewMCPServerRequest,
SpecialMCPServerName,
UpdateMCPServerRequest,
UserAPIKeyAuth,
)
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
_get_salt_key,
decrypt_value_helper,
encrypt_value_helper,
)
from litellm.proxy.utils import PrismaClient
from litellm.types.mcp import MCPCredentials
def _prepare_mcp_server_data(
data: Union[NewMCPServerRequest, UpdateMCPServerRequest],
) -> Dict[str, Any]:
"""
Helper function to prepare MCP server data for database operations.
Handles JSON field serialization for mcp_info and env fields.
Args:
data: NewMCPServerRequest or UpdateMCPServerRequest object
Returns:
Dict with properly serialized JSON fields
"""
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
# Convert model to dict
data_dict = data.model_dump(exclude_none=True)
# Ensure alias is always present in the dict (even if None)
if "alias" not in data_dict:
data_dict["alias"] = getattr(data, "alias", None)
# Handle credentials serialization
credentials = data_dict.get("credentials")
if credentials is not None:
data_dict["credentials"] = encrypt_credentials(
credentials=credentials, encryption_key=_get_salt_key()
)
data_dict["credentials"] = safe_dumps(data_dict["credentials"])
# Handle static_headers serialization
if data.static_headers is not None:
data_dict["static_headers"] = safe_dumps(data.static_headers)
# Handle mcp_info serialization
if data.mcp_info is not None:
data_dict["mcp_info"] = safe_dumps(data.mcp_info)
# Handle env serialization
if data.env is not None:
data_dict["env"] = safe_dumps(data.env)
# Handle tool name override serialization
if data.tool_name_to_display_name is not None:
data_dict["tool_name_to_display_name"] = safe_dumps(
data.tool_name_to_display_name
)
if data.tool_name_to_description is not None:
data_dict["tool_name_to_description"] = safe_dumps(
data.tool_name_to_description
)
# mcp_access_groups is already List[str], no serialization needed
# Force include is_byok even when False (exclude_none=True would not drop it,
# but be explicit to ensure a False value is always written to the DB).
data_dict["is_byok"] = getattr(data, "is_byok", False)
return data_dict
def encrypt_credentials(
credentials: MCPCredentials, encryption_key: Optional[str]
) -> MCPCredentials:
auth_value = credentials.get("auth_value")
if auth_value is not None:
credentials["auth_value"] = encrypt_value_helper(
value=auth_value,
new_encryption_key=encryption_key,
)
client_id = credentials.get("client_id")
if client_id is not None:
credentials["client_id"] = encrypt_value_helper(
value=client_id,
new_encryption_key=encryption_key,
)
client_secret = credentials.get("client_secret")
if client_secret is not None:
credentials["client_secret"] = encrypt_value_helper(
value=client_secret,
new_encryption_key=encryption_key,
)
# AWS SigV4 credential fields
aws_access_key_id = credentials.get("aws_access_key_id")
if aws_access_key_id is not None:
credentials["aws_access_key_id"] = encrypt_value_helper(
value=aws_access_key_id,
new_encryption_key=encryption_key,
)
aws_secret_access_key = credentials.get("aws_secret_access_key")
if aws_secret_access_key is not None:
credentials["aws_secret_access_key"] = encrypt_value_helper(
value=aws_secret_access_key,
new_encryption_key=encryption_key,
)
aws_session_token = credentials.get("aws_session_token")
if aws_session_token is not None:
credentials["aws_session_token"] = encrypt_value_helper(
value=aws_session_token,
new_encryption_key=encryption_key,
)
# aws_region_name and aws_service_name are NOT secrets — stored as-is
return credentials
def decrypt_credentials(
credentials: MCPCredentials,
) -> MCPCredentials:
"""Decrypt all secret fields in an MCPCredentials dict using the global salt key."""
secret_fields = [
"auth_value",
"client_id",
"client_secret",
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
]
for field in secret_fields:
value = credentials.get(field) # type: ignore[literal-required]
if value is not None and isinstance(value, str):
credentials[field] = decrypt_value_helper( # type: ignore[literal-required]
value=value,
key=field,
exception_type="debug",
return_original_value=True,
)
return credentials
async def get_all_mcp_servers(
prisma_client: PrismaClient,
approval_status: Optional[str] = None,
) -> List[LiteLLM_MCPServerTable]:
"""
Returns mcp servers from the db, optionally filtered by approval_status.
Pass approval_status=None to return all servers regardless of approval state.
"""
try:
where: Dict[str, Any] = {}
if approval_status is not None:
where["approval_status"] = approval_status
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many(
where=where if where else {}
)
return [
LiteLLM_MCPServerTable(**mcp_server.model_dump())
for mcp_server in mcp_servers
]
except Exception as e:
verbose_proxy_logger.debug(
"litellm.proxy._experimental.mcp_server.db.py::get_all_mcp_servers - {}".format(
str(e)
)
)
return []
async def get_mcp_server(
prisma_client: PrismaClient, server_id: str
) -> Optional[LiteLLM_MCPServerTable]:
"""
Returns the matching mcp server from the db iff exists
"""
mcp_server: Optional[
LiteLLM_MCPServerTable
] = await prisma_client.db.litellm_mcpservertable.find_unique(
where={
"server_id": server_id,
}
)
return mcp_server
async def get_mcp_servers(
prisma_client: PrismaClient, server_ids: Iterable[str]
) -> List[LiteLLM_MCPServerTable]:
"""
Returns the matching mcp servers from the db with the server_ids
"""
_mcp_servers: List[
LiteLLM_MCPServerTable
] = await prisma_client.db.litellm_mcpservertable.find_many(
where={
"server_id": {"in": server_ids},
}
)
final_mcp_servers: List[LiteLLM_MCPServerTable] = []
for _mcp_server in _mcp_servers:
final_mcp_servers.append(LiteLLM_MCPServerTable(**_mcp_server.model_dump()))
return final_mcp_servers
async def get_mcp_servers_by_verificationtoken(
prisma_client: PrismaClient, token: str
) -> List[str]:
"""
Returns the mcp servers from the db for the verification token
"""
verification_token_record: LiteLLM_TeamTable = (
await prisma_client.db.litellm_verificationtoken.find_unique(
where={
"token": token,
},
include={
"object_permission": True,
},
)
)
mcp_servers: Optional[List[str]] = []
if (
verification_token_record is not None
and verification_token_record.object_permission is not None
):
mcp_servers = verification_token_record.object_permission.mcp_servers
return mcp_servers or []
async def get_mcp_servers_by_team(
prisma_client: PrismaClient, team_id: str
) -> List[str]:
"""
Returns the mcp servers from the db for the team id
"""
team_record: LiteLLM_TeamTable = (
await prisma_client.db.litellm_teamtable.find_unique(
where={
"team_id": team_id,
},
include={
"object_permission": True,
},
)
)
mcp_servers: Optional[List[str]] = []
if team_record is not None and team_record.object_permission is not None:
mcp_servers = team_record.object_permission.mcp_servers
return mcp_servers or []
async def get_all_mcp_servers_for_user(
prisma_client: PrismaClient,
user: UserAPIKeyAuth,
) -> List[LiteLLM_MCPServerTable]:
"""
Get all the mcp servers filtered by the given user has access to.
Following Least-Privilege Principle - the requestor should only be able to see the mcp servers that they have access to.
"""
mcp_server_ids: Set[str] = set()
mcp_servers = []
# Get the mcp servers for the key
if user.api_key:
token_mcp_servers = await get_mcp_servers_by_verificationtoken(
prisma_client, user.api_key
)
mcp_server_ids.update(token_mcp_servers)
# check for special team membership
if (
SpecialMCPServerName.all_team_servers in mcp_server_ids
and user.team_id is not None
):
team_mcp_servers = await get_mcp_servers_by_team(
prisma_client, user.team_id
)
mcp_server_ids.update(team_mcp_servers)
if len(mcp_server_ids) > 0:
mcp_servers = await get_mcp_servers(prisma_client, mcp_server_ids)
return mcp_servers
async def get_objectpermissions_for_mcp_server(
prisma_client: PrismaClient, mcp_server_id: str
) -> List[LiteLLM_ObjectPermissionTable]:
"""
Get all the object permissions records and the associated team and verficiationtoken records that have access to the mcp server
"""
object_permission_records = (
await prisma_client.db.litellm_objectpermissiontable.find_many(
where={
"mcp_servers": {"has": mcp_server_id},
},
include={
"teams": True,
"verification_tokens": True,
},
)
)
return object_permission_records
async def get_virtualkeys_for_mcp_server(
prisma_client: PrismaClient, server_id: str
) -> List:
"""
Get all the virtual keys that have access to the mcp server
"""
virtual_keys = await prisma_client.db.litellm_verificationtoken.find_many(
where={
"mcp_servers": {"has": server_id},
},
)
if virtual_keys is None:
return []
return virtual_keys
async def delete_mcp_server_from_team(prisma_client: PrismaClient, server_id: str):
"""
Remove the mcp server from the team
"""
pass
async def delete_mcp_server_from_virtualkey():
"""
Remove the mcp server from the virtual key
"""
pass
async def delete_mcp_server(
prisma_client: PrismaClient, server_id: str
) -> Optional[LiteLLM_MCPServerTable]:
"""
Delete the mcp server from the db by server_id
Returns the deleted mcp server record if it exists, otherwise None
"""
deleted_server = await prisma_client.db.litellm_mcpservertable.delete(
where={
"server_id": server_id,
},
)
return deleted_server
async def create_mcp_server(
prisma_client: PrismaClient, data: NewMCPServerRequest, touched_by: str
) -> LiteLLM_MCPServerTable:
"""
Create a new mcp server record in the db
"""
if data.server_id is None:
data.server_id = str(uuid.uuid4())
# Use helper to prepare data with proper JSON serialization
data_dict = _prepare_mcp_server_data(data)
# Add audit fields
data_dict["created_by"] = touched_by
data_dict["updated_by"] = touched_by
new_mcp_server = await prisma_client.db.litellm_mcpservertable.create(
data=data_dict # type: ignore
)
return new_mcp_server
async def update_mcp_server(
prisma_client: PrismaClient, data: UpdateMCPServerRequest, touched_by: str
) -> LiteLLM_MCPServerTable:
"""
Update a new mcp server record in the db
"""
import json
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
# Use helper to prepare data with proper JSON serialization
data_dict = _prepare_mcp_server_data(data)
# Pre-fetch existing record once if we need it for auth_type or credential logic
existing = None
has_credentials = (
"credentials" in data_dict and data_dict["credentials"] is not None
)
if data.auth_type or has_credentials:
existing = await prisma_client.db.litellm_mcpservertable.find_unique(
where={"server_id": data.server_id}
)
# Clear stale credentials when auth_type changes but no new credentials provided
if (
data.auth_type
and "credentials" not in data_dict
and existing
and existing.auth_type is not None
and existing.auth_type != data.auth_type
):
data_dict["credentials"] = None
# Merge credentials: preserve existing fields not present in the update.
# Without this, a partial credential update (e.g. changing only region)
# would wipe encrypted secrets that the UI cannot display back.
if "credentials" in data_dict and data_dict["credentials"] is not None:
if existing and existing.credentials:
# Only merge when auth_type is unchanged. Switching auth types
# (e.g. oauth2 → api_key) should replace credentials entirely
# to avoid stale secrets from the previous auth type lingering.
auth_type_unchanged = (
data.auth_type is None or data.auth_type == existing.auth_type
)
if auth_type_unchanged:
existing_creds = (
json.loads(existing.credentials)
if isinstance(existing.credentials, str)
else dict(existing.credentials)
)
new_creds = (
json.loads(data_dict["credentials"])
if isinstance(data_dict["credentials"], str)
else dict(data_dict["credentials"])
)
# New values override existing; existing keys not in update are preserved
merged = {**existing_creds, **new_creds}
data_dict["credentials"] = safe_dumps(merged)
# Add audit fields
data_dict["updated_by"] = touched_by
updated_mcp_server = await prisma_client.db.litellm_mcpservertable.update(
where={"server_id": data.server_id}, data=data_dict # type: ignore
)
return updated_mcp_server
async def rotate_mcp_server_credentials_master_key(
prisma_client: PrismaClient, touched_by: str, new_master_key: str
):
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many()
for mcp_server in mcp_servers:
credentials = mcp_server.credentials
if not credentials:
continue
credentials_copy = dict(credentials)
# Decrypt with current key first, then re-encrypt with new key
decrypted_credentials = decrypt_credentials(
credentials=cast(MCPCredentials, credentials_copy),
)
encrypted_credentials = encrypt_credentials(
credentials=decrypted_credentials,
encryption_key=new_master_key,
)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
serialized_credentials = safe_dumps(encrypted_credentials)
await prisma_client.db.litellm_mcpservertable.update(
where={"server_id": mcp_server.server_id},
data={
"credentials": serialized_credentials,
"updated_by": touched_by,
},
)
async def store_user_credential(
prisma_client: PrismaClient,
user_id: str,
server_id: str,
credential: str,
) -> None:
"""Store a user credential for a BYOK MCP server."""
encoded = base64.urlsafe_b64encode(credential.encode()).decode()
await prisma_client.db.litellm_mcpusercredentials.upsert(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}},
data={
"create": {
"user_id": user_id,
"server_id": server_id,
"credential_b64": encoded,
},
"update": {"credential_b64": encoded},
},
)
async def get_user_credential(
prisma_client: PrismaClient,
user_id: str,
server_id: str,
) -> Optional[str]:
"""Return credential for a user+server pair, or None."""
row = await prisma_client.db.litellm_mcpusercredentials.find_unique(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
)
if row is None:
return None
try:
return base64.urlsafe_b64decode(row.credential_b64).decode()
except Exception:
# Fall back to nacl decryption for credentials stored by older code
return decrypt_value_helper(
value=row.credential_b64,
key="byok_credential",
exception_type="debug",
return_original_value=False,
)
async def has_user_credential(
prisma_client: PrismaClient,
user_id: str,
server_id: str,
) -> bool:
"""Return True if the user has a stored credential for this server."""
row = await prisma_client.db.litellm_mcpusercredentials.find_unique(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
)
return row is not None
async def delete_user_credential(
prisma_client: PrismaClient,
user_id: str,
server_id: str,
) -> None:
"""Delete the user's stored credential for a BYOK MCP server."""
await prisma_client.db.litellm_mcpusercredentials.delete(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
)
# ── OAuth2 user-credential helpers ────────────────────────────────────────────
async def store_user_oauth_credential(
prisma_client: PrismaClient,
user_id: str,
server_id: str,
access_token: str,
refresh_token: Optional[str] = None,
expires_in: Optional[int] = None,
scopes: Optional[List[str]] = None,
) -> None:
"""Persist an OAuth2 access token for a user+server pair.
The payload is JSON-serialised and stored base64-encoded in the same
``credential_b64`` column used by BYOK. A ``"type": "oauth2"`` key
differentiates it from plain BYOK API keys.
"""
expires_at: Optional[str] = None
if expires_in is not None:
expires_at = (
datetime.now(timezone.utc) + timedelta(seconds=expires_in)
).isoformat()
payload: Dict[str, Any] = {
"type": "oauth2",
"access_token": access_token,
"connected_at": datetime.now(timezone.utc).isoformat(),
}
if refresh_token:
payload["refresh_token"] = refresh_token
if expires_at:
payload["expires_at"] = expires_at
if scopes:
payload["scopes"] = scopes
# Guard against silently overwriting a BYOK credential with an OAuth token.
# BYOK credentials lack a "type" field (or use a non-"oauth2" type).
existing = await prisma_client.db.litellm_mcpusercredentials.find_unique(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
)
if existing is not None:
_byok_error = ValueError(
f"A non-OAuth2 credential already exists for user {user_id} "
f"and server {server_id}. Refusing to overwrite."
)
try:
raw = json.loads(base64.urlsafe_b64decode(existing.credential_b64).decode())
except Exception:
# Credential is not base64+JSON — it's a plain-text BYOK key.
raise _byok_error
if raw.get("type") != "oauth2":
raise _byok_error
encoded = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode()
await prisma_client.db.litellm_mcpusercredentials.upsert(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}},
data={
"create": {
"user_id": user_id,
"server_id": server_id,
"credential_b64": encoded,
},
"update": {"credential_b64": encoded},
},
)
def is_oauth_credential_expired(cred: Dict[str, Any]) -> bool:
"""Return True if the OAuth2 credential's access_token has expired.
Checks the ``expires_at`` ISO-format string stored in the credential payload.
Returns False when ``expires_at`` is absent or unparseable (treat as non-expired).
"""
expires_at = cred.get("expires_at")
if not expires_at:
return False
try:
exp_dt = datetime.fromisoformat(expires_at)
if exp_dt.tzinfo is None:
exp_dt = exp_dt.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > exp_dt
except (ValueError, TypeError):
return False
async def get_user_oauth_credential(
prisma_client: PrismaClient,
user_id: str,
server_id: str,
) -> Optional[Dict[str, Any]]:
"""Return the decoded OAuth2 payload dict for a user+server pair, or None."""
row = await prisma_client.db.litellm_mcpusercredentials.find_unique(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
)
if row is None:
return None
try:
decoded = base64.urlsafe_b64decode(row.credential_b64).decode()
parsed = json.loads(decoded)
if isinstance(parsed, dict) and parsed.get("type") == "oauth2":
return parsed
# Row exists but is a BYOK (plain string), not an OAuth token
return None
except Exception:
return None
async def list_user_oauth_credentials(
prisma_client: PrismaClient,
user_id: str,
) -> List[Dict[str, Any]]:
"""Return all OAuth2 credential payloads for a user, tagged with server_id."""
rows = await prisma_client.db.litellm_mcpusercredentials.find_many(
where={"user_id": user_id}
)
results: List[Dict[str, Any]] = []
for row in rows:
try:
decoded = base64.urlsafe_b64decode(row.credential_b64).decode()
parsed = json.loads(decoded)
if isinstance(parsed, dict) and parsed.get("type") == "oauth2":
parsed["server_id"] = row.server_id
results.append(parsed)
except Exception:
pass # Skip non-OAuth rows (BYOK plain strings)
return results
async def approve_mcp_server(
prisma_client: PrismaClient,
server_id: str,
touched_by: str,
) -> LiteLLM_MCPServerTable:
"""Set approval_status=active and record reviewed_at."""
now = datetime.now(timezone.utc)
updated = await prisma_client.db.litellm_mcpservertable.update(
where={"server_id": server_id},
data={
"approval_status": MCPApprovalStatus.active,
"reviewed_at": now,
"updated_by": touched_by,
},
)
return LiteLLM_MCPServerTable(**updated.model_dump())
async def reject_mcp_server(
prisma_client: PrismaClient,
server_id: str,
touched_by: str,
review_notes: Optional[str] = None,
) -> LiteLLM_MCPServerTable:
"""Set approval_status=rejected, record reviewed_at and review_notes."""
now = datetime.now(timezone.utc)
data: Dict[str, Any] = {
"approval_status": MCPApprovalStatus.rejected,
"reviewed_at": now,
"updated_by": touched_by,
}
if review_notes is not None:
data["review_notes"] = review_notes
updated = await prisma_client.db.litellm_mcpservertable.update(
where={"server_id": server_id},
data=data,
)
return LiteLLM_MCPServerTable(**updated.model_dump())
async def get_mcp_submissions(
prisma_client: PrismaClient,
) -> MCPSubmissionsSummary:
"""
Returns all MCP servers that were submitted by non-admin users (submitted_at IS NOT NULL),
along with a summary count breakdown by approval_status.
Mirrors get_guardrail_submissions() from guardrail_endpoints.py.
"""
rows = await prisma_client.db.litellm_mcpservertable.find_many(
where={"submitted_at": {"not": None}},
order={"submitted_at": "desc"},
take=500, # safety cap; paginate if needed in a future iteration
)
items = [LiteLLM_MCPServerTable(**r.model_dump()) for r in rows]
pending = sum(
1 for i in items if i.approval_status == MCPApprovalStatus.pending_review
)
active = sum(1 for i in items if i.approval_status == MCPApprovalStatus.active)
rejected = sum(1 for i in items if i.approval_status == MCPApprovalStatus.rejected)
return MCPSubmissionsSummary(
total=len(items),
pending_review=pending,
active=active,
rejected=rejected,
items=items,
)

View File

@@ -0,0 +1,741 @@
import json
from typing import Optional
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
from fastapi import APIRouter, Form, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
decrypt_value_helper,
encrypt_value_helper,
)
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.utils import get_server_root_path
from litellm.types.mcp import MCPAuth
from litellm.types.mcp_server.mcp_server_manager import MCPServer
router = APIRouter(
tags=["mcp"],
)
def get_request_base_url(request: Request) -> str:
"""
Get the base URL for the request, considering X-Forwarded-* headers.
When behind a proxy (like nginx), the proxy may set:
- X-Forwarded-Proto: The original protocol (http/https)
- X-Forwarded-Host: The original host (may include port)
- X-Forwarded-Port: The original port (if not in Host header)
Args:
request: FastAPI Request object
Returns:
The reconstructed base URL (e.g., "https://proxy.example.com")
"""
base_url = str(request.base_url).rstrip("/")
parsed = urlparse(base_url)
# Get forwarded headers
x_forwarded_proto = request.headers.get("X-Forwarded-Proto")
x_forwarded_host = request.headers.get("X-Forwarded-Host")
x_forwarded_port = request.headers.get("X-Forwarded-Port")
# Start with the original scheme
scheme = x_forwarded_proto if x_forwarded_proto else parsed.scheme
# Handle host and port
if x_forwarded_host:
# X-Forwarded-Host may already include port (e.g., "example.com:8080")
if ":" in x_forwarded_host and not x_forwarded_host.startswith("["):
# Host includes port
netloc = x_forwarded_host
elif x_forwarded_port:
# Port is separate
netloc = f"{x_forwarded_host}:{x_forwarded_port}"
else:
# Just host, no explicit port
netloc = x_forwarded_host
else:
# No X-Forwarded-Host, use original netloc
netloc = parsed.netloc
if x_forwarded_port and ":" not in netloc:
# Add forwarded port if not already in netloc
netloc = f"{netloc}:{x_forwarded_port}"
# Reconstruct the URL
return urlunparse((scheme, netloc, parsed.path, "", "", ""))
def encode_state_with_base_url(
base_url: str,
original_state: str,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
client_redirect_uri: Optional[str] = None,
) -> str:
"""
Encode the base_url, original state, and PKCE parameters using encryption.
Args:
base_url: The base URL to encode
original_state: The original state parameter
code_challenge: PKCE code challenge from client
code_challenge_method: PKCE code challenge method from client
client_redirect_uri: Original redirect_uri from client
Returns:
An encrypted string that encodes all values
"""
state_data = {
"base_url": base_url,
"original_state": original_state,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"client_redirect_uri": client_redirect_uri,
}
state_json = json.dumps(state_data, sort_keys=True)
encrypted_state = encrypt_value_helper(state_json)
return encrypted_state
def decode_state_hash(encrypted_state: str) -> dict:
"""
Decode an encrypted state to retrieve all OAuth session data.
Args:
encrypted_state: The encrypted string to decode
Returns:
A dict containing base_url, original_state, and optional PKCE parameters
Raises:
Exception: If decryption fails or data is malformed
"""
decrypted_json = decrypt_value_helper(encrypted_state, "oauth_state")
if decrypted_json is None:
raise ValueError("Failed to decrypt state parameter")
state_data = json.loads(decrypted_json)
return state_data
def _resolve_oauth2_server_for_root_endpoints(
client_ip: Optional[str] = None,
) -> Optional[MCPServer]:
"""
Resolve the MCP server for root-level OAuth endpoints (no server name in path).
When the MCP SDK hits root-level endpoints like /register, /authorize, /token
without a server name prefix, we try to find the right server automatically.
Returns the server if exactly one OAuth2 server is configured, else None.
"""
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
registry = global_mcp_server_manager.get_filtered_registry(client_ip=client_ip)
oauth2_servers = [s for s in registry.values() if s.auth_type == MCPAuth.oauth2]
if len(oauth2_servers) == 1:
return oauth2_servers[0]
return None
async def authorize_with_server(
request: Request,
mcp_server: MCPServer,
client_id: str,
redirect_uri: str,
state: str = "",
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
response_type: Optional[str] = None,
scope: Optional[str] = None,
):
if mcp_server.auth_type != "oauth2":
raise HTTPException(status_code=400, detail="MCP server is not OAuth2")
if mcp_server.authorization_url is None:
raise HTTPException(
status_code=400, detail="MCP server authorization url is not set"
)
parsed = urlparse(redirect_uri)
base_url = urlunparse(parsed._replace(query=""))
request_base_url = get_request_base_url(request)
encoded_state = encode_state_with_base_url(
base_url=base_url,
original_state=state,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
client_redirect_uri=redirect_uri,
)
params = {
"client_id": mcp_server.client_id if mcp_server.client_id else client_id,
"redirect_uri": f"{request_base_url}/callback",
"state": encoded_state,
"response_type": response_type or "code",
}
if scope:
params["scope"] = scope
elif mcp_server.scopes:
params["scope"] = " ".join(mcp_server.scopes)
if code_challenge:
params["code_challenge"] = code_challenge
if code_challenge_method:
params["code_challenge_method"] = code_challenge_method
parsed_auth_url = urlparse(mcp_server.authorization_url)
existing_params = dict(parse_qsl(parsed_auth_url.query))
existing_params.update(params)
final_url = urlunparse(parsed_auth_url._replace(query=urlencode(existing_params)))
return RedirectResponse(final_url)
async def exchange_token_with_server(
request: Request,
mcp_server: MCPServer,
grant_type: str,
code: Optional[str],
redirect_uri: Optional[str],
client_id: str,
client_secret: Optional[str],
code_verifier: Optional[str],
):
if grant_type != "authorization_code":
raise HTTPException(status_code=400, detail="Unsupported grant_type")
if mcp_server.token_url is None:
raise HTTPException(status_code=400, detail="MCP server token url is not set")
proxy_base_url = get_request_base_url(request)
token_data = {
"grant_type": "authorization_code",
"client_id": mcp_server.client_id if mcp_server.client_id else client_id,
"client_secret": mcp_server.client_secret
if mcp_server.client_secret
else client_secret,
"code": code,
"redirect_uri": f"{proxy_base_url}/callback",
}
if code_verifier:
token_data["code_verifier"] = code_verifier
async_client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
response = await async_client.post(
mcp_server.token_url,
headers={"Accept": "application/json"},
data=token_data,
)
response.raise_for_status()
token_response = response.json()
access_token = token_response["access_token"]
result = {
"access_token": access_token,
"token_type": token_response.get("token_type", "Bearer"),
"expires_in": token_response.get("expires_in", 3600),
}
if "refresh_token" in token_response and token_response["refresh_token"]:
result["refresh_token"] = token_response["refresh_token"]
if "scope" in token_response and token_response["scope"]:
result["scope"] = token_response["scope"]
return JSONResponse(result)
async def register_client_with_server(
request: Request,
mcp_server: MCPServer,
client_name: str,
grant_types: Optional[list],
response_types: Optional[list],
token_endpoint_auth_method: Optional[str],
fallback_client_id: Optional[str] = None,
):
request_base_url = get_request_base_url(request)
dummy_return = {
"client_id": fallback_client_id or mcp_server.server_name,
"client_secret": "dummy",
"redirect_uris": [f"{request_base_url}/callback"],
}
if mcp_server.client_id and mcp_server.client_secret:
return dummy_return
if mcp_server.authorization_url is None:
raise HTTPException(
status_code=400, detail="MCP server authorization url is not set"
)
if mcp_server.registration_url is None:
return dummy_return
register_data = {
"client_name": client_name,
"redirect_uris": [f"{request_base_url}/callback"],
"grant_types": grant_types or [],
"response_types": response_types or [],
"token_endpoint_auth_method": token_endpoint_auth_method or "",
}
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.Oauth2Register
)
response = await async_client.post(
mcp_server.registration_url,
headers=headers,
json=register_data,
)
response.raise_for_status()
token_response = response.json()
return JSONResponse(token_response)
@router.get("/{mcp_server_name}/authorize")
@router.get("/authorize")
async def authorize(
request: Request,
redirect_uri: str,
client_id: Optional[str] = None,
state: str = "",
mcp_server_name: Optional[str] = None,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
response_type: Optional[str] = None,
scope: Optional[str] = None,
):
# Redirect to real OAuth provider with PKCE support
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
lookup_name: Optional[str] = mcp_server_name or client_id
client_ip = IPAddressUtils.get_mcp_client_ip(request)
mcp_server = (
global_mcp_server_manager.get_mcp_server_by_name(
lookup_name, client_ip=client_ip
)
if lookup_name
else None
)
if mcp_server is None and mcp_server_name is None:
mcp_server = _resolve_oauth2_server_for_root_endpoints()
if mcp_server is None:
raise HTTPException(status_code=404, detail="MCP server not found")
# Use server's stored client_id when caller doesn't supply one.
# Raise a clear error instead of passing an empty string — an empty
# client_id would silently produce a broken authorization URL.
resolved_client_id: str = mcp_server.client_id or client_id or ""
if not resolved_client_id:
raise HTTPException(
status_code=400,
detail={
"error": "client_id is required but was not supplied and is not "
"stored on the MCP server record. Provide client_id as a query "
"parameter or configure it on the server."
},
)
return await authorize_with_server(
request=request,
mcp_server=mcp_server,
client_id=resolved_client_id,
redirect_uri=redirect_uri,
state=state,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
response_type=response_type,
scope=scope,
)
@router.post("/{mcp_server_name}/token")
@router.post("/token")
async def token_endpoint(
request: Request,
grant_type: str = Form(...),
code: str = Form(None),
redirect_uri: str = Form(None),
client_id: str = Form(...),
client_secret: Optional[str] = Form(None),
code_verifier: str = Form(None),
mcp_server_name: Optional[str] = None,
):
"""
Accept the authorization code from client and exchange it for OAuth token.
Supports PKCE flow by forwarding code_verifier to upstream provider.
1. Call the token endpoint with PKCE parameters
2. Store the user's token in the db - and generate a LiteLLM virtual key
3. Return the token
4. Return a virtual key in this response
"""
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
lookup_name = mcp_server_name or client_id
client_ip = IPAddressUtils.get_mcp_client_ip(request)
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
lookup_name, client_ip=client_ip
)
if mcp_server is None and mcp_server_name is None:
mcp_server = _resolve_oauth2_server_for_root_endpoints()
if mcp_server is None:
raise HTTPException(status_code=404, detail="MCP server not found")
return await exchange_token_with_server(
request=request,
mcp_server=mcp_server,
grant_type=grant_type,
code=code,
redirect_uri=redirect_uri,
client_id=client_id,
client_secret=client_secret,
code_verifier=code_verifier,
)
@router.get("/callback")
async def callback(code: str, state: str):
try:
# Decode the state hash to get base_url, original state, and PKCE params
state_data = decode_state_hash(state)
base_url = state_data["base_url"]
original_state = state_data["original_state"]
# Forward code and original state back to client
params = {"code": code, "state": original_state}
# Forward to client's callback endpoint
complete_returned_url = f"{base_url}?{urlencode(params)}"
return RedirectResponse(url=complete_returned_url, status_code=302)
except Exception:
# fallback if state hash not found
return HTMLResponse(
"<html><body>Authentication incomplete. You can close this window.</body></html>"
)
# ------------------------------
# Optional .well-known endpoints for MCP + OAuth discovery
# ------------------------------
"""
Per SEP-985, the client MUST:
1. Try resource_metadata from WWW-Authenticate header (if present)
2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path}
(
If the resource identifier value contains a path or query component, any terminating slash (/)
following the host component MUST be removed before inserting /.well-known/ and the well-known
URI path suffix between the host component and the path(include root path) and/or query components.
https://datatracker.ietf.org/doc/html/rfc9728#section-3.1)
3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource
Dual Pattern Support:
- Standard MCP pattern: /mcp/{server_name} (recommended, used by mcp-inspector, VSCode Copilot)
- LiteLLM legacy pattern: /{server_name}/mcp (backward compatibility)
The resource URL returned matches the pattern used in the discovery request.
"""
def _build_oauth_protected_resource_response(
request: Request,
mcp_server_name: Optional[str],
use_standard_pattern: bool,
) -> dict:
"""
Build OAuth protected resource response with the appropriate URL pattern.
Args:
request: FastAPI Request object
mcp_server_name: Name of the MCP server
use_standard_pattern: If True, use /mcp/{server_name} pattern;
if False, use /{server_name}/mcp pattern
Returns:
OAuth protected resource metadata dict
"""
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
request_base_url = get_request_base_url(request)
# When no server name provided, try to resolve the single OAuth2 server
if mcp_server_name is None:
resolved = _resolve_oauth2_server_for_root_endpoints()
if resolved:
mcp_server_name = resolved.server_name or resolved.name
mcp_server: Optional[MCPServer] = None
if mcp_server_name:
client_ip = IPAddressUtils.get_mcp_client_ip(request)
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
mcp_server_name, client_ip=client_ip
)
# Build resource URL based on the pattern
if mcp_server_name:
if use_standard_pattern:
# Standard MCP pattern: /mcp/{server_name}
resource_url = f"{request_base_url}/mcp/{mcp_server_name}"
else:
# LiteLLM legacy pattern: /{server_name}/mcp
resource_url = f"{request_base_url}/{mcp_server_name}/mcp"
else:
resource_url = f"{request_base_url}/mcp"
return {
"authorization_servers": [
(
f"{request_base_url}/{mcp_server_name}"
if mcp_server_name
else f"{request_base_url}"
)
],
"resource": resource_url,
"scopes_supported": mcp_server.scopes
if mcp_server and mcp_server.scopes
else [],
}
# Standard MCP pattern: /.well-known/oauth-protected-resource/mcp/{server_name}
# This is the pattern expected by standard MCP clients (mcp-inspector, VSCode Copilot)
@router.get(
f"/.well-known/oauth-protected-resource{'' if get_server_root_path() == '/' else get_server_root_path()}/mcp/{{mcp_server_name}}"
)
async def oauth_protected_resource_mcp_standard(request: Request, mcp_server_name: str):
"""
OAuth protected resource discovery endpoint using standard MCP URL pattern.
Standard pattern: /mcp/{server_name}
Discovery path: /.well-known/oauth-protected-resource/mcp/{server_name}
This endpoint is compliant with MCP specification and works with standard
MCP clients like mcp-inspector and VSCode Copilot.
"""
return _build_oauth_protected_resource_response(
request=request,
mcp_server_name=mcp_server_name,
use_standard_pattern=True,
)
# LiteLLM legacy pattern: /.well-known/oauth-protected-resource/{server_name}/mcp
# Kept for backward compatibility with existing deployments
@router.get(
f"/.well-known/oauth-protected-resource{'' if get_server_root_path() == '/' else get_server_root_path()}/{{mcp_server_name}}/mcp"
)
@router.get("/.well-known/oauth-protected-resource")
async def oauth_protected_resource_mcp(
request: Request, mcp_server_name: Optional[str] = None
):
"""
OAuth protected resource discovery endpoint using LiteLLM legacy URL pattern.
Legacy pattern: /{server_name}/mcp
Discovery path: /.well-known/oauth-protected-resource/{server_name}/mcp
This endpoint is kept for backward compatibility. New integrations should
use the standard MCP pattern (/mcp/{server_name}) instead.
"""
return _build_oauth_protected_resource_response(
request=request,
mcp_server_name=mcp_server_name,
use_standard_pattern=False,
)
"""
https://datatracker.ietf.org/doc/html/rfc8414#section-3.1
RFC 8414: Path-aware OAuth discovery
If the issuer identifier value contains a path component, any
terminating "/" MUST be removed before inserting "/.well-known/" and
the well-known URI suffix between the host component and the path(include root path)
component.
"""
def _build_oauth_authorization_server_response(
request: Request,
mcp_server_name: Optional[str],
) -> dict:
"""
Build OAuth authorization server metadata response.
Args:
request: FastAPI Request object
mcp_server_name: Name of the MCP server
Returns:
OAuth authorization server metadata dict
"""
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
request_base_url = get_request_base_url(request)
# When no server name provided, try to resolve the single OAuth2 server
if mcp_server_name is None:
resolved = _resolve_oauth2_server_for_root_endpoints()
if resolved:
mcp_server_name = resolved.server_name or resolved.name
authorization_endpoint = (
f"{request_base_url}/{mcp_server_name}/authorize"
if mcp_server_name
else f"{request_base_url}/authorize"
)
token_endpoint = (
f"{request_base_url}/{mcp_server_name}/token"
if mcp_server_name
else f"{request_base_url}/token"
)
mcp_server: Optional[MCPServer] = None
if mcp_server_name:
client_ip = IPAddressUtils.get_mcp_client_ip(request)
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
mcp_server_name, client_ip=client_ip
)
return {
"issuer": request_base_url, # point to your proxy
"authorization_endpoint": authorization_endpoint,
"token_endpoint": token_endpoint,
"response_types_supported": ["code"],
"scopes_supported": mcp_server.scopes
if mcp_server and mcp_server.scopes
else [],
"grant_types_supported": ["authorization_code", "refresh_token"],
"code_challenge_methods_supported": ["S256"],
"token_endpoint_auth_methods_supported": ["client_secret_post"],
# Claude expects a registration endpoint, even if we just fake it
"registration_endpoint": f"{request_base_url}/{mcp_server_name}/register"
if mcp_server_name
else f"{request_base_url}/register",
}
# Standard MCP pattern: /.well-known/oauth-authorization-server/mcp/{server_name}
@router.get(
f"/.well-known/oauth-authorization-server{'' if get_server_root_path() == '/' else get_server_root_path()}/mcp/{{mcp_server_name}}"
)
async def oauth_authorization_server_mcp_standard(
request: Request, mcp_server_name: str
):
"""
OAuth authorization server discovery endpoint using standard MCP URL pattern.
Standard pattern: /mcp/{server_name}
Discovery path: /.well-known/oauth-authorization-server/mcp/{server_name}
"""
return _build_oauth_authorization_server_response(
request=request,
mcp_server_name=mcp_server_name,
)
# LiteLLM legacy pattern and root endpoint
@router.get(
f"/.well-known/oauth-authorization-server{'' if get_server_root_path() == '/' else get_server_root_path()}/{{mcp_server_name}}"
)
@router.get("/.well-known/oauth-authorization-server")
async def oauth_authorization_server_mcp(
request: Request, mcp_server_name: Optional[str] = None
):
"""
OAuth authorization server discovery endpoint.
Supports both legacy pattern (/{server_name}) and root endpoint.
"""
return _build_oauth_authorization_server_response(
request=request,
mcp_server_name=mcp_server_name,
)
# Alias for standard OpenID discovery
@router.get("/.well-known/openid-configuration")
async def openid_configuration(request: Request):
return await oauth_authorization_server_mcp(request)
# Additional legacy pattern support
@router.get("/.well-known/oauth-authorization-server/{mcp_server_name}/mcp")
async def oauth_authorization_server_legacy(request: Request, mcp_server_name: str):
"""
OAuth authorization server discovery for legacy /{server_name}/mcp pattern.
"""
return _build_oauth_authorization_server_response(
request=request,
mcp_server_name=mcp_server_name,
)
@router.post("/{mcp_server_name}/register")
@router.post("/register")
async def register_client(request: Request, mcp_server_name: Optional[str] = None):
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
# Get the correct base URL considering X-Forwarded-* headers
request_base_url = get_request_base_url(request)
request_data = await _read_request_body(request=request)
data: dict = {**request_data}
dummy_return = {
"client_id": mcp_server_name or "dummy_client",
"client_secret": "dummy",
"redirect_uris": [f"{request_base_url}/callback"],
}
if not mcp_server_name:
resolved = _resolve_oauth2_server_for_root_endpoints()
if resolved:
return await register_client_with_server(
request=request,
mcp_server=resolved,
client_name=data.get("client_name", ""),
grant_types=data.get("grant_types", []),
response_types=data.get("response_types", []),
token_endpoint_auth_method=data.get("token_endpoint_auth_method", ""),
fallback_client_id=resolved.server_name or resolved.name,
)
return dummy_return
client_ip = IPAddressUtils.get_mcp_client_ip(request)
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
mcp_server_name, client_ip=client_ip
)
if mcp_server is None:
return dummy_return
return await register_client_with_server(
request=request,
mcp_server=mcp_server,
client_name=data.get("client_name", ""),
grant_types=data.get("grant_types", []),
response_types=data.get("response_types", []),
token_endpoint_auth_method=data.get("token_endpoint_auth_method", ""),
fallback_client_id=mcp_server_name,
)

View File

@@ -0,0 +1,16 @@
"""Guardrail translation mapping for MCP tool calls."""
from litellm.proxy._experimental.mcp_server.guardrail_translation.handler import (
MCPGuardrailTranslationHandler,
)
from litellm.types.utils import CallTypes
# This mapping lives alongside the MCP server implementation because MCP
# integrations are managed by the proxy subsystem, not litellm.llms providers.
# Unified guardrails import this module explicitly to register the handler.
guardrail_translation_mappings = {
CallTypes.call_mcp_tool: MCPGuardrailTranslationHandler,
}
__all__ = ["guardrail_translation_mappings", "MCPGuardrailTranslationHandler"]

View File

@@ -0,0 +1,99 @@
"""
MCP Guardrail Handler for Unified Guardrails.
Converts an MCP call_tool (name + arguments) into a single OpenAI-compatible
tool_call and passes it to apply_guardrail. Works with the synthetic payload
from ProxyLogging._convert_mcp_to_llm_format.
Note: For MCP tool definitions (schema) -> OpenAI tools=[], see
litellm.experimental_mcp_client.tools.transform_mcp_tool_to_openai_tool
when you have a full MCP Tool from list_tools. Here we only have the call
payload (name + arguments) so we just build the tool_call.
"""
from typing import TYPE_CHECKING, Any, Dict, Optional
from mcp.types import Tool as MCPTool
from litellm._logging import verbose_proxy_logger
from litellm.experimental_mcp_client.tools import transform_mcp_tool_to_openai_tool
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
from litellm.types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
from litellm.types.utils import GenericGuardrailAPIInputs
if TYPE_CHECKING:
from mcp.types import CallToolResult
from litellm.integrations.custom_guardrail import CustomGuardrail
class MCPGuardrailTranslationHandler(BaseTranslation):
"""Guardrail translation handler for MCP tool calls (passes a single tool_call to guardrail)."""
async def process_input_messages(
self,
data: Dict[str, Any],
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Dict[str, Any]:
mcp_tool_name = data.get("mcp_tool_name") or data.get("name")
mcp_arguments = data.get("mcp_arguments") or data.get("arguments")
mcp_tool_description = data.get("mcp_tool_description") or data.get(
"description"
)
if mcp_arguments is None or not isinstance(mcp_arguments, dict):
mcp_arguments = {}
if not mcp_tool_name:
verbose_proxy_logger.debug("MCP Guardrail: mcp_tool_name missing")
return data
# Convert MCP input via transform_mcp_tool_to_openai_tool, then map to litellm
# ChatCompletionToolParam (openai SDK type has incompatible strict/cache_control).
mcp_tool = MCPTool(
name=mcp_tool_name,
description=mcp_tool_description or "",
inputSchema={}, # Call payload has no schema; guardrail gets args from request_data
)
openai_tool = transform_mcp_tool_to_openai_tool(mcp_tool)
fn = openai_tool["function"]
tool_def: ChatCompletionToolParam = {
"type": "function",
"function": ChatCompletionToolParamFunctionChunk(
name=fn["name"],
description=fn.get("description") or "",
parameters=fn.get("parameters")
or {
"type": "object",
"properties": {},
"additionalProperties": False,
},
strict=fn.get("strict", False) or False, # Default to False if None
),
}
inputs: GenericGuardrailAPIInputs = GenericGuardrailAPIInputs(
tools=[tool_def],
)
await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=data,
input_type="request",
logging_obj=litellm_logging_obj,
)
return data
async def process_output_response(
self,
response: "CallToolResult",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
) -> Any:
verbose_proxy_logger.debug(
"MCP Guardrail: Output processing not implemented for MCP tools",
)
return response

View File

@@ -0,0 +1,325 @@
"""
MCP OAuth2 Debug Headers
========================
Client-side debugging for MCP authentication flows.
When a client sends the ``x-litellm-mcp-debug: true`` header, LiteLLM
returns masked diagnostic headers in the response so operators can
troubleshoot OAuth2 issues without SSH access to the gateway.
Response headers returned (all values are masked for safety):
x-mcp-debug-inbound-auth
Which inbound auth headers were present and how they were classified.
Example: ``x-litellm-api-key=Bearer sk-12****1234``
x-mcp-debug-oauth2-token
The OAuth2 token extracted from the Authorization header (masked).
Shows ``(none)`` if absent, or flags ``SAME_AS_LITELLM_KEY`` when
the LiteLLM API key is accidentally leaking to the MCP server.
x-mcp-debug-auth-resolution
Which auth priority was used for the outbound MCP call:
``per-request-header``, ``m2m-client-credentials``, ``static-token``,
``oauth2-passthrough``, or ``no-auth``.
x-mcp-debug-outbound-url
The upstream MCP server URL that will receive the request.
x-mcp-debug-server-auth-type
The ``auth_type`` configured on the MCP server (e.g. ``oauth2``,
``bearer_token``, ``none``).
Debugging Guide
---------------
**Common issue: LiteLLM API key leaking to the MCP server**
Symptom: ``x-mcp-debug-oauth2-token`` shows ``SAME_AS_LITELLM_KEY``.
This means the ``Authorization`` header carries the LiteLLM API key and
it's being forwarded to the upstream MCP server instead of an OAuth2 token.
Fix: Move the LiteLLM key to ``x-litellm-api-key`` so the ``Authorization``
header is free for OAuth2 discovery::
# WRONG — blocks OAuth2 discovery
claude mcp add --transport http my_server http://proxy/mcp/server \\
--header "Authorization: Bearer sk-..."
# CORRECT — LiteLLM key in dedicated header, Authorization free for OAuth2
claude mcp add --transport http my_server http://proxy/mcp/server \\
--header "x-litellm-api-key: Bearer sk-..." \\
--header "x-litellm-mcp-debug: true"
**Common issue: No OAuth2 token present**
Symptom: ``x-mcp-debug-oauth2-token`` shows ``(none)`` and
``x-mcp-debug-auth-resolution`` shows ``no-auth``.
This means the client didn't go through the OAuth2 flow. Check that:
1. The ``Authorization`` header is NOT set as a static header in the client config.
2. The ``.well-known/oauth-protected-resource`` endpoint returns valid metadata.
3. The MCP server in LiteLLM config has ``auth_type: oauth2``.
**Common issue: M2M token used instead of user token**
Symptom: ``x-mcp-debug-auth-resolution`` shows ``m2m-client-credentials``.
This means the server has ``client_id``/``client_secret``/``token_url``
configured and LiteLLM is fetching a machine-to-machine token instead of
using the per-user OAuth2 token. If you want per-user tokens, remove the
client credentials from the server config.
Usage from Claude Code::
claude mcp add --transport http my_server http://proxy/mcp/server \\
--header "x-litellm-api-key: Bearer sk-..." \\
--header "x-litellm-mcp-debug: true"
Usage with curl::
curl -H "x-litellm-mcp-debug: true" \\
-H "x-litellm-api-key: Bearer sk-..." \\
http://localhost:4000/mcp/atlassian_mcp
"""
from typing import TYPE_CHECKING, Dict, List, Optional
from starlette.types import Message, Send
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
if TYPE_CHECKING:
from litellm.types.mcp_server.mcp_server_manager import MCPServer
# Header the client sends to opt into debug mode
MCP_DEBUG_REQUEST_HEADER = "x-litellm-mcp-debug"
# Prefix for all debug response headers
_RESPONSE_HEADER_PREFIX = "x-mcp-debug"
class MCPDebug:
"""
Static helper class for MCP OAuth2 debug headers.
Provides opt-in client-side diagnostics by injecting masked
authentication info into HTTP response headers.
"""
# Masker: show first 6 and last 4 chars so you can distinguish token types
# e.g. "Bearer****ef01" vs "sk-123****cdef"
_masker = SensitiveDataMasker(
sensitive_patterns={
"authorization",
"token",
"key",
"secret",
"auth",
"bearer",
},
visible_prefix=6,
visible_suffix=4,
)
@staticmethod
def _mask(value: Optional[str]) -> str:
"""Mask a single value for safe display in headers."""
if not value:
return "(none)"
return MCPDebug._masker._mask_value(value)
@staticmethod
def is_debug_enabled(headers: Dict[str, str]) -> bool:
"""
Check if the client opted into MCP debug mode.
Looks for ``x-litellm-mcp-debug: true`` (case-insensitive) in the
request headers.
"""
for key, val in headers.items():
if key.lower() == MCP_DEBUG_REQUEST_HEADER:
return val.strip().lower() in ("true", "1", "yes")
return False
@staticmethod
def resolve_auth_resolution(
server: "MCPServer",
mcp_auth_header: Optional[str],
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]],
oauth2_headers: Optional[Dict[str, str]],
) -> str:
"""
Determine which auth priority will be used for the outbound MCP call.
Returns one of: ``per-request-header``, ``m2m-client-credentials``,
``static-token``, ``oauth2-passthrough``, or ``no-auth``.
"""
from litellm.types.mcp import MCPAuth
has_server_specific = bool(
mcp_server_auth_headers
and (
mcp_server_auth_headers.get(server.alias or "")
or mcp_server_auth_headers.get(server.server_name or "")
)
)
if has_server_specific or mcp_auth_header:
return "per-request-header"
if server.has_client_credentials:
return "m2m-client-credentials"
if server.authentication_token:
return "static-token"
if oauth2_headers and server.auth_type == MCPAuth.oauth2:
return "oauth2-passthrough"
return "no-auth"
@staticmethod
def build_debug_headers(
*,
inbound_headers: Dict[str, str],
oauth2_headers: Optional[Dict[str, str]],
litellm_api_key: Optional[str],
auth_resolution: str,
server_url: Optional[str],
server_auth_type: Optional[str],
) -> Dict[str, str]:
"""
Build masked debug response headers.
Parameters
----------
inbound_headers : dict
Raw headers received from the MCP client.
oauth2_headers : dict or None
Extracted OAuth2 headers (``{"Authorization": "Bearer ..."}``).
litellm_api_key : str or None
The LiteLLM API key extracted from ``x-litellm-api-key`` or
``Authorization`` header.
auth_resolution : str
Which auth priority was selected for the outbound call.
server_url : str or None
Upstream MCP server URL.
server_auth_type : str or None
The ``auth_type`` configured on the server (e.g. ``oauth2``).
Returns
-------
dict
Headers to include in the response (all values masked).
"""
debug: Dict[str, str] = {}
# --- Inbound auth summary ---
inbound_parts = []
for hdr_name in ("x-litellm-api-key", "authorization", "x-mcp-auth"):
for k, v in inbound_headers.items():
if k.lower() == hdr_name:
inbound_parts.append(f"{hdr_name}={MCPDebug._mask(v)}")
break
debug[f"{_RESPONSE_HEADER_PREFIX}-inbound-auth"] = (
"; ".join(inbound_parts) if inbound_parts else "(none)"
)
# --- OAuth2 token ---
oauth2_token = (oauth2_headers or {}).get("Authorization")
if oauth2_token and litellm_api_key:
oauth2_raw = oauth2_token.removeprefix("Bearer ").strip()
litellm_raw = litellm_api_key.removeprefix("Bearer ").strip()
if oauth2_raw == litellm_raw:
debug[f"{_RESPONSE_HEADER_PREFIX}-oauth2-token"] = (
f"{MCPDebug._mask(oauth2_token)} "
f"(SAME_AS_LITELLM_KEY - likely misconfigured)"
)
else:
debug[f"{_RESPONSE_HEADER_PREFIX}-oauth2-token"] = MCPDebug._mask(
oauth2_token
)
else:
debug[f"{_RESPONSE_HEADER_PREFIX}-oauth2-token"] = MCPDebug._mask(
oauth2_token
)
# --- Auth resolution ---
debug[f"{_RESPONSE_HEADER_PREFIX}-auth-resolution"] = auth_resolution
# --- Server info ---
debug[f"{_RESPONSE_HEADER_PREFIX}-outbound-url"] = server_url or "(unknown)"
debug[f"{_RESPONSE_HEADER_PREFIX}-server-auth-type"] = (
server_auth_type or "(none)"
)
return debug
@staticmethod
def wrap_send_with_debug_headers(send: Send, debug_headers: Dict[str, str]) -> Send:
"""
Return a new ASGI ``send`` callable that injects *debug_headers*
into the ``http.response.start`` message.
"""
async def _send_with_debug(message: Message) -> None:
if message["type"] == "http.response.start":
headers = list(message.get("headers", []))
for k, v in debug_headers.items():
headers.append((k.encode(), v.encode()))
message = {**message, "headers": headers}
await send(message)
return _send_with_debug
@staticmethod
def maybe_build_debug_headers(
*,
raw_headers: Optional[Dict[str, str]],
scope: Dict,
mcp_servers: Optional[List[str]],
mcp_auth_header: Optional[str],
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]],
oauth2_headers: Optional[Dict[str, str]],
client_ip: Optional[str],
) -> Dict[str, str]:
"""
Build debug headers if debug mode is enabled, otherwise return empty dict.
This is the single entry point called from the MCP request handler.
"""
if not raw_headers or not MCPDebug.is_debug_enabled(raw_headers):
return {}
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
MCPRequestHandler,
)
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
server_url: Optional[str] = None
server_auth_type: Optional[str] = None
auth_resolution = "no-auth"
for server_name in mcp_servers or []:
server = global_mcp_server_manager.get_mcp_server_by_name(
server_name, client_ip=client_ip
)
if server:
server_url = server.url
server_auth_type = server.auth_type
auth_resolution = MCPDebug.resolve_auth_resolution(
server, mcp_auth_header, mcp_server_auth_headers, oauth2_headers
)
break
scope_headers = MCPRequestHandler._safe_get_headers_from_scope(scope)
litellm_key = MCPRequestHandler.get_litellm_api_key_from_headers(scope_headers)
return MCPDebug.build_debug_headers(
inbound_headers=raw_headers,
oauth2_headers=oauth2_headers,
litellm_api_key=litellm_key,
auth_resolution=auth_resolution,
server_url=server_url,
server_auth_type=server_auth_type,
)

View File

@@ -0,0 +1,170 @@
"""
OAuth2 client_credentials token cache for MCP servers.
Automatically fetches and refreshes access tokens for MCP servers configured
with ``client_id``, ``client_secret``, and ``token_url``.
"""
import asyncio
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
import httpx
from litellm._logging import verbose_logger
from litellm.caching.in_memory_cache import InMemoryCache
from litellm.constants import (
MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL,
MCP_OAUTH2_TOKEN_CACHE_MAX_SIZE,
MCP_OAUTH2_TOKEN_CACHE_MIN_TTL,
MCP_OAUTH2_TOKEN_EXPIRY_BUFFER_SECONDS,
)
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.types.llms.custom_http import httpxSpecialProvider
if TYPE_CHECKING:
from litellm.types.mcp_server.mcp_server_manager import MCPServer
class MCPOAuth2TokenCache(InMemoryCache):
"""
In-memory cache for OAuth2 client_credentials tokens, keyed by server_id.
Inherits from ``InMemoryCache`` for TTL-based storage and eviction.
Adds per-server ``asyncio.Lock`` to prevent duplicate concurrent fetches.
"""
def __init__(self) -> None:
super().__init__(
max_size_in_memory=MCP_OAUTH2_TOKEN_CACHE_MAX_SIZE,
default_ttl=MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL,
)
self._locks: Dict[str, asyncio.Lock] = {}
def _get_lock(self, server_id: str) -> asyncio.Lock:
return self._locks.setdefault(server_id, asyncio.Lock())
async def async_get_token(self, server: "MCPServer") -> Optional[str]:
"""Return a valid access token, fetching or refreshing as needed.
Returns ``None`` when the server lacks client credentials config.
"""
if not server.has_client_credentials:
return None
server_id = server.server_id
# Fast path — cached token is still valid
cached = self.get_cache(server_id)
if cached is not None:
return cached
# Slow path — acquire per-server lock then double-check
async with self._get_lock(server_id):
cached = self.get_cache(server_id)
if cached is not None:
return cached
token, ttl = await self._fetch_token(server)
self.set_cache(server_id, token, ttl=ttl)
return token
async def _fetch_token(self, server: "MCPServer") -> Tuple[str, int]:
"""POST to ``token_url`` with ``grant_type=client_credentials``.
Returns ``(access_token, ttl_seconds)`` where ttl accounts for the
expiry buffer so the cache entry expires before the real token does.
"""
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.MCP)
if not server.client_id or not server.client_secret or not server.token_url:
raise ValueError(
f"MCP server '{server.server_id}' missing required OAuth2 fields: "
f"client_id={bool(server.client_id)}, "
f"client_secret={bool(server.client_secret)}, "
f"token_url={bool(server.token_url)}"
)
data: Dict[str, str] = {
"grant_type": "client_credentials",
"client_id": server.client_id,
"client_secret": server.client_secret,
}
if server.scopes:
data["scope"] = " ".join(server.scopes)
verbose_logger.debug(
"Fetching OAuth2 client_credentials token for MCP server %s",
server.server_id,
)
try:
response = await client.post(server.token_url, data=data)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
raise ValueError(
f"OAuth2 token request for MCP server '{server.server_id}' "
f"failed with status {exc.response.status_code}"
) from exc
body = response.json()
if not isinstance(body, dict):
raise ValueError(
f"OAuth2 token response for MCP server '{server.server_id}' "
f"returned non-object JSON (got {type(body).__name__})"
)
access_token = body.get("access_token")
if not access_token:
raise ValueError(
f"OAuth2 token response for MCP server '{server.server_id}' "
f"missing 'access_token'"
)
# Safely parse expires_in — providers may return null or non-numeric values
raw_expires_in = body.get("expires_in")
try:
expires_in = (
int(raw_expires_in)
if raw_expires_in is not None
else MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL
)
except (TypeError, ValueError):
expires_in = MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL
ttl = max(
expires_in - MCP_OAUTH2_TOKEN_EXPIRY_BUFFER_SECONDS,
MCP_OAUTH2_TOKEN_CACHE_MIN_TTL,
)
verbose_logger.info(
"Fetched OAuth2 token for MCP server %s (expires in %ds)",
server.server_id,
expires_in,
)
return access_token, ttl
def invalidate(self, server_id: str) -> None:
"""Remove a cached token (e.g. after a 401)."""
self.delete_cache(server_id)
mcp_oauth2_token_cache = MCPOAuth2TokenCache()
async def resolve_mcp_auth(
server: "MCPServer",
mcp_auth_header: Optional[Union[str, Dict[str, str]]] = None,
) -> Optional[Union[str, Dict[str, str]]]:
"""Resolve the auth value for an MCP server.
Priority:
1. ``mcp_auth_header`` — per-request/per-user override
2. OAuth2 client_credentials token — auto-fetched and cached
3. ``server.authentication_token`` — static token from config/DB
"""
if mcp_auth_header:
return mcp_auth_header
if server.has_client_credentials:
return await mcp_oauth2_token_cache.async_get_token(server)
return server.authentication_token

View File

@@ -0,0 +1,435 @@
"""
This module is used to generate MCP tools from OpenAPI specs.
"""
import asyncio
import contextvars
import json
import os
from pathlib import PurePosixPath
from typing import Any, Dict, List, Optional
from urllib.parse import quote
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._experimental.mcp_server.tool_registry import (
global_mcp_tool_registry,
)
# Store the base URL and headers globally
BASE_URL = ""
HEADERS: Dict[str, str] = {}
# Per-request auth header override for BYOK servers.
# Set this ContextVar before calling a local tool handler to inject the user's
# stored credential into the HTTP request made by the tool function closure.
_request_auth_header: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"_request_auth_header", default=None
)
def _sanitize_path_parameter_value(param_value: Any, param_name: str) -> str:
"""Ensure path params cannot introduce directory traversal."""
if param_value is None:
return ""
value_str = str(param_value)
if value_str == "":
return ""
normalized_value = value_str.replace("\\", "/")
if "/" in normalized_value:
raise ValueError(
f"Path parameter '{param_name}' must not contain path separators"
)
if any(part in {".", ".."} for part in PurePosixPath(normalized_value).parts):
raise ValueError(
f"Path parameter '{param_name}' cannot include '.' or '..' segments"
)
return quote(value_str, safe="")
def load_openapi_spec(filepath: str) -> Dict[str, Any]:
"""
Sync wrapper. For URL specs, use the shared/custom MCP httpx client.
"""
try:
# If we're already inside an event loop, prefer the async function.
asyncio.get_running_loop()
raise RuntimeError(
"load_openapi_spec() was called from within a running event loop. "
"Use 'await load_openapi_spec_async(...)' instead."
)
except RuntimeError as e:
# "no running event loop" is fine; other RuntimeErrors we re-raise
if "no running event loop" not in str(e).lower():
raise
return asyncio.run(load_openapi_spec_async(filepath))
async def load_openapi_spec_async(filepath: str) -> Dict[str, Any]:
if filepath.startswith("http://") or filepath.startswith("https://"):
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.MCP)
# NOTE: do not close shared client if get_async_httpx_client returns a shared singleton.
# If it returns a new client each time, consider wrapping it in an async context manager.
r = await client.get(filepath)
r.raise_for_status()
return r.json()
# fallback: local file
# Local filesystem path
if not os.path.exists(filepath):
raise FileNotFoundError(f"OpenAPI spec not found at {filepath}")
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
def get_base_url(spec: Dict[str, Any], spec_path: Optional[str] = None) -> str:
"""Extract base URL from OpenAPI spec."""
# OpenAPI 3.x
if "servers" in spec and spec["servers"]:
server_url = spec["servers"][0]["url"]
# If the server URL is relative (starts with /), derive base from spec_path
if server_url.startswith("/") and spec_path:
if spec_path.startswith("http://") or spec_path.startswith("https://"):
# Extract base URL from spec_path (e.g., https://petstore3.swagger.io/api/v3/openapi.json)
# Combine domain with the relative server URL
from urllib.parse import urlparse
parsed = urlparse(spec_path)
base_domain = f"{parsed.scheme}://{parsed.netloc}"
full_base_url = base_domain + server_url
verbose_logger.info(
f"OpenAPI spec has relative server URL '{server_url}'. "
f"Deriving base from spec_path: {full_base_url}"
)
return full_base_url
return server_url
# OpenAPI 2.x (Swagger)
elif "host" in spec:
scheme = spec.get("schemes", ["https"])[0]
base_path = spec.get("basePath", "")
return f"{scheme}://{spec['host']}{base_path}"
# Fallback: derive base URL from spec_path if it's a URL
if spec_path and (
spec_path.startswith("http://") or spec_path.startswith("https://")
):
for suffix in [
"/openapi.json",
"/openapi.yaml",
"/swagger.json",
"/swagger.yaml",
]:
if spec_path.endswith(suffix):
base_url = spec_path[: -len(suffix)]
verbose_logger.info(
f"No server info in OpenAPI spec. Using derived base URL: {base_url}"
)
return base_url
if spec_path.split("/")[-1].endswith((".json", ".yaml", ".yml")):
base_url = "/".join(spec_path.split("/")[:-1])
verbose_logger.info(
f"No server info in OpenAPI spec. Using derived base URL: {base_url}"
)
return base_url
return ""
def _resolve_ref(
param: Dict[str, Any], component_params: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""Resolve a single parameter, following a $ref if present.
Returns the resolved param dict, or None if the $ref target is absent from
components (so callers can skip/filter it rather than propagating a stub
with name=None that would corrupt deduplication).
"""
ref = param.get("$ref", "")
if not ref.startswith("#/components/parameters/"):
return param
return component_params.get(ref.split("/")[-1])
def _resolve_param_list(
raw: List[Dict[str, Any]], component_params: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""Resolve $refs in a parameter list, dropping any unresolvable entries."""
result = []
for p in raw:
resolved = _resolve_ref(p, component_params)
if resolved is not None and resolved.get("name"):
result.append(resolved)
return result
def resolve_operation_params(
operation: Dict[str, Any],
path_item: Dict[str, Any],
components: Dict[str, Any],
) -> Dict[str, Any]:
"""Return a copy of *operation* with fully-resolved, merged parameters.
Handles two common patterns in real-world OpenAPI specs:
1. **$ref parameters** — ``{"$ref": "#/components/parameters/per-page"}``
instead of inline objects. Each ref is resolved against
``components["parameters"]``; unresolvable refs are silently dropped so
they cannot corrupt the deduplication set with ``(None, None)`` keys.
2. **Path-level parameters** — params defined on the path item that apply
to every HTTP method on that path (e.g. ``owner``, ``repo``). They are
merged with the operation-level params; operation-level wins when the
same ``name`` + ``in`` combination appears in both.
"""
component_params = components.get("parameters", {})
path_level = _resolve_param_list(path_item.get("parameters", []), component_params)
op_level = _resolve_param_list(operation.get("parameters", []), component_params)
op_keys = {(p["name"], p.get("in")) for p in op_level}
merged = [
p for p in path_level if (p["name"], p.get("in")) not in op_keys
] + op_level
result = dict(operation)
result["parameters"] = merged
return result
def extract_parameters(operation: Dict[str, Any]) -> tuple:
"""Extract parameter names from OpenAPI operation."""
path_params = []
query_params = []
body_params = []
# OpenAPI 3.x and 2.x parameters
if "parameters" in operation:
for param in operation["parameters"]:
if "name" not in param:
continue
param_name = param["name"]
if param.get("in") == "path":
path_params.append(param_name)
elif param.get("in") == "query":
query_params.append(param_name)
elif param.get("in") == "body":
body_params.append(param_name)
# OpenAPI 3.x requestBody
if "requestBody" in operation:
body_params.append("body")
return path_params, query_params, body_params
def build_input_schema(operation: Dict[str, Any]) -> Dict[str, Any]:
"""Build MCP input schema from OpenAPI operation."""
properties = {}
required = []
# Process parameters
if "parameters" in operation:
for param in operation["parameters"]:
if "name" not in param:
continue
param_name = param["name"]
param_schema = param.get("schema", {})
param_type = param_schema.get("type", "string")
properties[param_name] = {
"type": param_type,
"description": param.get("description", ""),
}
if param.get("required", False):
required.append(param_name)
# Process requestBody (OpenAPI 3.x)
if "requestBody" in operation:
request_body = operation["requestBody"]
content = request_body.get("content", {})
# Try to get JSON schema
if "application/json" in content:
schema = content["application/json"].get("schema", {})
properties["body"] = {
"type": "object",
"description": request_body.get("description", "Request body"),
"properties": schema.get("properties", {}),
}
if request_body.get("required", False):
required.append("body")
return {
"type": "object",
"properties": properties,
"required": required if required else [],
}
def create_tool_function(
path: str,
method: str,
operation: Dict[str, Any],
base_url: str,
headers: Optional[Dict[str, str]] = None,
):
"""Create a tool function for an OpenAPI operation.
This function creates an async tool function that can be called with
keyword arguments. Parameter names from the OpenAPI spec are accessed
directly via **kwargs, avoiding syntax errors from invalid Python identifiers.
Args:
path: API endpoint path
method: HTTP method (get, post, put, delete, patch)
operation: OpenAPI operation object
base_url: Base URL for the API
headers: Optional headers to include in requests (e.g., authentication)
Returns:
An async function that accepts **kwargs and makes the HTTP request
"""
if headers is None:
headers = {}
path_params, query_params, body_params = extract_parameters(operation)
original_method = method.lower()
async def tool_function(**kwargs: Any) -> str:
"""
Dynamically generated tool function.
Accepts keyword arguments where keys are the original OpenAPI parameter names.
The function safely handles parameter names that aren't valid Python identifiers
by using **kwargs instead of named parameters.
"""
# Allow per-request auth override (e.g. BYOK credential set via ContextVar).
# The ContextVar holds the full Authorization header value, including the
# correct prefix (Bearer / ApiKey / Basic) formatted by the caller in
# server.py based on the server's configured auth_type.
effective_headers = dict(headers)
override_auth = _request_auth_header.get()
if override_auth:
effective_headers["Authorization"] = override_auth
# Build URL from base_url and path
url = base_url + path
# Replace path parameters using original names from OpenAPI spec
# Apply path traversal validation and URL encoding
for param_name in path_params:
param_value = kwargs.get(param_name, "")
if param_value:
try:
# Sanitize and encode path parameter to prevent traversal attacks
safe_value = _sanitize_path_parameter_value(param_value, param_name)
except ValueError as exc:
return "Invalid path parameter: " + str(exc)
# Replace {param_name} or {{param_name}} in URL
url = url.replace("{" + param_name + "}", safe_value)
url = url.replace("{{" + param_name + "}}", safe_value)
# Build query params using original parameter names
params: Dict[str, Any] = {}
for param_name in query_params:
param_value = kwargs.get(param_name, "")
if param_value:
# Use original parameter name in query string (as expected by API)
params[param_name] = param_value
# Build request body
json_body: Optional[Dict[str, Any]] = None
if body_params:
# Try "body" first (most common), then check all body param names
body_value = kwargs.get("body", {})
if not body_value:
for param_name in body_params:
body_value = kwargs.get(param_name, {})
if body_value:
break
if isinstance(body_value, dict):
json_body = body_value
elif body_value:
# If it's a string, try to parse as JSON
try:
json_body = (
json.loads(body_value)
if isinstance(body_value, str)
else {"data": body_value}
)
except (json.JSONDecodeError, TypeError):
json_body = {"data": body_value}
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.MCP)
if original_method == "get":
response = await client.get(url, params=params, headers=effective_headers)
elif original_method == "post":
response = await client.post(
url, params=params, json=json_body, headers=effective_headers
)
elif original_method == "put":
response = await client.put(
url, params=params, json=json_body, headers=effective_headers
)
elif original_method == "delete":
response = await client.delete(
url, params=params, headers=effective_headers
)
elif original_method == "patch":
response = await client.patch(
url, params=params, json=json_body, headers=effective_headers
)
else:
return f"Unsupported HTTP method: {original_method}"
return response.text
return tool_function
def register_tools_from_openapi(spec: Dict[str, Any], base_url: str):
"""Register MCP tools from OpenAPI specification."""
paths = spec.get("paths", {})
for path, path_item in paths.items():
for method in ["get", "post", "put", "delete", "patch"]:
if method in path_item:
operation = path_item[method]
# Generate tool name
operation_id = operation.get(
"operationId", f"{method}_{path.replace('/', '_')}"
)
tool_name = operation_id.replace(" ", "_").lower()
# Get description
description = operation.get(
"summary", operation.get("description", f"{method.upper()} {path}")
)
# Build input schema
input_schema = build_input_schema(operation)
# Create tool function
tool_func = create_tool_function(path, method, operation, base_url)
tool_func.__name__ = tool_name
tool_func.__doc__ = description
# Register tool with local registry
global_mcp_tool_registry.register_tool(
name=tool_name,
description=description,
input_schema=input_schema,
handler=tool_func,
)
verbose_logger.debug(f"Registered tool: {tool_name}")

View File

@@ -0,0 +1,256 @@
"""
Semantic MCP Tool Filtering using semantic-router
Filters MCP tools semantically for /chat/completions and /responses endpoints.
"""
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from litellm._logging import verbose_logger
if TYPE_CHECKING:
from semantic_router.routers import SemanticRouter
from litellm.router import Router
class SemanticMCPToolFilter:
"""Filters MCP tools using semantic similarity to reduce context window size."""
def __init__(
self,
embedding_model: str,
litellm_router_instance: "Router",
top_k: int = 10,
similarity_threshold: float = 0.3,
enabled: bool = True,
):
"""
Initialize the semantic tool filter.
Args:
embedding_model: Model to use for embeddings (e.g., "text-embedding-3-small")
litellm_router_instance: Router instance for embedding generation
top_k: Maximum number of tools to return
similarity_threshold: Minimum similarity score for filtering
enabled: Whether filtering is enabled
"""
self.enabled = enabled
self.top_k = top_k
self.similarity_threshold = similarity_threshold
self.embedding_model = embedding_model
self.router_instance = litellm_router_instance
self.tool_router: Optional["SemanticRouter"] = None
self._tool_map: Dict[str, Any] = {} # MCPTool objects or OpenAI function dicts
async def build_router_from_mcp_registry(self) -> None:
"""Build semantic router from all MCP tools in the registry (no auth checks)."""
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
try:
# Get all servers from registry without auth checks
registry = global_mcp_server_manager.get_registry()
if not registry:
verbose_logger.warning("MCP registry is empty")
self.tool_router = None
return
# Fetch tools from all servers in parallel
all_tools = []
for server_id, server in registry.items():
try:
tools = await global_mcp_server_manager.get_tools_for_server(
server_id
)
all_tools.extend(tools)
except Exception as e:
verbose_logger.warning(
f"Failed to fetch tools from server {server_id}: {e}"
)
continue
if not all_tools:
verbose_logger.warning("No MCP tools found in registry")
self.tool_router = None
return
verbose_logger.info(
f"Fetched {len(all_tools)} tools from {len(registry)} MCP servers"
)
self._build_router(all_tools)
except Exception as e:
verbose_logger.error(f"Failed to build router from MCP registry: {e}")
self.tool_router = None
raise
def _extract_tool_info(self, tool) -> tuple[str, str]:
"""Extract name and description from MCP tool or OpenAI function dict."""
name: str
description: str
if isinstance(tool, dict):
# OpenAI function format
name = tool.get("name", "")
description = tool.get("description", name)
else:
# MCPTool object
name = str(tool.name)
description = str(tool.description) if tool.description else str(tool.name)
return name, description
def _build_router(self, tools: List) -> None:
"""Build semantic router with tools (MCPTool objects or OpenAI function dicts)."""
from semantic_router.routers import SemanticRouter
from semantic_router.routers.base import Route
from litellm.router_strategy.auto_router.litellm_encoder import (
LiteLLMRouterEncoder,
)
if not tools:
self.tool_router = None
return
try:
# Convert tools to routes
routes = []
self._tool_map = {}
for tool in tools:
name, description = self._extract_tool_info(tool)
self._tool_map[name] = tool
routes.append(
Route(
name=name,
description=description,
utterances=[description],
score_threshold=self.similarity_threshold,
)
)
self.tool_router = SemanticRouter(
routes=routes,
encoder=LiteLLMRouterEncoder(
litellm_router_instance=self.router_instance,
model_name=self.embedding_model,
score_threshold=self.similarity_threshold,
),
auto_sync="local",
)
verbose_logger.info(f"Built semantic router with {len(routes)} tools")
except Exception as e:
verbose_logger.error(f"Failed to build semantic router: {e}")
self.tool_router = None
raise
async def filter_tools(
self,
query: str,
available_tools: List[Any],
top_k: Optional[int] = None,
) -> List[Any]:
"""
Filter tools semantically based on query.
Args:
query: User query to match against tools
available_tools: Full list of available MCP tools
top_k: Override default top_k (optional)
Returns:
Filtered and ordered list of tools (up to top_k)
"""
# Early returns for cases where we can't/shouldn't filter
if not self.enabled:
return available_tools
if not available_tools:
return available_tools
if not query or not query.strip():
return available_tools
# Router should be built on startup - if not, something went wrong
if self.tool_router is None:
verbose_logger.warning(
"Router not initialized - was build_router_from_mcp_registry() called on startup?"
)
return available_tools
# Run semantic filtering
try:
limit = top_k or self.top_k
matches = self.tool_router(text=query, limit=limit)
matched_tool_names = self._extract_tool_names_from_matches(matches)
if not matched_tool_names:
return available_tools
return self._get_tools_by_names(matched_tool_names, available_tools)
except Exception as e:
verbose_logger.error(f"Semantic tool filter failed: {e}", exc_info=True)
return available_tools
def _extract_tool_names_from_matches(self, matches) -> List[str]:
"""Extract tool names from semantic router match results."""
if not matches:
return []
# Handle single match
if hasattr(matches, "name") and matches.name:
return [matches.name]
# Handle list of matches
if isinstance(matches, list):
return [m.name for m in matches if hasattr(m, "name") and m.name]
return []
def _get_tools_by_names(
self, tool_names: List[str], available_tools: List[Any]
) -> List[Any]:
"""Get tools from available_tools by their names, preserving order."""
# Match tools from available_tools (preserves format - dict or MCPTool)
matched_tools = []
for tool in available_tools:
tool_name, _ = self._extract_tool_info(tool)
if tool_name in tool_names:
matched_tools.append(tool)
# Reorder to match semantic router's ordering
tool_map = {self._extract_tool_info(t)[0]: t for t in matched_tools}
return [tool_map[name] for name in tool_names if name in tool_map]
def extract_user_query(self, messages: List[Dict[str, Any]]) -> str:
"""
Extract user query from messages for /chat/completions or /responses.
Args:
messages: List of message dictionaries (from 'messages' or 'input' field)
Returns:
Extracted query string
"""
for msg in reversed(messages):
if msg.get("role") == "user":
content = msg.get("content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
texts = [
block.get("text", "") if isinstance(block, dict) else str(block)
for block in content
if isinstance(block, (dict, str))
]
return " ".join(texts)
return ""

View File

@@ -0,0 +1,150 @@
"""
This is a modification of code from: https://github.com/SecretiveShell/MCP-Bridge/blob/master/mcp_bridge/mcp_server/sse_transport.py
Credit to the maintainers of SecretiveShell for their SSE Transport implementation
"""
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import quote
from uuid import UUID, uuid4
import anyio
import mcp.types as types
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from fastapi.requests import Request
from fastapi.responses import Response
from pydantic import ValidationError
from sse_starlette import EventSourceResponse
from starlette.types import Receive, Scope, Send
from litellm._logging import verbose_logger
class SseServerTransport:
"""
SSE server transport for MCP. This class provides _two_ ASGI applications,
suitable to be used with a framework like Starlette and a server like Hypercorn:
1. connect_sse() is an ASGI application which receives incoming GET requests,
and sets up a new SSE stream to send server messages to the client.
2. handle_post_message() is an ASGI application which receives incoming POST
requests, which should contain client messages that link to a
previously-established SSE session.
"""
_endpoint: str
_read_stream_writers: dict[
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
]
def __init__(self, endpoint: str) -> None:
"""
Creates a new SSE server transport, which will direct the client to POST
messages to the relative or absolute URL given.
"""
super().__init__()
self._endpoint = endpoint
self._read_stream_writers = {}
verbose_logger.debug(
f"SseServerTransport initialized with endpoint: {endpoint}"
)
@asynccontextmanager
async def connect_sse(self, request: Request):
if request.scope["type"] != "http":
verbose_logger.error("connect_sse received non-HTTP request")
raise ValueError("connect_sse can only handle HTTP requests")
verbose_logger.debug("Setting up SSE connection")
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
session_id = uuid4()
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
self._read_stream_writers[session_id] = read_stream_writer
verbose_logger.debug(f"Created new session with ID: {session_id}")
sse_stream_writer: MemoryObjectSendStream[dict[str, Any]]
sse_stream_reader: MemoryObjectReceiveStream[dict[str, Any]]
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(
0, dict[str, Any]
)
async def sse_writer():
verbose_logger.debug("Starting SSE writer")
async with sse_stream_writer, write_stream_reader:
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
verbose_logger.debug(f"Sent endpoint event: {session_uri}")
async for message in write_stream_reader:
verbose_logger.debug(f"Sending message via SSE: {message}")
await sse_stream_writer.send(
{
"event": "message",
"data": message.model_dump_json(
by_alias=True, exclude_none=True
),
}
)
async with anyio.create_task_group() as tg:
response = EventSourceResponse(
content=sse_stream_reader, data_sender_callable=sse_writer
)
verbose_logger.debug("Starting SSE response task")
tg.start_soon(response, request.scope, request.receive, request._send)
verbose_logger.debug("Yielding read and write streams")
yield (read_stream, write_stream)
async def handle_post_message(
self, scope: Scope, receive: Receive, send: Send
) -> Response:
verbose_logger.debug("Handling POST message")
request = Request(scope, receive)
session_id_param = request.query_params.get("session_id")
if session_id_param is None:
verbose_logger.warning("Received request without session_id")
response = Response("session_id is required", status_code=400)
return response
try:
session_id = UUID(hex=session_id_param)
verbose_logger.debug(f"Parsed session ID: {session_id}")
except ValueError:
verbose_logger.warning(f"Received invalid session ID: {session_id_param}")
response = Response("Invalid session ID", status_code=400)
return response
writer = self._read_stream_writers.get(session_id)
if not writer:
verbose_logger.warning(f"Could not find session for ID: {session_id}")
response = Response("Could not find session", status_code=404)
return response
json = await request.json()
verbose_logger.debug(f"Received JSON: {json}")
try:
message = types.JSONRPCMessage.model_validate(json)
verbose_logger.debug(f"Validated client message: {message}")
except ValidationError as err:
verbose_logger.error(f"Failed to parse message: {err}")
response = Response("Could not parse message", status_code=400)
await writer.send(err)
return response
verbose_logger.debug(f"Sending message to writer: {message}")
response = Response("Accepted", status_code=202)
await writer.send(message)
return response

View File

@@ -0,0 +1,133 @@
import json
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from litellm._logging import verbose_logger
from litellm.proxy.types_utils.utils import get_instance_fn
from litellm.types.mcp_server.tool_registry import MCPTool
if TYPE_CHECKING:
from mcp.types import Tool as MCPToolSDKTool
else:
try:
from mcp.types import Tool as MCPToolSDKTool
except ImportError:
MCPToolSDKTool = None # type: ignore
class MCPToolRegistry:
"""
A registry for managing MCP tools
"""
def __init__(self):
# Registry to store all registered tools
self.tools: Dict[str, MCPTool] = {}
def register_tool(
self,
name: str,
description: str,
input_schema: Dict[str, Any],
handler: Callable,
) -> None:
"""
Register a new tool in the registry
"""
self.tools[name] = MCPTool(
name=name,
description=description,
input_schema=input_schema,
handler=handler,
)
verbose_logger.debug(f"Registered tool: {name}")
def get_tool(self, name: str) -> Optional[MCPTool]:
"""
Get a tool from the registry by name
"""
return self.tools.get(name)
def list_tools(self, tool_prefix: Optional[str] = None) -> List[MCPTool]:
"""
List all registered tools
"""
if tool_prefix:
return [
tool
for tool in self.tools.values()
if tool.name.startswith(tool_prefix)
]
return list(self.tools.values())
def convert_tools_to_mcp_sdk_tool_type(
self, tools: List[MCPTool]
) -> List["MCPToolSDKTool"]:
if MCPToolSDKTool is None:
raise ImportError(
"MCP SDK is not installed. Please install it with: pip install 'litellm[proxy]'"
)
return [
MCPToolSDKTool(
name=tool.name,
description=tool.description,
inputSchema=tool.input_schema,
)
for tool in tools
]
def load_tools_from_config(
self, mcp_tools_config: Optional[Dict[str, Any]] = None
) -> None:
"""
Load and register tools from the proxy config
Args:
mcp_tools_config: The mcp_tools config from the proxy config
"""
if mcp_tools_config is None:
raise ValueError(
"mcp_tools_config is required, please set `mcp_tools` in your proxy config"
)
for tool_config in mcp_tools_config:
if not isinstance(tool_config, dict):
raise ValueError("mcp_tools_config must be a list of dictionaries")
name = tool_config.get("name")
description = tool_config.get("description")
input_schema = tool_config.get("input_schema", {})
handler_name = tool_config.get("handler")
if not all([name, description, handler_name]):
continue
# Try to resolve the handler
# First check if it's a module path (e.g., "module.submodule.function")
if handler_name is None:
raise ValueError(f"handler is required for tool {name}")
handler = get_instance_fn(handler_name)
if handler is None:
verbose_logger.warning(
f"Warning: Could not find handler {handler_name} for tool {name}"
)
continue
# Register the tool
if name is None:
raise ValueError(f"name is required for tool {name}")
if description is None:
raise ValueError(f"description is required for tool {name}")
self.register_tool(
name=name,
description=description,
input_schema=input_schema,
handler=handler,
)
verbose_logger.debug(
"all registered tools: %s", json.dumps(self.tools, indent=4, default=str)
)
global_mcp_tool_registry = MCPToolRegistry()

View File

@@ -0,0 +1,85 @@
"""Helpers to resolve real team contexts for UI session tokens."""
from __future__ import annotations
from typing import List
from litellm._logging import verbose_logger
from litellm.constants import UI_SESSION_TOKEN_TEAM_ID
from litellm.proxy._types import UserAPIKeyAuth
def clone_user_api_key_auth_with_team(
user_api_key_auth: UserAPIKeyAuth,
team_id: str,
) -> UserAPIKeyAuth:
"""Return a deep copy of the auth context with a different team id."""
try:
cloned_auth = user_api_key_auth.model_copy()
except AttributeError:
cloned_auth = user_api_key_auth.copy() # type: ignore[attr-defined]
cloned_auth.team_id = team_id
return cloned_auth
async def resolve_ui_session_team_ids(
user_api_key_auth: UserAPIKeyAuth,
) -> List[str]:
"""Resolve the real team ids backing a UI session token."""
if (
user_api_key_auth.team_id != UI_SESSION_TOKEN_TEAM_ID
or not user_api_key_auth.user_id
):
return []
from litellm.proxy.auth.auth_checks import get_user_object
from litellm.proxy.proxy_server import (
prisma_client,
proxy_logging_obj,
user_api_key_cache,
)
if prisma_client is None:
verbose_logger.debug("Cannot resolve UI session team ids without DB access")
return []
try:
user_obj = await get_user_object(
user_id=user_api_key_auth.user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
user_id_upsert=False,
parent_otel_span=user_api_key_auth.parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
except Exception as exc: # pragma: no cover - defensive logging
verbose_logger.warning(
"Failed to load teams for UI session token user.",
exc,
)
return []
if user_obj is None or not user_obj.teams:
return []
resolved_team_ids: List[str] = []
for team_id in user_obj.teams:
if team_id and team_id not in resolved_team_ids:
resolved_team_ids.append(team_id)
return resolved_team_ids
async def build_effective_auth_contexts(
user_api_key_auth: UserAPIKeyAuth,
) -> List[UserAPIKeyAuth]:
"""Return auth contexts that reflect the actual teams for UI session tokens."""
resolved_team_ids = await resolve_ui_session_team_ids(user_api_key_auth)
if resolved_team_ids:
return [
clone_user_api_key_auth_with_team(user_api_key_auth, team_id)
for team_id in resolved_team_ids
]
return [user_api_key_auth]

View File

@@ -0,0 +1,167 @@
"""
MCP Server Utilities
"""
from typing import Any, Dict, Mapping, Optional, Tuple
import os
import importlib
# Constants
LITELLM_MCP_SERVER_NAME = "litellm-mcp-server"
LITELLM_MCP_SERVER_VERSION = "1.0.0"
LITELLM_MCP_SERVER_DESCRIPTION = "MCP Server for LiteLLM"
MCP_TOOL_PREFIX_SEPARATOR = os.environ.get("MCP_TOOL_PREFIX_SEPARATOR", "-")
MCP_TOOL_PREFIX_FORMAT = "{server_name}{separator}{tool_name}"
def is_mcp_available() -> bool:
"""
Returns True if the MCP module is available, False otherwise
"""
try:
importlib.import_module("mcp")
return True
except ImportError:
return False
def normalize_server_name(server_name: str) -> str:
"""
Normalize server name by replacing spaces with underscores
"""
return server_name.replace(" ", "_")
def validate_and_normalize_mcp_server_payload(payload: Any) -> None:
"""
Validate and normalize MCP server payload fields (server_name and alias).
This function:
1. Validates that server_name and alias don't contain the MCP_TOOL_PREFIX_SEPARATOR
2. Normalizes alias by replacing spaces with underscores
3. Sets default alias if not provided (using server_name as base)
Args:
payload: The payload object containing server_name and alias fields
Raises:
HTTPException: If validation fails
"""
# Server name validation: disallow '-'
if hasattr(payload, "server_name") and payload.server_name:
validate_mcp_server_name(payload.server_name, raise_http_exception=True)
# Alias validation: disallow '-'
if hasattr(payload, "alias") and payload.alias:
validate_mcp_server_name(payload.alias, raise_http_exception=True)
# Alias normalization and defaulting
alias = getattr(payload, "alias", None)
server_name = getattr(payload, "server_name", None)
if not alias and server_name:
alias = normalize_server_name(server_name)
elif alias:
alias = normalize_server_name(alias)
# Update the payload with normalized alias
if hasattr(payload, "alias"):
payload.alias = alias
def add_server_prefix_to_name(name: str, server_name: str) -> str:
"""Add server name prefix to any MCP resource name."""
formatted_server_name = normalize_server_name(server_name)
return MCP_TOOL_PREFIX_FORMAT.format(
server_name=formatted_server_name,
separator=MCP_TOOL_PREFIX_SEPARATOR,
tool_name=name,
)
def get_server_prefix(server: Any) -> str:
"""Return the prefix for a server: alias if present, else server_name, else server_id"""
if hasattr(server, "alias") and server.alias:
return server.alias
if hasattr(server, "server_name") and server.server_name:
return server.server_name
if hasattr(server, "server_id"):
return server.server_id
return ""
def split_server_prefix_from_name(prefixed_name: str) -> Tuple[str, str]:
"""Return the unprefixed name plus the server name used as prefix."""
if MCP_TOOL_PREFIX_SEPARATOR in prefixed_name:
parts = prefixed_name.split(MCP_TOOL_PREFIX_SEPARATOR, 1)
if len(parts) == 2:
return parts[1], parts[0]
return prefixed_name, ""
def is_tool_name_prefixed(tool_name: str) -> bool:
"""
Check if tool name has server prefix
Args:
tool_name: Tool name to check
Returns:
True if tool name is prefixed, False otherwise
"""
return MCP_TOOL_PREFIX_SEPARATOR in tool_name
def validate_mcp_server_name(
server_name: str, raise_http_exception: bool = False
) -> None:
"""
Validate that MCP server name does not contain 'MCP_TOOL_PREFIX_SEPARATOR'.
Args:
server_name: The server name to validate
raise_http_exception: If True, raises HTTPException instead of generic Exception
Raises:
Exception or HTTPException: If server name contains 'MCP_TOOL_PREFIX_SEPARATOR'
"""
if server_name and MCP_TOOL_PREFIX_SEPARATOR in server_name:
error_message = f"Server name cannot contain '{MCP_TOOL_PREFIX_SEPARATOR}'. Use an alternative character instead Found: {server_name}"
if raise_http_exception:
from fastapi import HTTPException
from starlette import status
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail={"error": error_message}
)
else:
raise Exception(error_message)
def merge_mcp_headers(
*,
extra_headers: Optional[Mapping[str, str]] = None,
static_headers: Optional[Mapping[str, str]] = None,
) -> Optional[Dict[str, str]]:
"""Merge outbound HTTP headers for MCP calls.
This is used when calling out to external MCP servers (or OpenAPI-based MCP tools).
Merge rules:
- Start with `extra_headers` (typically OAuth2-derived headers)
- Overlay `static_headers` (user-configured per MCP server)
If both contain the same key, `static_headers` wins. This matches the existing
behavior in `MCPServerManager` where `server.static_headers` is applied after
any caller-provided headers.
"""
merged: Dict[str, str] = {}
if extra_headers:
merged.update({str(k): str(v) for k, v in extra_headers.items()})
if static_headers:
merged.update({str(k): str(v) for k, v in static_headers.items()})
return merged or None

View File

@@ -0,0 +1,4 @@
def my_custom_rule(input): # receives the model response
# if len(input) < 5: # trigger fallback if the model response is too short
return False
return True

View File

@@ -0,0 +1,40 @@
### DEPRECATED ###
## unused file. initially written for json logging on proxy.
import json
import logging
import os
from logging import Formatter
from litellm import json_logs
# Set default log level to INFO
log_level = os.getenv("LITELLM_LOG", "INFO")
numeric_level: str = getattr(logging, log_level.upper())
class JsonFormatter(Formatter):
def __init__(self):
super(JsonFormatter, self).__init__()
def format(self, record):
json_record = {
"message": record.getMessage(),
"level": record.levelname,
"timestamp": self.formatTime(record, self.datefmt),
}
return json.dumps(json_record)
logger = logging.root
handler = logging.StreamHandler()
if json_logs:
handler.setFormatter(JsonFormatter())
else:
formatter = logging.Formatter(
"\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s",
datefmt="%H:%M:%S",
)
handler.setFormatter(formatter)
logger.handlers = [handler]
logger.setLevel(numeric_level)

View File

@@ -0,0 +1,14 @@
model_list:
- model_name: bedrock-claude
litellm_params:
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
aws_region_name: us-east-1
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
litellm_settings:
callbacks: ["datadog"] # logs llm success + failure logs on datadog
service_callback: ["datadog"] # logs redis, postgres failures on datadog
general_settings:
store_prompts_in_spend_logs: true

View File

@@ -0,0 +1,41 @@
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: openai/gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
- model_name: gpt-4o
litellm_params:
model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
- model_name: claude-sonnet-4-5-20250929
litellm_params:
model: anthropic/claude-sonnet-4-5-20250929
- model_name: gpt-4.1-mini
litellm_params:
model: openai/gpt-4.1-mini
- model_name: gpt-5-mini
litellm_params:
model: openai/gpt-5-mini
- model_name: custom_litellm_model
litellm_params:
model: litellm_agent/claude-sonnet-4-5-20250929
litellm_system_prompt: "Be a helpful assistant."
guardrails:
- guardrail_name: "tool_policy"
litellm_params:
guardrail: tool_policy
mode: [pre_call, post_call]
default_on: true
mcp_servers:
my_http_server:
url: "http://0.0.0.0:8001/mcp"
transport: "http"
description: "My custom MCP server"
available_on_public_internet: true
general_settings:
store_model_in_db: true
store_prompts_in_spend_logs: true

View File

@@ -0,0 +1,110 @@
model_list:
- model_name: claude-3-5-sonnet
litellm_params:
model: claude-3-haiku-20240307
# - model_name: gemini-1.5-flash-gemini
# litellm_params:
# model: vertex_ai_beta/gemini-1.5-flash
# api_base: https://gateway.ai.cloudflare.com/v1/fa4cdcab1f32b95ca3b53fd36043d691/test/google-vertex-ai/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.5-flash
- litellm_params:
api_base: http://0.0.0.0:8080
api_key: ''
model: gpt-4o
rpm: 800
input_cost_per_token: 300
model_name: gpt-4o
- model_name: llama3-70b-8192
litellm_params:
model: groq/llama3-70b-8192
- model_name: fake-openai-endpoint
litellm_params:
model: predibase/llama-3-8b-instruct
api_key: os.environ/PREDIBASE_API_KEY
tenant_id: os.environ/PREDIBASE_TENANT_ID
max_new_tokens: 256
# - litellm_params:
# api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
# api_key: os.environ/AZURE_EUROPE_API_KEY
# model: azure/gpt-35-turbo
# rpm: 10
# model_name: gpt-3.5-turbo-fake-model
- litellm_params:
api_base: https://openai-gpt-4-test-v-1.openai.azure.com
api_key: os.environ/AZURE_API_KEY
api_version: 2024-02-15-preview
model: azure/chatgpt-v-2
tpm: 100
model_name: gpt-3.5-turbo
- litellm_params:
model: anthropic.claude-3-sonnet-20240229-v1:0
model_name: bedrock-anthropic-claude-3
- litellm_params:
model: claude-3-haiku-20240307
model_name: anthropic-claude-3
- litellm_params:
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_key: os.environ/AZURE_API_KEY
api_version: 2024-02-15-preview
model: azure/chatgpt-v-2
drop_params: True
tpm: 100
model_name: gpt-3.5-turbo
- model_name: tts
litellm_params:
model: openai/tts-1
- model_name: gpt-4-turbo-preview
litellm_params:
api_base: https://openai-france-1234.openai.azure.com
api_key: os.environ/AZURE_FRANCE_API_KEY
api_version: 2024-02-15-preview
model: azure/gpt-turbo
- model_name: text-embedding
litellm_params:
model: textembedding-gecko-multilingual@001
vertex_project: my-project-9d5c
vertex_location: us-central1
- model_name: lbl/command-r-plus
litellm_params:
model: openai/lbl/command-r-plus
api_key: "os.environ/VLLM_API_KEY"
api_base: http://vllm-command:8000/v1
rpm: 1000
input_cost_per_token: 0
output_cost_per_token: 0
model_info:
max_input_tokens: 80920
# litellm_settings:
# callbacks: ["dynamic_rate_limiter"]
# # success_callback: ["langfuse"]
# # failure_callback: ["langfuse"]
# # default_team_settings:
# # - team_id: proj1
# # success_callback: ["langfuse"]
# # langfuse_public_key: os.environ/LANGFUSE_PUBLIC_KEY
# # langfuse_secret: os.environ/LANGFUSE_SECRET
# # langfuse_host: https://us.cloud.langfuse.com
# # - team_id: proj2
# # success_callback: ["langfuse"]
# # langfuse_public_key: os.environ/LANGFUSE_PUBLIC_KEY
# # langfuse_secret: os.environ/LANGFUSE_SECRET
# # langfuse_host: https://us.cloud.langfuse.com
assistant_settings:
custom_llm_provider: openai
litellm_params:
api_key: os.environ/OPENAI_API_KEY
router_settings:
enable_pre_call_checks: true
litellm_settings:
callbacks: ["s3"]
# general_settings:
# # alerting: ["slack"]
# enable_jwt_auth: True
# litellm_jwtauth:
# team_id_jwt_field: "client_id"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,500 @@
"""
A2A Protocol endpoints for LiteLLM Proxy.
Allows clients to invoke agents through LiteLLM using the A2A protocol.
The A2A SDK can point to LiteLLM's URL and invoke agents registered with LiteLLM.
"""
import json
from typing import Any, Dict, Optional
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.agent_endpoints.utils import merge_agent_headers
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.utils import all_litellm_params
router = APIRouter()
def _jsonrpc_error(
request_id: Optional[str],
code: int,
message: str,
status_code: int = 400,
) -> JSONResponse:
"""Create a JSON-RPC 2.0 error response."""
return JSONResponse(
content={
"jsonrpc": "2.0",
"id": request_id,
"error": {"code": code, "message": message},
},
status_code=status_code,
)
def _get_agent(agent_id: str):
"""Look up an agent by ID or name. Returns None if not found."""
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
agent = global_agent_registry.get_agent_by_id(agent_id=agent_id)
if agent is None:
agent = global_agent_registry.get_agent_by_name(agent_name=agent_id)
return agent
def _enforce_inbound_trace_id(agent: Any, request: Request) -> None:
"""Raise 400 if agent requires x-litellm-trace-id on inbound calls and it is missing."""
agent_litellm_params = agent.litellm_params or {}
if not agent_litellm_params.get("require_trace_id_on_calls_to_agent"):
return
from litellm.proxy.litellm_pre_call_utils import get_chain_id_from_headers
headers_dict = dict(request.headers)
trace_id = get_chain_id_from_headers(headers_dict)
if not trace_id:
raise HTTPException(
status_code=400,
detail=(
f"Agent '{agent.agent_id}' requires x-litellm-trace-id header "
"on all inbound requests."
),
)
async def _handle_stream_message(
api_base: Optional[str],
request_id: str,
params: dict,
litellm_params: Optional[dict] = None,
agent_id: Optional[str] = None,
metadata: Optional[dict] = None,
proxy_server_request: Optional[dict] = None,
*,
agent_extra_headers: Optional[Dict[str, str]] = None,
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
request_data: Optional[dict] = None,
proxy_logging_obj: Optional[Any] = None,
) -> StreamingResponse:
"""Handle message/stream method via SDK functions.
When user_api_key_dict, request_data, and proxy_logging_obj are provided,
uses common_request_processing.async_streaming_data_generator with NDJSON
serializers so proxy hooks and cost injection apply.
"""
from litellm.a2a_protocol import asend_message_streaming
from litellm.a2a_protocol.main import A2A_SDK_AVAILABLE
if not A2A_SDK_AVAILABLE:
async def _error_stream():
yield json.dumps(
{
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32603,
"message": "Server error: 'a2a' package not installed",
},
}
) + "\n"
return StreamingResponse(_error_stream(), media_type="application/x-ndjson")
from a2a.types import MessageSendParams, SendStreamingMessageRequest
use_proxy_hooks = (
user_api_key_dict is not None
and request_data is not None
and proxy_logging_obj is not None
)
async def stream_response():
try:
a2a_request = SendStreamingMessageRequest(
id=request_id,
params=MessageSendParams(**params),
)
a2a_stream = asend_message_streaming(
request=a2a_request,
api_base=api_base,
litellm_params=litellm_params,
agent_id=agent_id,
metadata=metadata,
proxy_server_request=proxy_server_request,
agent_extra_headers=agent_extra_headers,
)
if (
use_proxy_hooks
and user_api_key_dict is not None
and request_data is not None
and proxy_logging_obj is not None
):
from litellm.proxy.common_request_processing import (
ProxyBaseLLMRequestProcessing,
)
def _ndjson_chunk(chunk: Any) -> str:
if hasattr(chunk, "model_dump"):
obj = chunk.model_dump(mode="json", exclude_none=True)
else:
obj = chunk
return json.dumps(obj) + "\n"
def _ndjson_error(proxy_exc: Any) -> str:
return (
json.dumps(
{
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32603,
"message": getattr(
proxy_exc,
"message",
f"Streaming error: {proxy_exc!s}",
),
},
}
)
+ "\n"
)
async for (
line
) in ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
response=a2a_stream,
user_api_key_dict=user_api_key_dict,
request_data=request_data,
proxy_logging_obj=proxy_logging_obj,
serialize_chunk=_ndjson_chunk,
serialize_error=_ndjson_error,
):
yield line
else:
async for chunk in a2a_stream:
if hasattr(chunk, "model_dump"):
yield json.dumps(
chunk.model_dump(mode="json", exclude_none=True)
) + "\n"
else:
yield json.dumps(chunk) + "\n"
except Exception as e:
verbose_proxy_logger.exception(f"Error streaming A2A response: {e}")
if (
use_proxy_hooks
and proxy_logging_obj is not None
and user_api_key_dict is not None
and request_data is not None
):
transformed_exception = await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=request_data,
)
if transformed_exception is not None:
e = transformed_exception
if isinstance(e, HTTPException):
raise
yield json.dumps(
{
"jsonrpc": "2.0",
"id": request_id,
"error": {"code": -32603, "message": f"Streaming error: {str(e)}"},
}
) + "\n"
return StreamingResponse(stream_response(), media_type="application/x-ndjson")
@router.get(
"/a2a/{agent_id}/.well-known/agent-card.json",
tags=["[beta] A2A Agents"],
dependencies=[Depends(user_api_key_auth)],
)
@router.get(
"/a2a/{agent_id}/.well-known/agent.json",
tags=["[beta] A2A Agents"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_agent_card(
agent_id: str,
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get the agent card for an agent (A2A discovery endpoint).
Supports both standard paths:
- /.well-known/agent-card.json
- /.well-known/agent.json
The URL in the agent card is rewritten to point to the LiteLLM proxy,
so all subsequent A2A calls go through LiteLLM for logging and cost tracking.
"""
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
AgentRequestHandler,
)
try:
agent = _get_agent(agent_id)
if agent is None:
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
# Check agent permission (skip for admin users)
is_allowed = await AgentRequestHandler.is_agent_allowed(
agent_id=agent.agent_id,
user_api_key_auth=user_api_key_dict,
)
if not is_allowed:
raise HTTPException(
status_code=403,
detail=f"Agent '{agent_id}' is not allowed for your key/team. Contact proxy admin for access.",
)
# Copy and rewrite URL to point to LiteLLM proxy
agent_card = dict(agent.agent_card_params)
agent_card["url"] = f"{str(request.base_url).rstrip('/')}/a2a/{agent_id}"
verbose_proxy_logger.debug(
f"Returning agent card for '{agent_id}' with proxy URL: {agent_card['url']}"
)
return JSONResponse(content=agent_card)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error getting agent card: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/a2a/{agent_id}",
tags=["[beta] A2A Agents"],
dependencies=[Depends(user_api_key_auth)],
)
@router.post(
"/a2a/{agent_id}/message/send",
tags=["[beta] A2A Agents"],
dependencies=[Depends(user_api_key_auth)],
)
@router.post(
"/v1/a2a/{agent_id}/message/send",
tags=["[beta] A2A Agents"],
dependencies=[Depends(user_api_key_auth)],
)
async def invoke_agent_a2a( # noqa: PLR0915
agent_id: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Invoke an agent using the A2A protocol (JSON-RPC 2.0).
Supported methods:
- message/send: Send a message and get a response
- message/stream: Send a message and stream the response
"""
from litellm.a2a_protocol import asend_message
from litellm.a2a_protocol.main import A2A_SDK_AVAILABLE
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
AgentRequestHandler,
)
from litellm.proxy.proxy_server import (
general_settings,
proxy_config,
proxy_logging_obj,
version,
)
body = {}
try:
body = await request.json()
verbose_proxy_logger.debug(f"A2A request for agent '{agent_id}': {body}")
# Validate JSON-RPC format
if body.get("jsonrpc") != "2.0":
return _jsonrpc_error(
body.get("id"), -32600, "Invalid Request: jsonrpc must be '2.0'"
)
request_id = body.get("id")
method = body.get("method")
params = body.get("params", {})
if params:
# extract any litellm params from the params - eg. 'guardrails'
params_to_remove = []
for key, value in params.items():
if key in all_litellm_params:
params_to_remove.append(key)
body[key] = value
for key in params_to_remove:
params.pop(key)
if not A2A_SDK_AVAILABLE:
return _jsonrpc_error(
request_id,
-32603,
"Server error: 'a2a' package not installed. Please install 'a2a-sdk'.",
500,
)
# Find the agent
agent = _get_agent(agent_id)
if agent is None:
return _jsonrpc_error(
request_id, -32000, f"Agent '{agent_id}' not found", 404
)
is_allowed = await AgentRequestHandler.is_agent_allowed(
agent_id=agent.agent_id,
user_api_key_auth=user_api_key_dict,
)
if not is_allowed:
raise HTTPException(
status_code=403,
detail=f"Agent '{agent_id}' is not allowed for your key/team. Contact proxy admin for access.",
)
_enforce_inbound_trace_id(agent, request)
# Get backend URL and agent name
agent_url = agent.agent_card_params.get("url")
agent_name = agent.agent_card_params.get("name", agent_id)
# Get litellm_params (may include custom_llm_provider for completion bridge)
litellm_params = agent.litellm_params or {}
custom_llm_provider = litellm_params.get("custom_llm_provider")
# URL is required unless using completion bridge with a provider that derives endpoint from model
# (e.g., bedrock/agentcore derives endpoint from ARN in model string)
if not agent_url and not custom_llm_provider:
return _jsonrpc_error(
request_id, -32000, f"Agent '{agent_id}' has no URL configured", 500
)
verbose_proxy_logger.info(
f"Proxying A2A request to agent '{agent_id}' at {agent_url or 'completion-bridge'}"
)
# Set up data dict for litellm processing
if "metadata" not in body:
body["metadata"] = {}
body["metadata"]["agent_id"] = agent.agent_id
body.update(
{
"model": f"a2a_agent/{agent_name}",
"custom_llm_provider": "a2a_agent",
}
)
# Add litellm data (user_api_key, user_id, team_id, etc.)
from litellm.proxy.common_request_processing import (
ProxyBaseLLMRequestProcessing,
)
processor = ProxyBaseLLMRequestProcessing(data=body)
data, logging_obj = await processor.common_processing_pre_call_logic(
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
proxy_config=proxy_config,
route_type="asend_message",
version=version,
)
# Build merged headers for the backend agent
static_headers: Dict[str, str] = dict(agent.static_headers or {})
raw_headers = dict(request.headers)
normalized = {k.lower(): v for k, v in raw_headers.items()}
dynamic_headers: Dict[str, str] = {}
# 1. Admin-configured extra_headers: forward named headers from client request
if agent.extra_headers:
for header_name in agent.extra_headers:
val = normalized.get(header_name.lower())
if val is not None:
dynamic_headers[header_name] = val
# 2. Convention-based forwarding: x-a2a-{agent_id_or_name}-{header_name}
# Matches both agent_id (UUID) and agent_name (alias), case-insensitive.
for alias in (agent.agent_id.lower(), agent.agent_name.lower()):
prefix = f"x-a2a-{alias}-"
for key, val in normalized.items():
if key.startswith(prefix):
header_name = key[len(prefix) :]
if header_name:
dynamic_headers[header_name] = val
agent_extra_headers = merge_agent_headers(
dynamic_headers=dynamic_headers or None,
static_headers=static_headers or None,
)
# Route through SDK functions
if method == "message/send":
from a2a.types import MessageSendParams, SendMessageRequest
a2a_request = SendMessageRequest(
id=request_id,
params=MessageSendParams(**params),
)
response = await asend_message(
request=a2a_request,
api_base=agent_url,
litellm_params=litellm_params,
agent_id=agent.agent_id,
metadata=data.get("metadata", {}),
proxy_server_request=data.get("proxy_server_request"),
litellm_logging_obj=logging_obj,
agent_extra_headers=agent_extra_headers,
)
response = await proxy_logging_obj.post_call_success_hook(
user_api_key_dict=user_api_key_dict,
data=data,
response=response,
)
return JSONResponse(
content=(
response.model_dump(mode="json", exclude_none=True) # type: ignore
if hasattr(response, "model_dump")
else response
)
)
elif method == "message/stream":
return await _handle_stream_message(
api_base=agent_url,
request_id=request_id,
params=params,
litellm_params=litellm_params,
agent_id=agent.agent_id,
metadata=data.get("metadata", {}),
proxy_server_request=data.get("proxy_server_request"),
agent_extra_headers=agent_extra_headers,
user_api_key_dict=user_api_key_dict,
request_data=data,
proxy_logging_obj=proxy_logging_obj,
)
else:
return _jsonrpc_error(request_id, -32601, f"Method '{method}' not found")
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error invoking agent: {e}")
return _jsonrpc_error(body.get("id"), -32603, f"Internal error: {str(e)}", 500)

View File

@@ -0,0 +1,53 @@
"""
A2A Agent Routing
Handles routing for A2A agents (models with "a2a/<agent-name>" prefix).
Looks up agents in the registry and injects their API base URL.
"""
from typing import Any, Optional
import litellm
from litellm._logging import verbose_proxy_logger
def route_a2a_agent_request(data: dict, route_type: str) -> Optional[Any]:
"""
Route A2A agent requests directly to litellm with injected API base.
Returns None if not an A2A request (allows normal routing to continue).
"""
# Import here to avoid circular imports
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
from litellm.proxy.route_llm_request import (
ROUTE_ENDPOINT_MAPPING,
ProxyModelNotFoundError,
)
model_name = data.get("model", "")
# Check if this is an A2A agent request
if not isinstance(model_name, str) or not model_name.startswith("a2a/"):
return None
# Extract agent name (e.g., "a2a/my-agent" -> "my-agent")
agent_name = model_name[4:]
# Look up agent in registry
agent = global_agent_registry.get_agent_by_name(agent_name)
if agent is None:
verbose_proxy_logger.error(f"[A2A] Agent '{agent_name}' not found in registry")
route_name = ROUTE_ENDPOINT_MAPPING.get(route_type, route_type)
raise ProxyModelNotFoundError(route=route_name, model_name=model_name)
# Get API base URL from agent config
if not agent.agent_card_params or "url" not in agent.agent_card_params:
verbose_proxy_logger.error(f"[A2A] Agent '{agent_name}' has no URL configured")
route_name = ROUTE_ENDPOINT_MAPPING.get(route_type, route_type)
raise ProxyModelNotFoundError(route=route_name, model_name=model_name)
# Inject API base and route to litellm
data["api_base"] = agent.agent_card_params["url"]
verbose_proxy_logger.debug(f"[A2A] Routing {model_name} to {data['api_base']}")
return getattr(litellm, f"{route_type}")(**data)

View File

@@ -0,0 +1,458 @@
import hashlib
import json
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
import litellm
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy.management_helpers.object_permission_utils import (
handle_update_object_permission_common,
)
from litellm.proxy.utils import PrismaClient
from litellm.types.agents import AgentConfig, AgentResponse, PatchAgentRequest
class AgentRegistry:
def __init__(self):
self.agent_list: List[AgentResponse] = []
def reset_agent_list(self):
self.agent_list = []
def register_agent(self, agent_config: AgentResponse):
self.agent_list.append(agent_config)
def deregister_agent(self, agent_name: str):
self.agent_list = [
agent for agent in self.agent_list if agent.agent_name != agent_name
]
def get_agent_list(self, agent_names: Optional[List[str]] = None):
if agent_names is not None:
return [
agent for agent in self.agent_list if agent.agent_name in agent_names
]
return self.agent_list
def get_public_agent_list(self) -> List[AgentResponse]:
public_agent_list: List[AgentResponse] = []
if litellm.public_agent_groups is None:
return public_agent_list
for agent in self.agent_list:
if agent.agent_id in litellm.public_agent_groups:
public_agent_list.append(agent)
return public_agent_list
def _create_agent_id(self, agent_config: AgentConfig) -> str:
return hashlib.sha256(
json.dumps(agent_config, sort_keys=True).encode()
).hexdigest()
def load_agents_from_config(self, agent_config: Optional[List[AgentConfig]] = None):
if agent_config is None:
return None
for agent_config_item in agent_config:
if not isinstance(agent_config_item, dict):
raise ValueError("agent_config must be a list of dictionaries")
agent_name = agent_config_item.get("agent_name")
agent_card_params = agent_config_item.get("agent_card_params")
if not all([agent_name, agent_card_params]):
continue
# create a stable hash id for config item
config_hash = self._create_agent_id(agent_config_item)
self.register_agent(agent_config=AgentResponse(agent_id=config_hash, **agent_config_item)) # type: ignore
def load_agents_from_db_and_config(
self,
agent_config: Optional[List[AgentConfig]] = None,
db_agents: Optional[List[Dict[str, Any]]] = None,
):
self.reset_agent_list()
if agent_config:
for agent_config_item in agent_config:
if not isinstance(agent_config_item, dict):
raise ValueError("agent_config must be a list of dictionaries")
self.register_agent(agent_config=AgentResponse(agent_id=self._create_agent_id(agent_config_item), **agent_config_item)) # type: ignore
if db_agents:
for db_agent in db_agents:
if not isinstance(db_agent, dict):
raise ValueError("db_agents must be a list of dictionaries")
self.register_agent(agent_config=AgentResponse(**db_agent)) # type: ignore
return self.agent_list
###########################################################
########### DB management helpers for agents ###########
############################################################
async def add_agent_to_db(
self, agent: AgentConfig, prisma_client: PrismaClient, created_by: str
) -> AgentResponse:
"""
Add an agent to the database
"""
try:
agent_name = agent.get("agent_name")
# Serialize litellm_params
litellm_params_obj: Any = agent.get("litellm_params", {})
if hasattr(litellm_params_obj, "model_dump"):
litellm_params_dict = litellm_params_obj.model_dump()
else:
litellm_params_dict = (
dict(litellm_params_obj) if litellm_params_obj else {}
)
litellm_params: str = safe_dumps(litellm_params_dict)
# Serialize agent_card_params
agent_card_params_obj: Any = agent.get("agent_card_params", {})
if hasattr(agent_card_params_obj, "model_dump"):
agent_card_params_dict = agent_card_params_obj.model_dump()
else:
agent_card_params_dict = (
dict(agent_card_params_obj) if agent_card_params_obj else {}
)
agent_card_params: str = safe_dumps(agent_card_params_dict)
# Handle object_permission (MCP tool access for agent)
object_permission_id: Optional[str] = None
if agent.get("object_permission") is not None:
agent_copy = dict(agent)
object_permission_id = await handle_update_object_permission_common(
agent_copy, None, prisma_client
)
# Serialize static_headers
static_headers_obj = agent.get("static_headers")
static_headers_val: Optional[str] = (
safe_dumps(dict(static_headers_obj)) if static_headers_obj else None
)
extra_headers_val: Optional[List[str]] = agent.get("extra_headers")
create_data: Dict[str, Any] = {
"agent_name": agent_name,
"litellm_params": litellm_params,
"agent_card_params": agent_card_params,
"created_by": created_by,
"updated_by": created_by,
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
}
if static_headers_val is not None:
create_data["static_headers"] = static_headers_val
if extra_headers_val is not None:
create_data["extra_headers"] = extra_headers_val
if object_permission_id is not None:
create_data["object_permission_id"] = object_permission_id
for rate_field in (
"tpm_limit",
"rpm_limit",
"session_tpm_limit",
"session_rpm_limit",
):
_val = agent.get(rate_field)
if _val is not None:
create_data[rate_field] = _val
# Create agent in DB
created_agent = await prisma_client.db.litellm_agentstable.create(
data=create_data,
include={"object_permission": True},
)
created_agent_dict = created_agent.model_dump()
if created_agent.object_permission is not None:
try:
created_agent_dict[
"object_permission"
] = created_agent.object_permission.model_dump()
except Exception:
created_agent_dict[
"object_permission"
] = created_agent.object_permission.dict()
return AgentResponse(**created_agent_dict) # type: ignore
except Exception as e:
raise Exception(f"Error adding agent to DB: {str(e)}")
async def delete_agent_from_db(
self, agent_id: str, prisma_client: PrismaClient
) -> Dict[str, Any]:
"""
Delete an agent from the database
"""
try:
deleted_agent = await prisma_client.db.litellm_agentstable.delete(
where={"agent_id": agent_id}
)
return dict(deleted_agent)
except Exception as e:
raise Exception(f"Error deleting agent from DB: {str(e)}")
async def patch_agent_in_db(
self,
agent_id: str,
agent: PatchAgentRequest,
prisma_client: PrismaClient,
updated_by: str,
) -> AgentResponse:
"""
Patch an agent in the database.
Get the existing agent from the database and patch it with the new values.
Args:
agent_id: The ID of the agent to patch
agent: The new agent values to patch
prisma_client: The Prisma client to use
updated_by: The user ID of the user who is patching the agent
Returns:
The patched agent
"""
try:
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
where={"agent_id": agent_id}
)
if existing_agent is not None:
existing_agent = dict(existing_agent)
if existing_agent is None:
raise Exception(f"Agent with ID {agent_id} not found")
augment_agent = {**existing_agent, **agent}
update_data: Dict[str, Any] = {}
if augment_agent.get("agent_name"):
update_data["agent_name"] = augment_agent.get("agent_name")
if augment_agent.get("litellm_params"):
update_data["litellm_params"] = safe_dumps(
augment_agent.get("litellm_params")
)
if augment_agent.get("agent_card_params"):
update_data["agent_card_params"] = safe_dumps(
augment_agent.get("agent_card_params")
)
for rate_field in (
"tpm_limit",
"rpm_limit",
"session_tpm_limit",
"session_rpm_limit",
):
if rate_field in agent:
update_data[rate_field] = agent.get(rate_field)
if "static_headers" in agent:
headers_value = agent.get("static_headers")
update_data["static_headers"] = safe_dumps(
dict(headers_value) if headers_value is not None else {}
)
if "extra_headers" in agent:
extra_headers_value = agent.get("extra_headers")
update_data["extra_headers"] = (
extra_headers_value if extra_headers_value is not None else []
)
if agent.get("object_permission") is not None:
agent_copy = dict(augment_agent)
existing_object_permission_id = existing_agent.get(
"object_permission_id"
)
object_permission_id = await handle_update_object_permission_common(
agent_copy,
existing_object_permission_id,
prisma_client,
)
if object_permission_id is not None:
update_data["object_permission_id"] = object_permission_id
# Patch agent in DB
patched_agent = await prisma_client.db.litellm_agentstable.update(
where={"agent_id": agent_id},
data={
**update_data,
"updated_by": updated_by,
"updated_at": datetime.now(timezone.utc),
},
include={"object_permission": True},
)
patched_agent_dict = patched_agent.model_dump()
if patched_agent.object_permission is not None:
try:
patched_agent_dict[
"object_permission"
] = patched_agent.object_permission.model_dump()
except Exception:
patched_agent_dict[
"object_permission"
] = patched_agent.object_permission.dict()
return AgentResponse(**patched_agent_dict) # type: ignore
except Exception as e:
raise Exception(f"Error patching agent in DB: {str(e)}")
async def update_agent_in_db(
self,
agent_id: str,
agent: AgentConfig,
prisma_client: PrismaClient,
updated_by: str,
) -> AgentResponse:
"""
Update an agent in the database
"""
try:
agent_name = agent.get("agent_name")
# Serialize litellm_params
litellm_params_obj: Any = agent.get("litellm_params", {})
if hasattr(litellm_params_obj, "model_dump"):
litellm_params_dict = litellm_params_obj.model_dump()
else:
litellm_params_dict = (
dict(litellm_params_obj) if litellm_params_obj else {}
)
litellm_params: str = safe_dumps(litellm_params_dict)
# Serialize agent_card_params
agent_card_params_obj: Any = agent.get("agent_card_params", {})
if hasattr(agent_card_params_obj, "model_dump"):
agent_card_params_dict = agent_card_params_obj.model_dump()
else:
agent_card_params_dict = (
dict(agent_card_params_obj) if agent_card_params_obj else {}
)
agent_card_params: str = safe_dumps(agent_card_params_dict)
# Serialize static_headers for update
static_headers_obj_u = agent.get("static_headers")
static_headers_val_u: str = (
safe_dumps(dict(static_headers_obj_u))
if static_headers_obj_u is not None
else safe_dumps({})
)
extra_headers_val_u: List[str] = agent.get("extra_headers") or []
update_data: Dict[str, Any] = {
"agent_name": agent_name,
"litellm_params": litellm_params,
"agent_card_params": agent_card_params,
"static_headers": static_headers_val_u,
"extra_headers": extra_headers_val_u,
"updated_by": updated_by,
"updated_at": datetime.now(timezone.utc),
}
for rate_field in (
"tpm_limit",
"rpm_limit",
"session_tpm_limit",
"session_rpm_limit",
):
_val = agent.get(rate_field)
if _val is not None:
update_data[rate_field] = _val
if agent.get("object_permission") is not None:
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
where={"agent_id": agent_id}
)
existing_object_permission_id = (
existing_agent.object_permission_id
if existing_agent is not None
else None
)
agent_copy = dict(agent)
object_permission_id = await handle_update_object_permission_common(
agent_copy,
existing_object_permission_id,
prisma_client,
)
if object_permission_id is not None:
update_data["object_permission_id"] = object_permission_id
# Update agent in DB
updated_agent = await prisma_client.db.litellm_agentstable.update(
where={"agent_id": agent_id},
data=update_data,
include={"object_permission": True},
)
updated_agent_dict = updated_agent.model_dump()
if updated_agent.object_permission is not None:
try:
updated_agent_dict[
"object_permission"
] = updated_agent.object_permission.model_dump()
except Exception:
updated_agent_dict[
"object_permission"
] = updated_agent.object_permission.dict()
return AgentResponse(**updated_agent_dict) # type: ignore
except Exception as e:
raise Exception(f"Error updating agent in DB: {str(e)}")
@staticmethod
async def get_all_agents_from_db(
prisma_client: PrismaClient,
) -> List[Dict[str, Any]]:
"""
Get all agents from the database
"""
try:
agents_from_db = await prisma_client.db.litellm_agentstable.find_many(
order={"created_at": "desc"},
include={"object_permission": True},
)
agents: List[Dict[str, Any]] = []
for agent in agents_from_db:
agent_dict = dict(agent)
# object_permission is eagerly loaded via include above
if agent.object_permission is not None:
try:
agent_dict[
"object_permission"
] = agent.object_permission.model_dump()
except Exception:
agent_dict["object_permission"] = agent.object_permission.dict()
agents.append(agent_dict)
return agents
except Exception as e:
raise Exception(f"Error getting agents from DB: {str(e)}")
def get_agent_by_id(
self,
agent_id: str,
) -> Optional[AgentResponse]:
"""
Get an agent by its ID from the database
"""
try:
for agent in self.agent_list:
if agent.agent_id == agent_id:
return agent
return None
except Exception as e:
raise Exception(f"Error getting agent from DB: {str(e)}")
def get_agent_by_name(self, agent_name: str) -> Optional[AgentResponse]:
"""
Get an agent by its name from the database
"""
try:
for agent in self.agent_list:
if agent.agent_name == agent_name:
return agent
return None
except Exception as e:
raise Exception(f"Error getting agent from DB: {str(e)}")
global_agent_registry = AgentRegistry()

View File

@@ -0,0 +1,451 @@
"""
Agent Permission Handler for LiteLLM Proxy.
Handles agent permission checking for keys and teams using object_permission_id.
Follows the same pattern as MCP permission handling.
"""
from typing import List, Optional, Set
from litellm._logging import verbose_logger
from litellm.proxy._types import (
LiteLLM_ObjectPermissionTable,
LiteLLM_TeamTable,
UI_TEAM_ID,
UserAPIKeyAuth,
)
class AgentRequestHandler:
"""
Class to handle agent permission checking, including:
1. Key-level agent permissions
2. Team-level agent permissions
3. Agent access group resolution
Follows the same inheritance logic as MCP:
- If team has restrictions and key has restrictions: use intersection
- If team has restrictions and key has none: inherit from team
- If team has no restrictions: use key restrictions
- If no restrictions: allow all agents
"""
@staticmethod
async def get_allowed_agents(
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
) -> List[str]:
"""
Get list of allowed agent IDs for the given user/key based on permissions.
Returns:
List[str]: List of allowed agent IDs. Empty list means no restrictions (allow all).
"""
try:
allowed_agents: List[str] = []
allowed_agents_for_key = (
await AgentRequestHandler._get_allowed_agents_for_key(user_api_key_auth)
)
allowed_agents_for_team = (
await AgentRequestHandler._get_allowed_agents_for_team(
user_api_key_auth
)
)
# If team has agent restrictions, handle inheritance and intersection logic
if len(allowed_agents_for_team) > 0:
if len(allowed_agents_for_key) > 0:
# Key has its own agent permissions - use intersection with team permissions
for agent_id in allowed_agents_for_key:
if agent_id in allowed_agents_for_team:
allowed_agents.append(agent_id)
else:
# Key has no agent permissions - inherit from team
allowed_agents = allowed_agents_for_team
else:
allowed_agents = allowed_agents_for_key
return list(set(allowed_agents))
except Exception as e:
verbose_logger.warning(f"Failed to get allowed agents: {str(e)}")
return []
@staticmethod
async def is_agent_allowed(
agent_id: str,
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
) -> bool:
"""
Check if a specific agent is allowed for the given user/key.
Args:
agent_id: The agent ID to check
user_api_key_auth: User authentication info
Returns:
bool: True if agent is allowed, False otherwise
"""
allowed_agents = await AgentRequestHandler.get_allowed_agents(user_api_key_auth)
# Empty list means no restrictions - allow all
if len(allowed_agents) == 0:
return True
return agent_id in allowed_agents
@staticmethod
def _get_key_object_permission(
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
) -> Optional[LiteLLM_ObjectPermissionTable]:
"""
Get key object_permission - already loaded by get_key_object() in main auth flow.
Note: object_permission is automatically populated when the key is fetched via
get_key_object() in litellm/proxy/auth/auth_checks.py
"""
if not user_api_key_auth:
return None
return user_api_key_auth.object_permission
@staticmethod
async def _get_team_object_permission(
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
) -> Optional[LiteLLM_ObjectPermissionTable]:
"""
Get team object_permission - automatically loaded by get_team_object() in main auth flow.
Note: object_permission is automatically populated when the team is fetched via
get_team_object() in litellm/proxy/auth/auth_checks.py
"""
from litellm.proxy.auth.auth_checks import get_team_object
from litellm.proxy.proxy_server import (
prisma_client,
proxy_logging_obj,
user_api_key_cache,
)
if not user_api_key_auth or not user_api_key_auth.team_id or not prisma_client:
return None
# Get the team object (which has object_permission already loaded)
team_obj: Optional[LiteLLM_TeamTable] = await get_team_object(
team_id=user_api_key_auth.team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=user_api_key_auth.parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
if not team_obj:
return None
return team_obj.object_permission
@staticmethod
async def _get_allowed_agents_for_key(
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
) -> List[str]:
"""
Get allowed agents for a key.
1. First checks native key-level agent permissions (object_permission)
2. Also includes agents from key's access_group_ids (unified access groups)
Note: object_permission is already loaded by get_key_object() in main auth flow.
"""
if user_api_key_auth is None:
return []
try:
all_agents: List[str] = []
# 1. Get agents from object_permission (native permissions)
key_object_permission = AgentRequestHandler._get_key_object_permission(
user_api_key_auth
)
if key_object_permission is not None:
# Get direct agents
direct_agents = key_object_permission.agents or []
# Get agents from access groups
access_group_agents = (
await AgentRequestHandler._get_agents_from_access_groups(
key_object_permission.agent_access_groups or []
)
)
all_agents = direct_agents + access_group_agents
# 2. Fallback: get agent IDs from key's access_group_ids (unified access groups)
key_access_group_ids = user_api_key_auth.access_group_ids or []
if key_access_group_ids:
from litellm.proxy.auth.auth_checks import (
_get_agent_ids_from_access_groups,
)
unified_agents = await _get_agent_ids_from_access_groups(
access_group_ids=key_access_group_ids,
)
all_agents.extend(unified_agents)
return list(set(all_agents))
except Exception as e:
verbose_logger.warning(f"Failed to get allowed agents for key: {str(e)}")
return []
@staticmethod
async def _get_allowed_agents_for_team(
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
) -> List[str]:
"""
Get allowed agents for a team.
1. First checks native team-level agent permissions (object_permission)
2. Also includes agents from team's access_group_ids (unified access groups)
Fetches the team object once and reuses it for both permission sources.
"""
if user_api_key_auth is None:
return []
if user_api_key_auth.team_id is None:
return []
try:
from litellm.proxy.auth.auth_checks import get_team_object
from litellm.proxy.proxy_server import (
prisma_client,
proxy_logging_obj,
user_api_key_cache,
)
if not prisma_client:
return []
# Fetch the team object once for both permission sources
team_obj = await get_team_object(
team_id=user_api_key_auth.team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=user_api_key_auth.parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
if team_obj is None:
return []
all_agents: List[str] = []
# 1. Get agents from object_permission (native permissions)
object_permissions = team_obj.object_permission
if object_permissions is not None:
# Get direct agents
direct_agents = object_permissions.agents or []
# Get agents from access groups
access_group_agents = (
await AgentRequestHandler._get_agents_from_access_groups(
object_permissions.agent_access_groups or []
)
)
all_agents = direct_agents + access_group_agents
# 2. Also include agents from team's access_group_ids (unified access groups)
team_access_group_ids = team_obj.access_group_ids or []
if team_access_group_ids:
from litellm.proxy.auth.auth_checks import (
_get_agent_ids_from_access_groups,
)
unified_agents = await _get_agent_ids_from_access_groups(
access_group_ids=team_access_group_ids,
)
all_agents.extend(unified_agents)
return list(set(all_agents))
except Exception as e:
# litellm-dashboard is the default UI team and will never have agents;
# skip noisy warnings for it.
if user_api_key_auth.team_id != UI_TEAM_ID:
verbose_logger.warning(
f"Failed to get allowed agents for team: {str(e)}"
)
return []
@staticmethod
def _get_config_agent_ids_for_access_groups(
config_agents: List, access_groups: List[str]
) -> Set[str]:
"""
Helper to get agent_ids from config-loaded agents that match any of the given access groups.
"""
server_ids: Set[str] = set()
for agent in config_agents:
agent_access_groups = getattr(agent, "agent_access_groups", None)
if agent_access_groups:
if any(group in agent_access_groups for group in access_groups):
server_ids.add(agent.agent_id)
return server_ids
@staticmethod
async def _get_db_agent_ids_for_access_groups(
prisma_client, access_groups: List[str]
) -> Set[str]:
"""
Helper to get agent_ids from DB agents that match any of the given access groups.
"""
agent_ids: Set[str] = set()
if access_groups and prisma_client is not None:
try:
agents = await prisma_client.db.litellm_agentstable.find_many(
where={"agent_access_groups": {"hasSome": access_groups}}
)
for agent in agents:
agent_ids.add(agent.agent_id)
except Exception as e:
verbose_logger.debug(f"Error getting agents from access groups: {e}")
return agent_ids
@staticmethod
async def _get_agents_from_access_groups(
access_groups: List[str],
) -> List[str]:
"""
Resolve agent access groups to agent IDs by querying BOTH the agent table (DB) AND config-loaded agents.
"""
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
from litellm.proxy.proxy_server import prisma_client
try:
# Use the helper for config-loaded agents
agent_ids = AgentRequestHandler._get_config_agent_ids_for_access_groups(
global_agent_registry.agent_list, access_groups
)
# Use the helper for DB agents
db_agent_ids = (
await AgentRequestHandler._get_db_agent_ids_for_access_groups(
prisma_client, access_groups
)
)
agent_ids.update(db_agent_ids)
return list(agent_ids)
except Exception as e:
verbose_logger.warning(f"Failed to get agents from access groups: {str(e)}")
return []
@staticmethod
async def get_agent_access_groups(
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
) -> List[str]:
"""
Get list of agent access groups for the given user/key based on permissions.
"""
access_groups: List[str] = []
access_groups_for_key = (
await AgentRequestHandler._get_agent_access_groups_for_key(
user_api_key_auth
)
)
access_groups_for_team = (
await AgentRequestHandler._get_agent_access_groups_for_team(
user_api_key_auth
)
)
# If team has access groups, then key must have a subset of the team's access groups
if len(access_groups_for_team) > 0:
for access_group in access_groups_for_key:
if access_group in access_groups_for_team:
access_groups.append(access_group)
else:
access_groups = access_groups_for_key
return list(set(access_groups))
@staticmethod
async def _get_agent_access_groups_for_key(
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
) -> List[str]:
"""Get agent access groups for the key."""
from litellm.proxy.auth.auth_checks import get_object_permission
from litellm.proxy.proxy_server import (
prisma_client,
proxy_logging_obj,
user_api_key_cache,
)
if user_api_key_auth is None:
return []
if user_api_key_auth.object_permission_id is None:
return []
if prisma_client is None:
verbose_logger.debug("prisma_client is None")
return []
try:
key_object_permission = await get_object_permission(
object_permission_id=user_api_key_auth.object_permission_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=user_api_key_auth.parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
if key_object_permission is None:
return []
return key_object_permission.agent_access_groups or []
except Exception as e:
verbose_logger.warning(
f"Failed to get agent access groups for key: {str(e)}"
)
return []
@staticmethod
async def _get_agent_access_groups_for_team(
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
) -> List[str]:
"""Get agent access groups for the team."""
from litellm.proxy.auth.auth_checks import get_team_object
from litellm.proxy.proxy_server import (
prisma_client,
proxy_logging_obj,
user_api_key_cache,
)
if user_api_key_auth is None:
return []
if user_api_key_auth.team_id is None:
return []
if prisma_client is None:
verbose_logger.debug("prisma_client is None")
return []
try:
team_obj: Optional[LiteLLM_TeamTable] = await get_team_object(
team_id=user_api_key_auth.team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=user_api_key_auth.parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
if team_obj is None:
verbose_logger.debug("team_obj is None")
return []
object_permissions = team_obj.object_permission
if object_permissions is None:
return []
return object_permissions.agent_access_groups or []
except Exception as e:
verbose_logger.warning(
f"Failed to get agent access groups for team: {str(e)}"
)
return []

View File

@@ -0,0 +1,944 @@
"""
Agent endpoints for registering + discovering agents via LiteLLM.
Follows the A2A Spec.
1. Register an agent via POST `/v1/agents`
2. Discover agents via GET `/v1/agents`
3. Get specific agent via GET `/v1/agents/{agent_id}`
"""
import asyncio
import os
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.rbac_utils import check_feature_access_for_user
from litellm.proxy.management_endpoints.common_daily_activity import get_daily_activity
from litellm.types.agents import (
AgentConfig,
AgentMakePublicResponse,
AgentResponse,
MakeAgentsPublicRequest,
PatchAgentRequest,
)
from litellm.types.llms.custom_http import httpxSpecialProvider
from litellm.types.proxy.management_endpoints.common_daily_activity import (
SpendAnalyticsPaginatedResponse,
)
router = APIRouter()
def _check_agent_management_permission(user_api_key_dict: UserAPIKeyAuth) -> None:
"""
Raises HTTP 403 if the caller does not have permission to create, update,
or delete agents. Only PROXY_ADMIN users are allowed to perform these
write operations.
"""
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=403,
detail={
"error": "Only proxy admins can create, update, or delete agents. Your role={}".format(
user_api_key_dict.user_role
)
},
)
AGENT_HEALTH_CHECK_TIMEOUT_SECONDS = float(
os.environ.get("LITELLM_AGENT_HEALTH_CHECK_TIMEOUT", "5.0")
)
AGENT_HEALTH_CHECK_GATHER_TIMEOUT_SECONDS = float(
os.environ.get("LITELLM_AGENT_HEALTH_CHECK_GATHER_TIMEOUT", "30.0")
)
async def _check_agent_url_health(
agent: AgentResponse,
) -> Dict[str, Any]:
"""
Perform a GET request against the agent's URL and return the health result.
Returns a dict with ``agent_id``, ``healthy`` (bool), and an optional
``error`` message.
"""
url = (agent.agent_card_params or {}).get("url")
if not url:
return {"agent_id": agent.agent_id, "healthy": True}
try:
client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.AgentHealthCheck,
params={"timeout": AGENT_HEALTH_CHECK_TIMEOUT_SECONDS},
)
response = await client.get(url)
if response.status_code >= 500:
return {
"agent_id": agent.agent_id,
"healthy": False,
"error": f"HTTP {response.status_code}",
}
return {"agent_id": agent.agent_id, "healthy": True}
except Exception as exc:
return {
"agent_id": agent.agent_id,
"healthy": False,
"error": str(exc),
}
@router.get(
"/v1/agents",
tags=["[beta] A2A Agents"],
dependencies=[Depends(user_api_key_auth)],
response_model=List[AgentResponse],
)
async def get_agents(
request: Request,
health_check: bool = Query(
False,
description="When true, performs a GET request to each agent's URL. Agents with reachable URLs (HTTP status < 500) and agents without a URL are returned; unreachable agents are filtered out.",
),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), # Used for auth
):
"""
Example usage:
```
curl -X GET "http://localhost:4000/v1/agents" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer your-key" \
```
Pass `?health_check=true` to filter out agents whose URL is unreachable:
```
curl -X GET "http://localhost:4000/v1/agents?health_check=true" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer your-key" \
```
Returns: List[AgentResponse]
"""
await check_feature_access_for_user(user_api_key_dict, "agents")
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
AgentRequestHandler,
)
try:
returned_agents: List[AgentResponse] = []
# Admin users get all agents
if (
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
):
returned_agents = global_agent_registry.get_agent_list()
else:
# Get allowed agents from object_permission (key/team level)
allowed_agent_ids = await AgentRequestHandler.get_allowed_agents(
user_api_key_auth=user_api_key_dict
)
# If no restrictions (empty list), return all agents
if len(allowed_agent_ids) == 0:
returned_agents = global_agent_registry.get_agent_list()
else:
# Filter agents by allowed IDs
all_agents = global_agent_registry.get_agent_list()
returned_agents = [
agent for agent in all_agents if agent.agent_id in allowed_agent_ids
]
# Fetch current spend from DB for all returned agents
from litellm.proxy.proxy_server import prisma_client
if prisma_client is not None:
agent_ids = [agent.agent_id for agent in returned_agents]
if agent_ids:
db_agents = await prisma_client.db.litellm_agentstable.find_many(
where={"agent_id": {"in": agent_ids}},
)
spend_map = {a.agent_id: a.spend for a in db_agents}
for agent in returned_agents:
if agent.agent_id in spend_map:
agent.spend = spend_map[agent.agent_id]
# add is_public field to each agent - we do it this way, to allow setting config agents as public
for agent in returned_agents:
if agent.litellm_params is None:
agent.litellm_params = {}
agent.litellm_params[
"is_public"
] = litellm.public_agent_groups is not None and (
agent.agent_id in litellm.public_agent_groups
)
if health_check:
agents_with_url = [
agent
for agent in returned_agents
if (agent.agent_card_params or {}).get("url")
]
agents_without_url = [
agent
for agent in returned_agents
if not (agent.agent_card_params or {}).get("url")
]
try:
health_results = await asyncio.wait_for(
asyncio.gather(
*[_check_agent_url_health(agent) for agent in agents_with_url]
),
timeout=AGENT_HEALTH_CHECK_GATHER_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
verbose_proxy_logger.warning(
"Agent health check gather timed out after %s seconds",
AGENT_HEALTH_CHECK_GATHER_TIMEOUT_SECONDS,
)
health_results = [
{
"agent_id": agent.agent_id,
"healthy": False,
"error": "Health check timed out",
}
for agent in agents_with_url
]
healthy_ids = {
result["agent_id"] for result in health_results if result["healthy"]
}
returned_agents = [
agent for agent in agents_with_url if agent.agent_id in healthy_ids
] + agents_without_url
return returned_agents
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.agent_endpoints.get_agents(): Exception occurred - {}".format(
str(e)
)
)
raise HTTPException(
status_code=500, detail={"error": f"Internal server error: {str(e)}"}
)
#### CRUD ENDPOINTS FOR AGENTS ####
from litellm.proxy.agent_endpoints.agent_registry import (
global_agent_registry as AGENT_REGISTRY,
)
@router.post(
"/v1/agents",
tags=["[beta] A2A Agents"],
dependencies=[Depends(user_api_key_auth)],
response_model=AgentResponse,
)
async def create_agent(
request: AgentConfig,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create a new agent
Example Request:
```bash
curl -X POST "http://localhost:4000/agents" \\
-H "Authorization: Bearer <your_api_key>" \\
-H "Content-Type: application/json" \\
-d '{
"agent": {
"agent_name": "my-custom-agent",
"agent_card_params": {
"protocolVersion": "1.0",
"name": "Hello World Agent",
"description": "Just a hello world agent",
"url": "http://localhost:9999/",
"version": "1.0.0",
"defaultInputModes": ["text"],
"defaultOutputModes": ["text"],
"capabilities": {
"streaming": true
},
"skills": [
{
"id": "hello_world",
"name": "Returns hello world",
"description": "just returns hello world",
"tags": ["hello world"],
"examples": ["hi", "hello world"]
}
]
},
"litellm_params": {
"make_public": true
}
}
}'
```
"""
await check_feature_access_for_user(user_api_key_dict, "agents")
from litellm.proxy.proxy_server import prisma_client
_check_agent_management_permission(user_api_key_dict)
if prisma_client is None:
raise HTTPException(status_code=500, detail="Prisma client not initialized")
try:
# Get the user ID from the API key auth
created_by = user_api_key_dict.user_id or "unknown"
# check for naming conflicts
existing_agent = AGENT_REGISTRY.get_agent_by_name(
agent_name=request.get("agent_name") # type: ignore
)
if existing_agent is not None:
raise HTTPException(
status_code=400,
detail=f"Agent with name {request.get('agent_name')} already exists",
)
result = await AGENT_REGISTRY.add_agent_to_db(
agent=request, prisma_client=prisma_client, created_by=created_by
)
agent_name = result.agent_name
agent_id = result.agent_id
# Also register in memory
try:
AGENT_REGISTRY.register_agent(agent_config=result)
verbose_proxy_logger.info(
f"Successfully registered agent '{agent_name}' (ID: {agent_id}) in memory"
)
except Exception as reg_error:
verbose_proxy_logger.warning(
f"Failed to register agent '{agent_name}' (ID: {agent_id}) in memory: {reg_error}"
)
return result
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error adding agent to db: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get(
"/v1/agents/{agent_id}",
tags=["[beta] A2A Agents"],
dependencies=[Depends(user_api_key_auth)],
response_model=AgentResponse,
)
async def get_agent_by_id(
agent_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get a specific agent by ID
Example Request:
```bash
curl -X GET "http://localhost:4000/agents/123e4567-e89b-12d3-a456-426614174000" \\
-H "Authorization: Bearer <your_api_key>"
```
"""
await check_feature_access_for_user(user_api_key_dict, "agents")
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Prisma client not initialized")
try:
agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id)
if agent is None:
agent_row = await prisma_client.db.litellm_agentstable.find_unique(
where={"agent_id": agent_id},
include={"object_permission": True},
)
if agent_row is not None:
agent_dict = agent_row.model_dump()
if agent_row.object_permission is not None:
try:
agent_dict[
"object_permission"
] = agent_row.object_permission.model_dump()
except Exception:
agent_dict[
"object_permission"
] = agent_row.object_permission.dict()
agent = AgentResponse(**agent_dict) # type: ignore
else:
# Agent found in memory — refresh spend from DB
db_row = await prisma_client.db.litellm_agentstable.find_unique(
where={"agent_id": agent_id}
)
if db_row is not None:
agent.spend = db_row.spend
if agent is None:
raise HTTPException(
status_code=404, detail=f"Agent with ID {agent_id} not found"
)
return agent
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error getting agent from db: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.put(
"/v1/agents/{agent_id}",
tags=["[beta] A2A Agents"],
dependencies=[Depends(user_api_key_auth)],
response_model=AgentResponse,
)
async def update_agent(
agent_id: str,
request: AgentConfig,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Update an existing agent
Example Request:
```bash
curl -X PUT "http://localhost:4000/agents/123e4567-e89b-12d3-a456-426614174000" \\
-H "Authorization: Bearer <your_api_key>" \\
-H "Content-Type: application/json" \\
-d '{
"agent": {
"agent_name": "updated-agent",
"agent_card_params": {
"protocolVersion": "1.0",
"name": "Updated Agent",
"description": "Updated description",
"url": "http://localhost:9999/",
"version": "1.1.0",
"defaultInputModes": ["text"],
"defaultOutputModes": ["text"],
"capabilities": {
"streaming": true
},
"skills": []
},
"litellm_params": {
"make_public": false
}
}
}'
```
"""
await check_feature_access_for_user(user_api_key_dict, "agents")
from litellm.proxy.proxy_server import prisma_client
_check_agent_management_permission(user_api_key_dict)
if prisma_client is None:
raise HTTPException(
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
)
try:
# Check if agent exists
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
where={"agent_id": agent_id}
)
if existing_agent is not None:
existing_agent = dict(existing_agent)
if existing_agent is None:
raise HTTPException(
status_code=404, detail=f"Agent with ID {agent_id} not found"
)
# Get the user ID from the API key auth
updated_by = user_api_key_dict.user_id or "unknown"
result = await AGENT_REGISTRY.update_agent_in_db(
agent_id=agent_id,
agent=request,
prisma_client=prisma_client,
updated_by=updated_by,
)
# deregister in memory
AGENT_REGISTRY.deregister_agent(agent_name=existing_agent.get("agent_name")) # type: ignore
# register in memory
AGENT_REGISTRY.register_agent(agent_config=result)
verbose_proxy_logger.info(
f"Successfully updated agent '{existing_agent.get('agent_name')}' (ID: {agent_id}) in memory"
)
return result
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error updating agent: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.patch(
"/v1/agents/{agent_id}",
tags=["[beta] A2A Agents"],
dependencies=[Depends(user_api_key_auth)],
response_model=AgentResponse,
)
async def patch_agent(
agent_id: str,
request: PatchAgentRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Update an existing agent
Example Request:
```bash
curl -X PUT "http://localhost:4000/agents/123e4567-e89b-12d3-a456-426614174000" \\
-H "Authorization: Bearer <your_api_key>" \\
-H "Content-Type: application/json" \\
-d '{
"agent": {
"agent_name": "updated-agent",
"agent_card_params": {
"protocolVersion": "1.0",
"name": "Updated Agent",
"description": "Updated description",
"url": "http://localhost:9999/",
"version": "1.1.0",
"defaultInputModes": ["text"],
"defaultOutputModes": ["text"],
"capabilities": {
"streaming": true
},
"skills": []
},
"litellm_params": {
"make_public": false
}
}
}'
```
"""
await check_feature_access_for_user(user_api_key_dict, "agents")
from litellm.proxy.proxy_server import prisma_client
_check_agent_management_permission(user_api_key_dict)
if prisma_client is None:
raise HTTPException(
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
)
try:
# Check if agent exists
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
where={"agent_id": agent_id}
)
if existing_agent is not None:
existing_agent = dict(existing_agent)
if existing_agent is None:
raise HTTPException(
status_code=404, detail=f"Agent with ID {agent_id} not found"
)
# Get the user ID from the API key auth
updated_by = user_api_key_dict.user_id or "unknown"
result = await AGENT_REGISTRY.patch_agent_in_db(
agent_id=agent_id,
agent=request,
prisma_client=prisma_client,
updated_by=updated_by,
)
# deregister in memory
AGENT_REGISTRY.deregister_agent(agent_name=existing_agent.get("agent_name")) # type: ignore
# register in memory
AGENT_REGISTRY.register_agent(agent_config=result)
verbose_proxy_logger.info(
f"Successfully updated agent '{existing_agent.get('agent_name')}' (ID: {agent_id}) in memory"
)
return result
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error updating agent: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete(
"/v1/agents/{agent_id}",
tags=["Agents"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_agent(
agent_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Delete an agent
Example Request:
```bash
curl -X DELETE "http://localhost:4000/agents/123e4567-e89b-12d3-a456-426614174000" \\
-H "Authorization: Bearer <your_api_key>"
```
Example Response:
```json
{
"message": "Agent 123e4567-e89b-12d3-a456-426614174000 deleted successfully"
}
```
"""
await check_feature_access_for_user(user_api_key_dict, "agents")
from litellm.proxy.proxy_server import prisma_client
_check_agent_management_permission(user_api_key_dict)
if prisma_client is None:
raise HTTPException(status_code=500, detail="Prisma client not initialized")
try:
# Check if agent exists
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
where={"agent_id": agent_id}
)
if existing_agent is not None:
existing_agent = dict[Any, Any](existing_agent)
if existing_agent is None:
raise HTTPException(
status_code=404, detail=f"Agent with ID {agent_id} not found in DB."
)
await AGENT_REGISTRY.delete_agent_from_db(
agent_id=agent_id, prisma_client=prisma_client
)
AGENT_REGISTRY.deregister_agent(agent_name=existing_agent.get("agent_name")) # type: ignore
return {"message": f"Agent {agent_id} deleted successfully"}
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error deleting agent: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/v1/agents/{agent_id}/make_public",
tags=["[beta] A2A Agents"],
dependencies=[Depends(user_api_key_auth)],
response_model=AgentMakePublicResponse,
)
async def make_agent_public(
agent_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Make an agent publicly discoverable
Example Request:
```bash
curl -X POST "http://localhost:4000/v1/agents/123e4567-e89b-12d3-a456-426614174000/make_public" \\
-H "Authorization: Bearer <your_api_key>" \\
-H "Content-Type: application/json"
```
Example Response:
```json
{
"agent_id": "123e4567-e89b-12d3-a456-426614174000",
"agent_name": "my-custom-agent",
"litellm_params": {
"make_public": true
},
"agent_card_params": {...},
"created_at": "2025-11-15T10:30:00Z",
"updated_at": "2025-11-15T10:35:00Z",
"created_by": "user123",
"updated_by": "user123"
}
```
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
)
try:
# Update the public model groups
import litellm
from litellm.proxy.agent_endpoints.agent_registry import (
global_agent_registry as AGENT_REGISTRY,
)
from litellm.proxy.proxy_server import proxy_config
# Check if user has admin permissions
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=403,
detail={
"error": "Only proxy admins can update public model groups. Your role={}".format(
user_api_key_dict.user_role
)
},
)
agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id)
if agent is None:
# check if agent exists in DB
agent = await prisma_client.db.litellm_agentstable.find_unique(
where={"agent_id": agent_id}
)
if agent is not None:
agent = AgentResponse(**agent.model_dump()) # type: ignore
if agent is None:
raise HTTPException(
status_code=404, detail=f"Agent with ID {agent_id} not found"
)
if litellm.public_agent_groups is None:
litellm.public_agent_groups = []
# handle duplicates
if agent.agent_id in litellm.public_agent_groups:
raise HTTPException(
status_code=400,
detail=f"Agent with name {agent.agent_name} already in public agent groups",
)
litellm.public_agent_groups.append(agent.agent_id)
# Load existing config
config = await proxy_config.get_config()
# Update config with new settings
if "litellm_settings" not in config or config["litellm_settings"] is None:
config["litellm_settings"] = {}
config["litellm_settings"]["public_agent_groups"] = litellm.public_agent_groups
# Save the updated config
await proxy_config.save_config(new_config=config)
verbose_proxy_logger.debug(
f"Updated public agent groups to: {litellm.public_agent_groups} by user: {user_api_key_dict.user_id}"
)
return {
"message": "Successfully updated public agent groups",
"public_agent_groups": litellm.public_agent_groups,
"updated_by": user_api_key_dict.user_id,
}
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error making agent public: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/v1/agents/make_public",
tags=["[beta] A2A Agents"],
dependencies=[Depends(user_api_key_auth)],
response_model=AgentMakePublicResponse,
)
async def make_agents_public(
request: MakeAgentsPublicRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Make multiple agents publicly discoverable
Example Request:
```bash
curl -X POST "http://localhost:4000/v1/agents/make_public" \\
-H "Authorization: Bearer <your_api_key>" \\
-H "Content-Type: application/json" \\
-d '{
"agent_ids": ["123e4567-e89b-12d3-a456-426614174000", "123e4567-e89b-12d3-a456-426614174001"]
}'
```
Example Response:
```json
{
"agent_id": "123e4567-e89b-12d3-a456-426614174000",
"agent_name": "my-custom-agent",
"litellm_params": {
"make_public": true
},
"agent_card_params": {...},
"created_at": "2025-11-15T10:30:00Z",
"updated_at": "2025-11-15T10:35:00Z",
"created_by": "user123",
"updated_by": "user123"
}
```
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
)
try:
# Update the public model groups
import litellm
from litellm.proxy.agent_endpoints.agent_registry import (
global_agent_registry as AGENT_REGISTRY,
)
from litellm.proxy.proxy_server import proxy_config
# Load existing config
config = await proxy_config.get_config()
# Check if user has admin permissions
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=403,
detail={
"error": "Only proxy admins can update public model groups. Your role={}".format(
user_api_key_dict.user_role
)
},
)
if litellm.public_agent_groups is None:
litellm.public_agent_groups = []
for agent_id in request.agent_ids:
agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id)
if agent is None:
# check if agent exists in DB
agent = await prisma_client.db.litellm_agentstable.find_unique(
where={"agent_id": agent_id}
)
if agent is not None:
agent = AgentResponse(**agent.model_dump()) # type: ignore
if agent is None:
raise HTTPException(
status_code=404, detail=f"Agent with ID {agent_id} not found"
)
litellm.public_agent_groups = request.agent_ids
# Update config with new settings
if "litellm_settings" not in config or config["litellm_settings"] is None:
config["litellm_settings"] = {}
config["litellm_settings"]["public_agent_groups"] = litellm.public_agent_groups
# Save the updated config
await proxy_config.save_config(new_config=config)
verbose_proxy_logger.debug(
f"Updated public agent groups to: {litellm.public_agent_groups} by user: {user_api_key_dict.user_id}"
)
return {
"message": "Successfully updated public agent groups",
"public_agent_groups": litellm.public_agent_groups,
"updated_by": user_api_key_dict.user_id,
}
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error making agent public: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get(
"/agent/daily/activity",
tags=["Agent Management"],
dependencies=[Depends(user_api_key_auth)],
response_model=SpendAnalyticsPaginatedResponse,
)
async def get_agent_daily_activity(
agent_ids: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
model: Optional[str] = None,
api_key: Optional[str] = None,
page: int = 1,
page_size: int = 10,
exclude_agent_ids: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get daily activity for specific agents or all accessible agents.
"""
await check_feature_access_for_user(user_api_key_dict, "agents")
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
agent_ids_list = agent_ids.split(",") if agent_ids else None
exclude_agent_ids_list: Optional[List[str]] = None
if exclude_agent_ids:
exclude_agent_ids_list = (
exclude_agent_ids.split(",") if exclude_agent_ids else None
)
where_condition = {}
if agent_ids_list:
where_condition["agent_id"] = {"in": list(agent_ids_list)}
agent_records = await prisma_client.db.litellm_agentstable.find_many(
where=where_condition
)
agent_metadata = {
agent.agent_id: {"agent_name": agent.agent_name} for agent in agent_records
}
return await get_daily_activity(
prisma_client=prisma_client,
table_name="litellm_dailyagentspend",
entity_id_field="agent_id",
entity_id=agent_ids_list,
entity_metadata_field=agent_metadata,
exclude_entity_ids=exclude_agent_ids_list,
start_date=start_date,
end_date=end_date,
model=model,
api_key=api_key,
page=page,
page_size=page_size,
)

View File

@@ -0,0 +1,94 @@
"""
Helper functions for appending A2A agents to model lists.
Used by proxy model endpoints to make agents appear in UI alongside models.
"""
from typing import List
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.proxy.management_endpoints.model_management_endpoints import (
ModelGroupInfoProxy,
)
async def append_agents_to_model_group(
model_groups: List[ModelGroupInfoProxy],
user_api_key_dict: UserAPIKeyAuth,
) -> List[ModelGroupInfoProxy]:
"""
Append A2A agents to model groups list for UI display.
Converts agents to model format with "a2a/<agent-name>" naming
so they appear in playground and work with LiteLLM routing.
"""
try:
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
AgentRequestHandler,
)
allowed_agent_ids = await AgentRequestHandler.get_allowed_agents(
user_api_key_auth=user_api_key_dict
)
for agent_id in allowed_agent_ids:
agent = global_agent_registry.get_agent_by_id(agent_id)
if agent is not None:
model_groups.append(
ModelGroupInfoProxy(
model_group=f"a2a/{agent.agent_name}",
mode="chat",
providers=["a2a"],
)
)
except Exception as e:
verbose_proxy_logger.debug(f"Error appending agents to model_group/info: {e}")
return model_groups
async def append_agents_to_model_info(
models: List[dict],
user_api_key_dict: UserAPIKeyAuth,
) -> List[dict]:
"""
Append A2A agents to model info list for UI display.
Converts agents to model format with "a2a/<agent-name>" naming
so they appear in models page and work with LiteLLM routing.
"""
try:
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
AgentRequestHandler,
)
allowed_agent_ids = await AgentRequestHandler.get_allowed_agents(
user_api_key_auth=user_api_key_dict
)
for agent_id in allowed_agent_ids:
agent = global_agent_registry.get_agent_by_id(agent_id)
if agent is not None:
models.append(
{
"model_name": f"a2a/{agent.agent_name}",
"litellm_params": {
"model": f"a2a/{agent.agent_name}",
"custom_llm_provider": "a2a",
},
"model_info": {
"id": agent.agent_id,
"mode": "chat",
"db_model": True,
"created_by": agent.created_by,
"created_at": agent.created_at,
"updated_at": agent.updated_at,
},
}
)
except Exception as e:
verbose_proxy_logger.debug(f"Error appending agents to v2/model/info: {e}")
return models

View File

@@ -0,0 +1,27 @@
"""Utility helpers for A2A agent endpoints."""
from typing import Dict, Mapping, Optional
def merge_agent_headers(
*,
dynamic_headers: Optional[Mapping[str, str]] = None,
static_headers: Optional[Mapping[str, str]] = None,
) -> Optional[Dict[str, str]]:
"""Merge outbound HTTP headers for A2A agent calls.
Merge rules:
- Start with ``dynamic_headers`` (values extracted from the incoming client request).
- Overlay ``static_headers`` (admin-configured per agent).
If both contain the same key, ``static_headers`` wins.
"""
merged: Dict[str, str] = {}
if dynamic_headers:
merged.update({str(k): str(v) for k, v in dynamic_headers.items()})
if static_headers:
merged.update({str(k): str(v) for k, v in static_headers.items()})
return merged or None

View File

@@ -0,0 +1,106 @@
#### Analytics Endpoints #####
from datetime import datetime
from typing import List, Optional
import fastapi
from fastapi import APIRouter, Depends, HTTPException, status
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
router = APIRouter()
@router.get(
"/global/activity/cache_hits",
tags=["Budget & Spend Tracking"],
dependencies=[Depends(user_api_key_auth)],
responses={
200: {"model": List[LiteLLM_SpendLogs]},
},
include_in_schema=False,
)
async def get_global_activity(
start_date: Optional[str] = fastapi.Query(
default=None,
description="Time from which to start viewing spend",
),
end_date: Optional[str] = fastapi.Query(
default=None,
description="Time till which to view spend",
),
):
"""
Get number of cache hits, vs misses
{
"daily_data": [
const chartdata = [
{
date: 'Jan 22',
cache_hits: 10,
llm_api_calls: 2000
},
{
date: 'Jan 23',
cache_hits: 10,
llm_api_calls: 12
},
],
"sum_cache_hits": 20,
"sum_llm_api_calls": 2012
}
"""
if start_date is None or end_date is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Please provide start_date and end_date"},
)
start_date_obj = datetime.strptime(start_date, "%Y-%m-%d")
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
from litellm.proxy.proxy_server import prisma_client
try:
if prisma_client is None:
raise ValueError(
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
sql_query = """
SELECT
CASE
WHEN vt."key_alias" IS NOT NULL THEN vt."key_alias"
ELSE 'Unnamed Key'
END AS api_key,
sl."call_type",
sl."model",
COUNT(*) AS total_rows,
SUM(CASE WHEN sl."cache_hit" = 'True' THEN 1 ELSE 0 END) AS cache_hit_true_rows,
SUM(CASE WHEN sl."cache_hit" = 'True' THEN sl."completion_tokens" ELSE 0 END) AS cached_completion_tokens,
SUM(CASE WHEN sl."cache_hit" != 'True' THEN sl."completion_tokens" ELSE 0 END) AS generated_completion_tokens
FROM "LiteLLM_SpendLogs" sl
LEFT JOIN "LiteLLM_VerificationToken" vt ON sl."api_key" = vt."token"
WHERE
sl."startTime" >= $1::timestamptz AND "startTime" < ($2::timestamptz + INTERVAL \'1 day\')
GROUP BY
vt."key_alias",
sl."call_type",
sl."model"
"""
db_response = await prisma_client.db.query_raw(
sql_query, start_date_obj, end_date_obj
)
if db_response is None:
return []
return db_response
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
)

View File

@@ -0,0 +1,11 @@
"""
Claude Code Endpoints
Provides endpoints for Claude Code plugin marketplace integration.
"""
from litellm.proxy.anthropic_endpoints.claude_code_endpoints.claude_code_marketplace import (
router as claude_code_marketplace_router,
)
__all__ = ["claude_code_marketplace_router"]

View File

@@ -0,0 +1,546 @@
"""
CLAUDE CODE MARKETPLACE
Provides a registry/discovery layer for Claude Code plugins.
Plugins are stored as metadata + git source references in LiteLLM database.
Actual plugin files are hosted on GitHub/GitLab/Bitbucket.
Endpoints:
/claude-code/marketplace.json - GET - List plugins for Claude Code discovery
/claude-code/plugins - POST - Register a plugin
/claude-code/plugins - GET - List plugins (admin)
/claude-code/plugins/{name} - GET - Get plugin details
/claude-code/plugins/{name}/enable - POST - Enable a plugin
/claude-code/plugins/{name}/disable - POST - Disable a plugin
/claude-code/plugins/{name} - DELETE - Delete a plugin
"""
import json
import re
from datetime import datetime, timezone
from typing import Any, Dict
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.proxy.claude_code_endpoints import (
ListPluginsResponse,
PluginListItem,
RegisterPluginRequest,
)
router = APIRouter()
async def _get_prisma_client():
"""Get the prisma client from proxy_server."""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
return prisma_client
@router.get(
"/claude-code/marketplace.json",
tags=["Claude Code Marketplace"],
)
async def get_marketplace():
"""
Serve marketplace.json for Claude Code plugin discovery.
This endpoint is accessed by Claude Code CLI when users run:
- claude plugin marketplace add <url>
- claude plugin install <name>@<marketplace>
Returns:
Marketplace catalog with list of available plugins and their git sources.
Example:
```bash
claude plugin marketplace add http://localhost:4000/claude-code/marketplace.json
claude plugin install my-plugin@litellm
```
"""
try:
prisma_client = await _get_prisma_client()
plugins = await prisma_client.db.litellm_claudecodeplugintable.find_many(
where={"enabled": True}
)
plugin_list = []
for plugin in plugins:
try:
manifest = json.loads(plugin.manifest_json)
except json.JSONDecodeError:
verbose_proxy_logger.warning(
f"Plugin {plugin.name} has invalid manifest JSON, skipping"
)
continue
# Source must be specified for URL-based marketplaces
if "source" not in manifest:
verbose_proxy_logger.warning(
f"Plugin {plugin.name} has no source field, skipping"
)
continue
entry: Dict[str, Any] = {
"name": plugin.name,
"source": manifest["source"],
}
if plugin.version:
entry["version"] = plugin.version
if plugin.description:
entry["description"] = plugin.description
if "author" in manifest:
entry["author"] = manifest["author"]
if "homepage" in manifest:
entry["homepage"] = manifest["homepage"]
if "keywords" in manifest:
entry["keywords"] = manifest["keywords"]
if "category" in manifest:
entry["category"] = manifest["category"]
plugin_list.append(entry)
marketplace = {
"name": "litellm",
"owner": {"name": "LiteLLM", "email": "support@litellm.ai"},
"plugins": plugin_list,
}
return JSONResponse(content=marketplace)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error generating marketplace: {e}")
raise HTTPException(
status_code=500,
detail={"error": f"Failed to generate marketplace: {str(e)}"},
)
@router.post(
"/claude-code/plugins",
tags=["Claude Code Marketplace"],
dependencies=[Depends(user_api_key_auth)],
)
async def register_plugin(
request: RegisterPluginRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Register a plugin in the LiteLLM marketplace.
LiteLLM acts as a registry/discovery layer. Plugins are hosted on
GitHub/GitLab/Bitbucket. Claude Code will clone from the git source
when users install.
Parameters:
- name: Plugin name (kebab-case)
- source: Git source reference (github or url format)
- version: Semantic version (optional)
- description: Plugin description (optional)
- author: Author information (optional)
- homepage: Plugin homepage URL (optional)
- keywords: Search keywords (optional)
- category: Plugin category (optional)
Returns:
Registration status and plugin information.
Example:
```bash
curl -X POST http://localhost:4000/claude-code/plugins \\
-H "Authorization: Bearer sk-..." \\
-H "Content-Type: application/json" \\
-d '{
"name": "my-plugin",
"source": {"source": "github", "repo": "org/my-plugin"},
"version": "1.0.0",
"description": "My awesome plugin"
}'
```
"""
try:
prisma_client = await _get_prisma_client()
# Validate name format
if not re.match(r"^[a-z0-9-]+$", request.name):
raise HTTPException(
status_code=400,
detail={
"error": "Plugin name must be kebab-case (lowercase letters, numbers, hyphens)"
},
)
# Validate source format
source = request.source
source_type = source.get("source")
if source_type == "github":
if "repo" not in source:
raise HTTPException(
status_code=400,
detail={
"error": "GitHub source must include 'repo' field (e.g., 'org/repo')"
},
)
elif source_type == "url":
if "url" not in source:
raise HTTPException(
status_code=400,
detail={
"error": "URL source must include 'url' field (e.g., 'https://github.com/org/repo.git')"
},
)
else:
raise HTTPException(
status_code=400,
detail={"error": "source.source must be 'github' or 'url'"},
)
# Build manifest for storage
manifest: Dict[str, Any] = {
"name": request.name,
"source": request.source,
}
if request.version:
manifest["version"] = request.version
if request.description:
manifest["description"] = request.description
if request.author:
manifest["author"] = request.author.model_dump(exclude_none=True)
if request.homepage:
manifest["homepage"] = request.homepage
if request.keywords:
manifest["keywords"] = request.keywords
if request.category:
manifest["category"] = request.category
# Check if plugin exists
existing = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
where={"name": request.name}
)
if existing:
plugin = await prisma_client.db.litellm_claudecodeplugintable.update(
where={"name": request.name},
data={
"version": request.version,
"description": request.description,
"manifest_json": json.dumps(manifest),
"files_json": "{}",
"updated_at": datetime.now(timezone.utc),
},
)
action = "updated"
else:
plugin = await prisma_client.db.litellm_claudecodeplugintable.create(
data={
"name": request.name,
"version": request.version,
"description": request.description,
"manifest_json": json.dumps(manifest),
"files_json": "{}",
"enabled": True,
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
"created_by": user_api_key_dict.user_id,
}
)
action = "created"
verbose_proxy_logger.info(f"Plugin {request.name} {action} successfully")
return {
"status": "success",
"action": action,
"plugin": {
"id": plugin.id,
"name": plugin.name,
"version": plugin.version,
"description": plugin.description,
"source": request.source,
"enabled": plugin.enabled,
},
}
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error registering plugin: {e}")
raise HTTPException(
status_code=500,
detail={"error": f"Registration failed: {str(e)}"},
)
@router.get(
"/claude-code/plugins",
tags=["Claude Code Marketplace"],
dependencies=[Depends(user_api_key_auth)],
response_model=ListPluginsResponse,
)
async def list_plugins(
enabled_only: bool = False,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
List all plugins in the marketplace.
Parameters:
- enabled_only: If true, only return enabled plugins
Returns:
List of plugins with their metadata.
"""
try:
prisma_client = await _get_prisma_client()
where = {"enabled": True} if enabled_only else {}
plugins = await prisma_client.db.litellm_claudecodeplugintable.find_many(
where=where
)
plugin_list = []
for p in plugins:
# Parse manifest to get additional fields
manifest = json.loads(p.manifest_json) if p.manifest_json else {}
plugin_list.append(
PluginListItem(
id=p.id,
name=p.name,
version=p.version,
description=p.description,
source=manifest.get("source", {}),
author=manifest.get("author"),
homepage=manifest.get("homepage"),
keywords=manifest.get("keywords"),
category=manifest.get("category"),
enabled=p.enabled,
created_at=p.created_at.isoformat() if p.created_at else None,
updated_at=p.updated_at.isoformat() if p.updated_at else None,
)
)
# Sort by created_at descending (newest first)
plugin_list.sort(key=lambda x: x.created_at or "", reverse=True)
return ListPluginsResponse(
plugins=plugin_list,
count=len(plugin_list),
)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error listing plugins: {e}")
raise HTTPException(
status_code=500,
detail={"error": str(e)},
)
@router.get(
"/claude-code/plugins/{plugin_name}",
tags=["Claude Code Marketplace"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_plugin(
plugin_name: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get details of a specific plugin.
Parameters:
- plugin_name: The name of the plugin
Returns:
Plugin details including source and metadata.
"""
try:
prisma_client = await _get_prisma_client()
plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
where={"name": plugin_name}
)
if not plugin:
raise HTTPException(
status_code=404,
detail={"error": f"Plugin '{plugin_name}' not found"},
)
manifest = json.loads(plugin.manifest_json) if plugin.manifest_json else {}
return {
"id": plugin.id,
"name": plugin.name,
"version": plugin.version,
"description": plugin.description,
"source": manifest.get("source"),
"author": manifest.get("author"),
"homepage": manifest.get("homepage"),
"keywords": manifest.get("keywords"),
"category": manifest.get("category"),
"enabled": plugin.enabled,
"created_at": plugin.created_at.isoformat() if plugin.created_at else None,
"updated_at": plugin.updated_at.isoformat() if plugin.updated_at else None,
"created_by": plugin.created_by,
}
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error getting plugin: {e}")
raise HTTPException(
status_code=500,
detail={"error": str(e)},
)
@router.post(
"/claude-code/plugins/{plugin_name}/enable",
tags=["Claude Code Marketplace"],
dependencies=[Depends(user_api_key_auth)],
)
async def enable_plugin(
plugin_name: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Enable a disabled plugin.
Parameters:
- plugin_name: The name of the plugin to enable
"""
try:
prisma_client = await _get_prisma_client()
plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
where={"name": plugin_name}
)
if not plugin:
raise HTTPException(
status_code=404,
detail={"error": f"Plugin '{plugin_name}' not found"},
)
await prisma_client.db.litellm_claudecodeplugintable.update(
where={"name": plugin_name},
data={"enabled": True, "updated_at": datetime.now(timezone.utc)},
)
verbose_proxy_logger.info(f"Plugin {plugin_name} enabled")
return {"status": "success", "message": f"Plugin '{plugin_name}' enabled"}
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error enabling plugin: {e}")
raise HTTPException(
status_code=500,
detail={"error": str(e)},
)
@router.post(
"/claude-code/plugins/{plugin_name}/disable",
tags=["Claude Code Marketplace"],
dependencies=[Depends(user_api_key_auth)],
)
async def disable_plugin(
plugin_name: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Disable a plugin without deleting it.
Parameters:
- plugin_name: The name of the plugin to disable
"""
try:
prisma_client = await _get_prisma_client()
plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
where={"name": plugin_name}
)
if not plugin:
raise HTTPException(
status_code=404,
detail={"error": f"Plugin '{plugin_name}' not found"},
)
await prisma_client.db.litellm_claudecodeplugintable.update(
where={"name": plugin_name},
data={"enabled": False, "updated_at": datetime.now(timezone.utc)},
)
verbose_proxy_logger.info(f"Plugin {plugin_name} disabled")
return {"status": "success", "message": f"Plugin '{plugin_name}' disabled"}
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error disabling plugin: {e}")
raise HTTPException(
status_code=500,
detail={"error": str(e)},
)
@router.delete(
"/claude-code/plugins/{plugin_name}",
tags=["Claude Code Marketplace"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_plugin(
plugin_name: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Delete a plugin from the marketplace.
Parameters:
- plugin_name: The name of the plugin to delete
"""
try:
prisma_client = await _get_prisma_client()
plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
where={"name": plugin_name}
)
if not plugin:
raise HTTPException(
status_code=404,
detail={"error": f"Plugin '{plugin_name}' not found"},
)
await prisma_client.db.litellm_claudecodeplugintable.delete(
where={"name": plugin_name}
)
verbose_proxy_logger.info(f"Plugin {plugin_name} deleted")
return {"status": "success", "message": f"Plugin '{plugin_name}' deleted"}
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error deleting plugin: {e}")
raise HTTPException(
status_code=500,
detail={"error": str(e)},
)

View File

@@ -0,0 +1,264 @@
"""
Unified /v1/messages endpoint - (Anthropic Spec)
"""
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from litellm._logging import verbose_proxy_logger
from litellm.anthropic_interface.exceptions import AnthropicExceptionMapping
from litellm.integrations.custom_guardrail import ModifyResponseException
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_request_processing import (
ProxyBaseLLMRequestProcessing,
create_response,
)
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.types.utils import TokenCountResponse
router = APIRouter()
@router.post(
"/v1/messages",
tags=["[beta] Anthropic `/v1/messages`"],
dependencies=[Depends(user_api_key_auth)],
)
async def anthropic_response( # noqa: PLR0915
fastapi_response: Response,
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Use `{PROXY_BASE_URL}/anthropic/v1/messages` instead - [Docs](https://docs.litellm.ai/docs/pass_through/anthropic_completion).
This was a BETA endpoint that calls 100+ LLMs in the anthropic format.
"""
from litellm.proxy.proxy_server import (
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
user_api_base,
user_max_tokens,
user_model,
user_request_timeout,
user_temperature,
version,
)
data = await _read_request_body(request=request)
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
try:
result = await base_llm_response_processor.base_process_llm_request(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
route_type="anthropic_messages",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=None,
model=None,
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
return result
except ModifyResponseException as e:
# Guardrail flagged content in passthrough mode - return 200 with violation message
_data = e.request_data
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=_data,
)
# Create Anthropic-formatted response with violation message
import uuid
from litellm.types.utils import AnthropicMessagesResponse
_anthropic_response = AnthropicMessagesResponse(
id=f"msg_{str(uuid.uuid4())}",
type="message",
role="assistant",
content=[{"type": "text", "text": e.message}],
model=e.model,
stop_reason="end_turn",
usage={"input_tokens": 0, "output_tokens": 0},
)
if data.get("stream", None) is not None and data["stream"] is True:
# For streaming, use the standard SSE data generator
async def _passthrough_stream_generator():
yield _anthropic_response
selected_data_generator = (
ProxyBaseLLMRequestProcessing.async_sse_data_generator(
response=_passthrough_stream_generator(),
user_api_key_dict=user_api_key_dict,
request_data=_data,
proxy_logging_obj=proxy_logging_obj,
)
)
return await create_response(
generator=selected_data_generator,
media_type="text/event-stream",
headers={},
)
return _anthropic_response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.anthropic_response(): Exception occured - {}".format(
str(e)
)
)
# Extract model_id from request metadata (same as success path)
litellm_metadata = data.get("litellm_metadata", {}) or {}
model_info = litellm_metadata.get("model_info", {}) or {}
model_id = model_info.get("id", "") or ""
# Get headers
headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
call_id=data.get("litellm_call_id", ""),
model_id=model_id,
version=version,
response_cost=0,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
timeout=getattr(e, "timeout", None),
litellm_logging_obj=None,
)
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
headers=headers,
)
@router.post(
"/v1/messages/count_tokens",
tags=["[beta] Anthropic Messages Token Counting"],
dependencies=[Depends(user_api_key_auth)],
)
async def count_tokens(
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), # Used for auth
):
"""
Count tokens for Anthropic Messages API format.
This endpoint follows the Anthropic Messages API token counting specification.
It accepts the same parameters as the /v1/messages endpoint but returns
token counts instead of generating a response.
Example usage:
```
curl -X POST "http://localhost:4000/v1/messages/count_tokens?beta=true" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer your-key" \
-d '{
"model": "claude-3-sonnet-20240229",
"messages": [{"role": "user", "content": "Hello Claude!"}]
}'
```
Returns: {"input_tokens": <number>}
"""
from litellm.proxy.proxy_server import token_counter as internal_token_counter
try:
request_data = await _read_request_body(request=request)
data: dict = {**request_data}
# Extract required fields
model_name = data.get("model")
messages = data.get("messages", [])
if not model_name:
raise HTTPException(
status_code=400, detail={"error": "model parameter is required"}
)
if not messages:
raise HTTPException(
status_code=400, detail={"error": "messages parameter is required"}
)
# Create TokenCountRequest for the internal endpoint
from litellm.proxy._types import TokenCountRequest
token_request = TokenCountRequest(
model=model_name,
messages=messages,
tools=data.get("tools"),
system=data.get("system"),
)
# Call the internal token counter function with direct request flag set to False
token_response = await internal_token_counter(
request=token_request,
call_endpoint=True,
)
_token_response_dict: dict = {}
if isinstance(token_response, TokenCountResponse):
_token_response_dict = token_response.model_dump()
elif isinstance(token_response, dict):
_token_response_dict = token_response
# Convert the internal response to Anthropic API format
return {"input_tokens": _token_response_dict.get("total_tokens", 0)}
except HTTPException:
raise
except ProxyException as e:
status_code = int(e.code) if e.code and e.code.isdigit() else 500
detail = AnthropicExceptionMapping.transform_to_anthropic_error(
status_code=status_code,
raw_message=e.message,
)
raise HTTPException(
status_code=status_code,
detail=detail,
)
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.anthropic_endpoints.count_tokens(): Exception occurred - {}".format(
str(e)
)
)
raise HTTPException(
status_code=500, detail={"error": f"Internal server error: {str(e)}"}
)
@router.post(
"/api/event_logging/batch",
tags=["[beta] Anthropic Event Logging"],
)
async def event_logging_batch(
request: Request,
):
"""
Stubbed endpoint for Anthropic event logging batch requests.
This endpoint accepts event logging requests but does nothing with them.
It exists to prevent 404 errors from Claude Code clients that send telemetry.
"""
return {"status": "ok"}

View File

@@ -0,0 +1,437 @@
"""
Anthropic Skills API endpoints - /v1/skills
"""
from typing import Optional
import orjson
from fastapi import APIRouter, Depends, Request, Response
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
from litellm.proxy.common_utils.http_parsing_utils import (
convert_upload_files_to_file_data,
get_form_data,
)
from litellm.types.llms.anthropic_skills import (
DeleteSkillResponse,
ListSkillsResponse,
Skill,
)
router = APIRouter()
@router.post(
"/v1/skills",
tags=["[beta] Anthropic Skills API"],
dependencies=[Depends(user_api_key_auth)],
response_model=Skill,
)
async def create_skill(
fastapi_response: Response,
request: Request,
custom_llm_provider: Optional[str] = "anthropic",
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create a new skill on Anthropic.
Requires `?beta=true` query parameter.
Model-based routing (for multi-account support):
- Pass model via header: `x-litellm-model: claude-account-1`
- Pass model via query: `?model=claude-account-1`
- Pass model via form field: `model=claude-account-1`
Example usage:
```bash
# Basic usage
curl -X POST "http://localhost:4000/v1/skills?beta=true" \
-H "Content-Type: multipart/form-data" \
-H "Authorization: Bearer your-key" \
-F "display_title=My Skill" \
-F "files[]=@skill.zip"
# With model-based routing
curl -X POST "http://localhost:4000/v1/skills?beta=true" \
-H "Content-Type: multipart/form-data" \
-H "Authorization: Bearer your-key" \
-H "x-litellm-model: claude-account-1" \
-F "display_title=My Skill" \
-F "files[]=@skill.zip"
```
Returns: Skill object with id, display_title, etc.
"""
from litellm.proxy.proxy_server import (
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
select_data_generator,
user_api_base,
user_max_tokens,
user_model,
user_request_timeout,
user_temperature,
version,
)
# Read form data and convert UploadFile objects to file data tuples
form_data = await get_form_data(request)
data = await convert_upload_files_to_file_data(form_data)
# Extract model for routing (header > query > body)
model = (
data.get("model")
or request.query_params.get("model")
or request.headers.get("x-litellm-model")
)
if model:
data["model"] = model
if "custom_llm_provider" not in data:
data["custom_llm_provider"] = custom_llm_provider
# Process request using ProxyBaseLLMRequestProcessing
processor = ProxyBaseLLMRequestProcessing(data=data)
try:
return await processor.base_process_llm_request(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
route_type="acreate_skill",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=select_data_generator,
model=data.get("model"),
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
except Exception as e:
raise await processor._handle_llm_api_exception(
e=e,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
version=version,
)
@router.get(
"/v1/skills",
tags=["[beta] Anthropic Skills API"],
dependencies=[Depends(user_api_key_auth)],
response_model=ListSkillsResponse,
)
async def list_skills(
fastapi_response: Response,
request: Request,
limit: Optional[int] = 10,
after_id: Optional[str] = None,
before_id: Optional[str] = None,
custom_llm_provider: Optional[str] = "anthropic",
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
List skills on Anthropic.
Requires `?beta=true` query parameter.
Model-based routing (for multi-account support):
- Pass model via header: `x-litellm-model: claude-account-1`
- Pass model via query: `?model=claude-account-1`
- Pass model via body: `{"model": "claude-account-1"}`
Example usage:
```bash
# Basic usage
curl "http://localhost:4000/v1/skills?beta=true&limit=10" \
-H "Authorization: Bearer your-key"
# With model-based routing
curl "http://localhost:4000/v1/skills?beta=true&limit=10" \
-H "Authorization: Bearer your-key" \
-H "x-litellm-model: claude-account-1"
```
Returns: ListSkillsResponse with list of skills
"""
from litellm.proxy.proxy_server import (
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
select_data_generator,
user_api_base,
user_max_tokens,
user_model,
user_request_timeout,
user_temperature,
version,
)
# Read request body
body = await request.body()
data = orjson.loads(body) if body else {}
# Use query params if not in body
if "limit" not in data and limit is not None:
data["limit"] = limit
if "after_id" not in data and after_id is not None:
data["after_id"] = after_id
if "before_id" not in data and before_id is not None:
data["before_id"] = before_id
# Extract model for routing (header > query > body)
model = (
data.get("model")
or request.query_params.get("model")
or request.headers.get("x-litellm-model")
)
if model:
data["model"] = model
# Set custom_llm_provider: body > query param > default
if "custom_llm_provider" not in data:
data["custom_llm_provider"] = custom_llm_provider
# Process request using ProxyBaseLLMRequestProcessing
processor = ProxyBaseLLMRequestProcessing(data=data)
try:
return await processor.base_process_llm_request(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
route_type="alist_skills",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=select_data_generator,
model=data.get("model"),
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
except Exception as e:
raise await processor._handle_llm_api_exception(
e=e,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
version=version,
)
@router.get(
"/v1/skills/{skill_id}",
tags=["[beta] Anthropic Skills API"],
dependencies=[Depends(user_api_key_auth)],
response_model=Skill,
)
async def get_skill(
skill_id: str,
fastapi_response: Response,
request: Request,
custom_llm_provider: Optional[str] = "anthropic",
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get a specific skill by ID from Anthropic.
Requires `?beta=true` query parameter.
Model-based routing (for multi-account support):
- Pass model via header: `x-litellm-model: claude-account-1`
- Pass model via query: `?model=claude-account-1`
- Pass model via body: `{"model": "claude-account-1"}`
Example usage:
```bash
# Basic usage
curl "http://localhost:4000/v1/skills/skill_123?beta=true" \
-H "Authorization: Bearer your-key"
# With model-based routing
curl "http://localhost:4000/v1/skills/skill_123?beta=true" \
-H "Authorization: Bearer your-key" \
-H "x-litellm-model: claude-account-1"
```
Returns: Skill object
"""
from litellm.proxy.proxy_server import (
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
select_data_generator,
user_api_base,
user_max_tokens,
user_model,
user_request_timeout,
user_temperature,
version,
)
# Read request body
body = await request.body()
data = orjson.loads(body) if body else {}
# Set skill_id from path parameter
data["skill_id"] = skill_id
# Extract model for routing (header > query > body)
model = (
data.get("model")
or request.query_params.get("model")
or request.headers.get("x-litellm-model")
)
if model:
data["model"] = model
# Set custom_llm_provider: body > query param > default
if "custom_llm_provider" not in data:
data["custom_llm_provider"] = custom_llm_provider
# Process request using ProxyBaseLLMRequestProcessing
processor = ProxyBaseLLMRequestProcessing(data=data)
try:
return await processor.base_process_llm_request(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
route_type="aget_skill",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=select_data_generator,
model=data.get("model"),
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
except Exception as e:
raise await processor._handle_llm_api_exception(
e=e,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
version=version,
)
@router.delete(
"/v1/skills/{skill_id}",
tags=["[beta] Anthropic Skills API"],
dependencies=[Depends(user_api_key_auth)],
response_model=DeleteSkillResponse,
)
async def delete_skill(
skill_id: str,
fastapi_response: Response,
request: Request,
custom_llm_provider: Optional[str] = "anthropic",
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Delete a skill by ID from Anthropic.
Requires `?beta=true` query parameter.
Note: Anthropic does not allow deleting skills with existing versions.
Model-based routing (for multi-account support):
- Pass model via header: `x-litellm-model: claude-account-1`
- Pass model via query: `?model=claude-account-1`
- Pass model via body: `{"model": "claude-account-1"}`
Example usage:
```bash
# Basic usage
curl -X DELETE "http://localhost:4000/v1/skills/skill_123?beta=true" \
-H "Authorization: Bearer your-key"
# With model-based routing
curl -X DELETE "http://localhost:4000/v1/skills/skill_123?beta=true" \
-H "Authorization: Bearer your-key" \
-H "x-litellm-model: claude-account-1"
```
Returns: DeleteSkillResponse with type="skill_deleted"
"""
from litellm.proxy.proxy_server import (
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
select_data_generator,
user_api_base,
user_max_tokens,
user_model,
user_request_timeout,
user_temperature,
version,
)
# Read request body
body = await request.body()
data = orjson.loads(body) if body else {}
# Set skill_id from path parameter
data["skill_id"] = skill_id
# Extract model for routing (header > query > body)
model = (
data.get("model")
or request.query_params.get("model")
or request.headers.get("x-litellm-model")
)
if model:
data["model"] = model
# Set custom_llm_provider: body > query param > default
if "custom_llm_provider" not in data:
data["custom_llm_provider"] = custom_llm_provider
# Process request using ProxyBaseLLMRequestProcessing
processor = ProxyBaseLLMRequestProcessing(data=data)
try:
return await processor.base_process_llm_request(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
route_type="adelete_skill",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=select_data_generator,
model=data.get("model"),
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
except Exception as e:
raise await processor._handle_llm_api_exception(
e=e,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
version=version,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,176 @@
"""
Auth Checks for Organizations
"""
from typing import Dict, List, Optional, Tuple
from fastapi import status
from litellm.proxy._types import *
def organization_role_based_access_check(
request_body: dict,
user_object: Optional[LiteLLM_UserTable],
route: str,
):
"""
Role based access control checks only run if a user is part of an Organization
Organization Checks:
ONLY RUN IF user_object.organization_memberships is not None
1. Only Proxy Admins can access /organization/new
2. IF route is a LiteLLMRoutes.org_admin_only_routes, then check if user is an Org Admin for that organization
"""
if user_object is None:
return
passed_organization_id: Optional[str] = request_body.get("organization_id", None)
if route == "/organization/new":
if user_object.user_role != LitellmUserRoles.PROXY_ADMIN.value:
raise ProxyException(
message=f"Only proxy admins can create new organizations. You are {user_object.user_role}",
type=ProxyErrorTypes.auth_error.value,
param="user_role",
code=status.HTTP_401_UNAUTHORIZED,
)
if user_object.user_role == LitellmUserRoles.PROXY_ADMIN.value:
return
# Checks if route is an Org Admin Only Route
if route in LiteLLMRoutes.org_admin_only_routes.value:
(
_user_organizations,
_user_organization_role_mapping,
) = get_user_organization_info(user_object)
if user_object.organization_memberships is None:
raise ProxyException(
message=f"Tried to access route={route} but you are not a member of any organization. Please contact the proxy admin to request access.",
type=ProxyErrorTypes.auth_error.value,
param="organization_id",
code=status.HTTP_401_UNAUTHORIZED,
)
if passed_organization_id is None:
raise ProxyException(
message="Passed organization_id is None, please pass an organization_id in your request",
type=ProxyErrorTypes.auth_error.value,
param="organization_id",
code=status.HTTP_401_UNAUTHORIZED,
)
user_role: Optional[LitellmUserRoles] = _user_organization_role_mapping.get(
passed_organization_id
)
if user_role is None:
raise ProxyException(
message=f"You do not have a role within the selected organization. Passed organization_id: {passed_organization_id}. Please contact the organization admin to request access.",
type=ProxyErrorTypes.auth_error.value,
param="organization_id",
code=status.HTTP_401_UNAUTHORIZED,
)
if user_role != LitellmUserRoles.ORG_ADMIN.value:
raise ProxyException(
message=f"You do not have the required role to perform {route} in Organization {passed_organization_id}. Your role is {user_role} in Organization {passed_organization_id}",
type=ProxyErrorTypes.auth_error.value,
param="user_role",
code=status.HTTP_401_UNAUTHORIZED,
)
elif route == "/team/new":
# if user is part of multiple teams, then they need to specify the organization_id
(
_user_organizations,
_user_organization_role_mapping,
) = get_user_organization_info(user_object)
if (
user_object.organization_memberships is not None
and len(user_object.organization_memberships) > 0
):
if passed_organization_id is None:
raise ProxyException(
message=f"Passed organization_id is None, please specify the organization_id in your request. You are part of multiple organizations: {_user_organizations}",
type=ProxyErrorTypes.auth_error.value,
param="organization_id",
code=status.HTTP_401_UNAUTHORIZED,
)
_user_role_in_passed_org = _user_organization_role_mapping.get(
passed_organization_id
)
if _user_role_in_passed_org != LitellmUserRoles.ORG_ADMIN.value:
raise ProxyException(
message=f"You do not have the required role to call {route}. Your role is {_user_role_in_passed_org} in Organization {passed_organization_id}",
type=ProxyErrorTypes.auth_error.value,
param="user_role",
code=status.HTTP_401_UNAUTHORIZED,
)
def get_user_organization_info(
user_object: LiteLLM_UserTable,
) -> Tuple[List[str], Dict[str, Optional[LitellmUserRoles]]]:
"""
Helper function to extract user organization information.
Args:
user_object (LiteLLM_UserTable): The user object containing organization memberships.
Returns:
Tuple[List[str], Dict[str, Optional[LitellmUserRoles]]]: A tuple containing:
- List of organization IDs the user is a member of
- Dictionary mapping organization IDs to user roles
"""
_user_organizations: List[str] = []
_user_organization_role_mapping: Dict[str, Optional[LitellmUserRoles]] = {}
if user_object.organization_memberships is not None:
for _membership in user_object.organization_memberships:
if _membership.organization_id is not None:
_user_organizations.append(_membership.organization_id)
_user_organization_role_mapping[_membership.organization_id] = _membership.user_role # type: ignore
return _user_organizations, _user_organization_role_mapping
def _user_is_org_admin(
request_data: dict,
user_object: Optional[LiteLLM_UserTable] = None,
) -> bool:
"""
Helper function to check if user is an org admin for any of the passed organizations.
Checks both:
- `organization_id` (singular string) — legacy callers
- `organizations` (list of strings) — used by /user/new
"""
if user_object is None:
return False
if user_object.organization_memberships is None:
return False
# Collect candidate org IDs from both fields
candidate_org_ids: List[str] = []
singular = request_data.get("organization_id", None)
if singular is not None:
candidate_org_ids.append(singular)
orgs_list = request_data.get("organizations", None)
if isinstance(orgs_list, list):
candidate_org_ids.extend(orgs_list)
if not candidate_org_ids:
return False
for _membership in user_object.organization_memberships:
if _membership.organization_id in candidate_org_ids:
if _membership.user_role == LitellmUserRoles.ORG_ADMIN.value:
return True
return False

View File

@@ -0,0 +1,125 @@
"""
Handles Authentication Errors
"""
from typing import TYPE_CHECKING, Any, Optional, Union
from fastapi import HTTPException, Request, status
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ProxyErrorTypes, ProxyException, UserAPIKeyAuth
from litellm.proxy.auth.auth_utils import _get_request_ip_address
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler
from litellm.types.services import ServiceTypes
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
class UserAPIKeyAuthExceptionHandler:
@staticmethod
async def _handle_authentication_error(
e: Exception,
request: Request,
request_data: dict,
route: str,
parent_otel_span: Optional[Span],
api_key: str,
) -> UserAPIKeyAuth:
"""
Handles Connection Errors when reading a Virtual Key from LiteLLM DB
Use this if you don't want failed DB queries to block LLM API reqiests
Reliability scenarios this covers:
- DB is down and having an outage
- Unable to read / recover a key from the DB
Returns:
- UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True
Raises:
- Original Exception in all other cases
"""
from litellm.proxy.proxy_server import (
general_settings,
litellm_proxy_admin_name,
proxy_logging_obj,
)
if (
PrismaDBExceptionHandler.should_allow_request_on_db_unavailable()
and PrismaDBExceptionHandler.is_database_connection_error(e)
):
# log this as a DB failure on prometheus
proxy_logging_obj.service_logging_obj.service_failure_hook(
service=ServiceTypes.DB,
call_type="get_key_object",
error=e,
duration=0.0,
)
return UserAPIKeyAuth(
key_name="failed-to-connect-to-db",
token="failed-to-connect-to-db",
user_id=litellm_proxy_admin_name,
request_route=route,
)
else:
# raise the exception to the caller
requester_ip = _get_request_ip_address(
request=request,
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
)
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\nRequester IP Address:{}".format(
str(e),
requester_ip,
),
extra={"requester_ip": requester_ip},
)
# Log this exception to OTEL, Datadog etc
user_api_key_dict = UserAPIKeyAuth(
parent_otel_span=parent_otel_span,
api_key=api_key,
request_route=route,
)
# Allow callbacks to transform the error response
transformed_exception = await proxy_logging_obj.post_call_failure_hook(
request_data=request_data,
original_exception=e,
user_api_key_dict=user_api_key_dict,
error_type=ProxyErrorTypes.auth_error,
route=route,
)
# Use transformed exception if callback returned one, otherwise use original
if transformed_exception is not None:
e = transformed_exception
if isinstance(e, litellm.BudgetExceededError):
raise ProxyException(
message=e.message,
type=ProxyErrorTypes.budget_exceeded,
param=None,
code=400,
)
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type=ProxyErrorTypes.auth_error,
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Authentication Error, " + str(e),
type=ProxyErrorTypes.auth_error,
param=getattr(e, "param", "None"),
code=status.HTTP_401_UNAUTHORIZED,
)

View File

@@ -0,0 +1,835 @@
import os
import re
import sys
from functools import lru_cache
from typing import Any, List, Optional, Tuple
from fastapi import HTTPException, Request, status
from litellm import Router, provider_list
from litellm._logging import verbose_proxy_logger
from litellm.constants import STANDARD_CUSTOMER_ID_HEADERS
from litellm.proxy._types import *
from litellm.types.router import CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS
def _get_request_ip_address(
request: Request, use_x_forwarded_for: Optional[bool] = False
) -> Optional[str]:
client_ip = None
if use_x_forwarded_for is True and "x-forwarded-for" in request.headers:
client_ip = request.headers["x-forwarded-for"]
elif request.client is not None:
client_ip = request.client.host
else:
client_ip = ""
return client_ip
def _check_valid_ip(
allowed_ips: Optional[List[str]],
request: Request,
use_x_forwarded_for: Optional[bool] = False,
) -> Tuple[bool, Optional[str]]:
"""
Returns if ip is allowed or not
"""
if allowed_ips is None: # if not set, assume true
return True, None
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
client_ip = _get_request_ip_address(
request=request, use_x_forwarded_for=use_x_forwarded_for
)
# Check if IP address is allowed
if client_ip not in allowed_ips:
return False, client_ip
return True, client_ip
def check_complete_credentials(request_body: dict) -> bool:
"""
if 'api_base' in request body. Check if complete credentials given. Prevent malicious attacks.
"""
given_model: Optional[str] = None
given_model = request_body.get("model")
if given_model is None:
return False
if (
"sagemaker" in given_model
or "bedrock" in given_model
or "vertex_ai" in given_model
or "vertex_ai_beta" in given_model
):
# complex credentials - easier to make a malicious request
return False
if "api_key" in request_body:
return True
return False
def check_regex_or_str_match(request_body_value: Any, regex_str: str) -> bool:
"""
Check if request_body_value matches the regex_str or is equal to param
"""
if re.match(regex_str, request_body_value) or regex_str == request_body_value:
return True
return False
def _is_param_allowed(
param: str,
request_body_value: Any,
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
) -> bool:
"""
Check if param is a str or dict and if request_body_value is in the list of allowed values
"""
if configurable_clientside_auth_params is None:
return False
for item in configurable_clientside_auth_params:
if isinstance(item, str) and param == item:
return True
elif isinstance(item, Dict):
if param == "api_base" and check_regex_or_str_match(
request_body_value=request_body_value,
regex_str=item["api_base"],
): # assume param is a regex
return True
return False
def _allow_model_level_clientside_configurable_parameters(
model: str, param: str, request_body_value: Any, llm_router: Optional[Router]
) -> bool:
"""
Check if model is allowed to use configurable client-side params
- get matching model
- check if 'clientside_configurable_parameters' is set for model
-
"""
if llm_router is None:
return False
# check if model is set
model_info = llm_router.get_model_group_info(model_group=model)
if model_info is None:
# check if wildcard model is set
if model.split("/", 1)[0] in provider_list:
model_info = llm_router.get_model_group_info(
model_group=model.split("/", 1)[0]
)
if model_info is None:
return False
if model_info is None or model_info.configurable_clientside_auth_params is None:
return False
return _is_param_allowed(
param=param,
request_body_value=request_body_value,
configurable_clientside_auth_params=model_info.configurable_clientside_auth_params,
)
def is_request_body_safe(
request_body: dict, general_settings: dict, llm_router: Optional[Router], model: str
) -> bool:
"""
Check if the request body is safe.
A malicious user can set the api_base to their own domain and invoke POST /chat/completions to intercept and steal the OpenAI API key.
Relevant issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997
"""
banned_params = ["api_base", "base_url"]
for param in banned_params:
if (
param in request_body
and not check_complete_credentials( # allow client-credentials to be passed to proxy
request_body=request_body
)
):
if general_settings.get("allow_client_side_credentials") is True:
return True
elif (
_allow_model_level_clientside_configurable_parameters(
model=model,
param=param,
request_body_value=request_body[param],
llm_router=llm_router,
)
is True
):
return True
raise ValueError(
f"Rejected Request: {param} is not allowed in request body. "
"Enable with `general_settings::allow_client_side_credentials` on proxy config.yaml. "
"Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997",
)
return True
async def pre_db_read_auth_checks(
request: Request,
request_data: dict,
route: str,
):
"""
1. Checks if request size is under max_request_size_mb (if set)
2. Check if request body is safe (example user has not set api_base in request body)
3. Check if IP address is allowed (if set)
4. Check if request route is an allowed route on the proxy (if set)
Returns:
- True
Raises:
- HTTPException if request fails initial auth checks
"""
from litellm.proxy.proxy_server import general_settings, llm_router, premium_user
# Check 1. request size
await check_if_request_size_is_safe(request=request)
# Check 2. Request body is safe
is_request_body_safe(
request_body=request_data,
general_settings=general_settings,
llm_router=llm_router,
model=request_data.get(
"model", ""
), # [TODO] use model passed in url as well (azure openai routes)
)
# Check 3. Check if IP address is allowed
is_valid_ip, passed_in_ip = _check_valid_ip(
allowed_ips=general_settings.get("allowed_ips", None),
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
request=request,
)
if not is_valid_ip:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access forbidden: IP address {passed_in_ip} not allowed.",
)
# Check 4. Check if request route is an allowed route on the proxy
if "allowed_routes" in general_settings:
_allowed_routes = general_settings["allowed_routes"]
if premium_user is not True:
verbose_proxy_logger.error(
f"Trying to set allowed_routes. This is an Enterprise feature. {CommonProxyErrors.not_premium_user.value}"
)
if route not in _allowed_routes:
verbose_proxy_logger.error(
f"Route {route} not in allowed_routes={_allowed_routes}"
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access forbidden: Route {route} not allowed",
)
def route_in_additonal_public_routes(current_route: str):
"""
Helper to check if the user defined public_routes on config.yaml
Parameters:
- current_route: str - the route the user is trying to call
Returns:
- bool - True if the route is defined in public_routes
- bool - False if the route is not defined in public_routes
Supports wildcard patterns (e.g., "/api/*" matches "/api/users", "/api/users/123")
In order to use this the litellm config.yaml should have the following in general_settings:
```yaml
general_settings:
master_key: sk-1234
public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate", "/api/*"]
```
"""
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.proxy_server import general_settings, premium_user
try:
if premium_user is not True:
return False
if general_settings is None:
return False
routes_defined = general_settings.get("public_routes", [])
# Check exact match first
if current_route in routes_defined:
return True
# Check wildcard patterns
for route_pattern in routes_defined:
if RouteChecks._route_matches_wildcard_pattern(
route=current_route, pattern=route_pattern
):
return True
return False
except Exception as e:
verbose_proxy_logger.error(f"route_in_additonal_public_routes: {str(e)}")
return False
def get_request_route(request: Request) -> str:
"""
Helper to get the route from the request
remove base url from path if set e.g. `/genai/chat/completions` -> `/chat/completions
"""
try:
if hasattr(request, "base_url") and request.url.path.startswith(
request.base_url.path
):
# remove base_url from path
return request.url.path[len(request.base_url.path) - 1 :]
else:
return request.url.path
except Exception as e:
verbose_proxy_logger.debug(
f"error on get_request_route: {str(e)}, defaulting to request.url.path={request.url.path}"
)
return request.url.path
@lru_cache(maxsize=256)
def normalize_request_route(route: str) -> str:
"""
Normalize request routes by replacing dynamic path parameters with placeholders.
This prevents high cardinality in Prometheus metrics by collapsing routes like:
- /v1/responses/1234567890 -> /v1/responses/{response_id}
- /v1/threads/thread_123 -> /v1/threads/{thread_id}
Args:
route: The request route path
Returns:
Normalized route with dynamic parameters replaced by placeholders
Examples:
>>> normalize_request_route("/v1/responses/abc123")
'/v1/responses/{response_id}'
>>> normalize_request_route("/v1/responses/abc123/cancel")
'/v1/responses/{response_id}/cancel'
>>> normalize_request_route("/chat/completions")
'/chat/completions'
"""
# Define patterns for routes with dynamic IDs
# Format: (regex_pattern, replacement_template)
patterns = [
# Responses API - must come before generic patterns
(r"^(/(?:openai/)?v1/responses)/([^/]+)(/input_items)$", r"\1/{response_id}\3"),
(r"^(/(?:openai/)?v1/responses)/([^/]+)(/cancel)$", r"\1/{response_id}\3"),
(r"^(/(?:openai/)?v1/responses)/([^/]+)$", r"\1/{response_id}"),
(r"^(/responses)/([^/]+)(/input_items)$", r"\1/{response_id}\3"),
(r"^(/responses)/([^/]+)(/cancel)$", r"\1/{response_id}\3"),
(r"^(/responses)/([^/]+)$", r"\1/{response_id}"),
# Threads API
(
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/steps)/([^/]+)$",
r"\1/{thread_id}\3/{run_id}\5/{step_id}",
),
(
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/steps)$",
r"\1/{thread_id}\3/{run_id}\5",
),
(
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/cancel)$",
r"\1/{thread_id}\3/{run_id}\5",
),
(
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/submit_tool_outputs)$",
r"\1/{thread_id}\3/{run_id}\5",
),
(
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)$",
r"\1/{thread_id}\3/{run_id}",
),
(r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)$", r"\1/{thread_id}\3"),
(
r"^(/(?:openai/)?v1/threads)/([^/]+)(/messages)/([^/]+)$",
r"\1/{thread_id}\3/{message_id}",
),
(r"^(/(?:openai/)?v1/threads)/([^/]+)(/messages)$", r"\1/{thread_id}\3"),
(r"^(/(?:openai/)?v1/threads)/([^/]+)$", r"\1/{thread_id}"),
# Vector Stores API
(
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/files)/([^/]+)$",
r"\1/{vector_store_id}\3/{file_id}",
),
(
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/files)$",
r"\1/{vector_store_id}\3",
),
(
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/file_batches)/([^/]+)$",
r"\1/{vector_store_id}\3/{batch_id}",
),
(
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/file_batches)$",
r"\1/{vector_store_id}\3",
),
(r"^(/(?:openai/)?v1/vector_stores)/([^/]+)$", r"\1/{vector_store_id}"),
# Assistants API
(r"^(/(?:openai/)?v1/assistants)/([^/]+)$", r"\1/{assistant_id}"),
# Files API
(r"^(/(?:openai/)?v1/files)/([^/]+)(/content)$", r"\1/{file_id}\3"),
(r"^(/(?:openai/)?v1/files)/([^/]+)$", r"\1/{file_id}"),
# Batches API
(r"^(/(?:openai/)?v1/batches)/([^/]+)(/cancel)$", r"\1/{batch_id}\3"),
(r"^(/(?:openai/)?v1/batches)/([^/]+)$", r"\1/{batch_id}"),
# Fine-tuning API
(
r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)(/events)$",
r"\1/{fine_tuning_job_id}\3",
),
(
r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)(/cancel)$",
r"\1/{fine_tuning_job_id}\3",
),
(
r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)(/checkpoints)$",
r"\1/{fine_tuning_job_id}\3",
),
(r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)$", r"\1/{fine_tuning_job_id}"),
# Models API
(r"^(/(?:openai/)?v1/models)/([^/]+)$", r"\1/{model}"),
]
# Apply patterns in order
for pattern, replacement in patterns:
normalized = re.sub(pattern, replacement, route)
if normalized != route:
return normalized
# Return original route if no pattern matched
return route
async def check_if_request_size_is_safe(request: Request) -> bool:
"""
Enterprise Only:
- Checks if the request size is within the limit
Args:
request (Request): The incoming request.
Returns:
bool: True if the request size is within the limit
Raises:
ProxyException: If the request size is too large
"""
from litellm.proxy.proxy_server import general_settings, premium_user
max_request_size_mb = general_settings.get("max_request_size_mb", None)
if max_request_size_mb is not None:
# Check if premium user
if premium_user is not True:
verbose_proxy_logger.warning(
f"using max_request_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
)
return True
# Get the request body
content_length = request.headers.get("content-length")
if content_length:
header_size = int(content_length)
header_size_mb = bytes_to_mb(bytes_value=header_size)
verbose_proxy_logger.debug(
f"content_length request size in MB={header_size_mb}"
)
if header_size_mb > max_request_size_mb:
raise ProxyException(
message=f"Request size is too large. Request size is {header_size_mb} MB. Max size is {max_request_size_mb} MB",
type=ProxyErrorTypes.bad_request_error.value,
code=400,
param="content-length",
)
else:
# If Content-Length is not available, read the body
body = await request.body()
body_size = len(body)
request_size_mb = bytes_to_mb(bytes_value=body_size)
verbose_proxy_logger.debug(
f"request body request size in MB={request_size_mb}"
)
if request_size_mb > max_request_size_mb:
raise ProxyException(
message=f"Request size is too large. Request size is {request_size_mb} MB. Max size is {max_request_size_mb} MB",
type=ProxyErrorTypes.bad_request_error.value,
code=400,
param="content-length",
)
return True
async def check_response_size_is_safe(response: Any) -> bool:
"""
Enterprise Only:
- Checks if the response size is within the limit
Args:
response (Any): The response to check.
Returns:
bool: True if the response size is within the limit
Raises:
ProxyException: If the response size is too large
"""
from litellm.proxy.proxy_server import general_settings, premium_user
max_response_size_mb = general_settings.get("max_response_size_mb", None)
if max_response_size_mb is not None:
# Check if premium user
if premium_user is not True:
verbose_proxy_logger.warning(
f"using max_response_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
)
return True
response_size_mb = bytes_to_mb(bytes_value=sys.getsizeof(response))
verbose_proxy_logger.debug(f"response size in MB={response_size_mb}")
if response_size_mb > max_response_size_mb:
raise ProxyException(
message=f"Response size is too large. Response size is {response_size_mb} MB. Max size is {max_response_size_mb} MB",
type=ProxyErrorTypes.bad_request_error.value,
code=400,
param="content-length",
)
return True
def bytes_to_mb(bytes_value: int):
"""
Helper to convert bytes to MB
"""
return bytes_value / (1024 * 1024)
# helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key
def get_key_model_rpm_limit(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[Dict[str, int]]:
"""
Get the model rpm limit for a given api key.
Priority order (returns first found):
1. Key metadata (model_rpm_limit)
2. Key model_max_budget (rpm_limit per model)
3. Team metadata (model_rpm_limit)
"""
# 1. Check key metadata first (takes priority)
if user_api_key_dict.metadata:
result = user_api_key_dict.metadata.get("model_rpm_limit")
if result:
return result
# 2. Check model_max_budget
if user_api_key_dict.model_max_budget:
model_rpm_limit: Dict[str, Any] = {}
for model, budget in user_api_key_dict.model_max_budget.items():
if isinstance(budget, dict) and budget.get("rpm_limit") is not None:
model_rpm_limit[model] = budget["rpm_limit"]
if model_rpm_limit:
return model_rpm_limit
# 3. Fallback to team metadata
if user_api_key_dict.team_metadata:
return user_api_key_dict.team_metadata.get("model_rpm_limit")
return None
def get_key_model_tpm_limit(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[Dict[str, int]]:
"""
Get the model tpm limit for a given api key.
Priority order (returns first found):
1. Key metadata (model_tpm_limit)
2. Key model_max_budget (tpm_limit per model)
3. Team metadata (model_tpm_limit)
"""
# 1. Check key metadata first (takes priority)
if user_api_key_dict.metadata:
result = user_api_key_dict.metadata.get("model_tpm_limit")
if result:
return result
# 2. Check model_max_budget (iterate per-model like RPM does)
if user_api_key_dict.model_max_budget:
model_tpm_limit: Dict[str, Any] = {}
for model, budget in user_api_key_dict.model_max_budget.items():
if isinstance(budget, dict) and budget.get("tpm_limit") is not None:
model_tpm_limit[model] = budget["tpm_limit"]
if model_tpm_limit:
return model_tpm_limit
# 3. Fallback to team metadata
if user_api_key_dict.team_metadata:
return user_api_key_dict.team_metadata.get("model_tpm_limit")
return None
def get_model_rate_limit_from_metadata(
user_api_key_dict: UserAPIKeyAuth,
metadata_accessor_key: Literal["team_metadata", "organization_metadata"],
rate_limit_key: Literal["model_rpm_limit", "model_tpm_limit"],
) -> Optional[Dict[str, int]]:
if getattr(user_api_key_dict, metadata_accessor_key):
return getattr(user_api_key_dict, metadata_accessor_key).get(rate_limit_key)
return None
def get_team_model_rpm_limit(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[Dict[str, int]]:
if user_api_key_dict.team_metadata:
return user_api_key_dict.team_metadata.get("model_rpm_limit")
return None
def get_team_model_tpm_limit(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[Dict[str, int]]:
if user_api_key_dict.team_metadata:
return user_api_key_dict.team_metadata.get("model_tpm_limit")
return None
def is_pass_through_provider_route(route: str) -> bool:
PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES = [
"vertex-ai",
]
# check if any of the prefixes are in the route
for prefix in PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES:
if prefix in route:
return True
return False
def _has_user_setup_sso():
"""
Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID or generic client ID and UI username environment variables.
Returns a boolean indicating whether SSO has been set up.
"""
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
sso_setup = (
(microsoft_client_id is not None)
or (google_client_id is not None)
or (generic_client_id is not None)
)
return sso_setup
def get_customer_user_header_from_mapping(user_id_mapping) -> Optional[str]:
"""Return the header_name mapped to CUSTOMER role, if any (dict-based)."""
if not user_id_mapping:
return None
items = user_id_mapping if isinstance(user_id_mapping, list) else [user_id_mapping]
for item in items:
if not isinstance(item, dict):
continue
role = item.get("litellm_user_role")
header_name = item.get("header_name")
if role is None or not header_name:
continue
if str(role).lower() == str(LitellmUserRoles.CUSTOMER).lower():
return header_name
return None
def _get_customer_id_from_standard_headers(
request_headers: Optional[dict],
) -> Optional[str]:
"""
Check standard customer ID headers for a customer/end-user ID.
This enables tools like Claude Code to pass customer IDs via ANTHROPIC_CUSTOM_HEADERS.
No configuration required - these headers are always checked.
Args:
request_headers: The request headers dict
Returns:
The customer ID if found in standard headers, None otherwise
"""
if request_headers is None:
return None
for standard_header in STANDARD_CUSTOMER_ID_HEADERS:
for header_name, header_value in request_headers.items():
if header_name.lower() == standard_header.lower():
user_id_str = str(header_value) if header_value is not None else ""
if user_id_str.strip():
return user_id_str
return None
def get_end_user_id_from_request_body(
request_body: dict, request_headers: Optional[dict] = None
) -> Optional[str]:
# Import general_settings here to avoid potential circular import issues at module level
# and to ensure it's fetched at runtime.
from litellm.proxy.proxy_server import general_settings
# Check 1: Standard customer ID headers (always checked, no configuration required)
customer_id = _get_customer_id_from_standard_headers(
request_headers=request_headers
)
if customer_id is not None:
return customer_id
# Check 2: Follow the user header mappings feature, if not found, then check for deprecated user_header_name (only if request_headers is provided)
# User query: "system not respecting user_header_name property"
# This implies the key in general_settings is 'user_header_name'.
if request_headers is not None:
custom_header_name_to_check: Optional[str] = None
# Prefer user mappings (new behavior)
user_id_mapping = general_settings.get("user_header_mappings", None)
if user_id_mapping:
custom_header_name_to_check = get_customer_user_header_from_mapping(
user_id_mapping
)
# Fallback to deprecated user_header_name if mapping did not specify
if not custom_header_name_to_check:
user_id_header_config_key = "user_header_name"
value = general_settings.get(user_id_header_config_key)
if isinstance(value, str) and value.strip() != "":
custom_header_name_to_check = value
# If we have a header name to check, try to read it from request headers
if isinstance(custom_header_name_to_check, str):
for header_name, header_value in request_headers.items():
if header_name.lower() == custom_header_name_to_check.lower():
user_id_from_header = header_value
user_id_str = (
str(user_id_from_header)
if user_id_from_header is not None
else ""
)
if user_id_str.strip():
return user_id_str
# Check 3: 'user' field in request_body (commonly OpenAI)
if "user" in request_body and request_body["user"] is not None:
user_from_body_user_field = request_body["user"]
return str(user_from_body_user_field)
# Check 4: 'litellm_metadata.user' in request_body (commonly Anthropic)
litellm_metadata = request_body.get("litellm_metadata")
if isinstance(litellm_metadata, dict):
user_from_litellm_metadata = litellm_metadata.get("user")
if user_from_litellm_metadata is not None:
return str(user_from_litellm_metadata)
# Check 5: 'metadata.user_id' in request_body (another common pattern)
metadata_dict = request_body.get("metadata")
if isinstance(metadata_dict, dict):
user_id_from_metadata_field = metadata_dict.get("user_id")
if user_id_from_metadata_field is not None:
return str(user_id_from_metadata_field)
# Check 6: 'safety_identifier' in request body (OpenAI Responses API parameter)
# SECURITY NOTE: safety_identifier can be set by any caller in the request body.
# Only use this for end-user identification in trusted environments where you control
# the calling application. For untrusted callers, prefer using headers or server-side
# middleware to set the end_user_id to prevent impersonation.
if request_body.get("safety_identifier") is not None:
user_from_body_user_field = request_body["safety_identifier"]
return str(user_from_body_user_field)
return None
def get_model_from_request(
request_data: dict, route: str
) -> Optional[Union[str, List[str]]]:
# First try to get model from request_data
model = request_data.get("model") or request_data.get("target_model_names")
if model is not None:
model_names = model.split(",")
if len(model_names) == 1:
model = model_names[0].strip()
else:
model = [m.strip() for m in model_names]
# If model not in request_data, try to extract from route
if model is None:
# Parse model from route that follows the pattern /openai/deployments/{model}/*
match = re.match(r"/openai/deployments/([^/]+)", route)
if match:
model = match.group(1)
# If still not found, extract model from Google generateContent-style routes.
# These routes put the model in the path and allow "/" inside the model id.
# Examples:
# - /v1beta/models/gemini-2.0-flash:generateContent
# - /v1beta/models/bedrock/claude-sonnet-3.7:generateContent
# - /models/custom/ns/model:streamGenerateContent
if model is None and not route.lower().startswith("/vertex"):
google_match = re.search(r"/(?:v1beta|beta)/models/([^:]+):", route)
if google_match:
model = google_match.group(1)
if model is None and not route.lower().startswith("/vertex"):
google_match = re.search(r"^/models/([^:]+):", route)
if google_match:
model = google_match.group(1)
# If still not found, extract from Vertex AI passthrough route
# Pattern: /vertex_ai/.../models/{model_id}:*
# Example: /vertex_ai/v1/.../models/gemini-1.5-pro:generateContent
if model is None and route.lower().startswith("/vertex"):
vertex_match = re.search(r"/models/([^:]+)", route)
if vertex_match:
model = vertex_match.group(1)
return model
def abbreviate_api_key(api_key: str) -> str:
return f"sk-...{api_key[-4:]}"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,156 @@
"""
IP address utilities for MCP public/private access control.
Internal callers (private IPs) see all MCP servers.
External callers (public IPs) only see servers with available_on_public_internet=True.
"""
import ipaddress
from typing import Any, Dict, List, Optional, Union
from fastapi import Request
from litellm._logging import verbose_proxy_logger
from litellm.proxy.auth.auth_utils import _get_request_ip_address
class IPAddressUtils:
"""Static utilities for IP-based MCP access control."""
_DEFAULT_INTERNAL_NETWORKS = [
ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("127.0.0.0/8"),
ipaddress.ip_network("::1/128"),
ipaddress.ip_network("fc00::/7"),
]
@staticmethod
def parse_internal_networks(
configured_ranges: Optional[List[str]],
) -> List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]:
"""Parse configured CIDR ranges into network objects, falling back to defaults."""
if not configured_ranges:
return IPAddressUtils._DEFAULT_INTERNAL_NETWORKS
networks: List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]] = []
for cidr in configured_ranges:
try:
networks.append(ipaddress.ip_network(cidr, strict=False))
except ValueError:
verbose_proxy_logger.warning(
"Invalid CIDR in mcp_internal_ip_ranges: %s, skipping", cidr
)
return networks if networks else IPAddressUtils._DEFAULT_INTERNAL_NETWORKS
@staticmethod
def parse_trusted_proxy_networks(
configured_ranges: Optional[List[str]],
) -> List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]:
"""
Parse trusted proxy CIDR ranges for XFF validation.
Returns empty list if not configured (XFF will not be trusted).
"""
if not configured_ranges:
return []
networks: List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]] = []
for cidr in configured_ranges:
try:
networks.append(ipaddress.ip_network(cidr, strict=False))
except ValueError:
verbose_proxy_logger.warning(
"Invalid CIDR in mcp_trusted_proxy_ranges: %s, skipping", cidr
)
return networks
@staticmethod
def is_trusted_proxy(
proxy_ip: Optional[str],
trusted_networks: List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]],
) -> bool:
"""Check if the direct connection IP is from a trusted proxy."""
if not proxy_ip or not trusted_networks:
return False
try:
addr = ipaddress.ip_address(proxy_ip.strip())
return any(addr in network for network in trusted_networks)
except ValueError:
return False
@staticmethod
def is_internal_ip(
client_ip: Optional[str],
internal_networks: Optional[
List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]
] = None,
) -> bool:
"""
Check if a client IP is from an internal/private network.
Handles X-Forwarded-For comma chains (takes leftmost = original client).
Fails closed: empty/invalid IPs are treated as external.
"""
if not client_ip:
return False
# X-Forwarded-For may contain comma-separated chain; leftmost is original client
if "," in client_ip:
client_ip = client_ip.split(",")[0].strip()
networks = internal_networks or IPAddressUtils._DEFAULT_INTERNAL_NETWORKS
try:
addr = ipaddress.ip_address(client_ip.strip())
except ValueError:
return False
return any(addr in network for network in networks)
@staticmethod
def get_mcp_client_ip(
request: Request,
general_settings: Optional[Dict[str, Any]] = None,
) -> Optional[str]:
"""
Extract client IP from a FastAPI request for MCP access control.
Security: Only trusts X-Forwarded-For if:
1. use_x_forwarded_for is enabled in settings
2. The direct connection is from a trusted proxy (if mcp_trusted_proxy_ranges configured)
Args:
request: FastAPI request object
general_settings: Optional settings dict. If not provided, imports from proxy_server.
"""
if general_settings is None:
try:
from litellm.proxy.proxy_server import (
general_settings as proxy_general_settings,
)
general_settings = proxy_general_settings
except ImportError:
general_settings = {}
# Handle case where general_settings is still None after import
if general_settings is None:
general_settings = {}
use_xff = general_settings.get("use_x_forwarded_for", False)
# If XFF is enabled, validate the request comes from a trusted proxy
if use_xff and "x-forwarded-for" in request.headers:
trusted_ranges = general_settings.get("mcp_trusted_proxy_ranges")
if trusted_ranges:
# Validate direct connection is from trusted proxy
direct_ip = request.client.host if request.client else None
trusted_networks = IPAddressUtils.parse_trusted_proxy_networks(
trusted_ranges
)
if not IPAddressUtils.is_trusted_proxy(direct_ip, trusted_networks):
# Untrusted source trying to set XFF - ignore XFF, use direct IP
verbose_proxy_logger.warning(
"XFF header from untrusted IP %s, ignoring", direct_ip
)
return direct_ip
return _get_request_ip_address(request, use_x_forwarded_for=use_xff)

View File

@@ -0,0 +1,214 @@
# What is this?
## If litellm license in env, checks if it's valid
import base64
import json
import os
from datetime import datetime
from typing import TYPE_CHECKING, Optional
import httpx
from litellm._logging import verbose_proxy_logger
from litellm.constants import NON_LLM_CONNECTION_TIMEOUT
from litellm.llms.custom_httpx.http_handler import HTTPHandler
if TYPE_CHECKING:
from litellm.proxy._types import EnterpriseLicenseData
class LicenseCheck:
"""
- Check if license in env
- Returns if license is valid
"""
base_url = "https://license.litellm.ai"
def __init__(self) -> None:
self.license_str = os.getenv("LITELLM_LICENSE", None)
verbose_proxy_logger.debug("License Str value - {}".format(self.license_str))
self.http_handler = HTTPHandler(timeout=NON_LLM_CONNECTION_TIMEOUT)
self._premium_check_logged = False
self.public_key = None
self.read_public_key()
self.airgapped_license_data: Optional["EnterpriseLicenseData"] = None
def read_public_key(self):
try:
from cryptography.hazmat.primitives import serialization
# current dir
current_dir = os.path.dirname(os.path.realpath(__file__))
# check if public_key.pem exists
_path_to_public_key = os.path.join(current_dir, "public_key.pem")
if os.path.exists(_path_to_public_key):
with open(_path_to_public_key, "rb") as key_file:
self.public_key = serialization.load_pem_public_key(key_file.read())
else:
self.public_key = None
except Exception as e:
verbose_proxy_logger.error(f"Error reading public key: {str(e)}")
def _verify(self, license_str: str) -> bool:
verbose_proxy_logger.debug(
"litellm.proxy.auth.litellm_license.py::_verify - Checking license against {}/verify_license - {}".format(
self.base_url, license_str
)
)
url = "{}/verify_license/{}".format(self.base_url, license_str)
response: Optional[httpx.Response] = None
try: # don't impact user, if call fails
num_retries = 3
for i in range(num_retries):
try:
response = self.http_handler.get(url=url)
if response is None:
raise Exception("No response from license server")
response.raise_for_status()
except httpx.HTTPStatusError:
if i == num_retries - 1:
raise
if response is None:
raise Exception("No response from license server")
response_json = response.json()
premium = response_json["verify"]
assert isinstance(premium, bool)
verbose_proxy_logger.debug(
"litellm.proxy.auth.litellm_license.py::_verify - License={} is premium={}".format(
license_str, premium
)
)
return premium
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.auth.litellm_license.py::_verify - Unable to verify License={} via api. - {}".format(
license_str, str(e)
)
)
return False
def is_premium(self) -> bool:
"""
1. verify_license_without_api_request: checks if license was generate using private / public key pair
2. _verify: checks if license is valid calling litellm API. This is the old way we were generating/validating license
"""
try:
if not self._premium_check_logged:
verbose_proxy_logger.debug(
"litellm.proxy.auth.litellm_license.py::is_premium() - ENTERING 'IS_PREMIUM' - LiteLLM License={}".format(
self.license_str
)
)
if self.license_str is None:
self.license_str = os.getenv("LITELLM_LICENSE", None)
if not self._premium_check_logged:
verbose_proxy_logger.debug(
"litellm.proxy.auth.litellm_license.py::is_premium() - Updated 'self.license_str' - {}".format(
self.license_str
)
)
self._premium_check_logged = True
if self.license_str is None:
return False
elif (
self.verify_license_without_api_request(
public_key=self.public_key, license_key=self.license_str
)
is True
):
return True
elif self._verify(license_str=self.license_str) is True:
return True
return False
except Exception:
return False
def is_over_limit(self, total_users: int) -> bool:
"""
Check if the license is over the limit
"""
if self.airgapped_license_data is None:
return False
if "max_users" not in self.airgapped_license_data or not isinstance(
self.airgapped_license_data["max_users"], int
):
return False
return total_users > self.airgapped_license_data["max_users"]
def is_team_count_over_limit(self, team_count: int) -> bool:
"""
Check if the license is over the limit
"""
if self.airgapped_license_data is None:
return False
_max_teams_in_license: Optional[int] = self.airgapped_license_data.get(
"max_teams"
)
if "max_teams" not in self.airgapped_license_data or not isinstance(
_max_teams_in_license, int
):
return False
return team_count > _max_teams_in_license
def verify_license_without_api_request(self, public_key, license_key):
try:
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from litellm.proxy._types import EnterpriseLicenseData
# Decode the license key - add padding if needed for base64
# Base64 strings need to be a multiple of 4 characters
padding_needed = len(license_key) % 4
if padding_needed:
license_key += "=" * (4 - padding_needed)
decoded = base64.b64decode(license_key)
message, signature = decoded.split(b".", 1)
# Verify the signature
public_key.verify(
signature,
message,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
hashes.SHA256(),
)
# Decode and parse the data
license_data = json.loads(message.decode())
self.airgapped_license_data = EnterpriseLicenseData(**license_data)
# debug information provided in license data
verbose_proxy_logger.debug("License data: %s", license_data)
# Check expiration date
expiration_date = datetime.strptime(
license_data["expiration_date"], "%Y-%m-%d"
)
if expiration_date < datetime.now():
return False, "License has expired"
return True
except Exception as e:
verbose_proxy_logger.debug(
"litellm.proxy.auth.litellm_license.py::verify_license_without_api_request - Unable to verify License locally. - {}".format(
str(e)
)
)
return False

View File

@@ -0,0 +1,344 @@
"""
Login utilities for handling user authentication in the proxy server.
This module contains the core login logic that can be reused across different
login endpoints (e.g., /login and /v2/login).
"""
import os
import secrets
from typing import Literal, Optional, cast
from fastapi import HTTPException
import litellm
from litellm.constants import LITELLM_PROXY_ADMIN_NAME, LITELLM_UI_SESSION_DURATION
from litellm.proxy._types import (
LiteLLM_UserTable,
LitellmUserRoles,
ProxyErrorTypes,
ProxyException,
UpdateUserRequest,
UserAPIKeyAuth,
hash_token,
)
from litellm.proxy.management_endpoints.internal_user_endpoints import user_update
from litellm.proxy.management_endpoints.key_management_endpoints import (
generate_key_helper_fn,
)
from litellm.proxy.management_endpoints.ui_sso import (
get_disabled_non_admin_personal_key_creation,
)
from litellm.proxy.utils import PrismaClient, get_server_root_path
from litellm.secret_managers.main import get_secret_bool
from litellm.types.proxy.ui_sso import ReturnedUITokenObject
def get_ui_credentials(master_key: Optional[str]) -> tuple[str, str]:
"""
Get UI username and password from environment variables or master key.
Args:
master_key: Master key for the proxy (used as fallback for password)
Returns:
tuple[str, str]: A tuple containing (ui_username, ui_password)
Raises:
ProxyException: If neither UI_PASSWORD nor master_key is available
"""
ui_username = os.getenv("UI_USERNAME", "admin")
ui_password = os.getenv("UI_PASSWORD", None)
if ui_password is None:
ui_password = str(master_key) if master_key is not None else None
if ui_password is None:
raise ProxyException(
message="set Proxy master key to use UI. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
type=ProxyErrorTypes.auth_error,
param="UI_PASSWORD",
code=500,
)
return ui_username, ui_password
class LoginResult:
"""Result object containing authentication data from login."""
user_id: str
key: str
user_email: Optional[str]
user_role: str
login_method: Literal["sso", "username_password"]
def __init__(
self,
user_id: str,
key: str,
user_email: Optional[str],
user_role: str,
login_method: Literal["sso", "username_password"] = "username_password",
):
self.user_id = user_id
self.key = key
self.user_email = user_email
self.user_role = user_role
self.login_method = login_method
async def authenticate_user( # noqa: PLR0915
username: str,
password: str,
master_key: Optional[str],
prisma_client: Optional[PrismaClient],
) -> LoginResult:
"""
Authenticate a user and generate an API key for UI access.
This function handles two login scenarios:
1. Admin login using UI_USERNAME and UI_PASSWORD
2. User login using email and password from database
Args:
username: Username or email from the login form
password: Password from the login form
master_key: Master key for the proxy (required)
prisma_client: Prisma database client (optional)
Returns:
LoginResult: Object containing authentication data
Raises:
ProxyException: If authentication fails or required configuration is missing
"""
if master_key is None:
raise ProxyException(
message="Master Key not set for Proxy. Please set Master Key to use Admin UI. Set `LITELLM_MASTER_KEY` in .env or set general_settings:master_key in config.yaml. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
type=ProxyErrorTypes.auth_error,
param="master_key",
code=500,
)
ui_username, ui_password = get_ui_credentials(master_key)
# Check if we can find the `username` in the db. On the UI, users can enter username=their email
_user_row: Optional[LiteLLM_UserTable] = None
user_role: Optional[
Literal[
LitellmUserRoles.PROXY_ADMIN,
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
LitellmUserRoles.INTERNAL_USER,
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
]
] = None
if prisma_client is not None:
_user_row = cast(
Optional[LiteLLM_UserTable],
await prisma_client.db.litellm_usertable.find_first(
where={"user_email": {"equals": username, "mode": "insensitive"}}
),
)
"""
To login to Admin UI, we support the following
- Login with UI_USERNAME and UI_PASSWORD
- Login with Invite Link `user_email` and `password` combination
"""
if secrets.compare_digest(
username.encode("utf-8"), ui_username.encode("utf-8")
) and secrets.compare_digest(password.encode("utf-8"), ui_password.encode("utf-8")):
# Non SSO -> If user is using UI_USERNAME and UI_PASSWORD they are Proxy admin
user_role = LitellmUserRoles.PROXY_ADMIN
user_id = LITELLM_PROXY_ADMIN_NAME
# we want the key created to have PROXY_ADMIN_PERMISSIONS
key_user_id = LITELLM_PROXY_ADMIN_NAME
if (
os.getenv("PROXY_ADMIN_ID", None) is not None
and os.environ["PROXY_ADMIN_ID"] == user_id
) or user_id == LITELLM_PROXY_ADMIN_NAME:
# checks if user is admin
key_user_id = os.getenv("PROXY_ADMIN_ID", LITELLM_PROXY_ADMIN_NAME)
# Admin is Authe'd in - generate key for the UI to access Proxy
# ensure this user is set as the proxy admin, in this route there is no sso, we can assume this user is only the admin
await user_update(
data=UpdateUserRequest(
user_id=key_user_id,
user_role=user_role,
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
),
)
if os.getenv("DATABASE_URL") is not None:
response = await generate_key_helper_fn(
request_type="key",
**{
"user_role": LitellmUserRoles.PROXY_ADMIN,
"duration": LITELLM_UI_SESSION_DURATION,
"key_max_budget": litellm.max_ui_session_budget,
"models": [],
"aliases": {},
"config": {},
"spend": 0,
"user_id": key_user_id,
"team_id": "litellm-dashboard",
}, # type: ignore
)
else:
raise ProxyException(
message="No Database connected. Set DATABASE_URL in .env. If set, use `--detailed_debug` to debug issue.",
type=ProxyErrorTypes.auth_error,
param="DATABASE_URL",
code=500,
)
key = response["token"] # type: ignore
if get_secret_bool("EXPERIMENTAL_UI_LOGIN"):
from litellm.proxy.auth.auth_checks import ExperimentalUIJWTToken
user_info: Optional[LiteLLM_UserTable] = None
if _user_row is not None:
user_info = _user_row
elif (
user_id is not None
): # if user_id is not None, we are using the UI_USERNAME and UI_PASSWORD
user_info = LiteLLM_UserTable(
user_id=user_id,
user_role=user_role,
models=[],
max_budget=litellm.max_ui_session_budget,
)
if user_info is None:
raise HTTPException(
status_code=401,
detail={
"error": "User Information is required for experimental UI login"
},
)
key = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token(
user_info
)
return LoginResult(
user_id=user_id,
key=key,
user_email=None,
user_role=user_role,
login_method="username_password",
)
elif _user_row is not None:
"""
When sharing invite links
-> if the user has no role in the DB assume they are only a viewer
"""
user_id = getattr(_user_row, "user_id", "unknown")
user_role = getattr(
_user_row, "user_role", LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
)
user_email = getattr(_user_row, "user_email", "unknown")
_password = getattr(_user_row, "password", "unknown")
if _password is None:
raise ProxyException(
message="User has no password set. Please set a password for the user via `/user/update`.",
type=ProxyErrorTypes.auth_error,
param="password",
code=401,
)
# check if password == _user_row.password
hash_password = hash_token(token=password)
if secrets.compare_digest(
password.encode("utf-8"), _password.encode("utf-8")
) or secrets.compare_digest(
hash_password.encode("utf-8"), _password.encode("utf-8")
):
if os.getenv("DATABASE_URL") is not None:
response = await generate_key_helper_fn(
request_type="key",
**{ # type: ignore
"user_role": user_role,
"duration": LITELLM_UI_SESSION_DURATION,
"key_max_budget": litellm.max_ui_session_budget,
"models": [],
"aliases": {},
"config": {},
"spend": 0,
"user_id": user_id,
"team_id": "litellm-dashboard",
},
)
else:
raise ProxyException(
message="No Database connected. Set DATABASE_URL in .env. If set, use `--detailed_debug` to debug issue.",
type=ProxyErrorTypes.auth_error,
param="DATABASE_URL",
code=500,
)
key = response["token"] # type: ignore
return LoginResult(
user_id=user_id,
key=key,
user_email=user_email,
user_role=cast(str, user_role),
login_method="username_password",
)
else:
raise ProxyException(
message=f"Invalid credentials used to access UI.\nNot valid credentials for {username}",
type=ProxyErrorTypes.auth_error,
param="invalid_credentials",
code=401,
)
else:
raise ProxyException(
message="Invalid credentials used to access UI.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
type=ProxyErrorTypes.auth_error,
param="invalid_credentials",
code=401,
)
def create_ui_token_object(
login_result: LoginResult,
general_settings: dict,
premium_user: bool,
) -> ReturnedUITokenObject:
"""
Create a ReturnedUITokenObject from a LoginResult.
Args:
login_result: The result from authenticate_user
general_settings: General proxy settings dictionary
premium_user: Whether premium features are enabled
Returns:
ReturnedUITokenObject: Token object ready for JWT encoding
"""
disabled_non_admin_personal_key_creation = (
get_disabled_non_admin_personal_key_creation()
)
return ReturnedUITokenObject(
user_id=login_result.user_id,
key=login_result.key,
user_email=login_result.user_email,
user_role=login_result.user_role,
login_method=login_result.login_method,
premium_user=premium_user,
auth_header_name=general_settings.get(
"litellm_key_header_name", "Authorization"
),
disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation,
server_root_path=get_server_root_path(),
)

View File

@@ -0,0 +1,381 @@
# What is this?
## Common checks for /v1/models and `/model/info`
from typing import Dict, List, Optional, Set
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth
from litellm.router import Router
from litellm.router_utils.fallback_event_handlers import get_fallback_model_group
from litellm.types.router import LiteLLM_Params
from litellm.utils import get_valid_models
def _check_wildcard_routing(model: str) -> bool:
"""
Returns True if a model is a provider wildcard.
eg:
- anthropic/*
- openai/*
- *
"""
if "*" in model:
return True
return False
def get_provider_models(
provider: str, litellm_params: Optional[LiteLLM_Params] = None
) -> Optional[List[str]]:
"""
Returns the list of known models by provider
"""
if provider == "*":
return get_valid_models(litellm_params=litellm_params)
if provider in litellm.models_by_provider:
provider_models = get_valid_models(
custom_llm_provider=provider, litellm_params=litellm_params
)
return provider_models
return None
def _get_models_from_access_groups(
model_access_groups: Dict[str, List[str]],
all_models: List[str],
include_model_access_groups: Optional[bool] = False,
) -> List[str]:
idx_to_remove = []
new_models = []
for idx, model in enumerate(all_models):
if model in model_access_groups:
if (
not include_model_access_groups
): # remove access group, unless requested - e.g. when creating a key
idx_to_remove.append(idx)
new_models.extend(model_access_groups[model])
for idx in sorted(idx_to_remove, reverse=True):
all_models.pop(idx)
all_models.extend(new_models)
return all_models
async def get_mcp_server_ids(
user_api_key_dict: UserAPIKeyAuth,
) -> List[str]:
"""
Returns the list of MCP server ids for a given key by querying the object_permission table
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
return []
if user_api_key_dict.object_permission_id is None:
return []
# Make a direct SQL query to get just the mcp_servers
try:
result = await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": user_api_key_dict.object_permission_id},
)
if result and result.mcp_servers:
return result.mcp_servers
return []
except Exception:
return []
def get_key_models(
user_api_key_dict: UserAPIKeyAuth,
proxy_model_list: List[str],
model_access_groups: Dict[str, List[str]],
include_model_access_groups: Optional[bool] = False,
only_model_access_groups: Optional[bool] = False,
) -> List[str]:
"""
Returns:
- List of model name strings
- Empty list if no models set
- If model_access_groups is provided, only return models that are in the access groups
- If include_model_access_groups is True, it includes the 'keys' of the model_access_groups
in the response - {"beta-models": ["gpt-4", "claude-v1"]} -> returns 'beta-models'
"""
all_models: List[str] = []
if len(user_api_key_dict.models) > 0:
all_models = list(
user_api_key_dict.models
) # copy to avoid mutating cached objects
if SpecialModelNames.all_team_models.value in all_models:
all_models = list(
user_api_key_dict.team_models
) # copy to avoid mutating cached objects
if SpecialModelNames.all_proxy_models.value in all_models:
all_models = list(proxy_model_list) # copy to avoid mutating caller's list
if include_model_access_groups:
all_models.extend(model_access_groups.keys())
all_models = _get_models_from_access_groups(
model_access_groups=model_access_groups,
all_models=all_models,
include_model_access_groups=include_model_access_groups,
)
# deduplicate while preserving order
all_models = list(dict.fromkeys(all_models))
verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
return all_models
def get_team_models(
team_models: List[str],
proxy_model_list: List[str],
model_access_groups: Dict[str, List[str]],
include_model_access_groups: Optional[bool] = False,
) -> List[str]:
"""
Returns:
- List of model name strings
- Empty list if no models set
- If model_access_groups is provided, only return models that are in the access groups
"""
all_models_set: Set[str] = set()
if len(team_models) > 0:
all_models_set.update(team_models)
if SpecialModelNames.all_team_models.value in all_models_set:
all_models_set.update(team_models)
if SpecialModelNames.all_proxy_models.value in all_models_set:
all_models_set.update(proxy_model_list)
if include_model_access_groups:
all_models_set.update(model_access_groups.keys())
all_models = _get_models_from_access_groups(
model_access_groups=model_access_groups,
all_models=list(all_models_set),
include_model_access_groups=include_model_access_groups,
)
# deduplicate while preserving order
all_models = list(dict.fromkeys(all_models))
verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models)))
return all_models
def get_complete_model_list(
key_models: List[str],
team_models: List[str],
proxy_model_list: List[str],
user_model: Optional[str],
infer_model_from_keys: Optional[bool],
return_wildcard_routes: Optional[bool] = False,
llm_router: Optional[Router] = None,
model_access_groups: Dict[str, List[str]] = {},
include_model_access_groups: Optional[bool] = False,
only_model_access_groups: Optional[bool] = False,
) -> List[str]:
"""Logic for returning complete model list for a given key + team pair"""
"""
- If key list is empty -> defer to team list
- If team list is empty -> defer to proxy model list
If list contains wildcard -> return known provider models
"""
unique_models = []
def append_unique(models):
for model in models:
if model not in unique_models:
unique_models.append(model)
if key_models:
append_unique(key_models)
elif team_models:
append_unique(team_models)
else:
append_unique(proxy_model_list)
if include_model_access_groups:
append_unique(list(model_access_groups.keys())) # TODO: keys order
if user_model:
append_unique([user_model])
if infer_model_from_keys:
valid_models = get_valid_models()
append_unique(valid_models)
if only_model_access_groups:
model_access_groups_to_return: List[str] = []
for model in unique_models:
if model in model_access_groups:
model_access_groups_to_return.append(model)
return model_access_groups_to_return
all_wildcard_models = _get_wildcard_models(
unique_models=unique_models,
return_wildcard_routes=return_wildcard_routes,
llm_router=llm_router,
)
complete_model_list = unique_models + all_wildcard_models
return complete_model_list
def get_known_models_from_wildcard(
wildcard_model: str, litellm_params: Optional[LiteLLM_Params] = None
) -> List[str]:
try:
wildcard_provider_prefix, wildcard_suffix = wildcard_model.split("/", 1)
except ValueError: # safely fail
return []
# Use provider from litellm_params when available, otherwise from wildcard prefix
# (e.g., "openai" from "openai/*" - needed for BYOK where wildcard isn't in router)
if litellm_params is not None:
try:
provider = litellm_params.model.split("/", 1)[0]
except ValueError:
provider = wildcard_provider_prefix
else:
provider = wildcard_provider_prefix
# get all known provider models
wildcard_models = get_provider_models(
provider=provider, litellm_params=litellm_params
)
if wildcard_models is None:
return []
if wildcard_suffix != "*":
## CHECK IF PARTIAL FILTER e.g. `gemini-*`
model_prefix = wildcard_suffix.replace("*", "")
is_partial_filter = any(
wc_model.startswith(model_prefix) for wc_model in wildcard_models
)
if is_partial_filter:
filtered_wildcard_models = [
wc_model
for wc_model in wildcard_models
if wc_model.startswith(model_prefix)
]
wildcard_models = filtered_wildcard_models
else:
# add model prefix to wildcard models
wildcard_models = [f"{model_prefix}{model}" for model in wildcard_models]
suffix_appended_wildcard_models = []
for model in wildcard_models:
if not model.startswith(wildcard_provider_prefix):
model = f"{wildcard_provider_prefix}/{model}"
suffix_appended_wildcard_models.append(model)
return suffix_appended_wildcard_models or []
def _get_wildcard_models(
unique_models: List[str],
return_wildcard_routes: Optional[bool] = False,
llm_router: Optional[Router] = None,
) -> List[str]:
models_to_remove = set()
all_wildcard_models = []
for model in unique_models:
if _check_wildcard_routing(model=model):
if (
return_wildcard_routes
): # will add the wildcard route to the list eg: anthropic/*.
all_wildcard_models.append(model)
## get litellm params from model
if llm_router is not None:
model_list = llm_router.get_model_list(model_name=model)
if model_list:
for router_model in model_list:
wildcard_models = get_known_models_from_wildcard(
wildcard_model=model,
litellm_params=LiteLLM_Params(
**router_model["litellm_params"] # type: ignore
),
)
all_wildcard_models.extend(wildcard_models)
else:
# Router has no deployment for this wildcard (e.g., BYOK team models)
# Fall back to expanding from known provider models
wildcard_models = get_known_models_from_wildcard(
wildcard_model=model, litellm_params=None
)
if wildcard_models:
models_to_remove.add(model)
all_wildcard_models.extend(wildcard_models)
else:
# get all known provider models
wildcard_models = get_known_models_from_wildcard(
wildcard_model=model, litellm_params=None
)
if wildcard_models:
models_to_remove.add(model)
all_wildcard_models.extend(wildcard_models)
for model in models_to_remove:
unique_models.remove(model)
return all_wildcard_models
def get_all_fallbacks(
model: str,
llm_router: Optional[Router] = None,
fallback_type: str = "general",
) -> List[str]:
"""
Get all fallbacks for a given model from the router's fallback configuration.
Args:
model: The model name to get fallbacks for
llm_router: The LiteLLM router instance
fallback_type: Type of fallback ("general", "context_window", "content_policy")
Returns:
List of fallback model names. Empty list if no fallbacks found.
"""
if llm_router is None:
return []
# Get the appropriate fallback list based on type
fallbacks_config: list = []
if fallback_type == "general":
fallbacks_config = getattr(llm_router, "fallbacks", [])
elif fallback_type == "context_window":
fallbacks_config = getattr(llm_router, "context_window_fallbacks", [])
elif fallback_type == "content_policy":
fallbacks_config = getattr(llm_router, "content_policy_fallbacks", [])
else:
verbose_proxy_logger.warning(f"Unknown fallback_type: {fallback_type}")
return []
if not fallbacks_config:
return []
try:
# Use existing function to get fallback model group
fallback_model_group, _ = get_fallback_model_group(
fallbacks=fallbacks_config, model_group=model
)
if fallback_model_group is None:
return []
return fallback_model_group
except Exception as e:
verbose_proxy_logger.error(f"Error getting fallbacks for model {model}: {e}")
return []

View File

@@ -0,0 +1,222 @@
import base64
import os
from typing import Dict, Optional, Tuple, cast
import httpx
from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
class Oauth2Handler:
"""
Handles OAuth2 token validation.
"""
@staticmethod
def _is_introspection_endpoint(
token_info_endpoint: str,
oauth_client_id: Optional[str],
oauth_client_secret: Optional[str],
) -> bool:
"""
Determine if this is an introspection endpoint (requires POST) or token info endpoint (uses GET).
Args:
token_info_endpoint: The OAuth2 endpoint URL
oauth_client_id: OAuth2 client ID
oauth_client_secret: OAuth2 client secret
Returns:
bool: True if this is an introspection endpoint
"""
return (
"introspect" in token_info_endpoint.lower()
and oauth_client_id is not None
and oauth_client_secret is not None
)
@staticmethod
def _prepare_introspection_request(
token: str,
oauth_client_id: Optional[str],
oauth_client_secret: Optional[str],
) -> Tuple[Dict[str, str], Dict[str, str]]:
"""
Prepare headers and data for OAuth2 introspection endpoint (RFC 7662).
Args:
token: The OAuth2 token to validate
oauth_client_id: OAuth2 client ID
oauth_client_secret: OAuth2 client secret
Returns:
Tuple of (headers, data) for the introspection request
"""
headers = {"Content-Type": "application/x-www-form-urlencoded"}
data = {"token": token}
# Add client authentication if credentials are provided
if oauth_client_id and oauth_client_secret:
# Use HTTP Basic authentication for client credentials
credentials = base64.b64encode(
f"{oauth_client_id}:{oauth_client_secret}".encode()
).decode()
headers["Authorization"] = f"Basic {credentials}"
elif oauth_client_id:
# For public clients, include client_id in the request body
data["client_id"] = oauth_client_id
return headers, data
@staticmethod
def _prepare_token_info_request(token: str) -> Dict[str, str]:
"""
Prepare headers for generic token info endpoint.
Args:
token: The OAuth2 token to validate
Returns:
Dict of headers for the token info request
"""
return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
@staticmethod
def _extract_user_info(
response_data: Dict,
user_id_field_name: str,
user_role_field_name: str,
user_team_id_field_name: str,
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""
Extract user information from OAuth2 response.
Args:
response_data: The response data from OAuth2 endpoint
user_id_field_name: Field name for user ID
user_role_field_name: Field name for user role
user_team_id_field_name: Field name for team ID
Returns:
Tuple of (user_id, user_role, user_team_id)
"""
user_id = response_data.get(user_id_field_name)
user_team_id = response_data.get(user_team_id_field_name)
user_role = response_data.get(user_role_field_name)
return user_id, user_role, user_team_id
@staticmethod
async def check_oauth2_token(token: str) -> UserAPIKeyAuth:
"""
Makes a request to the token introspection endpoint to validate the OAuth2 token.
This function implements OAuth2 token introspection according to RFC 7662.
It supports both generic token info endpoints (GET) and OAuth2 introspection endpoints (POST).
Args:
token (str): The OAuth2 token to validate.
Returns:
UserAPIKeyAuth: If the token is valid, containing user information.
Raises:
ValueError: If the token is invalid, the request fails, or the token info endpoint is not set.
"""
from litellm.proxy.proxy_server import premium_user
if premium_user is not True:
raise ValueError(
"Oauth2 token validation is only available for premium users"
+ CommonProxyErrors.not_premium_user.value
)
verbose_proxy_logger.debug("Oauth2 token validation for token=%s", token)
# Get the token info endpoint from environment variable
token_info_endpoint = os.getenv("OAUTH_TOKEN_INFO_ENDPOINT")
user_id_field_name = os.environ.get("OAUTH_USER_ID_FIELD_NAME", "sub")
user_role_field_name = os.environ.get("OAUTH_USER_ROLE_FIELD_NAME", "role")
user_team_id_field_name = os.environ.get(
"OAUTH_USER_TEAM_ID_FIELD_NAME", "team_id"
)
# OAuth2 client credentials for introspection endpoint authentication
oauth_client_id = os.environ.get("OAUTH_CLIENT_ID")
oauth_client_secret = os.environ.get("OAUTH_CLIENT_SECRET")
if not token_info_endpoint:
raise ValueError(
"OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set"
)
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
# Determine if this is an introspection endpoint (requires POST) or token info endpoint (uses GET)
is_introspection_endpoint = Oauth2Handler._is_introspection_endpoint(
token_info_endpoint=token_info_endpoint,
oauth_client_id=oauth_client_id,
oauth_client_secret=oauth_client_secret,
)
try:
if is_introspection_endpoint:
# OAuth2 Token Introspection (RFC 7662) - requires POST with form data
verbose_proxy_logger.debug("Using OAuth2 introspection endpoint (POST)")
headers, data = Oauth2Handler._prepare_introspection_request(
token=token,
oauth_client_id=oauth_client_id,
oauth_client_secret=oauth_client_secret,
)
response = await client.post(
token_info_endpoint, headers=headers, data=data
)
else:
# Generic token info endpoint - uses GET with Bearer token
verbose_proxy_logger.debug("Using generic token info endpoint (GET)")
headers = Oauth2Handler._prepare_token_info_request(token=token)
response = await client.get(token_info_endpoint, headers=headers)
# if it's a bad token we expect it to raise an HTTPStatusError
response.raise_for_status()
# If we get here, the request was successful
data = response.json()
verbose_proxy_logger.debug(
"Oauth2 token validation for token=%s, response from endpoint=%s",
token,
data,
)
# For introspection endpoints, check if token is active
if is_introspection_endpoint and not data.get("active", True):
raise ValueError("Token is not active")
# Extract user information from response
user_id, user_role, user_team_id = Oauth2Handler._extract_user_info(
response_data=data,
user_id_field_name=user_id_field_name,
user_role_field_name=user_role_field_name,
user_team_id_field_name=user_team_id_field_name,
)
return UserAPIKeyAuth(
api_key=token,
team_id=user_team_id,
user_id=user_id,
user_role=cast(LitellmUserRoles, user_role),
)
except httpx.HTTPStatusError as e:
# This will catch any 4xx or 5xx errors
raise ValueError(f"Oauth 2.0 Token validation failed: {e}")
except Exception as e:
# This will catch any other errors (like network issues)
raise ValueError(f"An error occurred during token validation: {e}")

View File

@@ -0,0 +1,45 @@
from typing import Any, Dict
from fastapi import Request
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import UserAPIKeyAuth
async def handle_oauth2_proxy_request(request: Request) -> UserAPIKeyAuth:
"""
Handle request from oauth2 proxy.
"""
from litellm.proxy.proxy_server import general_settings
verbose_proxy_logger.debug("Handling oauth2 proxy request")
# Define the OAuth2 config mappings
oauth2_config_mappings: Dict[str, str] = (
general_settings.get("oauth2_config_mappings") or {}
)
verbose_proxy_logger.debug(f"Oauth2 config mappings: {oauth2_config_mappings}")
if not oauth2_config_mappings:
raise ValueError("Oauth2 config mappings not found in general_settings")
# Initialize a dictionary to store the mapped values
auth_data: Dict[str, Any] = {}
# Extract values from headers based on the mappings
for key, header in oauth2_config_mappings.items():
value = request.headers.get(header)
if value:
# Convert max_budget to float if present
if key == "max_budget":
auth_data[key] = float(value)
# Convert models to list if present
elif key == "models":
auth_data[key] = [model.strip() for model in value.split(",")]
else:
auth_data[key] = value
verbose_proxy_logger.debug(
f"Auth data before creating UserAPIKeyAuth object: {auth_data}"
)
user_api_key_auth = UserAPIKeyAuth(**auth_data)
verbose_proxy_logger.debug(f"UserAPIKeyAuth object created: {user_api_key_auth}")
# Create and return UserAPIKeyAuth object
return user_api_key_auth

View File

@@ -0,0 +1,9 @@
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwcNBabWBZzrDhFAuA4Fh
FhIcA3rF7vrLb8+1yhF2U62AghQp9nStyuJRjxMUuldWgJ1yRJ2s7UffVw5r8DeA
dqXPD+w+3LCNwqJGaIKN08QGJXNArM3QtMaN0RTzAyQ4iibN1r6609W5muK9wGp0
b1j5+iDUmf0ynItnhvaX6B8Xoaflc3WD/UBdrygLmsU5uR3XC86+/8ILoSZH3HtN
6FJmWhlhjS2TR1cKZv8K5D0WuADTFf5MF8jYFR+uORPj5Pe/EJlLGN26Lfn2QnGu
XgbPF6nCGwZ0hwH1Xkn3xzGaJ4xBEC761wqp5cHxWSDktHyFKnLbP3jVeegjVIHh
pQIDAQAB
-----END PUBLIC KEY-----

View File

@@ -0,0 +1,187 @@
import os
from typing import Any, Optional, Union
import httpx
def init_rds_client(
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
):
from litellm.secret_managers.main import get_secret
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
standard_aws_region_name = get_secret("AWS_REGION", None)
## CHECK IS 'os.environ/' passed in
# Define the list of parameters to check
params_to_check = [
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
params_to_check[i] = get_secret(param) # type: ignore
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
) = params_to_check
### SET REGION NAME
region_name = aws_region_name
if aws_region_name:
region_name = aws_region_name
elif litellm_aws_region_name:
region_name = litellm_aws_region_name
elif standard_aws_region_name:
region_name = standard_aws_region_name
else:
raise Exception(
"AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
)
import boto3
if isinstance(timeout, float):
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) # type: ignore
elif isinstance(timeout, httpx.Timeout):
config = boto3.session.Config( # type: ignore
connect_timeout=timeout.connect, read_timeout=timeout.read
)
else:
config = boto3.session.Config() # type: ignore
### CHECK STS ###
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
try:
oidc_token = open(aws_web_identity_token).read() # check if filepath
except Exception:
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise Exception(
"OIDC token could not be retrieved from secret manager.",
)
sts_client = boto3.client("sts")
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
client = boto3.client(
service_name="rds",
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=region_name,
config=config,
)
elif aws_role_name is not None and aws_session_name is not None:
# use sts if role name passed in
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
client = boto3.client(
service_name="rds",
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=region_name,
config=config,
)
elif aws_access_key_id is not None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
client = boto3.client(
service_name="rds",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=region_name,
config=config,
)
elif aws_profile_name is not None:
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name).client(
service_name="rds",
region_name=region_name,
config=config,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automatically reads env variables
client = boto3.client(
service_name="rds",
region_name=region_name,
config=config,
)
return client
def generate_iam_auth_token(
db_host, db_port, db_user, client: Optional[Any] = None
) -> str:
from urllib.parse import quote
if client is None:
boto_client = init_rds_client(
aws_region_name=os.getenv("AWS_REGION_NAME"),
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
aws_session_name=os.getenv("AWS_SESSION_NAME"),
aws_profile_name=os.getenv("AWS_PROFILE_NAME"),
aws_role_name=os.getenv("AWS_ROLE_NAME", os.getenv("AWS_ROLE_ARN")),
aws_web_identity_token=os.getenv(
"AWS_WEB_IDENTITY_TOKEN", os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
),
)
else:
boto_client = client
token = boto_client.generate_db_auth_token(
DBHostname=db_host, Port=db_port, DBUsername=db_user
)
cleaned_token = quote(token, safe="")
return cleaned_token

View File

@@ -0,0 +1,669 @@
import re
from typing import List, Optional
from fastapi import HTTPException, Request, status
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
CommonProxyErrors,
LiteLLM_UserTable,
LiteLLMRoutes,
LitellmUserRoles,
UserAPIKeyAuth,
)
from .auth_checks_organization import _user_is_org_admin
class RouteChecks:
@staticmethod
def should_call_route(route: str, valid_token: UserAPIKeyAuth):
"""
Check if management route is disabled and raise exception
"""
try:
from litellm_enterprise.proxy.auth.route_checks import EnterpriseRouteChecks
EnterpriseRouteChecks.should_call_route(route=route)
except HTTPException as e:
raise e
except Exception:
pass
# Check if Virtual Key is allowed to call the route - Applies to all Roles
RouteChecks.is_virtual_key_allowed_to_call_route(
route=route, valid_token=valid_token
)
return True
@staticmethod
def is_virtual_key_allowed_to_call_route(
route: str, valid_token: UserAPIKeyAuth
) -> bool:
"""
Raises Exception if Virtual Key is not allowed to call the route
"""
# Only check if valid_token.allowed_routes is set and is a list with at least one item
if valid_token.allowed_routes is None:
return True
if not isinstance(valid_token.allowed_routes, list):
return True
if len(valid_token.allowed_routes) == 0:
return True
# explicit check for allowed routes (exact match or prefix match)
for allowed_route in valid_token.allowed_routes:
if RouteChecks._route_matches_allowed_route(
route=route, allowed_route=allowed_route
):
return True
## check if 'allowed_route' is a field name in LiteLLMRoutes
if any(
allowed_route in LiteLLMRoutes._member_names_
for allowed_route in valid_token.allowed_routes
):
for allowed_route in valid_token.allowed_routes:
if allowed_route in LiteLLMRoutes._member_names_:
if RouteChecks.check_route_access(
route=route,
allowed_routes=LiteLLMRoutes._member_map_[allowed_route].value,
):
return True
################################################
# For llm_api_routes, also check registered pass-through endpoints
################################################
if allowed_route == "llm_api_routes":
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
InitPassThroughEndpointHelpers,
)
if InitPassThroughEndpointHelpers.is_registered_pass_through_route(
route=route
):
return True
# check if wildcard pattern is allowed
for allowed_route in valid_token.allowed_routes:
if RouteChecks._route_matches_wildcard_pattern(
route=route, pattern=allowed_route
):
return True
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Virtual key is not allowed to call this route. Only allowed to call routes: {valid_token.allowed_routes}. Tried to call route: {route}",
)
@staticmethod
def _mask_user_id(user_id: str) -> str:
"""
Mask user_id to prevent leaking sensitive information in error messages
Args:
user_id (str): The user_id to mask
Returns:
str: Masked user_id showing only first 2 and last 2 characters
"""
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
if not user_id or len(user_id) <= 4:
return "***"
# Use SensitiveDataMasker with custom configuration for user_id
masker = SensitiveDataMasker(visible_prefix=6, visible_suffix=2, mask_char="*")
return masker._mask_value(user_id)
@staticmethod
def _raise_admin_only_route_exception(
user_obj: Optional[LiteLLM_UserTable],
route: str,
) -> None:
"""
Raise exception for routes that require proxy admin access
Args:
user_obj (Optional[LiteLLM_UserTable]): The user object
route (str): The route being accessed
Raises:
Exception: With user role and masked user_id information
"""
user_role = "unknown"
user_id = "unknown"
if user_obj is not None:
user_role = user_obj.user_role or "unknown"
user_id = user_obj.user_id or "unknown"
masked_user_id = RouteChecks._mask_user_id(user_id)
raise Exception(
f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={masked_user_id}"
)
@staticmethod
def non_proxy_admin_allowed_routes_check(
user_obj: Optional[LiteLLM_UserTable],
_user_role: Optional[LitellmUserRoles],
route: str,
request: Request,
valid_token: UserAPIKeyAuth,
request_data: dict,
):
"""
Checks if Non Proxy Admin User is allowed to access the route
"""
# Check user has defined custom admin routes
RouteChecks.custom_admin_only_route_check(
route=route,
)
if RouteChecks.is_llm_api_route(route=route):
pass
elif RouteChecks.is_info_route(route=route):
# check if user allowed to call an info route
if route == "/key/info":
# handled by function itself
pass
elif route == "/user/info":
# check if user can access this route
query_params = request.query_params
user_id = query_params.get("user_id")
verbose_proxy_logger.debug(
f"user_id: {user_id} & valid_token.user_id: {valid_token.user_id}"
)
if user_id and user_id != valid_token.user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="key not allowed to access this user's info. user_id={}, key's user_id={}".format(
user_id, valid_token.user_id
),
)
elif route == "/model/info":
# /model/info just shows models user has access to
pass
elif route == "/team/info":
pass # handled by function itself
elif (
route in LiteLLMRoutes.global_spend_tracking_routes.value
and getattr(valid_token, "permissions", None) is not None
and "get_spend_routes" in getattr(valid_token, "permissions", [])
):
pass
elif _user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value:
RouteChecks._check_proxy_admin_viewer_access(
route=route,
_user_role=_user_role,
request_data=request_data,
)
elif (
_user_role == LitellmUserRoles.INTERNAL_USER.value
and RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.internal_user_routes.value
)
):
pass
elif _user_is_org_admin(
request_data=request_data, user_object=user_obj
) and RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.org_admin_allowed_routes.value
):
pass
elif (
_user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
and RouteChecks.check_route_access(
route=route,
allowed_routes=LiteLLMRoutes.internal_user_view_only_routes.value,
)
):
pass
elif RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.self_managed_routes.value
): # routes that manage their own allowed/disallowed logic
pass
elif route.startswith("/v1/mcp/") or route.startswith("/mcp-rest/"):
pass # authN/authZ handled by api itself
elif RouteChecks.check_passthrough_route_access(
route=route, user_api_key_dict=valid_token
):
pass
elif valid_token.allowed_routes is not None:
# check if route is in allowed_routes (exact match or prefix match)
route_allowed = False
for allowed_route in valid_token.allowed_routes:
if RouteChecks._route_matches_allowed_route(
route=route, allowed_route=allowed_route
):
route_allowed = True
break
if RouteChecks._route_matches_wildcard_pattern(
route=route, pattern=allowed_route
):
route_allowed = True
break
if not route_allowed:
RouteChecks._raise_admin_only_route_exception(
user_obj=user_obj, route=route
)
else:
RouteChecks._raise_admin_only_route_exception(
user_obj=user_obj, route=route
)
@staticmethod
def custom_admin_only_route_check(route: str):
from litellm.proxy.proxy_server import general_settings, premium_user
if "admin_only_routes" in general_settings:
if premium_user is not True:
verbose_proxy_logger.error(
f"Trying to use 'admin_only_routes' this is an Enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
)
return
if route in general_settings["admin_only_routes"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"user not allowed to access this route. Route={route} is an admin only route",
)
pass
@staticmethod
def is_llm_api_route(route: str) -> bool:
"""
Helper to checks if provided route is an OpenAI route
Returns:
- True: if route is an OpenAI route
- False: if route is not an OpenAI route
"""
# Ensure route is a string before performing checks
if not isinstance(route, str):
return False
if route in LiteLLMRoutes.openai_routes.value:
return True
if route in LiteLLMRoutes.anthropic_routes.value:
return True
if route in LiteLLMRoutes.google_routes.value:
return True
if RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.mcp_routes.value
):
return True
if RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.agent_routes.value
):
return True
if route in LiteLLMRoutes.litellm_native_routes.value:
return True
# fuzzy match routes like "/v1/threads/thread_49EIN5QF32s4mH20M7GFKdlZ"
# Check for routes with placeholders or wildcard patterns
for openai_route in LiteLLMRoutes.openai_routes.value:
# Replace placeholders with regex pattern
# placeholders are written as "/threads/{thread_id}"
if "{" in openai_route:
if RouteChecks._route_matches_pattern(
route=route, pattern=openai_route
):
return True
# Check for wildcard patterns like "/containers/*"
if RouteChecks._is_wildcard_pattern(pattern=openai_route):
if RouteChecks._route_matches_wildcard_pattern(
route=route, pattern=openai_route
):
return True
# Check for Google routes with placeholders like "/v1beta/models/{model_name}:generateContent"
for google_route in LiteLLMRoutes.google_routes.value:
if "{" in google_route:
if RouteChecks._route_matches_pattern(
route=route, pattern=google_route
):
return True
# Check for Anthropic routes with placeholders
for anthropic_route in LiteLLMRoutes.anthropic_routes.value:
if "{" in anthropic_route:
if RouteChecks._route_matches_pattern(
route=route, pattern=anthropic_route
):
return True
if RouteChecks._is_azure_openai_route(route=route):
return True
for _llm_passthrough_route in LiteLLMRoutes.mapped_pass_through_routes.value:
if _llm_passthrough_route in route:
return True
return False
@staticmethod
def is_management_route(route: str) -> bool:
"""
Check if route is a management route
"""
return route in LiteLLMRoutes.management_routes.value
@staticmethod
def is_info_route(route: str) -> bool:
"""
Check if route is an info route
"""
return route in LiteLLMRoutes.info_routes.value
@staticmethod
def _is_azure_openai_route(route: str) -> bool:
"""
Check if route is a route from AzureOpenAI SDK client
eg.
route='/openai/deployments/vertex_ai/gemini-1.5-flash/chat/completions'
"""
# Ensure route is a string before attempting regex matching
if not isinstance(route, str):
return False
# Add support for deployment and engine model paths
deployment_pattern = r"^/openai/deployments/[^/]+/[^/]+/chat/completions$"
engine_pattern = r"^/engines/[^/]+/chat/completions$"
if re.match(deployment_pattern, route) or re.match(engine_pattern, route):
return True
return False
@staticmethod
def _route_matches_pattern(route: str, pattern: str) -> bool:
"""
Check if route matches the pattern placed in proxy/_types.py
Example:
- pattern: "/threads/{thread_id}"
- route: "/threads/thread_49EIN5QF32s4mH20M7GFKdlZ"
- returns: True
- pattern: "/key/{token_id}/regenerate"
- route: "/key/regenerate/82akk800000000jjsk"
- returns: False, pattern is "/key/{token_id}/regenerate"
"""
# Ensure route is a string before attempting regex matching
if not isinstance(route, str):
return False
def _placeholder_to_regex(match: re.Match) -> str:
placeholder = match.group(0).strip("{}")
if placeholder.endswith(":path"):
# allow "/" in the placeholder value, but don't eat the route suffix after ":"
return r"[^:]+"
return r"[^/]+"
pattern = re.sub(r"\{[^}]+\}", _placeholder_to_regex, pattern)
# Anchor the pattern to match the entire string
pattern = f"^{pattern}$"
if re.match(pattern, route):
return True
return False
@staticmethod
def _is_wildcard_pattern(pattern: str) -> bool:
"""
Check if pattern is a wildcard pattern
"""
return pattern.endswith("*")
@staticmethod
def _route_matches_wildcard_pattern(route: str, pattern: str) -> bool:
"""
Check if route matches the wildcard pattern
eg.
pattern: "/scim/v2/*"
route: "/scim/v2/Users"
- returns: True
pattern: "/scim/v2/*"
route: "/chat/completions"
- returns: False
pattern: "/scim/v2/*"
route: "/scim/v2/Users/123"
- returns: True
"""
if pattern.endswith("*"):
# Get the prefix (everything before the wildcard)
prefix = pattern[:-1]
return route.startswith(prefix)
else:
# If there's no wildcard, the pattern and route should match exactly
return route == pattern
@staticmethod
def _route_matches_allowed_route(route: str, allowed_route: str) -> bool:
"""
Check if route matches the allowed_route pattern.
Supports both exact match and prefix match.
Examples:
- allowed_route="/fake-openai-proxy-6", route="/fake-openai-proxy-6" -> True (exact match)
- allowed_route="/fake-openai-proxy-6", route="/fake-openai-proxy-6/v1/chat/completions" -> True (prefix match)
- allowed_route="/fake-openai-proxy-6", route="/fake-openai-proxy-600" -> False (not a valid prefix)
Args:
route: The actual route being accessed
allowed_route: The allowed route pattern
Returns:
bool: True if route matches (exact or prefix), False otherwise
"""
# Exact match
if route == allowed_route:
return True
# Prefix match - ensure we add "/" to prevent false matches like /fake-openai-proxy-600
if route.startswith(allowed_route + "/"):
return True
return False
@staticmethod
def check_route_access(route: str, allowed_routes: List[str]) -> bool:
"""
Check if a route has access by checking both exact matches and patterns
Args:
route (str): The route to check
allowed_routes (list): List of allowed routes/patterns
Returns:
bool: True if route is allowed, False otherwise
"""
#########################################################
# exact match route is in allowed_routes
#########################################################
if route in allowed_routes:
return True
#########################################################
# wildcard match route is in allowed_routes
# e.g calling /anthropic/v1/messages is allowed if allowed_routes has /anthropic/*
#########################################################
wildcard_allowed_routes = [
route
for route in allowed_routes
if RouteChecks._is_wildcard_pattern(pattern=route)
]
for allowed_route in wildcard_allowed_routes:
if RouteChecks._route_matches_wildcard_pattern(
route=route, pattern=allowed_route
):
return True
#########################################################
# pattern match route is in allowed_routes
# pattern: "/threads/{thread_id}"
# route: "/threads/thread_49EIN5QF32s4mH20M7GFKdlZ"
# returns: True
#########################################################
if any( # Check pattern match
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
for allowed_route in allowed_routes
):
return True
return False
@staticmethod
def check_passthrough_route_access(
route: str, user_api_key_dict: UserAPIKeyAuth
) -> bool:
"""
Check if route is a passthrough route.
Supports both exact match and prefix match.
"""
metadata = user_api_key_dict.metadata
team_metadata = user_api_key_dict.team_metadata or {}
if metadata is None and team_metadata is None:
return False
if (
"allowed_passthrough_routes" not in metadata
and "allowed_passthrough_routes" not in team_metadata
):
return False
if (
metadata.get("allowed_passthrough_routes") is None
and team_metadata.get("allowed_passthrough_routes") is None
):
return False
allowed_passthrough_routes = (
metadata.get("allowed_passthrough_routes")
or team_metadata.get("allowed_passthrough_routes")
or []
)
# Check if route matches any allowed passthrough route (exact or prefix match)
for allowed_route in allowed_passthrough_routes:
if RouteChecks._route_matches_allowed_route(
route=route, allowed_route=allowed_route
):
return True
return False
@staticmethod
def _is_assistants_api_request(request: Request) -> bool:
"""
Returns True if `thread` or `assistant` is in the request path
Args:
request (Request): The request object
Returns:
bool: True if `thread` or `assistant` is in the request path, False otherwise
"""
if "thread" in request.url.path or "assistant" in request.url.path:
return True
return False
@staticmethod
def is_generate_content_route(route: str) -> bool:
"""
Returns True if this is a google generateContent or streamGenerateContent route
These routes from google allow passing key=api_key in the query params
"""
if "generateContent" in route:
return True
if "streamGenerateContent" in route:
return True
return False
@staticmethod
def _check_proxy_admin_viewer_access(
route: str,
_user_role: str,
request_data: dict,
) -> None:
"""
Check access for PROXY_ADMIN_VIEW_ONLY role
"""
if RouteChecks.is_llm_api_route(route=route):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"user not allowed to access this OpenAI routes, role= {_user_role}",
)
# Check if this is a write operation on management routes
if RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.management_routes.value
):
# For management routes, only allow read operations or specific allowed updates
if route == "/user/update":
# Check the Request params are valid for PROXY_ADMIN_VIEW_ONLY
if request_data is not None and isinstance(request_data, dict):
_params_updated = request_data.keys()
for param in _params_updated:
if param not in ["user_email", "password"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route} and updating invalid param: {param}. only user_email and password can be updated",
)
elif (
route
in [
"/user/new",
"/user/delete",
"/team/new",
"/team/update",
"/team/delete",
"/model/new",
"/model/update",
"/model/delete",
"/key/generate",
"/key/delete",
"/key/update",
"/key/regenerate",
"/key/service-account/generate",
"/key/block",
"/key/unblock",
]
or route.startswith("/key/")
and route.endswith("/regenerate")
):
# Block write operations for PROXY_ADMIN_VIEW_ONLY
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
)
# Allow read operations on management routes (like /user/info, /team/info, /model/info)
return
elif RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.admin_viewer_routes.value
):
# Allow access to admin viewer routes (read-only admin endpoints)
return
elif RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.global_spend_tracking_routes.value
):
# Allow access to global spend tracking routes (read-only spend endpoints)
# proxy_admin_viewer role description: "view all keys, view all spend"
return
else:
# For other routes, block access
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,967 @@
######################################################################
# /v1/batches Endpoints
######################################################################
import asyncio
from typing import Dict, Optional, cast
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.batches.main import CancelBatchRequest, RetrieveBatchRequest
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.common_utils.openai_endpoint_utils import (
get_custom_llm_provider_from_request_headers,
get_custom_llm_provider_from_request_query,
)
from litellm.proxy.openai_files_endpoints.common_utils import (
_is_base64_encoded_unified_file_id,
decode_model_from_file_id,
encode_batch_response_ids,
encode_file_id_with_model,
get_batch_from_database,
get_credentials_for_model,
get_model_id_from_unified_batch_id,
get_models_from_unified_file_id,
get_original_file_id,
prepare_data_with_credentials,
resolve_input_file_id_to_unified,
update_batch_in_database,
)
from litellm.proxy.utils import handle_exception_on_proxy, is_known_model
from litellm.types.llms.openai import LiteLLMBatchCreateRequest
router = APIRouter()
@router.post(
"/{provider}/v1/batches",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.post(
"/v1/batches",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.post(
"/batches",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
async def create_batch( # noqa: PLR0915
request: Request,
fastapi_response: Response,
provider: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create large batches of API requests for asynchronous processing.
This is the equivalent of POST https://api.openai.com/v1/batch
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch
Example Curl
```
curl http://localhost:4000/v1/batches \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
-d '{
"input_file_id": "file-abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h"
}'
```
"""
from litellm.proxy.proxy_server import (
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
version,
)
data: Dict = {}
try:
data = await _read_request_body(request=request)
verbose_proxy_logger.debug(
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
)
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
(
data,
litellm_logging_obj,
) = await base_llm_response_processor.common_processing_pre_call_logic(
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_logging_obj=proxy_logging_obj,
proxy_config=proxy_config,
route_type="acreate_batch",
)
## check if model is a loadbalanced model
router_model: Optional[str] = None
is_router_model = False
if litellm.enable_loadbalancing_on_batch_endpoints is True:
router_model = data.get("model", None)
is_router_model = is_known_model(model=router_model, llm_router=llm_router)
custom_llm_provider = (
provider
or data.pop("custom_llm_provider", None)
or get_custom_llm_provider_from_request_headers(request=request)
or "openai"
)
_create_batch_data = LiteLLMBatchCreateRequest(**data)
# Apply team-level batch output expiry enforcement
team_metadata = user_api_key_dict.team_metadata or {}
enforced_batch_expiry = team_metadata.get("enforced_batch_output_expires_after")
if enforced_batch_expiry is not None:
if (
"anchor" not in enforced_batch_expiry
or "seconds" not in enforced_batch_expiry
):
raise HTTPException(
status_code=500,
detail={
"error": "Server configuration error: team metadata field 'enforced_batch_output_expires_after' is malformed - must contain 'anchor' and 'seconds' keys. Contact your team or proxy admin to fix this setting.",
},
)
if enforced_batch_expiry["anchor"] != "created_at":
raise HTTPException(
status_code=500,
detail={
"error": f"Server configuration error: team metadata field 'enforced_batch_output_expires_after' has invalid anchor '{enforced_batch_expiry['anchor']}' - must be 'created_at'. Contact your team or proxy admin to fix this setting.",
},
)
_create_batch_data["output_expires_after"] = {
"anchor": "created_at",
"seconds": int(enforced_batch_expiry["seconds"]),
}
input_file_id = _create_batch_data.get("input_file_id", None)
unified_file_id: Union[str, Literal[False]] = False
model_from_file_id = None
if input_file_id:
model_from_file_id = decode_model_from_file_id(input_file_id)
unified_file_id = _is_base64_encoded_unified_file_id(input_file_id)
# SCENARIO 1: File ID is encoded with model info
if model_from_file_id is not None and input_file_id:
credentials = get_credentials_for_model(
llm_router=llm_router,
model_id=model_from_file_id,
operation_context="batch creation (file created with model)",
)
original_file_id = get_original_file_id(input_file_id)
_create_batch_data["input_file_id"] = original_file_id
prepare_data_with_credentials(
data=_create_batch_data, # type: ignore
credentials=credentials,
)
# Create batch using model credentials
response = await litellm.acreate_batch(
custom_llm_provider=credentials["custom_llm_provider"],
**_create_batch_data, # type: ignore
)
# Encode the batch ID and related file IDs with model information
if response and hasattr(response, "id") and response.id:
original_batch_id = response.id
encoded_batch_id = encode_file_id_with_model(
file_id=original_batch_id,
model=model_from_file_id,
id_type="batch",
)
response.id = encoded_batch_id
if hasattr(response, "output_file_id") and response.output_file_id:
response.output_file_id = encode_file_id_with_model(
file_id=response.output_file_id, model=model_from_file_id
)
if hasattr(response, "error_file_id") and response.error_file_id:
response.error_file_id = encode_file_id_with_model(
file_id=response.error_file_id, model=model_from_file_id
)
verbose_proxy_logger.debug(
f"Created batch using model: {model_from_file_id}, "
f"original_batch_id: {original_batch_id}, encoded: {encoded_batch_id}"
)
response.input_file_id = input_file_id
elif (
litellm.enable_loadbalancing_on_batch_endpoints is True
and is_router_model
and router_model is not None
):
if llm_router is None:
raise HTTPException(
status_code=500,
detail={
"error": "LLM Router not initialized. Ensure models added to proxy."
},
)
response = await llm_router.acreate_batch(**_create_batch_data) # type: ignore
elif (
unified_file_id and input_file_id
): # litellm_proxy:application/octet-stream;unified_id,c4843482-b176-4901-8292-7523fd0f2c6e;target_model_names,gpt-4o-mini
target_model_names = get_models_from_unified_file_id(unified_file_id)
## EXPECTS 1 MODEL
if len(target_model_names) != 1:
raise HTTPException(
status_code=400,
detail={
"error": "Expected 1 model, got {}".format(
len(target_model_names)
)
},
)
model = target_model_names[0]
_create_batch_data["model"] = model
if llm_router is None:
raise HTTPException(
status_code=500,
detail={
"error": "LLM Router not initialized. Ensure models added to proxy."
},
)
response = await llm_router.acreate_batch(**_create_batch_data)
response.input_file_id = input_file_id
response._hidden_params["unified_file_id"] = unified_file_id
else:
# Check if model specified via header/query/body param
model_param = (
data.get("model")
or request.query_params.get("model")
or request.headers.get("x-litellm-model")
)
# SCENARIO 2 & 3: Model from header/query OR custom_llm_provider fallback
if model_param:
# SCENARIO 2: Use model-based routing from header/query/body
credentials = get_credentials_for_model(
llm_router=llm_router,
model_id=model_param,
operation_context="batch creation",
)
prepare_data_with_credentials(
data=_create_batch_data, # type: ignore
credentials=credentials,
)
# Create batch using model credentials
response = await litellm.acreate_batch(
custom_llm_provider=credentials["custom_llm_provider"],
**_create_batch_data, # type: ignore
)
encode_batch_response_ids(response, model=model_param)
verbose_proxy_logger.debug(f"Created batch using model: {model_param}")
else:
# SCENARIO 3: Fallback to custom_llm_provider (uses env variables)
response = await litellm.acreate_batch(
custom_llm_provider=custom_llm_provider, **_create_batch_data # type: ignore
)
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e)
@router.get(
"/{provider}/v1/batches/{batch_id:path}",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.get(
"/v1/batches/{batch_id:path}",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.get(
"/batches/{batch_id:path}",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
async def retrieve_batch( # noqa: PLR0915
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
provider: Optional[str] = None,
batch_id: str = Path(
title="Batch ID to retrieve", description="The ID of the batch to retrieve"
),
):
"""
Retrieves a batch.
This is the equivalent of GET https://api.openai.com/v1/batches/{batch_id}
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/retrieve
Example Curl
```
curl http://localhost:4000/v1/batches/batch_abc123 \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
```
"""
from litellm.proxy.proxy_server import (
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
version,
)
data: Dict = {}
try:
model_from_id = decode_model_from_file_id(batch_id)
_retrieve_batch_request = RetrieveBatchRequest(
batch_id=batch_id,
)
data = cast(dict, _retrieve_batch_request)
unified_batch_id = _is_base64_encoded_unified_file_id(batch_id)
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
(
data,
litellm_logging_obj,
) = await base_llm_response_processor.common_processing_pre_call_logic(
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_logging_obj=proxy_logging_obj,
proxy_config=proxy_config,
route_type="aretrieve_batch",
)
# FIX: First, try to read from ManagedObjectTable for consistent state
managed_files_obj = proxy_logging_obj.get_proxy_hook("managed_files")
from litellm.proxy.proxy_server import prisma_client
db_batch_object, response = await get_batch_from_database(
batch_id=batch_id,
unified_batch_id=unified_batch_id,
managed_files_obj=managed_files_obj,
prisma_client=prisma_client,
verbose_proxy_logger=verbose_proxy_logger,
)
# If batch is in a terminal state, return immediately
if response is not None and response.status in [
"completed",
"failed",
"cancelled",
"expired",
]:
# Call hooks and return
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)
# async_post_call_success_hook replaces batch.id and output_file_id with unified IDs
# but not input_file_id. Resolve raw provider ID to unified ID.
if unified_batch_id:
await resolve_input_file_id_to_unified(response, prisma_client)
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
return response
# If batch is still processing, sync with provider to get latest state
if response is not None:
verbose_proxy_logger.debug(
f"Batch {batch_id} is in non-terminal state {response.status}, syncing with provider"
)
# Retrieve from provider (for non-terminal states or if DB lookup failed)
# SCENARIO 1: Batch ID is encoded with model info
if model_from_id is not None:
credentials = get_credentials_for_model(
llm_router=llm_router,
model_id=model_from_id,
operation_context="batch retrieval (batch created with model)",
)
original_batch_id = get_original_file_id(batch_id)
prepare_data_with_credentials(
data=data,
credentials=credentials,
file_id=original_batch_id, # Sets data["batch_id"] = original_batch_id
)
# Fix: The helper sets "file_id" but we need "batch_id"
data["batch_id"] = data.pop("file_id", original_batch_id)
# Retrieve batch using model credentials
response = await litellm.aretrieve_batch(
custom_llm_provider=credentials["custom_llm_provider"],
**data, # type: ignore
)
encode_batch_response_ids(response, model=model_from_id)
verbose_proxy_logger.debug(
f"Retrieved batch using model: {model_from_id}, original_id: {original_batch_id}"
)
elif (
litellm.enable_loadbalancing_on_batch_endpoints is True or unified_batch_id
):
if llm_router is None:
raise HTTPException(
status_code=500,
detail={
"error": "LLM Router not initialized. Ensure models added to proxy."
},
)
response = await llm_router.aretrieve_batch(**data) # type: ignore
response._hidden_params["unified_batch_id"] = unified_batch_id
if unified_batch_id:
model_id_from_batch = get_model_id_from_unified_batch_id(
unified_batch_id
)
if model_id_from_batch:
response._hidden_params["model_id"] = model_id_from_batch
# SCENARIO 3: Fallback to custom_llm_provider (uses env variables)
else:
custom_llm_provider = (
provider
or get_custom_llm_provider_from_request_headers(request=request)
or get_custom_llm_provider_from_request_query(request=request)
or "openai"
)
response = await litellm.aretrieve_batch(
custom_llm_provider=custom_llm_provider, **data # type: ignore
)
# FIX: Update the database with the latest state from provider
await update_batch_in_database(
batch_id=batch_id,
unified_batch_id=unified_batch_id,
response=response,
managed_files_obj=managed_files_obj,
prisma_client=prisma_client,
verbose_proxy_logger=verbose_proxy_logger,
db_batch_object=db_batch_object,
operation="retrieve",
)
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)
# Fix: bug_feb14_batch_retrieve_returns_raw_input_file_id
# Resolve raw provider input_file_id to unified ID.
if unified_batch_id:
await resolve_input_file_id_to_unified(response, prisma_client)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e)
@router.get(
"/{provider}/v1/batches",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.get(
"/v1/batches",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.get(
"/batches",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
async def list_batches(
request: Request,
fastapi_response: Response,
provider: Optional[str] = None,
limit: Optional[int] = None,
after: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
target_model_names: Optional[str] = None,
):
"""
Lists
This is the equivalent of GET https://api.openai.com/v1/batches/
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/list
Example Curl
```
curl http://localhost:4000/v1/batches?limit=2 \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
```
"""
from litellm.proxy.proxy_server import (
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
version,
)
verbose_proxy_logger.debug("GET /v1/batches after={} limit={}".format(after, limit))
try:
if llm_router is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.no_llm_router.value},
)
# Include original request and headers in the data
data = await _read_request_body(request=request)
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
(
data,
litellm_logging_obj,
) = await base_llm_response_processor.common_processing_pre_call_logic(
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_logging_obj=proxy_logging_obj,
proxy_config=proxy_config,
route_type="alist_batches",
)
# Try to use managed objects table for listing batches (returns encoded IDs)
managed_files_obj = proxy_logging_obj.get_proxy_hook("managed_files")
if managed_files_obj is not None and hasattr(
managed_files_obj, "list_user_batches"
):
verbose_proxy_logger.debug("Using managed objects table for batch listing")
response = await managed_files_obj.list_user_batches(
user_api_key_dict=user_api_key_dict,
limit=limit,
after=after,
provider=provider,
target_model_names=target_model_names,
llm_router=llm_router,
)
elif model_param := (
data.get("model")
or request.query_params.get("model")
or request.headers.get("x-litellm-model")
):
# SCENARIO 2: Use model-based routing from header/query/body
credentials = get_credentials_for_model(
llm_router=llm_router,
model_id=model_param,
operation_context="batch listing",
)
data.update(credentials)
response = await litellm.alist_batches(
custom_llm_provider=credentials["custom_llm_provider"],
after=after,
limit=limit,
**data, # type: ignore
)
# Encode batch IDs in the list response so clients can use
# them for retrieve/cancel/file downloads through the proxy.
if response and hasattr(response, "data") and response.data:
for batch in response.data:
encode_batch_response_ids(batch, model=model_param)
verbose_proxy_logger.debug(f"Listed batches using model: {model_param}")
# SCENARIO 2 (alternative): target_model_names based routing
elif target_model_names or data.get("target_model_names", None):
target_model_names = target_model_names or data.get(
"target_model_names", None
)
if target_model_names is None:
raise ValueError(
"target_model_names is required for this routing scenario"
)
model = target_model_names.split(",")[0]
data.pop("model", None)
response = await llm_router.alist_batches(
model=model,
after=after,
limit=limit,
**data,
)
# SCENARIO 3: Fallback to custom_llm_provider (uses env variables)
else:
custom_llm_provider = (
provider
or get_custom_llm_provider_from_request_headers(request=request)
or get_custom_llm_provider_from_request_query(request=request)
or "openai"
)
response = await litellm.alist_batches(
custom_llm_provider=custom_llm_provider, # type: ignore
after=after,
limit=limit,
**data,
)
## POST CALL HOOKS ###
_response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response # type: ignore
)
if _response is not None and type(response) is type(_response):
response = _response
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data={"after": after, "limit": limit},
)
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e)
@router.post(
"/{provider}/v1/batches/{batch_id:path}/cancel",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.post(
"/v1/batches/{batch_id:path}/cancel",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.post(
"/batches/{batch_id:path}/cancel",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
async def cancel_batch(
request: Request,
batch_id: str,
fastapi_response: Response,
provider: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Cancel a batch.
This is the equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/cancel
Example Curl
```
curl http://localhost:4000/v1/batches/batch_abc123/cancel \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
-X POST
```
"""
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
version,
)
data: Dict = {}
try:
# Check for encoded batch ID with model info
model_from_id = decode_model_from_file_id(batch_id)
# Create CancelBatchRequest with batch_id to enable ownership checking
_cancel_batch_request = CancelBatchRequest(
batch_id=batch_id,
)
data = cast(dict, _cancel_batch_request)
unified_batch_id = _is_base64_encoded_unified_file_id(batch_id)
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
(
data,
litellm_logging_obj,
) = await base_llm_response_processor.common_processing_pre_call_logic(
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_logging_obj=proxy_logging_obj,
proxy_config=proxy_config,
route_type="acancel_batch",
)
# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
# SCENARIO 1: Batch ID is encoded with model info
if model_from_id is not None:
credentials = get_credentials_for_model(
llm_router=llm_router,
model_id=model_from_id,
operation_context="batch cancellation (batch created with model)",
)
original_batch_id = get_original_file_id(batch_id)
prepare_data_with_credentials(
data=data,
credentials=credentials,
file_id=original_batch_id,
)
# Fix: The helper sets "file_id" but we need "batch_id"
data["batch_id"] = data.pop("file_id", original_batch_id)
# Cancel batch using model credentials
response = await litellm.acancel_batch(
custom_llm_provider=credentials["custom_llm_provider"],
**data, # type: ignore
)
encode_batch_response_ids(response, model=model_from_id)
verbose_proxy_logger.debug(
f"Cancelled batch using model: {model_from_id}, original_id: {original_batch_id}"
)
# SCENARIO 2: target_model_names based routing
elif unified_batch_id:
if llm_router is None:
raise HTTPException(
status_code=500,
detail={
"error": "LLM Router not initialized. Ensure models added to proxy."
},
)
# Hook has already extracted model and unwrapped batch_id into data dict
response = await llm_router.acancel_batch(**data) # type: ignore
response._hidden_params["unified_batch_id"] = unified_batch_id
# Ensure model_id is set for the post_call_success_hook to re-encode IDs
if not response._hidden_params.get("model_id") and data.get("model"):
response._hidden_params["model_id"] = data["model"]
# SCENARIO 3: Fallback to custom_llm_provider (uses env variables)
else:
custom_llm_provider = (
provider or data.pop("custom_llm_provider", None) or "openai"
)
# Extract batch_id from data to avoid "multiple values for keyword argument" error
# data was cast from CancelBatchRequest which already contains batch_id
data.pop("batch_id", None)
_cancel_batch_data = CancelBatchRequest(batch_id=batch_id, **data)
response = await litellm.acancel_batch(
custom_llm_provider=custom_llm_provider, # type: ignore
**_cancel_batch_data,
)
# FIX: Update the database with the new cancelled state
managed_files_obj = proxy_logging_obj.get_proxy_hook("managed_files")
from litellm.proxy.proxy_server import prisma_client
await update_batch_in_database(
batch_id=batch_id,
unified_batch_id=unified_batch_id,
response=response,
managed_files_obj=managed_files_obj,
prisma_client=prisma_client,
verbose_proxy_logger=verbose_proxy_logger,
operation="cancel",
)
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e)
######################################################################
# END OF /v1/batches Endpoints Implementation
######################################################################

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

View File

@@ -0,0 +1,257 @@
from typing import Any, Dict, List, Tuple
from fastapi import APIRouter, Depends, HTTPException, Request
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import RedisCache
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
from litellm.proxy._types import ProxyErrorTypes, ProxyException
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.caching import CachePingResponse, HealthCheckCacheParams
masker = SensitiveDataMasker()
router = APIRouter(
prefix="/cache",
tags=["caching"],
)
def _extract_cache_params() -> Dict[str, Any]:
"""
Safely extracts and cleans cache parameters.
The health check UI needs to display specific cache parameters, to show users how they set up their cache.
eg.
{
"host": "localhost",
"port": 6379,
"redis_kwargs": {"db": 0},
"namespace": "test",
}
Returns:
Dict containing cleaned and masked cache parameters
"""
if litellm.cache is None:
return {}
try:
cache_params = vars(litellm.cache.cache)
cleaned_params = (
HealthCheckCacheParams(**cache_params).model_dump() if cache_params else {}
)
return masker.mask_dict(cleaned_params)
except (AttributeError, TypeError) as e:
verbose_proxy_logger.debug(f"Error extracting cache params: {str(e)}")
return {}
@router.get(
"/ping",
response_model=CachePingResponse,
dependencies=[Depends(user_api_key_auth)],
)
async def cache_ping():
"""
Endpoint for checking if cache can be pinged
"""
litellm_cache_params: Dict[str, Any] = {}
cleaned_cache_params: Dict[str, Any] = {}
try:
if litellm.cache is None:
raise HTTPException(
status_code=503, detail="Cache not initialized. litellm.cache is None"
)
litellm_cache_params = masker.mask_dict(vars(litellm.cache))
# remove field that might reference itself
litellm_cache_params.pop("cache", None)
cleaned_cache_params = _extract_cache_params()
if litellm.cache.type == "redis":
ping_response = await litellm.cache.ping()
verbose_proxy_logger.debug(
"/cache/ping: ping_response: " + str(ping_response)
)
# add cache does not return anything
await litellm.cache.async_add_cache(
result="test_key",
model="test-model",
messages=[{"role": "user", "content": "test from litellm"}],
)
verbose_proxy_logger.debug("/cache/ping: done with set_cache()")
return CachePingResponse(
status="healthy",
cache_type=str(litellm.cache.type),
ping_response=True,
set_cache_response="success",
litellm_cache_params=safe_dumps(litellm_cache_params),
health_check_cache_params=cleaned_cache_params,
)
else:
return CachePingResponse(
status="healthy",
cache_type=str(litellm.cache.type),
litellm_cache_params=safe_dumps(litellm_cache_params),
)
except Exception as e:
import traceback
error_message = {
"message": f"Service Unhealthy ({str(e)})",
"litellm_cache_params": safe_dumps(litellm_cache_params),
"health_check_cache_params": safe_dumps(cleaned_cache_params),
"traceback": traceback.format_exc(),
}
raise ProxyException(
message=safe_dumps(error_message),
type=ProxyErrorTypes.cache_ping_error,
param="cache_ping",
code=503,
)
@router.post(
"/delete",
tags=["caching"],
dependencies=[Depends(user_api_key_auth)],
)
async def cache_delete(request: Request):
"""
Endpoint for deleting a key from the cache. All responses from litellm proxy have `x-litellm-cache-key` in the headers
Parameters:
- **keys**: *Optional[List[str]]* - A list of keys to delete from the cache. Example {"keys": ["key1", "key2"]}
```shell
curl -X POST "http://0.0.0.0:4000/cache/delete" \
-H "Authorization: Bearer sk-1234" \
-d '{"keys": ["key1", "key2"]}'
```
"""
try:
if litellm.cache is None:
raise HTTPException(
status_code=503, detail="Cache not initialized. litellm.cache is None"
)
request_data = await request.json()
keys = request_data.get("keys", None)
if litellm.cache.type == "redis":
await litellm.cache.delete_cache_keys(keys=keys)
return {
"status": "success",
}
else:
raise HTTPException(
status_code=500,
detail=f"Cache type {litellm.cache.type} does not support deleting a key. only `redis` is supported",
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Cache Delete Failed({str(e)})",
)
def _get_redis_client_info(cache_instance) -> Tuple[List, int]:
"""
Helper function to safely get Redis client list information.
Returns:
tuple: (client_list, num_clients) where num_clients is -1 if CLIENT LIST is unavailable
"""
try:
client_list = cache_instance.client_list()
return client_list, len(client_list)
except Exception as e:
verbose_proxy_logger.warning(
f"CLIENT LIST command failed (likely restricted on managed Redis): {str(e)}"
)
return ["CLIENT LIST command not available on this Redis instance"], -1
@router.get(
"/redis/info",
dependencies=[Depends(user_api_key_auth)],
)
async def cache_redis_info():
"""
Endpoint for getting /redis/info
"""
try:
if litellm.cache is None:
raise HTTPException(
status_code=503, detail="Cache not initialized. litellm.cache is None"
)
if not (
litellm.cache.type == "redis"
and isinstance(litellm.cache.cache, RedisCache)
):
raise HTTPException(
status_code=500,
detail=f"Cache type {litellm.cache.type} does not support redis info",
)
# Get client information (handles CLIENT LIST restrictions gracefully)
client_list, num_clients = _get_redis_client_info(litellm.cache.cache)
# Get Redis server information
redis_info = litellm.cache.cache.info()
return {
"num_clients": num_clients,
"clients": client_list,
"info": redis_info,
}
except Exception as e:
raise HTTPException(
status_code=503,
detail=f"Service Unhealthy ({str(e)})",
)
@router.post(
"/flushall",
tags=["caching"],
dependencies=[Depends(user_api_key_auth)],
)
async def cache_flushall():
"""
A function to flush all items from the cache. (All items will be deleted from the cache with this)
Raises HTTPException if the cache is not initialized or if the cache type does not support flushing.
Returns a dictionary with the status of the operation.
Usage:
```
curl -X POST http://0.0.0.0:4000/cache/flushall -H "Authorization: Bearer sk-1234"
```
"""
try:
if litellm.cache is None:
raise HTTPException(
status_code=503, detail="Cache not initialized. litellm.cache is None"
)
if litellm.cache.type == "redis" and isinstance(
litellm.cache.cache, RedisCache
):
litellm.cache.cache.flushall()
return {
"status": "success",
}
else:
raise HTTPException(
status_code=500,
detail=f"Cache type {litellm.cache.type} does not support flushing",
)
except Exception as e:
raise HTTPException(
status_code=503,
detail=f"Service Unhealthy ({str(e)})",
)

View File

@@ -0,0 +1,394 @@
# LiteLLM Proxy Client
A Python client library for interacting with the LiteLLM proxy server. This client provides a clean, typed interface for managing models, keys, credentials, and making chat completions.
## Installation
```bash
pip install litellm
```
## Quick Start
```python
from litellm.proxy.client import Client
# Initialize the client
client = Client(
base_url="http://localhost:4000", # Your LiteLLM proxy server URL
api_key="sk-api-key" # Optional: API key for authentication
)
# Make a chat completion request
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": "Hello, how are you?"}
]
)
print(response.choices[0].message.content)
```
## Features
The client is organized into several resource clients for different functionality:
- `chat`: Chat completions
- `models`: Model management
- `model_groups`: Model group management
- `keys`: API key management
- `credentials`: Credential management
- `users`: User management
## Chat Completions
Make chat completion requests to your LiteLLM proxy:
```python
# Basic chat completion
response = client.chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the capital of France?"}
]
)
# Stream responses
for chunk in client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Tell me a story"}],
stream=True
):
print(chunk.choices[0].delta.content or "", end="")
```
## Model Management
Manage available models on your proxy:
```python
# List available models
models = client.models.list()
# Add a new model
client.models.add(
model_name="gpt-4",
litellm_params={
"api_key": "your-openai-key",
"api_base": "https://api.openai.com/v1"
}
)
# Delete a model
client.models.delete(model_name="gpt-4")
```
## API Key Management
Manage virtual API keys:
```python
# Generate a new API key
key = client.keys.generate(
models=["gpt-4", "gpt-3.5-turbo"],
aliases={"gpt4": "gpt-4"},
duration="24h",
key_alias="my-key",
team_id="team123"
)
# List all keys
keys = client.keys.list(
page=1,
size=10,
return_full_object=True
)
# Delete keys
client.keys.delete(
keys=["sk-key1", "sk-key2"],
key_aliases=["alias1", "alias2"]
)
```
## Credential Management
Manage model credentials:
```python
# Create new credentials
client.credentials.create(
credential_name="azure1",
credential_info={"api_type": "azure"},
credential_values={
"api_key": "your-azure-key",
"api_base": "https://example.azure.openai.com"
}
)
# List all credentials
credentials = client.credentials.list()
# Get a specific credential
credential = client.credentials.get(credential_name="azure1")
# Delete credentials
client.credentials.delete(credential_name="azure1")
```
## Model Groups
Manage model groups for load balancing and fallbacks:
```python
# Create a model group
client.model_groups.create(
name="gpt4-group",
models=[
{"model_name": "gpt-4", "litellm_params": {"api_key": "key1"}},
{"model_name": "gpt-4-backup", "litellm_params": {"api_key": "key2"}}
]
)
# List model groups
groups = client.model_groups.list()
# Delete a model group
client.model_groups.delete(name="gpt4-group")
```
## Users Management
Manage users on your proxy:
```python
from litellm.proxy.client import UsersManagementClient
users = UsersManagementClient(base_url="http://localhost:4000", api_key="sk-test")
# List users
user_list = users.list_users()
# Get user info
user_info = users.get_user(user_id="u1")
# Create a new user
created = users.create_user({
"user_email": "a@b.com",
"user_role": "internal_user",
"user_alias": "Alice",
"teams": ["team1"],
"max_budget": 100.0
})
# Delete users
users.delete_user(["u1", "u2"])
```
## Low-Level HTTP Client
The client provides access to a low-level HTTP client for making direct requests
to the LiteLLM proxy server. This is useful when you need more control or when
working with endpoints that don't yet have a high-level interface.
```python
# Access the HTTP client
client = Client(
base_url="http://localhost:4000",
api_key="sk-api-key"
)
# Make a custom request
response = client.http.request(
method="POST",
uri="/health/test_connection",
json={
"litellm_params": {
"model": "gpt-4",
"api_key": "your-api-key",
"api_base": "https://api.openai.com/v1"
},
"mode": "chat"
}
)
# The response is automatically parsed from JSON
print(response)
```
### HTTP Client Features
- Automatic URL handling (handles trailing/leading slashes)
- Built-in authentication (adds Bearer token if `api_key` is provided)
- JSON request/response handling
- Configurable timeout (default: 30 seconds)
- Comprehensive error handling
- Support for custom headers and request parameters
### HTTP Client `request` method parameters
- `method`: HTTP method (GET, POST, PUT, DELETE, etc.)
- `uri`: URI path (will be appended to base_url)
- `data`: (optional) Data to send in the request body
- `json`: (optional) JSON data to send in the request body
- `headers`: (optional) Custom HTTP headers
- Additional keyword arguments are passed to the underlying requests library
## Error Handling
The client provides clear error handling with custom exceptions:
```python
from litellm.proxy.client.exceptions import UnauthorizedError
try:
response = client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Hello"}]
)
except UnauthorizedError as e:
print("Authentication failed:", e)
except Exception as e:
print("Request failed:", e)
```
## Advanced Usage
### Request Customization
All methods support returning the raw request object for inspection or modification:
```python
# Get the prepared request without sending it
request = client.models.list(return_request=True)
print(request.method) # GET
print(request.url) # http://localhost:8000/models
print(request.headers) # {'Content-Type': 'application/json', ...}
```
### Pagination
Methods that return lists support pagination:
```python
# Get the first page of keys
page1 = client.keys.list(page=1, size=10)
# Get the second page
page2 = client.keys.list(page=2, size=10)
```
### Filtering
Many list methods support filtering:
```python
# Filter keys by user and team
keys = client.keys.list(
user_id="user123",
team_id="team456",
include_team_keys=True
)
```
## Contributing
Contributions are welcome! Please check out our [contributing guidelines](../../CONTRIBUTING.md) for details.
## License
This project is licensed under the MIT License - see the [LICENSE](../../LICENSE) file for details.
## CLI Authentication Flow
The LiteLLM CLI supports SSO authentication through a polling-based approach that works with any OAuth-compatible SSO provider.
### How CLI Authentication Works
```mermaid
sequenceDiagram
participant CLI as CLI
participant Browser as Browser
participant Proxy as LiteLLM Proxy
participant SSO as SSO Provider
CLI->>CLI: Generate key ID (sk-uuid)
CLI->>Browser: Open /sso/key/generate?source=litellm-cli&key=sk-uuid
Browser->>Proxy: GET /sso/key/generate?source=litellm-cli&key=sk-uuid
Proxy->>Proxy: Set cli_state = litellm-session-token:sk-uuid
Proxy->>SSO: Redirect with state=litellm-session-token:sk-uuid
SSO->>Browser: Show login page
Browser->>SSO: User authenticates
SSO->>Proxy: Redirect to /sso/callback?state=litellm-session-token:sk-uuid
Proxy->>Proxy: Check if state starts with "litellm-session-token:"
Proxy->>Proxy: Generate API key with ID=sk-uuid
Proxy->>Browser: Show success page
CLI->>Proxy: Poll /sso/cli/poll/sk-uuid
Proxy->>CLI: Return {"status": "ready", "key": "sk-uuid"}
CLI->>CLI: Save key to ~/.litellm/token.json
```
### Authentication Commands
The CLI provides three authentication commands:
- **`litellm-proxy login`** - Start SSO authentication flow
- **`litellm-proxy logout`** - Clear stored authentication token
- **`litellm-proxy whoami`** - Show current authentication status
### Authentication Flow Steps
1. **Generate Session ID**: CLI generates a unique key ID (`sk-{uuid}`)
2. **Open Browser**: CLI opens browser to `/sso/key/generate` with CLI source and key parameters
3. **SSO Redirect**: Proxy sets the formatted state (`litellm-session-token:sk-uuid`) as OAuth state parameter and redirects to SSO provider
4. **User Authentication**: User completes SSO authentication in browser
5. **Callback Processing**: SSO provider redirects back to proxy with state parameter
6. **Key Generation**: Proxy detects CLI login (state starts with "litellm-session-token:") and generates API key with pre-specified ID
7. **Polling**: CLI polls `/sso/cli/poll/{key_id}` endpoint until key is ready
8. **Token Storage**: CLI saves the authentication token to `~/.litellm/token.json`
### Benefits of This Approach
- **No Local Server**: No need to run a local callback server
- **Standard OAuth**: Uses OAuth 2.0 state parameter correctly
- **Remote Compatible**: Works with remote proxy servers
- **Secure**: Uses UUID session identifiers
- **Simple Setup**: No additional OAuth redirect URL configuration needed
### Token Storage
Authentication tokens are stored in `~/.litellm/token.json` with restricted file permissions (600). The stored token includes:
```json
{
"key": "sk-...",
"user_id": "cli-user",
"user_email": "user@example.com",
"user_role": "cli",
"auth_header_name": "Authorization",
"timestamp": 1234567890
}
```
### Usage
Once authenticated, the CLI will automatically use the stored token for all requests. You no longer need to specify `--api-key` for subsequent commands.
```bash
# Login
litellm-proxy login
# Use CLI without specifying API key
litellm-proxy models list
# Check authentication status
litellm-proxy whoami
# Logout
litellm-proxy logout
```

View File

@@ -0,0 +1,17 @@
from .client import Client
from .chat import ChatClient
from .models import ModelsManagementClient
from .model_groups import ModelGroupsManagementClient
from .exceptions import UnauthorizedError
from .users import UsersManagementClient
from .health import HealthManagementClient
__all__ = [
"Client",
"ChatClient",
"ModelsManagementClient",
"ModelGroupsManagementClient",
"UsersManagementClient",
"UnauthorizedError",
"HealthManagementClient",
]

View File

@@ -0,0 +1,185 @@
import json
from typing import Any, Dict, Iterator, List, Optional, Union
import requests
from .exceptions import UnauthorizedError
class ChatClient:
def __init__(self, base_url: str, api_key: Optional[str] = None):
"""
Initialize the ChatClient.
Args:
base_url (str): The base URL of the LiteLLM proxy server (e.g., "http://localhost:8000")
api_key (Optional[str]): API key for authentication. If provided, it will be sent as a Bearer token.
"""
self._base_url = base_url.rstrip("/") # Remove trailing slash if present
self._api_key = api_key
def _get_headers(self) -> Dict[str, str]:
"""
Get the headers for API requests, including authorization if api_key is set.
Returns:
Dict[str, str]: Headers to use for API requests
"""
headers = {"Content-Type": "application/json"}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
return headers
def completions(
self,
model: str,
messages: List[Dict[str, str]],
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
user: Optional[str] = None,
return_request: bool = False,
) -> Union[Dict[str, Any], requests.Request]:
"""
Create a chat completion.
Args:
model (str): The model to use for completion
messages (List[Dict[str, str]]): The messages to generate a completion for
temperature (Optional[float]): Sampling temperature between 0 and 2
top_p (Optional[float]): Nucleus sampling parameter between 0 and 1
n (Optional[int]): Number of completions to generate
max_tokens (Optional[int]): Maximum number of tokens to generate
presence_penalty (Optional[float]): Presence penalty between -2.0 and 2.0
frequency_penalty (Optional[float]): Frequency penalty between -2.0 and 2.0
user (Optional[str]): Unique identifier for the end user
return_request (bool): If True, returns the prepared request object instead of executing it
Returns:
Union[Dict[str, Any], requests.Request]: Either the completion response from the server or
a prepared request object if return_request is True
Raises:
UnauthorizedError: If the request fails with a 401 status code
requests.exceptions.RequestException: If the request fails with any other error
"""
url = f"{self._base_url}/chat/completions"
# Build request data with required fields
data: Dict[str, Any] = {"model": model, "messages": messages}
# Add optional parameters if provided
if temperature is not None:
data["temperature"] = temperature
if top_p is not None:
data["top_p"] = top_p
if n is not None:
data["n"] = n
if max_tokens is not None:
data["max_tokens"] = max_tokens
if presence_penalty is not None:
data["presence_penalty"] = presence_penalty
if frequency_penalty is not None:
data["frequency_penalty"] = frequency_penalty
if user is not None:
data["user"] = user
request = requests.Request("POST", url, headers=self._get_headers(), json=data)
if return_request:
return request
# Prepare and send the request
session = requests.Session()
try:
response = session.send(request.prepare())
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise UnauthorizedError(e)
raise
def completions_stream(
self,
model: str,
messages: List[Dict[str, str]],
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
user: Optional[str] = None,
) -> Iterator[Dict[str, Any]]:
"""
Create a streaming chat completion.
Args:
model (str): The model to use for completion
messages (List[Dict[str, str]]): The messages to generate a completion for
temperature (Optional[float]): Sampling temperature between 0 and 2
top_p (Optional[float]): Nucleus sampling parameter between 0 and 1
n (Optional[int]): Number of completions to generate
max_tokens (Optional[int]): Maximum number of tokens to generate
presence_penalty (Optional[float]): Presence penalty between -2.0 and 2.0
frequency_penalty (Optional[float]): Frequency penalty between -2.0 and 2.0
user (Optional[str]): Unique identifier for the end user
Yields:
Dict[str, Any]: Streaming response chunks from the server
Raises:
UnauthorizedError: If the request fails with a 401 status code
requests.exceptions.RequestException: If the request fails with any other error
"""
url = f"{self._base_url}/chat/completions"
# Build request data with required fields
data: Dict[str, Any] = {"model": model, "messages": messages, "stream": True}
# Add optional parameters if provided
if temperature is not None:
data["temperature"] = temperature
if top_p is not None:
data["top_p"] = top_p
if n is not None:
data["n"] = n
if max_tokens is not None:
data["max_tokens"] = max_tokens
if presence_penalty is not None:
data["presence_penalty"] = presence_penalty
if frequency_penalty is not None:
data["frequency_penalty"] = frequency_penalty
if user is not None:
data["user"] = user
# Make streaming request
session = requests.Session()
try:
response = session.post(
url, headers=self._get_headers(), json=data, stream=True
)
response.raise_for_status()
# Parse SSE stream
for line in response.iter_lines():
if line:
line = line.decode("utf-8")
if line.startswith("data: "):
data_str = line[6:] # Remove 'data: ' prefix
if data_str.strip() == "[DONE]":
break
try:
chunk = json.loads(data_str)
yield chunk
except json.JSONDecodeError:
continue
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise UnauthorizedError(e)
raise

View File

@@ -0,0 +1,536 @@
# LiteLLM Proxy CLI
The LiteLLM Proxy CLI is a command-line tool for managing your LiteLLM proxy server. It provides commands for managing models, viewing server status, and interacting with the proxy server.
## Installation
```bash
pip install 'litellm[proxy]'
```
## Configuration
The CLI can be configured using environment variables or command-line options:
- `LITELLM_PROXY_URL`: Base URL of the LiteLLM proxy server (default: http://localhost:4000)
- `LITELLM_PROXY_API_KEY`: API key for authentication
## Global Options
- `--version`, `-v`: Print the LiteLLM Proxy client and server version and exit.
Example:
```bash
litellm-proxy version
# or
litellm-proxy --version
# or
litellm-proxy -v
```
## Commands
### Models Management
The CLI provides several commands for managing models on your LiteLLM proxy server:
#### List Models
View all available models:
```bash
litellm-proxy models list [--format table|json]
```
Options:
- `--format`: Output format (table or json, default: table)
#### Model Information
Get detailed information about all models:
```bash
litellm-proxy models info [options]
```
Options:
- `--format`: Output format (table or json, default: table)
- `--columns`: Comma-separated list of columns to display. Valid columns:
- `public_model`
- `upstream_model`
- `credential_name`
- `created_at`
- `updated_at`
- `id`
- `input_cost`
- `output_cost`
Default columns: `public_model`, `upstream_model`, `updated_at`
#### Add Model
Add a new model to the proxy:
```bash
litellm-proxy models add <model-name> [options]
```
Options:
- `--param`, `-p`: Model parameters in key=value format (can be specified multiple times)
- `--info`, `-i`: Model info in key=value format (can be specified multiple times)
Example:
```bash
litellm-proxy models add gpt-4 -p api_key=sk-123 -p api_base=https://api.openai.com -i description="GPT-4 model"
```
#### Get Model Info
Get information about a specific model:
```bash
litellm-proxy models get [--id MODEL_ID] [--name MODEL_NAME]
```
Options:
- `--id`: ID of the model to retrieve
- `--name`: Name of the model to retrieve
#### Delete Model
Delete a model from the proxy:
```bash
litellm-proxy models delete <model-id>
```
#### Update Model
Update an existing model's configuration:
```bash
litellm-proxy models update <model-id> [options]
```
Options:
- `--param`, `-p`: Model parameters in key=value format (can be specified multiple times)
- `--info`, `-i`: Model info in key=value format (can be specified multiple times)
#### Import Models
Import models from a YAML file:
```bash
litellm-proxy models import models.yaml
```
Options:
- `--dry-run`: Show what would be imported without making any changes.
- `--only-models-matching-regex <regex>`: Only import models where `litellm_params.model` matches the given regex.
- `--only-access-groups-matching-regex <regex>`: Only import models where at least one item in `model_info.access_groups` matches the given regex.
Examples:
1. Import all models from a YAML file:
```bash
litellm-proxy models import models.yaml
```
2. Dry run (show what would be imported):
```bash
litellm-proxy models import models.yaml --dry-run
```
3. Only import models where the model name contains 'gpt':
```bash
litellm-proxy models import models.yaml --only-models-matching-regex gpt
```
4. Only import models with access group containing 'beta':
```bash
litellm-proxy models import models.yaml --only-access-groups-matching-regex beta
```
5. Combine both filters:
```bash
litellm-proxy models import models.yaml --only-models-matching-regex gpt --only-access-groups-matching-regex beta
```
### Credentials Management
The CLI provides commands for managing credentials on your LiteLLM proxy server:
#### List Credentials
View all available credentials:
```bash
litellm-proxy credentials list [--format table|json]
```
Options:
- `--format`: Output format (table or json, default: table)
The table format displays:
- Credential Name
- Custom LLM Provider
#### Create Credential
Create a new credential:
```bash
litellm-proxy credentials create <credential-name> --info <json-string> --values <json-string>
```
Options:
- `--info`: JSON string containing credential info (e.g., custom_llm_provider)
- `--values`: JSON string containing credential values (e.g., api_key)
Example:
```bash
litellm-proxy credentials create azure-cred \
--info '{"custom_llm_provider": "azure"}' \
--values '{"api_key": "sk-123", "api_base": "https://example.azure.openai.com"}'
```
#### Get Credential
Get information about a specific credential:
```bash
litellm-proxy credentials get <credential-name>
```
#### Delete Credential
Delete a credential:
```bash
litellm-proxy credentials delete <credential-name>
```
### Keys Management
The CLI provides commands for managing API keys on your LiteLLM proxy server:
#### List Keys
View all API keys:
```bash
litellm-proxy keys list [--format table|json] [options]
```
Options:
- `--format`: Output format (table or json, default: table)
- `--page`: Page number for pagination
- `--size`: Number of items per page
- `--user-id`: Filter keys by user ID
- `--team-id`: Filter keys by team ID
- `--organization-id`: Filter keys by organization ID
- `--key-hash`: Filter by specific key hash
- `--key-alias`: Filter by key alias
- `--return-full-object`: Return the full key object
- `--include-team-keys`: Include team keys in the response
#### Generate Key
Generate a new API key:
```bash
litellm-proxy keys generate [options]
```
Options:
- `--models`: Comma-separated list of allowed models
- `--aliases`: JSON string of model alias mappings
- `--spend`: Maximum spend limit for this key
- `--duration`: Duration for which the key is valid (e.g. '24h', '7d')
- `--key-alias`: Alias/name for the key
- `--team-id`: Team ID to associate the key with
- `--user-id`: User ID to associate the key with
- `--budget-id`: Budget ID to associate the key with
- `--config`: JSON string of additional configuration parameters
Example:
```bash
litellm-proxy keys generate --models gpt-4,gpt-3.5-turbo --spend 100 --duration 24h --key-alias my-key --team-id team123
```
#### Delete Keys
Delete API keys by key or alias:
```bash
litellm-proxy keys delete [--keys <comma-separated-keys>] [--key-aliases <comma-separated-aliases>]
```
Options:
- `--keys`: Comma-separated list of API keys to delete
- `--key-aliases`: Comma-separated list of key aliases to delete
Example:
```bash
litellm-proxy keys delete --keys sk-key1,sk-key2 --key-aliases alias1,alias2
```
#### Get Key Info
Get information about a specific API key:
```bash
litellm-proxy keys info --key <key-hash>
```
Options:
- `--key`: The key hash to get information about
Example:
```bash
litellm-proxy keys info --key sk-key1
```
### User Management
The CLI provides commands for managing users on your LiteLLM proxy server:
#### List Users
View all users:
```bash
litellm-proxy users list
```
#### Get User Info
Get information about a specific user:
```bash
litellm-proxy users get --id <user-id>
```
#### Create User
Create a new user:
```bash
litellm-proxy users create --email user@example.com --role internal_user --alias "Alice" --team team1 --max-budget 100.0
```
#### Delete User
Delete one or more users by user_id:
```bash
litellm-proxy users delete <user-id-1> <user-id-2>
```
### Chat Commands
The CLI provides commands for interacting with chat models through your LiteLLM proxy server:
#### Chat Completions
Create a chat completion:
```bash
litellm-proxy chat completions <model> [options]
```
Arguments:
- `model`: The model to use (e.g., gpt-4, claude-2)
Options:
- `--message`, `-m`: Messages in 'role:content' format. Can be specified multiple times to create a conversation.
- `--temperature`, `-t`: Sampling temperature between 0 and 2
- `--top-p`: Nucleus sampling parameter between 0 and 1
- `--n`: Number of completions to generate
- `--max-tokens`: Maximum number of tokens to generate
- `--presence-penalty`: Presence penalty between -2.0 and 2.0
- `--frequency-penalty`: Frequency penalty between -2.0 and 2.0
- `--user`: Unique identifier for the end user
Examples:
1. Simple completion:
```bash
litellm-proxy chat completions gpt-4 -m "user:Hello, how are you?"
```
2. Multi-message conversation:
```bash
litellm-proxy chat completions gpt-4 \
-m "system:You are a helpful assistant" \
-m "user:What's the capital of France?" \
-m "assistant:The capital of France is Paris." \
-m "user:What's its population?"
```
3. With generation parameters:
```bash
litellm-proxy chat completions gpt-4 \
-m "user:Write a story" \
--temperature 0.7 \
--max-tokens 500 \
--top-p 0.9
```
### HTTP Commands
The CLI provides commands for making direct HTTP requests to your LiteLLM proxy server:
#### Make HTTP Request
Make an HTTP request to any endpoint:
```bash
litellm-proxy http request <method> <uri> [options]
```
Arguments:
- `method`: HTTP method (GET, POST, PUT, DELETE, etc.)
- `uri`: URI path (will be appended to base_url)
Options:
- `--data`, `-d`: Data to send in the request body (as JSON string)
- `--json`, `-j`: JSON data to send in the request body (as JSON string)
- `--header`, `-H`: HTTP headers in 'key:value' format. Can be specified multiple times.
Examples:
1. List models:
```bash
litellm-proxy http request GET /models
```
2. Create a chat completion:
```bash
litellm-proxy http request POST /chat/completions -j '{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}'
```
3. Test connection with custom headers:
```bash
litellm-proxy http request GET /health/test_connection -H "X-Custom-Header:value"
```
## Environment Variables
The CLI respects the following environment variables:
- `LITELLM_PROXY_URL`: Base URL of the proxy server
- `LITELLM_PROXY_API_KEY`: API key for authentication
## Examples
1. List all models in table format:
```bash
litellm-proxy models list
```
2. Add a new model with parameters:
```bash
litellm-proxy models add gpt-4 -p api_key=sk-123 -p max_tokens=2048
```
3. Get model information in JSON format:
```bash
litellm-proxy models info --format json
```
4. Update model parameters:
```bash
litellm-proxy models update model-123 -p temperature=0.7 -i description="Updated model"
```
5. List all credentials in table format:
```bash
litellm-proxy credentials list
```
6. Create a new credential for Azure:
```bash
litellm-proxy credentials create azure-prod \
--info '{"custom_llm_provider": "azure"}' \
--values '{"api_key": "sk-123", "api_base": "https://prod.azure.openai.com"}'
```
7. Make a custom HTTP request:
```bash
litellm-proxy http request POST /chat/completions \
-j '{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}' \
-H "X-Custom-Header:value"
```
8. User management:
```bash
# List users
litellm-proxy users list
# Get user info
litellm-proxy users get --id u1
# Create a user
litellm-proxy users create --email a@b.com --role internal_user --alias "Alice" --team team1 --max-budget 100.0
# Delete users
litellm-proxy users delete u1 u2
```
9. Import models from a YAML file (with filters):
```bash
# Only import models where the model name contains 'gpt'
litellm-proxy models import models.yaml --only-models-matching-regex gpt
# Only import models with access group containing 'beta'
litellm-proxy models import models.yaml --only-access-groups-matching-regex beta
# Combine both filters
litellm-proxy models import models.yaml --only-models-matching-regex gpt --only-access-groups-matching-regex beta
```
## Error Handling
The CLI will display appropriate error messages when:
- The proxy server is not accessible
- Authentication fails
- Invalid parameters are provided
- The requested model or credential doesn't exist
- Invalid JSON is provided for credential creation
- Any other operation fails
For detailed debugging, use the `--debug` flag with any command.

View File

@@ -0,0 +1,5 @@
"""CLI package for LiteLLM Proxy Client."""
from .main import cli
__all__ = ["cli"]

View File

@@ -0,0 +1 @@
"""Command groups for the LiteLLM proxy CLI."""

View File

@@ -0,0 +1,623 @@
import json
import os
import sys
import time
import webbrowser
from pathlib import Path
from typing import Any, Dict, List, Optional
import click
import requests
from rich.console import Console
from rich.table import Table
from litellm.constants import CLI_JWT_EXPIRATION_HOURS
# Token storage utilities
def get_token_file_path() -> str:
"""Get the path to store the authentication token"""
home_dir = Path.home()
config_dir = home_dir / ".litellm"
config_dir.mkdir(exist_ok=True)
return str(config_dir / "token.json")
def save_token(token_data: Dict[str, Any]) -> None:
"""Save token data to file"""
token_file = get_token_file_path()
with open(token_file, "w") as f:
json.dump(token_data, f, indent=2)
# Set file permissions to be readable only by owner
os.chmod(token_file, 0o600)
def load_token() -> Optional[Dict[str, Any]]:
"""Load token data from file"""
token_file = get_token_file_path()
if not os.path.exists(token_file):
return None
try:
with open(token_file, "r") as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
return None
def clear_token() -> None:
"""Clear stored token"""
token_file = get_token_file_path()
if os.path.exists(token_file):
os.remove(token_file)
def get_stored_api_key() -> Optional[str]:
"""Get the stored API key from token file"""
# Use the SDK-level utility
from litellm.litellm_core_utils.cli_token_utils import get_litellm_gateway_api_key
return get_litellm_gateway_api_key()
# Team selection utilities
def display_teams_table(teams: List[Dict[str, Any]]) -> None:
"""Display teams in a formatted table"""
console = Console()
if not teams:
console.print("❌ No teams found for your user.")
return
table = Table(title="Available Teams")
table.add_column("Index", style="cyan", no_wrap=True)
table.add_column("Team Alias", style="magenta")
table.add_column("Team ID", style="green")
table.add_column("Models", style="yellow")
table.add_column("Max Budget", style="blue")
for i, team in enumerate(teams):
team_alias = team.get("team_alias") or "N/A"
team_id = team.get("team_id", "N/A")
models = team.get("models", [])
max_budget = team.get("max_budget")
# Format models list
if models:
if len(models) > 3:
models_str = ", ".join(models[:3]) + f" (+{len(models) - 3} more)"
else:
models_str = ", ".join(models)
else:
models_str = "All models"
# Format budget
budget_str = f"${max_budget}" if max_budget else "Unlimited"
table.add_row(str(i + 1), team_alias, team_id, models_str, budget_str)
console.print(table)
def get_key_input():
"""Get a single key input from the user (cross-platform)"""
try:
if sys.platform == "win32":
import msvcrt
key = msvcrt.getch()
if key == b"\xe0": # Arrow keys on Windows
key = msvcrt.getch()
if key == b"H": # Up arrow
return "up"
elif key == b"P": # Down arrow
return "down"
elif key == b"\r": # Enter key
return "enter"
elif key == b"\x1b": # Escape key
return "escape"
elif key == b"q":
return "quit"
return None
else:
import termios
import tty
fd = sys.stdin.fileno()
old_settings = termios.tcgetattr(fd)
try:
tty.setraw(sys.stdin.fileno())
key = sys.stdin.read(1)
if key == "\x1b": # Escape sequence
key += sys.stdin.read(2)
if key == "\x1b[A": # Up arrow
return "up"
elif key == "\x1b[B": # Down arrow
return "down"
elif key == "\x1b": # Just escape
return "escape"
elif key == "\r" or key == "\n": # Enter key
return "enter"
elif key == "q":
return "quit"
return None
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
except ImportError:
# Fallback to simple input if termios/msvcrt not available
return None
def display_interactive_team_selection(
teams: List[Dict[str, Any]], selected_index: int = 0
) -> None:
"""Display teams with one highlighted for selection"""
console = Console()
# Clear the screen using Rich's method
console.clear()
console.print("🎯 Select a Team (Use ↑↓ arrows, Enter to select, 'q' to skip):\n")
for i, team in enumerate(teams):
team_alias = team.get("team_alias") or "N/A"
team_id = team.get("team_id", "N/A")
models = team.get("models", [])
max_budget = team.get("max_budget")
# Format models list
if models:
if len(models) > 3:
models_str = ", ".join(models[:3]) + f" (+{len(models) - 3} more)"
else:
models_str = ", ".join(models)
else:
models_str = "All models"
# Format budget
budget_str = f"${max_budget}" if max_budget else "Unlimited"
# Highlight the selected item
if i == selected_index:
console.print(f"➤ [bold cyan]{team_alias}[/bold cyan] ({team_id})")
console.print(f" Models: [yellow]{models_str}[/yellow]")
console.print(f" Budget: [blue]{budget_str}[/blue]\n")
else:
console.print(f" [dim]{team_alias}[/dim] ({team_id})")
console.print(f" Models: [dim]{models_str}[/dim]")
console.print(f" Budget: [dim]{budget_str}[/dim]\n")
def prompt_team_selection(teams: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Interactive team selection with arrow keys"""
if not teams:
return None
selected_index = 0
try:
# Check if we can use interactive mode
if not sys.stdin.isatty():
# Fallback to simple selection for non-interactive environments
return prompt_team_selection_fallback(teams)
while True:
display_interactive_team_selection(teams, selected_index)
key = get_key_input()
if key == "up":
selected_index = (selected_index - 1) % len(teams)
elif key == "down":
selected_index = (selected_index + 1) % len(teams)
elif key == "enter":
selected_team = teams[selected_index]
# Clear screen and show selection
console = Console()
console.clear()
click.echo(
f"✅ Selected team: {selected_team.get('team_alias', 'N/A')} ({selected_team.get('team_id')})"
)
return selected_team
elif key == "quit" or key == "escape":
# Clear screen
console = Console()
console.clear()
click.echo(" Team selection skipped.")
return None
elif key is None:
# If we can't get key input, fall back to simple selection
return prompt_team_selection_fallback(teams)
except KeyboardInterrupt:
console = Console()
console.clear()
click.echo("\n❌ Team selection cancelled.")
return None
except Exception:
# If interactive mode fails, fall back to simple selection
return prompt_team_selection_fallback(teams)
def prompt_team_selection_fallback(
teams: List[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""Fallback team selection for non-interactive environments"""
if not teams:
return None
while True:
try:
choice = click.prompt(
"\nSelect a team by entering the index number (or 'skip' to continue without a team)",
type=str,
).strip()
if choice.lower() == "skip":
return None
index = int(choice) - 1
if 0 <= index < len(teams):
selected_team = teams[index]
click.echo(
f"\n✅ Selected team: {selected_team.get('team_alias', 'N/A')} ({selected_team.get('team_id')})"
)
return selected_team
else:
click.echo(
f"❌ Invalid selection. Please enter a number between 1 and {len(teams)}"
)
except ValueError:
click.echo("❌ Invalid input. Please enter a number or 'skip'")
except KeyboardInterrupt:
click.echo("\n❌ Team selection cancelled.")
return None
# Polling-based authentication - no local server needed
def _poll_for_ready_data(
url: str,
*,
total_timeout: int = 300,
poll_interval: int = 2,
request_timeout: int = 10,
pending_message: Optional[str] = None,
pending_log_every: int = 10,
other_status_message: Optional[str] = None,
other_status_log_every: int = 10,
http_error_log_every: int = 10,
connection_error_log_every: int = 10,
) -> Optional[Dict[str, Any]]:
for attempt in range(total_timeout // poll_interval):
try:
response = requests.get(url, timeout=request_timeout)
if response.status_code == 200:
data = response.json()
status = data.get("status")
if status == "ready":
return data
if status == "pending":
if (
pending_message
and pending_log_every > 0
and attempt % pending_log_every == 0
):
click.echo(pending_message)
elif (
other_status_message
and other_status_log_every > 0
and attempt % other_status_log_every == 0
):
click.echo(other_status_message)
elif http_error_log_every > 0 and attempt % http_error_log_every == 0:
click.echo(f"Polling error: HTTP {response.status_code}")
except requests.RequestException as e:
if (
connection_error_log_every > 0
and attempt % connection_error_log_every == 0
):
click.echo(f"Connection error (will retry): {e}")
time.sleep(poll_interval)
return None
def _normalize_teams(teams, team_details):
"""If team_details are a
Args:
teams (_type_): _description_
team_details (_type_): _description_
Returns:
_type_: _description_
"""
if isinstance(team_details, list) and team_details:
return [
{
"team_id": i.get("team_id") or i.get("id"),
"team_alias": i.get("team_alias"),
}
for i in team_details
if isinstance(i, dict) and (i.get("team_id") or i.get("id"))
]
if isinstance(teams, list):
return [{"team_id": str(t), "team_alias": None} for t in teams]
return []
def _poll_for_authentication(base_url: str, key_id: str) -> Optional[dict]:
"""
Poll the server for authentication completion and handle team selection.
Returns:
Dictionary with authentication data if successful, None otherwise
"""
poll_url = f"{base_url}/sso/cli/poll/{key_id}"
data = _poll_for_ready_data(
poll_url,
pending_message="Still waiting for authentication...",
)
if not data:
return None
if data.get("requires_team_selection"):
teams = data.get("teams", [])
team_details = data.get("team_details")
user_id = data.get("user_id")
normalized_teams: List[Dict[str, Any]] = _normalize_teams(teams, team_details)
if not normalized_teams:
click.echo("⚠️ No teams available for selection.")
return None
# User has multiple teams - let them select
jwt_with_team = _handle_team_selection_during_polling(
base_url=base_url,
key_id=key_id,
teams=normalized_teams,
)
# Use the team-specific JWT if selection succeeded
if jwt_with_team:
return {
"api_key": jwt_with_team,
"user_id": user_id,
"teams": teams,
"team_id": None, # Set by server in JWT
}
click.echo("❌ Team selection cancelled or JWT generation failed.")
return None
# JWT is ready (single team or team already selected)
api_key = data.get("key")
user_id = data.get("user_id")
teams = data.get("teams", [])
team_id = data.get("team_id")
# Show which team was assigned
if team_id and len(teams) == 1:
click.echo(f"\n✅ Automatically assigned to team: {team_id}")
if api_key:
return {
"api_key": api_key,
"user_id": user_id,
"teams": teams,
"team_id": team_id,
}
return None
def _handle_team_selection_during_polling(
base_url: str, key_id: str, teams: List[Dict[str, Any]]
) -> Optional[str]:
"""
Handle team selection and re-poll with selected team_id.
Args:
teams: List of team IDs (strings)
Returns:
The JWT token with the selected team, or None if selection was skipped
"""
if not teams:
click.echo(
" No teams found. You can create or join teams using the web interface."
)
return None
click.echo("\n" + "=" * 60)
click.echo("📋 Select a team for your CLI session...")
team_id = _render_and_prompt_for_team_selection(teams)
if not team_id:
click.echo(" No team selected.")
return None
click.echo(f"\n🔄 Generating JWT for team: {team_id}")
poll_url = f"{base_url}/sso/cli/poll/{key_id}?team_id={team_id}"
data = _poll_for_ready_data(
poll_url,
pending_message="Still waiting for team authentication...",
other_status_message="Waiting for team authentication to complete...",
http_error_log_every=10,
)
if not data:
return None
jwt_token = data.get("key")
if jwt_token:
click.echo(f"✅ Successfully generated JWT for team: {team_id}")
return jwt_token
return None
def _render_and_prompt_for_team_selection(teams: List[Dict[str, Any]]) -> Optional[str]:
"""Render teams table and prompt user for a team selection.
Returns the selected team_id as a string, or None if selection was
cancelled or skipped without any teams available.
"""
# Display teams as a simple list, but prefer showing aliases where
# available while still keeping the underlying IDs intact.
console = Console()
table = Table(title="Available Teams")
table.add_column("Index", style="cyan", no_wrap=True)
table.add_column("Team Name", style="magenta")
table.add_column("Team ID", style="green")
for i, team in enumerate(teams):
team_id = str(team.get("team_id"))
team_alias = team.get("team_alias") or team_id
table.add_row(str(i + 1), team_alias, team_id)
console.print(table)
# Simple selection
while True:
try:
choice = click.prompt(
"\nSelect a team by entering the index number (or 'skip' to use first team)",
type=str,
).strip()
if choice.lower() == "skip":
# Default to the first team's ID if the user skips an
# explicit selection.
if teams:
first_team = teams[0]
return str(first_team.get("team_id"))
return None
index = int(choice) - 1
if 0 <= index < len(teams):
selected_team = teams[index]
team_id = str(selected_team.get("team_id"))
team_alias = selected_team.get("team_alias") or team_id
click.echo(f"\n✅ Selected team: {team_alias} ({team_id})")
return team_id
click.echo(
f"❌ Invalid selection. Please enter a number between 1 and {len(teams)}"
)
except ValueError:
click.echo("❌ Invalid input. Please enter a number or 'skip'")
except KeyboardInterrupt:
click.echo("\n❌ Team selection cancelled.")
return None
@click.command(name="login")
@click.pass_context
def login(ctx: click.Context):
"""Login to LiteLLM proxy using SSO authentication"""
from litellm._uuid import uuid
from litellm.constants import LITELLM_CLI_SOURCE_IDENTIFIER
from litellm.proxy.client.cli.interface import show_commands
base_url = ctx.obj["base_url"]
# Check if we have an existing key to regenerate
existing_key = get_stored_api_key()
# Generate unique key ID for this login session
key_id = f"sk-{str(uuid.uuid4())}"
try:
# Construct SSO login URL with CLI source and pre-generated key
sso_url = f"{base_url}/sso/key/generate?source={LITELLM_CLI_SOURCE_IDENTIFIER}&key={key_id}"
# If we have an existing key, include it as a parameter to the login endpoint
# The server will encode it in the OAuth state parameter for the SSO flow
if existing_key:
sso_url += f"&existing_key={existing_key}"
click.echo(f"Opening browser to: {sso_url}")
click.echo("Please complete the SSO authentication in your browser...")
click.echo(f"Session ID: {key_id}")
# Open browser
webbrowser.open(sso_url)
# Poll for authentication completion
click.echo("Waiting for authentication...")
auth_result = _poll_for_authentication(base_url=base_url, key_id=key_id)
if auth_result:
api_key = auth_result["api_key"]
user_id = auth_result["user_id"]
# Save token data (simplified for CLI - we just need the key)
save_token(
{
"key": api_key,
"user_id": user_id or "cli-user",
"user_email": "unknown",
"user_role": "cli",
"auth_header_name": "Authorization",
"jwt_token": "",
"timestamp": time.time(),
}
)
click.echo("\n✅ Login successful!")
click.echo(f"JWT Token: {api_key[:20]}...")
click.echo("You can now use the CLI without specifying --api-key")
# Show available commands after successful login
click.echo("\n" + "=" * 60)
show_commands()
return
else:
click.echo("❌ Authentication timed out. Please try again.")
return
except KeyboardInterrupt:
click.echo("\n❌ Authentication cancelled by user.")
return
except Exception as e:
click.echo(f"❌ Authentication failed: {e}")
return
@click.command(name="logout")
def logout():
"""Logout and clear stored authentication"""
clear_token()
click.echo("✅ Logged out successfully. Authentication token cleared.")
@click.command(name="whoami")
def whoami():
"""Show current authentication status"""
token_data = load_token()
if not token_data:
click.echo("❌ Not authenticated. Run 'litellm-proxy login' to authenticate.")
return
click.echo("✅ Authenticated")
click.echo(f"User Email: {token_data.get('user_email', 'Unknown')}")
click.echo(f"User ID: {token_data.get('user_id', 'Unknown')}")
click.echo(f"User Role: {token_data.get('user_role', 'Unknown')}")
# Check if token is still valid (basic timestamp check)
timestamp = token_data.get("timestamp", 0)
age_hours = (time.time() - timestamp) / 3600
click.echo(f"Token age: {age_hours:.1f} hours")
if age_hours > CLI_JWT_EXPIRATION_HOURS:
click.echo(
f"⚠️ Warning: Token is more than {CLI_JWT_EXPIRATION_HOURS} hours old and may have expired."
)
# Export functions for use by other CLI commands
__all__ = ["login", "logout", "whoami", "prompt_team_selection"]
# Export individual commands instead of grouping them
# login, logout, and whoami will be added as top-level commands

View File

@@ -0,0 +1,406 @@
import json
import sys
from typing import Any, Dict, List, Optional
import click
import requests
from rich.console import Console
from rich.panel import Panel
from rich.prompt import Prompt
from rich.table import Table
from ... import Client
from ...chat import ChatClient
def _get_available_models(ctx: click.Context) -> List[Dict[str, Any]]:
"""Get list of available models from the proxy server"""
try:
client = Client(base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"])
models_list = client.models.list()
# Ensure we return a list of dictionaries
if isinstance(models_list, list):
# Filter to ensure all items are dictionaries
return [model for model in models_list if isinstance(model, dict)]
return []
except Exception as e:
click.echo(f"Warning: Could not fetch models list: {e}", err=True)
return []
def _select_model(
console: Console, available_models: List[Dict[str, Any]]
) -> Optional[str]:
"""Interactive model selection"""
if not available_models:
console.print(
"[yellow]No models available or could not fetch models list.[/yellow]"
)
model_name = Prompt.ask("Please enter a model name")
return model_name if model_name.strip() else None
# Display available models in a table
table = Table(title="Available Models")
table.add_column("Index", style="cyan", no_wrap=True)
table.add_column("Model ID", style="green")
table.add_column("Owned By", style="yellow")
MAX_MODELS_TO_DISPLAY = 200
models_to_display: List[Dict[str, Any]] = available_models[:MAX_MODELS_TO_DISPLAY]
for i, model in enumerate(models_to_display): # Limit to first 200 models
table.add_row(
str(i + 1), str(model.get("id", "")), str(model.get("owned_by", ""))
)
if len(available_models) > MAX_MODELS_TO_DISPLAY:
console.print(
f"\n[dim]... and {len(available_models) - MAX_MODELS_TO_DISPLAY} more models[/dim]"
)
console.print(table)
while True:
try:
choice = Prompt.ask(
"\nSelect a model by entering the index number (or type a model name directly)",
default="1",
).strip()
# Try to parse as index
try:
index = int(choice) - 1
if 0 <= index < len(available_models):
return available_models[index]["id"]
else:
console.print(
f"[red]Invalid index. Please enter a number between 1 and {len(available_models)}[/red]"
)
continue
except ValueError:
# Not a number, treat as model name
if choice:
return choice
else:
console.print("[red]Please enter a valid model name or index[/red]")
continue
except KeyboardInterrupt:
console.print("\n[yellow]Model selection cancelled.[/yellow]")
return None
@click.command()
@click.argument("model", required=False)
@click.option(
"--temperature",
"-t",
type=float,
default=0.7,
help="Sampling temperature between 0 and 2 (default: 0.7)",
)
@click.option(
"--max-tokens",
type=int,
help="Maximum number of tokens to generate",
)
@click.option(
"--system",
"-s",
type=str,
help="System message to set the behavior of the assistant",
)
@click.pass_context
def chat(
ctx: click.Context,
model: Optional[str],
temperature: float,
max_tokens: Optional[int] = None,
system: Optional[str] = None,
):
"""Interactive chat with streaming responses
Examples:
# Chat with a specific model
litellm-proxy chat gpt-4
# Chat without specifying model (will show model selection)
litellm-proxy chat
# Chat with custom settings
litellm-proxy chat gpt-4 --temperature 0.9 --system "You are a helpful coding assistant"
"""
console = Console()
# If no model specified, show model selection
if not model:
available_models = _get_available_models(ctx)
model = _select_model(console, available_models)
if not model:
console.print("[red]No model selected. Exiting.[/red]")
return
client = ChatClient(ctx.obj["base_url"], ctx.obj["api_key"])
# Initialize conversation history
messages: List[Dict[str, Any]] = []
# Add system message if provided
if system:
messages.append({"role": "system", "content": system})
# Display welcome message
console.print(
Panel.fit(
f"[bold blue]LiteLLM Interactive Chat[/bold blue]\n"
f"Model: [green]{model}[/green]\n"
f"Temperature: [yellow]{temperature}[/yellow]\n"
f"Max Tokens: [yellow]{max_tokens or 'unlimited'}[/yellow]\n\n"
f"Type your messages and press Enter. Type '/quit' or '/exit' to end the session.\n"
f"Type '/help' for more commands.",
title="🤖 Chat Session",
)
)
try:
while True:
# Get user input
try:
user_input = console.input("\n[bold cyan]You:[/bold cyan] ").strip()
except (EOFError, KeyboardInterrupt):
console.print("\n[yellow]Chat session ended.[/yellow]")
break
# Handle special commands
should_exit, messages, new_model = _handle_special_commands(
console, user_input, messages, system, ctx
)
if should_exit:
break
if new_model:
model = new_model
# Check if this was a special command that was handled (not a normal message)
if (
user_input.lower().startswith(
(
"/quit",
"/exit",
"/q",
"/help",
"/clear",
"/history",
"/save",
"/load",
"/model",
)
)
or not user_input
):
continue
# Add user message to conversation
messages.append({"role": "user", "content": user_input})
# Display assistant label
console.print("\n[bold green]Assistant:[/bold green]")
# Stream the response
assistant_content = _stream_response(
console=console,
client=client,
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
# Add assistant message to conversation history
if assistant_content:
messages.append({"role": "assistant", "content": assistant_content})
else:
console.print("[red]Error: No content received from the model[/red]")
except KeyboardInterrupt:
console.print("\n[yellow]Chat session interrupted.[/yellow]")
def _show_help(console: Console):
"""Show help for interactive chat commands"""
help_text = """
[bold]Interactive Chat Commands:[/bold]
[cyan]/help[/cyan] - Show this help message
[cyan]/quit[/cyan] - Exit the chat session (also /exit, /q)
[cyan]/clear[/cyan] - Clear conversation history
[cyan]/history[/cyan] - Show conversation history
[cyan]/model[/cyan] - Switch to a different model
[cyan]/save <name>[/cyan] - Save conversation to file
[cyan]/load <name>[/cyan] - Load conversation from file
[bold]Tips:[/bold]
- Your conversation history is maintained during the session
- Use Ctrl+C to interrupt at any time
- Responses are streamed in real-time
- You can switch models mid-conversation with /model
"""
console.print(Panel(help_text, title="Help"))
def _show_history(console: Console, messages: List[Dict[str, Any]]):
"""Show conversation history"""
if not messages:
console.print("[yellow]No conversation history.[/yellow]")
return
console.print(Panel.fit("[bold]Conversation History[/bold]", title="History"))
for i, message in enumerate(messages, 1):
role = message["role"]
content = message["content"]
if role == "system":
console.print(
f"[dim]{i}. [bold magenta]System:[/bold magenta] {content}[/dim]"
)
elif role == "user":
console.print(f"{i}. [bold cyan]You:[/bold cyan] {content}")
elif role == "assistant":
console.print(
f"{i}. [bold green]Assistant:[/bold green] {content[:100]}{'...' if len(content) > 100 else ''}"
)
def _save_conversation(console: Console, messages: List[Dict[str, Any]], command: str):
"""Save conversation to a file"""
parts = command.split()
if len(parts) < 2:
console.print("[red]Usage: /save <filename>[/red]")
return
filename = parts[1]
if not filename.endswith(".json"):
filename += ".json"
try:
with open(filename, "w") as f:
json.dump(messages, f, indent=2)
console.print(f"[green]Conversation saved to {filename}[/green]")
except Exception as e:
console.print(f"[red]Error saving conversation: {e}[/red]")
def _load_conversation(
console: Console, command: str, system: Optional[str]
) -> List[Dict[str, Any]]:
"""Load conversation from a file"""
parts = command.split()
if len(parts) < 2:
console.print("[red]Usage: /load <filename>[/red]")
return []
filename = parts[1]
if not filename.endswith(".json"):
filename += ".json"
try:
with open(filename, "r") as f:
messages = json.load(f)
console.print(f"[green]Conversation loaded from {filename}[/green]")
return messages
except FileNotFoundError:
console.print(f"[red]File not found: {filename}[/red]")
except Exception as e:
console.print(f"[red]Error loading conversation: {e}[/red]")
# Return empty list or just system message if load failed
if system:
return [{"role": "system", "content": system}]
return []
def _handle_special_commands(
console: Console,
user_input: str,
messages: List[Dict[str, Any]],
system: Optional[str],
ctx: click.Context,
) -> tuple[bool, List[Dict[str, Any]], Optional[str]]:
"""Handle special chat commands. Returns (should_exit, updated_messages, updated_model)"""
if user_input.lower() in ["/quit", "/exit", "/q"]:
console.print("[yellow]Chat session ended.[/yellow]")
return True, messages, None
elif user_input.lower() == "/help":
_show_help(console)
return False, messages, None
elif user_input.lower() == "/clear":
new_messages = []
if system:
new_messages.append({"role": "system", "content": system})
console.print("[green]Conversation history cleared.[/green]")
return False, new_messages, None
elif user_input.lower() == "/history":
_show_history(console, messages)
return False, messages, None
elif user_input.lower().startswith("/save"):
_save_conversation(console, messages, user_input)
return False, messages, None
elif user_input.lower().startswith("/load"):
new_messages = _load_conversation(console, user_input, system)
return False, new_messages, None
elif user_input.lower() == "/model":
available_models = _get_available_models(ctx)
new_model = _select_model(console, available_models)
if new_model:
console.print(f"[green]Switched to model: {new_model}[/green]")
return False, messages, new_model
return False, messages, None
elif not user_input:
return False, messages, None
# Not a special command
return False, messages, None
def _stream_response(
console: Console,
client: ChatClient,
model: str,
messages: List[Dict[str, Any]],
temperature: float,
max_tokens: Optional[int],
) -> Optional[str]:
"""Stream the model response and return the complete content"""
try:
assistant_content = ""
for chunk in client.completions_stream(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
):
if "choices" in chunk and len(chunk["choices"]) > 0:
delta = chunk["choices"][0].get("delta", {})
content = delta.get("content", "")
if content:
assistant_content += content
console.print(content, end="")
sys.stdout.flush()
console.print() # Add newline after streaming
return assistant_content if assistant_content else None
except requests.exceptions.HTTPError as e:
console.print(f"\n[red]Error: HTTP {e.response.status_code}[/red]")
try:
error_body = e.response.json()
console.print(
f"[red]{error_body.get('error', {}).get('message', 'Unknown error')}[/red]"
)
except json.JSONDecodeError:
console.print(f"[red]{e.response.text}[/red]")
return None
except Exception as e:
console.print(f"\n[red]Error: {str(e)}[/red]")
return None

View File

@@ -0,0 +1,116 @@
import json
from typing import Literal
import click
import rich
import requests
from rich.table import Table
from ...credentials import CredentialsManagementClient
@click.group()
def credentials():
"""Manage credentials for the LiteLLM proxy server"""
pass
@credentials.command()
@click.option(
"--format",
"output_format",
type=click.Choice(["table", "json"]),
default="table",
help="Output format (table or json)",
)
@click.pass_context
def list(ctx: click.Context, output_format: Literal["table", "json"]):
"""List all credentials"""
client = CredentialsManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
response = client.list()
assert isinstance(response, dict)
if output_format == "json":
rich.print_json(data=response)
else: # table format
table = Table(title="Credentials")
# Add columns
table.add_column("Credential Name", style="cyan")
table.add_column("Custom LLM Provider", style="green")
# Add rows
for cred in response.get("credentials", []):
info = cred.get("credential_info", {})
table.add_row(
str(cred.get("credential_name", "")),
str(info.get("custom_llm_provider", "")),
)
rich.print(table)
@credentials.command()
@click.argument("credential_name")
@click.option(
"--info",
type=str,
help="JSON string containing credential info",
required=True,
)
@click.option(
"--values",
type=str,
help="JSON string containing credential values",
required=True,
)
@click.pass_context
def create(ctx: click.Context, credential_name: str, info: str, values: str):
"""Create a new credential"""
client = CredentialsManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
try:
credential_info = json.loads(info)
credential_values = json.loads(values)
except json.JSONDecodeError as e:
raise click.BadParameter(f"Invalid JSON: {str(e)}")
try:
response = client.create(credential_name, credential_info, credential_values)
rich.print_json(data=response)
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
try:
error_body = e.response.json()
rich.print_json(data=error_body)
except json.JSONDecodeError:
click.echo(e.response.text, err=True)
raise click.Abort()
@credentials.command()
@click.argument("credential_name")
@click.pass_context
def delete(ctx: click.Context, credential_name: str):
"""Delete a credential by name"""
client = CredentialsManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
try:
response = client.delete(credential_name)
rich.print_json(data=response)
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
try:
error_body = e.response.json()
rich.print_json(data=error_body)
except json.JSONDecodeError:
click.echo(e.response.text, err=True)
raise click.Abort()
@credentials.command()
@click.argument("credential_name")
@click.pass_context
def get(ctx: click.Context, credential_name: str):
"""Get a credential by name"""
client = CredentialsManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
response = client.get(credential_name)
rich.print_json(data=response)

View File

@@ -0,0 +1,102 @@
import json as json_lib
from typing import Optional
import click
import rich
import requests
from ...http_client import HTTPClient
@click.group()
def http():
"""Make HTTP requests to the LiteLLM proxy server"""
pass
@http.command()
@click.argument("method")
@click.argument("uri")
@click.option(
"--data",
"-d",
type=str,
help="Data to send in the request body (as JSON string)",
)
@click.option(
"--json",
"-j",
type=str,
help="JSON data to send in the request body (as JSON string)",
)
@click.option(
"--header",
"-H",
multiple=True,
help="HTTP headers in 'key:value' format. Can be specified multiple times.",
)
@click.pass_context
def request(
ctx: click.Context,
method: str,
uri: str,
data: Optional[str] = None,
json: Optional[str] = None,
header: tuple[str, ...] = (),
):
"""Make an HTTP request to the LiteLLM proxy server
METHOD: HTTP method (GET, POST, PUT, DELETE, etc.)
URI: URI path (will be appended to base_url)
Examples:
litellm http request GET /models
litellm http request POST /chat/completions -j '{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}'
litellm http request GET /health/test_connection -H "X-Custom-Header:value"
"""
# Parse headers from key:value format
headers = {}
for h in header:
try:
key, value = h.split(":", 1)
headers[key.strip()] = value.strip()
except ValueError:
raise click.BadParameter(
f"Invalid header format: {h}. Expected format: 'key:value'"
)
# Parse JSON data if provided
json_data = None
if json:
try:
json_data = json_lib.loads(json)
except ValueError as e:
raise click.BadParameter(f"Invalid JSON format: {e}")
# Parse data if provided
request_data = None
if data:
try:
request_data = json_lib.loads(data)
except ValueError:
# If not JSON, use as raw data
request_data = data
client = HTTPClient(ctx.obj["base_url"], ctx.obj["api_key"])
try:
response = client.request(
method=method,
uri=uri,
data=request_data,
json=json_data,
headers=headers,
)
rich.print_json(data=response)
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
try:
error_body = e.response.json()
rich.print_json(data=error_body)
except json_lib.JSONDecodeError:
click.echo(e.response.text, err=True)
raise click.Abort()

View File

@@ -0,0 +1,415 @@
import json
from datetime import datetime
from typing import Literal, Optional, List, Dict, Any
import click
import rich
import requests
from rich.table import Table
from ...keys import KeysManagementClient
@click.group()
def keys():
"""Manage API keys for the LiteLLM proxy server"""
pass
@keys.command()
@click.option("--page", type=int, help="Page number for pagination")
@click.option("--size", type=int, help="Number of items per page")
@click.option("--user-id", type=str, help="Filter keys by user ID")
@click.option("--team-id", type=str, help="Filter keys by team ID")
@click.option("--organization-id", type=str, help="Filter keys by organization ID")
@click.option("--key-hash", type=str, help="Filter by specific key hash")
@click.option("--key-alias", type=str, help="Filter by key alias")
@click.option(
"--return-full-object",
is_flag=True,
default=True,
help="Return the full key object",
)
@click.option(
"--include-team-keys", is_flag=True, help="Include team keys in the response"
)
@click.option(
"--format",
"output_format",
type=click.Choice(["table", "json"]),
default="table",
help="Output format (table or json)",
)
@click.pass_context
def list(
ctx: click.Context,
page: Optional[int],
size: Optional[int],
user_id: Optional[str],
team_id: Optional[str],
organization_id: Optional[str],
key_hash: Optional[str],
key_alias: Optional[str],
include_team_keys: bool,
output_format: Literal["table", "json"],
return_full_object: bool,
):
"""List all API keys"""
client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
response = client.list(
page=page,
size=size,
user_id=user_id,
team_id=team_id,
organization_id=organization_id,
key_hash=key_hash,
key_alias=key_alias,
return_full_object=return_full_object,
include_team_keys=include_team_keys,
)
assert isinstance(response, dict)
if output_format == "json":
rich.print_json(data=response)
else:
rich.print(
f"Showing {len(response.get('keys', []))} keys out of {response.get('total_count', 0)}"
)
table = Table(title="API Keys")
table.add_column("Key Hash", style="cyan")
table.add_column("Alias", style="green")
table.add_column("User ID", style="magenta")
table.add_column("Team ID", style="yellow")
table.add_column("Spend", style="red")
for key in response.get("keys", []):
table.add_row(
str(key.get("token", "")),
str(key.get("key_alias", "")),
str(key.get("user_id", "")),
str(key.get("team_id", "")),
str(key.get("spend", "")),
)
rich.print(table)
@keys.command()
@click.option("--models", type=str, help="Comma-separated list of allowed models")
@click.option("--aliases", type=str, help="JSON string of model alias mappings")
@click.option("--spend", type=float, help="Maximum spend limit for this key")
@click.option(
"--duration",
type=str,
help="Duration for which the key is valid (e.g. '24h', '7d')",
)
@click.option("--key-alias", type=str, help="Alias/name for the key")
@click.option("--team-id", type=str, help="Team ID to associate the key with")
@click.option("--user-id", type=str, help="User ID to associate the key with")
@click.option("--budget-id", type=str, help="Budget ID to associate the key with")
@click.option(
"--config", type=str, help="JSON string of additional configuration parameters"
)
@click.pass_context
def generate(
ctx: click.Context,
models: Optional[str],
aliases: Optional[str],
spend: Optional[float],
duration: Optional[str],
key_alias: Optional[str],
team_id: Optional[str],
user_id: Optional[str],
budget_id: Optional[str],
config: Optional[str],
):
"""Generate a new API key"""
client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
try:
models_list = [m.strip() for m in models.split(",")] if models else None
aliases_dict = json.loads(aliases) if aliases else None
config_dict = json.loads(config) if config else None
except json.JSONDecodeError as e:
raise click.BadParameter(f"Invalid JSON: {str(e)}")
try:
response = client.generate(
models=models_list,
aliases=aliases_dict,
spend=spend,
duration=duration,
key_alias=key_alias,
team_id=team_id,
user_id=user_id,
budget_id=budget_id,
config=config_dict,
)
rich.print_json(data=response)
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
try:
error_body = e.response.json()
rich.print_json(data=error_body)
except json.JSONDecodeError:
click.echo(e.response.text, err=True)
raise click.Abort()
@keys.command()
@click.option("--keys", type=str, help="Comma-separated list of API keys to delete")
@click.option(
"--key-aliases", type=str, help="Comma-separated list of key aliases to delete"
)
@click.pass_context
def delete(ctx: click.Context, keys: Optional[str], key_aliases: Optional[str]):
"""Delete API keys by key or alias"""
client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
keys_list = [k.strip() for k in keys.split(",")] if keys else None
aliases_list = [a.strip() for a in key_aliases.split(",")] if key_aliases else None
try:
response = client.delete(keys=keys_list, key_aliases=aliases_list)
rich.print_json(data=response)
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
try:
error_body = e.response.json()
rich.print_json(data=error_body)
except json.JSONDecodeError:
click.echo(e.response.text, err=True)
raise click.Abort()
def _parse_created_since_filter(created_since: Optional[str]) -> Optional[datetime]:
"""Parse and validate the created_since date filter."""
if not created_since:
return None
try:
# Support formats: YYYY-MM-DD_HH:MM or YYYY-MM-DD
if "_" in created_since:
return datetime.strptime(created_since, "%Y-%m-%d_%H:%M")
else:
return datetime.strptime(created_since, "%Y-%m-%d")
except ValueError:
click.echo(
f"Error: Invalid date format '{created_since}'. Use YYYY-MM-DD_HH:MM or YYYY-MM-DD",
err=True,
)
raise click.Abort()
def _fetch_all_keys_with_pagination(
source_client: KeysManagementClient, source_base_url: str
) -> List[Dict[str, Any]]:
"""Fetch all keys from source instance using pagination."""
click.echo(f"Fetching keys from source server: {source_base_url}")
source_keys = []
page = 1
page_size = 100 # Use a larger page size to minimize API calls
while True:
source_response = source_client.list(
return_full_object=True, page=page, size=page_size
)
# source_client.list() returns Dict[str, Any] when return_request is False (default)
assert isinstance(source_response, dict), "Expected dict response from list API"
page_keys = source_response.get("keys", [])
if not page_keys:
break
source_keys.extend(page_keys)
click.echo(f"Fetched page {page}: {len(page_keys)} keys")
# Check if we got fewer keys than the page size, indicating last page
if len(page_keys) < page_size:
break
page += 1
return source_keys
def _filter_keys_by_created_since(
source_keys: List[Dict[str, Any]],
created_since_dt: Optional[datetime],
created_since: str,
) -> List[Dict[str, Any]]:
"""Filter keys by created_since date if specified."""
if not created_since_dt:
return source_keys
filtered_keys = []
for key in source_keys:
key_created_at = key.get("created_at")
if key_created_at:
# Parse the key's created_at timestamp
if isinstance(key_created_at, str):
if "T" in key_created_at:
key_dt = datetime.fromisoformat(
key_created_at.replace("Z", "+00:00")
)
else:
key_dt = datetime.fromisoformat(key_created_at)
# Convert to naive datetime for comparison (assuming UTC)
if key_dt.tzinfo:
key_dt = key_dt.replace(tzinfo=None)
if key_dt >= created_since_dt:
filtered_keys.append(key)
click.echo(
f"Filtered {len(source_keys)} keys to {len(filtered_keys)} keys created since {created_since}"
)
return filtered_keys
def _display_dry_run_table(source_keys: List[Dict[str, Any]]) -> None:
"""Display a table of keys that would be imported in dry-run mode."""
click.echo("\n--- DRY RUN MODE ---")
table = Table(title="Keys that would be imported")
table.add_column("Key Alias", style="green")
table.add_column("User ID", style="magenta")
table.add_column("Created", style="cyan")
for key in source_keys:
created_at = key.get("created_at", "")
# Format the timestamp if it exists
if created_at:
# Try to parse and format the timestamp for better readability
if isinstance(created_at, str):
# Handle common timestamp formats
if "T" in created_at:
dt = datetime.fromisoformat(created_at.replace("Z", "+00:00"))
created_at = dt.strftime("%Y-%m-%d %H:%M")
table.add_row(
str(key.get("key_alias", "")), str(key.get("user_id", "")), str(created_at)
)
rich.print(table)
def _prepare_key_import_data(key: Dict[str, Any]) -> Dict[str, Any]:
"""Prepare key data for import by extracting relevant fields."""
import_data = {}
# Copy relevant fields if they exist
for field in [
"models",
"aliases",
"spend",
"key_alias",
"team_id",
"user_id",
"budget_id",
"config",
]:
if key.get(field):
import_data[field] = key[field]
return import_data
def _import_keys_to_destination(
source_keys: List[Dict[str, Any]], dest_client: KeysManagementClient
) -> tuple[int, int]:
"""Import each key to the destination instance and return counts."""
imported_count = 0
failed_count = 0
for key in source_keys:
try:
# Prepare key data for import
import_data = _prepare_key_import_data(key)
# Generate the key in destination instance
response = dest_client.generate(**import_data)
click.echo(f"Generated key: {response}")
# The generate method returns JSON data directly, not a Response object
imported_count += 1
key_alias = key.get("key_alias", "N/A")
click.echo(f"✓ Imported key: {key_alias}")
except Exception as e:
failed_count += 1
key_alias = key.get("key_alias", "N/A")
click.echo(f"✗ Failed to import key {key_alias}: {str(e)}", err=True)
return imported_count, failed_count
@keys.command(name="import")
@click.option(
"--source-base-url",
required=True,
help="Base URL of the source LiteLLM proxy server to import keys from",
)
@click.option(
"--source-api-key", help="API key for authentication to the source server"
)
@click.option(
"--dry-run",
is_flag=True,
help="Show what would be imported without actually importing",
)
@click.option(
"--created-since",
help="Only import keys created after this date/time (format: YYYY-MM-DD_HH:MM or YYYY-MM-DD)",
)
@click.pass_context
def import_keys(
ctx: click.Context,
source_base_url: str,
source_api_key: Optional[str],
dry_run: bool,
created_since: Optional[str],
):
"""Import API keys from another LiteLLM instance"""
# Parse created_since filter if provided
created_since_dt = _parse_created_since_filter(created_since)
# Create clients for both source and destination
source_client = KeysManagementClient(source_base_url, source_api_key)
dest_client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
try:
# Get all keys from source instance with pagination
source_keys = _fetch_all_keys_with_pagination(source_client, source_base_url)
# Filter keys by created_since if specified
if created_since:
source_keys = _filter_keys_by_created_since(
source_keys, created_since_dt, created_since
)
if not source_keys:
click.echo("No keys found in source instance.")
return
click.echo(f"Found {len(source_keys)} keys in source instance.")
if dry_run:
_display_dry_run_table(source_keys)
return
# Import each key
imported_count, failed_count = _import_keys_to_destination(
source_keys, dest_client
)
# Summary
click.echo("\nImport completed:")
click.echo(f" Successfully imported: {imported_count}")
click.echo(f" Failed to import: {failed_count}")
click.echo(f" Total keys processed: {len(source_keys)}")
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
try:
error_body = e.response.json()
rich.print_json(data=error_body)
except json.JSONDecodeError:
click.echo(e.response.text, err=True)
raise click.Abort()
except Exception as e:
click.echo(f"Error: {str(e)}", err=True)
raise click.Abort()

View File

@@ -0,0 +1,485 @@
# stdlib imports
from datetime import datetime
import re
from typing import Optional, Literal, Any
import yaml
from dataclasses import dataclass
from collections import defaultdict
# third party imports
import click
import rich
# local imports
from ... import Client
@dataclass
class ModelYamlInfo:
model_name: str
model_params: dict[str, Any]
model_info: dict[str, Any]
model_id: str
access_groups: list[str]
provider: str
@property
def access_groups_str(self) -> str:
return ", ".join(self.access_groups) if self.access_groups else ""
def _get_model_info_obj_from_yaml(model: dict[str, Any]) -> ModelYamlInfo:
"""Extract model info from a model dict and return as ModelYamlInfo dataclass."""
model_name: str = model["model_name"]
model_params: dict[str, Any] = model["litellm_params"]
model_info: dict[str, Any] = model.get("model_info", {})
model_id: str = model_params["model"]
access_groups = model_info.get("access_groups", [])
provider = model_id.split("/", 1)[0] if "/" in model_id else model_id
return ModelYamlInfo(
model_name=model_name,
model_params=model_params,
model_info=model_info,
model_id=model_id,
access_groups=access_groups,
provider=provider,
)
def format_iso_datetime_str(iso_datetime_str: Optional[str]) -> str:
"""Format an ISO format datetime string to human-readable date with minute resolution."""
if not iso_datetime_str:
return ""
try:
# Parse ISO format datetime string
dt = datetime.fromisoformat(iso_datetime_str.replace("Z", "+00:00"))
return dt.strftime("%Y-%m-%d %H:%M")
except (TypeError, ValueError):
return str(iso_datetime_str)
def format_timestamp(timestamp: Optional[int]) -> str:
"""Format a Unix timestamp (integer) to human-readable date with minute resolution."""
if timestamp is None:
return ""
try:
dt = datetime.fromtimestamp(timestamp)
return dt.strftime("%Y-%m-%d %H:%M")
except (TypeError, ValueError):
return str(timestamp)
def format_cost_per_1k_tokens(cost: Optional[float]) -> str:
"""Format a per-token cost to cost per 1000 tokens."""
if cost is None:
return ""
try:
# Convert string to float if needed
cost_float = float(cost)
# Multiply by 1000 and format to 4 decimal places
return f"${cost_float * 1000:.4f}"
except (TypeError, ValueError):
return str(cost)
def create_client(ctx: click.Context) -> Client:
"""Helper function to create a client from context."""
return Client(base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"])
@click.group()
def models() -> None:
"""Manage models on your LiteLLM proxy server"""
pass
@models.command("list")
@click.option(
"--format",
"output_format",
type=click.Choice(["table", "json"]),
default="table",
help="Output format (table or json)",
)
@click.pass_context
def list_models(ctx: click.Context, output_format: Literal["table", "json"]) -> None:
"""List all available models"""
client = create_client(ctx)
models_list = client.models.list()
assert isinstance(models_list, list)
if output_format == "json":
rich.print_json(data=models_list)
else: # table format
table = rich.table.Table(title="Available Models")
# Add columns based on the data structure
table.add_column("ID", style="cyan")
table.add_column("Object", style="green")
table.add_column("Created", style="magenta")
table.add_column("Owned By", style="yellow")
# Add rows
for model in models_list:
created = model.get("created")
# Convert string timestamp to integer if needed
if isinstance(created, str) and created.isdigit():
created = int(created)
table.add_row(
str(model.get("id", "")),
str(model.get("object", "model")),
format_timestamp(created)
if isinstance(created, int)
else format_iso_datetime_str(created),
str(model.get("owned_by", "")),
)
rich.print(table)
@models.command("add")
@click.argument("model-name")
@click.option(
"--param",
"-p",
multiple=True,
help="Model parameters in key=value format (can be specified multiple times)",
)
@click.option(
"--info",
"-i",
multiple=True,
help="Model info in key=value format (can be specified multiple times)",
)
@click.pass_context
def add_model(
ctx: click.Context, model_name: str, param: tuple[str, ...], info: tuple[str, ...]
) -> None:
"""Add a new model to the proxy"""
# Convert parameters from key=value format to dict
model_params = dict(p.split("=", 1) for p in param)
model_info = dict(i.split("=", 1) for i in info) if info else None
client = create_client(ctx)
result = client.models.new(
model_name=model_name,
model_params=model_params,
model_info=model_info,
)
rich.print_json(data=result)
@models.command("delete")
@click.argument("model-id")
@click.pass_context
def delete_model(ctx: click.Context, model_id: str) -> None:
"""Delete a model from the proxy"""
client = create_client(ctx)
result = client.models.delete(model_id=model_id)
rich.print_json(data=result)
@models.command("get")
@click.option("--id", "model_id", help="ID of the model to retrieve")
@click.option("--name", "model_name", help="Name of the model to retrieve")
@click.pass_context
def get_model(
ctx: click.Context, model_id: Optional[str], model_name: Optional[str]
) -> None:
"""Get information about a specific model"""
if not model_id and not model_name:
raise click.UsageError("Either --id or --name must be provided")
client = create_client(ctx)
result = client.models.get(model_id=model_id, model_name=model_name)
rich.print_json(data=result)
@models.command("info")
@click.option(
"--format",
"output_format",
type=click.Choice(["table", "json"]),
default="table",
help="Output format (table or json)",
)
@click.option(
"--columns",
"columns",
default="public_model,upstream_model,updated_at",
help="Comma-separated list of columns to display. Valid columns: public_model, upstream_model, credential_name, created_at, updated_at, id, input_cost, output_cost. Default: public_model,upstream_model,updated_at",
)
@click.pass_context
def get_models_info(
ctx: click.Context, output_format: Literal["table", "json"], columns: str
) -> None:
"""Get detailed information about all models"""
client = create_client(ctx)
models_info = client.models.info()
assert isinstance(models_info, list)
if output_format == "json":
rich.print_json(data=models_info)
else: # table format
table = rich.table.Table(title="Models Information")
# Define all possible columns with their configurations
column_configs: dict[str, dict[str, Any]] = {
"public_model": {
"header": "Public Model",
"style": "cyan",
"get_value": lambda m: str(m.get("model_name", "")),
},
"upstream_model": {
"header": "Upstream Model",
"style": "green",
"get_value": lambda m: str(
m.get("litellm_params", {}).get("model", "")
),
},
"credential_name": {
"header": "Credential Name",
"style": "yellow",
"get_value": lambda m: str(
m.get("litellm_params", {}).get("litellm_credential_name", "")
),
},
"created_at": {
"header": "Created At",
"style": "magenta",
"get_value": lambda m: format_iso_datetime_str(
m.get("model_info", {}).get("created_at")
),
},
"updated_at": {
"header": "Updated At",
"style": "magenta",
"get_value": lambda m: format_iso_datetime_str(
m.get("model_info", {}).get("updated_at")
),
},
"id": {
"header": "ID",
"style": "blue",
"get_value": lambda m: str(m.get("model_info", {}).get("id", "")),
},
"input_cost": {
"header": "Input Cost",
"style": "green",
"justify": "right",
"get_value": lambda m: format_cost_per_1k_tokens(
m.get("model_info", {}).get("input_cost_per_token")
),
},
"output_cost": {
"header": "Output Cost",
"style": "green",
"justify": "right",
"get_value": lambda m: format_cost_per_1k_tokens(
m.get("model_info", {}).get("output_cost_per_token")
),
},
}
# Add requested columns
requested_columns = [col.strip() for col in columns.split(",")]
for col_name in requested_columns:
if col_name in column_configs:
config = column_configs[col_name]
table.add_column(
config["header"],
style=config["style"],
justify=config.get("justify", "left"),
)
else:
click.echo(f"Warning: Unknown column '{col_name}'", err=True)
# Add rows with only the requested columns
for model in models_info:
row_values = []
for col_name in requested_columns:
if col_name in column_configs:
row_values.append(column_configs[col_name]["get_value"](model))
if row_values:
table.add_row(*row_values)
rich.print(table)
@models.command("update")
@click.argument("model-id")
@click.option(
"--param",
"-p",
multiple=True,
help="Model parameters in key=value format (can be specified multiple times)",
)
@click.option(
"--info",
"-i",
multiple=True,
help="Model info in key=value format (can be specified multiple times)",
)
@click.pass_context
def update_model(
ctx: click.Context, model_id: str, param: tuple[str, ...], info: tuple[str, ...]
) -> None:
"""Update an existing model's configuration"""
# Convert parameters from key=value format to dict
model_params = dict(p.split("=", 1) for p in param)
model_info = dict(i.split("=", 1) for i in info) if info else None
client = create_client(ctx)
result = client.models.update(
model_id=model_id,
model_params=model_params,
model_info=model_info,
)
rich.print_json(data=result)
def _filter_model(model, model_regex, access_group_regex):
model_name = model.get("model_name")
model_params = model.get("litellm_params")
model_info = model.get("model_info", {})
if not model_name or not model_params:
return False
model_id = model_params.get("model")
if not model_id or not isinstance(model_id, str):
return False
if model_regex and not model_regex.search(model_id):
return False
access_groups = model_info.get("access_groups", [])
if access_group_regex:
if not isinstance(access_groups, list):
return False
if not any(
isinstance(group, str) and access_group_regex.search(group)
for group in access_groups
):
return False
return True
def _print_models_table(added_models: list[ModelYamlInfo], table_title: str):
if not added_models:
return
table = rich.table.Table(title=table_title)
table.add_column("Model Name", style="cyan")
table.add_column("Upstream Model", style="green")
table.add_column("Access Groups", style="magenta")
for m in added_models:
table.add_row(m.model_name, m.model_id, m.access_groups_str)
rich.print(table)
def _print_summary_table(provider_counts):
summary_table = rich.table.Table(title="Model Import Summary")
summary_table.add_column("Provider", style="cyan")
summary_table.add_column("Count", style="green")
for provider, count in provider_counts.items():
summary_table.add_row(str(provider), str(count))
total = sum(provider_counts.values())
summary_table.add_row("[bold]Total[/bold]", f"[bold]{total}[/bold]")
rich.print(summary_table)
def get_model_list_from_yaml_file(yaml_file: str) -> list[dict[str, Any]]:
"""Load and validate the model list from a YAML file."""
with open(yaml_file, "r") as f:
data = yaml.safe_load(f)
if not data or "model_list" not in data:
raise click.ClickException(
"YAML file must contain a 'model_list' key with a list of models."
)
model_list = data["model_list"]
if not isinstance(model_list, list):
raise click.ClickException("'model_list' must be a list of model definitions.")
return model_list
def _get_filtered_model_list(
model_list, only_models_matching_regex, only_access_groups_matching_regex
):
"""Return a list of models that pass the filter criteria."""
model_regex = (
re.compile(only_models_matching_regex) if only_models_matching_regex else None
)
access_group_regex = (
re.compile(only_access_groups_matching_regex)
if only_access_groups_matching_regex
else None
)
return [
model
for model in model_list
if _filter_model(model, model_regex, access_group_regex)
]
def _import_models_get_table_title(dry_run: bool) -> str:
if dry_run:
return "Models that would be imported if [yellow]--dry-run[/yellow] was not provided"
else:
return "Models Imported"
@models.command("import")
@click.argument(
"yaml_file", type=click.Path(exists=True, dir_okay=False, readable=True)
)
@click.option(
"--dry-run",
is_flag=True,
help="Show what would be imported without making any changes.",
)
@click.option(
"--only-models-matching-regex",
default=None,
help="Only import models where litellm_params.model matches the given regex.",
)
@click.option(
"--only-access-groups-matching-regex",
default=None,
help="Only import models where at least one item in model_info.access_groups matches the given regex.",
)
@click.pass_context
def import_models(
ctx: click.Context,
yaml_file: str,
dry_run: bool,
only_models_matching_regex: Optional[str],
only_access_groups_matching_regex: Optional[str],
) -> None:
"""Import models from a YAML file and add them to the proxy."""
provider_counts: dict[str, int] = defaultdict(int)
added_models: list[ModelYamlInfo] = []
model_list = get_model_list_from_yaml_file(yaml_file)
filtered_model_list = _get_filtered_model_list(
model_list, only_models_matching_regex, only_access_groups_matching_regex
)
if not dry_run:
client = create_client(ctx)
for model in filtered_model_list:
model_info_obj = _get_model_info_obj_from_yaml(model)
if not dry_run:
try:
client.models.new(
model_name=model_info_obj.model_name,
model_params=model_info_obj.model_params,
model_info=model_info_obj.model_info,
)
except Exception:
pass # For summary, ignore errors
added_models.append(model_info_obj)
provider_counts[model_info_obj.provider] += 1
table_title = _import_models_get_table_title(dry_run)
_print_models_table(added_models, table_title)
_print_summary_table(provider_counts)

View File

@@ -0,0 +1,167 @@
"""Team management commands for LiteLLM CLI."""
from typing import Any, Dict, List, Optional
import click
import requests
from rich.console import Console
from rich.table import Table
from litellm.proxy.client import Client
@click.group()
def teams():
"""Manage teams and team assignments"""
pass
def display_teams_table(teams: List[Dict[str, Any]]) -> None:
"""Display teams in a formatted table"""
console = Console()
if not teams:
console.print("❌ No teams found for your user.")
return
table = Table(title="Available Teams")
table.add_column("Index", style="cyan", no_wrap=True)
table.add_column("Team Alias", style="magenta")
table.add_column("Team ID", style="green")
table.add_column("Models", style="yellow")
table.add_column("Max Budget", style="blue")
table.add_column("Role", style="red")
for i, team in enumerate(teams):
team_alias = team.get("team_alias") or "N/A"
team_id = team.get("team_id", "N/A")
models = team.get("models", [])
max_budget = team.get("max_budget")
# Format models list
if models:
if len(models) > 3:
models_str = ", ".join(models[:3]) + f" (+{len(models) - 3} more)"
else:
models_str = ", ".join(models)
else:
models_str = "All models"
# Format budget
budget_str = f"${max_budget}" if max_budget else "Unlimited"
# Try to determine role (this might vary based on API response structure)
role = "Member" # Default role
if (
isinstance(team, dict)
and "members_with_roles" in team
and team["members_with_roles"]
):
# This would need to be implemented based on actual API response structure
pass
table.add_row(str(i + 1), team_alias, team_id, models_str, budget_str, role)
console.print(table)
@teams.command()
@click.pass_context
def list(ctx: click.Context):
"""List teams that you belong to"""
client = Client(ctx.obj["base_url"], ctx.obj["api_key"])
try:
# Use list() for simpler response structure (returns array directly)
teams = client.teams.list()
display_teams_table(teams)
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
error_body = e.response.json()
click.echo(f"Details: {error_body.get('detail', 'Unknown error')}", err=True)
raise click.Abort()
except Exception as e:
click.echo(f"Error: {str(e)}", err=True)
raise click.Abort()
@teams.command()
@click.pass_context
def available(ctx: click.Context):
"""List teams that are available to join"""
client = Client(ctx.obj["base_url"], ctx.obj["api_key"])
try:
teams = client.teams.get_available()
if teams:
console = Console()
console.print("\n🎯 Available Teams to Join:")
display_teams_table(teams)
else:
click.echo(" No available teams to join.")
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
error_body = e.response.json()
click.echo(f"Details: {error_body.get('detail', 'Unknown error')}", err=True)
except Exception as e:
click.echo(f"Error: {str(e)}", err=True)
raise click.Abort()
@teams.command()
@click.option("--team-id", type=str, help="Team ID to assign the key to")
@click.pass_context
def assign_key(ctx: click.Context, team_id: Optional[str]):
"""Assign your current CLI key to a team"""
client = Client(ctx.obj["base_url"], ctx.obj["api_key"])
api_key = ctx.obj["api_key"]
if not api_key:
click.echo("❌ No API key found. Please login first using 'litellm login'")
raise click.Abort()
try:
# If no team_id provided, show teams and let user select
if not team_id:
teams = client.teams.list()
if not teams:
click.echo("❌ No teams found for your user.")
return
# Use interactive selection from auth module
from .auth import prompt_team_selection
selected_team = prompt_team_selection(teams)
if selected_team:
team_id = selected_team.get("team_id")
else:
click.echo("❌ Operation cancelled.")
return
# Update the key with the selected team
if team_id:
click.echo(f"\n🔄 Assigning your key to team: {team_id}")
client.keys.update(key=api_key, team_id=team_id)
click.echo(f"✅ Successfully assigned key to team: {team_id}")
# Show team details if available
teams = client.teams.list()
for team in teams:
if team.get("team_id") == team_id:
models = team.get("models", [])
if models:
click.echo(f"🎯 You can now access models: {', '.join(models)}")
else:
click.echo("🎯 You can now access all available models")
break
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
error_body = e.response.json()
click.echo(f"Details: {error_body.get('detail', 'Unknown error')}", err=True)
raise click.Abort()
except Exception as e:
click.echo(f"Error: {str(e)}", err=True)
raise click.Abort()

View File

@@ -0,0 +1,91 @@
import click
import rich
from ... import UsersManagementClient
@click.group()
def users():
"""Manage users on your LiteLLM proxy server"""
pass
@users.command("list")
@click.pass_context
def list_users(ctx: click.Context):
"""List all users"""
client = UsersManagementClient(
base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"]
)
users = client.list_users()
if isinstance(users, dict) and "users" in users:
users = users["users"]
if not users:
click.echo("No users found.")
return
from rich.table import Table
from rich.console import Console
table = Table(title="Users")
table.add_column("User ID", style="cyan")
table.add_column("Email", style="green")
table.add_column("Role", style="magenta")
table.add_column("Teams", style="yellow")
for user in users:
table.add_row(
str(user.get("user_id", "")),
str(user.get("user_email", "")),
str(user.get("user_role", "")),
", ".join(user.get("teams", []) or []),
)
console = Console()
console.print(table)
@users.command("get")
@click.option("--id", "user_id", help="ID of the user to retrieve")
@click.pass_context
def get_user(ctx: click.Context, user_id: str):
"""Get information about a specific user"""
client = UsersManagementClient(
base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"]
)
result = client.get_user(user_id=user_id)
rich.print_json(data=result)
@users.command("create")
@click.option("--email", required=True, help="User email")
@click.option("--role", default="internal_user", help="User role")
@click.option("--alias", default=None, help="User alias")
@click.option("--team", multiple=True, help="Team IDs (can specify multiple)")
@click.option("--max-budget", type=float, default=None, help="Max budget for user")
@click.pass_context
def create_user(ctx: click.Context, email, role, alias, team, max_budget):
"""Create a new user"""
client = UsersManagementClient(
base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"]
)
user_data = {
"user_email": email,
"user_role": role,
}
if alias:
user_data["user_alias"] = alias
if team:
user_data["teams"] = list(team)
if max_budget is not None:
user_data["max_budget"] = max_budget
result = client.create_user(user_data)
rich.print_json(data=result)
@users.command("delete")
@click.argument("user_ids", nargs=-1)
@click.pass_context
def delete_user(ctx: click.Context, user_ids):
"""Delete one or more users by user_id"""
client = UsersManagementClient(
base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"]
)
result = client.delete_user(list(user_ids))
rich.print_json(data=result)

View File

@@ -0,0 +1,207 @@
# stdlib imports
import os
import sys
from typing import TYPE_CHECKING
# third party imports
import click
from litellm._logging import verbose_logger
if TYPE_CHECKING:
pass
def styled_prompt():
"""Create a styled blue box prompt for user input."""
# Get terminal height to ensure we have enough space
try:
terminal_height = os.get_terminal_size().lines
# Ensure we have at least 5 lines of space (for the box + some buffer)
if terminal_height < 10:
# If terminal is too small, just add some newlines to push content up
click.echo("\n" * 3)
except Exception as e:
# Fallback if we can't get terminal size
verbose_logger.debug(f"Error getting terminal size: {e}")
click.echo("\n" * 3)
# Unicode box drawing characters
top_left = ""
top_right = ""
bottom_left = ""
bottom_right = ""
horizontal = ""
vertical = ""
# Create the box with increased width
width = 80
top_line = top_left + horizontal * (width - 2) + top_right
bottom_line = bottom_left + horizontal * (width - 2) + bottom_right
# Create styled elements
left_border = click.style(vertical, fg="blue", bold=True)
right_border = click.style(vertical, fg="blue", bold=True)
prompt_text = click.style("> ", fg="cyan", bold=True)
# Display the complete box structure first to reserve space
click.echo(click.style(top_line, fg="blue", bold=True))
# Create empty space in the box for input
empty_space = " " * (width - 4)
click.echo(f"{left_border} {empty_space} {right_border}")
# Display bottom border to complete the box
click.echo(click.style(bottom_line, fg="blue", bold=True))
# Now move cursor up to the input line and get input
click.echo("\033[2A", nl=False) # Move cursor up 2 lines
click.echo(
f"\r{left_border} {prompt_text}", nl=False
) # Position at start of input line
try:
# Get user input
user_input = input().strip()
# Move cursor down to after the box
click.echo("\033[1B") # Move cursor down 1 line
click.echo("") # Add some space after
except (KeyboardInterrupt, EOFError):
# Move cursor down and add space
click.echo("\033[1B")
click.echo("")
raise
return user_input
def show_commands():
"""Display available commands."""
commands = [
("login", "Authenticate with the LiteLLM proxy server"),
("logout", "Clear stored authentication"),
("whoami", "Show current authentication status"),
("models", "Manage and view model configurations"),
("credentials", "Manage API credentials"),
("chat", "Interactive streaming chat with models"),
("http", "Make HTTP requests to the proxy"),
("keys", "Manage API keys"),
("teams", "Manage teams and team assignments"),
("users", "Manage users"),
("version", "Show version information"),
("help", "Show this help message"),
("quit", "Exit the interactive session"),
]
click.echo("Available commands:")
for cmd, description in commands:
click.echo(f" {cmd:<20} {description}")
click.echo()
def setup_shell(ctx: click.Context):
"""Set up the interactive shell with banner and initial info."""
from litellm.proxy.common_utils.banner import show_banner
show_banner()
# Show server connection info
base_url = ctx.obj.get("base_url")
click.secho(f"Connected to LiteLLM server: {base_url}\n", fg="green")
show_commands()
def handle_special_commands(user_input: str) -> bool:
"""Handle special commands like exit, help, clear. Returns True if command was handled."""
if user_input.lower() in ["exit", "quit"]:
click.echo("Goodbye!")
return True
elif user_input.lower() == "help":
click.echo("") # Add space before help
show_commands()
return True
elif user_input.lower() == "clear":
click.clear()
from litellm.proxy.common_utils.banner import show_banner
show_banner()
show_commands()
return True
return False
def execute_command(user_input: str, ctx: click.Context):
"""Parse and execute a command."""
# Parse command and arguments
parts = user_input.split()
command = parts[0]
args = parts[1:] if len(parts) > 1 else []
# Import cli here to avoid circular import
from . import main
cli = main.cli
# Check if command exists
if command not in cli.commands:
click.echo(f"Unknown command: {command}")
click.echo("Type 'help' to see available commands.")
return
# Execute the command
try:
# Create a new argument list for click to parse
sys.argv = ["litellm-proxy"] + [command] + args
# Get the command object and invoke it
cmd = cli.commands[command]
# Create a new context for the subcommand
with ctx.scope():
cmd.main(args, parent=ctx, standalone_mode=False)
except click.ClickException as e:
e.show()
except click.Abort:
click.echo("Command aborted.")
except SystemExit:
# Prevent the interactive shell from exiting on command errors
pass
except Exception as e:
click.echo(f"Error executing command: {e}")
def interactive_shell(ctx: click.Context):
"""Run the interactive shell."""
setup_shell(ctx)
while True:
try:
# Add some space before the input box to ensure it's positioned well
click.echo("\n") # Extra spacing
# Show styled prompt
user_input = styled_prompt()
if not user_input:
continue
# Handle special commands
if handle_special_commands(user_input):
if user_input.lower() in ["exit", "quit"]:
break
continue
# Execute regular commands
execute_command(user_input, ctx)
except (KeyboardInterrupt, EOFError):
click.echo("\nGoodbye!")
break
except Exception as e:
click.echo(f"Error: {e}")

View File

@@ -0,0 +1,115 @@
# stdlib imports
from typing import Optional
# third party imports
import click
from litellm._version import version as litellm_version
from litellm.proxy.client.health import HealthManagementClient
from .commands.auth import get_stored_api_key, login, logout, whoami
from .commands.chat import chat
from .commands.credentials import credentials
from .commands.http import http
from .commands.keys import keys
# local imports
from .commands.models import models
from .commands.teams import teams
from .commands.users import users
from .interface import interactive_shell
def print_version(base_url: str, api_key: Optional[str]):
"""Print CLI and server version info."""
click.echo(f"LiteLLM Proxy CLI Version: {litellm_version}")
if base_url:
click.echo(f"LiteLLM Proxy Server URL: {base_url}")
try:
health_client = HealthManagementClient(base_url=base_url, api_key=api_key)
server_version = health_client.get_server_version()
if server_version:
click.echo(f"LiteLLM Proxy Server Version: {server_version}")
else:
click.echo("LiteLLM Proxy Server Version: (unavailable)")
except Exception as e:
click.echo(f"Could not retrieve server version: {e}")
@click.group(invoke_without_command=True)
@click.option(
"--version",
"-v",
is_flag=True,
is_eager=True,
expose_value=False,
help="Show the LiteLLM Proxy CLI and server version and exit.",
callback=lambda ctx, param, value: (
print_version(
ctx.params.get("base_url") or "http://localhost:4000",
ctx.params.get("api_key"),
)
or ctx.exit()
)
if value and not ctx.resilient_parsing
else None,
)
@click.option(
"--base-url",
envvar="LITELLM_PROXY_URL",
show_envvar=True,
default="http://localhost:4000",
help="Base URL of the LiteLLM proxy server",
)
@click.option(
"--api-key",
envvar="LITELLM_PROXY_API_KEY",
show_envvar=True,
help="API key for authentication",
)
@click.pass_context
def cli(ctx: click.Context, base_url: str, api_key: Optional[str]) -> None:
"""LiteLLM Proxy CLI - Manage your LiteLLM proxy server"""
ctx.ensure_object(dict)
# If no API key provided via flag or environment variable, try to load from saved token
if api_key is None:
api_key = get_stored_api_key()
ctx.obj["base_url"] = base_url
ctx.obj["api_key"] = api_key
# If no subcommand was invoked, start interactive mode
if ctx.invoked_subcommand is None:
interactive_shell(ctx)
@cli.command()
@click.pass_context
def version(ctx: click.Context):
"""Show the LiteLLM Proxy CLI and server version."""
print_version(ctx.obj.get("base_url"), ctx.obj.get("api_key"))
# Add authentication commands as top-level commands
cli.add_command(login)
cli.add_command(logout)
cli.add_command(whoami)
# Add the models command group
cli.add_command(models)
# Add the credentials command group
cli.add_command(credentials)
# Add the chat command group
cli.add_command(chat)
# Add the http command group
cli.add_command(http)
# Add the keys command group
cli.add_command(keys)
# Add the teams command group
cli.add_command(teams)
# Add the users command group
cli.add_command(users)
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,50 @@
from typing import Optional
from litellm.litellm_core_utils.cli_token_utils import get_litellm_gateway_api_key
from .chat import ChatClient
from .credentials import CredentialsManagementClient
from .http_client import HTTPClient
from .keys import KeysManagementClient
from .model_groups import ModelGroupsManagementClient
from .models import ModelsManagementClient
from .teams import TeamsManagementClient
class Client:
"""Main client for interacting with the LiteLLM proxy API."""
def __init__(
self,
base_url: str,
api_key: Optional[str] = None,
timeout: int = 30,
):
"""
Initialize the LiteLLM proxy client.
Args:
base_url (str): The base URL of the LiteLLM proxy server (e.g., "http://localhost:4000")
api_key (Optional[str]): API key for authentication. If provided, it will be sent as a Bearer token.
timeout: Request timeout in seconds (default: 30)
"""
self._base_url = base_url.rstrip("/") # Remove trailing slash if present
self._api_key = get_litellm_gateway_api_key() or api_key
# Initialize resource clients
self.http = HTTPClient(base_url=base_url, api_key=api_key, timeout=timeout)
self.models = ModelsManagementClient(
base_url=self._base_url, api_key=self._api_key
)
self.model_groups = ModelGroupsManagementClient(
base_url=self._base_url, api_key=self._api_key
)
self.chat = ChatClient(base_url=self._base_url, api_key=self._api_key)
self.keys = KeysManagementClient(base_url=self._base_url, api_key=self._api_key)
self.credentials = CredentialsManagementClient(
base_url=self._base_url, api_key=self._api_key
)
self.teams = TeamsManagementClient(
base_url=self._base_url, api_key=self._api_key
)

View File

@@ -0,0 +1,185 @@
import requests
from typing import Dict, Any, Optional, Union
from .exceptions import UnauthorizedError
class CredentialsManagementClient:
def __init__(self, base_url: str, api_key: Optional[str] = None):
"""
Initialize the CredentialsManagementClient.
Args:
base_url (str): The base URL of the LiteLLM proxy server (e.g., "http://localhost:8000")
api_key (Optional[str]): API key for authentication. If provided, it will be sent as a Bearer token.
"""
self._base_url = base_url.rstrip("/") # Remove trailing slash if present
self._api_key = api_key
def _get_headers(self) -> Dict[str, str]:
"""
Get the headers for API requests, including authorization if api_key is set.
Returns:
Dict[str, str]: Headers to use for API requests
"""
headers = {"Content-Type": "application/json"}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
return headers
def list(
self,
return_request: bool = False,
) -> Union[Dict[str, Any], requests.Request]:
"""
List all credentials.
Args:
return_request (bool): If True, returns the prepared request object instead of executing it
Returns:
Union[Dict[str, Any], requests.Request]: Either the response from the server or
a prepared request object if return_request is True
Raises:
UnauthorizedError: If the request fails with a 401 status code
requests.exceptions.RequestException: If the request fails with any other error
"""
url = f"{self._base_url}/credentials"
request = requests.Request("GET", url, headers=self._get_headers())
if return_request:
return request
session = requests.Session()
try:
response = session.send(request.prepare())
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise UnauthorizedError(e)
raise
def create(
self,
credential_name: str,
credential_info: Dict[str, Any],
credential_values: Dict[str, Any],
return_request: bool = False,
) -> Union[Dict[str, Any], requests.Request]:
"""
Create a new credential.
Args:
credential_name (str): Name of the credential
credential_info (Dict[str, Any]): Additional information about the credential
credential_values (Dict[str, Any]): Values for the credential
return_request (bool): If True, returns the prepared request object instead of executing it
Returns:
Union[Dict[str, Any], requests.Request]: Either the response from the server or
a prepared request object if return_request is True
Raises:
UnauthorizedError: If the request fails with a 401 status code
requests.exceptions.RequestException: If the request fails with any other error
"""
url = f"{self._base_url}/credentials"
data = {
"credential_name": credential_name,
"credential_info": credential_info,
"credential_values": credential_values,
}
request = requests.Request("POST", url, headers=self._get_headers(), json=data)
if return_request:
return request
session = requests.Session()
try:
response = session.send(request.prepare())
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise UnauthorizedError(e)
raise
def delete(
self,
credential_name: str,
return_request: bool = False,
) -> Union[Dict[str, Any], requests.Request]:
"""
Delete a credential by name.
Args:
credential_name (str): Name of the credential to delete
return_request (bool): If True, returns the prepared request object instead of executing it
Returns:
Union[Dict[str, Any], requests.Request]: Either the response from the server or
a prepared request object if return_request is True
Raises:
UnauthorizedError: If the request fails with a 401 status code
requests.exceptions.RequestException: If the request fails with any other error
"""
url = f"{self._base_url}/credentials/{credential_name}"
request = requests.Request("DELETE", url, headers=self._get_headers())
if return_request:
return request
session = requests.Session()
try:
response = session.send(request.prepare())
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise UnauthorizedError(e)
raise
def get(
self,
credential_name: str,
return_request: bool = False,
) -> Union[Dict[str, Any], requests.Request]:
"""
Get a credential by name.
Args:
credential_name (str): Name of the credential to retrieve
return_request (bool): If True, returns the prepared request object instead of executing it
Returns:
Union[Dict[str, Any], requests.Request]: Either the response from the server or
a prepared request object if return_request is True
Raises:
UnauthorizedError: If the request fails with a 401 status code
requests.exceptions.RequestException: If the request fails with any other error
"""
url = f"{self._base_url}/credentials/by_name/{credential_name}"
request = requests.Request("GET", url, headers=self._get_headers())
if return_request:
return request
session = requests.Session()
try:
response = session.send(request.prepare())
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise UnauthorizedError(e)
raise

View File

@@ -0,0 +1,19 @@
from typing import Union
import requests
class UnauthorizedError(Exception):
"""Exception raised when the API returns a 401 Unauthorized response."""
def __init__(self, orig_exception: Union[requests.exceptions.HTTPError, str]):
self.orig_exception = orig_exception
super().__init__(str(orig_exception))
class NotFoundError(Exception):
"""Exception raised when the API returns a 404 Not Found response or indicates a resource was not found."""
def __init__(self, orig_exception: Union[requests.exceptions.HTTPError, str]):
self.orig_exception = orig_exception
super().__init__(str(orig_exception))

View File

@@ -0,0 +1,42 @@
from typing import Optional, Dict, Any
from .http_client import HTTPClient
class HealthManagementClient:
"""
Client for interacting with the health endpoints of the LiteLLM proxy server.
"""
def __init__(self, base_url: str, api_key: Optional[str] = None, timeout: int = 30):
"""
Initialize the HealthManagementClient.
Args:
base_url (str): The base URL of the LiteLLM proxy server (e.g., "http://localhost:4000")
api_key (Optional[str]): API key for authentication. If provided, it will be sent as a Bearer token.
timeout (int): Request timeout in seconds (default: 30)
"""
self._http = HTTPClient(base_url=base_url, api_key=api_key, timeout=timeout)
def get_readiness(self) -> Dict[str, Any]:
"""
Check the readiness of the LiteLLM proxy server.
Returns:
Dict[str, Any]: The readiness status and details from the server.
Raises:
requests.exceptions.RequestException: If the request fails
ValueError: If the response is not valid JSON
"""
return self._http.request("GET", "/health/readiness")
def get_server_version(self) -> Optional[str]:
"""
Get the LiteLLM server version from the readiness endpoint.
Returns:
Optional[str]: The server version if available, otherwise None.
"""
readiness = self.get_readiness()
return readiness.get("litellm_version")

View File

@@ -0,0 +1,95 @@
"""HTTP client for making requests to the LiteLLM proxy server."""
from typing import Any, Dict, Optional, Union
import requests
class HTTPClient:
"""HTTP client for making requests to the LiteLLM proxy server."""
def __init__(self, base_url: str, api_key: Optional[str] = None, timeout: int = 30):
"""Initialize the HTTP client.
Args:
base_url: Base URL of the LiteLLM proxy server
api_key: Optional API key for authentication
timeout: Request timeout in seconds (default: 30)
"""
self._base_url = base_url.rstrip("/")
self._api_key = api_key
self._timeout = timeout
def request(
self,
method: str,
uri: str,
*,
data: Optional[Union[Dict[str, Any], list, bytes]] = None,
json: Optional[Union[Dict[str, Any], list]] = None,
headers: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> Any:
"""Make an HTTP request to the LiteLLM proxy server.
This method is used to make generic requests to the LiteLLM proxy
server, when there is not a specific client or method for the request.
Args:
method: HTTP method (GET, POST, PUT, DELETE, etc.)
uri: URI path (will be appended to base_url) (e.g., "/credentials")
data: (optional) Dictionary, list of tuples, bytes, or file-like
object to send in the body of the request.
json: (optional) A JSON serializable Python object to send in the body
of the request.
headers: (optional) Dictionary of HTTP headers to send with the request.
**kwargs: Additional keyword arguments to pass to the request.
Returns:
Parsed JSON response from the server
Raises:
requests.exceptions.RequestException: If the request fails
ValueError: If the response is not valid JSON
Example:
>>> client.http.request("POST", "/health/test_connection", json={
"litellm_params": {
"model": "gpt-4",
"custom_llm_provider": "azure_ai",
"litellm_credential_name": None,
"api_key": "6xxxxxxx",
"api_base": "https://litellm8397336933.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-10-21",
},
"mode": "chat",
})
{'status': 'error',
'result': {'model': 'gpt-4',
'custom_llm_provider': 'azure_ai',
'litellm_credential_name': None,
...
"""
# Build complete URL
url = f"{self._base_url}/{uri.lstrip('/')}"
# Prepare headers
request_headers = {}
if headers:
request_headers.update(headers)
if self._api_key:
request_headers["Authorization"] = f"Bearer {self._api_key}"
response = requests.request(
method=method,
url=url,
data=data,
json=json,
headers=request_headers,
timeout=self._timeout,
**kwargs,
)
# Raise for HTTP errors
response.raise_for_status()
# Parse and return JSON response
return response.json()

View File

@@ -0,0 +1,319 @@
from typing import Any, Dict, List, Optional, Union
import requests
from .exceptions import UnauthorizedError
class KeysManagementClient:
def __init__(self, base_url: str, api_key: Optional[str] = None):
"""
Initialize the KeysManagementClient.
Args:
base_url (str): The base URL of the LiteLLM proxy server (e.g., "http://localhost:8000")
api_key (Optional[str]): API key for authentication. If provided, it will be sent as a Bearer token.
"""
self._base_url = base_url.rstrip("/") # Remove trailing slash if present
self._api_key = api_key
def _get_headers(self) -> Dict[str, str]:
"""
Get the headers for API requests, including authorization if api_key is set.
Returns:
Dict[str, str]: Headers to use for API requests
"""
headers = {"Content-Type": "application/json"}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
return headers
def list(
self,
page: Optional[int] = None,
size: Optional[int] = None,
user_id: Optional[str] = None,
team_id: Optional[str] = None,
organization_id: Optional[str] = None,
key_hash: Optional[str] = None,
key_alias: Optional[str] = None,
return_full_object: Optional[bool] = None,
include_team_keys: Optional[bool] = None,
return_request: bool = False,
) -> Union[Dict[str, Any], requests.Request]:
"""
List all API keys with optional filtering and pagination.
Args:
page (Optional[int]): Page number for pagination
size (Optional[int]): Number of items per page
user_id (Optional[str]): Filter keys by user ID
team_id (Optional[str]): Filter keys by team ID
organization_id (Optional[str]): Filter keys by organization ID
key_hash (Optional[str]): Filter by specific key hash
key_alias (Optional[str]): Filter by key alias
return_full_object (Optional[bool]): Whether to return the full key object
include_team_keys (Optional[bool]): Whether to include team keys in the response
return_request (bool): If True, returns the prepared request object instead of executing it
Returns:
Union[Dict[str, Any], requests.Request]: Either the response from the server or
a prepared request object if return_request is True. The response contains a list
of API keys with their configurations.
Raises:
UnauthorizedError: If the request fails with a 401 status code
requests.exceptions.RequestException: If the request fails with any other error
"""
url = f"{self._base_url}/key/list"
params: Dict[str, Any] = {}
# Add optional query parameters
if page is not None:
params["page"] = page
if size is not None:
params["size"] = size
if user_id is not None:
params["user_id"] = user_id
if team_id is not None:
params["team_id"] = team_id
if organization_id is not None:
params["organization_id"] = organization_id
if key_hash is not None:
params["key_hash"] = key_hash
if key_alias is not None:
params["key_alias"] = key_alias
if return_full_object is not None:
params["return_full_object"] = str(return_full_object).lower()
if include_team_keys is not None:
params["include_team_keys"] = str(include_team_keys).lower()
request = requests.Request(
"GET", url, headers=self._get_headers(), params=params
)
if return_request:
return request
session = requests.Session()
try:
response = session.send(request.prepare())
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise UnauthorizedError(e)
raise
def generate(
self,
models: Optional[List[str]] = None,
aliases: Optional[Dict[str, str]] = None,
spend: Optional[float] = None,
duration: Optional[str] = None,
key_alias: Optional[str] = None,
team_id: Optional[str] = None,
user_id: Optional[str] = None,
budget_id: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
return_request: bool = False,
) -> Union[Dict[str, Any], requests.Request]:
"""
Generate an API key based on the provided data.
Docs: https://docs.litellm.ai/docs/proxy/virtual_keys
Args:
models (Optional[List[str]]): List of allowed models for this key
aliases (Optional[Dict[str, str]]): Model alias mappings
spend (Optional[float]): Maximum spend limit for this key
duration (Optional[str]): Duration for which the key is valid (e.g. "24h", "7d")
key_alias (Optional[str]): Alias/name for the key for easier identification
team_id (Optional[str]): Team ID to associate the key with
user_id (Optional[str]): User ID to associate the key with
budget_id (Optional[str]): Budget ID to associate the key with
config (Optional[Dict[str, Any]]): Additional configuration parameters
return_request (bool): If True, returns the prepared request object instead of executing it
Returns:
Union[Dict[str, Any], requests.Request]: Either the response from the server or
a prepared request object if return_request is True
Raises:
UnauthorizedError: If the request fails with a 401 status code
requests.exceptions.RequestException: If the request fails with any other error
"""
url = f"{self._base_url}/key/generate"
data: Dict[str, Any] = {}
if models is not None:
data["models"] = models
if aliases is not None:
data["aliases"] = aliases
if spend is not None:
data["spend"] = spend
if duration is not None:
data["duration"] = duration
if key_alias is not None:
data["key_alias"] = key_alias
if team_id is not None:
data["team_id"] = team_id
if user_id is not None:
data["user_id"] = user_id
if budget_id is not None:
data["budget_id"] = budget_id
if config is not None:
data["config"] = config
request = requests.Request("POST", url, headers=self._get_headers(), json=data)
if return_request:
return request
session = requests.Session()
try:
response = session.send(request.prepare())
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise UnauthorizedError(e)
raise
def delete(
self,
keys: Optional[List[str]] = None,
key_aliases: Optional[List[str]] = None,
return_request: bool = False,
) -> Union[Dict[str, Any], requests.Request]:
"""
Delete existing keys
Args:
keys (List[str]): List of API keys to delete
key_aliases (List[str]): List of key aliases to delete
return_request (bool): If True, returns the prepared request object instead of executing it
Returns:
Union[Dict[str, Any], requests.Request]: Either the response from the server or
a prepared request object if return_request is True
Raises:
UnauthorizedError: If the request fails with a 401 status code
requests.exceptions.RequestException: If the request fails with any other error
"""
url = f"{self._base_url}/key/delete"
data = {
"keys": keys,
"key_aliases": key_aliases,
}
request = requests.Request("POST", url, headers=self._get_headers(), json=data)
if return_request:
return request
session = requests.Session()
try:
response = session.send(request.prepare())
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise UnauthorizedError(e)
raise
def update(
self,
key: str,
models: Optional[List[str]] = None,
aliases: Optional[Dict[str, str]] = None,
spend: Optional[float] = None,
duration: Optional[str] = None,
key_alias: Optional[str] = None,
team_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> Union[Dict[str, Any], requests.Request]:
"""
Update an existing API key's parameters.
Args:
models: Optional[List[str]] = None,
aliases: Optional[Dict[str, str]] = None,
spend: Optional[float] = None,
duration: Optional[str] = None,
key_alias: Optional[str] = None,
team_id: Optional[str] = None,
user_id: Optional[str] = None,
Returns:
Union[Dict[str, Any], requests.Request]: Either the response from the server or
a prepared request object if return_request is True
Raises:
UnauthorizedError: If the request fails with a 401 status code
requests.exceptions.RequestException: If the request fails with any other error
"""
url = f"{self._base_url}/key/update"
data: Dict[str, Any] = {"key": key}
if key_alias is not None:
data["key_alias"] = key_alias
if user_id is not None:
data["user_id"] = user_id
if team_id is not None:
data["team_id"] = team_id
if models is not None:
data["models"] = models
if spend is not None:
data["spend"] = spend
if duration is not None:
data["duration"] = duration
if aliases is not None:
data["aliases"] = aliases
request = requests.Request("POST", url, headers=self._get_headers(), json=data)
session = requests.Session()
response_text: Optional[str] = None
try:
response = session.send(request.prepare())
response_text = response.text
response.raise_for_status()
return response.json()
except Exception:
raise Exception(f"Error updating key: {response_text}")
def info(
self, key: str, return_request: bool = False
) -> Union[Dict[str, Any], requests.Request]:
"""
Get information about API keys.
Args:
key (str): The key hash to get information about
return_request (bool): If True, returns the prepared request object instead of executing it
Returns:
Union[Dict[str, Any], requests.Request]: Either the response from the server or a prepared request object if return_request is True
Raises:
UnauthorizedError: If the request fails with a 401 status code
requests.exceptions.RequestException: If the request fails with any other error
"""
url = f"{self._base_url}/key/info?key={key}"
request = requests.Request("GET", url, headers=self._get_headers())
if return_request:
return request
session = requests.Session()
try:
response = session.send(request.prepare())
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise UnauthorizedError(e)
raise

View File

@@ -0,0 +1,62 @@
import requests
from typing import List, Dict, Any, Optional, Union
from .exceptions import UnauthorizedError
class ModelGroupsManagementClient:
def __init__(self, base_url: str, api_key: Optional[str] = None):
"""
Initialize the ModelGroupsManagementClient.
Args:
base_url (str): The base URL of the LiteLLM proxy server (e.g., "http://localhost:8000")
api_key (Optional[str]): API key for authentication. If provided, it will be sent as a Bearer token.
"""
self._base_url = base_url.rstrip("/") # Remove trailing slash if present
self._api_key = api_key
def _get_headers(self) -> Dict[str, str]:
"""
Get the headers for API requests, including authorization if api_key is set.
Returns:
Dict[str, str]: Headers to use for API requests
"""
headers = {}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
return headers
def info(
self, return_request: bool = False
) -> Union[List[Dict[str, Any]], requests.Request]:
"""
Get detailed information about all model groups from the server.
Args:
return_request (bool): If True, returns the prepared request object instead of executing it
Returns:
Union[List[Dict[str, Any]], requests.Request]: Either a list of model group information dictionaries
or a prepared request object if return_request is True
Raises:
UnauthorizedError: If the request fails with a 401 status code
requests.exceptions.RequestException: If the request fails with any other error
"""
url = f"{self._base_url}/model_group/info"
request = requests.Request("GET", url, headers=self._get_headers())
if return_request:
return request
# Prepare and send the request
session = requests.Session()
try:
response = session.send(request.prepare())
response.raise_for_status()
return response.json()["data"]
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise UnauthorizedError(e)
raise

View File

@@ -0,0 +1,298 @@
import requests
from typing import List, Dict, Any, Optional, Union
from .exceptions import UnauthorizedError, NotFoundError
class ModelsManagementClient:
def __init__(self, base_url: str, api_key: Optional[str] = None):
"""
Initialize the ModelsManagementClient.
Args:
base_url (str): The base URL of the LiteLLM proxy server (e.g., "http://localhost:8000")
api_key (Optional[str]): API key for authentication. If provided, it will be sent as a Bearer token.
"""
self._base_url = base_url.rstrip("/") # Remove trailing slash if present
self._api_key = api_key
def _get_headers(self) -> Dict[str, str]:
"""
Get the headers for API requests, including authorization if api_key is set.
Returns:
Dict[str, str]: Headers to use for API requests
"""
headers = {}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
return headers
def list(
self, return_request: bool = False
) -> Union[List[Dict[str, Any]], requests.Request]:
"""
Get the list of models supported by the server.
Args:
return_request (bool): If True, returns the prepared request object instead of executing it.
Useful for inspection or modification before sending.
Returns:
Union[List[Dict[str, Any]], requests.Request]: Either a list of model information dictionaries
or a prepared request object if return_request is True.
Raises:
UnauthorizedError: If the request fails with a 401 status code
requests.exceptions.RequestException: If the request fails with any other error
"""
url = f"{self._base_url}/models"
request = requests.Request("GET", url, headers=self._get_headers())
if return_request:
return request
# Prepare and send the request
session = requests.Session()
try:
response = session.send(request.prepare())
response.raise_for_status()
return response.json()["data"]
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise UnauthorizedError(e)
raise
def new(
self,
model_name: str,
model_params: Dict[str, Any],
model_info: Optional[Dict[str, Any]] = None,
return_request: bool = False,
) -> Union[Dict[str, Any], requests.Request]:
"""
Add a new model to the proxy.
Args:
model_name (str): Name of the model to add
model_params (Dict[str, Any]): Parameters for the model (e.g., model type, api_base, api_key)
model_info (Optional[Dict[str, Any]]): Additional information about the model
return_request (bool): If True, returns the prepared request object instead of executing it
Returns:
Union[Dict[str, Any], requests.Request]: Either the response from the server or
a prepared request object if return_request is True
Raises:
UnauthorizedError: If the request fails with a 401 status code
requests.exceptions.RequestException: If the request fails with any other error
"""
url = f"{self._base_url}/model/new"
data = {
"model_name": model_name,
"litellm_params": model_params,
}
if model_info:
data["model_info"] = model_info
request = requests.Request("POST", url, headers=self._get_headers(), json=data)
if return_request:
return request
# Prepare and send the request
session = requests.Session()
try:
response = session.send(request.prepare())
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise UnauthorizedError(e)
raise
def delete(
self, model_id: str, return_request: bool = False
) -> Union[Dict[str, Any], requests.Request]:
"""
Delete a model from the proxy.
Args:
model_id (str): ID of the model to delete (e.g., "2f23364f-4579-4d79-a43a-2d48dd551c2e")
return_request (bool): If True, returns the prepared request object instead of executing it
Returns:
Union[Dict[str, Any], requests.Request]: Either the response from the server or
a prepared request object if return_request is True
Raises:
UnauthorizedError: If the request fails with a 401 status code
NotFoundError: If the request fails with a 404 status code or indicates the model was not found
requests.exceptions.RequestException: If the request fails with any other error
"""
url = f"{self._base_url}/model/delete"
data = {"id": model_id}
request = requests.Request("POST", url, headers=self._get_headers(), json=data)
if return_request:
return request
# Prepare and send the request
session = requests.Session()
try:
response = session.send(request.prepare())
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise UnauthorizedError(e)
if e.response.status_code == 404 or "not found" in e.response.text.lower():
raise NotFoundError(e)
raise
def get(
self,
model_id: Optional[str] = None,
model_name: Optional[str] = None,
return_request: bool = False,
) -> Union[Dict[str, Any], requests.Request]:
"""
Get information about a specific model by its ID or name.
Args:
model_id (Optional[str]): ID of the model to retrieve
model_name (Optional[str]): Name of the model to retrieve
return_request (bool): If True, returns the prepared request object instead of executing it
Returns:
Union[Dict[str, Any], requests.Request]: Either the model information from the server or
a prepared request object if return_request is True
Raises:
ValueError: If neither model_id nor model_name is provided, or if both are provided
UnauthorizedError: If the request fails with a 401 status code
NotFoundError: If the model is not found
requests.exceptions.RequestException: If the request fails with any other error
"""
if (model_id is None and model_name is None) or (
model_id is not None and model_name is not None
):
raise ValueError("Exactly one of model_id or model_name must be provided")
# If return_request is True, delegate to info
if return_request:
result = self.info(return_request=True)
assert isinstance(result, requests.Request)
return result
# Get all models and filter
models = self.info()
assert isinstance(models, List)
# Find the matching model
for model in models:
if (model_id and model.get("model_info", {}).get("id") == model_id) or (
model_name and model.get("model_name") == model_name
):
return model
# If we get here, no model was found
if model_id:
msg = f"Model with id={model_id} not found"
elif model_name:
msg = f"Model with model_name={model_name} not found"
else:
msg = "Unknown error trying to find model"
raise NotFoundError(
requests.exceptions.HTTPError(
msg,
response=requests.Response(), # Empty response since we didn't make a direct request
)
)
def info(
self, return_request: bool = False
) -> Union[List[Dict[str, Any]], requests.Request]:
"""
Get detailed information about all models from the server.
Args:
return_request (bool): If True, returns the prepared request object instead of executing it
Returns:
Union[List[Dict[str, Any]], requests.Request]: Either a list of model information dictionaries
or a prepared request object if return_request is True
Raises:
UnauthorizedError: If the request fails with a 401 status code
requests.exceptions.RequestException: If the request fails with any other error
"""
url = f"{self._base_url}/v1/model/info"
request = requests.Request("GET", url, headers=self._get_headers())
if return_request:
return request
# Prepare and send the request
session = requests.Session()
try:
response = session.send(request.prepare())
response.raise_for_status()
return response.json()["data"]
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise UnauthorizedError(e)
raise
def update(
self,
model_id: str,
model_params: Dict[str, Any],
model_info: Optional[Dict[str, Any]] = None,
return_request: bool = False,
) -> Union[Dict[str, Any], requests.Request]:
"""
Update an existing model's configuration.
Args:
model_id (str): ID of the model to update
model_params (Dict[str, Any]): New parameters for the model (e.g., model type, api_base, api_key)
model_info (Optional[Dict[str, Any]]): Additional information about the model
return_request (bool): If True, returns the prepared request object instead of executing it
Returns:
Union[Dict[str, Any], requests.Request]: Either the response from the server or
a prepared request object if return_request is True
Raises:
UnauthorizedError: If the request fails with a 401 status code
NotFoundError: If the model is not found
requests.exceptions.RequestException: If the request fails with any other error
"""
url = f"{self._base_url}/model/update"
data = {
"id": model_id,
"litellm_params": model_params,
}
if model_info:
data["model_info"] = model_info
request = requests.Request("POST", url, headers=self._get_headers(), json=data)
if return_request:
return request
# Prepare and send the request
session = requests.Session()
try:
response = session.send(request.prepare())
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise UnauthorizedError(e)
if e.response.status_code == 404 or "not found" in e.response.text.lower():
raise NotFoundError(e)
raise

View File

@@ -0,0 +1,146 @@
"""Teams management client for LiteLLM proxy."""
from typing import Any, Dict, List, Optional, Union
import requests
from .exceptions import UnauthorizedError
class TeamsManagementClient:
"""Client for managing teams in LiteLLM proxy."""
def __init__(self, base_url: str, api_key: Optional[str] = None):
"""
Initialize the TeamsManagementClient.
Args:
base_url (str): The base URL of the LiteLLM proxy server (e.g., "http://localhost:4000")
api_key (Optional[str]): API key for authentication. If provided, it will be sent as a Bearer token.
"""
self._base_url = base_url.rstrip("/") # Remove trailing slash if present
self._api_key = api_key
def _get_headers(self) -> Dict[str, str]:
"""
Get the headers for API requests, including authorization if api_key is set.
Returns:
Dict[str, str]: Headers to use for API requests
"""
headers = {"Content-Type": "application/json"}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
return headers
def list(
self,
user_id: Optional[str] = None,
organization_id: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
List teams that the user belongs to.
Args:
user_id (Optional[str]): Only return teams which this user belongs to
organization_id (Optional[str]): Only return teams which belong to this organization
Returns:
List[Dict[str, Any]]: List of team objects
Raises:
requests.exceptions.HTTPError: If the request fails
UnauthorizedError: If authentication fails
"""
url = f"{self._base_url}/team/list"
params = {}
if user_id:
params["user_id"] = user_id
if organization_id:
params["organization_id"] = organization_id
response = requests.get(url, headers=self._get_headers(), params=params)
if response.status_code == 401:
raise UnauthorizedError("Authentication failed. Check your API key.")
response.raise_for_status()
return response.json()
def list_v2(
self,
user_id: Optional[str] = None,
organization_id: Optional[str] = None,
team_id: Optional[str] = None,
team_alias: Optional[str] = None,
page: int = 1,
page_size: int = 10,
sort_by: Optional[str] = None,
sort_order: str = "asc",
) -> Dict[str, Any]:
"""
Get a paginated list of teams with filtering and sorting options.
Args:
user_id (Optional[str]): Only return teams which this user belongs to
organization_id (Optional[str]): Only return teams which belong to this organization
team_id (Optional[str]): Filter teams by exact team_id match
team_alias (Optional[str]): Filter teams by partial team_alias match
page (int): Page number for pagination
page_size (int): Number of teams per page
sort_by (Optional[str]): Column to sort by (e.g. 'team_id', 'team_alias', 'created_at')
sort_order (str): Sort order ('asc' or 'desc')
Returns:
Dict[str, Any]: Paginated response containing teams and pagination info
Raises:
requests.exceptions.HTTPError: If the request fails
UnauthorizedError: If authentication fails
"""
url = f"{self._base_url}/v2/team/list"
params: Dict[str, Union[str, int]] = {
"page": page,
"page_size": page_size,
"sort_order": sort_order,
}
if user_id:
params["user_id"] = user_id
if organization_id:
params["organization_id"] = organization_id
if team_id:
params["team_id"] = team_id
if team_alias:
params["team_alias"] = team_alias
if sort_by:
params["sort_by"] = sort_by
response = requests.get(url, headers=self._get_headers(), params=params)
if response.status_code == 401:
raise UnauthorizedError("Authentication failed. Check your API key.")
response.raise_for_status()
return response.json()
def get_available(self) -> List[Dict[str, Any]]:
"""
Get list of available teams that the user can join.
Returns:
List[Dict[str, Any]]: List of available team objects
Raises:
requests.exceptions.HTTPError: If the request fails
UnauthorizedError: If authentication fails
"""
url = f"{self._base_url}/team/available"
response = requests.get(url, headers=self._get_headers())
if response.status_code == 401:
raise UnauthorizedError("Authentication failed. Check your API key.")
response.raise_for_status()
return response.json()

View File

@@ -0,0 +1,58 @@
import requests
from typing import List, Dict, Any, Optional
from .exceptions import UnauthorizedError, NotFoundError
class UsersManagementClient:
def __init__(self, base_url: str, api_key: Optional[str] = None):
self.base_url = base_url.rstrip("/")
self.api_key = api_key
def _get_headers(self) -> Dict[str, str]:
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
return headers
def list_users(
self, params: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""List users (GET /user/list)"""
url = f"{self.base_url}/user/list"
response = requests.get(url, headers=self._get_headers(), params=params)
if response.status_code == 401:
raise UnauthorizedError(response.text)
response.raise_for_status()
return response.json().get("users", response.json())
def get_user(self, user_id: Optional[str] = None) -> Dict[str, Any]:
"""Get user info (GET /user/info)"""
url = f"{self.base_url}/user/info"
params = {"user_id": user_id} if user_id else {}
response = requests.get(url, headers=self._get_headers(), params=params)
if response.status_code == 401:
raise UnauthorizedError(response.text)
if response.status_code == 404:
raise NotFoundError(response.text)
response.raise_for_status()
return response.json()
def create_user(self, user_data: Dict[str, Any]) -> Dict[str, Any]:
"""Create a new user (POST /user/new)"""
url = f"{self.base_url}/user/new"
response = requests.post(url, headers=self._get_headers(), json=user_data)
if response.status_code == 401:
raise UnauthorizedError(response.text)
response.raise_for_status()
return response.json()
def delete_user(self, user_ids: List[str]) -> Dict[str, Any]:
"""Delete users (POST /user/delete)"""
url = f"{self.base_url}/user/delete"
response = requests.post(
url, headers=self._get_headers(), json={"user_ids": user_ids}
)
if response.status_code == 401:
raise UnauthorizedError(response.text)
response.raise_for_status()
return response.json()

View File

@@ -0,0 +1,169 @@
def show_missing_vars_in_env():
from fastapi.responses import HTMLResponse
from litellm.proxy.proxy_server import master_key, prisma_client
if prisma_client is None and master_key is None:
return HTMLResponse(
content=missing_keys_form(
missing_key_names="DATABASE_URL, LITELLM_MASTER_KEY"
),
status_code=200,
)
if prisma_client is None:
return HTMLResponse(
content=missing_keys_form(missing_key_names="DATABASE_URL"), status_code=200
)
if master_key is None:
return HTMLResponse(
content=missing_keys_form(missing_key_names="LITELLM_MASTER_KEY"),
status_code=200,
)
return None
def missing_keys_form(missing_key_names: str):
missing_keys_html_form = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {{
font-family: Arial, sans-serif;
background-color: #f4f4f9;
color: #333;
margin: 20px;
line-height: 1.6;
}}
.container {{
max-width: 800px;
margin: auto;
padding: 20px;
background: #fff;
border: 1px solid #ddd;
border-radius: 5px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}}
h1 {{
font-size: 24px;
margin-bottom: 20px;
}}
pre {{
background: #f8f8f8;
padding: 1px;
border: 1px solid #ccc;
border-radius: 4px;
overflow-x: auto;
font-size: 14px;
}}
.env-var {{
font-weight: normal;
}}
.comment {{
font-weight: normal;
color: #777;
}}
</style>
<title>Environment Setup Instructions</title>
</head>
<body>
<div class="container">
<h1>Environment Setup Instructions</h1>
<p>Please add the following variables to your environment variables:</p>
<pre>
<span class="env-var">LITELLM_MASTER_KEY="sk-1234"</span> <span class="comment"># Your master key for the proxy server. Can use this to send /chat/completion requests etc</span>
<span class="env-var">LITELLM_SALT_KEY="sk-XXXXXXXX"</span> <span class="comment"># Can NOT CHANGE THIS ONCE SET - It is used to encrypt/decrypt credentials stored in DB. If value of 'LITELLM_SALT_KEY' changes your models cannot be retrieved from DB</span>
<span class="env-var">DATABASE_URL="postgres://..."</span> <span class="comment"># Need a postgres database? (Check out Supabase, Neon, etc)</span>
<span class="comment">## OPTIONAL ##</span>
<span class="env-var">PORT=4000</span> <span class="comment"># DO THIS FOR RENDER/RAILWAY</span>
<span class="env-var">STORE_MODEL_IN_DB="True"</span> <span class="comment"># Allow storing models in db</span>
</pre>
<h1>Missing Environment Variables</h1>
<p>{missing_keys}</p>
</div>
<div class="container">
<h1>Need Help? Support</h1>
<p>Discord: <a href="https://discord.com/invite/wuPM9dRgDw" target="_blank">https://discord.com/invite/wuPM9dRgDw</a></p>
<p>Docs: <a href="https://docs.litellm.ai/docs/" target="_blank">https://docs.litellm.ai/docs/</a></p>
</div>
</body>
</html>
"""
return missing_keys_html_form.format(missing_keys=missing_key_names)
def admin_ui_disabled():
from fastapi.responses import HTMLResponse
ui_disabled_html = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {{
font-family: Arial, sans-serif;
background-color: #f4f4f9;
color: #333;
margin: 20px;
line-height: 1.6;
}}
.container {{
max-width: 800px;
margin: auto;
padding: 20px;
background: #fff;
border: 1px solid #ddd;
border-radius: 5px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}}
h1 {{
font-size: 24px;
margin-bottom: 20px;
}}
pre {{
background: #f8f8f8;
padding: 1px;
border: 1px solid #ccc;
border-radius: 4px;
overflow-x: auto;
font-size: 14px;
}}
.env-var {{
font-weight: normal;
}}
.comment {{
font-weight: normal;
color: #777;
}}
</style>
<title>Admin UI Disabled</title>
</head>
<body>
<div class="container">
<h1>Admin UI is Disabled</h1>
<p>The Admin UI has been disabled by the administrator. To re-enable it, please update the following environment variable:</p>
<pre>
<span class="env-var">DISABLE_ADMIN_UI="False"</span> <span class="comment"># Set this to "False" to enable the Admin UI.</span>
</pre>
<p>After making this change, restart the application for it to take effect.</p>
</div>
<div class="container">
<h1>Need Help? Support</h1>
<p>Discord: <a href="https://discord.com/invite/wuPM9dRgDw" target="_blank">https://discord.com/invite/wuPM9dRgDw</a></p>
<p>Docs: <a href="https://docs.litellm.ai/docs/" target="_blank">https://docs.litellm.ai/docs/</a></p>
</div>
</body>
</html>
"""
return HTMLResponse(
content=ui_disabled_html,
status_code=200,
)

View File

@@ -0,0 +1,17 @@
# LiteLLM ASCII banner
LITELLM_BANNER = """ ██╗ ██╗████████╗███████╗██╗ ██╗ ███╗ ███╗
██║ ██║╚══██╔══╝██╔════╝██║ ██║ ████╗ ████║
██║ ██║ ██║ █████╗ ██║ ██║ ██╔████╔██║
██║ ██║ ██║ ██╔══╝ ██║ ██║ ██║╚██╔╝██║
███████╗██║ ██║ ███████╗███████╗███████╗██║ ╚═╝ ██║
╚══════╝╚═╝ ╚═╝ ╚══════╝╚══════╝╚══════╝╚═╝ ╚═╝"""
def show_banner():
"""Display the LiteLLM CLI banner."""
try:
import click
click.echo(f"\n{LITELLM_BANNER}\n")
except ImportError:
print("\n") # noqa: T201

View File

@@ -0,0 +1,190 @@
"""
Event-driven cache coordinator to prevent cache stampede.
Use this when many requests can miss the same cache key at once (e.g. after
expiry or restart). Without coordination, they would all run the expensive
load (DB query, API call) in parallel and overload the backend.
This module ensures only one request performs the load; the rest wait for a
signal and then read the freshly cached value. Reuse it for any cache-aside
pattern: global spend, feature flags, config, or other shared read-through data.
"""
import asyncio
import time
from typing import Any, Awaitable, Callable, Optional, Protocol, TypeVar
from litellm._logging import verbose_proxy_logger
T = TypeVar("T")
class AsyncCacheProtocol(Protocol):
"""Protocol for cache backends used by EventDrivenCacheCoordinator."""
async def async_get_cache(self, key: str, **kwargs: Any) -> Any:
...
async def async_set_cache(self, key: str, value: Any, **kwargs: Any) -> Any:
...
class EventDrivenCacheCoordinator:
"""
Coordinates a single in-flight load per logical resource to prevent cache stampede.
Pattern:
- First request: loads data (e.g. DB query), caches it, then signals waiters.
- Other requests: wait for the signal, then read from cache.
Create one instance per resource (e.g. one for global spend, one for feature flags).
"""
def __init__(self, log_prefix: str = "[CACHE]"):
self._lock = asyncio.Lock()
self._event: Optional[asyncio.Event] = None
self._query_in_progress = False
self._log_prefix = log_prefix
async def _get_cached(
self, cache_key: str, cache: AsyncCacheProtocol
) -> Optional[Any]:
"""Return value from cache if present, else None."""
return await cache.async_get_cache(key=cache_key)
def _log_cache_hit(self, value: T) -> None:
if self._log_prefix:
verbose_proxy_logger.debug(
"%s Cache hit, value: %s", self._log_prefix, value
)
def _log_cache_miss(self) -> None:
if self._log_prefix:
verbose_proxy_logger.debug("%s Cache miss", self._log_prefix)
async def _claim_role(self) -> Optional[asyncio.Event]:
"""
Under lock: return event to wait on if load is in progress, else set us as loader and return None.
"""
async with self._lock:
if self._query_in_progress and self._event is not None:
if self._log_prefix:
verbose_proxy_logger.debug(
"%s Load in flight, waiting for signal", self._log_prefix
)
return self._event
self._query_in_progress = True
self._event = asyncio.Event()
if self._log_prefix:
verbose_proxy_logger.debug(
"%s Starting load (will signal others when done)",
self._log_prefix,
)
return None
async def _wait_for_signal_and_get(
self,
event: asyncio.Event,
cache_key: str,
cache: AsyncCacheProtocol,
) -> Optional[T]:
"""Wait for loader to finish, then read from cache."""
await event.wait()
if self._log_prefix:
verbose_proxy_logger.debug(
"%s Signal received, reading from cache", self._log_prefix
)
value: Optional[T] = await cache.async_get_cache(key=cache_key)
if value is not None and self._log_prefix:
verbose_proxy_logger.debug(
"%s Cache filled by other request, value: %s",
self._log_prefix,
value,
)
elif value is None and self._log_prefix:
verbose_proxy_logger.debug(
"%s Signal received but cache still empty", self._log_prefix
)
return value
async def _load_and_cache(
self,
cache_key: str,
cache: AsyncCacheProtocol,
load_fn: Callable[[], Awaitable[T]],
) -> Optional[T]:
"""Double-check cache, run load_fn, set cache, return value. Caller must call _signal_done in finally."""
value = await cache.async_get_cache(key=cache_key)
if value is not None:
if self._log_prefix:
verbose_proxy_logger.debug(
"%s Cache filled while acquiring lock, value: %s",
self._log_prefix,
value,
)
return value
if self._log_prefix:
verbose_proxy_logger.debug("%s Running load", self._log_prefix)
start = time.perf_counter()
value = await load_fn()
elapsed_ms = (time.perf_counter() - start) * 1000
if self._log_prefix:
verbose_proxy_logger.debug(
"%s Load completed in %.2fms, result: %s",
self._log_prefix,
elapsed_ms,
value,
)
await cache.async_set_cache(key=cache_key, value=value)
if self._log_prefix:
verbose_proxy_logger.debug("%s Result cached", self._log_prefix)
return value
async def _signal_done(self) -> None:
"""Reset loader state and signal all waiters."""
async with self._lock:
self._query_in_progress = False
if self._event is not None:
if self._log_prefix:
verbose_proxy_logger.debug(
"%s Signaling all waiting requests", self._log_prefix
)
self._event.set()
self._event = None
async def get_or_load(
self,
cache_key: str,
cache: AsyncCacheProtocol,
load_fn: Callable[[], Awaitable[T]],
) -> Optional[T]:
"""
Return cached value or load it once and signal waiters.
- cache_key: Key to read/write in the cache.
- cache: Object with async_get_cache(key) and async_set_cache(key, value).
- load_fn: Async callable that performs the load (e.g. DB query). No args.
Return value is cached and returned. If it raises, waiters are
still signaled so they can retry or handle empty cache.
Returns the value from cache or from load_fn, or None if load failed or
cache was still empty after waiting.
"""
value = await self._get_cached(cache_key, cache)
if value is not None:
self._log_cache_hit(value)
return value
self._log_cache_miss()
event_to_wait = await self._claim_role()
if event_to_wait is not None:
return await self._wait_for_signal_and_get(event_to_wait, cache_key, cache)
try:
result = await self._load_and_cache(cache_key, cache, load_fn)
return result
finally:
await self._signal_done()

View File

@@ -0,0 +1,526 @@
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional
import litellm
from litellm import get_secret
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams
from litellm.proxy.types_utils.utils import get_instance_fn
from litellm.types.utils import (
StandardLoggingGuardrailInformation,
StandardLoggingPayload,
)
blue_color_code = "\033[94m"
reset_color_code = "\033[0m"
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
def initialize_callbacks_on_proxy( # noqa: PLR0915
value: Any,
premium_user: bool,
config_file_path: str,
litellm_settings: dict,
callback_specific_params: dict = {},
):
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.logging_callback_manager import (
LoggingCallbackManager,
)
from litellm.proxy.proxy_server import prisma_client
verbose_proxy_logger.debug(
f"{blue_color_code}initializing callbacks={value} on proxy{reset_color_code}"
)
if isinstance(value, list):
imported_list: List[Any] = []
for callback in value: # ["presidio", <my-custom-callback>]
# check if callback is a custom logger compatible callback
if isinstance(callback, str):
callback = LoggingCallbackManager._add_custom_callback_generic_api_str(
callback
)
if (
isinstance(callback, str)
and callback in litellm._known_custom_logger_compatible_callbacks
):
imported_list.append(callback)
elif isinstance(callback, str) and callback == "presidio":
from litellm.proxy.guardrails.guardrail_hooks.presidio import (
_OPTIONAL_PresidioPIIMasking,
)
presidio_logging_only: Optional[bool] = litellm_settings.get(
"presidio_logging_only", None
)
if presidio_logging_only is not None:
presidio_logging_only = bool(
presidio_logging_only
) # validate boolean given
_presidio_params = {}
if "presidio" in callback_specific_params and isinstance(
callback_specific_params["presidio"], dict
):
_presidio_params = callback_specific_params["presidio"]
params: Dict[str, Any] = {
"logging_only": presidio_logging_only,
**_presidio_params,
}
pii_masking_object = _OPTIONAL_PresidioPIIMasking(**params)
imported_list.append(pii_masking_object)
elif isinstance(callback, str) and callback == "llamaguard_moderations":
try:
from litellm_enterprise.enterprise_callbacks.llama_guard import (
_ENTERPRISE_LlamaGuard,
)
except ImportError:
raise Exception(
"MissingTrying to use Llama Guard"
+ CommonProxyErrors.missing_enterprise_package.value
)
if premium_user is not True:
raise Exception(
"Trying to use Llama Guard"
+ CommonProxyErrors.not_premium_user.value
)
llama_guard_object = _ENTERPRISE_LlamaGuard()
imported_list.append(llama_guard_object)
elif isinstance(callback, str) and callback == "hide_secrets":
try:
from litellm_enterprise.enterprise_callbacks.secret_detection import (
_ENTERPRISE_SecretDetection,
)
except ImportError:
raise Exception(
"Trying to use Secret Detection"
+ CommonProxyErrors.missing_enterprise_package.value
)
if premium_user is not True:
raise Exception(
"Trying to use secret hiding"
+ CommonProxyErrors.not_premium_user.value
)
_secret_detection_object = _ENTERPRISE_SecretDetection()
imported_list.append(_secret_detection_object)
elif isinstance(callback, str) and callback == "openai_moderations":
try:
from enterprise.enterprise_hooks.openai_moderation import (
_ENTERPRISE_OpenAI_Moderation,
)
except ImportError:
raise Exception(
"Trying to use OpenAI Moderations Check,"
+ CommonProxyErrors.missing_enterprise_package_docker.value
)
if premium_user is not True:
raise Exception(
"Trying to use OpenAI Moderations Check"
+ CommonProxyErrors.not_premium_user.value
)
openai_moderations_object = _ENTERPRISE_OpenAI_Moderation()
imported_list.append(openai_moderations_object)
elif isinstance(callback, str) and callback == "lakera_prompt_injection":
from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import (
lakeraAI_Moderation,
)
init_params = {}
if "lakera_prompt_injection" in callback_specific_params:
init_params = callback_specific_params["lakera_prompt_injection"]
lakera_moderations_object = lakeraAI_Moderation(**init_params)
imported_list.append(lakera_moderations_object)
elif isinstance(callback, str) and callback == "aporia_prompt_injection":
from litellm.proxy.guardrails.guardrail_hooks.aporia_ai.aporia_ai import (
AporiaGuardrail,
)
aporia_guardrail_object = AporiaGuardrail()
imported_list.append(aporia_guardrail_object)
elif isinstance(callback, str) and callback == "google_text_moderation":
try:
from enterprise.enterprise_hooks.google_text_moderation import (
_ENTERPRISE_GoogleTextModeration,
)
except ImportError:
raise Exception(
"Trying to use Google Text Moderation,"
+ CommonProxyErrors.missing_enterprise_package_docker.value
)
if premium_user is not True:
raise Exception(
"Trying to use Google Text Moderation"
+ CommonProxyErrors.not_premium_user.value
)
google_text_moderation_obj = _ENTERPRISE_GoogleTextModeration()
imported_list.append(google_text_moderation_obj)
elif isinstance(callback, str) and callback == "llmguard_moderations":
try:
from litellm_enterprise.enterprise_callbacks.llm_guard import (
_ENTERPRISE_LLMGuard,
)
except ImportError:
raise Exception(
"Trying to use Llm Guard"
+ CommonProxyErrors.missing_enterprise_package.value
)
if premium_user is not True:
raise Exception(
"Trying to use Llm Guard"
+ CommonProxyErrors.not_premium_user.value
)
llm_guard_moderation_obj = _ENTERPRISE_LLMGuard()
imported_list.append(llm_guard_moderation_obj)
elif isinstance(callback, str) and callback == "blocked_user_check":
try:
from enterprise.enterprise_hooks.blocked_user_list import (
_ENTERPRISE_BlockedUserList,
)
except ImportError:
raise Exception(
"Trying to use Blocked User List"
+ CommonProxyErrors.missing_enterprise_package_docker.value
)
if premium_user is not True:
raise Exception(
"Trying to use ENTERPRISE BlockedUser"
+ CommonProxyErrors.not_premium_user.value
)
blocked_user_list = _ENTERPRISE_BlockedUserList(
prisma_client=prisma_client
)
imported_list.append(blocked_user_list)
elif isinstance(callback, str) and callback == "banned_keywords":
try:
from enterprise.enterprise_hooks.banned_keywords import (
_ENTERPRISE_BannedKeywords,
)
except ImportError:
raise Exception(
"Trying to use Banned Keywords"
+ CommonProxyErrors.missing_enterprise_package_docker.value
)
if premium_user is not True:
raise Exception(
"Trying to use ENTERPRISE BannedKeyword"
+ CommonProxyErrors.not_premium_user.value
)
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
imported_list.append(banned_keywords_obj)
elif isinstance(callback, str) and callback == "detect_prompt_injection":
from litellm.proxy.hooks.prompt_injection_detection import (
_OPTIONAL_PromptInjectionDetection,
)
prompt_injection_params = None
if "prompt_injection_params" in litellm_settings:
prompt_injection_params_in_config = litellm_settings[
"prompt_injection_params"
]
prompt_injection_params = LiteLLMPromptInjectionParams(
**prompt_injection_params_in_config
)
prompt_injection_detection_obj = _OPTIONAL_PromptInjectionDetection(
prompt_injection_params=prompt_injection_params,
)
imported_list.append(prompt_injection_detection_obj)
elif isinstance(callback, str) and callback == "batch_redis_requests":
from litellm.proxy.hooks.batch_redis_get import (
_PROXY_BatchRedisRequests,
)
batch_redis_obj = _PROXY_BatchRedisRequests()
imported_list.append(batch_redis_obj)
elif isinstance(callback, str) and callback == "azure_content_safety":
from litellm.proxy.hooks.azure_content_safety import (
_PROXY_AzureContentSafety,
)
azure_content_safety_params = litellm_settings[
"azure_content_safety_params"
]
for k, v in azure_content_safety_params.items():
if (
v is not None
and isinstance(v, str)
and v.startswith("os.environ/")
):
azure_content_safety_params[k] = get_secret(v)
azure_content_safety_obj = _PROXY_AzureContentSafety(
**azure_content_safety_params,
)
imported_list.append(azure_content_safety_obj)
elif isinstance(callback, str) and callback == "websearch_interception":
from litellm.integrations.websearch_interception.handler import (
WebSearchInterceptionLogger,
)
websearch_interception_obj = (
WebSearchInterceptionLogger.initialize_from_proxy_config(
litellm_settings=litellm_settings,
callback_specific_params=callback_specific_params,
)
)
imported_list.append(websearch_interception_obj)
elif isinstance(callback, str) and callback == "datadog_cost_management":
from litellm.integrations.datadog.datadog_cost_management import (
DatadogCostManagementLogger,
)
datadog_cost_management_obj = DatadogCostManagementLogger()
imported_list.append(datadog_cost_management_obj)
elif isinstance(callback, CustomLogger):
imported_list.append(callback)
else:
verbose_proxy_logger.debug(
f"{blue_color_code} attempting to import custom calback={callback} {reset_color_code}"
)
imported_list.append(
get_instance_fn(
value=callback,
config_file_path=config_file_path,
)
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.extend(imported_list)
else:
litellm.callbacks = imported_list # type: ignore
if "prometheus" in value:
from litellm.integrations.prometheus import PrometheusLogger
PrometheusLogger._mount_metrics_endpoint()
else:
litellm.callbacks = [
get_instance_fn(
value=value,
config_file_path=config_file_path,
)
]
verbose_proxy_logger.debug(
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
)
def get_model_group_from_litellm_kwargs(kwargs: dict) -> Optional[str]:
_litellm_params = kwargs.get("litellm_params", None) or {}
_metadata = (
_litellm_params.get(get_metadata_variable_name_from_kwargs(kwargs)) or {}
)
_model_group = _metadata.get("model_group", None)
if _model_group is not None:
return _model_group
return None
def get_model_group_from_request_data(data: dict) -> Optional[str]:
_metadata = data.get("metadata", None) or {}
_model_group = _metadata.get("model_group", None)
if _model_group is not None:
return _model_group
return None
def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str, str]:
"""
Helper function to return x-litellm-key-remaining-tokens-{model_group} and x-litellm-key-remaining-requests-{model_group}
Returns {} when api_key + model rpm/tpm limit is not set
"""
headers = {}
_metadata = data.get("metadata", None) or {}
model_group = get_model_group_from_request_data(data)
# The h11 package considers "/" or ":" invalid and raise a LocalProtocolError
h11_model_group_name = (
model_group.replace("/", "-").replace(":", "-") if model_group else None
)
# Remaining Requests
remaining_requests_variable_name = f"litellm-key-remaining-requests-{model_group}"
remaining_requests = _metadata.get(remaining_requests_variable_name, None)
if remaining_requests:
headers[
f"x-litellm-key-remaining-requests-{h11_model_group_name}"
] = remaining_requests
# Remaining Tokens
remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}"
remaining_tokens = _metadata.get(remaining_tokens_variable_name, None)
if remaining_tokens:
headers[
f"x-litellm-key-remaining-tokens-{h11_model_group_name}"
] = remaining_tokens
return headers
def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]:
_metadata = request_data.get("metadata", None)
if not _metadata:
_metadata = request_data.get("litellm_metadata", None)
if not isinstance(_metadata, dict):
_metadata = {}
headers = {}
if "applied_guardrails" in _metadata:
headers["x-litellm-applied-guardrails"] = ",".join(
_metadata["applied_guardrails"]
)
if "applied_policies" in _metadata:
headers["x-litellm-applied-policies"] = ",".join(_metadata["applied_policies"])
if "policy_sources" in _metadata:
sources = _metadata["policy_sources"]
if isinstance(sources, dict) and sources:
# Use ';' as delimiter — matched_via reasons may contain commas
headers["x-litellm-policy-sources"] = "; ".join(
f"{name}={reason}" for name, reason in sources.items()
)
if "semantic-similarity" in _metadata:
headers["x-litellm-semantic-similarity"] = str(_metadata["semantic-similarity"])
pillar_headers = _metadata.get("pillar_response_headers")
if isinstance(pillar_headers, dict):
headers.update(pillar_headers)
elif "pillar_flagged" in _metadata:
headers["x-pillar-flagged"] = str(_metadata["pillar_flagged"]).lower()
return headers
def add_guardrail_to_applied_guardrails_header(
request_data: Dict, guardrail_name: Optional[str]
):
if guardrail_name is None:
return
_metadata = request_data.get("metadata", None) or {}
if "applied_guardrails" in _metadata:
_metadata["applied_guardrails"].append(guardrail_name)
else:
_metadata["applied_guardrails"] = [guardrail_name]
# Ensure metadata is set back to request_data (important when metadata didn't exist)
request_data["metadata"] = _metadata
def add_policy_to_applied_policies_header(
request_data: Dict, policy_name: Optional[str]
):
"""
Add a policy name to the applied_policies list in request metadata.
This is used to track which policies were applied to a request,
similar to how applied_guardrails tracks guardrails.
"""
if policy_name is None:
return
_metadata = request_data.get("metadata", None) or {}
if "applied_policies" in _metadata:
if policy_name not in _metadata["applied_policies"]:
_metadata["applied_policies"].append(policy_name)
else:
_metadata["applied_policies"] = [policy_name]
# Ensure metadata is set back to request_data (important when metadata didn't exist)
request_data["metadata"] = _metadata
def add_policy_sources_to_metadata(request_data: Dict, policy_sources: Dict[str, str]):
"""
Store policy match reasons in metadata for x-litellm-policy-sources header.
Args:
request_data: The request data dict
policy_sources: Map of policy_name -> matched_via reason
"""
if not policy_sources:
return
_metadata = request_data.get("metadata", None) or {}
existing = _metadata.get("policy_sources", {})
if not isinstance(existing, dict):
existing = {}
existing.update(policy_sources)
_metadata["policy_sources"] = existing
request_data["metadata"] = _metadata
def add_guardrail_response_to_standard_logging_object(
litellm_logging_obj: Optional["LiteLLMLogging"],
guardrail_response: StandardLoggingGuardrailInformation,
):
if litellm_logging_obj is None:
return
standard_logging_object: Optional[
StandardLoggingPayload
] = litellm_logging_obj.model_call_details.get("standard_logging_object")
if standard_logging_object is None:
return
guardrail_information = standard_logging_object.get("guardrail_information", [])
if guardrail_information is None:
guardrail_information = []
guardrail_information.append(guardrail_response)
standard_logging_object["guardrail_information"] = guardrail_information
return standard_logging_object
def get_metadata_variable_name_from_kwargs(
kwargs: dict,
) -> Literal["metadata", "litellm_metadata"]:
"""
Helper to return what the "metadata" field should be called in the request data
- New endpoints return `litellm_metadata`
- Old endpoints return `metadata`
Context:
- LiteLLM used `metadata` as an internal field for storing metadata
- OpenAI then started using this field for their metadata
- LiteLLM is now moving to using `litellm_metadata` for our metadata
"""
return "litellm_metadata" if "litellm_metadata" in kwargs else "metadata"
def process_callback(
_callback: str, callback_type: str, environment_variables: dict
) -> dict:
"""Process a single callback and return its data with environment variables"""
env_vars = CustomLogger.get_callback_env_vars(_callback)
env_vars_dict: dict[str, str | None] = {}
for _var in env_vars:
env_variable = environment_variables.get(_var, None)
if env_variable is None:
env_vars_dict[_var] = None
else:
env_vars_dict[_var] = env_variable
return {"name": _callback, "variables": env_vars_dict, "type": callback_type}
def normalize_callback_names(callbacks: Iterable[Any]) -> List[Any]:
if callbacks is None:
return []
return [c.lower() if isinstance(c, str) else c for c in callbacks]

View File

@@ -0,0 +1,437 @@
from typing import Any, Dict, List, Optional, Type
from litellm._logging import verbose_proxy_logger
class CustomOpenAPISpec:
"""
Handler for customizing OpenAPI specifications with Pydantic models
for documentation purposes without runtime validation.
"""
CHAT_COMPLETION_PATHS = [
"/v1/chat/completions",
"/chat/completions",
"/engines/{model}/chat/completions",
"/openai/deployments/{model}/chat/completions",
]
EMBEDDING_PATHS = [
"/v1/embeddings",
"/embeddings",
"/engines/{model}/embeddings",
"/openai/deployments/{model}/embeddings",
]
RESPONSES_API_PATHS = ["/v1/responses", "/responses"]
@staticmethod
def get_pydantic_schema(model_class) -> Optional[Dict[str, Any]]:
"""
Get JSON schema from a Pydantic model, handling both v1 and v2 APIs.
Args:
model_class: Pydantic model class
Returns:
JSON schema dict or None if failed
"""
try:
# Try Pydantic v2 method first
return model_class.model_json_schema() # type: ignore
except AttributeError:
try:
# Fallback to Pydantic v1 method
return model_class.schema() # type: ignore
except AttributeError:
# If both methods fail, return None
return None
except Exception as e:
# FastAPI 0.120+ may fail schema generation for certain types (e.g., openai.Timeout)
# Log the error and return None to skip schema generation for this model
verbose_proxy_logger.debug(
f"Failed to generate schema for {model_class}: {e}"
)
return None
@staticmethod
def add_schema_to_components(
openapi_schema: Dict[str, Any], schema_name: str, schema_def: Dict[str, Any]
) -> None:
"""
Add a schema definition to the OpenAPI components/schemas section.
Args:
openapi_schema: The OpenAPI schema dict to modify
schema_name: Name for the schema component
schema_def: The schema definition
"""
# Ensure components/schemas structure exists
if "components" not in openapi_schema:
openapi_schema["components"] = {}
if "schemas" not in openapi_schema["components"]:
openapi_schema["components"]["schemas"] = {}
# Add the schema
CustomOpenAPISpec._move_defs_to_components(
openapi_schema, {schema_name: schema_def}
)
@staticmethod
def add_request_body_to_paths(
openapi_schema: Dict[str, Any], paths: List[str], schema_ref: str
) -> None:
"""
Add request body with expanded form fields for better Swagger UI display.
This keeps the request body but expands it to show individual fields in the UI.
Args:
openapi_schema: The OpenAPI schema dict to modify
paths: List of paths to update
schema_ref: Reference to the schema component (e.g., "#/components/schemas/ModelName")
"""
for path in paths:
if (
path in openapi_schema.get("paths", {})
and "post" in openapi_schema["paths"][path]
):
# Get the actual schema to extract ALL field definitions
schema_name = schema_ref.split("/")[
-1
] # Extract "ProxyChatCompletionRequest" from the ref
actual_schema = (
openapi_schema.get("components", {})
.get("schemas", {})
.get(schema_name, {})
)
schema_properties = actual_schema.get("properties", {})
required_fields = actual_schema.get("required", [])
# Extract $defs and add them to components/schemas
# This fixes Pydantic v2 $defs not being resolvable in Swagger/OpenAPI
if "$defs" in actual_schema:
CustomOpenAPISpec._move_defs_to_components(
openapi_schema, actual_schema["$defs"]
)
# Create an expanded inline schema instead of just a $ref
# This makes Swagger UI show all individual fields in the request body editor
expanded_schema = {
"type": "object",
"required": required_fields,
"properties": {},
}
# Add all properties with their full definitions
for field_name, field_def in schema_properties.items():
expanded_field = CustomOpenAPISpec._expand_field_definition(
field_def
)
# Rewrite $defs references to use components/schemas instead
expanded_field = CustomOpenAPISpec._rewrite_defs_refs(
expanded_field
)
# Add a simple example for the messages field
if field_name == "messages":
expanded_field["example"] = [
{"role": "user", "content": "Hello, how are you?"}
]
expanded_schema["properties"][field_name] = expanded_field
# Set the request body with the expanded schema
openapi_schema["paths"][path]["post"]["requestBody"] = {
"required": True,
"content": {"application/json": {"schema": expanded_schema}},
}
# Keep any existing parameters (like path parameters) but remove conflicting query params
if "parameters" in openapi_schema["paths"][path]["post"]:
existing_params = openapi_schema["paths"][path]["post"][
"parameters"
]
# Only keep path parameters, remove query params that conflict with request body
filtered_params = [
param for param in existing_params if param.get("in") == "path"
]
openapi_schema["paths"][path]["post"][
"parameters"
] = filtered_params
@staticmethod
def _move_defs_to_components(
openapi_schema: Dict[str, Any], defs: Dict[str, Any]
) -> None:
"""
Move $defs from Pydantic v2 schema to OpenAPI components/schemas.
This makes the definitions resolvable in Swagger/OpenAPI viewers.
Args:
openapi_schema: The OpenAPI schema dict to modify
defs: The $defs dictionary from Pydantic schema
"""
if not defs:
return
# Ensure components/schemas exists
if "components" not in openapi_schema:
openapi_schema["components"] = {}
if "schemas" not in openapi_schema["components"]:
openapi_schema["components"]["schemas"] = {}
# Add each definition to components/schemas
for def_name, def_schema in defs.items():
# Recursively rewrite any nested $defs references within this definition
rewritten_def = CustomOpenAPISpec._rewrite_defs_refs(def_schema)
openapi_schema["components"]["schemas"][def_name] = rewritten_def
# If this definition also has $defs, process them recursively
if "$defs" in def_schema:
CustomOpenAPISpec._move_defs_to_components(
openapi_schema, def_schema["$defs"]
)
@staticmethod
def _rewrite_defs_refs(schema: Any) -> Any:
"""
Recursively rewrite $ref values from #/$defs/... to #/components/schemas/...
This converts Pydantic v2 references to OpenAPI-compatible references.
Args:
schema: Schema object to process (can be dict, list, or primitive)
Returns:
Schema with rewritten references
"""
if isinstance(schema, dict):
result = {}
for key, value in schema.items():
if (
key == "$ref"
and isinstance(value, str)
and value.startswith("#/$defs/")
):
# Rewrite the reference to use components/schemas
def_name = value.replace("#/$defs/", "")
result[key] = f"#/components/schemas/{def_name}"
elif key == "$defs":
# Remove $defs from the schema since they're moved to components
continue
else:
# Recursively process nested structures
result[key] = CustomOpenAPISpec._rewrite_defs_refs(value)
return result
elif isinstance(schema, list):
return [CustomOpenAPISpec._rewrite_defs_refs(item) for item in schema]
else:
return schema
@staticmethod
def _extract_field_schema(field_def: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract a simple schema from a Pydantic field definition for parameter display.
Args:
field_def: Pydantic field definition
Returns:
Simplified schema for OpenAPI parameter
"""
# Handle simple types
if "type" in field_def:
return {"type": field_def["type"]}
# Handle anyOf (Optional fields in Pydantic v2)
if "anyOf" in field_def:
any_of = field_def["anyOf"]
# Find the non-null type
for option in any_of:
if option.get("type") != "null":
return option
# Fallback to string if all else fails
return {"type": "string"}
# Default fallback
return {"type": "string"}
@staticmethod
def _expand_field_definition(field_def: Dict[str, Any]) -> Dict[str, Any]:
"""
Expand a Pydantic field definition for inline use in OpenAPI schema.
This creates a full field definition that Swagger UI can render as individual form fields.
Args:
field_def: Pydantic field definition
Returns:
Expanded field definition for OpenAPI schema
"""
# Return the field definition as-is since Pydantic already provides proper schemas
return field_def.copy()
@staticmethod
def add_request_schema(
openapi_schema: Dict[str, Any],
model_class: Type,
schema_name: str,
paths: List[str],
operation_name: str,
) -> Dict[str, Any]:
"""
Generic method to add a request schema to OpenAPI specification.
Args:
openapi_schema: The OpenAPI schema dict to modify
model_class: The Pydantic model class to get schema from
schema_name: Name for the schema component
paths: List of paths to add the request body to
operation_name: Name of the operation for logging (e.g., "chat completion", "embedding")
Returns:
Modified OpenAPI schema
"""
try:
# Get the schema for the model class
request_schema = CustomOpenAPISpec.get_pydantic_schema(model_class)
# Only proceed if we successfully got the schema
if request_schema is not None:
# Add schema to components
CustomOpenAPISpec.add_schema_to_components(
openapi_schema, schema_name, request_schema
)
# Add request body to specified endpoints
CustomOpenAPISpec.add_request_body_to_paths(
openapi_schema, paths, f"#/components/schemas/{schema_name}"
)
verbose_proxy_logger.debug(
f"Successfully added {schema_name} schema to OpenAPI spec"
)
else:
verbose_proxy_logger.debug(f"Could not get schema for {schema_name}")
except Exception as e:
# If schema addition fails, continue without it
verbose_proxy_logger.debug(
f"Failed to add {operation_name} request schema: {str(e)}"
)
return openapi_schema
@staticmethod
def add_chat_completion_request_schema(
openapi_schema: Dict[str, Any]
) -> Dict[str, Any]:
"""
Add ProxyChatCompletionRequest schema to chat completion endpoints for documentation.
This shows the request body in Swagger without runtime validation.
Args:
openapi_schema: The OpenAPI schema dict to modify
Returns:
Modified OpenAPI schema
"""
try:
from litellm.proxy._types import ProxyChatCompletionRequest
return CustomOpenAPISpec.add_request_schema(
openapi_schema=openapi_schema,
model_class=ProxyChatCompletionRequest,
schema_name="ProxyChatCompletionRequest",
paths=CustomOpenAPISpec.CHAT_COMPLETION_PATHS,
operation_name="chat completion",
)
except ImportError as e:
verbose_proxy_logger.debug(
f"Failed to import ProxyChatCompletionRequest: {str(e)}"
)
return openapi_schema
@staticmethod
def add_embedding_request_schema(openapi_schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Add EmbeddingRequest schema to embedding endpoints for documentation.
This shows the request body in Swagger without runtime validation.
Args:
openapi_schema: The OpenAPI schema dict to modify
Returns:
Modified OpenAPI schema
"""
try:
from litellm.types.embedding import EmbeddingRequest
return CustomOpenAPISpec.add_request_schema(
openapi_schema=openapi_schema,
model_class=EmbeddingRequest,
schema_name="EmbeddingRequest",
paths=CustomOpenAPISpec.EMBEDDING_PATHS,
operation_name="embedding",
)
except ImportError as e:
verbose_proxy_logger.debug(f"Failed to import EmbeddingRequest: {str(e)}")
return openapi_schema
@staticmethod
def add_responses_api_request_schema(
openapi_schema: Dict[str, Any]
) -> Dict[str, Any]:
"""
Add ResponsesAPIRequestParams schema to responses API endpoints for documentation.
This shows the request body in Swagger without runtime validation.
Args:
openapi_schema: The OpenAPI schema dict to modify
Returns:
Modified OpenAPI schema
"""
try:
from litellm.types.llms.openai import ResponsesAPIRequestParams
return CustomOpenAPISpec.add_request_schema(
openapi_schema=openapi_schema,
model_class=ResponsesAPIRequestParams,
schema_name="ResponsesAPIRequestParams",
paths=CustomOpenAPISpec.RESPONSES_API_PATHS,
operation_name="responses API",
)
except ImportError as e:
verbose_proxy_logger.debug(
f"Failed to import ResponsesAPIRequestParams: {str(e)}"
)
return openapi_schema
@staticmethod
def add_llm_api_request_schema_body(
openapi_schema: Dict[str, Any]
) -> Dict[str, Any]:
"""
Add LLM API request schema bodies to OpenAPI specification for documentation.
Args:
openapi_schema: The base OpenAPI schema
Returns:
OpenAPI schema with added request body schemas
"""
# Add chat completion request schema
openapi_schema = CustomOpenAPISpec.add_chat_completion_request_schema(
openapi_schema
)
# Add embedding request schema
openapi_schema = CustomOpenAPISpec.add_embedding_request_schema(openapi_schema)
# Add responses API request schema
openapi_schema = CustomOpenAPISpec.add_responses_api_request_schema(
openapi_schema
)
return openapi_schema

View File

@@ -0,0 +1,832 @@
# Start tracing memory allocations
import asyncio
import gc
import json
import os
import sys
import tracemalloc
from collections import Counter
from typing import Any, Dict, List, Optional, Tuple
from fastapi import APIRouter, Depends, HTTPException, Query
from litellm import get_secret_str
from litellm._logging import verbose_proxy_logger
from litellm.constants import PYTHON_GC_THRESHOLD
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
router = APIRouter()
# Configure garbage collection thresholds from environment variables
def configure_gc_thresholds():
"""Configure Python garbage collection thresholds from environment variables."""
gc_threshold_env = PYTHON_GC_THRESHOLD
if gc_threshold_env:
try:
# Parse threshold string like "1000,50,50"
thresholds = [int(x.strip()) for x in gc_threshold_env.split(",")]
if len(thresholds) == 3:
gc.set_threshold(*thresholds)
verbose_proxy_logger.info(f"GC thresholds set to: {thresholds}")
else:
verbose_proxy_logger.warning(
f"GC threshold not set: {gc_threshold_env}. Expected format: 'gen0,gen1,gen2'"
)
except ValueError as e:
verbose_proxy_logger.warning(
f"Failed to parse GC threshold: {gc_threshold_env}. Error: {e}"
)
# Log current thresholds
current_thresholds = gc.get_threshold()
verbose_proxy_logger.info(
f"Current GC thresholds: gen0={current_thresholds[0]}, gen1={current_thresholds[1]}, gen2={current_thresholds[2]}"
)
# Initialize GC configuration
configure_gc_thresholds()
@router.get("/debug/asyncio-tasks")
async def get_active_tasks_stats():
"""
Returns:
total_active_tasks: int
by_name: { coroutine_name: count }
"""
MAX_TASKS_TO_CHECK = 5000
# Gather all tasks in this event loop (including this endpoints own task).
all_tasks = asyncio.all_tasks()
# Filter out tasks that are already done.
active_tasks = [t for t in all_tasks if not t.done()]
# Count how many active tasks exist, grouped by coroutine function name.
counter = Counter()
for idx, task in enumerate(active_tasks):
# reasonable max circuit breaker
if idx >= MAX_TASKS_TO_CHECK:
break
coro = task.get_coro()
# Derive a humanreadable name from the coroutine:
name = (
getattr(coro, "__qualname__", None)
or getattr(coro, "__name__", None)
or repr(coro)
)
counter[name] += 1
return {
"total_active_tasks": len(active_tasks),
"by_name": dict(counter),
}
if os.environ.get("LITELLM_PROFILE", "false").lower() == "true":
try:
import objgraph # type: ignore
print("growth of objects") # noqa
objgraph.show_growth()
print("\n\nMost common types") # noqa
objgraph.show_most_common_types()
roots = objgraph.get_leaking_objects()
print("\n\nLeaking objects") # noqa
objgraph.show_most_common_types(objects=roots)
except ImportError:
raise ImportError(
"objgraph not found. Please install objgraph to use this feature."
)
tracemalloc.start(10)
@router.get("/memory-usage", include_in_schema=False)
async def memory_usage():
# Take a snapshot of the current memory usage
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics("lineno")
verbose_proxy_logger.debug("TOP STATS: %s", top_stats)
# Get the top 50 memory usage lines
top_50 = top_stats[:50]
result = []
for stat in top_50:
result.append(f"{stat.traceback.format(limit=10)}: {stat.size / 1024} KiB")
return {"top_50_memory_usage": result}
@router.get("/memory-usage-in-mem-cache", include_in_schema=False)
async def memory_usage_in_mem_cache(
_: UserAPIKeyAuth = Depends(user_api_key_auth),
):
# returns the size of all in-memory caches on the proxy server
"""
1. user_api_key_cache
2. router_cache
3. proxy_logging_cache
4. internal_usage_cache
"""
from litellm.proxy.proxy_server import (
llm_router,
proxy_logging_obj,
user_api_key_cache,
)
if llm_router is None:
num_items_in_llm_router_cache = 0
else:
num_items_in_llm_router_cache = len(
llm_router.cache.in_memory_cache.cache_dict
) + len(llm_router.cache.in_memory_cache.ttl_dict)
num_items_in_user_api_key_cache = len(
user_api_key_cache.in_memory_cache.cache_dict
) + len(user_api_key_cache.in_memory_cache.ttl_dict)
num_items_in_proxy_logging_obj_cache = len(
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
) + len(proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict)
return {
"num_items_in_user_api_key_cache": num_items_in_user_api_key_cache,
"num_items_in_llm_router_cache": num_items_in_llm_router_cache,
"num_items_in_proxy_logging_obj_cache": num_items_in_proxy_logging_obj_cache,
}
@router.get("/memory-usage-in-mem-cache-items", include_in_schema=False)
async def memory_usage_in_mem_cache_items(
_: UserAPIKeyAuth = Depends(user_api_key_auth),
):
# returns the size of all in-memory caches on the proxy server
"""
1. user_api_key_cache
2. router_cache
3. proxy_logging_cache
4. internal_usage_cache
"""
from litellm.proxy.proxy_server import (
llm_router,
proxy_logging_obj,
user_api_key_cache,
)
if llm_router is None:
llm_router_in_memory_cache_dict = {}
llm_router_in_memory_ttl_dict = {}
else:
llm_router_in_memory_cache_dict = llm_router.cache.in_memory_cache.cache_dict
llm_router_in_memory_ttl_dict = llm_router.cache.in_memory_cache.ttl_dict
return {
"user_api_key_cache": user_api_key_cache.in_memory_cache.cache_dict,
"user_api_key_ttl": user_api_key_cache.in_memory_cache.ttl_dict,
"llm_router_cache": llm_router_in_memory_cache_dict,
"llm_router_ttl": llm_router_in_memory_ttl_dict,
"proxy_logging_obj_cache": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict,
"proxy_logging_obj_ttl": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict,
}
@router.get("/debug/memory/summary", include_in_schema=False)
async def get_memory_summary(
_: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> Dict[str, Any]:
"""
Get simplified memory usage summary for the proxy.
Returns:
- worker_pid: Process ID
- status: Overall health based on memory usage
- memory: Process memory usage and RAM info
- caches: Cache item counts and descriptions
- garbage_collector: GC status and pending object counts
Example usage:
curl http://localhost:4000/debug/memory/summary -H "Authorization: Bearer sk-1234"
For detailed analysis, call GET /debug/memory/details
For cache management, use the cache management endpoints
"""
from litellm.proxy.proxy_server import (
llm_router,
proxy_logging_obj,
user_api_key_cache,
)
# Get process memory info
process_memory = {}
health_status = "healthy"
try:
import psutil
process = psutil.Process()
memory_info = process.memory_info()
memory_mb = memory_info.rss / (1024 * 1024)
memory_percent = process.memory_percent()
process_memory = {
"summary": f"{memory_mb:.1f} MB ({memory_percent:.1f}% of system memory)",
"ram_usage_mb": round(memory_mb, 2),
"system_memory_percent": round(memory_percent, 2),
}
# Check memory health status
if memory_percent > 80:
health_status = "critical"
elif memory_percent > 60:
health_status = "warning"
else:
health_status = "healthy"
except ImportError:
process_memory[
"error"
] = "Install psutil for memory monitoring: pip install psutil"
except Exception as e:
process_memory["error"] = str(e)
# Get cache information
caches: Dict[str, Any] = {}
total_cache_items = 0
try:
# User API key cache
user_cache_items = len(user_api_key_cache.in_memory_cache.cache_dict)
total_cache_items += user_cache_items
caches["user_api_keys"] = {
"count": user_cache_items,
"count_readable": f"{user_cache_items:,}",
"what_it_stores": "Validated API keys for faster authentication",
}
# Router cache
if llm_router is not None:
router_cache_items = len(llm_router.cache.in_memory_cache.cache_dict)
total_cache_items += router_cache_items
caches["llm_responses"] = {
"count": router_cache_items,
"count_readable": f"{router_cache_items:,}",
"what_it_stores": "LLM responses for identical requests",
}
# Proxy logging cache
logging_cache_items = len(
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
)
total_cache_items += logging_cache_items
caches["usage_tracking"] = {
"count": logging_cache_items,
"count_readable": f"{logging_cache_items:,}",
"what_it_stores": "Usage metrics before database write",
}
except Exception as e:
caches["error"] = str(e)
# Get garbage collector stats
gc_enabled = gc.isenabled()
objects_pending = gc.get_count()[0]
uncollectable = len(gc.garbage)
gc_info = {
"status": "enabled" if gc_enabled else "disabled",
"objects_awaiting_collection": objects_pending,
}
# Add warning if garbage collection issues detected
if uncollectable > 0:
gc_info[
"warning"
] = f"{uncollectable} uncollectable objects (possible memory leak)"
return {
"worker_pid": os.getpid(),
"status": health_status,
"memory": process_memory,
"caches": {
"total_items": total_cache_items,
"breakdown": caches,
},
"garbage_collector": gc_info,
}
def _get_gc_statistics() -> Dict[str, Any]:
"""Get garbage collector statistics."""
return {
"enabled": gc.isenabled(),
"thresholds": {
"generation_0": gc.get_threshold()[0],
"generation_1": gc.get_threshold()[1],
"generation_2": gc.get_threshold()[2],
"explanation": "Number of allocations before automatic collection for each generation",
},
"current_counts": {
"generation_0": gc.get_count()[0],
"generation_1": gc.get_count()[1],
"generation_2": gc.get_count()[2],
"explanation": "Current number of allocated objects in each generation",
},
"collection_history": [
{
"generation": i,
"total_collections": stat["collections"],
"total_collected": stat["collected"],
"uncollectable": stat["uncollectable"],
}
for i, stat in enumerate(gc.get_stats())
],
}
def _get_object_type_counts(top_n: int) -> Tuple[int, List[Dict[str, Any]]]:
"""Count objects by type and return total count and top N types."""
type_counts: Counter = Counter()
total_objects = 0
for obj in gc.get_objects():
total_objects += 1
obj_type = type(obj).__name__
type_counts[obj_type] += 1
top_object_types = [
{"type": obj_type, "count": count, "count_readable": f"{count:,}"}
for obj_type, count in type_counts.most_common(top_n)
]
return total_objects, top_object_types
def _get_uncollectable_objects_info() -> Dict[str, Any]:
"""Get information about uncollectable objects (potential memory leaks)."""
uncollectable = gc.garbage
return {
"count": len(uncollectable),
"sample_types": [type(obj).__name__ for obj in uncollectable[:10]],
"warning": "If count > 0, you may have reference cycles preventing garbage collection"
if len(uncollectable) > 0
else None,
}
def _get_cache_memory_stats(
user_api_key_cache, llm_router, proxy_logging_obj, redis_usage_cache
) -> Dict[str, Any]:
"""Calculate memory usage for all caches."""
cache_stats: Dict[str, Any] = {}
try:
# User API key cache
user_cache_size = sys.getsizeof(user_api_key_cache.in_memory_cache.cache_dict)
user_ttl_size = sys.getsizeof(user_api_key_cache.in_memory_cache.ttl_dict)
cache_stats["user_api_key_cache"] = {
"num_items": len(user_api_key_cache.in_memory_cache.cache_dict),
"cache_dict_size_bytes": user_cache_size,
"ttl_dict_size_bytes": user_ttl_size,
"total_size_mb": round(
(user_cache_size + user_ttl_size) / (1024 * 1024), 2
),
}
# Router cache
if llm_router is not None:
router_cache_size = sys.getsizeof(
llm_router.cache.in_memory_cache.cache_dict
)
router_ttl_size = sys.getsizeof(llm_router.cache.in_memory_cache.ttl_dict)
cache_stats["llm_router_cache"] = {
"num_items": len(llm_router.cache.in_memory_cache.cache_dict),
"cache_dict_size_bytes": router_cache_size,
"ttl_dict_size_bytes": router_ttl_size,
"total_size_mb": round(
(router_cache_size + router_ttl_size) / (1024 * 1024), 2
),
}
# Proxy logging cache
logging_cache_size = sys.getsizeof(
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
)
logging_ttl_size = sys.getsizeof(
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict
)
cache_stats["proxy_logging_cache"] = {
"num_items": len(
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
),
"cache_dict_size_bytes": logging_cache_size,
"ttl_dict_size_bytes": logging_ttl_size,
"total_size_mb": round(
(logging_cache_size + logging_ttl_size) / (1024 * 1024), 2
),
}
# Redis cache info
if redis_usage_cache is not None:
cache_stats["redis_usage_cache"] = {
"enabled": True,
"cache_type": type(redis_usage_cache).__name__,
}
# Try to get Redis connection pool info if available
try:
if (
hasattr(redis_usage_cache, "redis_client")
and redis_usage_cache.redis_client
):
if hasattr(redis_usage_cache.redis_client, "connection_pool"):
pool_info = redis_usage_cache.redis_client.connection_pool # type: ignore
cache_stats["redis_usage_cache"]["connection_pool"] = {
"max_connections": pool_info.max_connections
if hasattr(pool_info, "max_connections")
else None,
"connection_class": pool_info.connection_class.__name__
if hasattr(pool_info, "connection_class")
else None,
}
except Exception as e:
verbose_proxy_logger.debug(f"Error getting Redis pool info: {e}")
else:
cache_stats["redis_usage_cache"] = {"enabled": False}
except Exception as e:
verbose_proxy_logger.debug(f"Error calculating cache stats: {e}")
cache_stats["error"] = str(e)
return cache_stats
def _get_router_memory_stats(llm_router) -> Dict[str, Any]:
"""Get memory usage statistics for LiteLLM router."""
litellm_router_memory: Dict[str, Any] = {}
try:
if llm_router is not None:
# Model list memory size
if hasattr(llm_router, "model_list") and llm_router.model_list:
model_list_size = sys.getsizeof(llm_router.model_list)
litellm_router_memory["model_list"] = {
"num_models": len(llm_router.model_list),
"size_bytes": model_list_size,
"size_mb": round(model_list_size / (1024 * 1024), 4),
}
# Model names set
if hasattr(llm_router, "model_names") and llm_router.model_names:
model_names_size = sys.getsizeof(llm_router.model_names)
litellm_router_memory["model_names_set"] = {
"num_model_groups": len(llm_router.model_names),
"size_bytes": model_names_size,
"size_mb": round(model_names_size / (1024 * 1024), 4),
}
# Deployment names list
if hasattr(llm_router, "deployment_names") and llm_router.deployment_names:
deployment_names_size = sys.getsizeof(llm_router.deployment_names)
litellm_router_memory["deployment_names"] = {
"num_deployments": len(llm_router.deployment_names),
"size_bytes": deployment_names_size,
"size_mb": round(deployment_names_size / (1024 * 1024), 4),
}
# Deployment latency map
if (
hasattr(llm_router, "deployment_latency_map")
and llm_router.deployment_latency_map
):
latency_map_size = sys.getsizeof(llm_router.deployment_latency_map)
litellm_router_memory["deployment_latency_map"] = {
"num_tracked_deployments": len(llm_router.deployment_latency_map),
"size_bytes": latency_map_size,
"size_mb": round(latency_map_size / (1024 * 1024), 4),
}
# Fallback configuration
if hasattr(llm_router, "fallbacks") and llm_router.fallbacks:
fallbacks_size = sys.getsizeof(llm_router.fallbacks)
litellm_router_memory["fallbacks"] = {
"num_fallback_configs": len(llm_router.fallbacks),
"size_bytes": fallbacks_size,
"size_mb": round(fallbacks_size / (1024 * 1024), 4),
}
# Total router object size
router_obj_size = sys.getsizeof(llm_router)
litellm_router_memory["router_object"] = {
"size_bytes": router_obj_size,
"size_mb": round(router_obj_size / (1024 * 1024), 4),
}
else:
litellm_router_memory = {"note": "Router not initialized"}
except Exception as e:
verbose_proxy_logger.debug(f"Error getting router memory info: {e}")
litellm_router_memory = {"error": str(e)}
return litellm_router_memory
def _get_process_memory_info(
worker_pid: int, include_process_info: bool
) -> Optional[Dict[str, Any]]:
"""Get process-level memory information using psutil."""
if not include_process_info:
return None
try:
import psutil
process = psutil.Process()
memory_info = process.memory_info()
ram_usage_mb = round(memory_info.rss / (1024 * 1024), 2)
virtual_memory_mb = round(memory_info.vms / (1024 * 1024), 2)
memory_percent = round(process.memory_percent(), 2)
return {
"pid": worker_pid,
"summary": f"Worker PID {worker_pid} using {ram_usage_mb:.1f} MB of RAM ({memory_percent:.1f}% of system memory)",
"ram_usage": {
"megabytes": ram_usage_mb,
"description": "Actual physical RAM used by this process",
},
"virtual_memory": {
"megabytes": virtual_memory_mb,
"description": "Total virtual memory allocated (includes swapped memory)",
},
"system_memory_percent": {
"percent": memory_percent,
"description": "Percentage of total system RAM being used",
},
"open_file_handles": {
"count": process.num_fds()
if hasattr(process, "num_fds")
else "N/A (Windows)",
"description": "Number of open file descriptors/handles",
},
"threads": {
"count": process.num_threads(),
"description": "Number of active threads in this process",
},
}
except ImportError:
return {
"pid": worker_pid,
"error": "psutil not installed. Install with: pip install psutil",
}
except Exception as e:
verbose_proxy_logger.debug(f"Error getting process info: {e}")
return {"pid": worker_pid, "error": str(e)}
@router.get("/debug/memory/details", include_in_schema=False)
async def get_memory_details(
_: UserAPIKeyAuth = Depends(user_api_key_auth),
top_n: int = Query(20, description="Number of top object types to return"),
include_process_info: bool = Query(True, description="Include process memory info"),
) -> Dict[str, Any]:
"""
Get detailed memory diagnostics for deep debugging.
Returns:
- worker_pid: Process ID
- process_memory: RAM usage, virtual memory, file handles, threads
- garbage_collector: GC thresholds, counts, collection history
- objects: Total tracked objects and top object types
- uncollectable: Objects that can't be garbage collected (potential leaks)
- cache_memory: Memory usage of user_api_key, router, and logging caches
- router_memory: Memory usage of router components (model_list, deployment_names, etc.)
Query Parameters:
- top_n: Number of top object types to return (default: 20)
- include_process_info: Include process-level memory info using psutil (default: true)
Example usage:
curl "http://localhost:4000/debug/memory/details?top_n=30" -H "Authorization: Bearer sk-1234"
All memory sizes are reported in both bytes and MB.
"""
from litellm.proxy.proxy_server import (
llm_router,
proxy_logging_obj,
user_api_key_cache,
redis_usage_cache,
)
worker_pid = os.getpid()
# Collect all diagnostics using helper functions
gc_stats = _get_gc_statistics()
total_objects, top_object_types = _get_object_type_counts(top_n)
uncollectable_info = _get_uncollectable_objects_info()
cache_stats = _get_cache_memory_stats(
user_api_key_cache, llm_router, proxy_logging_obj, redis_usage_cache
)
litellm_router_memory = _get_router_memory_stats(llm_router)
process_info = _get_process_memory_info(worker_pid, include_process_info)
return {
"worker_pid": worker_pid,
"process_memory": process_info,
"garbage_collector": gc_stats,
"objects": {
"total_tracked": total_objects,
"total_tracked_readable": f"{total_objects:,}",
"top_types": top_object_types,
},
"uncollectable": uncollectable_info,
"cache_memory": cache_stats,
"router_memory": litellm_router_memory,
}
@router.post("/debug/memory/gc/configure", include_in_schema=False)
async def configure_gc_thresholds_endpoint(
_: UserAPIKeyAuth = Depends(user_api_key_auth),
generation_0: int = Query(700, description="Generation 0 threshold (default: 700)"),
generation_1: int = Query(10, description="Generation 1 threshold (default: 10)"),
generation_2: int = Query(10, description="Generation 2 threshold (default: 10)"),
) -> Dict[str, Any]:
"""
Configure Python garbage collection thresholds.
Lower thresholds mean more frequent GC cycles (less memory, more CPU overhead).
Higher thresholds mean less frequent GC cycles (more memory, less CPU overhead).
Returns:
- message: Confirmation message
- previous_thresholds: Old threshold values
- new_thresholds: New threshold values
- objects_awaiting_collection: Current object count in gen-0
- tip: Hint about when next collection will occur
Query Parameters:
- generation_0: Number of allocations before gen-0 collection (default: 700)
- generation_1: Number of gen-0 collections before gen-1 collection (default: 10)
- generation_2: Number of gen-1 collections before gen-2 collection (default: 10)
Example for more aggressive collection:
curl -X POST "http://localhost:4000/debug/memory/gc/configure?generation_0=500" -H "Authorization: Bearer sk-1234"
Example for less aggressive collection:
curl -X POST "http://localhost:4000/debug/memory/gc/configure?generation_0=1000" -H "Authorization: Bearer sk-1234"
Monitor memory usage with GET /debug/memory/summary after changes.
"""
# Get current thresholds for logging
old_thresholds = gc.get_threshold()
# Set new thresholds with error handling
try:
gc.set_threshold(generation_0, generation_1, generation_2)
verbose_proxy_logger.info(
f"GC thresholds updated from {old_thresholds} to "
f"({generation_0}, {generation_1}, {generation_2})"
)
except Exception as e:
verbose_proxy_logger.error(f"Failed to set GC thresholds: {e}")
raise HTTPException(
status_code=500, detail=f"Failed to set GC thresholds: {str(e)}"
)
# Get current object count to show immediate impact
current_count = gc.get_count()[0]
return {
"message": "GC thresholds updated",
"previous_thresholds": f"{old_thresholds[0]}, {old_thresholds[1]}, {old_thresholds[2]}",
"new_thresholds": f"{generation_0}, {generation_1}, {generation_2}",
"objects_awaiting_collection": current_count,
"tip": f"Next collection will run after {generation_0 - current_count} more allocations",
}
@router.get("/otel-spans", include_in_schema=False)
async def get_otel_spans():
from litellm.proxy.proxy_server import open_telemetry_logger
if open_telemetry_logger is None:
return {
"otel_spans": [],
"spans_grouped_by_parent": {},
"most_recent_parent": None,
}
otel_exporter = open_telemetry_logger.OTEL_EXPORTER
if hasattr(otel_exporter, "get_finished_spans"):
recorded_spans = otel_exporter.get_finished_spans() # type: ignore
else:
recorded_spans = []
print("Spans: ", recorded_spans) # noqa
most_recent_parent = None
most_recent_start_time = 1000000
spans_grouped_by_parent = {}
for span in recorded_spans:
if span.parent is not None:
parent_trace_id = span.parent.trace_id
if parent_trace_id not in spans_grouped_by_parent:
spans_grouped_by_parent[parent_trace_id] = []
spans_grouped_by_parent[parent_trace_id].append(span.name)
# check time of span
if span.start_time > most_recent_start_time:
most_recent_parent = parent_trace_id
most_recent_start_time = span.start_time
# these are otel spans - get the span name
span_names = [span.name for span in recorded_spans]
return {
"otel_spans": span_names,
"spans_grouped_by_parent": spans_grouped_by_parent,
"most_recent_parent": most_recent_parent,
}
# Helper functions for debugging
def init_verbose_loggers():
try:
worker_config = get_secret_str("WORKER_CONFIG")
# if not, assume it's a json string
if worker_config is None:
return
if os.path.isfile(worker_config):
return
_settings = json.loads(worker_config)
if not isinstance(_settings, dict):
return
debug = _settings.get("debug", None)
detailed_debug = _settings.get("detailed_debug", None)
if debug is True: # this needs to be first, so users can see Router init debugg
import logging
from litellm._logging import (
verbose_logger,
verbose_proxy_logger,
verbose_router_logger,
)
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
verbose_logger.setLevel(level=logging.INFO) # sets package logs to info
verbose_router_logger.setLevel(
level=logging.INFO
) # set router logs to info
verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info
if detailed_debug is True:
import logging
from litellm._logging import (
verbose_logger,
verbose_proxy_logger,
verbose_router_logger,
)
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
verbose_router_logger.setLevel(
level=logging.DEBUG
) # set router logs to debug
verbose_proxy_logger.setLevel(
level=logging.DEBUG
) # set proxy logs to debug
elif debug is False and detailed_debug is False:
# users can control proxy debugging using env variable = 'LITELLM_LOG'
litellm_log_setting = os.environ.get("LITELLM_LOG", "")
if litellm_log_setting is not None:
if litellm_log_setting.upper() == "INFO":
import logging
from litellm._logging import (
verbose_proxy_logger,
verbose_router_logger,
)
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
verbose_router_logger.setLevel(
level=logging.INFO
) # set router logs to info
verbose_proxy_logger.setLevel(
level=logging.INFO
) # set proxy logs to info
elif litellm_log_setting.upper() == "DEBUG":
import logging
from litellm._logging import (
verbose_proxy_logger,
verbose_router_logger,
)
verbose_router_logger.setLevel(
level=logging.DEBUG
) # set router logs to info
verbose_proxy_logger.setLevel(
level=logging.DEBUG
) # set proxy logs to debug
except Exception as e:
import logging
logging.warning(f"Failed to init verbose loggers: {str(e)}")

View File

@@ -0,0 +1,122 @@
import base64
import os
from typing import Literal, Optional
from litellm._logging import verbose_proxy_logger
def _get_salt_key():
from litellm.proxy.proxy_server import master_key
salt_key = os.getenv("LITELLM_SALT_KEY", None)
if salt_key is None:
salt_key = master_key
return salt_key
def encrypt_value_helper(value: str, new_encryption_key: Optional[str] = None):
signing_key = new_encryption_key or _get_salt_key()
try:
if isinstance(value, str):
encrypted_value = encrypt_value(value=value, signing_key=signing_key) # type: ignore
# Use urlsafe_b64encode for URL-safe base64 encoding (replaces + with - and / with _)
encrypted_value = base64.urlsafe_b64encode(encrypted_value).decode("utf-8")
return encrypted_value
verbose_proxy_logger.debug(
f"Invalid value type passed to encrypt_value: {type(value)} for Value: {value}\n Value must be a string"
)
# if it's not a string - do not encrypt it and return the value
return value
except Exception as e:
raise e
def decrypt_value_helper(
value: str,
key: str, # this is just for debug purposes, showing the k,v pair that's invalid. not a signing key.
exception_type: Literal["debug", "error"] = "error",
return_original_value: bool = False,
):
signing_key = _get_salt_key()
try:
if isinstance(value, str):
# Try URL-safe base64 decoding first (new format)
# Fall back to standard base64 decoding for backwards compatibility (old format)
try:
decoded_b64 = base64.urlsafe_b64decode(value)
except Exception:
# If URL-safe decoding fails, try standard base64 decoding for backwards compatibility
decoded_b64 = base64.b64decode(value)
value = decrypt_value(value=decoded_b64, signing_key=signing_key) # type: ignore
return value
# if it's not str - do not decrypt it, return the value
return value
except Exception as e:
error_message = f"Error decrypting value for key: {key}, Did your master_key/salt key change recently? \nError: {str(e)}\nSet permanent salt key - https://docs.litellm.ai/docs/proxy/prod#5-set-litellm-salt-key"
if exception_type == "debug":
verbose_proxy_logger.debug(error_message)
return value if return_original_value else None
verbose_proxy_logger.debug(
f"Unable to decrypt value={value} for key: {key}, returning None"
)
if return_original_value:
return value
else:
verbose_proxy_logger.exception(error_message)
# [Non-Blocking Exception. - this should not block decrypting other values]
return None
def encrypt_value(value: str, signing_key: str):
import hashlib
import nacl.secret
import nacl.utils
# get 32 byte master key #
hash_object = hashlib.sha256(signing_key.encode())
hash_bytes = hash_object.digest()
# initialize secret box #
box = nacl.secret.SecretBox(hash_bytes)
# encode message #
value_bytes = value.encode("utf-8")
encrypted = box.encrypt(value_bytes)
return encrypted
def decrypt_value(value: bytes, signing_key: str) -> str:
import hashlib
import nacl.secret
import nacl.utils
# get 32 byte master key #
hash_object = hashlib.sha256(signing_key.encode())
hash_bytes = hash_object.digest()
# initialize secret box #
box = nacl.secret.SecretBox(hash_bytes)
# Convert the bytes object to a string
try:
if len(value) == 0:
return ""
plaintext = box.decrypt(value)
plaintext = plaintext.decode("utf-8") # type: ignore
return plaintext # type: ignore
except Exception as e:
raise e

View File

@@ -0,0 +1,82 @@
"""
Utility class for getting routes from a FastAPI app.
"""
from typing import Any, Dict, List, Optional
from starlette.routing import BaseRoute
from litellm._logging import verbose_logger
class GetRoutes:
@staticmethod
def get_app_routes(
route: BaseRoute,
endpoint_route: Any,
) -> List[Dict[str, Any]]:
"""
Get routes for a regular route.
"""
routes: List[Dict[str, Any]] = []
route_info = {
"path": getattr(route, "path", None),
"methods": getattr(route, "methods", None),
"name": getattr(route, "name", None),
"endpoint": (
endpoint_route.__name__ if getattr(route, "endpoint", None) else None
),
}
routes.append(route_info)
return routes
@staticmethod
def get_routes_for_mounted_app(
route: BaseRoute,
) -> List[Dict[str, Any]]:
"""
Get routes for a mounted sub-application.
"""
routes: List[Dict[str, Any]] = []
mount_path = getattr(route, "path", "")
sub_app = getattr(route, "app", None)
if sub_app and hasattr(sub_app, "routes"):
for sub_route in sub_app.routes:
# Get endpoint - either from endpoint attribute or app attribute
endpoint_func = getattr(sub_route, "endpoint", None) or getattr(
sub_route, "app", None
)
if endpoint_func is not None:
sub_route_path = getattr(sub_route, "path", "")
full_path = mount_path.rstrip("/") + sub_route_path
route_info = {
"path": full_path,
"methods": getattr(sub_route, "methods", ["GET", "POST"]),
"name": getattr(sub_route, "name", None),
"endpoint": GetRoutes._safe_get_endpoint_name(endpoint_func),
"mounted_app": True,
}
routes.append(route_info)
return routes
@staticmethod
def _safe_get_endpoint_name(endpoint_function: Any) -> Optional[str]:
"""
Safely get the name of the endpoint function.
"""
try:
if hasattr(endpoint_function, "__name__"):
return getattr(endpoint_function, "__name__")
elif hasattr(endpoint_function, "__class__") and hasattr(
endpoint_function.__class__, "__name__"
):
return getattr(endpoint_function.__class__, "__name__")
else:
return None
except Exception:
verbose_logger.exception(
f"Error getting endpoint name for route: {endpoint_function}"
)
return None

View File

@@ -0,0 +1,207 @@
from litellm.proxy.common_utils.banner import LITELLM_BANNER
def render_cli_sso_success_page() -> str:
"""
Renders the CLI SSO authentication success page with minimal styling
Returns:
str: HTML content for the success page
"""
html_content = f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>CLI Authentication Successful - LiteLLM</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {{
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
background-color: #f8fafc;
margin: 0;
padding: 20px;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
color: #1e293b;
}}
.container {{
background-color: #fff;
padding: 40px;
border-radius: 8px;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
width: 450px;
max-width: 100%;
text-align: center;
}}
.logo-container {{
margin-bottom: 20px;
}}
.logo {{
font-size: 24px;
font-weight: 600;
color: #1e293b;
}}
h1 {{
margin: 0 0 10px;
color: #1e293b;
font-size: 28px;
font-weight: 600;
}}
.subtitle {{
color: #64748b;
margin: 0 0 30px;
font-size: 16px;
}}
.banner {{
background-color: #f8fafc;
color: #334155;
font-family: 'Courier New', Consolas, monospace;
font-size: 10px;
line-height: 1.1;
white-space: pre;
padding: 20px;
border-radius: 6px;
margin: 20px 0;
text-align: center;
border: 1px solid #e2e8f0;
overflow-x: auto;
}}
.success-box {{
background-color: #f8fafc;
border-radius: 6px;
padding: 20px;
margin-bottom: 30px;
border: 1px solid #e2e8f0;
}}
.success-header {{
display: flex;
align-items: center;
justify-content: center;
margin-bottom: 12px;
color: #1e293b;
font-weight: 600;
font-size: 16px;
}}
.success-header svg {{
margin-right: 8px;
}}
.success-box p {{
color: #64748b;
margin: 8px 0;
line-height: 1.5;
font-size: 14px;
}}
.instructions {{
background-color: #f8fafc;
border-radius: 6px;
padding: 20px;
margin-bottom: 20px;
border: 1px solid #e2e8f0;
}}
.instructions-header {{
display: flex;
align-items: center;
justify-content: center;
margin-bottom: 12px;
color: #1e293b;
font-weight: 600;
font-size: 16px;
}}
.instructions-header svg {{
margin-right: 8px;
}}
.instructions p {{
color: #64748b;
margin: 8px 0;
line-height: 1.5;
font-size: 14px;
}}
.countdown {{
color: #64748b;
font-size: 14px;
font-weight: 500;
padding: 12px;
background-color: #f8fafc;
border-radius: 6px;
border: 1px solid #e2e8f0;
}}
</style>
</head>
<body>
<div class="container">
<div class="logo-container">
<div class="logo">
🚅 LiteLLM
</div>
</div>
<div class="banner">{LITELLM_BANNER}</div>
<h1>Authentication Successful!</h1>
<p class="subtitle">Your CLI authentication is complete.</p>
<div class="success-box">
<div class="success-header">
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M9 12l2 2 4-4"></path>
<circle cx="12" cy="12" r="10"></circle>
</svg>
CLI Authentication Complete
</div>
<p>Your LiteLLM CLI has been successfully authenticated and is ready to use.</p>
</div>
<div class="instructions">
<div class="instructions-header">
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<circle cx="12" cy="12" r="10"></circle>
<line x1="12" y1="16" x2="12" y2="12"></line>
<line x1="12" y1="8" x2="12.01" y2="8"></line>
</svg>
Next Steps
</div>
<p>Return to your terminal - the CLI will automatically detect the successful authentication.</p>
<p>You can now use LiteLLM CLI commands with your authenticated session.</p>
</div>
<div class="countdown" id="countdown">This window will close in 3 seconds...</div>
</div>
<script>
let seconds = 3;
const countdownElement = document.getElementById('countdown');
const countdown = setInterval(function() {{
seconds--;
if (seconds > 0) {{
countdownElement.textContent = `This window will close in ${{seconds}} second${{seconds === 1 ? '' : 's'}}...`;
}} else {{
countdownElement.textContent = 'Closing...';
clearInterval(countdown);
window.close();
}}
}}, 1000);
</script>
</body>
</html>
"""
return html_content

View File

@@ -0,0 +1,284 @@
# JWT display template for SSO debug callback
jwt_display_template = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>LiteLLM SSO Debug - JWT Information</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
background-color: #f8fafc;
margin: 0;
padding: 20px;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
color: #333;
}
.container {
background-color: #fff;
padding: 40px;
border-radius: 8px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
width: 800px;
max-width: 100%;
}
.logo-container {
text-align: center;
margin-bottom: 30px;
}
.logo {
font-size: 24px;
font-weight: 600;
color: #1e293b;
}
h2 {
margin: 0 0 10px;
color: #1e293b;
font-size: 28px;
font-weight: 600;
text-align: center;
}
.subtitle {
color: #64748b;
margin: 0 0 20px;
font-size: 16px;
text-align: center;
}
.info-box {
background-color: #f1f5f9;
border-radius: 6px;
padding: 20px;
margin-bottom: 30px;
border-left: 4px solid #2563eb;
}
.success-box {
background-color: #f0fdf4;
border-radius: 6px;
padding: 20px;
margin-bottom: 30px;
border-left: 4px solid #16a34a;
}
.info-header {
display: flex;
align-items: center;
margin-bottom: 12px;
color: #1e40af;
font-weight: 600;
font-size: 16px;
}
.success-header {
display: flex;
align-items: center;
margin-bottom: 12px;
color: #166534;
font-weight: 600;
font-size: 16px;
}
.info-header svg, .success-header svg {
margin-right: 8px;
}
.data-container {
margin-top: 20px;
}
.data-row {
display: flex;
border-bottom: 1px solid #e2e8f0;
padding: 12px 0;
}
.data-row:last-child {
border-bottom: none;
}
.data-label {
font-weight: 500;
color: #334155;
width: 180px;
flex-shrink: 0;
}
.data-value {
color: #475569;
word-break: break-all;
}
.jwt-container {
background-color: #f8fafc;
border-radius: 6px;
padding: 15px;
margin-top: 20px;
overflow-x: auto;
border: 1px solid #e2e8f0;
}
.jwt-text {
font-family: monospace;
white-space: pre-wrap;
word-break: break-all;
margin: 0;
color: #334155;
}
.back-button {
display: inline-block;
background-color: #6466E9;
color: #fff;
text-decoration: none;
padding: 10px 16px;
border-radius: 6px;
font-weight: 500;
margin-top: 20px;
text-align: center;
}
.back-button:hover {
background-color: #4138C2;
text-decoration: none;
}
.buttons {
display: flex;
gap: 10px;
margin-top: 20px;
}
.copy-button {
background-color: #e2e8f0;
color: #334155;
border: none;
padding: 8px 12px;
border-radius: 4px;
cursor: pointer;
font-size: 14px;
display: flex;
align-items: center;
}
.copy-button:hover {
background-color: #cbd5e1;
}
.copy-button svg {
margin-right: 6px;
}
</style>
</head>
<body>
<div class="container">
<div class="logo-container">
<div class="logo">
🚅 LiteLLM
</div>
</div>
<h2>SSO Debug Information</h2>
<p class="subtitle">Results from the SSO authentication process.</p>
<div class="success-box">
<div class="success-header">
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M22 11.08V12a10 10 0 1 1-5.93-9.14"></path>
<polyline points="22 4 12 14.01 9 11.01"></polyline>
</svg>
Authentication Successful
</div>
<p>The SSO authentication completed successfully. Below is the information returned by the provider.</p>
</div>
<div class="data-container" id="userData">
<!-- Data will be inserted here by JavaScript -->
</div>
<div class="info-box">
<div class="info-header">
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<circle cx="12" cy="12" r="10"></circle>
<line x1="12" y1="16" x2="12" y2="12"></line>
<line x1="12" y1="8" x2="12.01" y2="8"></line>
</svg>
JSON Representation
</div>
<div class="jwt-container">
<pre class="jwt-text" id="jsonData">Loading...</pre>
</div>
<div class="buttons">
<button class="copy-button" onclick="copyToClipboard('jsonData')">
<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect>
<path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path>
</svg>
Copy to Clipboard
</button>
</div>
</div>
<a href="/sso/debug/login" class="back-button">
Try Another SSO Login
</a>
</div>
<script>
// This will be populated with the actual data from the server
const userData = SSO_DATA;
function renderUserData() {
const container = document.getElementById('userData');
const jsonDisplay = document.getElementById('jsonData');
// Format JSON with indentation for display
jsonDisplay.textContent = JSON.stringify(userData, null, 2);
// Clear container
container.innerHTML = '';
// Add each key-value pair to the UI
for (const [key, value] of Object.entries(userData)) {
if (typeof value !== 'object' || value === null) {
const row = document.createElement('div');
row.className = 'data-row';
const label = document.createElement('div');
label.className = 'data-label';
label.textContent = key;
const dataValue = document.createElement('div');
dataValue.className = 'data-value';
dataValue.textContent = value !== null ? value : 'null';
row.appendChild(label);
row.appendChild(dataValue);
container.appendChild(row);
}
}
}
function copyToClipboard(elementId) {
const text = document.getElementById(elementId).textContent;
navigator.clipboard.writeText(text).then(() => {
alert('Copied to clipboard!');
}).catch(err => {
console.error('Could not copy text: ', err);
});
}
// Render the data when the page loads
document.addEventListener('DOMContentLoaded', renderUserData);
</script>
</body>
</html>
"""

View File

@@ -0,0 +1,269 @@
import os
from litellm.proxy.utils import get_custom_url
url_to_redirect_to = os.getenv("PROXY_BASE_URL", "")
server_root_path = os.getenv("SERVER_ROOT_PATH", "")
if server_root_path != "":
url_to_redirect_to += server_root_path
url_to_redirect_to += "/login"
new_ui_login_url = get_custom_url("", "ui/login")
def build_ui_login_form(show_deprecation_banner: bool = False) -> str:
banner_html = (
f"""
<div class="deprecation-banner">
<strong>Deprecated:</strong> Logging in with username and password on this page is deprecated.
Please use the <a href="{new_ui_login_url}">new login page</a> instead.
This page will be dedicated to signing in via SSO in the future.
</div>
"""
if show_deprecation_banner
else ""
)
return f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>LiteLLM Login</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {{
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
background-color: #f8fafc;
margin: 0;
padding: 20px;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
color: #333;
}}
form {{
background-color: #fff;
padding: 40px;
border-radius: 8px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
width: 450px;
max-width: 100%;
}}
.logo-container {{
text-align: center;
margin-bottom: 30px;
}}
.logo {{
font-size: 24px;
font-weight: 600;
color: #1e293b;
}}
h2 {{
margin: 0 0 10px;
color: #1e293b;
font-size: 28px;
font-weight: 600;
text-align: center;
}}
.subtitle {{
color: #64748b;
margin: 0 0 20px;
font-size: 16px;
text-align: center;
}}
.info-box {{
background-color: #f1f5f9;
border-radius: 6px;
padding: 20px;
margin-bottom: 30px;
border-left: 4px solid #2563eb;
}}
.info-header {{
display: flex;
align-items: center;
margin-bottom: 12px;
color: #1e40af;
font-weight: 600;
font-size: 16px;
}}
.info-header svg {{
margin-right: 8px;
}}
.info-box p {{
color: #475569;
margin: 8px 0;
line-height: 1.5;
font-size: 14px;
}}
label {{
display: block;
margin-bottom: 8px;
font-weight: 500;
color: #334155;
font-size: 14px;
}}
.required {{
color: #dc2626;
margin-left: 2px;
}}
input[type="text"],
input[type="password"] {{
width: 100%;
padding: 10px 14px;
margin-bottom: 20px;
box-sizing: border-box;
border: 1px solid #e2e8f0;
border-radius: 6px;
font-size: 15px;
color: #1e293b;
background-color: #fff;
transition: border-color 0.2s, box-shadow 0.2s;
}}
input[type="text"]:focus,
input[type="password"]:focus {{
outline: none;
border-color: #3b82f6;
box-shadow: 0 0 0 2px rgba(59, 130, 246, 0.2);
}}
.toggle-password {{
display: flex;
align-items: center;
margin-top: -15px;
margin-bottom: 20px;
}}
.toggle-password input[type="checkbox"] {{
margin-right: 8px;
vertical-align: middle;
width: 16px;
height: 16px;
}}
.toggle-password label {{
margin-bottom: 0;
font-size: 14px;
cursor: pointer;
line-height: 1;
}}
input[type="submit"] {{
background-color: #6466E9;
color: #fff;
cursor: pointer;
font-weight: 500;
border: none;
padding: 10px 16px;
transition: background-color 0.2s;
border-radius: 6px;
margin-top: 10px;
font-size: 14px;
width: 100%;
}}
input[type="submit"]:hover {{
background-color: #4138C2;
}}
a {{
color: #3b82f6;
text-decoration: none;
}}
a:hover {{
text-decoration: underline;
}}
code {{
background-color: #f1f5f9;
padding: 2px 4px;
border-radius: 4px;
font-family: monospace;
font-size: 13px;
color: #334155;
}}
.help-text {{
color: #64748b;
font-size: 14px;
margin-top: -12px;
margin-bottom: 20px;
}}
.deprecation-banner {{
background-color: #fee2e2;
border: 1px solid #ef4444;
color: #991b1b;
padding: 14px 16px;
border-radius: 6px;
margin-bottom: 20px;
font-size: 14px;
line-height: 1.5;
}}
.deprecation-banner a {{
color: #991b1b;
font-weight: 600;
text-decoration: underline;
}}
</style>
</head>
<body>
<form action="{url_to_redirect_to}" method="post">
{banner_html}
<div class="logo-container">
<div class="logo">
🚅 LiteLLM
</div>
</div>
<h2>Login</h2>
<p class="subtitle">Access your LiteLLM Admin UI.</p>
<div class="info-box">
<div class="info-header">
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<circle cx="12" cy="12" r="10"></circle>
<line x1="12" y1="16" x2="12" y2="12"></line>
<line x1="12" y1="8" x2="12.01" y2="8"></line>
</svg>
Default Credentials
</div>
<p>By default, Username is <code>admin</code> and Password is your set LiteLLM Proxy <code>MASTER_KEY</code>.</p>
<p>Need to set UI credentials or SSO? <a href="https://docs.litellm.ai/docs/proxy/ui" target="_blank">Check the documentation</a>.</p>
</div>
<label for="username">Username<span class="required">*</span></label>
<input type="text" id="username" name="username" required placeholder="Enter your username" autocomplete="username">
<label for="password">Password<span class="required">*</span></label>
<input type="password" id="password" name="password" required placeholder="Enter your password" autocomplete="current-password">
<div class="toggle-password">
<input type="checkbox" id="show-password" onclick="togglePasswordVisibility()">
<label for="show-password">Show password</label>
</div>
<input type="submit" value="Login">
</form>
<script>
function togglePasswordVisibility() {{
var passwordField = document.getElementById("password");
passwordField.type = passwordField.type === "password" ? "text" : "password";
}}
</script>
</body>
</html>
"""
html_form = build_ui_login_form(show_deprecation_banner=True)

View File

@@ -0,0 +1,522 @@
import json
import re
from typing import Any, Collection, Dict, List, Optional
import orjson
from fastapi import Request, UploadFile, status
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ProxyException
from litellm.proxy.common_utils.callback_utils import (
get_metadata_variable_name_from_kwargs,
)
from litellm.types.router import Deployment
async def _read_request_body(request: Optional[Request]) -> Dict:
"""
Safely read the request body and parse it as JSON.
Parameters:
- request: The request object to read the body from
Returns:
- dict: Parsed request data as a dictionary or an empty dictionary if parsing fails
"""
try:
if request is None:
return {}
# Check if we already read and parsed the body
_cached_request_body: Optional[dict] = _safe_get_request_parsed_body(
request=request
)
if _cached_request_body is not None:
return _cached_request_body
_request_headers: dict = _safe_get_request_headers(request=request)
content_type = _request_headers.get("content-type", "")
if "form" in content_type:
parsed_body = dict(await request.form())
if "metadata" in parsed_body and isinstance(parsed_body["metadata"], str):
parsed_body["metadata"] = json.loads(parsed_body["metadata"])
else:
# Read the request body
body = await request.body()
# Return empty dict if body is empty or None
if not body:
parsed_body = {}
else:
try:
parsed_body = orjson.loads(body)
except orjson.JSONDecodeError as e:
# First try the standard json module which is more forgiving
# First decode bytes to string if needed
body_str = body.decode("utf-8") if isinstance(body, bytes) else body
# Replace invalid surrogate pairs
# This regex finds incomplete surrogate pairs
body_str = re.sub(
r"[\uD800-\uDBFF](?![\uDC00-\uDFFF])", "", body_str
)
# This regex finds low surrogates without high surrogates
body_str = re.sub(
r"(?<![\uD800-\uDBFF])[\uDC00-\uDFFF]", "", body_str
)
try:
parsed_body = json.loads(body_str)
except json.JSONDecodeError:
# If both orjson and json.loads fail, throw a proper error
verbose_proxy_logger.error(
f"Invalid JSON payload received: {str(e)}"
)
raise ProxyException(
message=f"Invalid JSON payload: {str(e)}",
type="invalid_request_error",
param="request_body",
code=status.HTTP_400_BAD_REQUEST,
)
# Cache the parsed result
_safe_set_request_parsed_body(request=request, parsed_body=parsed_body)
return parsed_body
except (json.JSONDecodeError, orjson.JSONDecodeError, ProxyException) as e:
# Re-raise ProxyException as-is
verbose_proxy_logger.error(f"Invalid JSON payload received: {str(e)}")
raise
except Exception as e:
# Catch unexpected errors to avoid crashes
verbose_proxy_logger.exception(
"Unexpected error reading request body - {}".format(e)
)
return {}
def _safe_get_request_parsed_body(request: Optional[Request]) -> Optional[dict]:
if request is None:
return None
if (
hasattr(request, "scope")
and "parsed_body" in request.scope
and isinstance(request.scope["parsed_body"], tuple)
):
accepted_keys, parsed_body = request.scope["parsed_body"]
return {key: parsed_body[key] for key in accepted_keys}
return None
def _safe_get_request_query_params(request: Optional[Request]) -> Dict:
if request is None:
return {}
try:
if hasattr(request, "query_params"):
return dict(request.query_params)
return {}
except Exception as e:
verbose_proxy_logger.debug(
"Unexpected error reading request query params - {}".format(e)
)
return {}
def _safe_set_request_parsed_body(
request: Optional[Request],
parsed_body: dict,
) -> None:
try:
if request is None:
return
request.scope["parsed_body"] = (tuple(parsed_body.keys()), parsed_body)
except Exception as e:
verbose_proxy_logger.debug(
"Unexpected error setting request parsed body - {}".format(e)
)
def _safe_get_request_headers(request: Optional[Request]) -> dict:
"""
[Non-Blocking] Safely get the request headers.
Caches the result on request.state to avoid re-creating dict(request.headers) per call.
Warning: Callers must NOT mutate the returned dict — it is shared across
all callers within the same request via the cache.
"""
if request is None:
return {}
state = getattr(request, "state", None)
cached = getattr(state, "_cached_headers", None)
if isinstance(cached, dict):
return cached
if cached is not None:
verbose_proxy_logger.debug(
"Unexpected cached request headers type - {}".format(type(cached))
)
try:
headers = dict(request.headers)
except Exception as e:
verbose_proxy_logger.debug(
"Unexpected error reading request headers - {}".format(e)
)
headers = {}
try:
if state is not None:
state._cached_headers = headers
except Exception:
pass # request.state may not be available in all contexts
return headers
def check_file_size_under_limit(
request_data: dict,
file: UploadFile,
router_model_names: Collection[str],
) -> bool:
"""
Check if any files passed in request are under max_file_size_mb
Returns True -> when file size is under max_file_size_mb limit
Raises ProxyException -> when file size is over max_file_size_mb limit or not a premium_user
"""
from litellm.proxy.proxy_server import (
CommonProxyErrors,
ProxyException,
llm_router,
premium_user,
)
file_contents_size = file.size or 0
file_content_size_in_mb = file_contents_size / (1024 * 1024)
if "metadata" not in request_data:
request_data["metadata"] = {}
request_data["metadata"]["file_size_in_mb"] = file_content_size_in_mb
max_file_size_mb = None
if llm_router is not None and request_data["model"] in router_model_names:
try:
deployment: Optional[
Deployment
] = llm_router.get_deployment_by_model_group_name(
model_group_name=request_data["model"]
)
if (
deployment
and deployment.litellm_params is not None
and deployment.litellm_params.max_file_size_mb is not None
):
max_file_size_mb = deployment.litellm_params.max_file_size_mb
except Exception as e:
verbose_proxy_logger.error(
"Got error when checking file size: %s", (str(e))
)
if max_file_size_mb is not None:
verbose_proxy_logger.debug(
"Checking file size, file content size=%s, max_file_size_mb=%s",
file_content_size_in_mb,
max_file_size_mb,
)
if not premium_user:
raise ProxyException(
message=f"Tried setting max_file_size_mb for /audio/transcriptions. {CommonProxyErrors.not_premium_user.value}",
code=status.HTTP_400_BAD_REQUEST,
type="bad_request",
param="file",
)
if file_content_size_in_mb > max_file_size_mb:
raise ProxyException(
message=f"File size is too large. Please check your file size. Passed file size: {file_content_size_in_mb} MB. Max file size: {max_file_size_mb} MB",
code=status.HTTP_400_BAD_REQUEST,
type="bad_request",
param="file",
)
return True
async def get_form_data(request: Request) -> Dict[str, Any]:
"""
Read form data from request
Handles when OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
"""
form = await request.form()
form_data = dict(form)
parsed_form_data: dict[str, Any] = {}
for key, value in form_data.items():
# OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
if key.endswith("[]"):
clean_key = key[:-2]
parsed_form_data.setdefault(clean_key, []).append(value)
else:
parsed_form_data[key] = value
return parsed_form_data
async def convert_upload_files_to_file_data(
form_data: Dict[str, Any]
) -> Dict[str, Any]:
"""
Convert FastAPI UploadFile objects to file data tuples for litellm.
Converts UploadFile objects to tuples of (filename, content, content_type)
which is the format expected by httpx and litellm's HTTP handlers.
Args:
form_data: Dictionary containing form data with potential UploadFile objects
Returns:
Dictionary with UploadFile objects converted to file data tuples
Example:
```python
form_data = await get_form_data(request)
data = await convert_upload_files_to_file_data(form_data)
# data["files"] is now [(filename, content, content_type), ...]
```
"""
data = {}
for key, value in form_data.items():
if isinstance(value, list):
# Check if it's a list of UploadFile objects
if value and hasattr(value[0], "read"):
files = []
for f in value:
file_content = await f.read()
# Create tuple: (filename, content, content_type)
files.append((f.filename, file_content, f.content_type))
data[key] = files
else:
data[key] = value
elif hasattr(value, "read"):
# Single UploadFile object - read and convert to list for consistency
file_content = await value.read()
data[key] = [(value.filename, file_content, value.content_type)]
else:
# Regular form field
data[key] = value
return data
async def get_request_body(request: Request) -> Dict[str, Any]:
"""
Read the request body and parse it as JSON.
"""
if request.method == "POST":
if request.headers.get("content-type", "") == "application/json":
return await _read_request_body(request)
elif "multipart/form-data" in request.headers.get(
"content-type", ""
) or "application/x-www-form-urlencoded" in request.headers.get(
"content-type", ""
):
return await get_form_data(request)
else:
raise ValueError(
f"Unsupported content type: {request.headers.get('content-type')}"
)
return {}
def extract_nested_form_metadata(
form_data: Dict[str, Any], prefix: str = "litellm_metadata["
) -> Dict[str, Any]:
"""
Extract nested metadata from form data with bracket notation.
Handles form data that uses bracket notation to represent nested dictionaries,
such as litellm_metadata[spend_logs_metadata][owner] = "value".
This is commonly encountered when SDKs or clients send form data with nested
structures using bracket notation instead of JSON.
Args:
form_data: Dictionary containing form data (from request.form())
prefix: The prefix to look for in form keys (default: "litellm_metadata[")
Returns:
Dictionary with nested structure reconstructed from bracket notation
Example:
Input form_data:
{
"litellm_metadata[spend_logs_metadata][owner]": "john",
"litellm_metadata[spend_logs_metadata][team]": "engineering",
"litellm_metadata[tags]": "production",
"other_field": "value"
}
Output:
{
"spend_logs_metadata": {
"owner": "john",
"team": "engineering"
},
"tags": "production"
}
"""
if not form_data:
return {}
metadata: Dict[str, Any] = {}
for key, value in form_data.items():
# Skip keys that don't start with the prefix
if not isinstance(key, str) or not key.startswith(prefix):
continue
# Skip UploadFile objects - they should not be in metadata
if isinstance(value, UploadFile):
verbose_proxy_logger.warning(
f"Skipping UploadFile in metadata extraction for key: {key}"
)
continue
# Extract the nested path from bracket notation
# Example: "litellm_metadata[spend_logs_metadata][owner]" -> ["spend_logs_metadata", "owner"]
try:
# Remove the prefix and strip trailing ']'
path_string = key.replace(prefix, "").rstrip("]")
# Split by "][" to get individual path parts
parts = path_string.split("][")
if not parts or not parts[0]:
verbose_proxy_logger.warning(
f"Invalid metadata key format (empty path): {key}"
)
continue
# Navigate/create nested dictionary structure
current = metadata
for part in parts[:-1]:
if not isinstance(current, dict):
verbose_proxy_logger.warning(
f"Cannot create nested path - intermediate value is not a dict at: {part}"
)
break
current = current.setdefault(part, {})
else:
# Set the final value (only if we didn't break out of the loop)
if isinstance(current, dict):
current[parts[-1]] = value
else:
verbose_proxy_logger.warning(
f"Cannot set value - parent is not a dict for key: {key}"
)
except Exception as e:
verbose_proxy_logger.error(f"Error parsing metadata key '{key}': {str(e)}")
continue
return metadata
def get_tags_from_request_body(request_body: dict) -> List[str]:
"""
Extract tags from request body metadata.
Args:
request_body: The request body dictionary
Returns:
List of tag names (strings), empty list if no valid tags found
"""
metadata_variable_name = get_metadata_variable_name_from_kwargs(request_body)
metadata = request_body.get(metadata_variable_name) or {}
tags_in_metadata: Any = metadata.get("tags", [])
tags_in_request_body: Any = request_body.get("tags", [])
combined_tags: List[str] = []
######################################
# Only combine tags if they are lists
######################################
if isinstance(tags_in_metadata, list):
combined_tags.extend(tags_in_metadata)
if isinstance(tags_in_request_body, list):
combined_tags.extend(tags_in_request_body)
######################################
return [tag for tag in combined_tags if isinstance(tag, str)]
def populate_request_with_path_params(request_data: dict, request: Request) -> dict:
"""
Copy FastAPI path params and query params into the request payload so downstream checks
(e.g. vector store RBAC, organization RBAC) see them the same way as body params.
Since path_params may not be available during dependency injection,
we parse the URL path directly for known patterns.
Args:
request_data: The request data dictionary to populate
request: The FastAPI Request object
Returns:
dict: Updated request_data with path parameters and query parameters added
"""
# Add query parameters to request_data (for GET requests, etc.)
query_params = _safe_get_request_query_params(request)
if query_params:
for key, value in query_params.items():
# Don't overwrite existing values from request body
request_data.setdefault(key, value)
# Try to get path_params if available (sometimes populated by FastAPI)
path_params = getattr(request, "path_params", None)
if isinstance(path_params, dict) and path_params:
for key, value in path_params.items():
if key == "vector_store_id":
request_data.setdefault("vector_store_id", value)
existing_ids = request_data.get("vector_store_ids")
if isinstance(existing_ids, list):
if value not in existing_ids:
existing_ids.append(value)
else:
request_data["vector_store_ids"] = [value]
continue
request_data.setdefault(key, value)
verbose_proxy_logger.debug(
f"populate_request_with_path_params: Found path_params, vector_store_ids={request_data.get('vector_store_ids')}"
)
return request_data
# Fallback: parse the URL path directly to extract vector_store_id
_add_vector_store_id_from_path(request_data=request_data, request=request)
return request_data
def _add_vector_store_id_from_path(request_data: dict, request: Request) -> None:
"""
Parse the request path to find /vector_stores/{vector_store_id}/... segments.
When found, ensure both vector_store_id and vector_store_ids are populated.
Args:
request_data: The request data dictionary to populate
request: The FastAPI Request object
"""
path = request.url.path
vector_store_match = re.search(r"/vector_stores/([^/]+)/", path)
if vector_store_match:
vector_store_id = vector_store_match.group(1)
verbose_proxy_logger.debug(
f"populate_request_with_path_params: Extracted vector_store_id={vector_store_id} from path={path}"
)
request_data.setdefault("vector_store_id", vector_store_id)
existing_ids = request_data.get("vector_store_ids")
if isinstance(existing_ids, list):
if vector_store_id not in existing_ids:
existing_ids.append(vector_store_id)
else:
request_data["vector_store_ids"] = [vector_store_id]
verbose_proxy_logger.debug(
f"populate_request_with_path_params: Updated request_data with vector_store_ids={request_data.get('vector_store_ids')}"
)
else:
verbose_proxy_logger.debug(
f"populate_request_with_path_params: No vector_store_id present in path={path}"
)

View File

@@ -0,0 +1,187 @@
"""
Key Rotation Manager - Automated key rotation based on rotation schedules
Handles finding keys that need rotation based on their individual schedules.
"""
from datetime import datetime, timezone
from typing import List
from litellm._logging import verbose_proxy_logger
from litellm.constants import (
LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME,
LITELLM_KEY_ROTATION_GRACE_PERIOD,
)
from litellm.proxy._types import (
GenerateKeyResponse,
LiteLLM_VerificationToken,
RegenerateKeyRequest,
)
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
from litellm.proxy.management_endpoints.key_management_endpoints import (
_calculate_key_rotation_time,
regenerate_key_fn,
)
from litellm.proxy.utils import PrismaClient
class KeyRotationManager:
"""
Manages automated key rotation based on individual key rotation schedules.
"""
def __init__(self, prisma_client: PrismaClient):
self.prisma_client = prisma_client
async def process_rotations(self):
"""
Main entry point - find and rotate keys that are due for rotation
"""
try:
verbose_proxy_logger.info("Starting scheduled key rotation check...")
# Clean up expired deprecated keys first
await self._cleanup_expired_deprecated_keys()
# Find keys that are due for rotation
keys_to_rotate = await self._find_keys_needing_rotation()
if not keys_to_rotate:
verbose_proxy_logger.debug("No keys are due for rotation at this time")
return
verbose_proxy_logger.info(
f"Found {len(keys_to_rotate)} keys due for rotation"
)
# Rotate each key
for key in keys_to_rotate:
try:
await self._rotate_key(key)
key_identifier = key.key_name or (
key.token[:8] + "..." if key.token else "unknown"
)
verbose_proxy_logger.info(
f"Successfully rotated key: {key_identifier}"
)
except Exception as e:
key_identifier = key.key_name or (
key.token[:8] + "..." if key.token else "unknown"
)
verbose_proxy_logger.error(
f"Failed to rotate key {key_identifier}: {e}"
)
except Exception as e:
verbose_proxy_logger.error(f"Key rotation process failed: {e}")
async def _find_keys_needing_rotation(self) -> List[LiteLLM_VerificationToken]:
"""
Find keys that are due for rotation based on their key_rotation_at timestamp.
Logic:
- Key has auto_rotate = true
- key_rotation_at is null (needs initial setup) OR key_rotation_at <= now
"""
now = datetime.now(timezone.utc)
keys_with_rotation = (
await self.prisma_client.db.litellm_verificationtoken.find_many(
where={
"auto_rotate": True, # Only keys marked for auto rotation
"OR": [
{
"key_rotation_at": None
}, # Keys that need initial rotation time setup
{
"key_rotation_at": {"lte": now}
}, # Keys where rotation time has passed
],
}
)
)
return keys_with_rotation
async def _cleanup_expired_deprecated_keys(self) -> None:
"""
Remove deprecated key entries whose revoke_at has passed.
"""
try:
now = datetime.now(timezone.utc)
result = await self.prisma_client.db.litellm_deprecatedverificationtoken.delete_many(
where={"revoke_at": {"lt": now}}
)
if result > 0:
verbose_proxy_logger.debug(
"Cleaned up %s expired deprecated key(s)", result
)
except Exception as e:
verbose_proxy_logger.debug(
"Deprecated key cleanup skipped (table may not exist): %s", e
)
def _should_rotate_key(self, key: LiteLLM_VerificationToken, now: datetime) -> bool:
"""
Determine if a key should be rotated based on key_rotation_at timestamp.
"""
if not key.rotation_interval:
return False
# If key_rotation_at is not set, rotate immediately (and set it)
if key.key_rotation_at is None:
return True
# Check if the rotation time has passed
return now >= key.key_rotation_at
async def _rotate_key(self, key: LiteLLM_VerificationToken):
"""
Rotate a single key using existing regenerate_key_fn and call the rotation hook
"""
# Create regenerate request with grace period for seamless cutover
regenerate_request = RegenerateKeyRequest(
key=key.token or "",
key_alias=key.key_alias, # Pass key alias to ensure correct secret is updated in AWS Secrets Manager
grace_period=LITELLM_KEY_ROTATION_GRACE_PERIOD or None,
)
# Create a system user for key rotation
from litellm.proxy._types import UserAPIKeyAuth
system_user = UserAPIKeyAuth.get_litellm_internal_jobs_user_api_key_auth()
# Use existing regenerate key function
response = await regenerate_key_fn(
data=regenerate_request,
user_api_key_dict=system_user,
litellm_changed_by=LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME,
)
# Update the NEW key with rotation info (regenerate_key_fn creates a new token)
if (
isinstance(response, GenerateKeyResponse)
and response.token_id
and key.rotation_interval
):
# Calculate next rotation time using helper function
now = datetime.now(timezone.utc)
next_rotation_time = _calculate_key_rotation_time(key.rotation_interval)
await self.prisma_client.db.litellm_verificationtoken.update(
where={"token": response.token_id},
data={
"rotation_count": (key.rotation_count or 0) + 1,
"last_rotation_at": now,
"key_rotation_at": next_rotation_time,
},
)
# Call the existing rotation hook for notifications, audit logs, etc.
if isinstance(response, GenerateKeyResponse):
await KeyManagementEventHooks.async_key_rotated_hook(
data=regenerate_request,
existing_key_row=key,
response=response,
user_api_key_dict=system_user,
litellm_changed_by=LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME,
)

View File

@@ -0,0 +1,178 @@
import os
import yaml
from litellm._logging import verbose_proxy_logger
def get_file_contents_from_s3(bucket_name, object_key):
try:
# v0 rely on boto3 for authentication - allowing boto3 to handle IAM credentials etc
import boto3
from botocore.credentials import Credentials
from litellm.main import bedrock_converse_chat_completion
credentials: Credentials = bedrock_converse_chat_completion.get_credentials()
s3_client = boto3.client(
"s3",
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token, # Optional, if using temporary credentials
)
verbose_proxy_logger.debug(
f"Retrieving {object_key} from S3 bucket: {bucket_name}"
)
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
verbose_proxy_logger.debug(f"Response: {response}")
# Read the file contents and directly parse YAML
file_contents = response["Body"].read().decode("utf-8")
verbose_proxy_logger.debug("File contents retrieved from S3")
# Parse YAML directly from string
config = yaml.safe_load(file_contents)
return config
except ImportError as e:
# this is most likely if a user is not using the litellm docker container
verbose_proxy_logger.error(f"ImportError: {str(e)}")
pass
except Exception as e:
verbose_proxy_logger.error(f"Error retrieving file contents: {str(e)}")
return None
async def get_config_file_contents_from_gcs(bucket_name, object_key):
try:
from litellm.integrations.gcs_bucket.gcs_bucket import GCSBucketLogger
gcs_bucket = GCSBucketLogger(
bucket_name=bucket_name,
)
file_contents = await gcs_bucket.download_gcs_object(object_key)
if file_contents is None:
raise Exception(f"File contents are None for {object_key}")
# file_contentis is a bytes object, so we need to convert it to yaml
file_contents = file_contents.decode("utf-8")
# convert to yaml
config = yaml.safe_load(file_contents)
return config
except Exception as e:
verbose_proxy_logger.error(f"Error retrieving file contents: {str(e)}")
return None
def download_python_file_from_s3(
bucket_name: str,
object_key: str,
local_file_path: str,
) -> bool:
"""
Download a Python file from S3 and save it to local filesystem.
Args:
bucket_name (str): S3 bucket name
object_key (str): S3 object key (file path in bucket)
local_file_path (str): Local path where file should be saved
Returns:
bool: True if successful, False otherwise
"""
try:
import boto3
from botocore.credentials import Credentials
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
base_aws_llm = BaseAWSLLM()
credentials: Credentials = base_aws_llm.get_credentials()
s3_client = boto3.client(
"s3",
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token,
)
verbose_proxy_logger.debug(
f"Downloading Python file {object_key} from S3 bucket: {bucket_name}"
)
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
# Read the file contents
file_contents = response["Body"].read().decode("utf-8")
verbose_proxy_logger.debug(f"File contents: {file_contents}")
# Ensure directory exists
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
# Write to local file
with open(local_file_path, "w") as f:
f.write(file_contents)
verbose_proxy_logger.debug(
f"Python file downloaded successfully to {local_file_path}"
)
return True
except ImportError as e:
verbose_proxy_logger.error(f"ImportError: {str(e)}")
return False
except Exception as e:
verbose_proxy_logger.exception(f"Error downloading Python file: {str(e)}")
return False
async def download_python_file_from_gcs(
bucket_name: str,
object_key: str,
local_file_path: str,
) -> bool:
"""
Download a Python file from GCS and save it to local filesystem.
Args:
bucket_name (str): GCS bucket name
object_key (str): GCS object key (file path in bucket)
local_file_path (str): Local path where file should be saved
Returns:
bool: True if successful, False otherwise
"""
try:
from litellm.integrations.gcs_bucket.gcs_bucket import GCSBucketLogger
gcs_bucket = GCSBucketLogger(
bucket_name=bucket_name,
)
file_contents = await gcs_bucket.download_gcs_object(object_key)
if file_contents is None:
raise Exception(f"File contents are None for {object_key}")
# file_contents is a bytes object, decode it
file_contents = file_contents.decode("utf-8")
# Ensure directory exists
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
# Write to local file
with open(local_file_path, "w") as f:
f.write(file_contents)
verbose_proxy_logger.debug(
f"Python file downloaded successfully to {local_file_path}"
)
return True
except Exception as e:
verbose_proxy_logger.exception(
f"Error downloading Python file from GCS: {str(e)}"
)
return False
# # Example usage
# bucket_name = 'litellm-proxy'
# object_key = 'litellm_proxy_config.yaml'

Some files were not shown because too many files have changed in this diff Show More