chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,176 @@
"""
Auth Checks for Organizations
"""
from typing import Dict, List, Optional, Tuple
from fastapi import status
from litellm.proxy._types import *
def organization_role_based_access_check(
request_body: dict,
user_object: Optional[LiteLLM_UserTable],
route: str,
):
"""
Role based access control checks only run if a user is part of an Organization
Organization Checks:
ONLY RUN IF user_object.organization_memberships is not None
1. Only Proxy Admins can access /organization/new
2. IF route is a LiteLLMRoutes.org_admin_only_routes, then check if user is an Org Admin for that organization
"""
if user_object is None:
return
passed_organization_id: Optional[str] = request_body.get("organization_id", None)
if route == "/organization/new":
if user_object.user_role != LitellmUserRoles.PROXY_ADMIN.value:
raise ProxyException(
message=f"Only proxy admins can create new organizations. You are {user_object.user_role}",
type=ProxyErrorTypes.auth_error.value,
param="user_role",
code=status.HTTP_401_UNAUTHORIZED,
)
if user_object.user_role == LitellmUserRoles.PROXY_ADMIN.value:
return
# Checks if route is an Org Admin Only Route
if route in LiteLLMRoutes.org_admin_only_routes.value:
(
_user_organizations,
_user_organization_role_mapping,
) = get_user_organization_info(user_object)
if user_object.organization_memberships is None:
raise ProxyException(
message=f"Tried to access route={route} but you are not a member of any organization. Please contact the proxy admin to request access.",
type=ProxyErrorTypes.auth_error.value,
param="organization_id",
code=status.HTTP_401_UNAUTHORIZED,
)
if passed_organization_id is None:
raise ProxyException(
message="Passed organization_id is None, please pass an organization_id in your request",
type=ProxyErrorTypes.auth_error.value,
param="organization_id",
code=status.HTTP_401_UNAUTHORIZED,
)
user_role: Optional[LitellmUserRoles] = _user_organization_role_mapping.get(
passed_organization_id
)
if user_role is None:
raise ProxyException(
message=f"You do not have a role within the selected organization. Passed organization_id: {passed_organization_id}. Please contact the organization admin to request access.",
type=ProxyErrorTypes.auth_error.value,
param="organization_id",
code=status.HTTP_401_UNAUTHORIZED,
)
if user_role != LitellmUserRoles.ORG_ADMIN.value:
raise ProxyException(
message=f"You do not have the required role to perform {route} in Organization {passed_organization_id}. Your role is {user_role} in Organization {passed_organization_id}",
type=ProxyErrorTypes.auth_error.value,
param="user_role",
code=status.HTTP_401_UNAUTHORIZED,
)
elif route == "/team/new":
# if user is part of multiple teams, then they need to specify the organization_id
(
_user_organizations,
_user_organization_role_mapping,
) = get_user_organization_info(user_object)
if (
user_object.organization_memberships is not None
and len(user_object.organization_memberships) > 0
):
if passed_organization_id is None:
raise ProxyException(
message=f"Passed organization_id is None, please specify the organization_id in your request. You are part of multiple organizations: {_user_organizations}",
type=ProxyErrorTypes.auth_error.value,
param="organization_id",
code=status.HTTP_401_UNAUTHORIZED,
)
_user_role_in_passed_org = _user_organization_role_mapping.get(
passed_organization_id
)
if _user_role_in_passed_org != LitellmUserRoles.ORG_ADMIN.value:
raise ProxyException(
message=f"You do not have the required role to call {route}. Your role is {_user_role_in_passed_org} in Organization {passed_organization_id}",
type=ProxyErrorTypes.auth_error.value,
param="user_role",
code=status.HTTP_401_UNAUTHORIZED,
)
def get_user_organization_info(
user_object: LiteLLM_UserTable,
) -> Tuple[List[str], Dict[str, Optional[LitellmUserRoles]]]:
"""
Helper function to extract user organization information.
Args:
user_object (LiteLLM_UserTable): The user object containing organization memberships.
Returns:
Tuple[List[str], Dict[str, Optional[LitellmUserRoles]]]: A tuple containing:
- List of organization IDs the user is a member of
- Dictionary mapping organization IDs to user roles
"""
_user_organizations: List[str] = []
_user_organization_role_mapping: Dict[str, Optional[LitellmUserRoles]] = {}
if user_object.organization_memberships is not None:
for _membership in user_object.organization_memberships:
if _membership.organization_id is not None:
_user_organizations.append(_membership.organization_id)
_user_organization_role_mapping[_membership.organization_id] = _membership.user_role # type: ignore
return _user_organizations, _user_organization_role_mapping
def _user_is_org_admin(
request_data: dict,
user_object: Optional[LiteLLM_UserTable] = None,
) -> bool:
"""
Helper function to check if user is an org admin for any of the passed organizations.
Checks both:
- `organization_id` (singular string) — legacy callers
- `organizations` (list of strings) — used by /user/new
"""
if user_object is None:
return False
if user_object.organization_memberships is None:
return False
# Collect candidate org IDs from both fields
candidate_org_ids: List[str] = []
singular = request_data.get("organization_id", None)
if singular is not None:
candidate_org_ids.append(singular)
orgs_list = request_data.get("organizations", None)
if isinstance(orgs_list, list):
candidate_org_ids.extend(orgs_list)
if not candidate_org_ids:
return False
for _membership in user_object.organization_memberships:
if _membership.organization_id in candidate_org_ids:
if _membership.user_role == LitellmUserRoles.ORG_ADMIN.value:
return True
return False

View File

@@ -0,0 +1,125 @@
"""
Handles Authentication Errors
"""
from typing import TYPE_CHECKING, Any, Optional, Union
from fastapi import HTTPException, Request, status
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ProxyErrorTypes, ProxyException, UserAPIKeyAuth
from litellm.proxy.auth.auth_utils import _get_request_ip_address
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler
from litellm.types.services import ServiceTypes
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
class UserAPIKeyAuthExceptionHandler:
@staticmethod
async def _handle_authentication_error(
e: Exception,
request: Request,
request_data: dict,
route: str,
parent_otel_span: Optional[Span],
api_key: str,
) -> UserAPIKeyAuth:
"""
Handles Connection Errors when reading a Virtual Key from LiteLLM DB
Use this if you don't want failed DB queries to block LLM API reqiests
Reliability scenarios this covers:
- DB is down and having an outage
- Unable to read / recover a key from the DB
Returns:
- UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True
Raises:
- Original Exception in all other cases
"""
from litellm.proxy.proxy_server import (
general_settings,
litellm_proxy_admin_name,
proxy_logging_obj,
)
if (
PrismaDBExceptionHandler.should_allow_request_on_db_unavailable()
and PrismaDBExceptionHandler.is_database_connection_error(e)
):
# log this as a DB failure on prometheus
proxy_logging_obj.service_logging_obj.service_failure_hook(
service=ServiceTypes.DB,
call_type="get_key_object",
error=e,
duration=0.0,
)
return UserAPIKeyAuth(
key_name="failed-to-connect-to-db",
token="failed-to-connect-to-db",
user_id=litellm_proxy_admin_name,
request_route=route,
)
else:
# raise the exception to the caller
requester_ip = _get_request_ip_address(
request=request,
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
)
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\nRequester IP Address:{}".format(
str(e),
requester_ip,
),
extra={"requester_ip": requester_ip},
)
# Log this exception to OTEL, Datadog etc
user_api_key_dict = UserAPIKeyAuth(
parent_otel_span=parent_otel_span,
api_key=api_key,
request_route=route,
)
# Allow callbacks to transform the error response
transformed_exception = await proxy_logging_obj.post_call_failure_hook(
request_data=request_data,
original_exception=e,
user_api_key_dict=user_api_key_dict,
error_type=ProxyErrorTypes.auth_error,
route=route,
)
# Use transformed exception if callback returned one, otherwise use original
if transformed_exception is not None:
e = transformed_exception
if isinstance(e, litellm.BudgetExceededError):
raise ProxyException(
message=e.message,
type=ProxyErrorTypes.budget_exceeded,
param=None,
code=400,
)
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type=ProxyErrorTypes.auth_error,
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Authentication Error, " + str(e),
type=ProxyErrorTypes.auth_error,
param=getattr(e, "param", "None"),
code=status.HTTP_401_UNAUTHORIZED,
)

View File

