chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1,680 @@
from typing import List, Set
from fastapi import APIRouter, Depends, HTTPException, status
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
CommonProxyErrors,
LiteLLM_AccessGroupTable,
LitellmUserRoles,
UserAPIKeyAuth,
)
from litellm.proxy.auth.auth_checks import (
_cache_access_object,
_cache_key_object,
_cache_team_object,
_delete_cache_access_object,
_get_team_object_from_cache,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler
from litellm.proxy.utils import get_prisma_client_or_throw
from litellm.types.access_group import (
AccessGroupCreateRequest,
AccessGroupResponse,
AccessGroupUpdateRequest,
)
router = APIRouter(
tags=["access group management"],
)
def _require_proxy_admin(user_api_key_dict: UserAPIKeyAuth) -> None:
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={"error": CommonProxyErrors.not_allowed_access.value},
)
def _record_to_response(record) -> AccessGroupResponse:
return AccessGroupResponse(
access_group_id=record.access_group_id,
access_group_name=record.access_group_name,
description=record.description,
access_model_names=record.access_model_names,
access_mcp_server_ids=record.access_mcp_server_ids,
access_agent_ids=record.access_agent_ids,
assigned_team_ids=record.assigned_team_ids,
assigned_key_ids=record.assigned_key_ids,
created_at=record.created_at,
created_by=record.created_by,
updated_at=record.updated_at,
updated_by=record.updated_by,
)
def _record_to_access_group_table(record) -> LiteLLM_AccessGroupTable:
"""Convert a Prisma record to a LiteLLM_AccessGroupTable pydantic object for caching."""
return LiteLLM_AccessGroupTable(**record.dict())
async def _cache_access_group_record(record) -> None:
"""
Cache an access group Prisma record in the user_api_key_cache.
Uses a lazy import of user_api_key_cache and proxy_logging_obj from proxy_server
to avoid circular imports, following the same pattern as key_management_endpoints.
"""
from litellm.proxy.proxy_server import proxy_logging_obj, user_api_key_cache
access_group_table = _record_to_access_group_table(record)
await _cache_access_object(
access_group_id=record.access_group_id,
access_group_table=access_group_table,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
async def _invalidate_cache_access_group(access_group_id: str) -> None:
"""
Invalidate (delete) an access group entry from both in-memory and Redis caches.
Uses a lazy import of user_api_key_cache and proxy_logging_obj from proxy_server
to avoid circular imports, following the same pattern as key_management_endpoints.
"""
from litellm.proxy.proxy_server import proxy_logging_obj, user_api_key_cache
await _delete_cache_access_object(
access_group_id=access_group_id,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
# ---------------------------------------------------------------------------
# DB sync helpers (called inside a Prisma transaction)
# ---------------------------------------------------------------------------
async def _sync_add_access_group_to_teams(
tx, team_ids: List[str], access_group_id: str
) -> None:
"""Add access_group_id to each team's access_group_ids (idempotent)."""
for team_id in team_ids:
team = await tx.litellm_teamtable.find_unique(where={"team_id": team_id})
if team is not None and access_group_id not in (team.access_group_ids or []):
await tx.litellm_teamtable.update(
where={"team_id": team_id},
data={
"access_group_ids": list(team.access_group_ids or [])
+ [access_group_id]
},
)
async def _sync_remove_access_group_from_teams(
tx, team_ids: List[str], access_group_id: str
) -> None:
"""Remove access_group_id from each team's access_group_ids (idempotent)."""
for team_id in team_ids:
team = await tx.litellm_teamtable.find_unique(where={"team_id": team_id})
if team is not None and access_group_id in (team.access_group_ids or []):
await tx.litellm_teamtable.update(
where={"team_id": team_id},
data={
"access_group_ids": [
ag for ag in team.access_group_ids if ag != access_group_id
]
},
)
async def _sync_add_access_group_to_keys(
tx, key_tokens: List[str], access_group_id: str
) -> None:
"""Add access_group_id to each key's access_group_ids (idempotent)."""
for token in key_tokens:
key = await tx.litellm_verificationtoken.find_unique(where={"token": token})
if key is not None and access_group_id not in (key.access_group_ids or []):
await tx.litellm_verificationtoken.update(
where={"token": token},
data={
"access_group_ids": list(key.access_group_ids or [])
+ [access_group_id]
},
)
async def _sync_remove_access_group_from_keys(
tx, key_tokens: List[str], access_group_id: str
) -> None:
"""Remove access_group_id from each key's access_group_ids (idempotent)."""
for token in key_tokens:
key = await tx.litellm_verificationtoken.find_unique(where={"token": token})
if key is not None and access_group_id in (key.access_group_ids or []):
await tx.litellm_verificationtoken.update(
where={"token": token},
data={
"access_group_ids": [
ag for ag in key.access_group_ids if ag != access_group_id
]
},
)
# ---------------------------------------------------------------------------
# Cache patch helpers
# ---------------------------------------------------------------------------
async def _patch_team_caches_add_access_group(
team_ids: List[str],
access_group_id: str,
user_api_key_cache,
proxy_logging_obj,
) -> None:
"""Patch cached team objects to include access_group_id."""
for team_id in team_ids:
cached_team = await _get_team_object_from_cache(
key="team_id:{}".format(team_id),
proxy_logging_obj=proxy_logging_obj,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
)
if cached_team is None:
continue
if cached_team.access_group_ids is None:
cached_team.access_group_ids = [access_group_id]
elif access_group_id not in cached_team.access_group_ids:
cached_team.access_group_ids = list(cached_team.access_group_ids) + [
access_group_id
]
else:
continue
await _cache_team_object(
team_id=team_id,
team_table=cached_team,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
async def _patch_team_caches_remove_access_group(
team_ids: List[str],
access_group_id: str,
user_api_key_cache,
proxy_logging_obj,
) -> None:
"""Patch cached team objects to remove access_group_id."""
for team_id in team_ids:
cached_team = await _get_team_object_from_cache(
key="team_id:{}".format(team_id),
proxy_logging_obj=proxy_logging_obj,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
)
if cached_team is not None and cached_team.access_group_ids:
cached_team.access_group_ids = [
ag for ag in cached_team.access_group_ids if ag != access_group_id
]
await _cache_team_object(
team_id=team_id,
team_table=cached_team,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
async def _patch_key_caches_add_access_group(
key_tokens: List[str],
access_group_id: str,
user_api_key_cache,
proxy_logging_obj,
) -> None:
"""Patch cached key objects to include access_group_id."""
for token in key_tokens:
cached_key = await user_api_key_cache.async_get_cache(key=token)
if cached_key is None:
continue
if isinstance(cached_key, dict):
cached_key = UserAPIKeyAuth(**cached_key)
if not isinstance(cached_key, UserAPIKeyAuth):
continue
if cached_key.access_group_ids is None:
cached_key.access_group_ids = [access_group_id]
elif access_group_id not in cached_key.access_group_ids:
cached_key.access_group_ids = list(cached_key.access_group_ids) + [
access_group_id
]
else:
continue
await _cache_key_object(
hashed_token=token,
user_api_key_obj=cached_key,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
async def _patch_key_caches_remove_access_group(
key_tokens: List[str],
access_group_id: str,
user_api_key_cache,
proxy_logging_obj,
) -> None:
"""Patch cached key objects to remove access_group_id."""
for token in key_tokens:
cached_key = await user_api_key_cache.async_get_cache(key=token)
if cached_key is None:
continue
if isinstance(cached_key, dict):
cached_key = UserAPIKeyAuth(**cached_key)
if isinstance(cached_key, UserAPIKeyAuth) and cached_key.access_group_ids:
cached_key.access_group_ids = [
ag for ag in cached_key.access_group_ids if ag != access_group_id
]
await _cache_key_object(
hashed_token=token,
user_api_key_obj=cached_key,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
# ---------------------------------------------------------------------------
# CRUD endpoints
# ---------------------------------------------------------------------------
@router.post(
"/v1/access_group",
response_model=AccessGroupResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_access_group(
data: AccessGroupCreateRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> AccessGroupResponse:
_require_proxy_admin(user_api_key_dict)
prisma_client = get_prisma_client_or_throw(
CommonProxyErrors.db_not_connected_error.value
)
try:
async with prisma_client.db.tx() as tx:
existing = await tx.litellm_accessgrouptable.find_unique(
where={"access_group_name": data.access_group_name}
)
if existing is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Access group '{data.access_group_name}' already exists",
)
record = await tx.litellm_accessgrouptable.create(
data={
"access_group_name": data.access_group_name,
"description": data.description,
"access_model_names": data.access_model_names or [],
"access_mcp_server_ids": data.access_mcp_server_ids or [],
"access_agent_ids": data.access_agent_ids or [],
"assigned_team_ids": data.assigned_team_ids or [],
"assigned_key_ids": data.assigned_key_ids or [],
"created_by": user_api_key_dict.user_id,
"updated_by": user_api_key_dict.user_id,
}
)
# Sync team and key tables to reference the new access group
await _sync_add_access_group_to_teams(
tx, data.assigned_team_ids or [], record.access_group_id
)
await _sync_add_access_group_to_keys(
tx, data.assigned_key_ids or [], record.access_group_id
)
except HTTPException:
raise
except Exception as e:
# Race condition: another request created the same name between find_unique and create.
if "unique constraint" in str(e).lower() or "P2002" in str(e):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Access group '{data.access_group_name}' already exists",
)
raise
from litellm.proxy.proxy_server import proxy_logging_obj, user_api_key_cache
await _cache_access_group_record(record)
await _patch_team_caches_add_access_group(
data.assigned_team_ids or [],
record.access_group_id,
user_api_key_cache,
proxy_logging_obj,
)
await _patch_key_caches_add_access_group(
data.assigned_key_ids or [],
record.access_group_id,
user_api_key_cache,
proxy_logging_obj,
)
return _record_to_response(record)
@router.get(
"/v1/access_group",
response_model=List[AccessGroupResponse],
)
async def list_access_groups(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> List[AccessGroupResponse]:
_require_proxy_admin(user_api_key_dict)
prisma_client = get_prisma_client_or_throw(
CommonProxyErrors.db_not_connected_error.value
)
records = await prisma_client.db.litellm_accessgrouptable.find_many(
order={"created_at": "desc"}
)
return [_record_to_response(r) for r in records]
@router.get(
"/v1/access_group/{access_group_id}",
response_model=AccessGroupResponse,
)
async def get_access_group(
access_group_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> AccessGroupResponse:
_require_proxy_admin(user_api_key_dict)
prisma_client = get_prisma_client_or_throw(
CommonProxyErrors.db_not_connected_error.value
)
record = await prisma_client.db.litellm_accessgrouptable.find_unique(
where={"access_group_id": access_group_id}
)
if record is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Access group '{access_group_id}' not found",
)
return _record_to_response(record)
@router.put(
"/v1/access_group/{access_group_id}",
response_model=AccessGroupResponse,
)
async def update_access_group(
access_group_id: str,
data: AccessGroupUpdateRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> AccessGroupResponse:
_require_proxy_admin(user_api_key_dict)
prisma_client = get_prisma_client_or_throw(
CommonProxyErrors.db_not_connected_error.value
)
update_fields = data.model_dump(exclude_unset=True)
update_data: dict = {"updated_by": user_api_key_dict.user_id}
for field, value in update_fields.items():
if (
field
in (
"assigned_team_ids",
"assigned_key_ids",
"access_model_names",
"access_mcp_server_ids",
"access_agent_ids",
)
and value is None
):
value = []
update_data[field] = value
# Initialize delta lists before the try block so they remain accessible
# for cache updates after the transaction, even if an error path is added later.
teams_to_add: List[str] = []
teams_to_remove: List[str] = []
keys_to_add: List[str] = []
keys_to_remove: List[str] = []
try:
async with prisma_client.db.tx() as tx:
# Read inside the transaction so delta computation is consistent with the write,
# avoiding a TOCTOU race where a concurrent update could make deltas stale.
existing = await tx.litellm_accessgrouptable.find_unique(
where={"access_group_id": access_group_id}
)
if existing is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Access group '{access_group_id}' not found",
)
old_team_ids: Set[str] = set(existing.assigned_team_ids or [])
old_key_ids: Set[str] = set(existing.assigned_key_ids or [])
new_team_ids: Set[str] = (
set(update_fields["assigned_team_ids"] or [])
if "assigned_team_ids" in update_fields
else old_team_ids
)
new_key_ids: Set[str] = (
set(update_fields["assigned_key_ids"] or [])
if "assigned_key_ids" in update_fields
else old_key_ids
)
teams_to_add = list(new_team_ids - old_team_ids)
teams_to_remove = list(old_team_ids - new_team_ids)
keys_to_add = list(new_key_ids - old_key_ids)
keys_to_remove = list(old_key_ids - new_key_ids)
record = await tx.litellm_accessgrouptable.update(
where={"access_group_id": access_group_id},
data=update_data,
)
await _sync_add_access_group_to_teams(tx, teams_to_add, access_group_id)
await _sync_remove_access_group_from_teams(
tx, teams_to_remove, access_group_id
)
await _sync_add_access_group_to_keys(tx, keys_to_add, access_group_id)
await _sync_remove_access_group_from_keys(
tx, keys_to_remove, access_group_id
)
except HTTPException:
raise
except Exception as e:
# Unique constraint violation (e.g. access_group_name already exists).
if "unique constraint" in str(e).lower() or "P2002" in str(e):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Access group '{update_data.get('access_group_name', '')}' already exists",
)
raise
from litellm.proxy.proxy_server import proxy_logging_obj, user_api_key_cache
await _cache_access_group_record(record)
await _patch_team_caches_add_access_group(
teams_to_add, access_group_id, user_api_key_cache, proxy_logging_obj
)
await _patch_team_caches_remove_access_group(
teams_to_remove, access_group_id, user_api_key_cache, proxy_logging_obj
)
await _patch_key_caches_add_access_group(
keys_to_add, access_group_id, user_api_key_cache, proxy_logging_obj
)
await _patch_key_caches_remove_access_group(
keys_to_remove, access_group_id, user_api_key_cache, proxy_logging_obj
)
return _record_to_response(record)
@router.delete(
"/v1/access_group/{access_group_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
async def delete_access_group(
access_group_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> None:
_require_proxy_admin(user_api_key_dict)
prisma_client = get_prisma_client_or_throw(
CommonProxyErrors.db_not_connected_error.value
)
try:
affected_team_ids: List[str] = []
affected_key_tokens: List[str] = []
async with prisma_client.db.tx() as tx:
existing = await tx.litellm_accessgrouptable.find_unique(
where={"access_group_id": access_group_id}
)
if existing is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Access group '{access_group_id}' not found",
)
# Union of: teams that have this access_group_id in their own access_group_ids
# AND teams listed in assigned_team_ids (handles out-of-sync data from before this sync was added)
teams_with_group = await tx.litellm_teamtable.find_many(
where={"access_group_ids": {"hasSome": [access_group_id]}}
)
all_affected_team_ids: Set[str] = {
team.team_id for team in teams_with_group
} | set(existing.assigned_team_ids or [])
affected_team_ids = list(all_affected_team_ids)
# Union of: keys that have this access_group_id in their own access_group_ids
# AND keys listed in assigned_key_ids (handles out-of-sync data)
keys_with_group = await tx.litellm_verificationtoken.find_many(
where={"access_group_ids": {"hasSome": [access_group_id]}}
)
all_affected_key_tokens: Set[str] = {
key.token for key in keys_with_group
} | set(existing.assigned_key_ids or [])
affected_key_tokens = list(all_affected_key_tokens)
# Update teams returned by find_many directly — we already have their data.
for team in teams_with_group:
await tx.litellm_teamtable.update(
where={"team_id": team.team_id},
data={
"access_group_ids": [
ag
for ag in (team.access_group_ids or [])
if ag != access_group_id
]
},
)
# Use _sync_remove only for out-of-sync teams not found by the hasSome query.
out_of_sync_team_ids = set(existing.assigned_team_ids or []) - {
t.team_id for t in teams_with_group
}
await _sync_remove_access_group_from_teams(
tx, list(out_of_sync_team_ids), access_group_id
)
# Update keys returned by find_many directly — we already have their data.
for key in keys_with_group:
await tx.litellm_verificationtoken.update(
where={"token": key.token},
data={
"access_group_ids": [
ag
for ag in (key.access_group_ids or [])
if ag != access_group_id
]
},
)
# Use _sync_remove only for out-of-sync keys not found by the hasSome query.
out_of_sync_key_tokens = set(existing.assigned_key_ids or []) - {
k.token for k in keys_with_group
}
await _sync_remove_access_group_from_keys(
tx, list(out_of_sync_key_tokens), access_group_id
)
await tx.litellm_accessgrouptable.delete(
where={"access_group_id": access_group_id}
)
from litellm.proxy.proxy_server import proxy_logging_obj, user_api_key_cache
await _invalidate_cache_access_group(access_group_id)
await _patch_team_caches_remove_access_group(
affected_team_ids, access_group_id, user_api_key_cache, proxy_logging_obj
)
await _patch_key_caches_remove_access_group(
affected_key_tokens, access_group_id, user_api_key_cache, proxy_logging_obj
)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(
"delete_access_group failed: access_group_id=%s error=%s",
access_group_id,
e,
)
if PrismaDBExceptionHandler.is_database_connection_error(e):
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=CommonProxyErrors.db_not_connected_error.value,
)
if "P2025" in str(e) or (
"record" in str(e).lower() and "not found" in str(e).lower()
):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Access group '{access_group_id}' not found",
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to delete access group. Please try again.",
)
# Alias routes for /v1/unified_access_group
router.add_api_route(
"/v1/unified_access_group",
create_access_group,
methods=["POST"],
response_model=AccessGroupResponse,
status_code=status.HTTP_201_CREATED,
)
router.add_api_route(
"/v1/unified_access_group",
list_access_groups,
methods=["GET"],
response_model=List[AccessGroupResponse],
)
router.add_api_route(
"/v1/unified_access_group/{access_group_id}",
get_access_group,
methods=["GET"],
response_model=AccessGroupResponse,
)
router.add_api_route(
"/v1/unified_access_group/{access_group_id}",
update_access_group,
methods=["PUT"],
response_model=AccessGroupResponse,
)
router.add_api_route(
"/v1/unified_access_group/{access_group_id}",
delete_access_group,
methods=["DELETE"],
status_code=status.HTTP_204_NO_CONTENT,
)

View File

@@ -0,0 +1,352 @@
"""
BUDGET MANAGEMENT
All /budget management endpoints
/budget/new
/budget/info
/budget/update
/budget/delete
/budget/settings
/budget/list
"""
#### BUDGET TABLE MANAGEMENT ####
from datetime import timedelta
from fastapi import APIRouter, Depends, HTTPException
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.utils import jsonify_object
router = APIRouter()
@router.post(
"/budget/new",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def new_budget(
budget_obj: BudgetNewRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create a new budget object. Can apply this to teams, orgs, end-users, keys.
Parameters:
- budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.)
- budget_id: Optional[str] - The id of the budget. If not provided, a new id will be generated.
- max_budget: Optional[float] - The max budget for the budget.
- soft_budget: Optional[float] - The soft budget for the budget.
- max_parallel_requests: Optional[int] - The max number of parallel requests for the budget.
- tpm_limit: Optional[int] - The tokens per minute limit for the budget.
- rpm_limit: Optional[int] - The requests per minute limit for the budget.
- model_max_budget: Optional[dict] - Specify max budget for a given model. Example: {"openai/gpt-4o-mini": {"max_budget": 100.0, "budget_duration": "1d", "tpm_limit": 100000, "rpm_limit": 100000}}
- budget_reset_at: Optional[datetime] - Datetime when the initial budget is reset. Default is now.
"""
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
# Validate budget values are not negative
if budget_obj.max_budget is not None and budget_obj.max_budget < 0:
raise HTTPException(
status_code=400,
detail={
"error": f"max_budget cannot be negative. Received: {budget_obj.max_budget}"
},
)
if budget_obj.soft_budget is not None and budget_obj.soft_budget < 0:
raise HTTPException(
status_code=400,
detail={
"error": f"soft_budget cannot be negative. Received: {budget_obj.soft_budget}"
},
)
# Validate model_max_budget if present
if budget_obj.model_max_budget is not None and len(budget_obj.model_max_budget) > 0:
from litellm.proxy.management_endpoints.key_management_endpoints import (
validate_model_max_budget,
)
try:
validate_model_max_budget(budget_obj.model_max_budget)
except ValueError as e:
raise HTTPException(status_code=400, detail={"error": str(e)})
# if no budget_reset_at date is set, but a budget_duration is given, then set budget_reset_at initially to the first completed duration interval in future
if budget_obj.budget_reset_at is None and budget_obj.budget_duration is not None:
budget_obj.budget_reset_at = datetime.utcnow() + timedelta(
seconds=duration_in_seconds(duration=budget_obj.budget_duration)
)
budget_obj_json = budget_obj.model_dump(exclude_none=True)
budget_obj_jsonified = jsonify_object(budget_obj_json) # json dump any dictionaries
response = await prisma_client.db.litellm_budgettable.create(
data={
**budget_obj_jsonified, # type: ignore
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
} # type: ignore
)
return response
@router.post(
"/budget/update",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_budget(
budget_obj: BudgetNewRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Update an existing budget object.
Parameters:
- budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.)
- budget_id: Optional[str] - The id of the budget. If not provided, a new id will be generated.
- max_budget: Optional[float] - The max budget for the budget.
- soft_budget: Optional[float] - The soft budget for the budget.
- max_parallel_requests: Optional[int] - The max number of parallel requests for the budget.
- tpm_limit: Optional[int] - The tokens per minute limit for the budget.
- rpm_limit: Optional[int] - The requests per minute limit for the budget.
- model_max_budget: Optional[dict] - Specify max budget for a given model. Example: {"openai/gpt-4o-mini": {"max_budget": 100.0, "budget_duration": "1d", "tpm_limit": 100000, "rpm_limit": 100000}}
- budget_reset_at: Optional[datetime] - Update the Datetime when the budget was last reset.
"""
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if budget_obj.budget_id is None:
raise HTTPException(status_code=400, detail={"error": "budget_id is required"})
# Validate budget values are not negative
if budget_obj.max_budget is not None and budget_obj.max_budget < 0:
raise HTTPException(
status_code=400,
detail={
"error": f"max_budget cannot be negative. Received: {budget_obj.max_budget}"
},
)
if budget_obj.soft_budget is not None and budget_obj.soft_budget < 0:
raise HTTPException(
status_code=400,
detail={
"error": f"soft_budget cannot be negative. Received: {budget_obj.soft_budget}"
},
)
# Validate model_max_budget if present in update
if budget_obj.model_max_budget is not None and len(budget_obj.model_max_budget) > 0:
from litellm.proxy.management_endpoints.key_management_endpoints import (
validate_model_max_budget,
)
try:
validate_model_max_budget(budget_obj.model_max_budget)
except ValueError as e:
raise HTTPException(status_code=400, detail={"error": str(e)})
response = await prisma_client.db.litellm_budgettable.update(
where={"budget_id": budget_obj.budget_id},
data={
**budget_obj.model_dump(exclude_unset=True), # type: ignore
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
}, # type: ignore
)
return response
@router.post(
"/budget/info",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def info_budget(data: BudgetRequest):
"""
Get the budget id specific information
Parameters:
- budgets: List[str] - The list of budget ids to get information for
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if len(data.budgets) == 0:
raise HTTPException(
status_code=400,
detail={
"error": f"Specify list of budget id's to query. Passed in={data.budgets}"
},
)
response = await prisma_client.db.litellm_budgettable.find_many(
where={"budget_id": {"in": data.budgets}},
)
return response
@router.get(
"/budget/settings",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def budget_settings(
budget_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get list of configurable params + current value for a budget item + description of each field
Used on Admin UI.
Query Parameters:
- budget_id: str - The budget id to get information for
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=400,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={
"error": "{}, your role={}".format(
CommonProxyErrors.not_allowed_access.value,
user_api_key_dict.user_role,
)
},
)
## get budget item from db
db_budget_row = await prisma_client.db.litellm_budgettable.find_first(
where={"budget_id": budget_id}
)
if db_budget_row is not None:
db_budget_row_dict = db_budget_row.model_dump(exclude_none=True)
else:
db_budget_row_dict = {}
allowed_args = {
"max_parallel_requests": {"type": "Integer"},
"tpm_limit": {"type": "Integer"},
"rpm_limit": {"type": "Integer"},
"budget_duration": {"type": "String"},
"max_budget": {"type": "Float"},
"soft_budget": {"type": "Float"},
"model_max_budget": {"type": "Object"},
}
return_val = []
for field_name, field_info in BudgetNewRequest.model_fields.items():
if field_name in allowed_args:
_stored_in_db = True
_response_obj = ConfigList(
field_name=field_name,
field_type=allowed_args[field_name]["type"],
field_description=field_info.description or "",
field_value=db_budget_row_dict.get(field_name, None),
stored_in_db=_stored_in_db,
field_default_value=field_info.default,
)
return_val.append(_response_obj)
return return_val
@router.get(
"/budget/list",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def list_budget(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""List all the created budgets in proxy db. Used on Admin UI."""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=400,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={
"error": "{}, your role={}".format(
CommonProxyErrors.not_allowed_access.value,
user_api_key_dict.user_role,
)
},
)
response = await prisma_client.db.litellm_budgettable.find_many()
return response
@router.post(
"/budget/delete",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_budget(
data: BudgetDeleteRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Delete budget
Parameters:
- id: str - The budget id to delete
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=400,
detail={
"error": "{}, your role={}".format(
CommonProxyErrors.not_allowed_access.value,
user_api_key_dict.user_role,
)
},
)
response = await prisma_client.db.litellm_budgettable.delete(
where={"budget_id": data.id}
)
return response

View File

@@ -0,0 +1,365 @@
"""
CACHE SETTINGS MANAGEMENT
Endpoints for managing cache configuration
GET /cache/settings - Get cache configuration including available settings
POST /cache/settings/test - Test cache connection with provided credentials
POST /cache/settings - Save cache settings to database
"""
import json
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.management_endpoints import (
CACHE_SETTINGS_FIELDS,
REDIS_TYPE_DESCRIPTIONS,
CacheSettingsField,
)
router = APIRouter()
class CacheSettingsManager:
"""
Manages cache settings initialization and updates.
Tracks last cache params to avoid unnecessary reinitialization.
"""
_last_cache_params: Optional[Dict[str, Any]] = None
@staticmethod
def _cache_params_equal(params1: Dict[str, Any], params2: Dict[str, Any]) -> bool:
"""
Compare two cache parameter dictionaries for equality.
Normalizes values and filters out UI-only fields.
"""
# Normalize by removing None values and UI-only fields
def normalize(params: Dict[str, Any]) -> Dict[str, Any]:
normalized = {}
for k, v in params.items():
if k == "redis_type": # Skip UI-only field
continue
if v is not None:
# Convert to string for comparison to handle different types
normalized[k] = str(v) if not isinstance(v, (list, dict)) else v
return normalized
normalized1 = normalize(params1)
normalized2 = normalize(params2)
return normalized1 == normalized2
@staticmethod
async def init_cache_settings_in_db(prisma_client, proxy_config):
"""
Initialize cache settings from database into the router on startup.
Only reinitializes if cache params have changed.
"""
import json
try:
cache_config = await prisma_client.db.litellm_cacheconfig.find_unique(
where={"id": "cache_config"}
)
if cache_config is not None and cache_config.cache_settings:
# Parse cache settings JSON
cache_settings_json = cache_config.cache_settings
if isinstance(cache_settings_json, str):
cache_settings_dict = json.loads(cache_settings_json)
else:
cache_settings_dict = cache_settings_json
# Decrypt cache settings
decrypted_settings = proxy_config._decrypt_db_variables(
variables_dict=cache_settings_dict
)
# Remove redis_type if present (UI-only field, not a Cache parameter)
# We derive it for UI in get_cache_settings endpoint
cache_params = {
k: v for k, v in decrypted_settings.items() if k != "redis_type"
}
# Check if cache params have changed
if (
CacheSettingsManager._last_cache_params is not None
and CacheSettingsManager._cache_params_equal(
CacheSettingsManager._last_cache_params, cache_params
)
):
verbose_proxy_logger.debug(
"Cache settings unchanged, skipping reinitialization"
)
return
# Initialize cache only if params changed or cache not initialized
proxy_config._init_cache(cache_params=cache_params)
# Store the params we just initialized
CacheSettingsManager._last_cache_params = cache_params.copy()
# Switch on LLM response caching
proxy_config.switch_on_llm_response_caching()
verbose_proxy_logger.info("Cache settings initialized from database")
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.management_endpoints.cache_settings_endpoints.py::CacheSettingsManager::init_cache_settings_in_db - {}".format(
str(e)
)
)
@staticmethod
def update_cache_params(cache_params: Dict[str, Any]):
"""
Update the last cache params after initialization.
Called after cache settings are updated via the API.
"""
CacheSettingsManager._last_cache_params = cache_params.copy()
class CacheSettingsResponse(BaseModel):
fields: List[CacheSettingsField] = Field(
description="List of all configurable cache settings with metadata"
)
current_values: Dict[str, Any] = Field(
description="Current values of cache settings"
)
redis_type_descriptions: Dict[str, str] = Field(
description="Descriptions for each Redis type option"
)
class CacheTestRequest(BaseModel):
cache_settings: Dict[str, Any] = Field(
description="Cache settings to test connection with"
)
class CacheTestResponse(BaseModel):
status: str = Field(description="Connection status: 'success' or 'failed'")
message: str = Field(description="Connection result message")
error: Optional[str] = Field(
default=None, description="Error message if connection failed"
)
class CacheSettingsUpdateRequest(BaseModel):
cache_settings: Dict[str, Any] = Field(description="Cache settings to save")
@router.get(
"/cache/settings",
tags=["Cache Settings"],
dependencies=[Depends(user_api_key_auth)],
response_model=CacheSettingsResponse,
)
async def get_cache_settings(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get cache configuration and available settings.
Returns:
- fields: List of all configurable cache settings with their metadata (type, description, default, options)
- current_values: Current values of cache settings from database
"""
from litellm.proxy.proxy_server import prisma_client, proxy_config
try:
# Get cache settings fields from types file
cache_fields = [field.model_copy(deep=True) for field in CACHE_SETTINGS_FIELDS]
# Try to get cache settings from database
current_values = {}
if prisma_client is not None:
cache_config = await prisma_client.db.litellm_cacheconfig.find_unique(
where={"id": "cache_config"}
)
if cache_config is not None and cache_config.cache_settings:
# Decrypt cache settings
cache_settings_json = cache_config.cache_settings
if isinstance(cache_settings_json, str):
cache_settings_dict = json.loads(cache_settings_json)
else:
cache_settings_dict = cache_settings_json
# Decrypt environment variables
decrypted_settings = proxy_config._decrypt_db_variables(
variables_dict=cache_settings_dict
)
# Derive redis_type for UI based on settings
# UI uses redis_type to show/hide fields, backend only stores 'type'
if decrypted_settings.get("type") == "redis":
if decrypted_settings.get("redis_startup_nodes"):
decrypted_settings["redis_type"] = "cluster"
elif decrypted_settings.get("sentinel_nodes"):
decrypted_settings["redis_type"] = "sentinel"
else:
decrypted_settings["redis_type"] = "node"
current_values = decrypted_settings
# Update field values with current values
for field in cache_fields:
if field.field_name in current_values:
field.field_value = current_values[field.field_name]
return CacheSettingsResponse(
fields=cache_fields,
current_values=current_values,
redis_type_descriptions=REDIS_TYPE_DESCRIPTIONS,
)
except Exception as e:
verbose_proxy_logger.error(f"Error fetching cache settings: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error fetching cache settings: {str(e)}"
)
@router.post(
"/cache/settings/test",
tags=["Cache Settings"],
dependencies=[Depends(user_api_key_auth)],
response_model=CacheTestResponse,
)
async def test_cache_connection(
request: CacheTestRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Test cache connection with provided credentials.
Creates a temporary cache instance and uses its test_connection method
to verify the credentials work without affecting global state.
"""
from litellm import Cache
try:
cache_settings = request.cache_settings.copy()
verbose_proxy_logger.debug(
"Testing cache connection with settings: %s", cache_settings
)
# Only support Redis for now
if cache_settings.get("type") != "redis":
return CacheTestResponse(
status="failed",
message="Only Redis cache type is currently supported for testing",
)
# Create temporary cache instance
temp_cache = Cache(**cache_settings)
# Use the cache's test_connection method
result = await temp_cache.cache.test_connection()
return CacheTestResponse(**result)
except Exception as e:
verbose_proxy_logger.error(f"Error testing cache connection: {str(e)}")
return CacheTestResponse(
status="failed",
message=f"Cache connection test failed: {str(e)}",
error=str(e),
)
@router.post(
"/cache/settings",
tags=["Cache Settings"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_cache_settings(
request: CacheSettingsUpdateRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Save cache settings to database and initialize cache.
This endpoint:
1. Encrypts sensitive fields (passwords, etc.)
2. Saves to LiteLLM_CacheConfig table
3. Reinitializes cache with new settings
"""
from litellm.proxy.proxy_server import (
prisma_client,
proxy_config,
store_model_in_db,
)
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": "Database not connected. Please connect a database."},
)
if store_model_in_db is not True:
raise HTTPException(
status_code=500,
detail={
"error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature."
},
)
try:
cache_settings = request.cache_settings.copy()
# Encrypt sensitive fields (keep redis_type for storage)
encrypted_settings = proxy_config._encrypt_env_variables(
environment_variables=cache_settings
)
# Save to database
await prisma_client.db.litellm_cacheconfig.upsert(
where={"id": "cache_config"},
data={
"create": {
"id": "cache_config",
"cache_settings": json.dumps(encrypted_settings),
},
"update": {
"cache_settings": json.dumps(encrypted_settings),
},
},
)
# Reinitialize cache with new settings
# Decrypt for initialization
decrypted_settings = proxy_config._decrypt_db_variables(
variables_dict=encrypted_settings
)
# Remove redis_type if present (UI-only field, not a Cache parameter)
cache_params = {
k: v for k, v in decrypted_settings.items() if k != "redis_type"
}
# Initialize cache (frontend sends type="redis", not redis_type)
proxy_config._init_cache(cache_params=cache_params)
# Update the last cache params to avoid reinitializing unnecessarily
CacheSettingsManager.update_cache_params(cache_params)
# Switch on LLM response caching
proxy_config.switch_on_llm_response_caching()
return {
"message": "Cache settings updated successfully",
"status": "success",
"settings": cache_settings,
}
except Exception as e:
verbose_proxy_logger.error(f"Error updating cache settings: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error updating cache settings: {str(e)}"
)

View File

@@ -0,0 +1,54 @@
"""
Endpoints for managing callbacks
"""
import json
import os
from fastapi import APIRouter, Depends
from litellm.litellm_core_utils.logging_callback_manager import CallbacksByType
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
router = APIRouter()
@router.get(
"/callbacks/list",
tags=["Logging Callbacks"],
dependencies=[Depends(user_api_key_auth)],
response_model=CallbacksByType,
)
async def list_callbacks():
"""
View List of Active Logging Callbacks
"""
from litellm import logging_callback_manager
# Get callbacks organized by type using the callback manager utility
callbacks_by_type = logging_callback_manager.get_callbacks_by_type()
return callbacks_by_type
@router.get(
"/callbacks/configs",
tags=["Logging Callbacks"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_callback_configs():
"""
Get Available Callback Configurations
Returns the configuration details for all available logging callbacks,
including supported parameters, field types, and descriptions.
"""
config_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"integrations",
"callback_configs.json",
)
with open(config_path, "r") as f:
configs = json.load(f)
return configs

View File

@@ -0,0 +1,814 @@
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from fastapi import HTTPException, status
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import CommonProxyErrors
from litellm.proxy.utils import PrismaClient
from litellm.types.proxy.management_endpoints.common_daily_activity import (
BreakdownMetrics,
DailySpendData,
DailySpendMetadata,
KeyMetadata,
KeyMetricWithMetadata,
MetricWithMetadata,
SpendAnalyticsPaginatedResponse,
SpendMetrics,
)
# Mapping from Prisma accessor names to actual PostgreSQL table names.
_PRISMA_TO_PG_TABLE: Dict[str, str] = {
"litellm_dailyuserspend": "LiteLLM_DailyUserSpend",
"litellm_dailyteamspend": "LiteLLM_DailyTeamSpend",
"litellm_dailyorganizationspend": "LiteLLM_DailyOrganizationSpend",
"litellm_dailyenduserspend": "LiteLLM_DailyEndUserSpend",
"litellm_dailyagentspend": "LiteLLM_DailyAgentSpend",
"litellm_dailytagspend": "LiteLLM_DailyTagSpend",
}
def update_metrics(existing_metrics: SpendMetrics, record: Any) -> SpendMetrics:
"""Update metrics with new record data."""
existing_metrics.spend += record.spend
existing_metrics.prompt_tokens += record.prompt_tokens
existing_metrics.completion_tokens += record.completion_tokens
existing_metrics.total_tokens += record.prompt_tokens + record.completion_tokens
existing_metrics.cache_read_input_tokens += record.cache_read_input_tokens
existing_metrics.cache_creation_input_tokens += record.cache_creation_input_tokens
existing_metrics.api_requests += record.api_requests
existing_metrics.successful_requests += record.successful_requests
existing_metrics.failed_requests += record.failed_requests
return existing_metrics
def _is_user_agent_tag(tag: Optional[str]) -> bool:
"""Determine whether a tag should be treated as a User-Agent tag."""
if not tag:
return False
normalized_tag = tag.strip().lower()
return normalized_tag.startswith("user-agent:") or normalized_tag.startswith(
"user agent:"
)
def compute_tag_metadata_totals(records: List[Any]) -> SpendMetrics:
"""
Deduplicate spend metrics for tags using request_id, ignoring User-Agent prefixed tags.
Each unique request_id contributes at most one record (the tag with max spend) to metadata.
"""
deduped_records: Dict[str, Any] = {}
for record in records:
request_id = getattr(record, "request_id", None)
if not request_id:
continue
tag_value = getattr(record, "tag", None)
if _is_user_agent_tag(tag_value):
continue
current_best = deduped_records.get(request_id)
if current_best is None or record.spend > current_best.spend:
deduped_records[request_id] = record
metadata_metrics = SpendMetrics()
for record in deduped_records.values():
update_metrics(metadata_metrics, record)
return metadata_metrics
def update_breakdown_metrics(
breakdown: BreakdownMetrics,
record: Any,
model_metadata: Dict[str, Dict[str, Any]],
provider_metadata: Dict[str, Dict[str, Any]],
api_key_metadata: Dict[str, Dict[str, Any]],
entity_id_field: Optional[str] = None,
entity_metadata_field: Optional[Dict[str, dict]] = None,
) -> BreakdownMetrics:
"""Updates breakdown metrics for a single record using the existing update_metrics function"""
# Update model breakdown
if record.model and record.model not in breakdown.models:
breakdown.models[record.model] = MetricWithMetadata(
metrics=SpendMetrics(),
metadata=model_metadata.get(
record.model, {}
), # Add any model-specific metadata here
)
if record.model:
breakdown.models[record.model].metrics = update_metrics(
breakdown.models[record.model].metrics, record
)
# Update API key breakdown for this model
if record.api_key not in breakdown.models[record.model].api_key_breakdown:
breakdown.models[record.model].api_key_breakdown[
record.api_key
] = KeyMetricWithMetadata(
metrics=SpendMetrics(),
metadata=KeyMetadata(
key_alias=api_key_metadata.get(record.api_key, {}).get(
"key_alias", None
),
team_id=api_key_metadata.get(record.api_key, {}).get(
"team_id", None
),
),
)
breakdown.models[record.model].api_key_breakdown[
record.api_key
].metrics = update_metrics(
breakdown.models[record.model].api_key_breakdown[record.api_key].metrics,
record,
)
# Update model group breakdown
if record.model_group and record.model_group not in breakdown.model_groups:
breakdown.model_groups[record.model_group] = MetricWithMetadata(
metrics=SpendMetrics(),
metadata=model_metadata.get(record.model_group, {}),
)
if record.model_group:
breakdown.model_groups[record.model_group].metrics = update_metrics(
breakdown.model_groups[record.model_group].metrics, record
)
# Update API key breakdown for this model
if (
record.api_key
not in breakdown.model_groups[record.model_group].api_key_breakdown
):
breakdown.model_groups[record.model_group].api_key_breakdown[
record.api_key
] = KeyMetricWithMetadata(
metrics=SpendMetrics(),
metadata=KeyMetadata(
key_alias=api_key_metadata.get(record.api_key, {}).get(
"key_alias", None
),
team_id=api_key_metadata.get(record.api_key, {}).get(
"team_id", None
),
),
)
breakdown.model_groups[record.model_group].api_key_breakdown[
record.api_key
].metrics = update_metrics(
breakdown.model_groups[record.model_group]
.api_key_breakdown[record.api_key]
.metrics,
record,
)
if record.mcp_namespaced_tool_name:
if record.mcp_namespaced_tool_name not in breakdown.mcp_servers:
breakdown.mcp_servers[record.mcp_namespaced_tool_name] = MetricWithMetadata(
metrics=SpendMetrics(),
metadata={},
)
breakdown.mcp_servers[record.mcp_namespaced_tool_name].metrics = update_metrics(
breakdown.mcp_servers[record.mcp_namespaced_tool_name].metrics, record
)
# Update API key breakdown for this MCP server
if (
record.api_key
not in breakdown.mcp_servers[
record.mcp_namespaced_tool_name
].api_key_breakdown
):
breakdown.mcp_servers[record.mcp_namespaced_tool_name].api_key_breakdown[
record.api_key
] = KeyMetricWithMetadata(
metrics=SpendMetrics(),
metadata=KeyMetadata(
key_alias=api_key_metadata.get(record.api_key, {}).get(
"key_alias", None
),
team_id=api_key_metadata.get(record.api_key, {}).get(
"team_id", None
),
),
)
breakdown.mcp_servers[record.mcp_namespaced_tool_name].api_key_breakdown[
record.api_key
].metrics = update_metrics(
breakdown.mcp_servers[record.mcp_namespaced_tool_name]
.api_key_breakdown[record.api_key]
.metrics,
record,
)
# Update provider breakdown
provider = record.custom_llm_provider or "unknown"
if provider not in breakdown.providers:
breakdown.providers[provider] = MetricWithMetadata(
metrics=SpendMetrics(),
metadata=provider_metadata.get(
provider, {}
), # Add any provider-specific metadata here
)
breakdown.providers[provider].metrics = update_metrics(
breakdown.providers[provider].metrics, record
)
# Update API key breakdown for this provider
if record.api_key not in breakdown.providers[provider].api_key_breakdown:
breakdown.providers[provider].api_key_breakdown[
record.api_key
] = KeyMetricWithMetadata(
metrics=SpendMetrics(),
metadata=KeyMetadata(
key_alias=api_key_metadata.get(record.api_key, {}).get(
"key_alias", None
),
team_id=api_key_metadata.get(record.api_key, {}).get("team_id", None),
),
)
breakdown.providers[provider].api_key_breakdown[
record.api_key
].metrics = update_metrics(
breakdown.providers[provider].api_key_breakdown[record.api_key].metrics,
record,
)
# Update endpoint breakdown
if record.endpoint:
if record.endpoint not in breakdown.endpoints:
breakdown.endpoints[record.endpoint] = MetricWithMetadata(
metrics=SpendMetrics(),
metadata={},
)
breakdown.endpoints[record.endpoint].metrics = update_metrics(
breakdown.endpoints[record.endpoint].metrics, record
)
# Update API key breakdown for this endpoint
if record.api_key not in breakdown.endpoints[record.endpoint].api_key_breakdown:
breakdown.endpoints[record.endpoint].api_key_breakdown[
record.api_key
] = KeyMetricWithMetadata(
metrics=SpendMetrics(),
metadata=KeyMetadata(
key_alias=api_key_metadata.get(record.api_key, {}).get(
"key_alias", None
),
team_id=api_key_metadata.get(record.api_key, {}).get(
"team_id", None
),
),
)
breakdown.endpoints[record.endpoint].api_key_breakdown[
record.api_key
].metrics = update_metrics(
breakdown.endpoints[record.endpoint]
.api_key_breakdown[record.api_key]
.metrics,
record,
)
# Update api key breakdown
if record.api_key not in breakdown.api_keys:
breakdown.api_keys[record.api_key] = KeyMetricWithMetadata(
metrics=SpendMetrics(),
metadata=KeyMetadata(
key_alias=api_key_metadata.get(record.api_key, {}).get(
"key_alias", None
),
team_id=api_key_metadata.get(record.api_key, {}).get("team_id", None),
), # Add any api_key-specific metadata here
)
breakdown.api_keys[record.api_key].metrics = update_metrics(
breakdown.api_keys[record.api_key].metrics, record
)
# Update entity-specific metrics if entity_id_field is provided
if entity_id_field:
entity_value = getattr(record, entity_id_field, None)
entity_value = (
entity_value if entity_value else "Unassigned"
) # allow for null entity_id_field
if entity_value not in breakdown.entities:
breakdown.entities[entity_value] = MetricWithMetadata(
metrics=SpendMetrics(),
metadata=(
entity_metadata_field.get(entity_value, {})
if entity_metadata_field
else {}
),
)
breakdown.entities[entity_value].metrics = update_metrics(
breakdown.entities[entity_value].metrics, record
)
# Update API key breakdown for this entity
if record.api_key not in breakdown.entities[entity_value].api_key_breakdown:
breakdown.entities[entity_value].api_key_breakdown[
record.api_key
] = KeyMetricWithMetadata(
metrics=SpendMetrics(),
metadata=KeyMetadata(
key_alias=api_key_metadata.get(record.api_key, {}).get(
"key_alias", None
),
team_id=api_key_metadata.get(record.api_key, {}).get(
"team_id", None
),
),
)
breakdown.entities[entity_value].api_key_breakdown[
record.api_key
].metrics = update_metrics(
breakdown.entities[entity_value].api_key_breakdown[record.api_key].metrics,
record,
)
return breakdown
async def get_api_key_metadata(
prisma_client: PrismaClient,
api_keys: Set[str],
) -> Dict[str, Dict[str, Any]]:
"""Get api key metadata, falling back to deleted keys table for keys not found in active table.
This ensures that key_alias and team_id are preserved in historical activity logs
even after a key is deleted or regenerated.
"""
key_records = await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": list(api_keys)}}
)
result = {
k.token: {"key_alias": k.key_alias, "team_id": k.team_id} for k in key_records
}
# For any keys not found in the active table, check the deleted keys table
missing_keys = api_keys - set(result.keys())
if missing_keys:
try:
deleted_key_records = (
await prisma_client.db.litellm_deletedverificationtoken.find_many(
where={"token": {"in": list(missing_keys)}},
order={"deleted_at": "desc"},
)
)
# Use the most recent deleted record for each token (ordered by deleted_at desc)
for k in deleted_key_records:
if k.token not in result:
result[k.token] = {
"key_alias": k.key_alias,
"team_id": k.team_id,
}
except Exception as e:
verbose_proxy_logger.warning(
"Failed to fetch deleted key metadata for %d missing keys: %s",
len(missing_keys),
e,
)
return result
def _adjust_dates_for_timezone(
start_date: str,
end_date: str,
timezone_offset_minutes: Optional[int],
) -> Tuple[str, str]:
"""
Adjust date range to account for timezone differences.
The database stores dates in UTC. When a user in a different timezone
selects a local date range, we need to expand the UTC query range to
capture all records that fall within their local date range.
Args:
start_date: Start date in YYYY-MM-DD format (user's local date)
end_date: End date in YYYY-MM-DD format (user's local date)
timezone_offset_minutes: Minutes behind UTC (positive = west of UTC)
This matches JavaScript's Date.getTimezoneOffset() convention.
For example: PST = +480 (8 hours * 60 = 480 minutes behind UTC)
Returns:
Tuple of (adjusted_start_date, adjusted_end_date) in YYYY-MM-DD format
"""
if timezone_offset_minutes is None or timezone_offset_minutes == 0:
return start_date, end_date
start = datetime.strptime(start_date, "%Y-%m-%d")
end = datetime.strptime(end_date, "%Y-%m-%d")
if timezone_offset_minutes > 0:
# West of UTC (Americas): local evening extends into next UTC day
# e.g., Feb 4 23:59 PST = Feb 5 07:59 UTC
end = end + timedelta(days=1)
else:
# East of UTC (Asia/Europe): local morning starts in previous UTC day
# e.g., Feb 4 00:00 IST = Feb 3 18:30 UTC
start = start - timedelta(days=1)
return start.strftime("%Y-%m-%d"), end.strftime("%Y-%m-%d")
def _build_where_conditions(
*,
entity_id_field: str,
entity_id: Optional[Union[str, List[str]]],
start_date: str,
end_date: str,
model: Optional[str],
api_key: Optional[Union[str, List[str]]],
exclude_entity_ids: Optional[List[str]] = None,
timezone_offset_minutes: Optional[int] = None,
) -> Dict[str, Any]:
"""Build prisma where clause for daily activity queries."""
# Adjust dates for timezone if provided
adjusted_start, adjusted_end = _adjust_dates_for_timezone(
start_date, end_date, timezone_offset_minutes
)
where_conditions: Dict[str, Any] = {
"date": {
"gte": adjusted_start,
"lte": adjusted_end,
}
}
if model:
where_conditions["model"] = model
if api_key:
if isinstance(api_key, list):
where_conditions["api_key"] = {"in": api_key}
else:
where_conditions["api_key"] = api_key
if entity_id is not None:
if isinstance(entity_id, list):
where_conditions[entity_id_field] = {"in": entity_id}
else:
where_conditions[entity_id_field] = {"equals": entity_id}
if exclude_entity_ids:
current = where_conditions.get(entity_id_field, {})
if isinstance(current, str):
current = {"equals": current}
current["not"] = {"in": exclude_entity_ids}
where_conditions[entity_id_field] = current
return where_conditions
def _build_aggregated_sql_query(
*,
table_name: str,
entity_id_field: str,
entity_id: Optional[Union[str, List[str]]],
start_date: str,
end_date: str,
model: Optional[str],
api_key: Optional[str],
exclude_entity_ids: Optional[List[str]] = None,
timezone_offset_minutes: Optional[int] = None,
) -> Tuple[str, List[Any]]:
"""Build a parameterized SQL GROUP BY query for aggregated daily activity.
Groups by (date, api_key, model, model_group, custom_llm_provider,
mcp_namespaced_tool_name, endpoint) with SUMs on all metric columns.
The entity_id column is intentionally omitted from GROUP BY to collapse
rows across entities — this is where the biggest row reduction comes from.
Returns:
Tuple of (sql_query, params_list) ready for prisma_client.db.query_raw().
"""
pg_table = _PRISMA_TO_PG_TABLE.get(table_name)
if pg_table is None:
raise ValueError(f"Unknown table name: {table_name}")
adjusted_start, adjusted_end = _adjust_dates_for_timezone(
start_date, end_date, timezone_offset_minutes
)
sql_conditions: List[str] = []
sql_params: List[Any] = []
p = 1 # parameter index (1-based for PostgreSQL $N placeholders)
# Date range (always present)
sql_conditions.append(f"date >= ${p}")
sql_params.append(adjusted_start)
p += 1
sql_conditions.append(f"date <= ${p}")
sql_params.append(adjusted_end)
p += 1
# Optional entity filter
if entity_id is not None:
if isinstance(entity_id, list):
placeholders = ", ".join(f"${p + i}" for i in range(len(entity_id)))
sql_conditions.append(f'"{entity_id_field}" IN ({placeholders})')
sql_params.extend(entity_id)
p += len(entity_id)
else:
sql_conditions.append(f'"{entity_id_field}" = ${p}')
sql_params.append(entity_id)
p += 1
# Exclude specific entities
if exclude_entity_ids:
placeholders = ", ".join(f"${p + i}" for i in range(len(exclude_entity_ids)))
sql_conditions.append(f'"{entity_id_field}" NOT IN ({placeholders})')
sql_params.extend(exclude_entity_ids)
p += len(exclude_entity_ids)
# Optional model filter
if model:
sql_conditions.append(f"model = ${p}")
sql_params.append(model)
p += 1
# Optional api_key filter
if api_key:
sql_conditions.append(f"api_key = ${p}")
sql_params.append(api_key)
p += 1
where_clause = " AND ".join(sql_conditions)
sql_query = f"""
SELECT
date,
api_key,
model,
model_group,
custom_llm_provider,
mcp_namespaced_tool_name,
endpoint,
SUM(spend)::float AS spend,
SUM(prompt_tokens)::bigint AS prompt_tokens,
SUM(completion_tokens)::bigint AS completion_tokens,
SUM(cache_read_input_tokens)::bigint AS cache_read_input_tokens,
SUM(cache_creation_input_tokens)::bigint AS cache_creation_input_tokens,
SUM(api_requests)::bigint AS api_requests,
SUM(successful_requests)::bigint AS successful_requests,
SUM(failed_requests)::bigint AS failed_requests
FROM "{pg_table}"
WHERE {where_clause}
GROUP BY date, api_key, model, model_group, custom_llm_provider,
mcp_namespaced_tool_name, endpoint
ORDER BY date DESC
"""
return sql_query, sql_params
async def _aggregate_spend_records(
*,
prisma_client: PrismaClient,
records: List[Any],
entity_id_field: Optional[str],
entity_metadata_field: Optional[Dict[str, dict]],
) -> Dict[str, Any]:
"""Aggregate rows into DailySpendData list and total metrics."""
api_keys: Set[str] = set()
for record in records:
if record.api_key:
api_keys.add(record.api_key)
api_key_metadata: Dict[str, Dict[str, Any]] = {}
model_metadata: Dict[str, Dict[str, Any]] = {}
provider_metadata: Dict[str, Dict[str, Any]] = {}
if api_keys:
api_key_metadata = await get_api_key_metadata(prisma_client, api_keys)
results: List[DailySpendData] = []
total_metrics = SpendMetrics()
grouped_data: Dict[str, Dict[str, Any]] = {}
for record in records:
date_str = record.date
if date_str not in grouped_data:
grouped_data[date_str] = {
"metrics": SpendMetrics(),
"breakdown": BreakdownMetrics(),
}
grouped_data[date_str]["metrics"] = update_metrics(
grouped_data[date_str]["metrics"], record
)
grouped_data[date_str]["breakdown"] = update_breakdown_metrics(
grouped_data[date_str]["breakdown"],
record,
model_metadata,
provider_metadata,
api_key_metadata,
entity_id_field=entity_id_field,
entity_metadata_field=entity_metadata_field,
)
total_metrics = update_metrics(total_metrics, record)
for date_str, data in grouped_data.items():
results.append(
DailySpendData(
date=datetime.strptime(date_str, "%Y-%m-%d").date(),
metrics=data["metrics"],
breakdown=data["breakdown"],
)
)
results.sort(key=lambda x: x.date, reverse=True)
return {"results": results, "totals": total_metrics}
async def get_daily_activity(
prisma_client: Optional[PrismaClient],
table_name: str,
entity_id_field: str,
entity_id: Optional[Union[str, List[str]]],
entity_metadata_field: Optional[Dict[str, dict]],
start_date: Optional[str],
end_date: Optional[str],
model: Optional[str],
api_key: Optional[Union[str, List[str]]],
page: int,
page_size: int,
exclude_entity_ids: Optional[List[str]] = None,
metadata_metrics_func: Optional[Callable[[List[Any]], SpendMetrics]] = None,
timezone_offset_minutes: Optional[int] = None,
) -> SpendAnalyticsPaginatedResponse:
"""Common function to get daily activity for any entity type."""
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if start_date is None or end_date is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Please provide start_date and end_date"},
)
try:
where_conditions = _build_where_conditions(
entity_id_field=entity_id_field,
entity_id=entity_id,
start_date=start_date,
end_date=end_date,
model=model,
api_key=api_key,
exclude_entity_ids=exclude_entity_ids,
timezone_offset_minutes=timezone_offset_minutes,
)
# Get total count for pagination
total_count = await getattr(prisma_client.db, table_name).count(
where=where_conditions
)
# Fetch paginated results
daily_spend_data = await getattr(prisma_client.db, table_name).find_many(
where=where_conditions,
order=[
{"date": "desc"},
],
skip=(page - 1) * page_size,
take=page_size,
)
aggregated = await _aggregate_spend_records(
prisma_client=prisma_client,
records=daily_spend_data,
entity_id_field=entity_id_field,
entity_metadata_field=entity_metadata_field,
)
metadata_metrics = aggregated["totals"]
if metadata_metrics_func:
metadata_metrics = metadata_metrics_func(daily_spend_data)
return SpendAnalyticsPaginatedResponse(
results=aggregated["results"],
metadata=DailySpendMetadata(
total_spend=metadata_metrics.spend,
total_prompt_tokens=metadata_metrics.prompt_tokens,
total_completion_tokens=metadata_metrics.completion_tokens,
total_tokens=metadata_metrics.total_tokens,
total_api_requests=metadata_metrics.api_requests,
total_successful_requests=metadata_metrics.successful_requests,
total_failed_requests=metadata_metrics.failed_requests,
total_cache_read_input_tokens=metadata_metrics.cache_read_input_tokens,
total_cache_creation_input_tokens=metadata_metrics.cache_creation_input_tokens,
page=page,
total_pages=-(-total_count // page_size), # Ceiling division
has_more=(page * page_size) < total_count,
),
)
except Exception as e:
verbose_proxy_logger.exception(f"Error fetching daily activity: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"error": f"Failed to fetch analytics: {str(e)}"},
)
async def get_daily_activity_aggregated(
prisma_client: Optional[PrismaClient],
table_name: str,
entity_id_field: str,
entity_id: Optional[Union[str, List[str]]],
entity_metadata_field: Optional[Dict[str, dict]],
start_date: Optional[str],
end_date: Optional[str],
model: Optional[str],
api_key: Optional[str],
exclude_entity_ids: Optional[List[str]] = None,
timezone_offset_minutes: Optional[int] = None,
) -> SpendAnalyticsPaginatedResponse:
"""Aggregated variant that returns the full result set (no pagination).
Uses SQL GROUP BY to aggregate rows in the database rather than fetching
all individual rows into Python. This collapses rows across entities
(users/teams/orgs), reducing ~150k rows to ~2-3k grouped rows.
Matches the response model of the paginated endpoint so the UI does not need to transform.
"""
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if start_date is None or end_date is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Please provide start_date and end_date"},
)
try:
sql_query, sql_params = _build_aggregated_sql_query(
table_name=table_name,
entity_id_field=entity_id_field,
entity_id=entity_id,
start_date=start_date,
end_date=end_date,
model=model,
api_key=api_key,
exclude_entity_ids=exclude_entity_ids,
timezone_offset_minutes=timezone_offset_minutes,
)
# Execute GROUP BY query — returns pre-aggregated dicts
rows = await prisma_client.db.query_raw(sql_query, *sql_params)
if rows is None:
rows = []
# Convert dicts to objects for compatibility with _aggregate_spend_records
records = [SimpleNamespace(**row) for row in rows]
# entity_id_field=None skips entity breakdown (entity dimension was
# collapsed by the GROUP BY, so per-entity data is not available)
aggregated = await _aggregate_spend_records(
prisma_client=prisma_client,
records=records,
entity_id_field=None,
entity_metadata_field=None,
)
return SpendAnalyticsPaginatedResponse(
results=aggregated["results"],
metadata=DailySpendMetadata(
total_spend=aggregated["totals"].spend,
total_prompt_tokens=aggregated["totals"].prompt_tokens,
total_completion_tokens=aggregated["totals"].completion_tokens,
total_tokens=aggregated["totals"].total_tokens,
total_api_requests=aggregated["totals"].api_requests,
total_successful_requests=aggregated["totals"].successful_requests,
total_failed_requests=aggregated["totals"].failed_requests,
total_cache_read_input_tokens=aggregated[
"totals"
].cache_read_input_tokens,
total_cache_creation_input_tokens=aggregated[
"totals"
].cache_creation_input_tokens,
page=1,
total_pages=1,
has_more=False,
),
)
except Exception as e:
verbose_proxy_logger.exception(
f"Error fetching aggregated daily activity: {str(e)}"
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"error": f"Failed to fetch analytics: {str(e)}"},
)

View File

@@ -0,0 +1,473 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.proxy._types import (
KeyRequestBase,
LiteLLM_ManagementEndpoint_MetadataFields,
LiteLLM_ManagementEndpoint_MetadataFields_Premium,
LiteLLM_OrganizationTable,
LiteLLM_ProjectTable,
LiteLLM_TeamTable,
LiteLLM_UserTable,
LitellmUserRoles,
NewProjectRequest,
UpdateProjectRequest,
UserAPIKeyAuth,
)
from litellm.proxy.utils import _premium_user_check
if TYPE_CHECKING:
from litellm.proxy._types import NewProjectRequest, UpdateProjectRequest
from litellm.proxy.utils import PrismaClient, ProxyLogging
def _user_has_admin_view(user_api_key_dict: UserAPIKeyAuth) -> bool:
return (
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY
)
def _is_user_team_admin(
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
) -> bool:
for member in team_obj.members_with_roles:
if (
member.user_id is not None and member.user_id == user_api_key_dict.user_id
) and member.role == "admin":
return True
return False
async def _is_user_org_admin_for_team(
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
) -> bool:
"""
Check if user is an org admin for the team's organization.
Returns True if:
- The team belongs to an organization, AND
- The user has org_admin role in that organization
"""
if not team_obj.organization_id or not user_api_key_dict.user_id:
return False
from litellm.proxy.auth.auth_checks import get_user_object
from litellm.proxy.proxy_server import (
prisma_client,
proxy_logging_obj,
user_api_key_cache,
)
caller_user = await get_user_object(
user_id=user_api_key_dict.user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
user_id_upsert=False,
proxy_logging_obj=proxy_logging_obj,
)
if caller_user is None:
return False
for m in caller_user.organization_memberships or []:
if (
m.organization_id == team_obj.organization_id
and m.user_role == LitellmUserRoles.ORG_ADMIN.value
):
return True
return False
def _team_member_has_permission(
user_api_key_dict: UserAPIKeyAuth,
team_obj: LiteLLM_TeamTable,
permission: str,
) -> bool:
"""Check if a non-admin team member has a specific permission on a team."""
if not team_obj.team_member_permissions:
return False
if permission not in team_obj.team_member_permissions:
return False
for member in team_obj.members_with_roles:
if member.user_id is not None and member.user_id == user_api_key_dict.user_id:
return True
return False
async def _user_has_admin_privileges(
user_api_key_dict: UserAPIKeyAuth,
prisma_client: Optional["PrismaClient"] = None,
user_api_key_cache: Optional["DualCache"] = None,
proxy_logging_obj: Optional["ProxyLogging"] = None,
) -> bool:
"""
Check if user has admin privileges (proxy admin, team admin, or org admin).
Args:
user_api_key_dict: User API key authentication object
prisma_client: Prisma client for database operations
user_api_key_cache: Cache for user API keys
proxy_logging_obj: Proxy logging object
Returns:
True if user is proxy admin, team admin for any team, or org admin for any organization
"""
# Check if user is proxy admin
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN:
return True
# If no database connection, can't check team/org admin status
if prisma_client is None or user_api_key_dict.user_id is None:
return False
# Get user object to check team and org admin status
from litellm.caching import DualCache as DualCacheImport
from litellm.proxy.auth.auth_checks import get_user_object
try:
user_obj = await get_user_object(
user_id=user_api_key_dict.user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache or DualCacheImport(),
user_id_upsert=False,
proxy_logging_obj=proxy_logging_obj,
)
if user_obj is None:
return False
# Check if user is org admin for any organization
if user_obj.organization_memberships is not None:
for membership in user_obj.organization_memberships:
if membership.user_role == LitellmUserRoles.ORG_ADMIN.value:
return True
# Check if user is team admin for any team
if user_obj.teams is not None and len(user_obj.teams) > 0:
# Get all teams user is in
teams = await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": user_obj.teams}}
)
for team in teams:
team_obj = LiteLLM_TeamTable(**team.model_dump())
if _is_user_team_admin(
user_api_key_dict=user_api_key_dict, team_obj=team_obj
):
return True
except Exception as e:
# If there's an error checking, default to False for security
verbose_proxy_logger.debug(
f"Error checking admin privileges for user {user_api_key_dict.user_id}: {e}"
)
return False
return False
def _org_admin_can_invite_user(
admin_user_obj: LiteLLM_UserTable,
target_user_obj: LiteLLM_UserTable,
) -> bool:
"""
Check if an org admin can invite the target user.
Target user must be in at least one org where the admin has org admin role.
Args:
admin_user_obj: The admin user's full object (from get_user_object)
target_user_obj: The target user's full object (from get_user_object)
Returns:
True if target user is in an org where admin has org admin role
"""
if admin_user_obj.organization_memberships is None:
return False
admin_org_ids = {
m.organization_id
for m in admin_user_obj.organization_memberships
if m.user_role == LitellmUserRoles.ORG_ADMIN.value
}
if not admin_org_ids:
return False
if target_user_obj.organization_memberships is None:
return False
target_org_ids = {
m.organization_id for m in target_user_obj.organization_memberships
}
return bool(admin_org_ids & target_org_ids)
async def _team_admin_can_invite_user(
user_api_key_dict: UserAPIKeyAuth,
admin_user_obj: LiteLLM_UserTable,
target_user_obj: LiteLLM_UserTable,
prisma_client: "PrismaClient",
) -> bool:
"""
Check if a team admin can invite the target user.
Target user must be in at least one team where the admin has team admin role.
Args:
user_api_key_dict: The admin user's API key auth object
admin_user_obj: The admin user's full object (from get_user_object)
target_user_obj: The target user's full object (from get_user_object)
prisma_client: Prisma client for database operations
Returns:
True if target user is in a team where admin has team admin role
"""
if not admin_user_obj.teams or len(admin_user_obj.teams) == 0:
return False
if not target_user_obj.teams or len(target_user_obj.teams) == 0:
return False
teams = await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": admin_user_obj.teams}}
)
admin_team_ids = [
team.team_id
for team in teams
if _is_user_team_admin(
user_api_key_dict=user_api_key_dict,
team_obj=LiteLLM_TeamTable(**team.model_dump()),
)
]
if not admin_team_ids:
return False
target_team_ids = set(target_user_obj.teams)
return bool(set(admin_team_ids) & target_team_ids)
async def admin_can_invite_user(
target_user_id: str,
user_api_key_dict: UserAPIKeyAuth,
prisma_client: Optional["PrismaClient"] = None,
user_api_key_cache: Optional["DualCache"] = None,
proxy_logging_obj: Optional["ProxyLogging"] = None,
) -> bool:
"""
Check if the admin can create an invitation for the target user.
- Proxy admins: can invite any user
- Org admins: can only invite users in their org(s)
- Team admins: can only invite users in their team(s)
Uses get_user_object for caching of both admin and target user objects.
Args:
target_user_id: The user_id of the user to invite
user_api_key_dict: The admin user's API key auth object
prisma_client: Prisma client for database operations
user_api_key_cache: Cache for user API keys
proxy_logging_obj: Proxy logging object
Returns:
True if user can invite the target user
"""
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN:
return True
if prisma_client is None or user_api_key_dict.user_id is None:
return False
from litellm.caching import DualCache as DualCacheImport
from litellm.proxy.auth.auth_checks import get_user_object
try:
cache = user_api_key_cache or DualCacheImport()
admin_user_obj = await get_user_object(
user_id=user_api_key_dict.user_id,
prisma_client=prisma_client,
user_api_key_cache=cache,
user_id_upsert=False,
proxy_logging_obj=proxy_logging_obj,
)
if admin_user_obj is None:
return False
target_user_obj = await get_user_object(
user_id=target_user_id,
prisma_client=prisma_client,
user_api_key_cache=cache,
user_id_upsert=False,
proxy_logging_obj=proxy_logging_obj,
)
if target_user_obj is None:
return False
if _org_admin_can_invite_user(admin_user_obj, target_user_obj):
return True
if await _team_admin_can_invite_user(
user_api_key_dict=user_api_key_dict,
admin_user_obj=admin_user_obj,
target_user_obj=target_user_obj,
prisma_client=prisma_client,
):
return True
return False
except Exception as e:
verbose_proxy_logger.debug(
f"Error checking invite permission for user {user_api_key_dict.user_id}: {e}"
)
return False
def _set_object_metadata_field(
object_data: Union[
LiteLLM_TeamTable,
KeyRequestBase,
LiteLLM_OrganizationTable,
LiteLLM_ProjectTable,
"NewProjectRequest",
"UpdateProjectRequest",
],
field_name: str,
value: Any,
) -> None:
"""
Helper function to set metadata fields that require premium user checks
Args:
object_data: The team/key/organization/project data object to modify
field_name: Name of the metadata field to set
value: Value to set for the field
"""
if field_name in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
_premium_user_check(field_name)
object_data.metadata = object_data.metadata or {}
object_data.metadata[field_name] = value
async def _upsert_budget_and_membership(
tx,
*,
team_id: str,
user_id: str,
max_budget: Optional[float],
existing_budget_id: Optional[str],
user_api_key_dict: UserAPIKeyAuth,
tpm_limit: Optional[int] = None,
rpm_limit: Optional[int] = None,
):
"""
Helper function to Create/Update or Delete the budget within the team membership
Args:
tx: The transaction object
team_id: The ID of the team
user_id: The ID of the user
max_budget: The maximum budget for the team
existing_budget_id: The ID of the existing budget, if any
user_api_key_dict: User API Key dictionary containing user information
tpm_limit: Tokens per minute limit for the team member
rpm_limit: Requests per minute limit for the team member
If max_budget, tpm_limit, and rpm_limit are all None, the user's budget is removed from the team membership.
If any of these values exist, a budget is updated or created and linked to the team membership.
"""
if max_budget is None and tpm_limit is None and rpm_limit is None:
# disconnect the budget since all limits are None
await tx.litellm_teammembership.update(
where={"user_id_team_id": {"user_id": user_id, "team_id": team_id}},
data={"litellm_budget_table": {"disconnect": True}},
)
return
# create a new budget
create_data: Dict[str, Any] = {
"created_by": user_api_key_dict.user_id or "",
"updated_by": user_api_key_dict.user_id or "",
}
if max_budget is not None:
create_data["max_budget"] = max_budget
if tpm_limit is not None:
create_data["tpm_limit"] = tpm_limit
if rpm_limit is not None:
create_data["rpm_limit"] = rpm_limit
new_budget = await tx.litellm_budgettable.create(
data=create_data,
include={"team_membership": True},
)
# upsert the team membership with the new/updated budget
await tx.litellm_teammembership.upsert(
where={
"user_id_team_id": {
"user_id": user_id,
"team_id": team_id,
}
},
data={
"create": {
"user_id": user_id,
"team_id": team_id,
"litellm_budget_table": {
"connect": {"budget_id": new_budget.budget_id},
},
},
"update": {
"litellm_budget_table": {
"connect": {"budget_id": new_budget.budget_id},
},
},
},
)
def _update_metadata_field(updated_kv: dict, field_name: str) -> None:
"""
Helper function to update metadata fields that require premium user checks in the update endpoint
Args:
updated_kv: The key-value dict being used for the update
field_name: Name of the metadata field being updated
"""
if field_name in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
value = updated_kv.get(field_name)
# Skip the premium check for empty collections ([] or {}).
# The UI sends these as defaults even when the user hasn't configured
# any enterprise features (see issue #20304). However, we still
# proceed with the update so that users can intentionally clear a
# previously-set field by sending an empty list/dict.
if value is not None and value != [] and value != {}:
_premium_user_check()
if field_name in updated_kv and updated_kv[field_name] is not None:
# remove field from updated_kv
_value = updated_kv.pop(field_name)
if "metadata" in updated_kv and updated_kv["metadata"] is not None:
updated_kv["metadata"][field_name] = _value
else:
updated_kv["metadata"] = {field_name: _value}
def _has_non_empty_value(value: Any) -> bool:
"""Check if a value has real content (not None, not empty list, not blank string)."""
if value is None:
return False
if isinstance(value, list) and len(value) == 0:
return False
if isinstance(value, str) and value.strip() == "":
return False
return True
def _update_metadata_fields(updated_kv: dict) -> None:
"""
Helper function to update all metadata fields (both premium and standard).
Args:
updated_kv: The key-value dict being used for the update
"""
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
if field in updated_kv and updated_kv[field] is not None:
_update_metadata_field(updated_kv=updated_kv, field_name=field)
for field in LiteLLM_ManagementEndpoint_MetadataFields:
if field in updated_kv and updated_kv[field] is not None:
_update_metadata_field(updated_kv=updated_kv, field_name=field)

View File

@@ -0,0 +1,79 @@
"""
COMPLIANCE CHECK ENDPOINTS
Endpoints for checking regulatory compliance of LLM request logs.
/compliance/eu-ai-act - Check EU AI Act compliance
/compliance/gdpr - Check GDPR compliance
"""
from fastapi import APIRouter, Depends, Request
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.compliance_checks import ComplianceChecker
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.types.proxy.compliance_endpoints import (
ComplianceCheckRequest,
ComplianceResponse,
)
router = APIRouter()
@router.post(
"/compliance/eu-ai-act",
tags=["compliance"],
dependencies=[Depends(user_api_key_auth)],
response_model=ComplianceResponse,
)
@management_endpoint_wrapper
async def check_eu_ai_act_compliance(
data: ComplianceCheckRequest,
http_request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> ComplianceResponse:
"""
Check EU AI Act compliance for a spend log entry.
Checks:
- Art. 9: Guardrails applied (any guardrail)
- Art. 5: Content screened before LLM (pre-call guardrails)
- Art. 12: Audit record complete (user_id, model, timestamp, guardrail_results)
"""
checker = ComplianceChecker(data)
checks = checker.check_eu_ai_act()
return ComplianceResponse(
compliant=all(c.passed for c in checks),
regulation="EU AI Act",
checks=checks,
)
@router.post(
"/compliance/gdpr",
tags=["compliance"],
dependencies=[Depends(user_api_key_auth)],
response_model=ComplianceResponse,
)
@management_endpoint_wrapper
async def check_gdpr_compliance(
data: ComplianceCheckRequest,
http_request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> ComplianceResponse:
"""
Check GDPR compliance for a spend log entry.
Checks:
- Art. 32: Data protection applied (pre-call guardrails)
- Art. 5(1)(c): Sensitive data protected (masked/blocked or no issues)
- Art. 30: Audit record complete (user_id, model, timestamp, guardrail_results)
"""
checker = ComplianceChecker(data)
checks = checker.check_gdpr()
return ComplianceResponse(
compliant=all(c.passed for c in checks),
regulation="GDPR",
checks=checks,
)

View File

@@ -0,0 +1,413 @@
import asyncio
import os
from typing import Any, Dict, Set
from fastapi import APIRouter, Depends, HTTPException
from pydantic import TypeAdapter
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
try:
from prisma.errors import RecordNotFoundError
except ImportError:
RecordNotFoundError = Exception # type: ignore
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.proxy._types import (
CommonProxyErrors,
KeyManagementSystem,
LitellmUserRoles,
UserAPIKeyAuth,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.llms.custom_http import httpxSpecialProvider
from litellm.types.proxy.management_endpoints.config_overrides import (
ConfigOverrideSettingsResponse,
HashicorpVaultConfig,
)
router = APIRouter()
# --- Hashicorp Vault constants ---
HASHICORP_ENV_VAR_MAPPING: Dict[str, str] = {
"vault_addr": "HCP_VAULT_ADDR",
"vault_token": "HCP_VAULT_TOKEN",
"approle_role_id": "HCP_VAULT_APPROLE_ROLE_ID",
"approle_secret_id": "HCP_VAULT_APPROLE_SECRET_ID",
"approle_mount_path": "HCP_VAULT_APPROLE_MOUNT_PATH",
"client_cert": "HCP_VAULT_CLIENT_CERT",
"client_key": "HCP_VAULT_CLIENT_KEY",
"vault_cert_role": "HCP_VAULT_CERT_ROLE",
"vault_namespace": "HCP_VAULT_NAMESPACE",
"vault_mount_name": "HCP_VAULT_MOUNT_NAME",
"vault_path_prefix": "HCP_VAULT_PATH_PREFIX",
}
HASHICORP_SENSITIVE_FIELDS: Set[str] = {
"vault_token",
"approle_secret_id",
"client_key",
}
_sensitive_masker = SensitiveDataMasker()
# --- Shared helpers ---
def _mask_sensitive_fields(
data: Dict[str, Any], sensitive_fields: Set[str]
) -> Dict[str, Any]:
"""Mask sensitive fields for API responses. Non-sensitive fields are left as-is."""
masked = {}
for key, value in data.items():
if value is not None and key in sensitive_fields and isinstance(value, str):
masked[key] = _sensitive_masker._mask_value(value)
else:
masked[key] = value
return masked
def _get_current_env_values(env_var_mapping: Dict[str, str]) -> Dict[str, Any]:
"""Read current env var values as fallback when no DB record exists."""
values = {}
for field_name, env_var_name in env_var_mapping.items():
env_value = os.environ.get(env_var_name)
values[field_name] = env_value
return values
def _extract_field_type(field_info: Dict[str, Any]) -> str:
"""Extract the non-null type from a Pydantic v2 JSON schema field."""
if "type" in field_info:
return field_info["type"]
for option in field_info.get("anyOf", []):
if option.get("type") != "null":
return option.get("type", "string")
return "string"
def _build_field_schema(model_class: type) -> Dict[str, Any]:
"""Build field_schema dict from a Pydantic model for UI rendering."""
schema = TypeAdapter(model_class).json_schema(by_alias=True)
properties = {}
for field_name, field_info in schema.get("properties", {}).items():
properties[field_name] = {
"description": field_info.get("description", ""),
"type": _extract_field_type(field_info),
}
return {
"description": schema.get("description", ""),
"properties": properties,
}
def _parse_config_value(raw: Any) -> Dict[str, Any]:
"""Parse a config_value from DB (may be JSON string or dict)."""
if isinstance(raw, str):
return safe_json_loads(raw, default={})
return dict(raw)
def _set_env_vars(config_data: Dict[str, Any]) -> None:
"""Set HCP_VAULT_* env vars from config data. Unsets vars for missing/None/empty fields."""
for field_name, env_var_name in HASHICORP_ENV_VAR_MAPPING.items():
value = config_data.get(field_name)
if value is not None and value != "":
os.environ[env_var_name] = str(value)
else:
os.environ.pop(env_var_name, None)
def _clear_hashicorp_vault_state(proxy_config: Any) -> None:
"""Clear all Hashicorp Vault state: env vars, secret manager, and change-detection cache."""
_set_env_vars({})
if litellm._key_management_system == KeyManagementSystem.HASHICORP_VAULT:
litellm.secret_manager_client = None
litellm._key_management_system = None
proxy_config._last_hashicorp_vault_config = None
# --- Hashicorp Vault endpoints ---
@router.post(
"/config_overrides/hashicorp_vault",
tags=["Config Overrides"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_hashicorp_vault_config(
config: HashicorpVaultConfig,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Update Hashicorp Vault secret manager configuration.
Sets environment variables, encrypts sensitive fields, and stores in DB.
Reinitializes the secret manager on this pod.
"""
from litellm.proxy.proxy_server import prisma_client, proxy_config
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=403,
detail="Only admin users can update config overrides",
)
if prisma_client is None:
raise HTTPException(
status_code=500,
detail=CommonProxyErrors.db_not_connected_error.value,
)
config_data = config.model_dump(exclude_none=True)
# Merge ALL fields the user didn't send: try DB first, fall back to env vars.
# Omitted field = keep existing; empty string = clear/remove the field.
existing_record = await prisma_client.db.litellm_configoverrides.find_unique(
where={"config_type": "hashicorp_vault"}
)
if existing_record is not None and existing_record.config_value is not None:
existing_data = _parse_config_value(existing_record.config_value)
existing_decrypted = proxy_config._decrypt_db_variables(existing_data)
for field in HASHICORP_ENV_VAR_MAPPING:
if field not in config_data and existing_decrypted.get(field):
config_data[field] = existing_decrypted[field]
else:
# No DB record yet — merge from current env vars
env_values = _get_current_env_values(HASHICORP_ENV_VAR_MAPPING)
for field in HASHICORP_ENV_VAR_MAPPING:
if field not in config_data and env_values.get(field):
config_data[field] = env_values[field]
# Strip empty strings — they signal "clear this field"
config_data = {k: v for k, v in config_data.items() if v != ""}
# Validate that the config has enough fields to initialize
has_vault_addr = bool(config_data.get("vault_addr"))
has_token_auth = bool(config_data.get("vault_token"))
has_approle_auth = bool(
config_data.get("approle_role_id") and config_data.get("approle_secret_id")
)
has_tls_cert_auth = bool(
config_data.get("client_cert") and config_data.get("client_key")
)
if not has_vault_addr:
raise HTTPException(
status_code=400,
detail="Vault Address is required",
)
if not has_token_auth and not has_approle_auth and not has_tls_cert_auth:
raise HTTPException(
status_code=400,
detail="At least one authentication method is required: "
"provide a Token, both AppRole Role ID and Secret ID, "
"or both Client Certificate and Client Key",
)
# Snapshot current env vars so we can restore on failure
previous_env = _get_current_env_values(HASHICORP_ENV_VAR_MAPPING)
# Set env vars and verify the secret manager can initialize before persisting
_set_env_vars(config_data)
try:
proxy_config.initialize_secret_manager(key_management_system="hashicorp_vault")
except Exception as e:
_set_env_vars(previous_env)
verbose_proxy_logger.exception(
"Error reinitializing Hashicorp Vault secret manager: %s", str(e)
)
raise HTTPException(
status_code=500,
detail=f"Failed to initialize secret manager: {e}",
)
# Only persist to DB after successful init
encrypted_data = proxy_config._encrypt_env_variables(config_data)
config_value = safe_dumps(encrypted_data)
await prisma_client.db.litellm_configoverrides.upsert(
where={"config_type": "hashicorp_vault"},
data={
"create": {
"config_type": "hashicorp_vault",
"config_value": config_value,
},
"update": {
"config_value": config_value,
},
},
)
# Update change-detection cache so the background reload doesn't redundantly re-init
proxy_config._last_hashicorp_vault_config = safe_json_loads(config_value)
return {
"message": "Hashicorp Vault configuration updated successfully",
"status": "success",
}
@router.get(
"/config_overrides/hashicorp_vault",
tags=["Config Overrides"],
dependencies=[Depends(user_api_key_auth)],
response_model=ConfigOverrideSettingsResponse,
)
async def get_hashicorp_vault_config(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get current Hashicorp Vault configuration.
Returns decrypted values from DB, or falls back to current env vars.
"""
from litellm.proxy.proxy_server import prisma_client, proxy_config
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=403,
detail="Only admin users can view config overrides",
)
if prisma_client is None:
raise HTTPException(
status_code=500,
detail=CommonProxyErrors.db_not_connected_error.value,
)
field_schema = _build_field_schema(HashicorpVaultConfig)
# Try to load from DB
db_record = await prisma_client.db.litellm_configoverrides.find_unique(
where={"config_type": "hashicorp_vault"}
)
if db_record is not None and db_record.config_value is not None:
config_data = _parse_config_value(db_record.config_value)
# Decrypt then mask sensitive fields so plaintext secrets are never sent to the UI
decrypted_data = proxy_config._decrypt_db_variables(config_data)
masked_data = _mask_sensitive_fields(decrypted_data, HASHICORP_SENSITIVE_FIELDS)
return ConfigOverrideSettingsResponse(
config_type="hashicorp_vault",
values=masked_data,
field_schema=field_schema,
)
# Fallback to env vars — also mask sensitive values
env_values = _get_current_env_values(HASHICORP_ENV_VAR_MAPPING)
masked_env_values = _mask_sensitive_fields(env_values, HASHICORP_SENSITIVE_FIELDS)
return ConfigOverrideSettingsResponse(
config_type="hashicorp_vault",
values=masked_env_values,
field_schema=field_schema,
)
@router.delete(
"/config_overrides/hashicorp_vault",
tags=["Config Overrides"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_hashicorp_vault_config(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""Delete Hashicorp Vault configuration. Idempotent."""
from litellm.proxy.proxy_server import prisma_client, proxy_config
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=403,
detail="Only admin users can delete config overrides",
)
if prisma_client is None:
raise HTTPException(
status_code=500,
detail=CommonProxyErrors.db_not_connected_error.value,
)
# Delete DB record if it exists — ignore if not found
try:
await prisma_client.db.litellm_configoverrides.delete(
where={"config_type": "hashicorp_vault"}
)
except RecordNotFoundError:
verbose_proxy_logger.debug(
"No existing Hashicorp Vault config record to delete"
)
_clear_hashicorp_vault_state(proxy_config)
return {
"message": "Hashicorp Vault configuration deleted successfully",
"status": "success",
}
@router.post(
"/config_overrides/hashicorp_vault/test_connection",
tags=["Config Overrides"],
dependencies=[Depends(user_api_key_auth)],
)
async def test_hashicorp_vault_connection(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Test the connection to the currently configured Hashicorp Vault.
Uses the already-initialized secret manager client. Does not modify any state.
"""
from litellm.secret_managers.hashicorp_secret_manager import (
HashicorpSecretManager,
)
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=403,
detail="Only admin users can test Vault connection",
)
client = litellm.secret_manager_client
if not isinstance(client, HashicorpSecretManager):
raise HTTPException(
status_code=400,
detail="Hashicorp Vault is not configured. Save a configuration first.",
)
# Step 1: Authenticate (exercises AppRole login, TLS cert login, or direct token)
try:
headers = await asyncio.to_thread(client._get_request_headers)
except Exception as e:
raise HTTPException(
status_code=502,
detail=f"Vault authentication failed: {e}",
)
# Step 2: Verify the token is valid via token/lookup-self
try:
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager
)
lookup_url = f"{client.vault_addr}/v1/auth/token/lookup-self"
if client.vault_namespace:
headers["X-Vault-Namespace"] = client.vault_namespace
response = await async_client.get(lookup_url, headers=headers)
response.raise_for_status()
except Exception as e:
raise HTTPException(
status_code=502,
detail=f"Vault token validation failed: {e}",
)
return {
"status": "success",
"message": f"Successfully connected to Vault at {client.vault_addr}",
}

View File

@@ -0,0 +1,588 @@
"""
COST TRACKING SETTINGS MANAGEMENT
Endpoints for managing cost discount and margin configuration
GET /config/cost_discount_config - Get current cost discount configuration
PATCH /config/cost_discount_config - Update cost discount configuration
GET /config/cost_margin_config - Get current cost margin configuration
PATCH /config/cost_margin_config - Update cost margin configuration
POST /cost/estimate - Estimate cost for a given model and token counts
"""
from typing import Dict, Optional, Tuple, Union
from fastapi import APIRouter, Depends, HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.cost_calculator import completion_cost
from litellm.proxy._types import (
CommonProxyErrors,
CostEstimateRequest,
CostEstimateResponse,
UserAPIKeyAuth,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.utils import LlmProvidersSet
router = APIRouter()
def _resolve_model_for_cost_lookup(model: str) -> Tuple[str, Optional[str]]:
"""
Resolve a model name (which may be a router alias/model_group) to the
underlying litellm model name for cost lookup.
Args:
model: The model name from the request (could be a router alias like 'e-model-router'
or an actual model name like 'azure_ai/gpt-4')
Returns:
Tuple of (resolved_model_name, custom_llm_provider)
- resolved_model_name: The actual model name to use for cost lookup
- custom_llm_provider: The provider if resolved from router, None otherwise
"""
from litellm.proxy.proxy_server import llm_router
custom_llm_provider: Optional[str] = None
# Try to resolve from router if available
if llm_router is not None:
try:
# Get deployments for this model name (handles aliases, wildcards, etc.)
deployments = llm_router.get_model_list(model_name=model)
if deployments and len(deployments) > 0:
first_deployment = deployments[0]
litellm_params = first_deployment.get("litellm_params", {})
model_info = first_deployment.get("model_info", {})
# Check base_model first (needed for Azure custom deployment names)
base_model = model_info.get("base_model") or litellm_params.get(
"base_model"
)
if base_model:
verbose_proxy_logger.debug(
f"Resolved model '{model}' to base_model '{base_model}' from router"
)
custom_llm_provider = litellm_params.get("custom_llm_provider")
return (
str(base_model),
str(custom_llm_provider)
if custom_llm_provider is not None
else None,
)
resolved_model = litellm_params.get("model")
if resolved_model:
verbose_proxy_logger.debug(
f"Resolved model '{model}' to '{resolved_model}' from router"
)
custom_llm_provider = litellm_params.get("custom_llm_provider")
return (
str(resolved_model),
str(custom_llm_provider)
if custom_llm_provider is not None
else None,
)
except Exception as e:
verbose_proxy_logger.debug(
f"Could not resolve model '{model}' from router: {e}"
)
# Return original model if not resolved
return model, custom_llm_provider
def _calculate_period_costs(
num_requests, cost_per_request, input_cost, output_cost, margin_cost
):
"""
Calculate costs for a given number of requests.
Returns tuple of (total_cost, input_cost, output_cost, margin_cost) or all None if num_requests is None/0.
"""
if not num_requests:
return None, None, None, None
return (
cost_per_request * num_requests,
input_cost * num_requests,
output_cost * num_requests,
margin_cost * num_requests,
)
@router.get(
"/config/cost_discount_config",
tags=["Cost Tracking"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_cost_discount_config(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get current cost discount configuration.
Returns the cost_discount_config from litellm_settings.
"""
from litellm.proxy.proxy_server import prisma_client, proxy_config
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
try:
# Load config from DB
config = await proxy_config.get_config()
# Get cost_discount_config from litellm_settings
litellm_settings = config.get("litellm_settings", {})
cost_discount_config = litellm_settings.get("cost_discount_config", {})
return {"values": cost_discount_config}
except Exception as e:
verbose_proxy_logger.error(f"Error fetching cost discount config: {str(e)}")
return {"values": {}}
@router.patch(
"/config/cost_discount_config",
tags=["Cost Tracking"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_cost_discount_config(
cost_discount_config: Dict[str, float],
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Update cost discount configuration.
Updates the cost_discount_config in litellm_settings.
Discounts should be between 0 and 1 (e.g., 0.05 = 5% discount).
Example:
```json
{
"vertex_ai": 0.05,
"gemini": 0.05,
"openai": 0.01
}
```
"""
from litellm.proxy.proxy_server import (
prisma_client,
proxy_config,
store_model_in_db,
)
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if store_model_in_db is not True:
raise HTTPException(
status_code=500,
detail={
"error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature."
},
)
# Validate that all providers are valid LiteLLM providers
invalid_providers = []
for provider in cost_discount_config.keys():
if provider not in LlmProvidersSet:
invalid_providers.append(provider)
if invalid_providers:
raise HTTPException(
status_code=400,
detail={
"error": f"Invalid provider(s): {', '.join(invalid_providers)}. Must be valid LiteLLM providers. See https://docs.litellm.ai/docs/providers for the full list."
},
)
# Validate discount values are between 0 and 1
for provider, discount in cost_discount_config.items():
if not isinstance(discount, (int, float)):
raise HTTPException(
status_code=400, detail=f"Discount for {provider} must be a number"
)
if not (0 <= discount <= 1):
raise HTTPException(
status_code=400,
detail=f"Discount for {provider} must be between 0 and 1 (0% to 100%)",
)
try:
# Load existing config
config = await proxy_config.get_config()
# Ensure litellm_settings exists
if "litellm_settings" not in config:
config["litellm_settings"] = {}
# Update cost_discount_config
config["litellm_settings"]["cost_discount_config"] = cost_discount_config
# Save the updated config to DB
await proxy_config.save_config(new_config=config)
# Update in-memory litellm.cost_discount_config
litellm.cost_discount_config = cost_discount_config
verbose_proxy_logger.info(
f"Updated cost_discount_config: {cost_discount_config}"
)
return {
"message": "Cost discount configuration updated successfully",
"status": "success",
"values": cost_discount_config,
}
except Exception as e:
verbose_proxy_logger.error(f"Error updating cost discount config: {str(e)}")
raise HTTPException(
status_code=500,
detail={"error": f"Failed to update cost discount config: {str(e)}"},
)
@router.get(
"/config/cost_margin_config",
tags=["Cost Tracking"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_cost_margin_config(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get current cost margin configuration.
Returns the cost_margin_config from litellm_settings.
"""
from litellm.proxy.proxy_server import prisma_client, proxy_config
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
try:
# Load config from DB
config = await proxy_config.get_config()
# Get cost_margin_config from litellm_settings
litellm_settings = config.get("litellm_settings", {})
cost_margin_config = litellm_settings.get("cost_margin_config", {})
return {"values": cost_margin_config}
except Exception as e:
verbose_proxy_logger.error(f"Error fetching cost margin config: {str(e)}")
return {"values": {}}
@router.patch(
"/config/cost_margin_config",
tags=["Cost Tracking"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_cost_margin_config(
cost_margin_config: Dict[str, Union[float, Dict[str, float]]],
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Update cost margin configuration.
Updates the cost_margin_config in litellm_settings.
Margins can be:
- Percentage: {"openai": 0.10} = 10% margin
- Fixed amount: {"openai": {"fixed_amount": 0.001}} = $0.001 per request
- Combined: {"vertex_ai": {"percentage": 0.08, "fixed_amount": 0.0005}}
- Global: {"global": 0.05} = 5% global margin on all providers
Example:
```json
{
"global": 0.05,
"openai": 0.10,
"anthropic": {"fixed_amount": 0.001},
"vertex_ai": {"percentage": 0.08, "fixed_amount": 0.0005}
}
```
"""
from litellm.proxy.proxy_server import (
prisma_client,
proxy_config,
store_model_in_db,
)
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if store_model_in_db is not True:
raise HTTPException(
status_code=500,
detail={
"error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature."
},
)
# Validate that all providers are valid LiteLLM providers (except "global")
invalid_providers = []
for provider in cost_margin_config.keys():
if provider != "global" and provider not in LlmProvidersSet:
invalid_providers.append(provider)
if invalid_providers:
raise HTTPException(
status_code=400,
detail={
"error": f"Invalid provider(s): {', '.join(invalid_providers)}. Must be valid LiteLLM providers or 'global'. See https://docs.litellm.ai/docs/providers for the full list."
},
)
# Validate margin values
for provider, margin_value in cost_margin_config.items():
if isinstance(margin_value, (int, float)):
# Simple percentage format: {"openai": 0.10}
if not (0 <= margin_value <= 10): # Allow up to 1000% margin
raise HTTPException(
status_code=400,
detail=f"Margin percentage for {provider} must be between 0 and 10 (0% to 1000%)",
)
elif isinstance(margin_value, dict):
# Complex format: {"percentage": 0.08, "fixed_amount": 0.0005}
if "percentage" in margin_value:
percentage = margin_value["percentage"]
if not isinstance(percentage, (int, float)):
raise HTTPException(
status_code=400,
detail=f"Margin percentage for {provider} must be a number",
)
if not (0 <= percentage <= 10):
raise HTTPException(
status_code=400,
detail=f"Margin percentage for {provider} must be between 0 and 10 (0% to 1000%)",
)
if "fixed_amount" in margin_value:
fixed_amount = margin_value["fixed_amount"]
if not isinstance(fixed_amount, (int, float)):
raise HTTPException(
status_code=400,
detail=f"Fixed margin amount for {provider} must be a number",
)
if fixed_amount < 0:
raise HTTPException(
status_code=400,
detail=f"Fixed margin amount for {provider} must be non-negative",
)
if not margin_value: # Empty dict
raise HTTPException(
status_code=400,
detail=f"Margin config for {provider} cannot be empty. Must include 'percentage' and/or 'fixed_amount'",
)
else:
raise HTTPException(
status_code=400,
detail=f"Margin for {provider} must be a number (percentage) or dict with 'percentage' and/or 'fixed_amount'",
)
try:
# Load existing config
config = await proxy_config.get_config()
# Ensure litellm_settings exists
if "litellm_settings" not in config:
config["litellm_settings"] = {}
# Update cost_margin_config
config["litellm_settings"]["cost_margin_config"] = cost_margin_config
# Save the updated config to DB
await proxy_config.save_config(new_config=config)
# Update in-memory litellm.cost_margin_config
litellm.cost_margin_config = cost_margin_config
verbose_proxy_logger.info(f"Updated cost_margin_config: {cost_margin_config}")
return {
"message": "Cost margin configuration updated successfully",
"status": "success",
"values": cost_margin_config,
}
except Exception as e:
verbose_proxy_logger.error(f"Error updating cost margin config: {str(e)}")
raise HTTPException(
status_code=500,
detail={"error": f"Failed to update cost margin config: {str(e)}"},
)
@router.post(
"/cost/estimate",
tags=["Cost Tracking"],
dependencies=[Depends(user_api_key_auth)],
response_model=CostEstimateResponse,
)
async def estimate_cost(
request: CostEstimateRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> CostEstimateResponse:
"""
Estimate cost for a given model and token counts.
This endpoint uses the same cost calculation logic as actual requests,
including any configured margins and discounts.
Parameters:
- model: Model name (e.g., "gpt-4", "claude-3-opus")
- input_tokens: Expected input tokens per request
- output_tokens: Expected output tokens per request
- num_requests_per_day: Number of requests per day (optional)
- num_requests_per_month: Number of requests per month (optional)
Returns cost breakdown including:
- Per-request costs (input, output, margin)
- Daily costs (if num_requests_per_day provided)
- Monthly costs (if num_requests_per_month provided)
Example:
```json
{
"model": "gpt-4",
"input_tokens": 1000,
"output_tokens": 500,
"num_requests_per_day": 100,
"num_requests_per_month": 3000
}
```
"""
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.utils import ModelResponse, Usage
# Resolve model name (handles router aliases like 'e-model-router' -> 'azure_ai/gpt-4')
resolved_model, resolved_provider = _resolve_model_for_cost_lookup(request.model)
verbose_proxy_logger.debug(
f"Cost estimate: request.model='{request.model}' resolved to '{resolved_model}'"
)
# Create a mock response with usage for completion_cost
mock_response = ModelResponse(
model=resolved_model,
usage=Usage(
prompt_tokens=request.input_tokens,
completion_tokens=request.output_tokens,
total_tokens=request.input_tokens + request.output_tokens,
),
)
# Create a logging object to capture cost breakdown
litellm_logging_obj = LiteLLMLoggingObj(
model=resolved_model,
messages=[],
stream=False,
call_type="completion",
start_time=None,
litellm_call_id="cost-estimate",
function_id="cost-estimate",
)
# Use completion_cost which handles all the logic including margins/discounts
try:
cost_per_request = completion_cost(
completion_response=mock_response,
model=resolved_model,
litellm_logging_obj=litellm_logging_obj,
)
except Exception as e:
raise HTTPException(
status_code=404,
detail={
"error": f"Could not calculate cost for model '{request.model}' (resolved to '{resolved_model}'): {str(e)}"
},
)
# Get cost breakdown from the logging object
cost_breakdown = litellm_logging_obj.cost_breakdown
input_cost = cost_breakdown.get("input_cost", 0.0) if cost_breakdown else 0.0
output_cost = cost_breakdown.get("output_cost", 0.0) if cost_breakdown else 0.0
margin_cost = (
cost_breakdown.get("margin_total_amount", 0.0) if cost_breakdown else 0.0
)
# Get model info for per-token pricing display
try:
model_info = litellm.get_model_info(model=resolved_model)
input_cost_per_token = model_info.get("input_cost_per_token")
output_cost_per_token = model_info.get("output_cost_per_token")
custom_llm_provider = model_info.get("litellm_provider")
except Exception:
input_cost_per_token = None
output_cost_per_token = None
custom_llm_provider = None
# Use provider from router resolution if not found in model_info
if custom_llm_provider is None and resolved_provider is not None:
custom_llm_provider = resolved_provider
# Calculate daily and monthly costs
(
daily_cost,
daily_input_cost,
daily_output_cost,
daily_margin_cost,
) = _calculate_period_costs(
num_requests=request.num_requests_per_day,
cost_per_request=cost_per_request,
input_cost=input_cost,
output_cost=output_cost,
margin_cost=margin_cost,
)
(
monthly_cost,
monthly_input_cost,
monthly_output_cost,
monthly_margin_cost,
) = _calculate_period_costs(
num_requests=request.num_requests_per_month,
cost_per_request=cost_per_request,
input_cost=input_cost,
output_cost=output_cost,
margin_cost=margin_cost,
)
return CostEstimateResponse(
model=request.model,
input_tokens=request.input_tokens,
output_tokens=request.output_tokens,
num_requests_per_day=request.num_requests_per_day,
num_requests_per_month=request.num_requests_per_month,
cost_per_request=cost_per_request,
input_cost_per_request=input_cost,
output_cost_per_request=output_cost,
margin_cost_per_request=margin_cost,
daily_cost=daily_cost,
daily_input_cost=daily_input_cost,
daily_output_cost=daily_output_cost,
daily_margin_cost=daily_margin_cost,
monthly_cost=monthly_cost,
monthly_input_cost=monthly_input_cost,
monthly_output_cost=monthly_output_cost,
monthly_margin_cost=monthly_margin_cost,
input_cost_per_token=input_cost_per_token,
output_cost_per_token=output_cost_per_token,
provider=custom_llm_provider,
)

View File

@@ -0,0 +1,925 @@
"""
CUSTOMER MANAGEMENT
All /customer management endpoints
/customer/new
/customer/info
/customer/update
/customer/delete
"""
#### END-USER/CUSTOMER MANAGEMENT ####
from datetime import datetime, timedelta
from typing import List, Optional
import fastapi
from fastapi import APIRouter, Depends, HTTPException, Request
import litellm
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.common_daily_activity import get_daily_activity
from litellm.proxy.management_helpers.object_permission_utils import (
_set_object_permission,
handle_update_object_permission_common,
)
from litellm.proxy.utils import handle_exception_on_proxy
from litellm.types.proxy.management_endpoints.common_daily_activity import (
SpendAnalyticsPaginatedResponse,
)
router = APIRouter()
@router.post(
"/end_user/block",
tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)],
include_in_schema=False,
)
@router.post(
"/customer/block",
tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)],
)
async def block_user(data: BlockUsers):
"""
[BETA] Reject calls with this end-user id
Parameters:
- user_ids (List[str], required): The unique `user_id`s for the users to block
(any /chat/completion call with this user={end-user-id} param, will be rejected.)
```
curl -X POST "http://0.0.0.0:8000/user/block"
-H "Authorization: Bearer sk-1234"
-d '{
"user_ids": [<user_id>, ...]
}'
```
"""
from litellm.proxy.proxy_server import prisma_client
try:
records = []
if prisma_client is not None:
for id in data.user_ids:
record = await prisma_client.db.litellm_endusertable.upsert(
where={"user_id": id}, # type: ignore
data={
"create": {"user_id": id, "blocked": True}, # type: ignore
"update": {"blocked": True},
},
)
records.append(record)
else:
raise HTTPException(
status_code=500,
detail={"error": "Postgres DB Not connected"},
)
return {"blocked_users": records}
except Exception as e:
verbose_proxy_logger.error(f"An error occurred - {str(e)}")
raise HTTPException(status_code=500, detail={"error": str(e)})
@router.post(
"/end_user/unblock",
tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)],
include_in_schema=False,
)
@router.post(
"/customer/unblock",
tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)],
)
async def unblock_user(data: BlockUsers):
"""
[BETA] Unblock calls with this user id
Example
```
curl -X POST "http://0.0.0.0:8000/user/unblock"
-H "Authorization: Bearer sk-1234"
-d '{
"user_ids": [<user_id>, ...]
}'
```
"""
try:
from enterprise.enterprise_hooks.blocked_user_list import (
_ENTERPRISE_BlockedUserList,
)
except ImportError:
raise HTTPException(
status_code=400,
detail={
"error": "Blocked user check was never set. This call has no effect."
+ CommonProxyErrors.missing_enterprise_package_docker.value
},
)
if (
not any(isinstance(x, _ENTERPRISE_BlockedUserList) for x in litellm.callbacks)
or litellm.blocked_user_list is None
):
raise HTTPException(
status_code=400,
detail={
"error": "Blocked user check was never set. This call has no effect."
},
)
if isinstance(litellm.blocked_user_list, list):
for id in data.user_ids:
litellm.blocked_user_list.remove(id)
else:
raise HTTPException(
status_code=500,
detail={
"error": "`blocked_user_list` must be set as a list. Filepaths can't be updated."
},
)
return {"blocked_users": litellm.blocked_user_list}
def new_budget_request(data: NewCustomerRequest) -> Optional[BudgetNewRequest]:
"""
Return a new budget object if new budget params are passed.
"""
budget_params = BudgetNewRequest.model_fields.keys()
budget_kv_pairs = {}
# Get the actual values from the data object using getattr
for field_name in budget_params:
if field_name == "budget_id":
continue
value = getattr(data, field_name, None)
if value is not None:
budget_kv_pairs[field_name] = value
if budget_kv_pairs:
budget_request = BudgetNewRequest(**budget_kv_pairs)
if (
budget_request.budget_reset_at is None
and budget_request.budget_duration is not None
):
budget_request.budget_reset_at = datetime.utcnow() + timedelta(
seconds=duration_in_seconds(duration=budget_request.budget_duration)
)
return budget_request
return None
async def _handle_customer_object_permission_update(
non_default_values: dict,
end_user_table_data_typed: Optional[LiteLLM_EndUserTable],
update_end_user_table_data: dict,
prisma_client,
) -> None:
"""
Handle object permission updates for customer endpoints.
Updates the update_end_user_table_data dict in place with the new object_permission_id.
Args:
non_default_values: Dictionary containing the update values including object_permission
end_user_table_data_typed: Existing end user table data
update_end_user_table_data: Dictionary to update with new object_permission_id
prisma_client: Prisma database client
"""
if "object_permission" in non_default_values:
existing_object_permission_id = (
end_user_table_data_typed.object_permission_id
if end_user_table_data_typed is not None
else None
)
object_permission_id = await handle_update_object_permission_common(
data_json=non_default_values,
existing_object_permission_id=existing_object_permission_id,
prisma_client=prisma_client,
)
if object_permission_id is not None:
update_end_user_table_data["object_permission_id"] = object_permission_id
@router.post(
"/end_user/new",
tags=["Customer Management"],
include_in_schema=False,
dependencies=[Depends(user_api_key_auth)],
)
@router.post(
"/customer/new",
tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)],
)
async def new_end_user(
data: NewCustomerRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Allow creating a new Customer
Parameters:
- user_id: str - The unique identifier for the user.
- alias: Optional[str] - A human-friendly alias for the user.
- blocked: bool - Flag to allow or disallow requests for this end-user. Default is False.
- max_budget: Optional[float] - The maximum budget allocated to the user. Either 'max_budget' or 'budget_id' should be provided, not both.
- budget_id: Optional[str] - The identifier for an existing budget allocated to the user. Either 'max_budget' or 'budget_id' should be provided, not both.
- allowed_model_region: Optional[Union[Literal["eu"], Literal["us"]]] - Require all user requests to use models in this specific region.
- default_model: Optional[str] - If no equivalent model in the allowed region, default all requests to this model.
- metadata: Optional[dict] = Metadata for customer, store information for customer. Example metadata = {"data_training_opt_out": True}
- budget_duration: Optional[str] - Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
- tpm_limit: Optional[int] - [Not Implemented Yet] Specify tpm limit for a given customer (Tokens per minute)
- rpm_limit: Optional[int] - [Not Implemented Yet] Specify rpm limit for a given customer (Requests per minute)
- model_max_budget: Optional[dict] - [Not Implemented Yet] Specify max budget for a given model. Example: {"openai/gpt-4o-mini": {"max_budget": 100.0, "budget_duration": "1d"}}
- max_parallel_requests: Optional[int] - [Not Implemented Yet] Specify max parallel requests for a given customer.
- soft_budget: Optional[float] - [Not Implemented Yet] Get alerts when customer crosses given budget, doesn't block requests.
- spend: Optional[float] - Specify initial spend for a given customer.
- budget_reset_at: Optional[str] - Specify the date and time when the budget should be reset.
- object_permission: Optional[LiteLLM_ObjectPermissionBase] - Customer-specific object permissions to control access to resources.
Supported fields:
* mcp_servers: List[str] - List of allowed MCP server IDs
* mcp_access_groups: List[str] - List of MCP access group names
* mcp_tool_permissions: Dict[str, List[str]] - Map of server ID to allowed tool names (e.g., {"server_1": ["tool_a", "tool_b"]})
* vector_stores: List[str] - List of allowed vector store IDs
* agents: List[str] - List of allowed agent IDs
* agent_access_groups: List[str] - List of agent access group names
Example: {"mcp_servers": ["server_1", "server_2"], "vector_stores": ["vector_store_1"], "agents": ["agent_1"]}
IF null or {} then no object-level restrictions apply.
- Allow specifying allowed regions
- Allow specifying default model
Example curl:
```
curl --location 'http://0.0.0.0:4000/customer/new' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"user_id" : "ishaan-jaff-3",
"allowed_region": "eu",
"budget_id": "free_tier",
"default_model": "azure/gpt-3.5-turbo-eu"
}'
# With object permissions
curl -L -X POST 'http://localhost:4000/customer/new' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"user_id": "user_1",
"object_permission": {
"mcp_servers": ["server_1"],
"mcp_access_groups": ["public_group"],
"vector_stores": ["vector_store_1"]
}
}'
# return end-user object
```
NOTE: This used to be called `/end_user/new`, we will still be maintaining compatibility for /end_user/XXX for these endpoints
"""
"""
Validation:
- check if default model exists
- create budget object if not already created
- Add user to end user table
Return
- end-user object
- currently allowed models
"""
from litellm.proxy.proxy_server import (
litellm_proxy_admin_name,
llm_router,
prisma_client,
)
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
try:
## VALIDATION ##
if data.default_model is not None:
if llm_router is None:
raise HTTPException(
status_code=422,
detail={"error": CommonProxyErrors.no_llm_router.value},
)
elif data.default_model not in llm_router.get_model_names():
raise HTTPException(
status_code=422,
detail={
"error": "Default Model not on proxy. Configure via `/model/new` or config.yaml. Default_model={}, proxy_model_names={}".format(
data.default_model, set(llm_router.get_model_names())
)
},
)
new_end_user_obj: Dict = {}
## CREATE BUDGET ## if set
_new_budget = new_budget_request(data)
if _new_budget is not None:
try:
budget_record = await prisma_client.db.litellm_budgettable.create(
data={
**_new_budget.model_dump(exclude_unset=True),
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, # type: ignore
"updated_by": user_api_key_dict.user_id
or litellm_proxy_admin_name,
}
)
except Exception as e:
raise HTTPException(status_code=422, detail={"error": str(e)})
new_end_user_obj["budget_id"] = budget_record.budget_id
elif data.budget_id is not None:
new_end_user_obj["budget_id"] = data.budget_id
_user_data = data.dict(exclude_none=True)
for k, v in _user_data.items():
if k not in BudgetNewRequest.model_fields.keys():
new_end_user_obj[k] = v
## Handle Object Permission - MCP Servers, Vector Stores etc.
new_end_user_obj = await _set_object_permission(
data_json=new_end_user_obj,
prisma_client=prisma_client,
)
# Ensure object_permission is not in the data being sent to create
# It should have been converted to object_permission_id by _set_object_permission
if "object_permission" in new_end_user_obj:
verbose_proxy_logger.warning(
f"object_permission still in new_end_user_obj after _set_object_permission: {new_end_user_obj.get('object_permission')}"
)
new_end_user_obj.pop("object_permission", None)
## WRITE TO DB ##
end_user_record = await prisma_client.db.litellm_endusertable.create(
data=new_end_user_obj, # type: ignore
include={"litellm_budget_table": True, "object_permission": True},
)
# Convert to dict and clean up recursive fields
response_dict = end_user_record.model_dump()
if response_dict.get("object_permission"):
# Remove reverse relations from object_permission
for field in [
"teams",
"verification_tokens",
"organizations",
"users",
"end_users",
]:
response_dict["object_permission"].pop(field, None)
return response_dict
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.management_endpoints.customer_endpoints.new_end_user(): Exception occured - {}".format(
str(e)
)
)
if "Unique constraint failed on the fields: (`user_id`)" in str(e):
raise ProxyException(
message=f"Customer already exists, passed user_id={data.user_id}. Please pass a new user_id.",
type="bad_request",
code=400,
param="user_id",
)
raise handle_exception_on_proxy(e)
@router.get(
"/customer/info",
tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)],
response_model=LiteLLM_EndUserTable,
)
@router.get(
"/end_user/info",
tags=["Customer Management"],
include_in_schema=False,
dependencies=[Depends(user_api_key_auth)],
)
async def end_user_info(
end_user_id: str = fastapi.Query(
description="End User ID in the request parameters"
),
):
"""
Get information about an end-user. An `end_user` is a customer (external user) of the proxy.
Parameters:
- end_user_id (str, required): The unique identifier for the end-user
Example curl:
```
curl -X GET 'http://localhost:4000/customer/info?end_user_id=test-litellm-user-4' \
-H 'Authorization: Bearer sk-1234'
```
"""
try:
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
user_info = await prisma_client.db.litellm_endusertable.find_first(
where={"user_id": end_user_id},
include={"litellm_budget_table": True, "object_permission": True},
)
if user_info is None:
raise ProxyException(
message="End User Id={} does not exist in db".format(end_user_id),
type="not_found",
code=404,
param="end_user_id",
)
# Convert to dict and clean up recursive fields
response_dict = user_info.model_dump(exclude_none=True)
if response_dict.get("object_permission"):
# Remove reverse relations from object_permission
for field in [
"teams",
"verification_tokens",
"organizations",
"users",
"end_users",
]:
response_dict["object_permission"].pop(field, None)
return response_dict
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.management_endpoints.customer_endpoints.end_user_info(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e)
@router.post(
"/customer/update",
tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)],
)
@router.post(
"/end_user/update",
tags=["Customer Management"],
include_in_schema=False,
dependencies=[Depends(user_api_key_auth)],
)
async def update_end_user(
data: UpdateCustomerRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Example curl
Parameters:
- user_id: str
- alias: Optional[str] = None # human-friendly alias
- blocked: bool = False # allow/disallow requests for this end-user
- max_budget: Optional[float] = None
- budget_id: Optional[str] = None # give either a budget_id or max_budget
- allowed_model_region: Optional[AllowedModelRegion] = (
None # require all user requests to use models in this specific region
)
- default_model: Optional[str] = (
None # if no equivalent model in allowed region - default all requests to this model
)
- object_permission: Optional[LiteLLM_ObjectPermissionBase] - Customer-specific object permissions to control access to resources.
Supported fields:
* mcp_servers: List[str] - List of allowed MCP server IDs
* mcp_access_groups: List[str] - List of MCP access group names
* mcp_tool_permissions: Dict[str, List[str]] - Map of server ID to allowed tool names
* vector_stores: List[str] - List of allowed vector store IDs
* agents: List[str] - List of allowed agent IDs
* agent_access_groups: List[str] - List of agent access group names
Example: {"mcp_servers": ["server_1"], "vector_stores": ["vector_store_1"]}
IF null or {} then no object-level restrictions apply.
Example curl:
```
curl --location 'http://0.0.0.0:4000/customer/update' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"user_id": "test-litellm-user-4",
"budget_id": "paid_tier"
}'
# Updating object permissions
curl -L -X POST 'http://localhost:4000/customer/update' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"user_id": "user_1",
"object_permission": {
"mcp_servers": ["server_3"],
"vector_stores": ["vector_store_2", "vector_store_3"]
}
}'
See below for all params
```
"""
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
try:
data_json: dict = data.json()
# get the row from db
if prisma_client is None:
raise Exception("Not connected to DB!")
# get non default values for key
non_default_values = {}
for k, v in data_json.items():
if v is not None and v not in (
[],
{},
0,
): # models default to [], spend defaults to 0, we should not reset these values
non_default_values[k] = v
## Get end user table data ##
end_user_table_data = await prisma_client.db.litellm_endusertable.find_first(
where={"user_id": data.user_id}, include={"litellm_budget_table": True}
)
if end_user_table_data is None:
raise ProxyException(
message="End User Id={} does not exist in db".format(data.user_id),
type="not_found",
code=404,
param="user_id",
)
end_user_table_data_typed = LiteLLM_EndUserTable(
**end_user_table_data.model_dump()
)
## Get budget table data ##
end_user_budget_table = end_user_table_data_typed.litellm_budget_table
## Get all params for budget table ##
budget_table_data = {}
update_end_user_table_data = {}
for k, v in non_default_values.items():
# budget_id is for linking to existing budget, not for creating new budget
if k == "budget_id":
update_end_user_table_data[k] = v
elif k in LiteLLM_BudgetTable.model_fields.keys():
budget_table_data[k] = v
elif k in LiteLLM_EndUserTable.model_fields.keys():
update_end_user_table_data[k] = v
## Handle object permission updates (MCP servers, vector stores, etc.)
await _handle_customer_object_permission_update(
non_default_values=non_default_values,
end_user_table_data_typed=end_user_table_data_typed,
update_end_user_table_data=update_end_user_table_data,
prisma_client=prisma_client,
)
## Check if we need to create a new budget (only if budget fields are provided, not just budget_id) ##
if budget_table_data:
if end_user_budget_table is None:
## Create new budget ##
budget_table_data_record = (
await prisma_client.db.litellm_budgettable.create(
data={
**budget_table_data,
"created_by": user_api_key_dict.user_id
or litellm_proxy_admin_name,
"updated_by": user_api_key_dict.user_id
or litellm_proxy_admin_name,
},
include={"end_users": True},
)
)
update_end_user_table_data[
"budget_id"
] = budget_table_data_record.budget_id
else:
## Update existing budget ##
budget_table_data_record = (
await prisma_client.db.litellm_budgettable.update(
where={"budget_id": end_user_budget_table.budget_id},
data=budget_table_data,
)
)
## Update user table, with update params + new budget id (if set) ##
verbose_proxy_logger.debug("/customer/update: Received data = %s", data)
# Ensure object_permission is not in the update data
# It should have been converted to object_permission_id by handle_update_object_permission_common
if "object_permission" in update_end_user_table_data:
verbose_proxy_logger.warning(
f"object_permission still in update_end_user_table_data: {update_end_user_table_data.get('object_permission')}"
)
update_end_user_table_data.pop("object_permission", None)
if data.user_id is not None and len(data.user_id) > 0:
update_end_user_table_data["user_id"] = data.user_id # type: ignore
verbose_proxy_logger.debug("In update customer, user_id condition block.")
response = await prisma_client.db.litellm_endusertable.update(
where={"user_id": data.user_id}, data=update_end_user_table_data, include={"litellm_budget_table": True, "object_permission": True} # type: ignore
)
if response is None:
raise ValueError(
f"Failed updating customer data. User ID does not exist passed user_id={data.user_id}"
)
verbose_proxy_logger.debug(
f"received response from updating prisma client. response={response}"
)
# Convert to dict and clean up recursive fields
response_dict = response.model_dump()
if response_dict.get("object_permission"):
# Remove reverse relations from object_permission
for field in [
"teams",
"verification_tokens",
"organizations",
"users",
"end_users",
]:
response_dict["object_permission"].pop(field, None)
return response_dict
else:
raise ValueError(f"user_id is required, passed user_id = {data.user_id}")
# update based on remaining passed in values
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.update_end_user(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e)
@router.post(
"/customer/delete",
tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)],
)
@router.post(
"/end_user/delete",
tags=["Customer Management"],
include_in_schema=False,
dependencies=[Depends(user_api_key_auth)],
)
async def delete_end_user(
data: DeleteCustomerRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Delete multiple end-users.
Parameters:
- user_ids (List[str], required): The unique `user_id`s for the users to delete
Example curl:
```
curl --location 'http://0.0.0.0:4000/customer/delete' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"user_ids" :["ishaan-jaff-5"]
}'
See below for all params
```
"""
from litellm.proxy.proxy_server import prisma_client
try:
if prisma_client is None:
raise Exception("Not connected to DB!")
verbose_proxy_logger.debug("/customer/delete: Received data = %s", data)
if (
data.user_ids is not None
and isinstance(data.user_ids, list)
and len(data.user_ids) > 0
):
# First check if all users exist
existing_users = await prisma_client.db.litellm_endusertable.find_many(
where={"user_id": {"in": data.user_ids}}
)
existing_user_ids = {user.user_id for user in existing_users}
missing_user_ids = [
user_id for user_id in data.user_ids if user_id not in existing_user_ids
]
if missing_user_ids:
raise ProxyException(
message="End User Id(s)={} do not exist in db".format(
", ".join(missing_user_ids)
),
type="not_found",
code=404,
param="user_ids",
)
# All users exist, proceed with deletion
response = await prisma_client.db.litellm_endusertable.delete_many(
where={"user_id": {"in": data.user_ids}}
)
verbose_proxy_logger.debug(
f"received response from updating prisma client. response={response}"
)
return {
"deleted_customers": response,
"message": "Successfully deleted customers with ids: "
+ str(data.user_ids),
}
else:
raise ValueError(f"user_id is required, passed user_id = {data.user_ids}")
# update based on remaining passed in values
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.delete_end_user(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e)
@router.get(
"/customer/list",
tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)],
response_model=List[LiteLLM_EndUserTable],
)
@router.get(
"/end_user/list",
tags=["Customer Management"],
include_in_schema=False,
dependencies=[Depends(user_api_key_auth)],
)
async def list_end_user(
http_request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
[Admin-only] List all available customers
Example curl:
```
curl --location --request GET 'http://0.0.0.0:4000/customer/list' \
--header 'Authorization: Bearer sk-1234'
```
"""
try:
from litellm.proxy.proxy_server import prisma_client
if (
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY
):
raise HTTPException(
status_code=401,
detail={
"error": "Admin-only endpoint. Your user role={}".format(
user_api_key_dict.user_role
)
},
)
if prisma_client is None:
raise HTTPException(
status_code=400,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
response = await prisma_client.db.litellm_endusertable.find_many(
include={"litellm_budget_table": True, "object_permission": True}
)
returned_response: List[LiteLLM_EndUserTable] = []
for item in response:
item_dict = item.model_dump()
# Remove reverse relations from object_permission
if item_dict.get("object_permission"):
for field in [
"teams",
"verification_tokens",
"organizations",
"users",
"end_users",
]:
item_dict["object_permission"].pop(field, None)
returned_response.append(LiteLLM_EndUserTable(**item_dict))
return returned_response
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.management_endpoints.customer_endpoints.list_end_user(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e)
@router.get(
"/customer/daily/activity",
tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)],
response_model=SpendAnalyticsPaginatedResponse,
)
@router.get(
"/end_user/daily/activity",
tags=["Customer Management"],
include_in_schema=False,
dependencies=[Depends(user_api_key_auth)],
)
async def get_customer_daily_activity(
end_user_ids: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
model: Optional[str] = None,
api_key: Optional[str] = None,
page: int = 1,
page_size: int = 10,
exclude_end_user_ids: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get daily activity for specific organizations or all accessible organizations.
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
# Parse comma-separated ids
end_user_ids_list = end_user_ids.split(",") if end_user_ids else None
exclude_end_user_ids_list: Optional[List[str]] = None
if exclude_end_user_ids:
exclude_end_user_ids_list = (
exclude_end_user_ids.split(",") if exclude_end_user_ids else None
)
# Fetch organization aliases for metadata
where_condition = {}
if end_user_ids_list:
where_condition["user_id"] = {"in": list(end_user_ids_list)}
end_user_aliases = await prisma_client.db.litellm_endusertable.find_many(
where=where_condition
)
end_user_alias_metadata = {e.user_id: {"alias": e.alias} for e in end_user_aliases}
# Query daily activity for organizations
return await get_daily_activity(
prisma_client=prisma_client,
table_name="litellm_dailyenduserspend",
entity_id_field="end_user_id",
entity_id=end_user_ids_list,
entity_metadata_field=end_user_alias_metadata,
exclude_entity_ids=exclude_end_user_ids_list,
start_date=start_date,
end_date=end_date,
model=model,
api_key=api_key,
page=page,
page_size=page_size,
)

View File

@@ -0,0 +1,367 @@
"""
FALLBACK MANAGEMENT ENDPOINTS
Dedicated endpoints for managing model fallbacks separately from general config.
POST /fallback - Create or update fallbacks for a specific model
GET /fallback/{model} - Get fallbacks for a specific model
DELETE /fallback/{model} - Delete fallbacks for a specific model
"""
# pyright: reportMissingImports=false
import json
from typing import TYPE_CHECKING, Dict, List, Literal
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.model_checks import get_all_fallbacks
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
if TYPE_CHECKING:
from fastapi import APIRouter, Depends, HTTPException, status
else:
try:
from fastapi import APIRouter, Depends, HTTPException, status
except ImportError:
# fastapi is only required for proxy, not for SDK usage
pass
from litellm.types.management_endpoints.router_settings_endpoints import (
FallbackCreateRequest,
FallbackDeleteResponse,
FallbackGetResponse,
FallbackResponse,
)
router = APIRouter()
@router.post(
"/fallback",
tags=["Fallback Management"],
dependencies=[Depends(user_api_key_auth)],
response_model=FallbackResponse,
status_code=status.HTTP_200_OK,
)
async def create_fallback(
data: FallbackCreateRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create or update fallbacks for a specific model.
This endpoint allows you to configure fallback models separately from the general config.
Fallbacks are triggered when a model call fails after retries.
**Example Request:**
```json
{
"model": "gpt-3.5-turbo",
"fallback_models": ["gpt-4", "claude-3-haiku"],
"fallback_type": "general"
}
```
**Fallback Types:**
- `general`: Standard fallbacks for any error (default)
- `context_window`: Fallbacks specifically for context window exceeded errors
- `content_policy`: Fallbacks specifically for content policy violations
"""
from litellm.proxy.proxy_server import (
llm_router,
prisma_client,
proxy_config,
store_model_in_db,
)
try:
# Validate that we have a router
if llm_router is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"error": "Router not initialized"},
)
# Validate that the model exists in the router
model_names = llm_router.model_names
if data.model not in model_names:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"error": f"Model '{data.model}' not found in router",
"available_models": list(model_names),
},
)
# Validate that all fallback models exist in the router
invalid_fallback_models = [
m for m in data.fallback_models if m not in model_names
]
if invalid_fallback_models:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": f"Invalid fallback models: {invalid_fallback_models}",
"available_models": list(model_names),
},
)
# Check if fallback model is the same as the primary model
if data.model in data.fallback_models:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": f"Model '{data.model}' cannot be its own fallback"},
)
# Check if we need to store in DB
if store_model_in_db is not True or prisma_client is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "Database storage not enabled. Set 'STORE_MODEL_IN_DB=True' in your environment to use this feature."
},
)
# Load existing config
config = await proxy_config.get_config()
router_settings = config.get("router_settings", {})
# Get the appropriate fallback list based on type
fallback_key = "fallbacks"
if data.fallback_type == "context_window":
fallback_key = "context_window_fallbacks"
elif data.fallback_type == "content_policy":
fallback_key = "content_policy_fallbacks"
# Get existing fallbacks
existing_fallbacks: List[Dict[str, List[str]]] = router_settings.get(
fallback_key, []
)
# Update or add the fallback configuration
fallback_updated = False
for i, fallback_dict in enumerate(existing_fallbacks):
if data.model in fallback_dict:
# Update existing fallback
existing_fallbacks[i] = {data.model: data.fallback_models}
fallback_updated = True
break
if not fallback_updated:
# Add new fallback
existing_fallbacks.append({data.model: data.fallback_models})
# Update router settings
router_settings[fallback_key] = existing_fallbacks
# Save to database - convert router_settings to JSON string
router_settings_json = json.dumps(router_settings)
await prisma_client.db.litellm_config.upsert(
where={"param_name": "router_settings"},
data={
"create": {
"param_name": "router_settings",
"param_value": router_settings_json,
},
"update": {"param_value": router_settings_json},
},
)
# Update the in-memory router configuration
setattr(llm_router, fallback_key, existing_fallbacks)
verbose_proxy_logger.info(
f"Fallback configured: {data.model} -> {data.fallback_models} (type: {data.fallback_type})"
)
return FallbackResponse(
model=data.model,
fallback_models=data.fallback_models,
fallback_type=data.fallback_type,
message=f"Fallback configuration {'updated' if fallback_updated else 'created'} successfully",
)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.error(f"Error creating fallback: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"error": f"Failed to create fallback: {str(e)}"},
)
@router.get(
"/fallback/{model}",
tags=["Fallback Management"],
dependencies=[Depends(user_api_key_auth)],
response_model=FallbackGetResponse,
)
async def get_fallback(
model: str,
fallback_type: Literal["general", "context_window", "content_policy"] = "general",
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get fallback configuration for a specific model.
**Parameters:**
- `model`: The model name to get fallbacks for
- `fallback_type`: Type of fallback to retrieve (query parameter)
**Example:**
```
GET /fallback/gpt-3.5-turbo?fallback_type=general
```
"""
from litellm.proxy.proxy_server import llm_router
try:
if llm_router is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"error": "Router not initialized"},
)
# Get fallbacks using the existing utility function
fallback_models = get_all_fallbacks(
model=model, llm_router=llm_router, fallback_type=fallback_type
)
if not fallback_models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"error": f"No {fallback_type} fallbacks configured for model '{model}'"
},
)
return FallbackGetResponse(
model=model,
fallback_models=fallback_models,
fallback_type=fallback_type,
)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.error(f"Error getting fallback: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"error": f"Failed to get fallback: {str(e)}"},
)
@router.delete(
"/fallback/{model}",
tags=["Fallback Management"],
dependencies=[Depends(user_api_key_auth)],
response_model=FallbackDeleteResponse,
)
async def delete_fallback(
model: str,
fallback_type: Literal["general", "context_window", "content_policy"] = "general",
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Delete fallback configuration for a specific model.
**Parameters:**
- `model`: The model name to delete fallbacks for
- `fallback_type`: Type of fallback to delete (query parameter)
**Example:**
```
DELETE /fallback/gpt-3.5-turbo?fallback_type=general
```
"""
from litellm.proxy.proxy_server import (
llm_router,
prisma_client,
proxy_config,
store_model_in_db,
)
try:
if llm_router is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"error": "Router not initialized"},
)
if store_model_in_db is not True or prisma_client is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "Database storage not enabled. Set 'STORE_MODEL_IN_DB=True' in your environment to use this feature."
},
)
# Load existing config
config = await proxy_config.get_config()
router_settings = config.get("router_settings", {})
# Get the appropriate fallback list based on type
fallback_key = "fallbacks"
if fallback_type == "context_window":
fallback_key = "context_window_fallbacks"
elif fallback_type == "content_policy":
fallback_key = "content_policy_fallbacks"
# Get existing fallbacks
existing_fallbacks: List[Dict[str, List[str]]] = router_settings.get(
fallback_key, []
)
# Find and remove the fallback configuration
fallback_found = False
updated_fallbacks = []
for fallback_dict in existing_fallbacks:
if model not in fallback_dict:
updated_fallbacks.append(fallback_dict)
else:
fallback_found = True
if not fallback_found:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"error": f"No {fallback_type} fallbacks configured for model '{model}'"
},
)
# Update router settings
router_settings[fallback_key] = updated_fallbacks
# Save to database - convert router_settings to JSON string
router_settings_json = json.dumps(router_settings)
await prisma_client.db.litellm_config.upsert(
where={"param_name": "router_settings"},
data={
"create": {
"param_name": "router_settings",
"param_value": router_settings_json,
},
"update": {"param_value": router_settings_json},
},
)
# Update the in-memory router configuration
setattr(llm_router, fallback_key, updated_fallbacks)
verbose_proxy_logger.info(f"Fallback deleted: {model} (type: {fallback_type})")
return FallbackDeleteResponse(
model=model,
fallback_type=fallback_type,
message="Fallback configuration deleted successfully",
)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.error(f"Error deleting fallback: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"error": f"Failed to delete fallback: {str(e)}"},
)