@@ -0,0 +1,835 @@
import os
import re
import sys
from functools import lru_cache
from typing import Any, List, Optional, Tuple
from fastapi import HTTPException, Request, status
from litellm import Router, provider_list
from litellm._logging import verbose_proxy_logger
from litellm.constants import STANDARD_CUSTOMER_ID_HEADERS
from litellm.proxy._types import *
from litellm.types.router import CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS
def _get_request_ip_address(
request: Request, use_x_forwarded_for: Optional[bool] = False
) -> Optional[str]:
client_ip = None
if use_x_forwarded_for is True and "x-forwarded-for" in request.headers:
client_ip = request.headers["x-forwarded-for"]
elif request.client is not None:
client_ip = request.client.host
else:
client_ip = ""
return client_ip
def _check_valid_ip(
allowed_ips: Optional[List[str]],
request: Request,
use_x_forwarded_for: Optional[bool] = False,
) -> Tuple[bool, Optional[str]]:
"""
Returns if ip is allowed or not
"""
if allowed_ips is None: # if not set, assume true
return True, None
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
client_ip = _get_request_ip_address(
request=request, use_x_forwarded_for=use_x_forwarded_for
)
# Check if IP address is allowed
if client_ip not in allowed_ips:
return False, client_ip
return True, client_ip
def check_complete_credentials(request_body: dict) -> bool:
"""
if 'api_base' in request body. Check if complete credentials given. Prevent malicious attacks.
"""
given_model: Optional[str] = None
given_model = request_body.get("model")
if given_model is None:
return False
if (
"sagemaker" in given_model
or "bedrock" in given_model
or "vertex_ai" in given_model
or "vertex_ai_beta" in given_model
):
# complex credentials - easier to make a malicious request
return False
if "api_key" in request_body:
return True
return False
def check_regex_or_str_match(request_body_value: Any, regex_str: str) -> bool:
"""
Check if request_body_value matches the regex_str or is equal to param
"""
if re.match(regex_str, request_body_value) or regex_str == request_body_value:
return True
return False
def _is_param_allowed(
param: str,
request_body_value: Any,
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
) -> bool:
"""
Check if param is a str or dict and if request_body_value is in the list of allowed values
"""
if configurable_clientside_auth_params is None:
return False
for item in configurable_clientside_auth_params:
if isinstance(item, str) and param == item:
return True
elif isinstance(item, Dict):
if param == "api_base" and check_regex_or_str_match(
request_body_value=request_body_value,
regex_str=item["api_base"],
): # assume param is a regex
return True
return False
def _allow_model_level_clientside_configurable_parameters(
model: str, param: str, request_body_value: Any, llm_router: Optional[Router]
) -> bool:
"""
Check if model is allowed to use configurable client-side params
- get matching model
- check if 'clientside_configurable_parameters' is set for model
-
"""
if llm_router is None:
return False
# check if model is set
model_info = llm_router.get_model_group_info(model_group=model)
if model_info is None:
# check if wildcard model is set
if model.split("/", 1)[0] in provider_list:
model_info = llm_router.get_model_group_info(
model_group=model.split("/", 1)[0]
)
if model_info is None:
return False
if model_info is None or model_info.configurable_clientside_auth_params is None:
return False
return _is_param_allowed(
param=param,
request_body_value=request_body_value,
configurable_clientside_auth_params=model_info.configurable_clientside_auth_params,
)
def is_request_body_safe(
request_body: dict, general_settings: dict, llm_router: Optional[Router], model: str
) -> bool:
"""
Check if the request body is safe.
A malicious user can set the api_base to their own domain and invoke POST /chat/completions to intercept and steal the OpenAI API key.
Relevant issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997
"""
banned_params = ["api_base", "base_url"]
for param in banned_params:
if (
param in request_body
and not check_complete_credentials( # allow client-credentials to be passed to proxy
request_body=request_body
)
):
if general_settings.get("allow_client_side_credentials") is True:
return True
elif (
_allow_model_level_clientside_configurable_parameters(
model=model,
param=param,
request_body_value=request_body[param],
llm_router=llm_router,
)
is True
):
return True
raise ValueError(
f"Rejected Request: {param} is not allowed in request body. "
"Enable with `general_settings::allow_client_side_credentials` on proxy config.yaml. "
"Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997",
)
return True
async def pre_db_read_auth_checks(
request: Request,
request_data: dict,
route: str,
):
"""
1. Checks if request size is under max_request_size_mb (if set)
2. Check if request body is safe (example user has not set api_base in request body)
3. Check if IP address is allowed (if set)
4. Check if request route is an allowed route on the proxy (if set)
Returns:
- True
Raises:
- HTTPException if request fails initial auth checks
"""
from litellm.proxy.proxy_server import general_settings, llm_router, premium_user
# Check 1. request size
await check_if_request_size_is_safe(request=request)
# Check 2. Request body is safe
is_request_body_safe(
request_body=request_data,
general_settings=general_settings,
llm_router=llm_router,
model=request_data.get(
"model", ""
), # [TODO] use model passed in url as well (azure openai routes)
)
# Check 3. Check if IP address is allowed
is_valid_ip, passed_in_ip = _check_valid_ip(
allowed_ips=general_settings.get("allowed_ips", None),
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
request=request,
)
if not is_valid_ip:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access forbidden: IP address {passed_in_ip} not allowed.",
)
# Check 4. Check if request route is an allowed route on the proxy
if "allowed_routes" in general_settings:
_allowed_routes = general_settings["allowed_routes"]
if premium_user is not True:
verbose_proxy_logger.error(
f"Trying to set allowed_routes. This is an Enterprise feature. {CommonProxyErrors.not_premium_user.value}"
)
if route not in _allowed_routes:
verbose_proxy_logger.error(
f"Route {route} not in allowed_routes={_allowed_routes}"
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access forbidden: Route {route} not allowed",
)
def route_in_additonal_public_routes(current_route: str):
"""
Helper to check if the user defined public_routes on config.yaml
Parameters:
- current_route: str - the route the user is trying to call
Returns:
- bool - True if the route is defined in public_routes
- bool - False if the route is not defined in public_routes
Supports wildcard patterns (e.g., "/api/*" matches "/api/users", "/api/users/123")
In order to use this the litellm config.yaml should have the following in general_settings:
```yaml
general_settings:
master_key: sk-1234
public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate", "/api/*"]
```
"""
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.proxy_server import general_settings, premium_user
try:
if premium_user is not True:
return False
if general_settings is None:
return False
routes_defined = general_settings.get("public_routes", [])
# Check exact match first
if current_route in routes_defined:
return True
# Check wildcard patterns
for route_pattern in routes_defined:
if RouteChecks._route_matches_wildcard_pattern(
route=current_route, pattern=route_pattern
):
return True
return False
except Exception as e:
verbose_proxy_logger.error(f"route_in_additonal_public_routes: {str(e)}")
return False
def get_request_route(request: Request) -> str:
"""
Helper to get the route from the request
remove base url from path if set e.g. `/genai/chat/completions` -> `/chat/completions
"""
try:
if hasattr(request, "base_url") and request.url.path.startswith(
request.base_url.path
):
# remove base_url from path
return request.url.path[len(request.base_url.path) - 1 :]
else:
return request.url.path
except Exception as e:
verbose_proxy_logger.debug(
f"error on get_request_route: {str(e)}, defaulting to request.url.path={request.url.path}"
)
return request.url.path
@lru_cache(maxsize=256)
def normalize_request_route(route: str) -> str:
"""
Normalize request routes by replacing dynamic path parameters with placeholders.
This prevents high cardinality in Prometheus metrics by collapsing routes like:
- /v1/responses/1234567890 -> /v1/responses/{response_id}
- /v1/threads/thread_123 -> /v1/threads/{thread_id}
Args:
route: The request route path
Returns:
Normalized route with dynamic parameters replaced by placeholders
Examples:
>>> normalize_request_route("/v1/responses/abc123")
'/v1/responses/{response_id}'
>>> normalize_request_route("/v1/responses/abc123/cancel")
'/v1/responses/{response_id}/cancel'
>>> normalize_request_route("/chat/completions")
'/chat/completions'
"""
# Define patterns for routes with dynamic IDs
# Format: (regex_pattern, replacement_template)
patterns = [
# Responses API - must come before generic patterns
(r"^(/(?:openai/)?v1/responses)/([^/]+)(/input_items)$", r"\1/{response_id}\3"),
(r"^(/(?:openai/)?v1/responses)/([^/]+)(/cancel)$", r"\1/{response_id}\3"),
(r"^(/(?:openai/)?v1/responses)/([^/]+)$", r"\1/{response_id}"),
(r"^(/responses)/([^/]+)(/input_items)$", r"\1/{response_id}\3"),
(r"^(/responses)/([^/]+)(/cancel)$", r"\1/{response_id}\3"),
(r"^(/responses)/([^/]+)$", r"\1/{response_id}"),
# Threads API
(
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/steps)/([^/]+)$",
r"\1/{thread_id}\3/{run_id}\5/{step_id}",
),
(
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/steps)$",
r"\1/{thread_id}\3/{run_id}\5",
),
(
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/cancel)$",
r"\1/{thread_id}\3/{run_id}\5",
),
(
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/submit_tool_outputs)$",
r"\1/{thread_id}\3/{run_id}\5",
),
(
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)$",
r"\1/{thread_id}\3/{run_id}",
),
(r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)$", r"\1/{thread_id}\3"),
(
r"^(/(?:openai/)?v1/threads)/([^/]+)(/messages)/([^/]+)$",
r"\1/{thread_id}\3/{message_id}",
),
(r"^(/(?:openai/)?v1/threads)/([^/]+)(/messages)$", r"\1/{thread_id}\3"),
(r"^(/(?:openai/)?v1/threads)/([^/]+)$", r"\1/{thread_id}"),
# Vector Stores API
(
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/files)/([^/]+)$",
r"\1/{vector_store_id}\3/{file_id}",
),
(
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/files)$",
r"\1/{vector_store_id}\3",
),
(
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/file_batches)/([^/]+)$",
r"\1/{vector_store_id}\3/{batch_id}",
),
(
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/file_batches)$",
r"\1/{vector_store_id}\3",
),
(r"^(/(?:openai/)?v1/vector_stores)/([^/]+)$", r"\1/{vector_store_id}"),
# Assistants API
(r"^(/(?:openai/)?v1/assistants)/([^/]+)$", r"\1/{assistant_id}"),
# Files API
(r"^(/(?:openai/)?v1/files)/([^/]+)(/content)$", r"\1/{file_id}\3"),
(r"^(/(?:openai/)?v1/files)/([^/]+)$", r"\1/{file_id}"),
# Batches API
(r"^(/(?:openai/)?v1/batches)/([^/]+)(/cancel)$", r"\1/{batch_id}\3"),
(r"^(/(?:openai/)?v1/batches)/([^/]+)$", r"\1/{batch_id}"),
# Fine-tuning API
(
r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)(/events)$",
r"\1/{fine_tuning_job_id}\3",
),
(
r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)(/cancel)$",
r"\1/{fine_tuning_job_id}\3",
),
(
r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)(/checkpoints)$",
r"\1/{fine_tuning_job_id}\3",
),
(r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)$", r"\1/{fine_tuning_job_id}"),
# Models API
(r"^(/(?:openai/)?v1/models)/([^/]+)$", r"\1/{model}"),
]
# Apply patterns in order
for pattern, replacement in patterns:
normalized = re.sub(pattern, replacement, route)
if normalized != route:
return normalized
# Return original route if no pattern matched
return route
async def check_if_request_size_is_safe(request: Request) -> bool:
"""
Enterprise Only:
- Checks if the request size is within the limit
Args:
request (Request): The incoming request.
Returns:
bool: True if the request size is within the limit
Raises:
ProxyException: If the request size is too large
"""
from litellm.proxy.proxy_server import general_settings, premium_user
max_request_size_mb = general_settings.get("max_request_size_mb", None)
if max_request_size_mb is not None:
# Check if premium user
if premium_user is not True:
verbose_proxy_logger.warning(
f"using max_request_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
)
return True
# Get the request body
content_length = request.headers.get("content-length")
if content_length:
header_size = int(content_length)
header_size_mb = bytes_to_mb(bytes_value=header_size)
verbose_proxy_logger.debug(
f"content_length request size in MB={header_size_mb}"
)
if header_size_mb > max_request_size_mb:
raise ProxyException(
message=f"Request size is too large. Request size is {header_size_mb} MB. Max size is {max_request_size_mb} MB",
type=ProxyErrorTypes.bad_request_error.value,
code=400,
param="content-length",
)
else:
# If Content-Length is not available, read the body
body = await request.body()
body_size = len(body)
request_size_mb = bytes_to_mb(bytes_value=body_size)
verbose_proxy_logger.debug(
f"request body request size in MB={request_size_mb}"
)
if request_size_mb > max_request_size_mb:
raise ProxyException(
message=f"Request size is too large. Request size is {request_size_mb} MB. Max size is {max_request_size_mb} MB",
type=ProxyErrorTypes.bad_request_error.value,
code=400,
param="content-length",
)
return True
async def check_response_size_is_safe(response: Any) -> bool:
"""
Enterprise Only:
- Checks if the response size is within the limit
Args:
response (Any): The response to check.
Returns:
bool: True if the response size is within the limit
Raises:
ProxyException: If the response size is too large
"""
from litellm.proxy.proxy_server import general_settings, premium_user
max_response_size_mb = general_settings.get("max_response_size_mb", None)
if max_response_size_mb is not None:
# Check if premium user
if premium_user is not True:
verbose_proxy_logger.warning(
f"using max_response_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
)
return True
response_size_mb = bytes_to_mb(bytes_value=sys.getsizeof(response))
verbose_proxy_logger.debug(f"response size in MB={response_size_mb}")
if response_size_mb > max_response_size_mb:
raise ProxyException(
message=f"Response size is too large. Response size is {response_size_mb} MB. Max size is {max_response_size_mb} MB",
type=ProxyErrorTypes.bad_request_error.value,
code=400,
param="content-length",
)
return True
def bytes_to_mb(bytes_value: int):
"""
Helper to convert bytes to MB
"""
return bytes_value / (1024 * 1024)
# helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key
def get_key_model_rpm_limit(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[Dict[str, int]]:
"""
Get the model rpm limit for a given api key.
Priority order (returns first found):
1. Key metadata (model_rpm_limit)
2. Key model_max_budget (rpm_limit per model)
3. Team metadata (model_rpm_limit)
"""
# 1. Check key metadata first (takes priority)
if user_api_key_dict.metadata:
result = user_api_key_dict.metadata.get("model_rpm_limit")
if result:
return result
# 2. Check model_max_budget
if user_api_key_dict.model_max_budget:
model_rpm_limit: Dict[str, Any] = {}
for model, budget in user_api_key_dict.model_max_budget.items():
if isinstance(budget, dict) and budget.get("rpm_limit") is not None:
model_rpm_limit[model] = budget["rpm_limit"]
if model_rpm_limit:
return model_rpm_limit
# 3. Fallback to team metadata
if user_api_key_dict.team_metadata:
return user_api_key_dict.team_metadata.get("model_rpm_limit")
return None
def get_key_model_tpm_limit(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[Dict[str, int]]:
"""
Get the model tpm limit for a given api key.
Priority order (returns first found):
1. Key metadata (model_tpm_limit)
2. Key model_max_budget (tpm_limit per model)
3. Team metadata (model_tpm_limit)
"""
# 1. Check key metadata first (takes priority)
if user_api_key_dict.metadata:
result = user_api_key_dict.metadata.get("model_tpm_limit")
if result:
return result
# 2. Check model_max_budget (iterate per-model like RPM does)
if user_api_key_dict.model_max_budget:
model_tpm_limit: Dict[str, Any] = {}
for model, budget in user_api_key_dict.model_max_budget.items():
if isinstance(budget, dict) and budget.get("tpm_limit") is not None:
model_tpm_limit[model] = budget["tpm_limit"]
if model_tpm_limit:
return model_tpm_limit
# 3. Fallback to team metadata
if user_api_key_dict.team_metadata:
return user_api_key_dict.team_metadata.get("model_tpm_limit")
return None
def get_model_rate_limit_from_metadata(
user_api_key_dict: UserAPIKeyAuth,
metadata_accessor_key: Literal["team_metadata", "organization_metadata"],
rate_limit_key: Literal["model_rpm_limit", "model_tpm_limit"],
) -> Optional[Dict[str, int]]:
if getattr(user_api_key_dict, metadata_accessor_key):
return getattr(user_api_key_dict, metadata_accessor_key).get(rate_limit_key)
return None
def get_team_model_rpm_limit(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[Dict[str, int]]:
if user_api_key_dict.team_metadata:
return user_api_key_dict.team_metadata.get("model_rpm_limit")
return None
def get_team_model_tpm_limit(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[Dict[str, int]]:
if user_api_key_dict.team_metadata:
return user_api_key_dict.team_metadata.get("model_tpm_limit")
return None
def is_pass_through_provider_route(route: str) -> bool:
PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES = [
"vertex-ai",
]
# check if any of the prefixes are in the route
for prefix in PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES:
if prefix in route:
return True
return False
def _has_user_setup_sso():
"""
Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID or generic client ID and UI username environment variables.
Returns a boolean indicating whether SSO has been set up.
"""
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
sso_setup = (
(microsoft_client_id is not None)
or (google_client_id is not None)
or (generic_client_id is not None)
)
return sso_setup
def get_customer_user_header_from_mapping(user_id_mapping) -> Optional[str]:
"""Return the header_name mapped to CUSTOMER role, if any (dict-based)."""
if not user_id_mapping:
return None
items = user_id_mapping if isinstance(user_id_mapping, list) else [user_id_mapping]
for item in items:
if not isinstance(item, dict):
continue
role = item.get("litellm_user_role")
header_name = item.get("header_name")
if role is None or not header_name:
continue
if str(role).lower() == str(LitellmUserRoles.CUSTOMER).lower():
return header_name
return None
def _get_customer_id_from_standard_headers(
request_headers: Optional[dict],
) -> Optional[str]:
"""
Check standard customer ID headers for a customer/end-user ID.
This enables tools like Claude Code to pass customer IDs via ANTHROPIC_CUSTOM_HEADERS.
No configuration required - these headers are always checked.
Args:
request_headers: The request headers dict
Returns:
The customer ID if found in standard headers, None otherwise
"""
if request_headers is None:
return None
for standard_header in STANDARD_CUSTOMER_ID_HEADERS:
for header_name, header_value in request_headers.items():
if header_name.lower() == standard_header.lower():
user_id_str = str(header_value) if header_value is not None else ""
if user_id_str.strip():
return user_id_str
return None
def get_end_user_id_from_request_body(
request_body: dict, request_headers: Optional[dict] = None
) -> Optional[str]:
# Import general_settings here to avoid potential circular import issues at module level
# and to ensure it's fetched at runtime.
from litellm.proxy.proxy_server import general_settings
# Check 1: Standard customer ID headers (always checked, no configuration required)
customer_id = _get_customer_id_from_standard_headers(
request_headers=request_headers
)
if customer_id is not None:
return customer_id
# Check 2: Follow the user header mappings feature, if not found, then check for deprecated user_header_name (only if request_headers is provided)
# User query: "system not respecting user_header_name property"
# This implies the key in general_settings is 'user_header_name'.
if request_headers is not None:
custom_header_name_to_check: Optional[str] = None
# Prefer user mappings (new behavior)
user_id_mapping = general_settings.get("user_header_mappings", None)
if user_id_mapping:
custom_header_name_to_check = get_customer_user_header_from_mapping(
user_id_mapping
)
# Fallback to deprecated user_header_name if mapping did not specify
if not custom_header_name_to_check:
user_id_header_config_key = "user_header_name"
value = general_settings.get(user_id_header_config_key)
if isinstance(value, str) and value.strip() != "":
custom_header_name_to_check = value
# If we have a header name to check, try to read it from request headers
if isinstance(custom_header_name_to_check, str):
for header_name, header_value in request_headers.items():
if header_name.lower() == custom_header_name_to_check.lower():
user_id_from_header = header_value
user_id_str = (
str(user_id_from_header)
if user_id_from_header is not None
else ""
)
if user_id_str.strip():
return user_id_str
# Check 3: 'user' field in request_body (commonly OpenAI)
if "user" in request_body and request_body["user"] is not None:
user_from_body_user_field = request_body["user"]
return str(user_from_body_user_field)
# Check 4: 'litellm_metadata.user' in request_body (commonly Anthropic)
litellm_metadata = request_body.get("litellm_metadata")
if isinstance(litellm_metadata, dict):
user_from_litellm_metadata = litellm_metadata.get("user")
if user_from_litellm_metadata is not None:
return str(user_from_litellm_metadata)
# Check 5: 'metadata.user_id' in request_body (another common pattern)
metadata_dict = request_body.get("metadata")
if isinstance(metadata_dict, dict):
user_id_from_metadata_field = metadata_dict.get("user_id")
if user_id_from_metadata_field is not None:
return str(user_id_from_metadata_field)
# Check 6: 'safety_identifier' in request body (OpenAI Responses API parameter)
# SECURITY NOTE: safety_identifier can be set by any caller in the request body.
# Only use this for end-user identification in trusted environments where you control
# the calling application. For untrusted callers, prefer using headers or server-side
# middleware to set the end_user_id to prevent impersonation.
if request_body.get("safety_identifier") is not None:
user_from_body_user_field = request_body["safety_identifier"]
return str(user_from_body_user_field)
return None
def get_model_from_request(
request_data: dict, route: str
) -> Optional[Union[str, List[str]]]:
# First try to get model from request_data
model = request_data.get("model") or request_data.get("target_model_names")
if model is not None:
model_names = model.split(",")
if len(model_names) == 1:
model = model_names[0].strip()
else:
model = [m.strip() for m in model_names]
# If model not in request_data, try to extract from route
if model is None:
# Parse model from route that follows the pattern /openai/deployments/{model}/*
match = re.match(r"/openai/deployments/([^/]+)", route)
if match:
model = match.group(1)
# If still not found, extract model from Google generateContent-style routes.
# These routes put the model in the path and allow "/" inside the model id.
# Examples:
# - /v1beta/models/gemini-2.0-flash:generateContent
# - /v1beta/models/bedrock/claude-sonnet-3.7:generateContent
# - /models/custom/ns/model:streamGenerateContent
if model is None and not route.lower().startswith("/vertex"):
google_match = re.search(r"/(?:v1beta|beta)/models/([^:]+):", route)
if google_match:
model = google_match.group(1)
if model is None and not route.lower().startswith("/vertex"):
google_match = re.search(r"^/models/([^:]+):", route)
if google_match:
model = google_match.group(1)
# If still not found, extract from Vertex AI passthrough route
# Pattern: /vertex_ai/.../models/{model_id}:*
# Example: /vertex_ai/v1/.../models/gemini-1.5-pro:generateContent
if model is None and route.lower().startswith("/vertex"):
vertex_match = re.search(r"/models/([^:]+)", route)
if vertex_match:
model = vertex_match.group(1)
return model
def abbreviate_api_key(api_key: str) -> str:
return f"sk-...{api_key[-4:]}"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,156 @@
"""
IP address utilities for MCP public/private access control.
Internal callers (private IPs) see all MCP servers.
External callers (public IPs) only see servers with available_on_public_internet=True.
"""
import ipaddress
from typing import Any, Dict, List, Optional, Union
from fastapi import Request
from litellm._logging import verbose_proxy_logger
from litellm.proxy.auth.auth_utils import _get_request_ip_address
class IPAddressUtils:
"""Static utilities for IP-based MCP access control."""
_DEFAULT_INTERNAL_NETWORKS = [
ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("127.0.0.0/8"),
ipaddress.ip_network("::1/128"),
ipaddress.ip_network("fc00::/7"),
]
@staticmethod
def parse_internal_networks(
configured_ranges: Optional[List[str]],
) -> List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]:
"""Parse configured CIDR ranges into network objects, falling back to defaults."""
if not configured_ranges:
return IPAddressUtils._DEFAULT_INTERNAL_NETWORKS
networks: List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]] = []
for cidr in configured_ranges:
try:
networks.append(ipaddress.ip_network(cidr, strict=False))
except ValueError:
verbose_proxy_logger.warning(
"Invalid CIDR in mcp_internal_ip_ranges: %s, skipping", cidr
)
return networks if networks else IPAddressUtils._DEFAULT_INTERNAL_NETWORKS
@staticmethod
def parse_trusted_proxy_networks(
configured_ranges: Optional[List[str]],
) -> List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]:
"""
Parse trusted proxy CIDR ranges for XFF validation.
Returns empty list if not configured (XFF will not be trusted).
"""
if not configured_ranges:
return []
networks: List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]] = []
for cidr in configured_ranges:
try:
networks.append(ipaddress.ip_network(cidr, strict=False))
except ValueError:
verbose_proxy_logger.warning(
"Invalid CIDR in mcp_trusted_proxy_ranges: %s, skipping", cidr
)
return networks
@staticmethod
def is_trusted_proxy(
proxy_ip: Optional[str],
trusted_networks: List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]],
) -> bool:
"""Check if the direct connection IP is from a trusted proxy."""
if not proxy_ip or not trusted_networks:
return False
try:
addr = ipaddress.ip_address(proxy_ip.strip())
return any(addr in network for network in trusted_networks)
except ValueError:
return False
@staticmethod
def is_internal_ip(
client_ip: Optional[str],
internal_networks: Optional[
List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]
] = None,
) -> bool:
"""
Check if a client IP is from an internal/private network.
Handles X-Forwarded-For comma chains (takes leftmost = original client).
Fails closed: empty/invalid IPs are treated as external.
"""
if not client_ip:
return False
# X-Forwarded-For may contain comma-separated chain; leftmost is original client
if "," in client_ip:
client_ip = client_ip.split(",")[0].strip()
networks = internal_networks or IPAddressUtils._DEFAULT_INTERNAL_NETWORKS
try:
addr = ipaddress.ip_address(client_ip.strip())
except ValueError:
return False
return any(addr in network for network in networks)
@staticmethod
def get_mcp_client_ip(
request: Request,
general_settings: Optional[Dict[str, Any]] = None,
) -> Optional[str]:
"""
Extract client IP from a FastAPI request for MCP access control.
Security: Only trusts X-Forwarded-For if:
1. use_x_forwarded_for is enabled in settings
2. The direct connection is from a trusted proxy (if mcp_trusted_proxy_ranges configured)
Args:
request: FastAPI request object
general_settings: Optional settings dict. If not provided, imports from proxy_server.
"""
if general_settings is None:
try:
from litellm.proxy.proxy_server import (
general_settings as proxy_general_settings,
)
general_settings = proxy_general_settings
except ImportError:
general_settings = {}
# Handle case where general_settings is still None after import
if general_settings is None:
general_settings = {}
use_xff = general_settings.get("use_x_forwarded_for", False)
# If XFF is enabled, validate the request comes from a trusted proxy
if use_xff and "x-forwarded-for" in request.headers:
trusted_ranges = general_settings.get("mcp_trusted_proxy_ranges")
if trusted_ranges:
# Validate direct connection is from trusted proxy
direct_ip = request.client.host if request.client else None
trusted_networks = IPAddressUtils.parse_trusted_proxy_networks(
trusted_ranges
)
if not IPAddressUtils.is_trusted_proxy(direct_ip, trusted_networks):
# Untrusted source trying to set XFF - ignore XFF, use direct IP
verbose_proxy_logger.warning(
"XFF header from untrusted IP %s, ignoring", direct_ip
)
return direct_ip
return _get_request_ip_address(request, use_x_forwarded_for=use_xff)

View File

@@ -0,0 +1,214 @@
# What is this?
## If litellm license in env, checks if it's valid
import base64
import json
import os
from datetime import datetime
from typing import TYPE_CHECKING, Optional
import httpx
from litellm._logging import verbose_proxy_logger
from litellm.constants import NON_LLM_CONNECTION_TIMEOUT
from litellm.llms.custom_httpx.http_handler import HTTPHandler
if TYPE_CHECKING:
from litellm.proxy._types import EnterpriseLicenseData
class LicenseCheck:
"""
- Check if license in env
- Returns if license is valid
"""
base_url = "https://license.litellm.ai"
def __init__(self) -> None:
self.license_str = os.getenv("LITELLM_LICENSE", None)
verbose_proxy_logger.debug("License Str value - {}".format(self.license_str))
self.http_handler = HTTPHandler(timeout=NON_LLM_CONNECTION_TIMEOUT)
self._premium_check_logged = False
self.public_key = None
self.read_public_key()
self.airgapped_license_data: Optional["EnterpriseLicenseData"] = None
def read_public_key(self):
try:
from cryptography.hazmat.primitives import serialization
# current dir
current_dir = os.path.dirname(os.path.realpath(__file__))
# check if public_key.pem exists
_path_to_public_key = os.path.join(current_dir, "public_key.pem")
if os.path.exists(_path_to_public_key):
with open(_path_to_public_key, "rb") as key_file:
self.public_key = serialization.load_pem_public_key(key_file.read())
else:
self.public_key = None
except Exception as e:
verbose_proxy_logger.error(f"Error reading public key: {str(e)}")
def _verify(self, license_str: str) -> bool:
verbose_proxy_logger.debug(
"litellm.proxy.auth.litellm_license.py::_verify - Checking license against {}/verify_license - {}".format(
self.base_url, license_str
)
)
url = "{}/verify_license/{}".format(self.base_url, license_str)
response: Optional[httpx.Response] = None
try: # don't impact user, if call fails
num_retries = 3
for i in range(num_retries):
try:
response = self.http_handler.get(url=url)
if response is None:
raise Exception("No response from license server")
response.raise_for_status()
except httpx.HTTPStatusError:
if i == num_retries - 1:
raise
if response is None:
raise Exception("No response from license server")
response_json = response.json()
premium = response_json["verify"]
assert isinstance(premium, bool)
verbose_proxy_logger.debug(
"litellm.proxy.auth.litellm_license.py::_verify - License={} is premium={}".format(
license_str, premium
)
)
return premium
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.auth.litellm_license.py::_verify - Unable to verify License={} via api. - {}".format(
license_str, str(e)
)
)
return False
def is_premium(self) -> bool:
"""
1. verify_license_without_api_request: checks if license was generate using private / public key pair
2. _verify: checks if license is valid calling litellm API. This is the old way we were generating/validating license
"""
try:
if not self._premium_check_logged:
verbose_proxy_logger.debug(
"litellm.proxy.auth.litellm_license.py::is_premium() - ENTERING 'IS_PREMIUM' - LiteLLM License={}".format(
self.license_str
)
)
if self.license_str is None:
self.license_str = os.getenv("LITELLM_LICENSE", None)
if not self._premium_check_logged:
verbose_proxy_logger.debug(
"litellm.proxy.auth.litellm_license.py::is_premium() - Updated 'self.license_str' - {}".format(
self.license_str
)
)
self._premium_check_logged = True
if self.license_str is None:
return False
elif (
self.verify_license_without_api_request(
public_key=self.public_key, license_key=self.license_str
)
is True
):
return True
elif self._verify(license_str=self.license_str) is True:
return True
return False
except Exception:
return False
def is_over_limit(self, total_users: int) -> bool:
"""
Check if the license is over the limit
"""
if self.airgapped_license_data is None:
return False
if "max_users" not in self.airgapped_license_data or not isinstance(
self.airgapped_license_data["max_users"], int
):
return False
return total_users > self.airgapped_license_data["max_users"]
def is_team_count_over_limit(self, team_count: int) -> bool:
"""
Check if the license is over the limit
"""
if self.airgapped_license_data is None:
return False
_max_teams_in_license: Optional[int] = self.airgapped_license_data.get(
"max_teams"
)
if "max_teams" not in self.airgapped_license_data or not isinstance(
_max_teams_in_license, int
):
return False
return team_count > _max_teams_in_license
def verify_license_without_api_request(self, public_key, license_key):
try:
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from litellm.proxy._types import EnterpriseLicenseData
# Decode the license key - add padding if needed for base64
# Base64 strings need to be a multiple of 4 characters
padding_needed = len(license_key) % 4
if padding_needed:
license_key += "=" * (4 - padding_needed)
decoded = base64.b64decode(license_key)
message, signature = decoded.split(b".", 1)
# Verify the signature
public_key.verify(
signature,
message,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
hashes.SHA256(),
)
# Decode and parse the data
license_data = json.loads(message.decode())
self.airgapped_license_data = EnterpriseLicenseData(**license_data)
# debug information provided in license data
verbose_proxy_logger.debug("License data: %s", license_data)
# Check expiration date
expiration_date = datetime.strptime(
license_data["expiration_date"], "%Y-%m-%d"
)
if expiration_date < datetime.now():
return False, "License has expired"
return True
except Exception as e:
verbose_proxy_logger.debug(
"litellm.proxy.auth.litellm_license.py::verify_license_without_api_request - Unable to verify License locally. - {}".format(
str(e)
)
)
return False