View File

@@ -0,0 +1,256 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from litellm.proxy._types import (
CreateJWTKeyMappingRequest,
DeleteJWTKeyMappingRequest,
JWTKeyMappingResponse,
LitellmUserRoles,
UpdateJWTKeyMappingRequest,
UserAPIKeyAuth,
hash_token,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
router = APIRouter()
def _to_response(mapping) -> JWTKeyMappingResponse:
"""Convert a Prisma mapping object to a safe response (no hashed token)."""
return JWTKeyMappingResponse(
id=mapping.id,
jwt_claim_name=mapping.jwt_claim_name,
jwt_claim_value=mapping.jwt_claim_value,
description=mapping.description,
is_active=mapping.is_active,
created_at=mapping.created_at,
updated_at=mapping.updated_at,
created_by=mapping.created_by,
updated_by=mapping.updated_by,
)
@router.post(
"/jwt/key/mapping/new",
tags=["JWT Key Mapping"],
response_model=JWTKeyMappingResponse,
)
async def create_jwt_key_mapping(
data: CreateJWTKeyMappingRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
from litellm.proxy.proxy_server import prisma_client, user_api_key_cache
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=403, detail="Only proxy admins can create JWT key mappings"
)
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
hashed_key = hash_token(data.key)
create_data = {
"jwt_claim_name": data.jwt_claim_name,
"jwt_claim_value": data.jwt_claim_value,
"token": hashed_key,
"created_by": user_api_key_dict.user_id,
"updated_by": user_api_key_dict.user_id,
}
if data.description is not None:
create_data["description"] = data.description
new_mapping = await prisma_client.db.litellm_jwtkeymapping.create(
data=create_data
)
# Invalidate cache
cache_key = f"jwt_key_mapping:{data.jwt_claim_name}:{data.jwt_claim_value}"
await user_api_key_cache.async_delete_cache(cache_key)
return _to_response(new_mapping)
except HTTPException:
raise
except Exception as e:
error_str = str(e).lower()
if "unique" in error_str or "p2002" in error_str:
raise HTTPException(
status_code=409,
detail=f"A mapping for claim '{data.jwt_claim_name}' = '{data.jwt_claim_value}' already exists.",
)
if "foreign" in error_str or "p2003" in error_str:
raise HTTPException(
status_code=400,
detail="The provided key does not match an existing virtual key.",
)
raise HTTPException(status_code=500, detail="Failed to create JWT key mapping.")
@router.post(
"/jwt/key/mapping/update",
tags=["JWT Key Mapping"],
response_model=JWTKeyMappingResponse,
)
async def update_jwt_key_mapping(
data: UpdateJWTKeyMappingRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
from litellm.proxy.proxy_server import prisma_client, user_api_key_cache
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=403, detail="Only proxy admins can update JWT key mappings"
)
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
update_data = data.model_dump(exclude_unset=True, exclude={"id", "key"})
if data.key is not None:
update_data["token"] = hash_token(data.key)
update_data["updated_by"] = user_api_key_dict.user_id
try:
# Get old mapping for cache invalidation
old_mapping = await prisma_client.db.litellm_jwtkeymapping.find_unique(
where={"id": data.id}
)
if old_mapping is None:
raise HTTPException(status_code=404, detail="Mapping not found")
cache_key = f"jwt_key_mapping:{old_mapping.jwt_claim_name}:{old_mapping.jwt_claim_value}"
await user_api_key_cache.async_delete_cache(cache_key)
updated_mapping = await prisma_client.db.litellm_jwtkeymapping.update(
where={"id": data.id}, data=update_data
)
# Invalidate new cache key if claim fields changed
cache_key = f"jwt_key_mapping:{updated_mapping.jwt_claim_name}:{updated_mapping.jwt_claim_value}"
await user_api_key_cache.async_delete_cache(cache_key)
return _to_response(updated_mapping)
except HTTPException:
raise
except Exception as e:
error_str = str(e).lower()
if "unique" in error_str or "p2002" in error_str:
raise HTTPException(
status_code=409,
detail="A mapping with those claim values already exists.",
)
if "foreign" in error_str or "p2003" in error_str:
raise HTTPException(
status_code=400,
detail="The provided key does not match an existing virtual key.",
)
raise HTTPException(status_code=500, detail="Failed to update JWT key mapping.")
@router.post("/jwt/key/mapping/delete", tags=["JWT Key Mapping"])
async def delete_jwt_key_mapping(
data: DeleteJWTKeyMappingRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
from litellm.proxy.proxy_server import prisma_client, user_api_key_cache
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=403, detail="Only proxy admins can delete JWT key mappings"
)
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
# Get old mapping for cache invalidation
old_mapping = await prisma_client.db.litellm_jwtkeymapping.find_unique(
where={"id": data.id}
)
if old_mapping is None:
raise HTTPException(status_code=404, detail="Mapping not found")
cache_key = f"jwt_key_mapping:{old_mapping.jwt_claim_name}:{old_mapping.jwt_claim_value}"
await user_api_key_cache.async_delete_cache(cache_key)
await prisma_client.db.litellm_jwtkeymapping.delete(where={"id": data.id})
return {"status": "success"}
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to delete JWT key mapping.")
@router.get(
"/jwt/key/mapping/list",
tags=["JWT Key Mapping"],
)
async def list_jwt_key_mappings(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
page: int = Query(1, description="Page number", ge=1),
size: int = Query(50, description="Page size", ge=1, le=100),
):
from litellm.proxy.proxy_server import prisma_client
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=403, detail="Only proxy admins can list JWT key mappings"
)
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
skip = (page - 1) * size
mappings = await prisma_client.db.litellm_jwtkeymapping.find_many(
skip=skip,
take=size,
order={"created_at": "desc"},
)
total_count = await prisma_client.db.litellm_jwtkeymapping.count()
return {
"mappings": [_to_response(m) for m in mappings],
"total_count": total_count,
"current_page": page,
"total_pages": -(-total_count // size), # ceiling division
}
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to list JWT key mappings.")
@router.get(
"/jwt/key/mapping/info",
tags=["JWT Key Mapping"],
response_model=JWTKeyMappingResponse,
)
async def info_jwt_key_mapping(
id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
from litellm.proxy.proxy_server import prisma_client
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=403, detail="Only proxy admins can get JWT key mapping info"
)
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
mapping = await prisma_client.db.litellm_jwtkeymapping.find_unique(
where={"id": id}
)
if mapping is None:
raise HTTPException(status_code=404, detail="Mapping not found")
return _to_response(mapping)
except HTTPException:
raise
except Exception:
raise HTTPException(
status_code=500, detail="Failed to get JWT key mapping info."
)

View File

@@ -0,0 +1,765 @@
"""
Allow proxy admin to manage model access groups
Endpoints here:
- POST /model_group/new - Create a new access group with multiple model names
"""
import json
from typing import Any, Dict, List, Tuple
from fastapi import APIRouter, Depends, HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
# Clear cache and reload models to pick up the access group changes
from litellm.proxy.management_endpoints.model_management_endpoints import (
clear_cache,
)
from litellm.proxy.utils import PrismaClient
from litellm.types.proxy.management_endpoints.model_management_endpoints import (
AccessGroupInfo,
DeleteModelGroupResponse,
ListAccessGroupsResponse,
NewModelGroupRequest,
NewModelGroupResponse,
UpdateModelGroupRequest,
)
router = APIRouter()
def validate_models_exist(model_names: List[str], llm_router) -> Tuple[bool, List[str]]:
"""
Validate that all requested model names exist in the router.
Checks only exact model name matches.
Returns:
Tuple[bool, List[str]]: (all_valid, missing_models)
"""
if llm_router is None:
return False, model_names
router_model_names = set(llm_router.get_model_names())
missing = [m for m in model_names if m not in router_model_names]
return (len(missing) == 0, missing)
def add_access_group_to_deployment(
model_info: Dict[str, Any], access_group: str
) -> Tuple[Dict[str, Any], bool]:
"""
Add an access group to a deployment's model_info.
Args:
model_info: The model_info dictionary from the deployment
access_group: The access group name to add
Returns:
Tuple[Dict[str, Any], bool]: (updated_model_info, was_modified)
"""
access_groups = model_info.get("access_groups", [])
# Check if access group already exists
if access_group in access_groups:
return model_info, False
# Add the access group
access_groups.append(access_group)
model_info["access_groups"] = access_groups
return model_info, True
async def update_deployments_with_access_group(
model_names: List[str],
access_group: str,
prisma_client: PrismaClient,
) -> int:
"""
Update all deployments for the given model names to include the access group.
Args:
model_names: List of model names whose deployments should be updated
access_group: The access group name to add
prisma_client: Database client
Returns:
int: Number of deployments updated
"""
models_updated = 0
for model_name in model_names:
verbose_proxy_logger.debug(f"Updating deployments for model_name: {model_name}")
# Get all deployments with this model_name
deployments = await prisma_client.db.litellm_proxymodeltable.find_many(
where={"model_name": model_name}
)
verbose_proxy_logger.debug(
f"Found {len(deployments)} deployments for model_name: {model_name}"
)
# If no deployments found, this is a config model (not in DB)
if len(deployments) == 0:
raise HTTPException(
status_code=400,
detail={
"error": f"Can't find model '{model_name}' in Database. Access group management is only supported for database models."
},
)
# Update each deployment
for deployment in deployments:
model_info = deployment.model_info or {}
# Add access group using helper
updated_model_info, was_modified = add_access_group_to_deployment(
model_info=model_info,
access_group=access_group,
)
# Only update in DB if modified
if was_modified:
await prisma_client.db.litellm_proxymodeltable.update(
where={"model_id": deployment.model_id},
data={"model_info": json.dumps(updated_model_info)},
)
models_updated += 1
verbose_proxy_logger.debug(
f"Updated deployment {deployment.model_id} with access group: {access_group}"
)
return models_updated
async def update_specific_deployments_with_access_group(
model_ids: List[str],
access_group: str,
prisma_client: PrismaClient,
) -> int:
"""
Update specific deployments (by model_id) to include the access group.
Unlike update_deployments_with_access_group which tags ALL deployments sharing
a model_name, this function only tags the specific deployments identified by
their unique model_id.
"""
models_updated = 0
for model_id in model_ids:
verbose_proxy_logger.debug(f"Updating specific deployment model_id: {model_id}")
deployment = await prisma_client.db.litellm_proxymodeltable.find_unique(
where={"model_id": model_id}
)
if deployment is None:
raise HTTPException(
status_code=400,
detail={
"error": f"Deployment with model_id '{model_id}' not found in Database."
},
)
model_info = deployment.model_info or {}
updated_model_info, was_modified = add_access_group_to_deployment(
model_info=model_info,
access_group=access_group,
)
if was_modified:
await prisma_client.db.litellm_proxymodeltable.update(
where={"model_id": model_id},
data={"model_info": json.dumps(updated_model_info)},
)
models_updated += 1
verbose_proxy_logger.debug(
f"Updated deployment {model_id} with access group: {access_group}"
)
return models_updated
def remove_access_group_from_deployment(
model_info: Dict[str, Any], access_group: str
) -> Tuple[Dict[str, Any], bool]:
"""
Remove an access group from a deployment's model_info.
Args:
model_info: The model_info dictionary from the deployment
access_group: The access group name to remove
Returns:
Tuple[Dict[str, Any], bool]: (updated_model_info, was_modified)
"""
access_groups = model_info.get("access_groups", [])
# Check if access group exists
if access_group not in access_groups:
return model_info, False
# Remove the access group
access_groups.remove(access_group)
model_info["access_groups"] = access_groups
return model_info, True
async def get_all_access_groups_from_db(
prisma_client: PrismaClient,
) -> Dict[str, AccessGroupInfo]:
"""
Get all access groups from the database.
Returns:
Dict[str, AccessGroupInfo]: Dictionary mapping access_group name to info
"""
# Get all deployments
deployments = await prisma_client.db.litellm_proxymodeltable.find_many()
# Build access group map
access_group_map: Dict[str, Dict[str, Any]] = {}
for deployment in deployments:
model_info = deployment.model_info or {}
access_groups = model_info.get("access_groups", [])
model_name = deployment.model_name
for access_group in access_groups:
if access_group not in access_group_map:
access_group_map[access_group] = {
"model_names": set(),
"deployment_count": 0,
}
access_group_map[access_group]["model_names"].add(model_name)
access_group_map[access_group]["deployment_count"] += 1
# Convert to AccessGroupInfo objects
result = {}
for access_group, data in access_group_map.items():
result[access_group] = AccessGroupInfo(
access_group=access_group,
model_names=sorted(list(data["model_names"])),
deployment_count=data["deployment_count"],
)
return result
@router.post(
"/access_group/new",
tags=["model management"],
dependencies=[Depends(user_api_key_auth)],
response_model=NewModelGroupResponse,
)
async def create_model_group(
data: NewModelGroupRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create a new access group containing multiple model names.
An access group is a named collection of model groups that can be referenced
by teams/keys for simplified access control.
Example:
```bash
curl -X POST 'http://localhost:4000/access_group/new' \\
-H 'Authorization: Bearer sk-1234' \\
-H 'Content-Type: application/json' \\
-d '{
"access_group": "production-models",
"model_names": ["gpt-4", "claude-3-opus", "gemini-pro"]
}'
```
Parameters:
- access_group: str - The access group name (e.g., "production-models")
- model_names: List[str] - List of existing model groups to include
Returns:
- NewModelGroupResponse with the created access group details
Raises:
- HTTPException 400: If any model names don't exist
- HTTPException 500: If database operations fail
"""
from litellm.proxy.proxy_server import (
llm_router,
prisma_client,
)
verbose_proxy_logger.debug(
f"Creating access group: {data.access_group} with models: {data.model_names}"
)
# Validation: Check if access_group is provided
if not data.access_group or not data.access_group.strip():
raise HTTPException(
status_code=400,
detail={"error": "access_group is required and cannot be empty"},
)
# Validation: Check that at least one of model_names or model_ids is provided
has_model_names = data.model_names and len(data.model_names) > 0
has_model_ids = data.model_ids and len(data.model_ids) > 0
if not has_model_names and not has_model_ids:
raise HTTPException(
status_code=400,
detail={
"error": "Either model_names or model_ids must be provided and non-empty"
},
)
# If model_ids is provided, use it (more precise targeting)
use_model_ids = has_model_ids
# Validate model_names exist in router (only if using model_names path)
if not use_model_ids and has_model_names:
assert data.model_names is not None
all_valid, missing_models = validate_models_exist(
model_names=data.model_names,
llm_router=llm_router,
)
if not all_valid:
raise HTTPException(
status_code=400,
detail={"error": f"Model(s) not found: {', '.join(missing_models)}"},
)
# Check if database is connected
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": "Database not connected. Cannot create access group."},
)
try:
# Check if access group already exists
existing_access_groups = await get_all_access_groups_from_db(
prisma_client=prisma_client
)
if data.access_group in existing_access_groups:
raise HTTPException(
status_code=409,
detail={
"error": f"Access group '{data.access_group}' already exists. Use PUT /access_group/{data.access_group}/update to modify it."
},
)
# Update deployments using the appropriate method
if use_model_ids:
assert data.model_ids is not None
models_updated = await update_specific_deployments_with_access_group(
model_ids=data.model_ids,
access_group=data.access_group,
prisma_client=prisma_client,
)
else:
assert data.model_names is not None
models_updated = await update_deployments_with_access_group(
model_names=data.model_names,
access_group=data.access_group,
prisma_client=prisma_client,
)
await clear_cache()
verbose_proxy_logger.info(
f"Successfully created access group '{data.access_group}' with {models_updated} models updated"
)
return NewModelGroupResponse(
access_group=data.access_group,
model_names=data.model_names,
model_ids=data.model_ids,
models_updated=models_updated,
)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(
f"Error creating access group '{data.access_group}': {str(e)}"
)
raise HTTPException(
status_code=500,
detail={"error": f"Failed to create access group: {str(e)}"},
)
@router.get(
"/access_group/list",
tags=["model management"],
dependencies=[Depends(user_api_key_auth)],
response_model=ListAccessGroupsResponse,
)
async def list_access_groups(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
List all access groups.
Returns a list of all access groups with their model names and deployment counts.
Example:
```bash
curl -X GET 'http://localhost:4000/access_group/list' \\
-H 'Authorization: Bearer sk-1234'
```
Returns:
- ListAccessGroupsResponse with all access groups
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": "Database not connected."},
)
try:
access_groups_map = await get_all_access_groups_from_db(
prisma_client=prisma_client
)
# Sort by access group name
access_groups_list = sorted(
access_groups_map.values(),
key=lambda x: x.access_group,
)
return ListAccessGroupsResponse(access_groups=access_groups_list)
except Exception as e:
verbose_proxy_logger.exception(f"Error listing access groups: {str(e)}")
raise HTTPException(
status_code=500,
detail={"error": f"Failed to list access groups: {str(e)}"},
)
@router.get(
"/access_group/{access_group}/info",
tags=["model management"],
dependencies=[Depends(user_api_key_auth)],
response_model=AccessGroupInfo,
)
async def get_access_group_info(
access_group: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get information about a specific access group.
Example:
```bash
curl -X GET 'http://localhost:4000/access_group/production-models/info' \\
-H 'Authorization: Bearer sk-1234'
```
Parameters:
- access_group: str - The access group name (URL path parameter)
Returns:
- AccessGroupInfo with the access group details
Raises:
- HTTPException 404: If access group not found
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": "Database not connected."},
)
try:
access_groups_map = await get_all_access_groups_from_db(
prisma_client=prisma_client
)
if access_group not in access_groups_map:
raise HTTPException(
status_code=404,
detail={"error": f"Access group '{access_group}' not found"},
)
return access_groups_map[access_group]
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(
f"Error getting access group info for '{access_group}': {str(e)}"
)
raise HTTPException(
status_code=500,
detail={"error": f"Failed to get access group info: {str(e)}"},
)
@router.put(
"/access_group/{access_group}/update",
tags=["model management"],
dependencies=[Depends(user_api_key_auth)],
response_model=NewModelGroupResponse,
)
async def update_access_group(
access_group: str,
data: UpdateModelGroupRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Update an access group's model names.
This will:
1. Remove the access group from all current deployments
2. Add the access group to all deployments for the new model_names list
Example:
```bash
curl -X PUT 'http://localhost:4000/access_group/production-models/update' \\
-H 'Authorization: Bearer sk-1234' \\
-H 'Content-Type: application/json' \\
-d '{
"model_names": ["gpt-4", "claude-3-sonnet"]
}'
```
Parameters:
- access_group: str - The access group name (URL path parameter)
- model_names: List[str] - New list of model groups to include
Returns:
- NewModelGroupResponse with the updated access group details
Raises:
- HTTPException 400: If any model names don't exist
- HTTPException 404: If access group not found
"""
from litellm.proxy.proxy_server import llm_router, prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": "Database not connected."},
)
verbose_proxy_logger.debug(
f"Updating access group: {access_group} with models: {data.model_names}"
)
# Validation: Check that at least one of model_names or model_ids is provided
has_model_names = data.model_names and len(data.model_names) > 0
has_model_ids = data.model_ids and len(data.model_ids) > 0
if not has_model_names and not has_model_ids:
raise HTTPException(
status_code=400,
detail={
"error": "Either model_names or model_ids must be provided and non-empty"
},
)
use_model_ids = has_model_ids
# Validation: Check if access group exists
try:
access_groups_map = await get_all_access_groups_from_db(
prisma_client=prisma_client
)
if access_group not in access_groups_map:
raise HTTPException(
status_code=404,
detail={"error": f"Access group '{access_group}' not found"},
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail={"error": f"Failed to check access group existence: {str(e)}"},
)
# Validation: Check if all new models exist (only if using model_names path)
if not use_model_ids and has_model_names:
assert data.model_names is not None
all_valid, missing_models = validate_models_exist(
model_names=data.model_names,
llm_router=llm_router,
)
if not all_valid:
raise HTTPException(
status_code=400,
detail={"error": f"Model(s) not found: {', '.join(missing_models)}"},
)
try:
# Step 1: Remove access group from ALL DB deployments (skip config models)
all_deployments = await prisma_client.db.litellm_proxymodeltable.find_many()
for deployment in all_deployments:
model_info = deployment.model_info or {}
updated_model_info, was_modified = remove_access_group_from_deployment(
model_info=model_info,
access_group=access_group,
)
if was_modified:
await prisma_client.db.litellm_proxymodeltable.update(
where={"model_id": deployment.model_id},
data={"model_info": json.dumps(updated_model_info)},
)
# Step 2: Add access group using the appropriate method
if use_model_ids:
assert data.model_ids is not None
models_updated = await update_specific_deployments_with_access_group(
model_ids=data.model_ids,
access_group=access_group,
prisma_client=prisma_client,
)
else:
assert data.model_names is not None
models_updated = await update_deployments_with_access_group(
model_names=data.model_names,
access_group=access_group,
prisma_client=prisma_client,
)
# Clear cache and reload models to pick up the access group changes
await clear_cache()
verbose_proxy_logger.info(
f"Successfully updated access group '{access_group}' with {models_updated} models updated"
)
return NewModelGroupResponse(
access_group=access_group,
model_names=data.model_names,
model_ids=data.model_ids,
models_updated=models_updated,
)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(
f"Error updating access group '{access_group}': {str(e)}"
)
raise HTTPException(
status_code=500,
detail={"error": f"Failed to update access group: {str(e)}"},
)
@router.delete(
"/access_group/{access_group}/delete",
tags=["model management"],
dependencies=[Depends(user_api_key_auth)],
response_model=DeleteModelGroupResponse,
)
async def delete_access_group(
access_group: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Delete an access group.
Removes the access group from all deployments that have it.
Example:
```bash
curl -X DELETE 'http://localhost:4000/access_group/production-models/delete' \\
-H 'Authorization: Bearer sk-1234'
```
Parameters:
- access_group: str - The access group name (URL path parameter)
Returns:
- DeleteModelGroupResponse with deletion details
Raises:
- HTTPException 404: If access group not found
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": "Database not connected."},
)
verbose_proxy_logger.debug(f"Deleting access group: {access_group}")
# Validation: Check if access group exists
try:
access_groups_map = await get_all_access_groups_from_db(
prisma_client=prisma_client
)
if access_group not in access_groups_map:
raise HTTPException(
status_code=404,
detail={"error": f"Access group '{access_group}' not found"},
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail={"error": f"Failed to check access group existence: {str(e)}"},
)
try:
# Remove access group from all DB deployments (skip config models)
all_deployments = await prisma_client.db.litellm_proxymodeltable.find_many()
models_updated = 0
for deployment in all_deployments:
model_info = deployment.model_info or {}
updated_model_info, was_modified = remove_access_group_from_deployment(
model_info=model_info,
access_group=access_group,
)
if was_modified:
await prisma_client.db.litellm_proxymodeltable.update(
where={"model_id": deployment.model_id},
data={"model_info": json.dumps(updated_model_info)},
)
models_updated += 1
# Clear cache and reload models to pick up the access group changes
await clear_cache()
verbose_proxy_logger.info(
f"Successfully deleted access group '{access_group}' from {models_updated} deployments"
)
return DeleteModelGroupResponse(
access_group=access_group,
models_updated=models_updated,
message=f"Access group '{access_group}' deleted successfully",
)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(
f"Error deleting access group '{access_group}': {str(e)}"
)
raise HTTPException(
status_code=500,
detail={"error": f"Failed to delete access group: {str(e)}"},
)

View File

@@ -0,0 +1,20 @@
"""
Policy endpoints package.
Re-exports everything from endpoints module so existing imports
like `from litellm.proxy.management_endpoints.policy_endpoints import router`
continue to work. Patch targets also resolve correctly since names
are imported directly into this namespace.
"""
from litellm.proxy.management_endpoints.policy_endpoints.endpoints import * # noqa: F401, F403
from litellm.proxy.management_endpoints.policy_endpoints.endpoints import ( # noqa: F401
_build_all_names_per_competitor,
_build_comparison_blocked_words,
_build_competitor_guardrail_definitions,
_build_name_blocked_words,
_build_recommendation_blocked_words,
_build_refinement_prompt,
_clean_competitor_line,
_parse_variations_response,
)

View File

@@ -0,0 +1,131 @@
"""
AI Policy Suggester - uses LLM tool calling to suggest policy templates
based on user-provided attack examples and descriptions.
"""
import json
from typing import List, Optional
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.constants import DEFAULT_COMPETITOR_DISCOVERY_MODEL
SUGGEST_TOOL = {
"type": "function",
"function": {
"name": "select_policy_templates",
"description": "Select one or more policy templates that best match the user's security requirements",
"parameters": {
"type": "object",
"properties": {
"selected_templates": {
"type": "array",
"items": {
"type": "object",
"properties": {
"template_id": {
"type": "string",
"description": "The ID of the selected template",
},
"reason": {
"type": "string",
"description": "Brief reason why this template matches",
},
},
"required": ["template_id", "reason"],
},
"description": "List of templates that match the user's requirements",
},
"explanation": {
"type": "string",
"description": "Overall explanation of why these templates were suggested",
},
},
"required": ["selected_templates", "explanation"],
},
},
}
class AiPolicySuggester:
"""Suggests policy templates using LLM tool calling."""
async def suggest(
self,
templates: list,
attack_examples: List[str],
description: str,
model: Optional[str] = None,
) -> dict:
system_prompt = self._build_system_prompt(templates)
user_prompt = self._build_user_prompt(attack_examples, description)
model = model or DEFAULT_COMPETITOR_DISCOVERY_MODEL
try:
response = await litellm.acompletion(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
tools=[SUGGEST_TOOL],
tool_choice={
"type": "function",
"function": {"name": "select_policy_templates"},
},
temperature=0.2,
)
tool_calls = response.choices[0].message.tool_calls # type: ignore
if not tool_calls:
return {
"selected_templates": [],
"explanation": "No templates could be matched to your requirements.",
}
result = json.loads(tool_calls[0].function.arguments)
valid_ids = {t["id"] for t in templates}
result["selected_templates"] = [
s
for s in result.get("selected_templates", [])
if s.get("template_id") in valid_ids
]
return result
except Exception as e:
verbose_proxy_logger.error("AI policy suggestion failed: %s", e)
raise
def _build_system_prompt(self, templates: list) -> str:
template_descriptions = []
for t in templates:
examples = t.get("example_sentences", [])
examples_str = ", ".join(f'"{e}"' for e in examples) if examples else "none"
entry = (
f"- ID: {t['id']}\n"
f" Title: {t['title']}\n"
f" Description: {t['description']}\n"
f" Example attacks it protects against: {examples_str}"
)
template_descriptions.append(entry)
return (
"You are a security policy advisor. The user will describe attacks or content "
"they want to block. Your job is to select the most relevant policy templates "
"from the available set. Use the select_policy_templates tool to return your "
"selections. Only select templates that are clearly relevant to what the user "
"wants to block.\n\n"
"Available templates:\n\n" + "\n\n".join(template_descriptions)
)
def _build_user_prompt(self, attack_examples: List[str], description: str) -> str:
parts = []
filtered_examples = [e for e in attack_examples if e.strip()]
if filtered_examples:
parts.append("Example attack prompts I want to block:")
for i, ex in enumerate(filtered_examples, 1):
parts.append(f" {i}. {ex}")
if description.strip():
parts.append(f"\nDescription of what I want to block: {description}")
return "\n".join(parts)

View File

@@ -0,0 +1,936 @@
"""
Endpoints for /project operations
/project/new
/project/update
/project/delete
/project/info
/project/list
"""
#### PROJECT MANAGEMENT ####
import json
from typing import List, Optional, Union
from fastapi import APIRouter, Depends, HTTPException, Request
from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.common_utils import _set_object_metadata_field
from litellm.proxy.management_helpers.utils import (
management_endpoint_wrapper,
)
from litellm.proxy.utils import PrismaClient, handle_exception_on_proxy
router = APIRouter()
async def _check_user_permission_for_project(
user_api_key_dict: UserAPIKeyAuth,
team_id: Optional[str],
prisma_client: PrismaClient,
require_admin: bool = False,
team_object: Optional[LiteLLM_TeamTable] = None,
) -> bool:
"""
Check if user has permission to manage a project.
Returns True if user is proxy admin or team admin (when team_id provided).
If require_admin=True, only proxy admins are allowed.
If team_object is provided, it will be used instead of fetching from DB
(avoids duplicate DB queries when team was already fetched for validation).
"""
is_proxy_admin = user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
if require_admin:
return is_proxy_admin
if is_proxy_admin:
return True
if not team_id or not user_api_key_dict.user_id:
return False
team = team_object
if team is None:
team = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
if team and team.admins:
return user_api_key_dict.user_id in team.admins
return False
async def _validate_team_exists(
team_id: str,
prisma_client: PrismaClient,
):
"""Validate that a team exists. Returns the team row."""
team = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id},
)
if team is None:
raise ProxyException(
message=f"Team not found, team_id={team_id}",
type="not_found",
code=404,
param="team_id",
)
return team
def _check_team_project_limits(
team_object: LiteLLM_TeamTable,
data: Union[NewProjectRequest, UpdateProjectRequest],
) -> None:
"""
Check that project limits respect its parent Team's limits.
Mirrors _check_org_team_limits() from team_endpoints.py.
Validates:
- Project models are a subset of Team models
- Project max_budget <= Team max_budget
- Project tpm_limit <= Team tpm_limit
- Project rpm_limit <= Team rpm_limit
- Budget values are non-negative
- soft_budget < max_budget
"""
# --- Budget non-negativity checks ---
if data.max_budget is not None and data.max_budget < 0:
raise HTTPException(
status_code=400,
detail={
"error": f"max_budget cannot be negative. Received: {data.max_budget}"
},
)
if data.soft_budget is not None and data.soft_budget < 0:
raise HTTPException(
status_code=400,
detail={
"error": f"soft_budget cannot be negative. Received: {data.soft_budget}"
},
)
# --- soft_budget < max_budget ---
if data.soft_budget is not None and data.max_budget is not None:
if data.soft_budget >= data.max_budget:
raise HTTPException(
status_code=400,
detail={
"error": f"soft_budget ({data.soft_budget}) must be strictly lower than max_budget ({data.max_budget})"
},
)
# --- Validate project models are a subset of team models ---
project_models = getattr(data, "models", None)
team_models = team_object.models or []
if project_models and len(team_models) > 0:
# If team has 'all-proxy-models', skip validation as it allows all models
if SpecialModelNames.all_proxy_models.value not in team_models:
for m in project_models:
if m not in team_models:
raise HTTPException(
status_code=400,
detail={
"error": f"Model '{m}' not in team's allowed models. Team allowed models={team_models}. Team: {team_object.team_id}"
},
)
# --- Validate project max_budget <= team max_budget ---
# Team stores budget fields directly (max_budget, tpm_limit, rpm_limit)
# unlike Project which uses a separate LiteLLM_BudgetTable relation
if (
data.max_budget is not None
and team_object.max_budget is not None
and data.max_budget > team_object.max_budget
):
raise HTTPException(
status_code=400,
detail={
"error": f"Project max_budget ({data.max_budget}) exceeds team's max_budget ({team_object.max_budget}). Team: {team_object.team_id}"
},
)
# --- Validate project tpm_limit <= team tpm_limit ---
if (
data.tpm_limit is not None
and team_object.tpm_limit is not None
and data.tpm_limit > team_object.tpm_limit
):
raise HTTPException(
status_code=400,
detail={
"error": f"Project tpm_limit ({data.tpm_limit}) exceeds team's tpm_limit ({team_object.tpm_limit}). Team: {team_object.team_id}"
},
)
# --- Validate project rpm_limit <= team rpm_limit ---
if (
data.rpm_limit is not None
and team_object.rpm_limit is not None
and data.rpm_limit > team_object.rpm_limit
):
raise HTTPException(
status_code=400,
detail={
"error": f"Project rpm_limit ({data.rpm_limit}) exceeds team's rpm_limit ({team_object.rpm_limit}). Team: {team_object.team_id}"
},
)
async def _create_budget_for_project(
data: NewProjectRequest,
user_id: Optional[str],
litellm_proxy_admin_name: str,
prisma_client: PrismaClient,
) -> str:
"""Create a budget for the project and return budget_id."""
budget_params = LiteLLM_BudgetTable.model_fields.keys()
_json_data = data.json(exclude_none=True)
_budget_data = {k: v for k, v in _json_data.items() if k in budget_params}
budget_row = LiteLLM_BudgetTable(**_budget_data)
new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True))
_budget = await prisma_client.db.litellm_budgettable.create(
data={
**new_budget,
"created_by": user_id or litellm_proxy_admin_name,
"updated_by": user_id or litellm_proxy_admin_name,
}
)
return _budget.budget_id
async def _set_project_object_permission(
data: NewProjectRequest,
prisma_client: Optional[PrismaClient],
) -> Optional[str]:
"""
Creates the LiteLLM_ObjectPermissionTable record for the project.
Returns the object_permission_id if created, otherwise None.
"""
if prisma_client is None:
return None
if data.object_permission is not None:
created_object_permission = (
await prisma_client.db.litellm_objectpermissiontable.create(
data=data.object_permission.model_dump(exclude_none=True),
)
)
del data.object_permission
return created_object_permission.object_permission_id
return None
def _remove_budget_fields_from_project_data(project_data: dict) -> dict:
"""
Remove budget fields from project data.
Budget fields belong to LiteLLM_BudgetTable, not LiteLLM_ProjectTable.
Keep budget_id as it's a foreign key.
Following the pattern from organization_endpoints.py
"""
budget_fields = LiteLLM_BudgetTable.model_fields.keys()
for field in list(budget_fields):
if field != "budget_id": # Keep the foreign key
project_data.pop(field, None)
return project_data
@router.post(
"/project/new",
tags=["project management"],
dependencies=[Depends(user_api_key_auth)],
response_model=NewProjectResponse,
)
@management_endpoint_wrapper
async def new_project(
data: NewProjectRequest,
http_request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create a new project. Projects sit between teams and keys in the hierarchy.
Only admins or team admins can create projects.
# Parameters
- project_alias: *Optional[str]* - The name of the project.
- description: *Optional[str]* - Description of the project's purpose and use case.
- team_id: *str* - The team id that this project belongs to. Required.
- models: *List* - The models the project has access to.
- budget_id: *Optional[str]* - The id for a budget (tpm/rpm/max budget) for the project.
### IF NO BUDGET ID - CREATE ONE WITH THESE PARAMS ###
- max_budget: *Optional[float]* - Max budget for project
- tpm_limit: *Optional[int]* - Max tpm limit for project
- rpm_limit: *Optional[int]* - Max rpm limit for project
- max_parallel_requests: *Optional[int]* - Max parallel requests for project
- soft_budget: *Optional[float]* - Get a slack alert when this soft budget is reached. Don't block requests.
- model_max_budget: *Optional[dict]* - Max budget for a specific model. Example: {"gpt-4": 100.0, "gpt-3.5-turbo": 50.0}
- model_rpm_limit: *Optional[dict]* - RPM limits per model. Example: {"gpt-4": 1000, "gpt-3.5-turbo": 5000}
- model_tpm_limit: *Optional[dict]* - TPM limits per model. Example: {"gpt-4": 50000, "gpt-3.5-turbo": 100000}
- budget_duration: *Optional[str]* - Frequency of reseting project budget
- metadata: *Optional[dict]* - Metadata for project, store information for project. Example metadata - {"use_case_id": "SNOW-12345", "responsible_ai_id": "RAI-67890"}
- tags: *Optional[list]* - Tags for the project. Example: ["production", "api"]
- blocked: *bool* - Flag indicating if the project is blocked or not - will stop all calls from keys with this project_id.
- object_permission: Optional[LiteLLM_ObjectPermissionBase] - project-specific object permission. Example - {"vector_stores": ["vector_store_1", "vector_store_2"]}. IF null or {} then no object permission.
Example 1: Create new project **without** a budget_id, with model-specific limits
```bash
curl --location 'http://0.0.0.0:4000/project/new' \\
--header 'Authorization: Bearer sk-1234' \\
--header 'Content-Type: application/json' \\
--data '{
"project_alias": "flight-search-assistant",
"description": "AI-powered flight search and booking assistant",
"team_id": "team-123",
"models": ["gpt-4", "gpt-3.5-turbo"],
"max_budget": 100,
"model_rpm_limit": {
"gpt-4": 1000,
"gpt-3.5-turbo": 5000
},
"model_tpm_limit": {
"gpt-4": 50000,
"gpt-3.5-turbo": 100000
},
"metadata": {
"use_case_id": "SNOW-12345",
"responsible_ai_id": "RAI-67890"
}
}'
```
Example 2: Create new project **with** a budget_id
```bash
curl --location 'http://0.0.0.0:4000/project/new' \\
--header 'Authorization: Bearer sk-1234' \\
--header 'Content-Type: application/json' \\
--data '{
"project_alias": "hotel-recommendations",
"description": "Personalized hotel recommendation engine",
"team_id": "team-123",
"models": ["claude-3-sonnet"],
"budget_id": "428eeaa8-f3ac-4e85-a8fb-7dc8d7aa8689",
"metadata": {
"use_case_id": "SNOW-54321"
}
}'
```
"""
from litellm.proxy.proxy_server import (
litellm_proxy_admin_name,
premium_user,
prisma_client,
)
try:
if getattr(data, "tags", None) is not None and not premium_user:
raise HTTPException(
status_code=403,
detail={
"error": "Only premium users can add tags to projects. "
+ CommonProxyErrors.not_premium_user.value
},
)
if not premium_user:
raise HTTPException(
status_code=403,
detail={
"error": "Project management is an enterprise feature. "
+ CommonProxyErrors.not_premium_user.value
},
)
# ADD METADATA FIELDS
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
if getattr(data, field, None) is not None:
_set_object_metadata_field(
object_data=data,
field_name=field,
value=getattr(data, field),
)
delattr(data, field)
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
# Validate team exists and get team object with budget
team_object = await _validate_team_exists(
team_id=data.team_id, prisma_client=prisma_client
)
# Validate project limits against team limits
_check_team_project_limits(
team_object=LiteLLM_TeamTable(**team_object.model_dump()),
data=data,
)
# Check if user has permission to create projects for this team
# only team admins can create projects for their team
has_permission = await _check_user_permission_for_project(
user_api_key_dict=user_api_key_dict,
team_id=data.team_id,
prisma_client=prisma_client,
team_object=LiteLLM_TeamTable(**team_object.model_dump()),
)
if not has_permission:
raise HTTPException(
status_code=403,
detail={
"error": f"Only admins or team admins can create projects. Your role is {user_api_key_dict.user_role}"
},
)
# Generate project_id if not provided
if data.project_id is None:
data.project_id = str(uuid.uuid4())
else:
# Check if project_id already exists
existing_project = await prisma_client.db.litellm_projecttable.find_unique(
where={"project_id": data.project_id}
)
if existing_project is not None:
raise ProxyException(
message=f"Project id = {data.project_id} already exists. Please use a different project id.",
type="bad_request",
code=400,
param="project_id",
)
# Create budget if not provided
if data.budget_id is None:
data.budget_id = await _create_budget_for_project(
data=data,
user_id=user_api_key_dict.user_id,
litellm_proxy_admin_name=litellm_proxy_admin_name,
prisma_client=prisma_client,
)
## Handle Object Permission - MCP, Vector Stores etc.
object_permission_id = await _set_project_object_permission(
data=data,
prisma_client=prisma_client,
)
# Create project row (following organization_endpoints.py pattern)
project_row = LiteLLM_ProjectTable(
**data.json(exclude_none=True),
object_permission_id=object_permission_id,
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
)
for field in LiteLLM_ManagementEndpoint_MetadataFields:
if getattr(data, field, None) is not None:
_set_object_metadata_field(
object_data=project_row,
field_name=field,
value=getattr(data, field),
)
new_project_row = prisma_client.jsonify_object(
project_row.json(exclude_none=True)
)
# Remove budget fields (following organization_endpoints.py pattern)
new_project_row = _remove_budget_fields_from_project_data(new_project_row)
verbose_proxy_logger.info(
f"new_project_row: {json.dumps(new_project_row, indent=2)}"
)
response = await prisma_client.db.litellm_projecttable.create(
data={
**new_project_row, # type: ignore
},
include={"litellm_budget_table": True},
)
return response
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.management_endpoints.project_endpoints.new_project(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e)
@router.post(
"/project/update",
tags=["project management"],
dependencies=[Depends(user_api_key_auth)],
response_model=LiteLLM_ProjectTable,
)
@management_endpoint_wrapper
async def update_project( # noqa: PLR0915
data: UpdateProjectRequest,
http_request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Update a project
Parameters:
- project_id: *str* - The project id to update. Required.
- project_alias: *Optional[str]* - Updated name for the project
- description: *Optional[str]* - Updated description for the project
- team_id: *Optional[str]* - Updated team_id for the project
- metadata: *Optional[dict]* - Updated metadata for project
- models: *Optional[list]* - Updated list of models for the project
- blocked: *Optional[bool]* - Updated blocked status
- max_budget: *Optional[float]* - Updated max budget
- tpm_limit: *Optional[int]* - Updated tpm limit
- rpm_limit: *Optional[int]* - Updated rpm limit
- model_rpm_limit: *Optional[dict]* - Updated RPM limits per model
- model_tpm_limit: *Optional[dict]* - Updated TPM limits per model
- budget_duration: *Optional[str]* - Updated budget duration
- tags: *Optional[list]* - Updated list of tags for the project
- object_permission: Optional[LiteLLM_ObjectPermissionBase] - Updated object permission
Example:
```bash
curl --location 'http://0.0.0.0:4000/project/update' \\
--header 'Authorization: Bearer sk-1234' \\
--header 'Content-Type: application/json' \\
--data '{
"project_id": "project-123",
"description": "Updated flight search system with enhanced capabilities",
"max_budget": 200,
"model_rpm_limit": {
"gpt-4": 2000,
"gpt-3.5-turbo": 10000
},
"metadata": {
"use_case_id": "SNOW-12345",
"status": "active"
}
}'
```
"""
from litellm.proxy.proxy_server import (
litellm_proxy_admin_name,
premium_user,
prisma_client,
)
try:
if getattr(data, "tags", None) is not None and not premium_user:
raise HTTPException(
status_code=403,
detail={
"error": "Only premium users can add tags to projects. "
+ CommonProxyErrors.not_premium_user.value
},
)
if not premium_user:
raise HTTPException(
status_code=403,
detail={
"error": "Project management is an enterprise feature. "
+ CommonProxyErrors.not_premium_user.value
},
)
# ADD METADATA FIELDS
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
if getattr(data, field, None) is not None:
_set_object_metadata_field(
object_data=data,
field_name=field,
value=getattr(data, field),
)
delattr(data, field)
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if data.project_id is None:
raise HTTPException(
status_code=400,
detail={"error": "project_id is required"},
)
# Fetch existing project
existing_project = await prisma_client.db.litellm_projecttable.find_unique(
where={"project_id": data.project_id}
)
if existing_project is None:
raise ProxyException(
message=f"Project not found, project_id={data.project_id}",
type="not_found",
code=404,
param="project_id",
)
# Validate team exists and get team object for limit + permission checks
team_id_to_check = data.team_id or existing_project.team_id
team_obj_for_checks = None
if team_id_to_check is not None:
team_obj_for_checks = await _validate_team_exists(
team_id=team_id_to_check, prisma_client=prisma_client
)
# Check if user has permission to update this project
has_permission = await _check_user_permission_for_project(
user_api_key_dict=user_api_key_dict,
team_id=existing_project.team_id,
prisma_client=prisma_client,
team_object=LiteLLM_TeamTable(**team_obj_for_checks.model_dump())
if team_obj_for_checks
else None,
)
if not has_permission:
raise HTTPException(
status_code=403,
detail={"error": "Only admins or team admins can update projects"},
)
# Validate project limits against team limits
if team_obj_for_checks is not None:
_check_team_project_limits(
team_object=LiteLLM_TeamTable(**team_obj_for_checks.model_dump()),
data=data,
)
# Prepare update data
update_data = data.json(exclude_none=True, exclude={"project_id"})
update_data = prisma_client.jsonify_object(update_data)
update_data["updated_by"] = (
user_api_key_dict.user_id or litellm_proxy_admin_name
)
# Handle budget updates
budget_fields = LiteLLM_BudgetTable.model_fields.keys()
budget_updates = {k: v for k, v in update_data.items() if k in budget_fields}
if budget_updates and existing_project.budget_id:
# Update existing budget
await prisma_client.db.litellm_budgettable.update(
where={"budget_id": existing_project.budget_id},
data={
**budget_updates,
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
},
)
# Remove budget fields from project update
for field in budget_updates.keys():
update_data.pop(field, None)
# Handle object permissions
if "object_permission" in update_data:
object_permission_data = update_data.pop("object_permission")
if object_permission_data:
if existing_project.object_permission_id:
# Update existing permission
await prisma_client.db.litellm_objectpermissiontable.update(
where={
"object_permission_id": existing_project.object_permission_id
},
data=object_permission_data,
)
else:
# Create new permission
created_permission = (
await prisma_client.db.litellm_objectpermissiontable.create(
data=object_permission_data,
)
)
update_data[
"object_permission_id"
] = created_permission.object_permission_id
# Handle metadata fields
for field in LiteLLM_ManagementEndpoint_MetadataFields:
if field in update_data:
if update_data.get("metadata") is None:
update_data["metadata"] = {}
update_data["metadata"][field] = update_data.pop(field)
# Remove budget fields (following organization_endpoints.py pattern)
update_data = _remove_budget_fields_from_project_data(update_data)
# Update project
updated_project = await prisma_client.db.litellm_projecttable.update(
where={"project_id": data.project_id},
data=update_data,
include={"litellm_budget_table": True, "object_permission": True},
)
return updated_project
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.management_endpoints.project_endpoints.update_project(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e)
@router.delete(
"/project/delete",
tags=["project management"],
dependencies=[Depends(user_api_key_auth)],
response_model=List[LiteLLM_ProjectTable],
)
@management_endpoint_wrapper
async def delete_project(
data: DeleteProjectRequest,
http_request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Delete projects
Parameters:
- project_ids: *List[str]* - List of project ids to delete
Example:
```bash
curl --location --request DELETE 'http://0.0.0.0:4000/project/delete' \\
--header 'Authorization: Bearer sk-1234' \\
--header 'Content-Type: application/json' \\
--data '{
"project_ids": ["project-123", "project-456"]
}'
```
"""
from litellm.proxy.proxy_server import premium_user, prisma_client
try:
if not premium_user:
raise HTTPException(
status_code=403,
detail={
"error": "Project management is an enterprise feature. "
+ CommonProxyErrors.not_premium_user.value
},
)
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
# Check if user is admin (only admins can delete projects)
has_permission = await _check_user_permission_for_project(
user_api_key_dict=user_api_key_dict,
team_id=None,
prisma_client=prisma_client,
require_admin=True,
)
if not has_permission:
raise HTTPException(
status_code=403,
detail={"error": "Only admins can delete projects"},
)
deleted_projects = []
for project_id in data.project_ids:
# Check if project exists
existing_project = await prisma_client.db.litellm_projecttable.find_unique(
where={"project_id": project_id}
)
if existing_project is None:
raise ProxyException(
message=f"Project not found, project_id={project_id}",
type="not_found",
code=404,
param="project_ids",
)
# Check if there are any keys associated with this project
associated_keys = (
await prisma_client.db.litellm_verificationtoken.find_many(
where={"project_id": project_id}
)
)
if len(associated_keys) > 0:
raise ProxyException(
message=f"Cannot delete project {project_id}. {len(associated_keys)} key(s) are associated with it. Please delete or reassign the keys first.",
type="bad_request",
code=400,
param="project_ids",
)
# Delete the project
deleted_project = await prisma_client.db.litellm_projecttable.delete(
where={"project_id": project_id}
)
deleted_projects.append(deleted_project)
return deleted_projects
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.management_endpoints.project_endpoints.delete_project(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e)
@router.get(
"/project/info",
tags=["project management"],
dependencies=[Depends(user_api_key_auth)],
response_model=LiteLLM_ProjectTable,
)
async def project_info(
project_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get information about a specific project
Parameters:
- project_id: *str* - The project id to fetch info for
Example:
```bash
curl --location 'http://0.0.0.0:4000/project/info?project_id=project-123' \\
--header 'Authorization: Bearer sk-1234'
```
"""
from litellm.proxy.proxy_server import prisma_client
try:
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
# Fetch project
project = await prisma_client.db.litellm_projecttable.find_unique(
where={"project_id": project_id},
include={"litellm_budget_table": True, "object_permission": True},
)
if project is None:
raise ProxyException(
message=f"Project not found, project_id={project_id}",
type="not_found",
code=404,
param="project_id",
)
# Check if user has access to this project (admin or team member)
is_admin = user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
is_team_member = False
if project.team_id and user_api_key_dict.user_id:
team = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": project.team_id}
)
if team:
is_team_member = (
user_api_key_dict.user_id in team.admins
or user_api_key_dict.user_id in team.members
)
if not (is_admin or is_team_member):
raise HTTPException(
status_code=403,
detail={"error": "You don't have access to this project"},
)
return project
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.management_endpoints.project_endpoints.project_info(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e)
@router.get(
"/project/list",
tags=["project management"],
dependencies=[Depends(user_api_key_auth)],
response_model=List[LiteLLM_ProjectTable],
)
async def list_projects(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
List all projects that the user has access to
Example:
```bash
curl --location 'http://0.0.0.0:4000/project/list' \\
--header 'Authorization: Bearer sk-1234'
```
"""
from litellm.proxy.proxy_server import prisma_client
try:
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
# If proxy admin, get all projects
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN:
projects = await prisma_client.db.litellm_projecttable.find_many(
include={"litellm_budget_table": True, "object_permission": True}
)
else:
# Get projects for teams the user belongs to
user_teams = await prisma_client.db.litellm_teamtable.find_many(
where={
"OR": [
{"members": {"has": user_api_key_dict.user_id}},
{"admins": {"has": user_api_key_dict.user_id}},
]
}
)
team_ids = [team.team_id for team in user_teams]
projects = await prisma_client.db.litellm_projecttable.find_many(
where={"team_id": {"in": team_ids}},
include={"litellm_budget_table": True, "object_permission": True},
)
return projects
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.management_endpoints.project_endpoints.list_projects(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e)

View File

@@ -0,0 +1,181 @@
"""
ROUTER SETTINGS MANAGEMENT
Endpoints for accessing router configuration and metadata
GET /router/settings - Get router configuration including available routing strategies
GET /router/fields - Get router settings field definitions without values (for UI rendering)
"""
import inspect
from typing import Any, Dict, List, get_args
from fastapi import APIRouter, Depends
from pydantic import BaseModel, Field
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.router import Router
from litellm.types.management_endpoints import (
ROUTER_SETTINGS_FIELDS,
ROUTING_STRATEGY_DESCRIPTIONS,
RouterSettingsField,
)
router = APIRouter()
class RouterSettingsResponse(BaseModel):
fields: List[RouterSettingsField] = Field(
description="List of all configurable router settings with metadata"
)
current_values: Dict[str, Any] = Field(
description="Current values of router settings"
)
routing_strategy_descriptions: Dict[str, str] = Field(
description="Descriptions for each routing strategy option"
)
class RouterFieldsResponse(BaseModel):
fields: List[RouterSettingsField] = Field(
description="List of all configurable router settings with metadata (without field values)"
)
routing_strategy_descriptions: Dict[str, str] = Field(
description="Descriptions for each routing strategy option"
)
def _get_routing_strategies_from_router_class() -> List[str]:
"""
Dynamically extract routing strategies from the Router class __init__ method.
"""
# Get the __init__ signature
sig = inspect.signature(Router.__init__)
# Get the routing_strategy parameter
routing_strategy_param = sig.parameters.get("routing_strategy")
if routing_strategy_param and routing_strategy_param.annotation:
# Extract Literal values using get_args
literal_values = get_args(routing_strategy_param.annotation)
if literal_values:
return list(literal_values)
raise ValueError("Unable to extract routing strategies from Router class")
@router.get(
"/router/settings",
tags=["Router Settings"],
dependencies=[Depends(user_api_key_auth)],
response_model=RouterSettingsResponse,
)
async def get_router_settings(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get router configuration and available settings.
Returns:
- fields: List of all configurable router settings with their metadata (type, description, default, options)
The routing_strategy field includes available options extracted from the Router class
- current_values: Current values of router settings from config
"""
from litellm.proxy.proxy_server import llm_router, proxy_config
try:
# Get available routing strategies dynamically from Router class
available_routing_strategies = _get_routing_strategies_from_router_class()
# Get router settings fields from types file
router_fields = [
field.model_copy(deep=True) for field in ROUTER_SETTINGS_FIELDS
]
# Populate routing_strategy field with available options and descriptions
for field in router_fields:
if field.field_name == "routing_strategy":
field.options = available_routing_strategies
break
# Try to get router settings from config
config = await proxy_config.get_config()
router_settings_from_config = config.get("router_settings", {})
# Get current values from llm_router if initialized
current_values = {}
if llm_router is not None:
# Check all field names from the fields list
for field in router_fields:
if hasattr(llm_router, field.field_name):
value = getattr(llm_router, field.field_name)
current_values[field.field_name] = value
# Merge with config values (config takes precedence)
current_values.update(router_settings_from_config)
# Update field values with current values
for field in router_fields:
if field.field_name in current_values:
field.field_value = current_values[field.field_name]
return RouterSettingsResponse(
fields=router_fields,
current_values=current_values,
routing_strategy_descriptions=ROUTING_STRATEGY_DESCRIPTIONS,
)
except Exception as e:
verbose_proxy_logger.error(f"Error fetching router settings: {str(e)}")
raise
@router.get(
"/router/fields",
tags=["Router Settings"],
dependencies=[Depends(user_api_key_auth)],
response_model=RouterFieldsResponse,
)
async def get_router_fields(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get router settings field definitions without values.
Returns only the field metadata (type, description, default, options) without
populating field_value. This is useful for UI components that need to know
what fields to render, but will get the actual values from a different endpoint.
Returns:
- fields: List of all configurable router settings with their metadata (type, description, default, options)
The routing_strategy field includes available options extracted from the Router class
Note: field_value will be None for all fields
- routing_strategy_descriptions: Descriptions for each routing strategy option
"""
try:
# Get available routing strategies dynamically from Router class
available_routing_strategies = _get_routing_strategies_from_router_class()
# Get router settings fields from types file
router_fields = [
field.model_copy(deep=True) for field in ROUTER_SETTINGS_FIELDS
]
# Populate routing_strategy field with available options
for field in router_fields:
if field.field_name == "routing_strategy":
field.options = available_routing_strategies
break
# Ensure field_value is None for all fields (don't populate values)
for field in router_fields:
field.field_value = None
return RouterFieldsResponse(
fields=router_fields,
routing_strategy_descriptions=ROUTING_STRATEGY_DESCRIPTIONS,
)
except Exception as e:
verbose_proxy_logger.error(f"Error fetching router fields: {str(e)}")
raise

View File

@@ -0,0 +1,118 @@
# SCIM v2 Integration for LiteLLM Proxy
This module provides SCIM v2 (System for Cross-domain Identity Management) endpoints for LiteLLM Proxy, allowing identity providers to manage users and teams (groups) within the LiteLLM ecosystem.
## Overview
SCIM is an open standard designed to simplify user management across different systems. This implementation allows compatible identity providers (like Okta, Azure AD, OneLogin, etc.) to automatically provision and deprovision users and groups in LiteLLM Proxy.
## Endpoints
The SCIM v2 API follows the standard specification with the following base URL:
```
/scim/v2
```
### User Management
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/Users` | GET | List all users with pagination support |
| `/Users/{user_id}` | GET | Get a specific user by ID |
| `/Users` | POST | Create a new user |
| `/Users/{user_id}` | PUT | Update an existing user |
| `/Users/{user_id}` | DELETE | Delete a user |
### Group Management
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/Groups` | GET | List all groups with pagination support |
| `/Groups/{group_id}` | GET | Get a specific group by ID |
| `/Groups` | POST | Create a new group |
| `/Groups/{group_id}` | PUT | Update an existing group |
| `/Groups/{group_id}` | DELETE | Delete a group |
## SCIM Schema
This implementation follows the standard SCIM v2 schema with the following mappings:
### Users
- SCIM User ID → LiteLLM `user_id`
- SCIM User Email → LiteLLM `user_email`
- SCIM User Group Memberships → LiteLLM User-Team relationships
### Groups
- SCIM Group ID → LiteLLM `team_id`
- SCIM Group Display Name → LiteLLM `team_alias`
- SCIM Group Members → LiteLLM Team members list
## Configuration
To enable SCIM in your identity provider, use the full URL to the SCIM endpoint:
```
https://your-litellm-proxy-url/scim/v2
```
Most identity providers will require authentication. You should use a valid LiteLLM API key with administrative privileges.
## Features
- Full CRUD operations for users and groups
- Pagination support
- Basic filtering support
- Automatic synchronization of user-team relationships
- Proper status codes and error handling per SCIM specification
## Example Usage
### Listing Users
```
GET /scim/v2/Users?startIndex=1&count=10
```
### Creating a User
```json
POST /scim/v2/Users
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"userName": "john.doe@example.com",
"active": true,
"emails": [
{
"value": "john.doe@example.com",
"primary": true
}
]
}
```
### Adding a User to Groups
```json
PUT /scim/v2/Users/{user_id}
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"userName": "john.doe@example.com",
"active": true,
"emails": [
{
"value": "john.doe@example.com",
"primary": true
}
],
"groups": [
{
"value": "team-123",
"display": "Engineering Team"
}
]
}
```

View File

@@ -0,0 +1,177 @@
from typing import List, Union
from litellm.proxy._types import (
LiteLLM_TeamTable,
LiteLLM_UserTable,
Member,
NewUserResponse,
)
from litellm.types.proxy.management_endpoints.scim_v2 import *
class ScimTransformations:
DEFAULT_SCIM_NAME = "Unknown User"
DEFAULT_SCIM_FAMILY_NAME = "Unknown Family Name"
DEFAULT_SCIM_DISPLAY_NAME = "Unknown Display Name"
DEFAULT_SCIM_MEMBER_VALUE = "Unknown Member Value"
@staticmethod
async def transform_litellm_user_to_scim_user(
user: Union[LiteLLM_UserTable, NewUserResponse],
) -> SCIMUser:
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500, detail={"error": "No database connected"}
)
# Get user's teams/groups
groups = []
for team_id in user.teams or []:
team = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
if team:
team_alias = getattr(team, "team_alias", team.team_id)
groups.append(SCIMUserGroup(value=team.team_id, display=team_alias))
user_created_at = user.created_at.isoformat() if user.created_at else None
user_updated_at = user.updated_at.isoformat() if user.updated_at else None
emails = []
# Only add email if it's a valid email address (contains @)
# user_email can be a UUID when users are created without an email
if user.user_email and "@" in user.user_email:
emails.append(SCIMUserEmail(value=user.user_email, primary=True))
return SCIMUser(
schemas=["urn:ietf:params:scim:schemas:core:2.0:User"],
id=user.user_id,
userName=ScimTransformations._get_scim_user_name(user),
displayName=ScimTransformations._get_scim_user_name(user),
name=SCIMUserName(
familyName=ScimTransformations._get_scim_family_name(user),
givenName=ScimTransformations._get_scim_given_name(user),
),
emails=emails,
groups=groups,
active=True,
meta={
"resourceType": "User",
"created": user_created_at,
"lastModified": user_updated_at,
},
)
@staticmethod
def _get_scim_user_name(user: Union[LiteLLM_UserTable, NewUserResponse]) -> str:
"""
SCIM requires a display name with length > 0
We use the same userName and displayName for SCIM users
"""
if user.user_email and len(user.user_email) > 0:
return user.user_email
return ScimTransformations.DEFAULT_SCIM_DISPLAY_NAME
@staticmethod
def _get_scim_family_name(user: Union[LiteLLM_UserTable, NewUserResponse]) -> str:
"""
SCIM requires a family name with length > 0
"""
metadata = user.metadata or {}
if "scim_metadata" in metadata:
scim_metadata: LiteLLM_UserScimMetadata = LiteLLM_UserScimMetadata(
**metadata["scim_metadata"]
)
if scim_metadata.familyName and len(scim_metadata.familyName) > 0:
return scim_metadata.familyName
if user.user_alias and len(user.user_alias) > 0:
return user.user_alias
return ScimTransformations.DEFAULT_SCIM_FAMILY_NAME
@staticmethod
def _get_scim_given_name(user: Union[LiteLLM_UserTable, NewUserResponse]) -> str:
"""
SCIM requires a given name with length > 0
"""
metadata = user.metadata or {}
if "scim_metadata" in metadata:
scim_metadata: LiteLLM_UserScimMetadata = LiteLLM_UserScimMetadata(
**metadata["scim_metadata"]
)
if scim_metadata.givenName and len(scim_metadata.givenName) > 0:
return scim_metadata.givenName
if user.user_alias and len(user.user_alias) > 0:
return user.user_alias or ScimTransformations.DEFAULT_SCIM_NAME
return ScimTransformations.DEFAULT_SCIM_NAME
@staticmethod
async def transform_litellm_team_to_scim_group(
team: Union[LiteLLM_TeamTable, dict],
) -> SCIMGroup:
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500, detail={"error": "No database connected"}
)
if isinstance(team, dict):
team = LiteLLM_TeamTable(**team)
# Get team members with proper display names
scim_members: List[SCIMMember] = []
for member in team.members_with_roles or []:
if isinstance(member, dict):
member = Member(**member)
scim_members.append(
SCIMMember(
value=ScimTransformations._get_scim_member_value(member),
display=ScimTransformations._get_scim_member_display(member),
)
)
team_alias = getattr(team, "team_alias", team.team_id)
team_created_at = team.created_at.isoformat() if team.created_at else None
team_updated_at = team.updated_at.isoformat() if team.updated_at else None
return SCIMGroup(
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
id=team.team_id,
displayName=team_alias,
members=scim_members,
meta={
"resourceType": "Group",
"created": team_created_at,
"lastModified": team_updated_at,
},
)
@staticmethod
def _get_scim_member_value(member: Member) -> str:
"""
Get the SCIM member value. Use user_email if available, otherwise use user_id.
SCIM member value should be the unique identifier for the user.
"""
if hasattr(member, "user_email") and member.user_email:
return member.user_email
elif hasattr(member, "user_id"):
return member.user_id or ScimTransformations.DEFAULT_SCIM_MEMBER_VALUE
return ScimTransformations.DEFAULT_SCIM_MEMBER_VALUE
@staticmethod
def _get_scim_member_display(member: Member) -> str:
"""
Get the SCIM member display. Use user_email if available, otherwise use user_id.
SCIM member display should be the display name for the user.
"""
if hasattr(member, "user_email") and member.user_email:
return member.user_email
elif hasattr(member, "user_id"):
return member.user_id or ScimTransformations.DEFAULT_SCIM_MEMBER_VALUE
return ScimTransformations.DEFAULT_SCIM_MEMBER_VALUE

View File

@@ -0,0 +1,11 @@
"""
SSO (Single Sign-On) related modules for LiteLLM Proxy.
This package contains custom SSO implementations and utilities.
"""
from litellm.proxy.management_endpoints.sso.custom_microsoft_sso import (
CustomMicrosoftSSO,
)
__all__ = ["CustomMicrosoftSSO"]

View File

@@ -0,0 +1,94 @@
"""
Custom Microsoft SSO class that allows overriding default Microsoft endpoints.
This module provides a subclass of fastapi_sso's MicrosoftSSO that allows
custom authorization, token, and userinfo endpoints to be specified via environment
variables.
Environment Variables:
- MICROSOFT_AUTHORIZATION_ENDPOINT: Custom authorization endpoint URL
- MICROSOFT_TOKEN_ENDPOINT: Custom token endpoint URL
- MICROSOFT_USERINFO_ENDPOINT: Custom userinfo endpoint URL
If these are not set, the default Microsoft endpoints are used.
"""
import os
from typing import List, Optional, Union
import pydantic
from fastapi_sso.sso.base import DiscoveryDocument
from fastapi_sso.sso.microsoft import MicrosoftSSO
from litellm._logging import verbose_proxy_logger
class CustomMicrosoftSSO(MicrosoftSSO):
"""
Microsoft SSO subclass that allows overriding default endpoints via environment variables.
Supports:
- MICROSOFT_AUTHORIZATION_ENDPOINT
- MICROSOFT_TOKEN_ENDPOINT
- MICROSOFT_USERINFO_ENDPOINT
"""
def __init__(
self,
client_id: str,
client_secret: str,
redirect_uri: Optional[Union[pydantic.AnyHttpUrl, str]] = None,
allow_insecure_http: bool = False,
scope: Optional[List[str]] = None,
tenant: Optional[str] = None,
):
super().__init__(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
allow_insecure_http=allow_insecure_http,
scope=scope,
tenant=tenant,
)
async def get_discovery_document(self) -> DiscoveryDocument:
"""
Override to support custom endpoints via environment variables.
Falls back to default Microsoft endpoints if not set.
"""
custom_authorization_endpoint = os.getenv(
"MICROSOFT_AUTHORIZATION_ENDPOINT", None
)
custom_token_endpoint = os.getenv("MICROSOFT_TOKEN_ENDPOINT", None)
custom_userinfo_endpoint = os.getenv("MICROSOFT_USERINFO_ENDPOINT", None)
# Use custom endpoints if set, otherwise use defaults
authorization_endpoint = (
custom_authorization_endpoint
or f"https://login.microsoftonline.com/{self.tenant}/oauth2/v2.0/authorize"
)
token_endpoint = (
custom_token_endpoint
or f"https://login.microsoftonline.com/{self.tenant}/oauth2/v2.0/token"
)
userinfo_endpoint = (
custom_userinfo_endpoint or f"https://graph.microsoft.com/{self.version}/me"
)
if (
custom_authorization_endpoint
or custom_token_endpoint
or custom_userinfo_endpoint
):
verbose_proxy_logger.debug(
f"Using custom Microsoft SSO endpoints - "
f"authorization: {authorization_endpoint}, "
f"token: {token_endpoint}, "
f"userinfo: {userinfo_endpoint}"
)
return DiscoveryDocument(
authorization_endpoint=authorization_endpoint,
token_endpoint=token_endpoint,
userinfo_endpoint=userinfo_endpoint,
)

View File

@@ -0,0 +1,27 @@
from typing import Dict, Union
from litellm.proxy._types import LitellmUserRoles
def check_is_admin_only_access(ui_access_mode: Union[str, Dict]) -> bool:
"""Checks ui access mode is admin_only"""
if isinstance(ui_access_mode, str):
return ui_access_mode == "admin_only"
else:
return False
def has_admin_ui_access(user_role: str) -> bool:
"""
Check if the user has admin access to the UI.
Returns:
bool: True if user is 'proxy_admin' or 'proxy_admin_view_only', False otherwise.
"""
if (
user_role != LitellmUserRoles.PROXY_ADMIN.value
and user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value
):
return False
return True

View File

@@ -0,0 +1,574 @@
"""
TAG MANAGEMENT
All /tag management endpoints
/tag/new
/tag/info
/tag/update
/tag/delete
/tag/list
"""
import asyncio
import json
from typing import TYPE_CHECKING, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.common_daily_activity import (
SpendAnalyticsPaginatedResponse,
get_daily_activity,
)
from litellm.proxy.management_helpers.utils import handle_budget_for_entity
from litellm.types.tag_management import (
TagConfig,
TagDeleteRequest,
TagInfoRequest,
TagNewRequest,
TagUpdateRequest,
)
if TYPE_CHECKING:
from litellm import Router
from litellm.types.router import Deployment
router = APIRouter()
async def _get_model_names(prisma_client, model_ids: list) -> Dict[str, str]:
"""Helper function to get model names from model IDs"""
try:
models = await prisma_client.db.litellm_proxymodeltable.find_many(
where={"model_id": {"in": model_ids}}
)
return {model.model_id: model.model_name for model in models}
except Exception as e:
verbose_proxy_logger.error(f"Error getting model names: {str(e)}")
return {}
async def get_deployments_by_model(
model: str, llm_router: "Router"
) -> List["Deployment"]:
"""
Get all deployments by model
"""
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
# Check if model id
deployment = llm_router.get_deployment(model_id=model)
if deployment is not None:
return [deployment]
# Check if model name
deployments = llm_router.get_model_list(model_name=model)
if deployments is None:
return []
return [
Deployment(
model_name=deployment["model_name"],
litellm_params=LiteLLM_Params(**deployment["litellm_params"]), # type: ignore
model_info=ModelInfo(**deployment.get("model_info") or {}),
)
for deployment in deployments
]
@router.post(
"/tag/new",
tags=["tag management"],
dependencies=[Depends(user_api_key_auth)],
)
async def new_tag(
tag: TagNewRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create a new tag.
Parameters:
- name: str - The name of the tag
- description: Optional[str] - Description of what this tag represents
- models: List[str] - List of either 'model_id' or 'model_name' allowed for this tag
- budget_id: Optional[str] - The id for a budget (tpm/rpm/max budget) for the tag
### IF NO BUDGET ID - CREATE ONE WITH THESE PARAMS ###
- max_budget: Optional[float] - Max budget for tag
- tpm_limit: Optional[int] - Max tpm limit for tag
- rpm_limit: Optional[int] - Max rpm limit for tag
- max_parallel_requests: Optional[int] - Max parallel requests for tag
- soft_budget: Optional[float] - Get a slack alert when this soft budget is reached
- model_max_budget: Optional[dict] - Max budget for a specific model
- budget_duration: Optional[str] - Frequency of resetting tag budget
"""
from litellm.proxy._types import CommonProxyErrors
from litellm.proxy.proxy_server import (
litellm_proxy_admin_name,
llm_router,
prisma_client,
)
if prisma_client is None:
raise HTTPException(
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
)
if llm_router is None:
raise HTTPException(
status_code=500, detail=CommonProxyErrors.no_llm_router.value
)
try:
# Check if tag already exists
existing_tag = await prisma_client.db.litellm_tagtable.find_unique(
where={"tag_name": tag.name}
)
if existing_tag is not None:
raise HTTPException(
status_code=400, detail=f"Tag {tag.name} already exists"
)
# Handle budget creation/assignment using common helper
budget_id = await handle_budget_for_entity(
data=tag,
existing_budget_id=None,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
litellm_proxy_admin_name=litellm_proxy_admin_name,
)
# Get model names for model_info
model_info = await _get_model_names(prisma_client, tag.models or [])
# Create new tag in database
new_tag_record = await prisma_client.db.litellm_tagtable.create(
data={
"tag_name": tag.name,
"description": tag.description,
"models": tag.models or [],
"model_info": json.dumps(model_info),
"spend": 0.0,
"budget_id": budget_id,
"created_by": user_api_key_dict.user_id,
}
)
# Update models with new tag
if tag.models:
tasks = []
for model in tag.models:
deployments = await get_deployments_by_model(model, llm_router)
tasks.extend(
[
_add_tag_to_deployment(
deployment=deployment,
tag=tag.name,
)
for deployment in deployments
]
)
await asyncio.gather(*tasks)
# Build response
tag_config = TagConfig(
name=new_tag_record.tag_name,
description=new_tag_record.description,
models=new_tag_record.models,
model_info=model_info,
created_at=new_tag_record.created_at.isoformat(),
updated_at=new_tag_record.updated_at.isoformat(),
created_by=new_tag_record.created_by,
)
return {
"message": f"Tag {tag.name} created successfully",
"tag": tag_config,
}
except Exception as e:
verbose_proxy_logger.exception(f"Error creating tag: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
async def _add_tag_to_deployment(deployment: "Deployment", tag: str):
"""Helper function to add tag to deployment"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
# Get current model from database to preserve encrypted fields
db_model = await prisma_client.db.litellm_proxymodeltable.find_unique(
where={"model_id": deployment.model_info.id}
)
if db_model is None:
raise HTTPException(
status_code=404,
detail=f"Model {deployment.model_info.id} not found in database",
)
# Prisma returns litellm_params as dict (already parsed from JSON)
existing_params = db_model.litellm_params
if isinstance(existing_params, str):
# If it's a string, parse it
existing_params = json.loads(existing_params)
elif not isinstance(existing_params, dict):
raise Exception(f"Unexpected litellm_params type: {type(existing_params)}")
# Add tag to tags array (preserve encryption of other fields)
if "tags" not in existing_params:
existing_params["tags"] = []
if tag not in existing_params["tags"]:
existing_params["tags"].append(tag)
# Update database with modified params (keeps encrypted fields encrypted)
await prisma_client.db.litellm_proxymodeltable.update(
where={"model_id": deployment.model_info.id},
data={"litellm_params": json.dumps(existing_params)},
)
except Exception as e:
verbose_proxy_logger.exception(f"Error adding tag to deployment: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/tag/update",
tags=["tag management"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_tag(
tag: TagUpdateRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Update an existing tag.
Parameters:
- name: str - The name of the tag to update
- description: Optional[str] - Updated description
- models: List[str] - Updated list of allowed LLM models
- budget_id: Optional[str] - The id for a budget to associate with the tag
### BUDGET UPDATE PARAMS ###
- max_budget: Optional[float] - Max budget for tag
- tpm_limit: Optional[int] - Max tpm limit for tag
- rpm_limit: Optional[int] - Max rpm limit for tag
- max_parallel_requests: Optional[int] - Max parallel requests for tag
- soft_budget: Optional[float] - Get a slack alert when this soft budget is reached
- model_max_budget: Optional[dict] - Max budget for a specific model
- budget_duration: Optional[str] - Frequency of resetting tag budget
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
# Check if tag exists
existing_tag = await prisma_client.db.litellm_tagtable.find_unique(
where={"tag_name": tag.name}
)
if existing_tag is None:
raise HTTPException(status_code=404, detail=f"Tag {tag.name} not found")
from litellm.proxy.proxy_server import litellm_proxy_admin_name
# Handle budget updates using common helper
budget_id = await handle_budget_for_entity(
data=tag,
existing_budget_id=existing_tag.budget_id,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
litellm_proxy_admin_name=litellm_proxy_admin_name,
)
# Get model names for model_info
model_info = await _get_model_names(prisma_client, tag.models or [])
# Prepare update data
update_data = {
"description": tag.description,
"models": tag.models or [],
"model_info": json.dumps(model_info),
}
# Add budget_id if it changed
if budget_id != existing_tag.budget_id:
update_data["budget_id"] = budget_id
# Update tag in database
updated_tag_record = await prisma_client.db.litellm_tagtable.update(
where={"tag_name": tag.name},
data=update_data,
)
# Build response
tag_config = TagConfig(
name=updated_tag_record.tag_name,
description=updated_tag_record.description,
models=updated_tag_record.models,
model_info=model_info,
created_at=updated_tag_record.created_at.isoformat(),
updated_at=updated_tag_record.updated_at.isoformat(),
created_by=updated_tag_record.created_by,
)
return {
"message": f"Tag {tag.name} updated successfully",
"tag": tag_config,
}
except Exception as e:
verbose_proxy_logger.exception(f"Error updating tag: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/tag/info",
tags=["tag management"],
dependencies=[Depends(user_api_key_auth)],
)
async def info_tag(
data: TagInfoRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get information about specific tags.
Parameters:
- names: List[str] - List of tag names to get information for
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
# Query tags from database with budget info
tag_records = await prisma_client.db.litellm_tagtable.find_many(
where={"tag_name": {"in": data.names}},
include={"litellm_budget_table": True},
)
# Check if any requested tags don't exist
found_tag_names = {tag.tag_name for tag in tag_records}
missing_tags = [name for name in data.names if name not in found_tag_names]
if missing_tags:
raise HTTPException(
status_code=404, detail=f"Tags not found: {missing_tags}"
)
# Build response
requested_tags = {}
for tag_record in tag_records:
# Parse model_info from JSON
model_info = {}
if tag_record.model_info:
if isinstance(tag_record.model_info, str):
model_info = json.loads(tag_record.model_info)
else:
model_info = tag_record.model_info
tag_dict = {
"name": tag_record.tag_name,
"description": tag_record.description,
"models": tag_record.models,
"model_info": model_info,
"created_at": tag_record.created_at.isoformat(),
"updated_at": tag_record.updated_at.isoformat(),
"created_by": tag_record.created_by,
}
# Add budget info if available
if (
hasattr(tag_record, "litellm_budget_table")
and tag_record.litellm_budget_table
):
tag_dict["litellm_budget_table"] = tag_record.litellm_budget_table
requested_tags[tag_record.tag_name] = tag_dict
return requested_tags
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get(
"/tag/list",
tags=["tag management"],
dependencies=[Depends(user_api_key_auth)],
)
async def list_tags(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
List all available tags with their budget information.
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
## QUERY STORED TAGS ##
tag_records = await prisma_client.db.litellm_tagtable.find_many(
include={"litellm_budget_table": True}
)
stored_tag_names = set()
list_of_tags = []
for tag_record in tag_records:
stored_tag_names.add(tag_record.tag_name)
# Parse model_info from JSON
model_info = {}
if tag_record.model_info:
if isinstance(tag_record.model_info, str):
model_info = json.loads(tag_record.model_info)
else:
model_info = tag_record.model_info
tag_dict = {
"name": tag_record.tag_name,
"description": tag_record.description,
"models": tag_record.models,
"model_info": model_info,
"created_at": tag_record.created_at.isoformat(),
"updated_at": tag_record.updated_at.isoformat(),
"created_by": tag_record.created_by,
}
# Add budget info if available
if (
hasattr(tag_record, "litellm_budget_table")
and tag_record.litellm_budget_table
):
tag_dict["litellm_budget_table"] = tag_record.litellm_budget_table
list_of_tags.append(tag_dict)
## QUERY DYNAMIC TAGS ##
# Use group_by instead of find_many(distinct=["tag"]).
# Prisma's distinct fetches all columns for all rows and deduplicates
# in application code, which is extremely slow on large tables.
# See: https://www.prisma.io/docs/orm/prisma-client/queries/aggregation-grouping-summarizing#distinct-under-the-hood
dynamic_tag_rows = await prisma_client.db.litellm_dailytagspend.group_by(
by=["tag"],
where={"tag": {"not": None}},
# The old find_many(distinct=...) returned arbitrary timestamps from
# whichever row Prisma happened to pick. MIN/MAX give more meaningful
# values: earliest appearance and most recent activity.
_min={"created_at": True},
_max={"updated_at": True},
)
dynamic_tag_config = [
{
"name": row["tag"],
"description": "This is just a spend tag that was passed dynamically in a request. It does not control any LLM models.",
"models": None,
"created_at": row["_min"]["created_at"].isoformat(),
"updated_at": row["_max"]["updated_at"].isoformat(),
}
for row in dynamic_tag_rows
if row["tag"] not in stored_tag_names
]
return list_of_tags + dynamic_tag_config
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/tag/delete",
tags=["tag management"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_tag(
data: TagDeleteRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Delete a tag.
Parameters:
- name: str - The name of the tag to delete
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
# Check if tag exists
existing_tag = await prisma_client.db.litellm_tagtable.find_unique(
where={"tag_name": data.name}
)
if existing_tag is None:
raise HTTPException(status_code=404, detail=f"Tag {data.name} not found")
# Delete tag from database
await prisma_client.db.litellm_tagtable.delete(where={"tag_name": data.name})
return {"message": f"Tag {data.name} deleted successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get(
"/tag/daily/activity",
response_model=SpendAnalyticsPaginatedResponse,
tags=["tag management"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_tag_daily_activity(
tags: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
model: Optional[str] = None,
api_key: Optional[str] = None,
page: int = 1,
page_size: int = 10,
):
"""
Get daily activity for specific tags or all tags.
Args:
tags (Optional[str]): Comma-separated list of tags to filter by. If not provided, returns data for all tags.
start_date (Optional[str]): Start date for the activity period (YYYY-MM-DD).
end_date (Optional[str]): End date for the activity period (YYYY-MM-DD).
model (Optional[str]): Filter by model name.
api_key (Optional[str]): Filter by API key.
page (int): Page number for pagination.
page_size (int): Number of items per page.
Returns:
SpendAnalyticsPaginatedResponse: Paginated response containing daily activity data.
"""
from litellm.proxy.proxy_server import prisma_client
# Convert comma-separated tags string to list if provided
tag_list = tags.split(",") if tags else None
return await get_daily_activity(
prisma_client=prisma_client,
table_name="litellm_dailytagspend",
entity_id_field="tag",
entity_id=tag_list,
entity_metadata_field=None,
start_date=start_date,
end_date=end_date,
model=model,
api_key=api_key,
page=page,
page_size=page_size,
# metadata_metrics_func=None because litellm_dailytagspend rows are
# pre-aggregated per (date, tag, model, …) and have no request_id.
# Deduplication across tags is therefore not possible at this level —
# a request tagged with N tags contributes its spend to N separate rows,
# so passing compute_tag_metadata_totals would double-count spend when
# multiple tags are present. The panel is primarily used to inspect
# individual tags, making this trade-off acceptable.
metadata_metrics_func=None,
)

View File

@@ -0,0 +1,346 @@
"""
Endpoints to control callbacks per team
Use this when each team should control its own callbacks
"""
import json
import traceback
from typing import List, Optional
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
AddTeamCallback,
ProxyErrorTypes,
ProxyException,
TeamCallbackMetadata,
UserAPIKeyAuth,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
router = APIRouter()
@router.post(
"/team/{team_id:path}/callback",
tags=["team management"],
dependencies=[Depends(user_api_key_auth)],
)
@management_endpoint_wrapper
async def add_team_callbacks(
data: AddTeamCallback,
http_request: Request,
team_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
litellm_changed_by: Optional[str] = Header(
None,
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
),
):
"""
Add a success/failure callback to a team
Use this if if you want different teams to have different success/failure callbacks
Parameters:
- callback_name (Literal["langfuse", "langsmith", "gcs"], required): The name of the callback to add
- callback_type (Literal["success", "failure", "success_and_failure"], required): The type of callback to add. One of:
- "success": Callback for successful LLM calls
- "failure": Callback for failed LLM calls
- "success_and_failure": Callback for both successful and failed LLM calls
- callback_vars (StandardCallbackDynamicParams, required): A dictionary of variables to pass to the callback
- langfuse_public_key: The public key for the Langfuse callback
- langfuse_secret_key: The secret key for the Langfuse callback
- langfuse_secret: The secret for the Langfuse callback
- langfuse_host: The host for the Langfuse callback
- gcs_bucket_name: The name of the GCS bucket
- gcs_path_service_account: The path to the GCS service account
- langsmith_api_key: The API key for the Langsmith callback
- langsmith_project: The project for the Langsmith callback
- langsmith_base_url: The base URL for the Langsmith callback
Example curl:
```
curl -X POST 'http:/localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/callback' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"callback_name": "langfuse",
"callback_type": "success",
"callback_vars": {"langfuse_public_key": "pk-lf-xxxx1", "langfuse_secret_key": "sk-xxxxx"}
}'
```
This means for the team where team_id = dbe2f686-a686-4896-864a-4c3924458709, all LLM calls will be logged to langfuse using the public key pk-lf-xxxx1 and the secret key sk-xxxxx
"""
try:
from litellm.proxy._types import CommonProxyErrors
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
# Check if team_id exists already
_existing_team = await prisma_client.get_data(
team_id=team_id, table_name="team", query_type="find_unique"
)
if _existing_team is None:
raise HTTPException(
status_code=400,
detail={
"error": f"Team id = {team_id} does not exist. Please use a different team id."
},
)
# store team callback settings in metadata
team_metadata = _existing_team.metadata
team_callback_settings: List[dict] = team_metadata.get(
"logging"
) # will be dict of type AddTeamCallback
if team_callback_settings is None or not isinstance(
team_callback_settings, list
):
team_callback_settings = []
## check if it already exists, for the same callback event
for callback in team_callback_settings:
if (
callback.get("callback_name") == data.callback_name
and callback.get("callback_type") == data.callback_type
):
raise ProxyException(
message=f"callback_name = {data.callback_name} already exists in team_callback_settings, for team_id = {team_id} and event = {data.callback_type}",
code=status.HTTP_400_BAD_REQUEST,
type=ProxyErrorTypes.bad_request_error,
param="callback_name",
)
team_callback_settings.append(data.model_dump())
team_metadata["logging"] = team_callback_settings
team_metadata_json = json.dumps(team_metadata) # update team_metadata
new_team_row = await prisma_client.db.litellm_teamtable.update(
where={"team_id": team_id}, data={"metadata": team_metadata_json} # type: ignore
)
return {
"status": "success",
"data": new_team_row,
}
except HTTPException as e:
raise e
except ProxyException as e:
raise e
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.add_team_callbacks(): Exception occured - {}".format(
str(e)
)
)
raise ProxyException(
message="Internal Server Error, " + str(e),
type=ProxyErrorTypes.internal_server_error.value,
param=getattr(e, "param", "None"),
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@router.post(
"/team/{team_id}/disable_logging",
tags=["team management"],
dependencies=[Depends(user_api_key_auth)],
)
@management_endpoint_wrapper
async def disable_team_logging(
http_request: Request,
team_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Disable all logging callbacks for a team
Parameters:
- team_id (str, required): The unique identifier for the team
Example curl:
```
curl -X POST 'http://localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/disable_logging' \
-H 'Authorization: Bearer sk-1234'
```
"""
try:
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
# Check if team exists
_existing_team = await prisma_client.get_data(
team_id=team_id, table_name="team", query_type="find_unique"
)
if _existing_team is None:
raise HTTPException(
status_code=404,
detail={"error": f"Team id = {team_id} does not exist."},
)
# Update team metadata to disable logging
team_metadata = _existing_team.metadata
team_callback_settings = team_metadata.get("callback_settings", {})
team_callback_settings_obj = TeamCallbackMetadata(**team_callback_settings)
# Reset callbacks
team_callback_settings_obj.success_callback = []
team_callback_settings_obj.failure_callback = []
# Update metadata
team_metadata["callback_settings"] = team_callback_settings_obj.model_dump()
team_metadata_json = json.dumps(team_metadata)
# Update team in database
updated_team = await prisma_client.db.litellm_teamtable.update(
where={"team_id": team_id}, data={"metadata": team_metadata_json} # type: ignore
)
if updated_team is None:
raise HTTPException(
status_code=404,
detail={
"error": f"Team id = {team_id} does not exist. Error updating team logging"
},
)
return {
"status": "success",
"message": f"Logging disabled for team {team_id}",
"data": {
"team_id": updated_team.team_id,
"success_callbacks": [],
"failure_callbacks": [],
},
}
except Exception as e:
verbose_proxy_logger.error(
f"litellm.proxy.proxy_server.disable_team_logging(): Exception occurred - {str(e)}"
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
type=ProxyErrorTypes.internal_server_error.value,
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Internal Server Error, " + str(e),
type=ProxyErrorTypes.internal_server_error.value,
param=getattr(e, "param", "None"),
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@router.get(
"/team/{team_id:path}/callback",
tags=["team management"],
dependencies=[Depends(user_api_key_auth)],
)
@management_endpoint_wrapper
async def get_team_callbacks(
http_request: Request,
team_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get the success/failure callbacks and variables for a team
Parameters:
- team_id (str, required): The unique identifier for the team
Example curl:
```
curl -X GET 'http://localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/callback' \
-H 'Authorization: Bearer sk-1234'
```
This will return the callback settings for the team with id dbe2f686-a686-4896-864a-4c3924458709
Returns {
"status": "success",
"data": {
"team_id": team_id,
"success_callbacks": team_callback_settings_obj.success_callback,
"failure_callbacks": team_callback_settings_obj.failure_callback,
"callback_vars": team_callback_settings_obj.callback_vars,
},
}
"""
try:
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
# Check if team_id exists
_existing_team = await prisma_client.get_data(
team_id=team_id, table_name="team", query_type="find_unique"
)
if _existing_team is None:
raise HTTPException(
status_code=404,
detail={"error": f"Team id = {team_id} does not exist."},
)
# Retrieve team callback settings from metadata
team_metadata = _existing_team.metadata
team_callback_settings = team_metadata.get("callback_settings", {})
# Convert to TeamCallbackMetadata object for consistent structure
team_callback_settings_obj = TeamCallbackMetadata(**team_callback_settings)
return {
"status": "success",
"data": {
"team_id": team_id,
"success_callbacks": team_callback_settings_obj.success_callback,
"failure_callbacks": team_callback_settings_obj.failure_callback,
"callback_vars": team_callback_settings_obj.callback_vars,
},
}
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.get_team_callbacks(): Exception occurred - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
type=ProxyErrorTypes.internal_server_error.value,
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Internal Server Error, " + str(e),
type=ProxyErrorTypes.internal_server_error.value,
param=getattr(e, "param", "None"),
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

View File

@@ -0,0 +1,605 @@
"""
TOOL POLICY MANAGEMENT
All /tool management endpoints
GET /v1/tool/list - List all discovered tools and their policies
GET /v1/tool/policy/options - List available input/output policy options with descriptions
GET /v1/tool/{tool_name} - Get a single tool's details
POST /v1/tool/policy - Update the input_policy / output_policy for a tool
"""
import uuid
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.tool_management import (
LiteLLM_ToolTableRow,
ToolDetailResponse,
ToolInputPolicy,
ToolListResponse,
ToolPolicyOption,
ToolPolicyOptionsResponse,
ToolPolicyUpdateRequest,
ToolPolicyUpdateResponse,
ToolUsageLogEntry,
ToolUsageLogsResponse,
)
router = APIRouter()
TOOL_POLICY_OPTIONS = ToolPolicyOptionsResponse(
input_policies=[
ToolPolicyOption(
value="untrusted",
label="Untrusted",
description="Tool accepts any input, including data from untrusted tool outputs. Default for newly discovered tools.",
),
ToolPolicyOption(
value="trusted",
label="Trusted",
description="Tool requires trusted input. Blocked if the conversation contains output from any tool with output_policy=untrusted.",
),
ToolPolicyOption(
value="blocked",
label="Blocked",
description="Tool is completely prohibited. Any attempt to call it is rejected.",
),
],
output_policies=[
ToolPolicyOption(
value="untrusted",
label="Untrusted",
description="Tool output may contain unsafe content (prompt injection, risky code). Downstream tools with input_policy=trusted will be blocked.",
),
ToolPolicyOption(
value="trusted",
label="Trusted",
description="Tool output is verified safe. Will not trigger trust-chain blocks on downstream tools.",
),
],
)
@router.get(
"/v1/tool/policy/options",
tags=["tool management"],
dependencies=[Depends(user_api_key_auth)],
response_model=ToolPolicyOptionsResponse,
)
async def get_tool_policy_options(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Return the available input and output policy options with descriptions.
Static data — no DB call.
"""
return TOOL_POLICY_OPTIONS
@router.get(
"/v1/tool/list",
tags=["tool management"],
dependencies=[Depends(user_api_key_auth)],
response_model=ToolListResponse,
)
async def list_tools(
input_policy: Optional[ToolInputPolicy] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
List all auto-discovered tools and their policies.
Parameters:
- input_policy: Optional filter — one of "trusted", "untrusted", "blocked"
"""
from litellm.proxy.db.tool_registry_writer import list_tools as db_list_tools
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
)
try:
tools = await db_list_tools(
prisma_client=prisma_client, input_policy=input_policy
)
return ToolListResponse(tools=tools, total=len(tools))
except Exception as e:
verbose_proxy_logger.exception("Error listing tools: %s", e)
raise HTTPException(status_code=500, detail=str(e))
@router.get(
"/v1/tool/{tool_name:path}/detail",
tags=["tool management"],
dependencies=[Depends(user_api_key_auth)],
response_model=ToolDetailResponse,
)
async def get_tool_detail(
tool_name: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get a single tool with its policy overrides (for UI detail view).
"""
from litellm.proxy.db.tool_registry_writer import get_tool as db_get_tool
from litellm.proxy.db.tool_registry_writer import list_overrides_for_tool
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
)
try:
tool = await db_get_tool(prisma_client=prisma_client, tool_name=tool_name)
if tool is None:
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
overrides = await list_overrides_for_tool(
prisma_client=prisma_client, tool_name=tool_name
)
return ToolDetailResponse(tool=tool, overrides=overrides)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception("Error getting tool detail: %s", e)
raise HTTPException(status_code=500, detail=str(e))
def _input_snippet_for_tool_log(sl: Any, max_len: int = 200) -> Optional[str]:
"""Short snippet from messages or proxy_server_request for tool usage log row."""
if sl is None:
return None
messages = getattr(sl, "messages", None)
if messages is not None:
s = _snippet_str(messages, max_len)
if s:
return s
psr = getattr(sl, "proxy_server_request", None)
if not psr:
return None
if isinstance(psr, str):
import json
try:
psr = json.loads(psr)
except Exception:
return _snippet_str(psr, max_len)
if isinstance(psr, dict):
msgs = psr.get("messages")
if msgs is None and isinstance(psr.get("body"), dict):
msgs = psr["body"].get("messages")
s = _snippet_str(msgs, max_len)
if s:
return s
return _snippet_str(psr, max_len)
def _snippet_str(text: Any, max_len: int = 200) -> Optional[str]:
if text is None:
return None
if isinstance(text, str):
s = text
elif isinstance(text, list):
parts = []
for item in text:
if isinstance(item, dict) and "content" in item:
c = item["content"]
parts.append(c if isinstance(c, str) else str(c))
else:
parts.append(str(item))
s = " ".join(parts)
else:
s = str(text)
if not s or s == "{}":
return None
return (s[:max_len] + "...") if len(s) > max_len else s
@router.get(
"/v1/tool/{tool_name:path}/logs",
tags=["tool management"],
dependencies=[Depends(user_api_key_auth)],
response_model=ToolUsageLogsResponse,
)
async def get_tool_usage_logs(
tool_name: str,
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=100),
start_date: Optional[str] = Query(None, description="YYYY-MM-DD"),
end_date: Optional[str] = Query(None, description="YYYY-MM-DD"),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Return paginated spend logs for requests that used this tool (from SpendLogToolIndex).
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
)
try:
where: dict = {"tool_name": tool_name}
if start_date or end_date:
start_time_filter: Optional[datetime] = None
end_time_filter: Optional[datetime] = None
if start_date:
try:
start_time_filter = datetime.strptime(
start_date + "T00:00:00", "%Y-%m-%dT%H:%M:%S"
).replace(tzinfo=timezone.utc)
except ValueError:
pass
if end_date:
try:
end_time_filter = datetime.strptime(
end_date + "T23:59:59", "%Y-%m-%dT%H:%M:%S"
).replace(tzinfo=timezone.utc)
except ValueError:
pass
if start_time_filter is not None or end_time_filter is not None:
where["start_time"] = {}
if start_time_filter is not None:
where["start_time"]["gte"] = start_time_filter
if end_time_filter is not None:
where["start_time"]["lte"] = end_time_filter
total = await prisma_client.db.litellm_spendlogtoolindex.count(where=where)
index_rows = await prisma_client.db.litellm_spendlogtoolindex.find_many(
where=where,
order={"start_time": "desc"},
skip=(page - 1) * page_size,
take=page_size,
)
request_ids = [r.request_id for r in index_rows]
if not request_ids:
return ToolUsageLogsResponse(
logs=[], total=total, page=page, page_size=page_size
)
spend_logs = await prisma_client.db.litellm_spendlogs.find_many(
where={"request_id": {"in": request_ids}}
)
log_by_id = {s.request_id: s for s in spend_logs}
logs_out: List[ToolUsageLogEntry] = []
for r in index_rows:
sl = log_by_id.get(r.request_id)
if not sl:
continue
ts = (
sl.startTime.isoformat()
if hasattr(sl.startTime, "isoformat")
else str(sl.startTime)
)
logs_out.append(
ToolUsageLogEntry(
id=sl.request_id,
timestamp=ts,
model=getattr(sl, "model", None) or None,
spend=getattr(sl, "spend", None),
total_tokens=getattr(sl, "total_tokens", None),
input_snippet=_input_snippet_for_tool_log(sl),
)
)
return ToolUsageLogsResponse(
logs=logs_out, total=total, page=page, page_size=page_size
)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception("Error getting tool usage logs: %s", e)
raise HTTPException(status_code=500, detail=str(e))
@router.get(
"/v1/tool/{tool_name:path}",
tags=["tool management"],
dependencies=[Depends(user_api_key_auth)],
response_model=LiteLLM_ToolTableRow,
)
async def get_tool(
tool_name: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get details for a single tool.
"""
from litellm.proxy.db.tool_registry_writer import get_tool as db_get_tool
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
)
try:
tool = await db_get_tool(prisma_client=prisma_client, tool_name=tool_name)
if tool is None:
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
return tool
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception("Error getting tool: %s", e)
raise HTTPException(status_code=500, detail=str(e))
async def _resolve_key_hash_to_object_permission_id(
prisma_client: "PrismaClient",
key_hash: str,
) -> Optional[str]:
"""Resolve key (hash or raw) to object_permission_id; create permission if key has none."""
from litellm.proxy.proxy_server import hash_token
hashed = key_hash if "sk-" not in (key_hash or "") else hash_token(key_hash)
if not hashed:
return None
row = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": hashed}
)
if row is None:
return None
op_id = getattr(row, "object_permission_id", None)
if op_id:
return op_id
new_id = str(uuid.uuid4())
await prisma_client.db.litellm_objectpermissiontable.create(
data={"object_permission_id": new_id, "blocked_tools": []}
)
updated_count = await prisma_client.db.litellm_verificationtoken.update_many(
where={"token": hashed, "object_permission_id": None},
data={"object_permission_id": new_id},
)
if updated_count == 0:
await prisma_client.db.litellm_objectpermissiontable.delete(
where={"object_permission_id": new_id}
)
row = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": hashed}
)
return getattr(row, "object_permission_id", None) if row else None
return new_id
async def _resolve_team_id_to_object_permission_id(
prisma_client: "PrismaClient",
team_id: str,
) -> Optional[str]:
"""Resolve team_id to object_permission_id; create permission if team has none."""
if not team_id or not team_id.strip():
return None
team_id_clean = team_id.strip()
row = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id_clean},
select={"object_permission_id": True},
)
if row is None:
return None
op_id = getattr(row, "object_permission_id", None)
if op_id:
return op_id
new_id = str(uuid.uuid4())
await prisma_client.db.litellm_objectpermissiontable.create(
data={"object_permission_id": new_id, "blocked_tools": []}
)
updated_count = await prisma_client.db.litellm_teamtable.update_many(
where={"team_id": team_id_clean, "object_permission_id": None},
data={"object_permission_id": new_id},
)
if updated_count == 0:
await prisma_client.db.litellm_objectpermissiontable.delete(
where={"object_permission_id": new_id}
)
row = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id_clean},
select={"object_permission_id": True},
)
return getattr(row, "object_permission_id", None) if row else None
return new_id
@router.post(
"/v1/tool/policy",
tags=["tool management"],
dependencies=[Depends(user_api_key_auth)],
response_model=ToolPolicyUpdateResponse,
)
async def update_tool_policy(
data: ToolPolicyUpdateRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Set the input_policy and/or output_policy for a tool (global), or block for a specific team/key (override).
Parameters:
- tool_name: str - The tool to update
- input_policy: optional - "trusted" | "untrusted" | "blocked"
- output_policy: optional - "trusted" | "untrusted"
- team_id: optional - if set, create/update override for this team only
- key_hash: optional - if set, create/update override for this key only
"""
from litellm.proxy.db.tool_registry_writer import (
add_tool_to_object_permission_blocked,
get_tool_policy_registry,
remove_tool_from_object_permission_blocked,
)
from litellm.proxy.db.tool_registry_writer import (
update_tool_policy as db_update_tool_policy,
)
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
)
try:
if data.team_id is not None or data.key_hash is not None:
if data.team_id is not None and data.key_hash is not None:
raise HTTPException(
status_code=400,
detail="Provide either team_id or key_hash, not both",
)
if data.key_hash is not None:
op_id = await _resolve_key_hash_to_object_permission_id(
prisma_client, data.key_hash
)
else:
op_id = await _resolve_team_id_to_object_permission_id(
prisma_client, data.team_id or ""
)
if op_id is None:
raise HTTPException(
status_code=404,
detail="Key or team not found for the given identifier",
)
is_blocking = data.input_policy == "blocked"
if is_blocking:
ok = await add_tool_to_object_permission_blocked(
prisma_client=prisma_client,
object_permission_id=op_id,
tool_name=data.tool_name,
)
else:
ok = await remove_tool_from_object_permission_blocked(
prisma_client=prisma_client,
object_permission_id=op_id,
tool_name=data.tool_name,
)
if not ok:
raise HTTPException(
status_code=500,
detail=f"Failed to update policy override for tool '{data.tool_name}'",
)
registry = get_tool_policy_registry()
if registry.is_initialized():
await registry.sync_tool_policy_from_db(prisma_client)
return ToolPolicyUpdateResponse(
tool_name=data.tool_name,
input_policy=data.input_policy,
output_policy=data.output_policy,
updated=True,
team_id=data.team_id,
key_hash=data.key_hash,
)
if data.input_policy is None and data.output_policy is None:
raise HTTPException(
status_code=400,
detail="At least one of input_policy or output_policy must be provided",
)
updated = await db_update_tool_policy(
prisma_client=prisma_client,
tool_name=data.tool_name,
updated_by=user_api_key_dict.user_id,
input_policy=data.input_policy,
output_policy=data.output_policy,
)
if updated is None:
raise HTTPException(
status_code=500,
detail=f"Failed to update policy for tool '{data.tool_name}'",
)
registry = get_tool_policy_registry()
if registry.is_initialized():
await registry.sync_tool_policy_from_db(prisma_client)
return ToolPolicyUpdateResponse(
tool_name=updated.tool_name,
input_policy=updated.input_policy,
output_policy=updated.output_policy,
updated=True,
)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception("Error updating tool policy: %s", e)
raise HTTPException(status_code=500, detail=str(e))
@router.delete(
"/v1/tool/{tool_name:path}/overrides",
tags=["tool management"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_tool_policy_override(
tool_name: str,
team_id: Optional[str] = Query(
None, description="Team ID of the override to remove"
),
key_hash: Optional[str] = Query(
None, description="Key hash of the override to remove"
),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Remove a policy override for a tool. Specify the override by team_id or key_hash
(exactly one required).
"""
from litellm.proxy.db.tool_registry_writer import (
get_tool_policy_registry,
remove_tool_from_object_permission_blocked,
)
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
)
if team_id is None and key_hash is None:
raise HTTPException(
status_code=400,
detail="At least one of team_id or key_hash is required to identify the override",
)
if team_id is not None and key_hash is not None:
raise HTTPException(
status_code=400,
detail="Provide either team_id or key_hash, not both",
)
try:
if key_hash is not None:
op_id = await _resolve_key_hash_to_object_permission_id(
prisma_client, key_hash
)
else:
op_id = await _resolve_team_id_to_object_permission_id(
prisma_client, team_id or ""
)
if op_id is None:
raise HTTPException(
status_code=404,
detail="Key or team not found for the given identifier",
)
deleted = await remove_tool_from_object_permission_blocked(
prisma_client=prisma_client,
object_permission_id=op_id,
tool_name=tool_name,
)
if not deleted:
raise HTTPException(
status_code=404,
detail=f"No override found for tool '{tool_name}' with the given scope",
)
registry = get_tool_policy_registry()
if registry.is_initialized():
await registry.sync_tool_policy_from_db(prisma_client)
return {"deleted": True, "tool_name": tool_name}
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception("Error deleting tool policy override: %s", e)
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,59 @@
"""
Types for the management endpoints
Might include fastapi/proxy requirements.txt related imports
"""
from typing import Any, Dict, List, Optional, cast
from fastapi_sso.sso.base import OpenID
from litellm.proxy._types import LitellmUserRoles
def is_valid_litellm_user_role(role_str: str) -> bool:
"""
Check if a string is a valid LitellmUserRoles enum value (case-insensitive).
Args:
role_str: String to validate (e.g., "proxy_admin", "PROXY_ADMIN", "internal_user")
Returns:
True if the string matches a valid LitellmUserRoles value, False otherwise
"""
try:
# Use _value2member_map_ for O(1) lookup, case-insensitive
return role_str.lower() in LitellmUserRoles._value2member_map_
except Exception:
return False
def get_litellm_user_role(role_str) -> Optional[LitellmUserRoles]:
"""
Convert a string (or list of strings) to a LitellmUserRoles enum if valid (case-insensitive).
Handles list inputs since some SSO providers (e.g., Keycloak) return roles
as arrays like ["proxy_admin"] instead of plain strings.
Args:
role_str: String or list to convert (e.g., "proxy_admin", ["proxy_admin"])
Returns:
LitellmUserRoles enum if valid, None otherwise
"""
try:
if isinstance(role_str, list):
if len(role_str) == 0:
return None
role_str = role_str[0]
# Use _value2member_map_ for O(1) lookup, case-insensitive
result = LitellmUserRoles._value2member_map_.get(role_str.lower())
return cast(Optional[LitellmUserRoles], result)
except Exception:
return None
class CustomOpenID(OpenID):
team_ids: List[str]
user_role: Optional[LitellmUserRoles] = None
extra_fields: Optional[Dict[str, Any]] = None

View File

@@ -0,0 +1,9 @@
"""
Usage endpoints package.
Re-exports the router from endpoints module.
"""
from litellm.proxy.management_endpoints.usage_endpoints.endpoints import ( # noqa: F401
router,
)

View File

@@ -0,0 +1,578 @@
"""
AI Usage Chat - uses LLM tool calling to answer questions about
usage/spend data by querying the aggregated daily activity endpoints.
"""
import json
from datetime import date
from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional, cast
from typing_extensions import TypedDict
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.constants import DEFAULT_COMPETITOR_DISCOVERY_MODEL
from litellm.types.proxy.management_endpoints.common_daily_activity import (
SpendAnalyticsPaginatedResponse,
)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
USAGE_AI_TEMPERATURE = 0.2
TABLE_DAILY_USER_SPEND = "litellm_dailyuserspend"
TABLE_DAILY_TEAM_SPEND = "litellm_dailyteamspend"
TABLE_DAILY_TAG_SPEND = "litellm_dailytagspend"
ENTITY_FIELD_USER = "user_id"
ENTITY_FIELD_TEAM = "team_id"
ENTITY_FIELD_TAG = "tag"
PAGINATED_PAGE_SIZE = 200
MAX_CHAT_MESSAGES = 20
TOP_N_MODELS = 15
TOP_N_PROVIDERS = 10
TOP_N_KEYS = 10
# ---------------------------------------------------------------------------
# Types
# ---------------------------------------------------------------------------
class SSEStatusEvent(TypedDict):
type: Literal["status"]
message: str
class SSEToolCallEvent(TypedDict, total=False):
type: Literal["tool_call"]
tool_name: str
tool_label: str
arguments: Dict[str, str]
status: Literal["running", "complete", "error"]
error: str
class SSEChunkEvent(TypedDict):
type: Literal["chunk"]
content: str
class SSEDoneEvent(TypedDict):
type: Literal["done"]
class SSEErrorEvent(TypedDict):
type: Literal["error"]
message: str
SSEEvent = (
SSEStatusEvent | SSEToolCallEvent | SSEChunkEvent | SSEDoneEvent | SSEErrorEvent
)
class ToolHandler(TypedDict):
fetch: Callable[..., Any]
summarise: Callable[[Dict[str, Any]], str]
label: str
# ---------------------------------------------------------------------------
# Tool definitions (OpenAI function-calling schema)
# ---------------------------------------------------------------------------
_DATE_PARAMS = {
"start_date": {"type": "string", "description": "Start date in YYYY-MM-DD format"},
"end_date": {"type": "string", "description": "End date in YYYY-MM-DD format"},
}
_TOOL_USAGE = {
"type": "function",
"function": {
"name": "get_usage_data",
"description": (
"Fetch aggregated global usage/spend data. Returns daily spend, "
"token counts, request counts, and breakdowns by model, provider, "
"and API key. Use for overall spend, top models, top providers."
),
"parameters": {
"type": "object",
"properties": {
**_DATE_PARAMS,
"user_id": {
"type": "string",
"description": "Optional user ID filter. Omit for global view.",
},
},
"required": ["start_date", "end_date"],
},
},
}
_TOOL_TEAM = {
"type": "function",
"function": {
"name": "get_team_usage_data",
"description": (
"Fetch usage/spend data broken down by team. Use for questions "
"like 'which team spends the most' or 'show me team X usage'."
),
"parameters": {
"type": "object",
"properties": {
**_DATE_PARAMS,
"team_ids": {
"type": "string",
"description": "Optional comma-separated team IDs. Omit for all teams.",
},
},
"required": ["start_date", "end_date"],
},
},
}
_TOOL_TAG = {
"type": "function",
"function": {
"name": "get_tag_usage_data",
"description": (
"Fetch usage/spend data broken down by tag. Tags are labels "
"attached to requests (features, environments, credentials)."
),
"parameters": {
"type": "object",
"properties": {
**_DATE_PARAMS,
"tags": {
"type": "string",
"description": "Optional comma-separated tag names. Omit for all tags.",
},
},
"required": ["start_date", "end_date"],
},
},
}
TOOLS_BASE = [_TOOL_USAGE]
TOOLS_ADMIN = [_TOOL_USAGE, _TOOL_TEAM, _TOOL_TAG]
def get_tools_for_role(is_admin: bool) -> List[Dict[str, Any]]:
"""Return the tool list appropriate for the user's role."""
return TOOLS_ADMIN if is_admin else TOOLS_BASE
_SYSTEM_PROMPT_BASE = (
"You are an AI assistant embedded in the LiteLLM Usage dashboard. "
"You help users understand their LLM API spend and usage data.\n\n"
"ALWAYS call the appropriate tool(s) first to fetch data before answering. "
"You may call multiple tools if the question spans different dimensions.\n\n"
"Guidelines:\n"
"- Be concise and specific. Use exact numbers from the data.\n"
"- Format costs as dollar amounts (e.g. $12.34).\n"
"- When comparing entities, show a ranked list.\n"
"- If data is empty or no results found, say so clearly.\n"
"- Do not hallucinate data — only use what the tools return.\n"
"- Today's date will be provided below. Use it to interpret relative dates "
"like 'this week', 'this month', 'last 7 days', etc."
)
_TOOL_DESCRIPTIONS_ADMIN = (
"You have access to these tools:\n"
"- `get_usage_data`: Global/user-level usage (spend, models, providers, API keys)\n"
"- `get_team_usage_data`: Team-level usage breakdown\n"
"- `get_tag_usage_data`: Tag-level usage breakdown\n\n"
)
_TOOL_DESCRIPTIONS_BASE = (
"You have access to this tool:\n"
"- `get_usage_data`: Your usage data (spend, models, providers, API keys)\n\n"
)
def _build_system_prompt(is_admin: bool) -> str:
"""Build role-appropriate system prompt with today's date."""
tool_desc = _TOOL_DESCRIPTIONS_ADMIN if is_admin else _TOOL_DESCRIPTIONS_BASE
return (
f"{_SYSTEM_PROMPT_BASE}\n\n{tool_desc}"
f"Today's date: {date.today().isoformat()}"
)
# keep a public reference for test assertions
SYSTEM_PROMPT = _SYSTEM_PROMPT_BASE
# ---------------------------------------------------------------------------
# Data fetchers
# ---------------------------------------------------------------------------
def _parse_csv_ids(raw: Optional[str]) -> Optional[List[str]]:
if not raw:
return None
return [t.strip() for t in raw.split(",") if t.strip()]
async def _query_activity(
table_name: str,
entity_id_field: str,
entity_id: Optional[Any],
start_date: str,
end_date: str,
*,
use_aggregated: bool = False,
) -> SpendAnalyticsPaginatedResponse:
"""Shared helper that calls the daily activity query layer."""
from litellm.proxy.management_endpoints.common_daily_activity import (
get_daily_activity,
get_daily_activity_aggregated,
)
from litellm.proxy.proxy_server import prisma_client
if use_aggregated:
return await get_daily_activity_aggregated(
prisma_client=prisma_client,
table_name=table_name,
entity_id_field=entity_id_field,
entity_id=entity_id,
entity_metadata_field=None,
start_date=start_date,
end_date=end_date,
model=None,
api_key=None,
)
return await get_daily_activity(
prisma_client=prisma_client,
table_name=table_name,
entity_id_field=entity_id_field,
entity_id=entity_id,
entity_metadata_field=None,
start_date=start_date,
end_date=end_date,
model=None,
api_key=None,
page=1,
page_size=PAGINATED_PAGE_SIZE,
)
async def _fetch_usage_data(
start_date: str, end_date: str, user_id: Optional[str] = None
) -> Dict[str, Any]:
resp = await _query_activity(
TABLE_DAILY_USER_SPEND,
ENTITY_FIELD_USER,
user_id,
start_date,
end_date,
use_aggregated=True,
)
return resp.model_dump(mode="json")
async def _fetch_team_usage_data(
start_date: str, end_date: str, team_ids: Optional[str] = None
) -> Dict[str, Any]:
resp = await _query_activity(
TABLE_DAILY_TEAM_SPEND,
ENTITY_FIELD_TEAM,
_parse_csv_ids(team_ids),
start_date,
end_date,
)
return resp.model_dump(mode="json")
async def _fetch_tag_usage_data(
start_date: str, end_date: str, tags: Optional[str] = None
) -> Dict[str, Any]:
resp = await _query_activity(
TABLE_DAILY_TAG_SPEND,
ENTITY_FIELD_TAG,
_parse_csv_ids(tags),
start_date,
end_date,
)
return resp.model_dump(mode="json")
# ---------------------------------------------------------------------------
# Summarisers — convert raw JSON to concise text the LLM can reason over
# ---------------------------------------------------------------------------
def _accumulate_breakdown(
results: List[Dict[str, Any]], dimension: str, fields: List[str]
) -> Dict[str, Dict[str, float]]:
"""Aggregate a single breakdown dimension across days."""
totals: Dict[str, Dict[str, float]] = {}
for day in results:
for key, entry in day.get("breakdown", {}).get(dimension, {}).items():
if key not in totals:
totals[key] = {f: 0.0 for f in fields}
m = entry.get("metrics", {})
for f in fields:
totals[key][f] += m.get(f, 0)
return totals
def _ranked_lines(
totals: Dict[str, Dict[str, float]],
fmt: Callable[[str, Dict[str, float]], str],
limit: int,
) -> List[str]:
"""Sort by spend descending, format each entry, and truncate."""
return [
fmt(name, vals)
for name, vals in sorted(totals.items(), key=lambda x: -x[1].get("spend", 0))[
:limit
]
]
def _summarise_usage_data(data: Dict[str, Any]) -> str:
meta = data.get("metadata", {})
results = data.get("results", [])
header = (
f"Total Spend: ${meta.get('total_spend', 0):.4f}\n"
f"Total Requests: {meta.get('total_api_requests', 0)}\n"
f"Successful: {meta.get('total_successful_requests', 0)} | "
f"Failed: {meta.get('total_failed_requests', 0)}\n"
f"Total Tokens: {meta.get('total_tokens', 0)}"
)
models = _accumulate_breakdown(
results, "models", ["spend", "api_requests", "total_tokens"]
)
providers = _accumulate_breakdown(results, "providers", ["spend", "api_requests"])
model_lines = _ranked_lines(
models,
lambda n, d: f" - {n}: ${d['spend']:.4f} ({int(d['api_requests'])} reqs, {int(d['total_tokens'])} tokens)",
TOP_N_MODELS,
)
provider_lines = _ranked_lines(
providers,
lambda n, d: f" - {n}: ${d['spend']:.4f} ({int(d['api_requests'])} reqs)",
TOP_N_PROVIDERS,
)
sections = [header, ""]
sections += ["Top Models by Spend:"] + (model_lines or [" (no data)"]) + [""]
sections += ["Top Providers by Spend:"] + (provider_lines or [" (no data)"])
return "\n".join(sections)
def _summarise_entity_data(data: Dict[str, Any], entity_label: str) -> str:
"""Summarise team/tag entity usage data."""
results = data.get("results", [])
if not results:
return f"No {entity_label} usage data found for the given date range."
totals: Dict[str, Dict[str, Any]] = {}
for day in results:
for eid, entry in day.get("breakdown", {}).get("entities", {}).items():
if eid not in totals:
alias = entry.get("metadata", {}).get("alias", eid)
totals[eid] = {"alias": alias, "spend": 0.0, "requests": 0, "tokens": 0}
m = entry.get("metrics", {})
totals[eid]["spend"] += m.get("spend", 0)
totals[eid]["requests"] += m.get("api_requests", 0)
totals[eid]["tokens"] += m.get("total_tokens", 0)
lines = [f"{entity_label} Usage ({len(totals)} {entity_label.lower()}s):", ""]
for eid, d in sorted(totals.items(), key=lambda x: -x[1]["spend"]):
label = d["alias"] if d["alias"] != eid else eid
lines.append(
f"- {label} (ID: {eid}): ${d['spend']:.4f} | "
f"{int(d['requests'])} reqs | {int(d['tokens'])} tokens"
)
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Tool dispatch registry
# ---------------------------------------------------------------------------
TOOL_HANDLERS: Dict[str, ToolHandler] = {
"get_usage_data": ToolHandler(
fetch=_fetch_usage_data,
summarise=_summarise_usage_data,
label="global usage data",
),
"get_team_usage_data": ToolHandler(
fetch=_fetch_team_usage_data,
summarise=lambda data: _summarise_entity_data(data, "Team"),
label="team usage data",
),
"get_tag_usage_data": ToolHandler(
fetch=_fetch_tag_usage_data,
summarise=lambda data: _summarise_entity_data(data, "Tag"),
label="tag usage data",
),
}
# ---------------------------------------------------------------------------
# SSE streaming
# ---------------------------------------------------------------------------
def _sse(event: SSEEvent) -> str:
return f"data: {json.dumps(event)}\n\n"
def _resolve_fetch_kwargs(
fn_name: str,
fn_args: Dict[str, str],
user_id: Optional[str],
is_admin: bool,
) -> Dict[str, Any]:
"""Build keyword arguments for a tool's fetch function."""
start_date = fn_args.get("start_date", "")
end_date = fn_args.get("end_date", "")
if not start_date or not end_date:
raise ValueError("Missing required start_date or end_date from tool arguments")
kwargs: Dict[str, Any] = {"start_date": start_date, "end_date": end_date}
if fn_name == "get_usage_data":
if not is_admin:
kwargs["user_id"] = user_id
elif fn_args.get("user_id"):
kwargs["user_id"] = fn_args["user_id"]
elif fn_name == "get_team_usage_data" and fn_args.get("team_ids"):
kwargs["team_ids"] = fn_args["team_ids"]
elif fn_name == "get_tag_usage_data" and fn_args.get("tags"):
kwargs["tags"] = fn_args["tags"]
return kwargs
async def _execute_tool_call(
handler: ToolHandler,
fn_name: str,
fn_args: Dict[str, str],
user_id: Optional[str],
is_admin: bool,
) -> str:
"""Run a single tool and return the summarised result text."""
kwargs = _resolve_fetch_kwargs(fn_name, fn_args, user_id, is_admin)
raw_data = await handler["fetch"](**kwargs)
return handler["summarise"](raw_data)
async def _process_tool_call(
tc: Any,
chat_messages: List[Dict[str, Any]],
user_id: Optional[str],
is_admin: bool,
) -> AsyncIterator[str]:
"""Execute a single tool call, yielding SSE events for status."""
fn_name = tc.function.name
fn_args = json.loads(tc.function.arguments)
allowed_names = {t["function"]["name"] for t in get_tools_for_role(is_admin)}
handler = TOOL_HANDLERS.get(fn_name)
if fn_name not in allowed_names or not handler:
chat_messages.append(
{
"role": "tool",
"tool_call_id": tc.id,
"content": f"Tool not available: {fn_name}",
}
)
return
tool_event_base = {
"type": "tool_call",
"tool_name": fn_name,
"tool_label": handler["label"],
"arguments": fn_args,
}
yield _sse(cast(SSEToolCallEvent, {**tool_event_base, "status": "running"}))
try:
tool_result = await _execute_tool_call(
handler, fn_name, fn_args, user_id, is_admin
)
yield _sse(cast(SSEToolCallEvent, {**tool_event_base, "status": "complete"}))
except Exception as e:
verbose_proxy_logger.error("Tool %s failed: %s", fn_name, e)
tool_result = f"Error fetching {handler['label']}. Please try again."
yield _sse(cast(SSEToolCallEvent, {**tool_event_base, "status": "error"}))
chat_messages.append(
{"role": "tool", "tool_call_id": tc.id, "content": tool_result}
)
async def _stream_final_response(
model: str, chat_messages: List[Dict[str, Any]]
) -> AsyncIterator[str]:
"""Stream the final LLM response after tool results are appended."""
yield _sse({"type": "status", "message": "Analyzing results..."})
response = await litellm.acompletion(
model=model,
messages=chat_messages,
stream=True,
temperature=USAGE_AI_TEMPERATURE,
)
async for chunk in response:
delta = chunk.choices[0].delta.content
if delta:
yield _sse({"type": "chunk", "content": delta})
async def stream_usage_ai_chat(
messages: List[Dict[str, str]],
model: Optional[str] = None,
user_id: Optional[str] = None,
is_admin: bool = False,
) -> AsyncIterator[str]:
"""Stream SSE events: status → tool_call → chunk → done."""
resolved_model = (model or "").strip() or DEFAULT_COMPETITOR_DISCOVERY_MODEL
truncated = (
messages[-MAX_CHAT_MESSAGES:] if len(messages) > MAX_CHAT_MESSAGES else messages
)
chat_messages: List[Dict[str, Any]] = [
{"role": "system", "content": _build_system_prompt(is_admin)},
*truncated,
]
try:
yield _sse({"type": "status", "message": "Thinking..."})
tools = get_tools_for_role(is_admin)
response = await litellm.acompletion(
model=resolved_model,
messages=chat_messages,
tools=tools,
temperature=USAGE_AI_TEMPERATURE,
)
choice = response.choices[0] # type: ignore
if not choice.message.tool_calls:
if choice.message.content:
yield _sse({"type": "chunk", "content": choice.message.content})
yield _sse({"type": "done"})
return
chat_messages.append(choice.message.model_dump())
for tc in choice.message.tool_calls:
async for event in _process_tool_call(tc, chat_messages, user_id, is_admin):
yield event
async for event in _stream_final_response(resolved_model, chat_messages):
yield event
yield _sse({"type": "done"})
except Exception as e:
verbose_proxy_logger.error("AI usage chat failed: %s", e)
yield _sse(
{
"type": "error",
"message": "An internal error occurred. Please try again.",
}
)

View File

@@ -0,0 +1,65 @@
"""
USAGE AI CHAT ENDPOINTS
/usage/ai/chat - Stream AI chat responses about usage data
"""
from typing import List, Literal, Optional
from fastapi import APIRouter, Depends, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
router = APIRouter()
class ChatMessage(BaseModel):
role: Literal["user", "assistant"]
content: str
class UsageAIChatRequest(BaseModel):
messages: List[ChatMessage] = Field(
..., description="Chat messages (user/assistant history)"
)
model: Optional[str] = Field(default=None, description="Model to use for AI chat")
@router.post(
"/usage/ai/chat",
tags=["Budget & Spend Tracking"],
dependencies=[Depends(user_api_key_auth)],
)
async def usage_ai_chat(
data: UsageAIChatRequest,
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
AI chat about usage data. Streams SSE events with the AI response.
The AI agent has access to tools that query aggregated daily activity data.
"""
from litellm.proxy.management_endpoints.common_utils import (
_user_has_admin_view,
)
from litellm.proxy.management_endpoints.usage_endpoints.ai_usage_chat import (
stream_usage_ai_chat,
)
is_admin = _user_has_admin_view(user_api_key_dict)
user_id = user_api_key_dict.user_id
messages = [{"role": m.role, "content": m.content} for m in data.messages]
return StreamingResponse(
stream_usage_ai_chat(
messages=messages,
model=data.model,
user_id=user_id,
is_admin=is_admin,
),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)

View File

@@ -0,0 +1,775 @@
"""
User Agent Analytics Endpoints
This module provides optimized endpoints for tracking user agent activity metrics including:
- Daily Active Users (DAU) by tags for configurable number of days
- Weekly Active Users (WAU) by tags for configurable number of weeks
- Monthly Active Users (MAU) by tags for configurable number of months
- Summary analytics by tags
These endpoints use optimized single SQL queries with joins to efficiently calculate
user metrics from tag activity data and return time series for dashboard visualization.
"""
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
# Constants for analytics periods
MAX_DAYS = 7 # Number of days to show in DAU analytics
MAX_WEEKS = 7 # Number of weeks to show in WAU analytics
MAX_MONTHS = 7 # Number of months to show in MAU analytics
MAX_TAGS = 250 # Maximum number of distinct tags to return
router = APIRouter()
class TagActiveUsersResponse(BaseModel):
"""Response for tag active users metrics"""
tag: str
active_users: int
date: str # The specific date or period identifier
period_start: Optional[
str
] = None # For WAU/MAU, this will be the start of the period
period_end: Optional[str] = None # For WAU/MAU, this will be the end of the period
class ActiveUsersAnalyticsResponse(BaseModel):
"""Response for active users analytics"""
results: List[TagActiveUsersResponse]
class TagSummaryMetrics(BaseModel):
"""Summary metrics for a tag"""
tag: str
unique_users: int
total_requests: int
successful_requests: int
failed_requests: int
total_tokens: int
total_spend: float
class TagSummaryResponse(BaseModel):
"""Response for tag summary analytics"""
results: List[TagSummaryMetrics]
class DistinctTagResponse(BaseModel):
"""Response for distinct user agent tags"""
tag: str
class DistinctTagsResponse(BaseModel):
"""Response for all distinct user agent tags"""
results: List[DistinctTagResponse]
class PerUserMetrics(BaseModel):
"""Metrics for individual user"""
user_id: str
user_email: Optional[str] = None
user_agent: Optional[str] = None
successful_requests: int = 0
failed_requests: int = 0
total_requests: int = 0
total_tokens: int = 0
spend: float = 0.0
class PerUserAnalyticsResponse(BaseModel):
"""Response for per-user analytics"""
results: List[PerUserMetrics]
total_count: int
page: int
page_size: int
total_pages: int
@router.get(
"/tag/distinct",
response_model=DistinctTagsResponse,
tags=["tag management", "user agent analytics"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_distinct_user_agent_tags(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get all distinct user agent tags up to a maximum of {MAX_TAGS} tags.
This endpoint returns all unique user agent tags found in the database,
sorted by frequency of usage.
Returns:
DistinctTagsResponse: List of distinct user agent tags
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
try:
sql_query = f"""
SELECT
dts.tag,
COUNT(*) as usage_count
FROM "LiteLLM_DailyTagSpend" dts
WHERE dts.tag LIKE 'User-Agent:%' OR dts.tag NOT LIKE '%:%'
GROUP BY dts.tag
ORDER BY usage_count DESC
LIMIT {MAX_TAGS}
"""
db_response = await prisma_client.db.query_raw(sql_query)
results = [DistinctTagResponse(tag=row["tag"]) for row in db_response]
return DistinctTagsResponse(results=results)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch distinct user agent tags: {str(e)}",
)
@router.get(
"/tag/dau",
response_model=ActiveUsersAnalyticsResponse,
tags=["tag management", "user agent analytics"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_daily_active_users(
tag_filter: Optional[str] = Query(
default=None,
description="Filter by specific tag (optional)",
),
tag_filters: Optional[List[str]] = Query(
default=None,
description="Filter by multiple specific tags (optional, takes precedence over tag_filter)",
),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get Daily Active Users (DAU) by tags for the last {MAX_DAYS} days ending on UTC today + 1 day.
This endpoint efficiently calculates unique users per tag for each of the last {MAX_DAYS} days
using a single optimized SQL query, perfect for dashboard time series visualization.
Args:
tag_filter: Optional filter to specific tag (legacy)
tag_filters: Optional filter to multiple specific tags (takes precedence over tag_filter)
Returns:
ActiveUsersAnalyticsResponse: DAU data by tag for each of the last {MAX_DAYS} days
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
try:
# Calculate end_date as UTC today + 1 day
from datetime import timezone
end_dt = datetime.now(timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0
) + timedelta(days=1)
end_date = end_dt.strftime("%Y-%m-%d")
# Calculate date range (last MAX_DAYS days)
start_dt = end_dt - timedelta(days=MAX_DAYS)
start_date = start_dt.strftime("%Y-%m-%d")
# Build SQL query with optional tag filter(s)
where_clause = (
"WHERE dts.date >= $1 AND dts.date <= $2 AND vt.user_id IS NOT NULL"
)
params = [start_date, end_date]
# Handle multiple tag filters (takes precedence over single tag filter)
if tag_filters and len(tag_filters) > 0:
tag_conditions = []
for i, tag in enumerate(tag_filters):
param_index = len(params) + 1
tag_conditions.append(f"dts.tag = ${param_index}")
params.append(tag)
where_clause += f" AND ({' OR '.join(tag_conditions)})"
elif tag_filter:
where_clause += " AND dts.tag ILIKE $3"
params.append(f"%{tag_filter}%")
sql_query = f"""
SELECT
dts.tag,
dts.date,
COUNT(DISTINCT vt.user_id) as active_users
FROM "LiteLLM_DailyTagSpend" dts
INNER JOIN "LiteLLM_VerificationToken" vt ON dts.api_key = vt.token
{where_clause}
GROUP BY dts.tag, dts.date
ORDER BY dts.date DESC, active_users DESC
"""
db_response = await prisma_client.db.query_raw(sql_query, *params)
results = [
TagActiveUsersResponse(
tag=row["tag"], active_users=row["active_users"], date=row["date"]
)
for row in db_response
]
return ActiveUsersAnalyticsResponse(results=results)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch DAU analytics: {str(e)}",
)
@router.get(
"/tag/wau",
response_model=ActiveUsersAnalyticsResponse,
tags=["tag management", "user agent analytics"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_weekly_active_users(
tag_filter: Optional[str] = Query(
default=None,
description="Filter by specific tag (optional)",
),
tag_filters: Optional[List[str]] = Query(
default=None,
description="Filter by multiple specific tags (optional, takes precedence over tag_filter)",
),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get Weekly Active Users (WAU) by tags for the last {MAX_WEEKS} weeks ending on UTC today + 1 day.
Shows week-by-week breakdown:
- Week 1 (Jan 1): Earliest week (7 weeks ago)
- Week 2 (Jan 8): Next week (6 weeks ago)
- Week 3 (Jan 15): Next week (5 weeks ago)
- ... and so on for {MAX_WEEKS} weeks total
- Week 7: Most recent week ending on UTC today + 1 day
Args:
tag_filter: Optional filter to specific tag (legacy)
tag_filters: Optional filter to multiple specific tags (takes precedence over tag_filter)
Returns:
ActiveUsersAnalyticsResponse: WAU data by tag for each of the last {MAX_WEEKS} weeks with descriptive week labels (e.g., "Week 1 (Jan 1)")
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
try:
# Calculate end_date as UTC today + 1 day
from datetime import timezone
end_dt = datetime.now(timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0
) + timedelta(days=1)
end_date = end_dt.strftime("%Y-%m-%d")
# Calculate date range for all weeks (49 days total)
# Start from 48 days before end_date to cover exactly MAX_WEEKS complete weeks
start_dt = end_dt - timedelta(
days=(MAX_WEEKS * 7 - 1)
) # MAX_WEEKS weeks * 7 days - 1
start_date = start_dt.strftime("%Y-%m-%d")
# Build SQL query with optional tag filter(s)
where_clause = (
"WHERE dts.date >= $1 AND dts.date <= $2 AND vt.user_id IS NOT NULL"
)
params = [start_date, end_date]
# Handle multiple tag filters (takes precedence over single tag filter)
if tag_filters and len(tag_filters) > 0:
tag_conditions = []
for i, tag in enumerate(tag_filters):
param_index = len(params) + 1
tag_conditions.append(f"dts.tag = ${param_index}")
params.append(tag)
where_clause += f" AND ({' OR '.join(tag_conditions)})"
elif tag_filter:
where_clause += " AND dts.tag ILIKE $3"
params.append(f"%{tag_filter}%")
# Use window function to group by weeks with clear week numbering
sql_query = f"""
WITH weekly_data AS (
SELECT
dts.tag,
dts.date,
vt.user_id,
-- Calculate week number (0 = Week 1 most recent, 1 = Week 2, etc.)
FLOOR((DATE '{end_date}' - dts.date::date) / 7) as week_offset
FROM "LiteLLM_DailyTagSpend" dts
INNER JOIN "LiteLLM_VerificationToken" vt ON dts.api_key = vt.token
{where_clause}
)
SELECT
tag,
COUNT(DISTINCT user_id) as active_users,
-- Week identifier with month and day (Week 1 (earliest), Week 2, etc.)
'Week ' || ({MAX_WEEKS} - week_offset)::text || ' (' ||
TO_CHAR(DATE '{end_date}' - (week_offset * 7 || ' days')::interval - '6 days'::interval, 'Mon DD') || ')' as date,
-- Calculate week start and end dates for each week
(DATE '{end_date}' - (week_offset * 7 || ' days')::interval - '6 days'::interval)::text as period_start,
(DATE '{end_date}' - (week_offset * 7 || ' days')::interval)::text as period_end,
week_offset
FROM weekly_data
WHERE week_offset < {MAX_WEEKS}
GROUP BY tag, week_offset
ORDER BY week_offset DESC, active_users DESC
"""
db_response = await prisma_client.db.query_raw(sql_query, *params)
results = [
TagActiveUsersResponse(
tag=row["tag"],
active_users=row["active_users"],
date=row[
"date"
], # This will be "Week 1 (Jan 15)", "Week 2 (Jan 8)", etc.
period_start=row["period_start"],
period_end=row["period_end"],
)
for row in db_response
]
return ActiveUsersAnalyticsResponse(results=results)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch WAU analytics: {str(e)}",
)
@router.get(
"/tag/mau",
response_model=ActiveUsersAnalyticsResponse,
tags=["tag management", "user agent analytics"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_monthly_active_users(
tag_filter: Optional[str] = Query(
default=None,
description="Filter by specific tag (optional)",
),
tag_filters: Optional[List[str]] = Query(
default=None,
description="Filter by multiple specific tags (optional, takes precedence over tag_filter)",
),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get Monthly Active Users (MAU) by tags for the last {MAX_MONTHS} months ending on UTC today + 1 day.
Shows month-by-month breakdown:
- Month 1 (Nov): Earliest month (7 months ago, 30-day period)
- Month 2 (Dec): Next month (6 months ago)
- Month 3 (Jan): Next month (5 months ago)
- ... and so on for {MAX_MONTHS} months total
- Month 7: Most recent month ending on UTC today + 1 day
Args:
tag_filter: Optional filter to specific tag (legacy)
tag_filters: Optional filter to multiple specific tags (takes precedence over tag_filter)
Returns:
ActiveUsersAnalyticsResponse: MAU data by tag for each of the last {MAX_MONTHS} months with descriptive month labels (e.g., "Month 1 (Nov)")
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
try:
# Calculate end_date as UTC today + 1 day
from datetime import timezone
end_dt = datetime.now(timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0
) + timedelta(days=1)
end_date = end_dt.strftime("%Y-%m-%d")
# Calculate date range for all months (210 days total)
# Start from 209 days before end_date to cover exactly MAX_MONTHS complete months
start_dt = end_dt - timedelta(
days=(MAX_MONTHS * 30 - 1)
) # MAX_MONTHS months * 30 days - 1
start_date = start_dt.strftime("%Y-%m-%d")
# Build SQL query with optional tag filter(s)
where_clause = (
"WHERE dts.date >= $1 AND dts.date <= $2 AND vt.user_id IS NOT NULL"
)
params = [start_date, end_date]
# Handle multiple tag filters (takes precedence over single tag filter)
if tag_filters and len(tag_filters) > 0:
tag_conditions = []
for i, tag in enumerate(tag_filters):
param_index = len(params) + 1
tag_conditions.append(f"dts.tag = ${param_index}")
params.append(tag)
where_clause += f" AND ({' OR '.join(tag_conditions)})"
elif tag_filter:
where_clause += " AND dts.tag ILIKE $3"
params.append(f"%{tag_filter}%")
# Use window function to group by months (30-day periods) with clear month numbering
sql_query = f"""
WITH monthly_data AS (
SELECT
dts.tag,
dts.date,
vt.user_id,
-- Calculate month number (0 = Month 1 most recent, 1 = Month 2, etc.)
FLOOR((DATE '{end_date}' - dts.date::date) / 30) as month_offset
FROM "LiteLLM_DailyTagSpend" dts
INNER JOIN "LiteLLM_VerificationToken" vt ON dts.api_key = vt.token
{where_clause}
)
SELECT
tag,
COUNT(DISTINCT user_id) as active_users,
-- Month identifier with month name (Month 1 (earliest), Month 2, etc.)
'Month ' || ({MAX_MONTHS} - month_offset)::text || ' (' ||
TO_CHAR(DATE '{end_date}' - (month_offset * 30 || ' days')::interval - '29 days'::interval, 'Mon') || ')' as date,
-- Calculate month start and end dates for each month
(DATE '{end_date}' - (month_offset * 30 || ' days')::interval - '29 days'::interval)::text as period_start,
(DATE '{end_date}' - (month_offset * 30 || ' days')::interval)::text as period_end,
month_offset
FROM monthly_data
WHERE month_offset < {MAX_MONTHS}
GROUP BY tag, month_offset
ORDER BY month_offset DESC, active_users DESC
"""
db_response = await prisma_client.db.query_raw(sql_query, *params)
results = [
TagActiveUsersResponse(
tag=row["tag"],
active_users=row["active_users"],
date=row["date"], # This will be "Month 1 (Jan)", "Month 2 (Dec)", etc.
period_start=row["period_start"],
period_end=row["period_end"],
)
for row in db_response
]
return ActiveUsersAnalyticsResponse(results=results)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch MAU analytics: {str(e)}",
)
@router.get(
"/tag/summary",
response_model=TagSummaryResponse,
tags=["tag management", "user agent analytics"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_tag_summary(
start_date: str = Query(description="Start date in YYYY-MM-DD format"),
end_date: str = Query(description="End date in YYYY-MM-DD format"),
tag_filter: Optional[str] = Query(
default=None,
description="Filter by specific tag (optional)",
),
tag_filters: Optional[List[str]] = Query(
default=None,
description="Filter by multiple specific tags (optional, takes precedence over tag_filter)",
),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get summary analytics for tags including unique users, requests, tokens, and spend.
Args:
start_date: Start date for the analytics period (YYYY-MM-DD)
end_date: End date for the analytics period (YYYY-MM-DD)
tag_filter: Optional filter to specific tag (legacy)
tag_filters: Optional filter to multiple specific tags (takes precedence over tag_filter)
Returns:
TagSummaryResponse: Summary analytics data by tag
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
try:
# Validate date format
datetime.strptime(start_date, "%Y-%m-%d")
datetime.strptime(end_date, "%Y-%m-%d")
# Build SQL query with optional tag filter(s)
where_clause = "WHERE dts.date >= $1 AND dts.date <= $2"
params = [start_date, end_date]
# Handle multiple tag filters (takes precedence over single tag filter)
if tag_filters and len(tag_filters) > 0:
tag_conditions = []
for i, tag in enumerate(tag_filters):
param_index = len(params) + 1
tag_conditions.append(f"dts.tag = ${param_index}")
params.append(tag)
where_clause += f" AND ({' OR '.join(tag_conditions)})"
elif tag_filter:
where_clause += " AND dts.tag ILIKE $3"
params.append(f"%{tag_filter}%")
sql_query = f"""
SELECT
dts.tag,
COUNT(DISTINCT vt.user_id) as unique_users,
SUM(dts.api_requests) as total_requests,
SUM(dts.successful_requests) as successful_requests,
SUM(dts.failed_requests) as failed_requests,
SUM(dts.prompt_tokens + dts.completion_tokens) as total_tokens,
SUM(dts.spend) as total_spend
FROM "LiteLLM_DailyTagSpend" dts
LEFT JOIN "LiteLLM_VerificationToken" vt ON dts.api_key = vt.token
{where_clause}
GROUP BY dts.tag
ORDER BY total_requests DESC
"""
db_response = await prisma_client.db.query_raw(sql_query, *params)
results = [
TagSummaryMetrics(
tag=row["tag"],
unique_users=row["unique_users"] or 0,
total_requests=int(row["total_requests"] or 0),
successful_requests=int(row["successful_requests"] or 0),
failed_requests=int(row["failed_requests"] or 0),
total_tokens=int(row["total_tokens"] or 0),
total_spend=float(row["total_spend"] or 0.0),
)
for row in db_response
]
return TagSummaryResponse(results=results)
except ValueError as e:
raise HTTPException(
status_code=400,
detail=f"Invalid date format. Use YYYY-MM-DD: {str(e)}",
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch tag summary analytics: {str(e)}",
)
@router.get(
"/tag/user-agent/per-user-analytics",
response_model=PerUserAnalyticsResponse,
tags=["tag management", "user agent analytics"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_per_user_analytics(
tag_filter: Optional[str] = Query(
default=None,
description="Filter by specific tag (optional)",
),
tag_filters: Optional[List[str]] = Query(
default=None,
description="Filter by multiple specific tags (optional, takes precedence over tag_filter)",
),
page: int = Query(default=1, description="Page number for pagination", ge=1),
page_size: int = Query(default=50, description="Items per page", ge=1, le=1000),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get per-user analytics including successful requests, tokens, and spend by individual users.
This endpoint provides usage metrics broken down by individual users based on their
tag activity during the last 30 days ending on UTC today + 1 day.
Args:
tag_filter: Optional filter to specific tag (legacy)
tag_filters: Optional filter to multiple specific tags (takes precedence over tag_filter)
page: Page number for pagination
page_size: Number of items per page
Returns:
PerUserAnalyticsResponse: Analytics data broken down by individual users for the last 30 days
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
try:
# Calculate end_date as UTC today + 1 day
from datetime import timezone
end_dt = datetime.now(timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0
) + timedelta(days=1)
end_date = end_dt.strftime("%Y-%m-%d")
# Calculate date range (last 30 days)
start_dt = end_dt - timedelta(days=30)
start_date = start_dt.strftime("%Y-%m-%d")
# Build where clause with date range
where_clause: Dict[str, Any] = {"date": {"gte": start_date, "lte": end_date}}
# Add tag filtering if provided
if tag_filters and len(tag_filters) > 0:
where_clause["tag"] = {"in": tag_filters}
elif tag_filter:
where_clause["tag"] = {"contains": tag_filter}
# Get all tag records in the date range with optional tag filtering
tag_records = await prisma_client.db.litellm_dailytagspend.find_many(
where=where_clause
)
# Get unique api_keys
api_keys = set(record.api_key for record in tag_records if record.api_key)
if not api_keys:
return PerUserAnalyticsResponse(
results=[],
total_count=0,
page=page,
page_size=page_size,
total_pages=0,
)
# Lookup user_id for each api_key
api_key_records = await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": list(api_keys)}}
)
# Create mapping from api_key to user_id
api_key_to_user_id = {
record.token: record.user_id for record in api_key_records if record.user_id
}
# Get user emails for the user_ids
user_ids = list(set(api_key_to_user_id.values()))
user_records = await prisma_client.db.litellm_usertable.find_many(
where={"user_id": {"in": user_ids}}
)
# Create mapping from user_id to user_email
user_id_to_email = {
record.user_id: record.user_email for record in user_records
}
# Aggregate metrics by user
user_metrics: Dict[str, PerUserMetrics] = {}
for record in tag_records:
if record.api_key in api_key_to_user_id:
user_id = api_key_to_user_id[record.api_key]
tag = record.tag # Use the full tag as user_agent
if user_id not in user_metrics:
user_metrics[user_id] = PerUserMetrics(
user_id=user_id,
user_email=user_id_to_email.get(user_id),
user_agent=tag,
)
else:
# If tag is different, keep the first one or prioritize certain ones
if tag and not user_metrics[user_id].user_agent:
user_metrics[user_id].user_agent = tag
# Aggregate metrics
user_metrics[user_id].successful_requests += (
record.successful_requests or 0
)
user_metrics[user_id].failed_requests += record.failed_requests or 0
user_metrics[user_id].total_requests += record.api_requests or 0
# Calculate total_tokens from prompt_tokens + completion_tokens
prompt_tokens = record.prompt_tokens or 0
completion_tokens = record.completion_tokens or 0
user_metrics[user_id].total_tokens += int(
prompt_tokens + completion_tokens
)
user_metrics[user_id].spend += record.spend or 0.0
# Convert to list and sort by successful requests (descending)
results = sorted(
list(user_metrics.values()),
key=lambda x: x.successful_requests,
reverse=True,
)
# Apply pagination
total_count = len(results)
total_pages = (total_count + page_size - 1) // page_size
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
paginated_results = results[start_idx:end_idx]
return PerUserAnalyticsResponse(
results=paginated_results,
total_count=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch per-user analytics: {str(e)}",
)