View File

@@ -0,0 +1,344 @@
"""
Login utilities for handling user authentication in the proxy server.
This module contains the core login logic that can be reused across different
login endpoints (e.g., /login and /v2/login).
"""
import os
import secrets
from typing import Literal, Optional, cast
from fastapi import HTTPException
import litellm
from litellm.constants import LITELLM_PROXY_ADMIN_NAME, LITELLM_UI_SESSION_DURATION
from litellm.proxy._types import (
LiteLLM_UserTable,
LitellmUserRoles,
ProxyErrorTypes,
ProxyException,
UpdateUserRequest,
UserAPIKeyAuth,
hash_token,
)
from litellm.proxy.management_endpoints.internal_user_endpoints import user_update
from litellm.proxy.management_endpoints.key_management_endpoints import (
generate_key_helper_fn,
)
from litellm.proxy.management_endpoints.ui_sso import (
get_disabled_non_admin_personal_key_creation,
)
from litellm.proxy.utils import PrismaClient, get_server_root_path
from litellm.secret_managers.main import get_secret_bool
from litellm.types.proxy.ui_sso import ReturnedUITokenObject
def get_ui_credentials(master_key: Optional[str]) -> tuple[str, str]:
"""
Get UI username and password from environment variables or master key.
Args:
master_key: Master key for the proxy (used as fallback for password)
Returns:
tuple[str, str]: A tuple containing (ui_username, ui_password)
Raises:
ProxyException: If neither UI_PASSWORD nor master_key is available
"""
ui_username = os.getenv("UI_USERNAME", "admin")
ui_password = os.getenv("UI_PASSWORD", None)
if ui_password is None:
ui_password = str(master_key) if master_key is not None else None
if ui_password is None:
raise ProxyException(
message="set Proxy master key to use UI. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
type=ProxyErrorTypes.auth_error,
param="UI_PASSWORD",
code=500,
)
return ui_username, ui_password
class LoginResult:
"""Result object containing authentication data from login."""
user_id: str
key: str
user_email: Optional[str]
user_role: str
login_method: Literal["sso", "username_password"]
def __init__(
self,
user_id: str,
key: str,
user_email: Optional[str],
user_role: str,
login_method: Literal["sso", "username_password"] = "username_password",
):
self.user_id = user_id
self.key = key
self.user_email = user_email
self.user_role = user_role
self.login_method = login_method
async def authenticate_user( # noqa: PLR0915
username: str,
password: str,
master_key: Optional[str],
prisma_client: Optional[PrismaClient],
) -> LoginResult:
"""
Authenticate a user and generate an API key for UI access.
This function handles two login scenarios:
1. Admin login using UI_USERNAME and UI_PASSWORD
2. User login using email and password from database
Args:
username: Username or email from the login form
password: Password from the login form
master_key: Master key for the proxy (required)
prisma_client: Prisma database client (optional)
Returns:
LoginResult: Object containing authentication data
Raises:
ProxyException: If authentication fails or required configuration is missing
"""
if master_key is None:
raise ProxyException(
message="Master Key not set for Proxy. Please set Master Key to use Admin UI. Set `LITELLM_MASTER_KEY` in .env or set general_settings:master_key in config.yaml. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
type=ProxyErrorTypes.auth_error,
param="master_key",
code=500,
)
ui_username, ui_password = get_ui_credentials(master_key)
# Check if we can find the `username` in the db. On the UI, users can enter username=their email
_user_row: Optional[LiteLLM_UserTable] = None
user_role: Optional[
Literal[
LitellmUserRoles.PROXY_ADMIN,
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
LitellmUserRoles.INTERNAL_USER,
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
]
] = None
if prisma_client is not None:
_user_row = cast(
Optional[LiteLLM_UserTable],
await prisma_client.db.litellm_usertable.find_first(
where={"user_email": {"equals": username, "mode": "insensitive"}}
),
)
"""
To login to Admin UI, we support the following
- Login with UI_USERNAME and UI_PASSWORD
- Login with Invite Link `user_email` and `password` combination
"""
if secrets.compare_digest(
username.encode("utf-8"), ui_username.encode("utf-8")
) and secrets.compare_digest(password.encode("utf-8"), ui_password.encode("utf-8")):
# Non SSO -> If user is using UI_USERNAME and UI_PASSWORD they are Proxy admin
user_role = LitellmUserRoles.PROXY_ADMIN
user_id = LITELLM_PROXY_ADMIN_NAME
# we want the key created to have PROXY_ADMIN_PERMISSIONS
key_user_id = LITELLM_PROXY_ADMIN_NAME
if (
os.getenv("PROXY_ADMIN_ID", None) is not None
and os.environ["PROXY_ADMIN_ID"] == user_id
) or user_id == LITELLM_PROXY_ADMIN_NAME:
# checks if user is admin
key_user_id = os.getenv("PROXY_ADMIN_ID", LITELLM_PROXY_ADMIN_NAME)
# Admin is Authe'd in - generate key for the UI to access Proxy
# ensure this user is set as the proxy admin, in this route there is no sso, we can assume this user is only the admin
await user_update(
data=UpdateUserRequest(
user_id=key_user_id,
user_role=user_role,
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
),
)
if os.getenv("DATABASE_URL") is not None:
response = await generate_key_helper_fn(
request_type="key",
**{
"user_role": LitellmUserRoles.PROXY_ADMIN,
"duration": LITELLM_UI_SESSION_DURATION,
"key_max_budget": litellm.max_ui_session_budget,
"models": [],
"aliases": {},
"config": {},
"spend": 0,
"user_id": key_user_id,
"team_id": "litellm-dashboard",
}, # type: ignore
)
else:
raise ProxyException(
message="No Database connected. Set DATABASE_URL in .env. If set, use `--detailed_debug` to debug issue.",
type=ProxyErrorTypes.auth_error,
param="DATABASE_URL",
code=500,
)
key = response["token"] # type: ignore
if get_secret_bool("EXPERIMENTAL_UI_LOGIN"):
from litellm.proxy.auth.auth_checks import ExperimentalUIJWTToken
user_info: Optional[LiteLLM_UserTable] = None
if _user_row is not None:
user_info = _user_row
elif (
user_id is not None
): # if user_id is not None, we are using the UI_USERNAME and UI_PASSWORD
user_info = LiteLLM_UserTable(
user_id=user_id,
user_role=user_role,
models=[],
max_budget=litellm.max_ui_session_budget,
)
if user_info is None:
raise HTTPException(
status_code=401,
detail={
"error": "User Information is required for experimental UI login"
},
)
key = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token(
user_info
)
return LoginResult(
user_id=user_id,
key=key,
user_email=None,
user_role=user_role,
login_method="username_password",
)
elif _user_row is not None:
"""
When sharing invite links
-> if the user has no role in the DB assume they are only a viewer
"""
user_id = getattr(_user_row, "user_id", "unknown")
user_role = getattr(
_user_row, "user_role", LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
)
user_email = getattr(_user_row, "user_email", "unknown")
_password = getattr(_user_row, "password", "unknown")
if _password is None:
raise ProxyException(
message="User has no password set. Please set a password for the user via `/user/update`.",
type=ProxyErrorTypes.auth_error,
param="password",
code=401,
)
# check if password == _user_row.password
hash_password = hash_token(token=password)
if secrets.compare_digest(
password.encode("utf-8"), _password.encode("utf-8")
) or secrets.compare_digest(
hash_password.encode("utf-8"), _password.encode("utf-8")
):
if os.getenv("DATABASE_URL") is not None:
response = await generate_key_helper_fn(
request_type="key",
**{ # type: ignore
"user_role": user_role,
"duration": LITELLM_UI_SESSION_DURATION,
"key_max_budget": litellm.max_ui_session_budget,
"models": [],
"aliases": {},
"config": {},
"spend": 0,
"user_id": user_id,
"team_id": "litellm-dashboard",
},
)
else:
raise ProxyException(
message="No Database connected. Set DATABASE_URL in .env. If set, use `--detailed_debug` to debug issue.",
type=ProxyErrorTypes.auth_error,
param="DATABASE_URL",
code=500,
)
key = response["token"] # type: ignore
return LoginResult(
user_id=user_id,
key=key,
user_email=user_email,
user_role=cast(str, user_role),
login_method="username_password",
)
else:
raise ProxyException(
message=f"Invalid credentials used to access UI.\nNot valid credentials for {username}",
type=ProxyErrorTypes.auth_error,
param="invalid_credentials",
code=401,
)
else:
raise ProxyException(
message="Invalid credentials used to access UI.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
type=ProxyErrorTypes.auth_error,
param="invalid_credentials",
code=401,
)
def create_ui_token_object(
login_result: LoginResult,
general_settings: dict,
premium_user: bool,
) -> ReturnedUITokenObject:
"""
Create a ReturnedUITokenObject from a LoginResult.
Args:
login_result: The result from authenticate_user
general_settings: General proxy settings dictionary
premium_user: Whether premium features are enabled
Returns:
ReturnedUITokenObject: Token object ready for JWT encoding
"""
disabled_non_admin_personal_key_creation = (
get_disabled_non_admin_personal_key_creation()
)
return ReturnedUITokenObject(
user_id=login_result.user_id,
key=login_result.key,
user_email=login_result.user_email,
user_role=login_result.user_role,
login_method=login_result.login_method,
premium_user=premium_user,
auth_header_name=general_settings.get(
"litellm_key_header_name", "Authorization"
),
disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation,
server_root_path=get_server_root_path(),
)

View File

@@ -0,0 +1,381 @@
# What is this?
## Common checks for /v1/models and `/model/info`
from typing import Dict, List, Optional, Set
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth
from litellm.router import Router
from litellm.router_utils.fallback_event_handlers import get_fallback_model_group
from litellm.types.router import LiteLLM_Params
from litellm.utils import get_valid_models
def _check_wildcard_routing(model: str) -> bool:
"""
Returns True if a model is a provider wildcard.
eg:
- anthropic/*
- openai/*
- *
"""
if "*" in model:
return True
return False
def get_provider_models(
provider: str, litellm_params: Optional[LiteLLM_Params] = None
) -> Optional[List[str]]:
"""
Returns the list of known models by provider
"""
if provider == "*":
return get_valid_models(litellm_params=litellm_params)
if provider in litellm.models_by_provider:
provider_models = get_valid_models(
custom_llm_provider=provider, litellm_params=litellm_params
)
return provider_models
return None
def _get_models_from_access_groups(
model_access_groups: Dict[str, List[str]],
all_models: List[str],
include_model_access_groups: Optional[bool] = False,
) -> List[str]:
idx_to_remove = []
new_models = []
for idx, model in enumerate(all_models):
if model in model_access_groups:
if (
not include_model_access_groups
): # remove access group, unless requested - e.g. when creating a key
idx_to_remove.append(idx)
new_models.extend(model_access_groups[model])
for idx in sorted(idx_to_remove, reverse=True):
all_models.pop(idx)
all_models.extend(new_models)
return all_models
async def get_mcp_server_ids(
user_api_key_dict: UserAPIKeyAuth,
) -> List[str]:
"""
Returns the list of MCP server ids for a given key by querying the object_permission table
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
return []
if user_api_key_dict.object_permission_id is None:
return []
# Make a direct SQL query to get just the mcp_servers
try:
result = await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": user_api_key_dict.object_permission_id},
)
if result and result.mcp_servers:
return result.mcp_servers
return []
except Exception:
return []
def get_key_models(
user_api_key_dict: UserAPIKeyAuth,
proxy_model_list: List[str],
model_access_groups: Dict[str, List[str]],
include_model_access_groups: Optional[bool] = False,
only_model_access_groups: Optional[bool] = False,
) -> List[str]:
"""
Returns:
- List of model name strings
- Empty list if no models set
- If model_access_groups is provided, only return models that are in the access groups
- If include_model_access_groups is True, it includes the 'keys' of the model_access_groups
in the response - {"beta-models": ["gpt-4", "claude-v1"]} -> returns 'beta-models'
"""
all_models: List[str] = []
if len(user_api_key_dict.models) > 0:
all_models = list(
user_api_key_dict.models
) # copy to avoid mutating cached objects
if SpecialModelNames.all_team_models.value in all_models:
all_models = list(
user_api_key_dict.team_models
) # copy to avoid mutating cached objects
if SpecialModelNames.all_proxy_models.value in all_models:
all_models = list(proxy_model_list) # copy to avoid mutating caller's list
if include_model_access_groups:
all_models.extend(model_access_groups.keys())
all_models = _get_models_from_access_groups(
model_access_groups=model_access_groups,
all_models=all_models,
include_model_access_groups=include_model_access_groups,
)
# deduplicate while preserving order
all_models = list(dict.fromkeys(all_models))
verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
return all_models
def get_team_models(
team_models: List[str],
proxy_model_list: List[str],
model_access_groups: Dict[str, List[str]],
include_model_access_groups: Optional[bool] = False,
) -> List[str]:
"""
Returns:
- List of model name strings
- Empty list if no models set
- If model_access_groups is provided, only return models that are in the access groups
"""
all_models_set: Set[str] = set()
if len(team_models) > 0:
all_models_set.update(team_models)
if SpecialModelNames.all_team_models.value in all_models_set:
all_models_set.update(team_models)
if SpecialModelNames.all_proxy_models.value in all_models_set:
all_models_set.update(proxy_model_list)
if include_model_access_groups:
all_models_set.update(model_access_groups.keys())
all_models = _get_models_from_access_groups(
model_access_groups=model_access_groups,
all_models=list(all_models_set),
include_model_access_groups=include_model_access_groups,
)
# deduplicate while preserving order
all_models = list(dict.fromkeys(all_models))
verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models)))
return all_models
def get_complete_model_list(
key_models: List[str],
team_models: List[str],
proxy_model_list: List[str],
user_model: Optional[str],
infer_model_from_keys: Optional[bool],
return_wildcard_routes: Optional[bool] = False,
llm_router: Optional[Router] = None,
model_access_groups: Dict[str, List[str]] = {},
include_model_access_groups: Optional[bool] = False,
only_model_access_groups: Optional[bool] = False,
) -> List[str]:
"""Logic for returning complete model list for a given key + team pair"""
"""
- If key list is empty -> defer to team list
- If team list is empty -> defer to proxy model list
If list contains wildcard -> return known provider models
"""
unique_models = []
def append_unique(models):
for model in models:
if model not in unique_models:
unique_models.append(model)
if key_models:
append_unique(key_models)
elif team_models:
append_unique(team_models)
else:
append_unique(proxy_model_list)
if include_model_access_groups:
append_unique(list(model_access_groups.keys())) # TODO: keys order
if user_model:
append_unique([user_model])
if infer_model_from_keys:
valid_models = get_valid_models()
append_unique(valid_models)
if only_model_access_groups:
model_access_groups_to_return: List[str] = []
for model in unique_models:
if model in model_access_groups:
model_access_groups_to_return.append(model)
return model_access_groups_to_return
all_wildcard_models = _get_wildcard_models(
unique_models=unique_models,
return_wildcard_routes=return_wildcard_routes,
llm_router=llm_router,
)
complete_model_list = unique_models + all_wildcard_models
return complete_model_list
def get_known_models_from_wildcard(
wildcard_model: str, litellm_params: Optional[LiteLLM_Params] = None
) -> List[str]:
try:
wildcard_provider_prefix, wildcard_suffix = wildcard_model.split("/", 1)
except ValueError: # safely fail
return []
# Use provider from litellm_params when available, otherwise from wildcard prefix
# (e.g., "openai" from "openai/*" - needed for BYOK where wildcard isn't in router)
if litellm_params is not None:
try:
provider = litellm_params.model.split("/", 1)[0]
except ValueError:
provider = wildcard_provider_prefix
else:
provider = wildcard_provider_prefix
# get all known provider models
wildcard_models = get_provider_models(
provider=provider, litellm_params=litellm_params
)
if wildcard_models is None:
return []
if wildcard_suffix != "*":
## CHECK IF PARTIAL FILTER e.g. `gemini-*`
model_prefix = wildcard_suffix.replace("*", "")
is_partial_filter = any(
wc_model.startswith(model_prefix) for wc_model in wildcard_models
)
if is_partial_filter:
filtered_wildcard_models = [
wc_model
for wc_model in wildcard_models
if wc_model.startswith(model_prefix)
]
wildcard_models = filtered_wildcard_models
else:
# add model prefix to wildcard models
wildcard_models = [f"{model_prefix}{model}" for model in wildcard_models]
suffix_appended_wildcard_models = []
for model in wildcard_models:
if not model.startswith(wildcard_provider_prefix):
model = f"{wildcard_provider_prefix}/{model}"
suffix_appended_wildcard_models.append(model)
return suffix_appended_wildcard_models or []
def _get_wildcard_models(
unique_models: List[str],
return_wildcard_routes: Optional[bool] = False,
llm_router: Optional[Router] = None,
) -> List[str]:
models_to_remove = set()
all_wildcard_models = []
for model in unique_models:
if _check_wildcard_routing(model=model):
if (
return_wildcard_routes
): # will add the wildcard route to the list eg: anthropic/*.
all_wildcard_models.append(model)
## get litellm params from model
if llm_router is not None:
model_list = llm_router.get_model_list(model_name=model)
if model_list:
for router_model in model_list:
wildcard_models = get_known_models_from_wildcard(
wildcard_model=model,
litellm_params=LiteLLM_Params(
**router_model["litellm_params"] # type: ignore
),
)
all_wildcard_models.extend(wildcard_models)
else:
# Router has no deployment for this wildcard (e.g., BYOK team models)
# Fall back to expanding from known provider models
wildcard_models = get_known_models_from_wildcard(
wildcard_model=model, litellm_params=None
)
if wildcard_models:
models_to_remove.add(model)
all_wildcard_models.extend(wildcard_models)
else:
# get all known provider models
wildcard_models = get_known_models_from_wildcard(
wildcard_model=model, litellm_params=None
)
if wildcard_models:
models_to_remove.add(model)
all_wildcard_models.extend(wildcard_models)
for model in models_to_remove:
unique_models.remove(model)
return all_wildcard_models
def get_all_fallbacks(
model: str,
llm_router: Optional[Router] = None,
fallback_type: str = "general",
) -> List[str]:
"""
Get all fallbacks for a given model from the router's fallback configuration.
Args:
model: The model name to get fallbacks for
llm_router: The LiteLLM router instance
fallback_type: Type of fallback ("general", "context_window", "content_policy")
Returns:
List of fallback model names. Empty list if no fallbacks found.
"""
if llm_router is None:
return []
# Get the appropriate fallback list based on type
fallbacks_config: list = []
if fallback_type == "general":
fallbacks_config = getattr(llm_router, "fallbacks", [])
elif fallback_type == "context_window":
fallbacks_config = getattr(llm_router, "context_window_fallbacks", [])
elif fallback_type == "content_policy":
fallbacks_config = getattr(llm_router, "content_policy_fallbacks", [])
else:
verbose_proxy_logger.warning(f"Unknown fallback_type: {fallback_type}")
return []
if not fallbacks_config:
return []
try:
# Use existing function to get fallback model group
fallback_model_group, _ = get_fallback_model_group(
fallbacks=fallbacks_config, model_group=model
)
if fallback_model_group is None:
return []
return fallback_model_group
except Exception as e:
verbose_proxy_logger.error(f"Error getting fallbacks for model {model}: {e}")
return []

View File

@@ -0,0 +1,222 @@
import base64
import os
from typing import Dict, Optional, Tuple, cast
import httpx
from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
class Oauth2Handler:
"""
Handles OAuth2 token validation.
"""
@staticmethod
def _is_introspection_endpoint(
token_info_endpoint: str,
oauth_client_id: Optional[str],
oauth_client_secret: Optional[str],
) -> bool:
"""
Determine if this is an introspection endpoint (requires POST) or token info endpoint (uses GET).
Args:
token_info_endpoint: The OAuth2 endpoint URL
oauth_client_id: OAuth2 client ID
oauth_client_secret: OAuth2 client secret
Returns:
bool: True if this is an introspection endpoint
"""
return (
"introspect" in token_info_endpoint.lower()
and oauth_client_id is not None
and oauth_client_secret is not None
)
@staticmethod
def _prepare_introspection_request(
token: str,
oauth_client_id: Optional[str],
oauth_client_secret: Optional[str],
) -> Tuple[Dict[str, str], Dict[str, str]]:
"""
Prepare headers and data for OAuth2 introspection endpoint (RFC 7662).
Args:
token: The OAuth2 token to validate
oauth_client_id: OAuth2 client ID
oauth_client_secret: OAuth2 client secret
Returns:
Tuple of (headers, data) for the introspection request
"""
headers = {"Content-Type": "application/x-www-form-urlencoded"}
data = {"token": token}
# Add client authentication if credentials are provided
if oauth_client_id and oauth_client_secret:
# Use HTTP Basic authentication for client credentials
credentials = base64.b64encode(
f"{oauth_client_id}:{oauth_client_secret}".encode()
).decode()
headers["Authorization"] = f"Basic {credentials}"
elif oauth_client_id:
# For public clients, include client_id in the request body
data["client_id"] = oauth_client_id
return headers, data
@staticmethod
def _prepare_token_info_request(token: str) -> Dict[str, str]:
"""
Prepare headers for generic token info endpoint.
Args:
token: The OAuth2 token to validate
Returns:
Dict of headers for the token info request
"""
return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
@staticmethod
def _extract_user_info(
response_data: Dict,
user_id_field_name: str,
user_role_field_name: str,
user_team_id_field_name: str,
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""
Extract user information from OAuth2 response.
Args:
response_data: The response data from OAuth2 endpoint
user_id_field_name: Field name for user ID
user_role_field_name: Field name for user role
user_team_id_field_name: Field name for team ID
Returns:
Tuple of (user_id, user_role, user_team_id)
"""
user_id = response_data.get(user_id_field_name)
user_team_id = response_data.get(user_team_id_field_name)
user_role = response_data.get(user_role_field_name)
return user_id, user_role, user_team_id
@staticmethod
async def check_oauth2_token(token: str) -> UserAPIKeyAuth:
"""
Makes a request to the token introspection endpoint to validate the OAuth2 token.
This function implements OAuth2 token introspection according to RFC 7662.
It supports both generic token info endpoints (GET) and OAuth2 introspection endpoints (POST).
Args:
token (str): The OAuth2 token to validate.
Returns:
UserAPIKeyAuth: If the token is valid, containing user information.
Raises:
ValueError: If the token is invalid, the request fails, or the token info endpoint is not set.
"""
from litellm.proxy.proxy_server import premium_user
if premium_user is not True:
raise ValueError(
"Oauth2 token validation is only available for premium users"
+ CommonProxyErrors.not_premium_user.value
)
verbose_proxy_logger.debug("Oauth2 token validation for token=%s", token)
# Get the token info endpoint from environment variable
token_info_endpoint = os.getenv("OAUTH_TOKEN_INFO_ENDPOINT")
user_id_field_name = os.environ.get("OAUTH_USER_ID_FIELD_NAME", "sub")
user_role_field_name = os.environ.get("OAUTH_USER_ROLE_FIELD_NAME", "role")
user_team_id_field_name = os.environ.get(
"OAUTH_USER_TEAM_ID_FIELD_NAME", "team_id"
)
# OAuth2 client credentials for introspection endpoint authentication
oauth_client_id = os.environ.get("OAUTH_CLIENT_ID")
oauth_client_secret = os.environ.get("OAUTH_CLIENT_SECRET")
if not token_info_endpoint:
raise ValueError(
"OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set"
)
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
# Determine if this is an introspection endpoint (requires POST) or token info endpoint (uses GET)
is_introspection_endpoint = Oauth2Handler._is_introspection_endpoint(
token_info_endpoint=token_info_endpoint,
oauth_client_id=oauth_client_id,
oauth_client_secret=oauth_client_secret,
)
try:
if is_introspection_endpoint:
# OAuth2 Token Introspection (RFC 7662) - requires POST with form data
verbose_proxy_logger.debug("Using OAuth2 introspection endpoint (POST)")
headers, data = Oauth2Handler._prepare_introspection_request(
token=token,
oauth_client_id=oauth_client_id,
oauth_client_secret=oauth_client_secret,
)
response = await client.post(
token_info_endpoint, headers=headers, data=data
)
else:
# Generic token info endpoint - uses GET with Bearer token
verbose_proxy_logger.debug("Using generic token info endpoint (GET)")
headers = Oauth2Handler._prepare_token_info_request(token=token)
response = await client.get(token_info_endpoint, headers=headers)
# if it's a bad token we expect it to raise an HTTPStatusError
response.raise_for_status()
# If we get here, the request was successful
data = response.json()
verbose_proxy_logger.debug(
"Oauth2 token validation for token=%s, response from endpoint=%s",
token,
data,
)
# For introspection endpoints, check if token is active
if is_introspection_endpoint and not data.get("active", True):
raise ValueError("Token is not active")
# Extract user information from response
user_id, user_role, user_team_id = Oauth2Handler._extract_user_info(
response_data=data,
user_id_field_name=user_id_field_name,
user_role_field_name=user_role_field_name,
user_team_id_field_name=user_team_id_field_name,
)
return UserAPIKeyAuth(
api_key=token,
team_id=user_team_id,
user_id=user_id,
user_role=cast(LitellmUserRoles, user_role),
)
except httpx.HTTPStatusError as e:
# This will catch any 4xx or 5xx errors
raise ValueError(f"Oauth 2.0 Token validation failed: {e}")
except Exception as e:
# This will catch any other errors (like network issues)
raise ValueError(f"An error occurred during token validation: {e}")

View File

@@ -0,0 +1,45 @@
from typing import Any, Dict
from fastapi import Request
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import UserAPIKeyAuth
async def handle_oauth2_proxy_request(request: Request) -> UserAPIKeyAuth:
"""
Handle request from oauth2 proxy.
"""
from litellm.proxy.proxy_server import general_settings
verbose_proxy_logger.debug("Handling oauth2 proxy request")
# Define the OAuth2 config mappings
oauth2_config_mappings: Dict[str, str] = (
general_settings.get("oauth2_config_mappings") or {}
)
verbose_proxy_logger.debug(f"Oauth2 config mappings: {oauth2_config_mappings}")
if not oauth2_config_mappings:
raise ValueError("Oauth2 config mappings not found in general_settings")
# Initialize a dictionary to store the mapped values
auth_data: Dict[str, Any] = {}
# Extract values from headers based on the mappings
for key, header in oauth2_config_mappings.items():
value = request.headers.get(header)
if value:
# Convert max_budget to float if present
if key == "max_budget":
auth_data[key] = float(value)
# Convert models to list if present
elif key == "models":
auth_data[key] = [model.strip() for model in value.split(",")]
else:
auth_data[key] = value
verbose_proxy_logger.debug(
f"Auth data before creating UserAPIKeyAuth object: {auth_data}"
)
user_api_key_auth = UserAPIKeyAuth(**auth_data)
verbose_proxy_logger.debug(f"UserAPIKeyAuth object created: {user_api_key_auth}")
# Create and return UserAPIKeyAuth object
return user_api_key_auth

View File

@@ -0,0 +1,9 @@
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwcNBabWBZzrDhFAuA4Fh
FhIcA3rF7vrLb8+1yhF2U62AghQp9nStyuJRjxMUuldWgJ1yRJ2s7UffVw5r8DeA
dqXPD+w+3LCNwqJGaIKN08QGJXNArM3QtMaN0RTzAyQ4iibN1r6609W5muK9wGp0
b1j5+iDUmf0ynItnhvaX6B8Xoaflc3WD/UBdrygLmsU5uR3XC86+/8ILoSZH3HtN
6FJmWhlhjS2TR1cKZv8K5D0WuADTFf5MF8jYFR+uORPj5Pe/EJlLGN26Lfn2QnGu
XgbPF6nCGwZ0hwH1Xkn3xzGaJ4xBEC761wqp5cHxWSDktHyFKnLbP3jVeegjVIHh
pQIDAQAB
-----END PUBLIC KEY-----

View File

@@ -0,0 +1,187 @@
import os
from typing import Any, Optional, Union
import httpx
def init_rds_client(
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
):
from litellm.secret_managers.main import get_secret
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
standard_aws_region_name = get_secret("AWS_REGION", None)
## CHECK IS 'os.environ/' passed in
# Define the list of parameters to check
params_to_check = [
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
params_to_check[i] = get_secret(param) # type: ignore
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
) = params_to_check
### SET REGION NAME
region_name = aws_region_name
if aws_region_name:
region_name = aws_region_name
elif litellm_aws_region_name:
region_name = litellm_aws_region_name
elif standard_aws_region_name:
region_name = standard_aws_region_name
else:
raise Exception(
"AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
)
import boto3
if isinstance(timeout, float):
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) # type: ignore
elif isinstance(timeout, httpx.Timeout):
config = boto3.session.Config( # type: ignore
connect_timeout=timeout.connect, read_timeout=timeout.read
)
else:
config = boto3.session.Config() # type: ignore
### CHECK STS ###
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
try:
oidc_token = open(aws_web_identity_token).read() # check if filepath
except Exception:
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise Exception(
"OIDC token could not be retrieved from secret manager.",
)
sts_client = boto3.client("sts")
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
client = boto3.client(
service_name="rds",
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=region_name,
config=config,
)
elif aws_role_name is not None and aws_session_name is not None:
# use sts if role name passed in
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
client = boto3.client(
service_name="rds",
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=region_name,
config=config,
)
elif aws_access_key_id is not None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
client = boto3.client(
service_name="rds",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=region_name,
config=config,
)
elif aws_profile_name is not None:
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name).client(
service_name="rds",
region_name=region_name,
config=config,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automatically reads env variables
client = boto3.client(
service_name="rds",
region_name=region_name,
config=config,
)
return client
def generate_iam_auth_token(
db_host, db_port, db_user, client: Optional[Any] = None
) -> str:
from urllib.parse import quote
if client is None:
boto_client = init_rds_client(
aws_region_name=os.getenv("AWS_REGION_NAME"),
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
aws_session_name=os.getenv("AWS_SESSION_NAME"),
aws_profile_name=os.getenv("AWS_PROFILE_NAME"),
aws_role_name=os.getenv("AWS_ROLE_NAME", os.getenv("AWS_ROLE_ARN")),
aws_web_identity_token=os.getenv(
"AWS_WEB_IDENTITY_TOKEN", os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
),
)
else:
boto_client = client
token = boto_client.generate_db_auth_token(
DBHostname=db_host, Port=db_port, DBUsername=db_user
)
cleaned_token = quote(token, safe="")
return cleaned_token

View File

@@ -0,0 +1,669 @@
import re
from typing import List, Optional
from fastapi import HTTPException, Request, status
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
CommonProxyErrors,
LiteLLM_UserTable,
LiteLLMRoutes,
LitellmUserRoles,
UserAPIKeyAuth,
)
from .auth_checks_organization import _user_is_org_admin
class RouteChecks:
@staticmethod
def should_call_route(route: str, valid_token: UserAPIKeyAuth):
"""
Check if management route is disabled and raise exception
"""
try:
from litellm_enterprise.proxy.auth.route_checks import EnterpriseRouteChecks
EnterpriseRouteChecks.should_call_route(route=route)
except HTTPException as e:
raise e
except Exception:
pass
# Check if Virtual Key is allowed to call the route - Applies to all Roles
RouteChecks.is_virtual_key_allowed_to_call_route(
route=route, valid_token=valid_token
)
return True
@staticmethod
def is_virtual_key_allowed_to_call_route(
route: str, valid_token: UserAPIKeyAuth
) -> bool:
"""
Raises Exception if Virtual Key is not allowed to call the route
"""
# Only check if valid_token.allowed_routes is set and is a list with at least one item
if valid_token.allowed_routes is None:
return True
if not isinstance(valid_token.allowed_routes, list):
return True
if len(valid_token.allowed_routes) == 0:
return True
# explicit check for allowed routes (exact match or prefix match)
for allowed_route in valid_token.allowed_routes:
if RouteChecks._route_matches_allowed_route(
route=route, allowed_route=allowed_route
):
return True
## check if 'allowed_route' is a field name in LiteLLMRoutes
if any(
allowed_route in LiteLLMRoutes._member_names_
for allowed_route in valid_token.allowed_routes
):
for allowed_route in valid_token.allowed_routes:
if allowed_route in LiteLLMRoutes._member_names_:
if RouteChecks.check_route_access(
route=route,
allowed_routes=LiteLLMRoutes._member_map_[allowed_route].value,
):
return True
################################################
# For llm_api_routes, also check registered pass-through endpoints
################################################
if allowed_route == "llm_api_routes":
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
InitPassThroughEndpointHelpers,
)
if InitPassThroughEndpointHelpers.is_registered_pass_through_route(
route=route
):
return True
# check if wildcard pattern is allowed
for allowed_route in valid_token.allowed_routes:
if RouteChecks._route_matches_wildcard_pattern(
route=route, pattern=allowed_route
):
return True
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Virtual key is not allowed to call this route. Only allowed to call routes: {valid_token.allowed_routes}. Tried to call route: {route}",
)
@staticmethod
def _mask_user_id(user_id: str) -> str:
"""
Mask user_id to prevent leaking sensitive information in error messages
Args:
user_id (str): The user_id to mask
Returns:
str: Masked user_id showing only first 2 and last 2 characters
"""
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
if not user_id or len(user_id) <= 4:
return "***"
# Use SensitiveDataMasker with custom configuration for user_id
masker = SensitiveDataMasker(visible_prefix=6, visible_suffix=2, mask_char="*")
return masker._mask_value(user_id)
@staticmethod
def _raise_admin_only_route_exception(
user_obj: Optional[LiteLLM_UserTable],
route: str,
) -> None:
"""
Raise exception for routes that require proxy admin access
Args:
user_obj (Optional[LiteLLM_UserTable]): The user object
route (str): The route being accessed
Raises:
Exception: With user role and masked user_id information
"""
user_role = "unknown"
user_id = "unknown"
if user_obj is not None:
user_role = user_obj.user_role or "unknown"
user_id = user_obj.user_id or "unknown"
masked_user_id = RouteChecks._mask_user_id(user_id)
raise Exception(
f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={masked_user_id}"
)
@staticmethod
def non_proxy_admin_allowed_routes_check(
user_obj: Optional[LiteLLM_UserTable],
_user_role: Optional[LitellmUserRoles],
route: str,
request: Request,
valid_token: UserAPIKeyAuth,
request_data: dict,
):
"""
Checks if Non Proxy Admin User is allowed to access the route
"""
# Check user has defined custom admin routes
RouteChecks.custom_admin_only_route_check(
route=route,
)
if RouteChecks.is_llm_api_route(route=route):
pass
elif RouteChecks.is_info_route(route=route):
# check if user allowed to call an info route
if route == "/key/info":
# handled by function itself
pass
elif route == "/user/info":
# check if user can access this route
query_params = request.query_params
user_id = query_params.get("user_id")
verbose_proxy_logger.debug(
f"user_id: {user_id} & valid_token.user_id: {valid_token.user_id}"
)
if user_id and user_id != valid_token.user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="key not allowed to access this user's info. user_id={}, key's user_id={}".format(
user_id, valid_token.user_id
),
)
elif route == "/model/info":
# /model/info just shows models user has access to
pass
elif route == "/team/info":
pass # handled by function itself
elif (
route in LiteLLMRoutes.global_spend_tracking_routes.value
and getattr(valid_token, "permissions", None) is not None
and "get_spend_routes" in getattr(valid_token, "permissions", [])
):
pass
elif _user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value:
RouteChecks._check_proxy_admin_viewer_access(
route=route,
_user_role=_user_role,
request_data=request_data,
)
elif (
_user_role == LitellmUserRoles.INTERNAL_USER.value
and RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.internal_user_routes.value
)
):
pass
elif _user_is_org_admin(
request_data=request_data, user_object=user_obj
) and RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.org_admin_allowed_routes.value
):
pass
elif (
_user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
and RouteChecks.check_route_access(
route=route,
allowed_routes=LiteLLMRoutes.internal_user_view_only_routes.value,
)
):
pass
elif RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.self_managed_routes.value
): # routes that manage their own allowed/disallowed logic
pass
elif route.startswith("/v1/mcp/") or route.startswith("/mcp-rest/"):
pass # authN/authZ handled by api itself
elif RouteChecks.check_passthrough_route_access(
route=route, user_api_key_dict=valid_token
):
pass
elif valid_token.allowed_routes is not None:
# check if route is in allowed_routes (exact match or prefix match)
route_allowed = False
for allowed_route in valid_token.allowed_routes:
if RouteChecks._route_matches_allowed_route(
route=route, allowed_route=allowed_route
):
route_allowed = True
break
if RouteChecks._route_matches_wildcard_pattern(
route=route, pattern=allowed_route
):
route_allowed = True
break
if not route_allowed:
RouteChecks._raise_admin_only_route_exception(
user_obj=user_obj, route=route
)
else:
RouteChecks._raise_admin_only_route_exception(
user_obj=user_obj, route=route
)
@staticmethod
def custom_admin_only_route_check(route: str):
from litellm.proxy.proxy_server import general_settings, premium_user
if "admin_only_routes" in general_settings:
if premium_user is not True:
verbose_proxy_logger.error(
f"Trying to use 'admin_only_routes' this is an Enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
)
return
if route in general_settings["admin_only_routes"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"user not allowed to access this route. Route={route} is an admin only route",
)
pass
@staticmethod
def is_llm_api_route(route: str) -> bool:
"""
Helper to checks if provided route is an OpenAI route
Returns:
- True: if route is an OpenAI route
- False: if route is not an OpenAI route
"""
# Ensure route is a string before performing checks
if not isinstance(route, str):
return False
if route in LiteLLMRoutes.openai_routes.value:
return True
if route in LiteLLMRoutes.anthropic_routes.value:
return True
if route in LiteLLMRoutes.google_routes.value:
return True
if RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.mcp_routes.value
):
return True
if RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.agent_routes.value
):
return True
if route in LiteLLMRoutes.litellm_native_routes.value:
return True
# fuzzy match routes like "/v1/threads/thread_49EIN5QF32s4mH20M7GFKdlZ"
# Check for routes with placeholders or wildcard patterns
for openai_route in LiteLLMRoutes.openai_routes.value:
# Replace placeholders with regex pattern
# placeholders are written as "/threads/{thread_id}"
if "{" in openai_route:
if RouteChecks._route_matches_pattern(
route=route, pattern=openai_route
):
return True
# Check for wildcard patterns like "/containers/*"
if RouteChecks._is_wildcard_pattern(pattern=openai_route):
if RouteChecks._route_matches_wildcard_pattern(
route=route, pattern=openai_route
):
return True
# Check for Google routes with placeholders like "/v1beta/models/{model_name}:generateContent"
for google_route in LiteLLMRoutes.google_routes.value:
if "{" in google_route:
if RouteChecks._route_matches_pattern(
route=route, pattern=google_route
):
return True
# Check for Anthropic routes with placeholders
for anthropic_route in LiteLLMRoutes.anthropic_routes.value:
if "{" in anthropic_route:
if RouteChecks._route_matches_pattern(
route=route, pattern=anthropic_route
):
return True
if RouteChecks._is_azure_openai_route(route=route):
return True
for _llm_passthrough_route in LiteLLMRoutes.mapped_pass_through_routes.value:
if _llm_passthrough_route in route:
return True
return False
@staticmethod
def is_management_route(route: str) -> bool:
"""
Check if route is a management route
"""
return route in LiteLLMRoutes.management_routes.value
@staticmethod
def is_info_route(route: str) -> bool:
"""
Check if route is an info route
"""
return route in LiteLLMRoutes.info_routes.value
@staticmethod
def _is_azure_openai_route(route: str) -> bool:
"""
Check if route is a route from AzureOpenAI SDK client
eg.
route='/openai/deployments/vertex_ai/gemini-1.5-flash/chat/completions'
"""
# Ensure route is a string before attempting regex matching
if not isinstance(route, str):
return False
# Add support for deployment and engine model paths
deployment_pattern = r"^/openai/deployments/[^/]+/[^/]+/chat/completions$"
engine_pattern = r"^/engines/[^/]+/chat/completions$"
if re.match(deployment_pattern, route) or re.match(engine_pattern, route):
return True
return False
@staticmethod
def _route_matches_pattern(route: str, pattern: str) -> bool:
"""
Check if route matches the pattern placed in proxy/_types.py
Example:
- pattern: "/threads/{thread_id}"
- route: "/threads/thread_49EIN5QF32s4mH20M7GFKdlZ"
- returns: True
- pattern: "/key/{token_id}/regenerate"
- route: "/key/regenerate/82akk800000000jjsk"
- returns: False, pattern is "/key/{token_id}/regenerate"
"""
# Ensure route is a string before attempting regex matching
if not isinstance(route, str):
return False
def _placeholder_to_regex(match: re.Match) -> str:
placeholder = match.group(0).strip("{}")
if placeholder.endswith(":path"):
# allow "/" in the placeholder value, but don't eat the route suffix after ":"
return r"[^:]+"
return r"[^/]+"
pattern = re.sub(r"\{[^}]+\}", _placeholder_to_regex, pattern)
# Anchor the pattern to match the entire string
pattern = f"^{pattern}$"
if re.match(pattern, route):
return True
return False
@staticmethod
def _is_wildcard_pattern(pattern: str) -> bool:
"""
Check if pattern is a wildcard pattern
"""
return pattern.endswith("*")
@staticmethod
def _route_matches_wildcard_pattern(route: str, pattern: str) -> bool:
"""
Check if route matches the wildcard pattern
eg.
pattern: "/scim/v2/*"
route: "/scim/v2/Users"
- returns: True
pattern: "/scim/v2/*"
route: "/chat/completions"
- returns: False
pattern: "/scim/v2/*"
route: "/scim/v2/Users/123"
- returns: True
"""
if pattern.endswith("*"):
# Get the prefix (everything before the wildcard)
prefix = pattern[:-1]
return route.startswith(prefix)
else:
# If there's no wildcard, the pattern and route should match exactly
return route == pattern
@staticmethod
def _route_matches_allowed_route(route: str, allowed_route: str) -> bool:
"""
Check if route matches the allowed_route pattern.
Supports both exact match and prefix match.
Examples:
- allowed_route="/fake-openai-proxy-6", route="/fake-openai-proxy-6" -> True (exact match)
- allowed_route="/fake-openai-proxy-6", route="/fake-openai-proxy-6/v1/chat/completions" -> True (prefix match)
- allowed_route="/fake-openai-proxy-6", route="/fake-openai-proxy-600" -> False (not a valid prefix)
Args:
route: The actual route being accessed
allowed_route: The allowed route pattern
Returns:
bool: True if route matches (exact or prefix), False otherwise
"""
# Exact match
if route == allowed_route:
return True
# Prefix match - ensure we add "/" to prevent false matches like /fake-openai-proxy-600
if route.startswith(allowed_route + "/"):
return True
return False
@staticmethod
def check_route_access(route: str, allowed_routes: List[str]) -> bool:
"""
Check if a route has access by checking both exact matches and patterns
Args:
route (str): The route to check
allowed_routes (list): List of allowed routes/patterns
Returns:
bool: True if route is allowed, False otherwise
"""
#########################################################
# exact match route is in allowed_routes
#########################################################
if route in allowed_routes:
return True
#########################################################
# wildcard match route is in allowed_routes
# e.g calling /anthropic/v1/messages is allowed if allowed_routes has /anthropic/*
#########################################################
wildcard_allowed_routes = [
route
for route in allowed_routes
if RouteChecks._is_wildcard_pattern(pattern=route)
]
for allowed_route in wildcard_allowed_routes:
if RouteChecks._route_matches_wildcard_pattern(
route=route, pattern=allowed_route
):
return True
#########################################################
# pattern match route is in allowed_routes
# pattern: "/threads/{thread_id}"
# route: "/threads/thread_49EIN5QF32s4mH20M7GFKdlZ"
# returns: True
#########################################################
if any( # Check pattern match
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
for allowed_route in allowed_routes
):
return True
return False
@staticmethod
def check_passthrough_route_access(
route: str, user_api_key_dict: UserAPIKeyAuth
) -> bool:
"""
Check if route is a passthrough route.
Supports both exact match and prefix match.
"""
metadata = user_api_key_dict.metadata
team_metadata = user_api_key_dict.team_metadata or {}
if metadata is None and team_metadata is None:
return False
if (
"allowed_passthrough_routes" not in metadata
and "allowed_passthrough_routes" not in team_metadata
):
return False
if (
metadata.get("allowed_passthrough_routes") is None
and team_metadata.get("allowed_passthrough_routes") is None
):
return False
allowed_passthrough_routes = (
metadata.get("allowed_passthrough_routes")
or team_metadata.get("allowed_passthrough_routes")
or []
)
# Check if route matches any allowed passthrough route (exact or prefix match)
for allowed_route in allowed_passthrough_routes:
if RouteChecks._route_matches_allowed_route(
route=route, allowed_route=allowed_route
):
return True
return False
@staticmethod
def _is_assistants_api_request(request: Request) -> bool:
"""
Returns True if `thread` or `assistant` is in the request path
Args:
request (Request): The request object
Returns:
bool: True if `thread` or `assistant` is in the request path, False otherwise
"""
if "thread" in request.url.path or "assistant" in request.url.path:
return True
return False
@staticmethod
def is_generate_content_route(route: str) -> bool:
"""
Returns True if this is a google generateContent or streamGenerateContent route
These routes from google allow passing key=api_key in the query params
"""
if "generateContent" in route:
return True
if "streamGenerateContent" in route:
return True
return False
@staticmethod
def _check_proxy_admin_viewer_access(
route: str,
_user_role: str,
request_data: dict,
) -> None:
"""
Check access for PROXY_ADMIN_VIEW_ONLY role
"""
if RouteChecks.is_llm_api_route(route=route):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"user not allowed to access this OpenAI routes, role= {_user_role}",
)
# Check if this is a write operation on management routes
if RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.management_routes.value
):
# For management routes, only allow read operations or specific allowed updates
if route == "/user/update":
# Check the Request params are valid for PROXY_ADMIN_VIEW_ONLY
if request_data is not None and isinstance(request_data, dict):
_params_updated = request_data.keys()
for param in _params_updated:
if param not in ["user_email", "password"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route} and updating invalid param: {param}. only user_email and password can be updated",
)
elif (
route
in [
"/user/new",
"/user/delete",
"/team/new",
"/team/update",
"/team/delete",
"/model/new",
"/model/update",
"/model/delete",
"/key/generate",
"/key/delete",
"/key/update",
"/key/regenerate",
"/key/service-account/generate",
"/key/block",
"/key/unblock",
]
or route.startswith("/key/")
and route.endswith("/regenerate")
):
# Block write operations for PROXY_ADMIN_VIEW_ONLY
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
)
# Allow read operations on management routes (like /user/info, /team/info, /model/info)
return
elif RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.admin_viewer_routes.value
):
# Allow access to admin viewer routes (read-only admin endpoints)
return
elif RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.global_spend_tracking_routes.value
):
# Allow access to global spend tracking routes (read-only spend endpoints)
# proxy_admin_viewer role description: "view all keys, view all spend"
return
else:
# For other routes, block access
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
)

File diff suppressed because it is too large Load Diff