chore: initial snapshot for gitea/github upload
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Auth Checks for Organizations
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import status
|
||||
|
||||
from litellm.proxy._types import *
|
||||
|
||||
|
||||
def organization_role_based_access_check(
|
||||
request_body: dict,
|
||||
user_object: Optional[LiteLLM_UserTable],
|
||||
route: str,
|
||||
):
|
||||
"""
|
||||
Role based access control checks only run if a user is part of an Organization
|
||||
|
||||
Organization Checks:
|
||||
ONLY RUN IF user_object.organization_memberships is not None
|
||||
|
||||
1. Only Proxy Admins can access /organization/new
|
||||
2. IF route is a LiteLLMRoutes.org_admin_only_routes, then check if user is an Org Admin for that organization
|
||||
|
||||
"""
|
||||
|
||||
if user_object is None:
|
||||
return
|
||||
|
||||
passed_organization_id: Optional[str] = request_body.get("organization_id", None)
|
||||
|
||||
if route == "/organization/new":
|
||||
if user_object.user_role != LitellmUserRoles.PROXY_ADMIN.value:
|
||||
raise ProxyException(
|
||||
message=f"Only proxy admins can create new organizations. You are {user_object.user_role}",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="user_role",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
if user_object.user_role == LitellmUserRoles.PROXY_ADMIN.value:
|
||||
return
|
||||
|
||||
# Checks if route is an Org Admin Only Route
|
||||
if route in LiteLLMRoutes.org_admin_only_routes.value:
|
||||
(
|
||||
_user_organizations,
|
||||
_user_organization_role_mapping,
|
||||
) = get_user_organization_info(user_object)
|
||||
|
||||
if user_object.organization_memberships is None:
|
||||
raise ProxyException(
|
||||
message=f"Tried to access route={route} but you are not a member of any organization. Please contact the proxy admin to request access.",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="organization_id",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
if passed_organization_id is None:
|
||||
raise ProxyException(
|
||||
message="Passed organization_id is None, please pass an organization_id in your request",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="organization_id",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
user_role: Optional[LitellmUserRoles] = _user_organization_role_mapping.get(
|
||||
passed_organization_id
|
||||
)
|
||||
if user_role is None:
|
||||
raise ProxyException(
|
||||
message=f"You do not have a role within the selected organization. Passed organization_id: {passed_organization_id}. Please contact the organization admin to request access.",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="organization_id",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
if user_role != LitellmUserRoles.ORG_ADMIN.value:
|
||||
raise ProxyException(
|
||||
message=f"You do not have the required role to perform {route} in Organization {passed_organization_id}. Your role is {user_role} in Organization {passed_organization_id}",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="user_role",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
elif route == "/team/new":
|
||||
# if user is part of multiple teams, then they need to specify the organization_id
|
||||
(
|
||||
_user_organizations,
|
||||
_user_organization_role_mapping,
|
||||
) = get_user_organization_info(user_object)
|
||||
if (
|
||||
user_object.organization_memberships is not None
|
||||
and len(user_object.organization_memberships) > 0
|
||||
):
|
||||
if passed_organization_id is None:
|
||||
raise ProxyException(
|
||||
message=f"Passed organization_id is None, please specify the organization_id in your request. You are part of multiple organizations: {_user_organizations}",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="organization_id",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
_user_role_in_passed_org = _user_organization_role_mapping.get(
|
||||
passed_organization_id
|
||||
)
|
||||
if _user_role_in_passed_org != LitellmUserRoles.ORG_ADMIN.value:
|
||||
raise ProxyException(
|
||||
message=f"You do not have the required role to call {route}. Your role is {_user_role_in_passed_org} in Organization {passed_organization_id}",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="user_role",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
|
||||
def get_user_organization_info(
|
||||
user_object: LiteLLM_UserTable,
|
||||
) -> Tuple[List[str], Dict[str, Optional[LitellmUserRoles]]]:
|
||||
"""
|
||||
Helper function to extract user organization information.
|
||||
|
||||
Args:
|
||||
user_object (LiteLLM_UserTable): The user object containing organization memberships.
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], Dict[str, Optional[LitellmUserRoles]]]: A tuple containing:
|
||||
- List of organization IDs the user is a member of
|
||||
- Dictionary mapping organization IDs to user roles
|
||||
"""
|
||||
_user_organizations: List[str] = []
|
||||
_user_organization_role_mapping: Dict[str, Optional[LitellmUserRoles]] = {}
|
||||
|
||||
if user_object.organization_memberships is not None:
|
||||
for _membership in user_object.organization_memberships:
|
||||
if _membership.organization_id is not None:
|
||||
_user_organizations.append(_membership.organization_id)
|
||||
_user_organization_role_mapping[_membership.organization_id] = _membership.user_role # type: ignore
|
||||
|
||||
return _user_organizations, _user_organization_role_mapping
|
||||
|
||||
|
||||
def _user_is_org_admin(
|
||||
request_data: dict,
|
||||
user_object: Optional[LiteLLM_UserTable] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Helper function to check if user is an org admin for any of the passed organizations.
|
||||
|
||||
Checks both:
|
||||
- `organization_id` (singular string) — legacy callers
|
||||
- `organizations` (list of strings) — used by /user/new
|
||||
"""
|
||||
if user_object is None:
|
||||
return False
|
||||
|
||||
if user_object.organization_memberships is None:
|
||||
return False
|
||||
|
||||
# Collect candidate org IDs from both fields
|
||||
candidate_org_ids: List[str] = []
|
||||
singular = request_data.get("organization_id", None)
|
||||
if singular is not None:
|
||||
candidate_org_ids.append(singular)
|
||||
orgs_list = request_data.get("organizations", None)
|
||||
if isinstance(orgs_list, list):
|
||||
candidate_org_ids.extend(orgs_list)
|
||||
|
||||
if not candidate_org_ids:
|
||||
return False
|
||||
|
||||
for _membership in user_object.organization_memberships:
|
||||
if _membership.organization_id in candidate_org_ids:
|
||||
if _membership.user_role == LitellmUserRoles.ORG_ADMIN.value:
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Handles Authentication Errors
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import ProxyErrorTypes, ProxyException, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.auth_utils import _get_request_ip_address
|
||||
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class UserAPIKeyAuthExceptionHandler:
|
||||
@staticmethod
|
||||
async def _handle_authentication_error(
|
||||
e: Exception,
|
||||
request: Request,
|
||||
request_data: dict,
|
||||
route: str,
|
||||
parent_otel_span: Optional[Span],
|
||||
api_key: str,
|
||||
) -> UserAPIKeyAuth:
|
||||
"""
|
||||
Handles Connection Errors when reading a Virtual Key from LiteLLM DB
|
||||
Use this if you don't want failed DB queries to block LLM API reqiests
|
||||
|
||||
Reliability scenarios this covers:
|
||||
- DB is down and having an outage
|
||||
- Unable to read / recover a key from the DB
|
||||
|
||||
Returns:
|
||||
- UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True
|
||||
|
||||
Raises:
|
||||
- Original Exception in all other cases
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
litellm_proxy_admin_name,
|
||||
proxy_logging_obj,
|
||||
)
|
||||
|
||||
if (
|
||||
PrismaDBExceptionHandler.should_allow_request_on_db_unavailable()
|
||||
and PrismaDBExceptionHandler.is_database_connection_error(e)
|
||||
):
|
||||
# log this as a DB failure on prometheus
|
||||
proxy_logging_obj.service_logging_obj.service_failure_hook(
|
||||
service=ServiceTypes.DB,
|
||||
call_type="get_key_object",
|
||||
error=e,
|
||||
duration=0.0,
|
||||
)
|
||||
|
||||
return UserAPIKeyAuth(
|
||||
key_name="failed-to-connect-to-db",
|
||||
token="failed-to-connect-to-db",
|
||||
user_id=litellm_proxy_admin_name,
|
||||
request_route=route,
|
||||
)
|
||||
else:
|
||||
# raise the exception to the caller
|
||||
requester_ip = _get_request_ip_address(
|
||||
request=request,
|
||||
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
|
||||
)
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\nRequester IP Address:{}".format(
|
||||
str(e),
|
||||
requester_ip,
|
||||
),
|
||||
extra={"requester_ip": requester_ip},
|
||||
)
|
||||
|
||||
# Log this exception to OTEL, Datadog etc
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
parent_otel_span=parent_otel_span,
|
||||
api_key=api_key,
|
||||
request_route=route,
|
||||
)
|
||||
# Allow callbacks to transform the error response
|
||||
transformed_exception = await proxy_logging_obj.post_call_failure_hook(
|
||||
request_data=request_data,
|
||||
original_exception=e,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
error_type=ProxyErrorTypes.auth_error,
|
||||
route=route,
|
||||
)
|
||||
# Use transformed exception if callback returned one, otherwise use original
|
||||
if transformed_exception is not None:
|
||||
e = transformed_exception
|
||||
|
||||
if isinstance(e, litellm.BudgetExceededError):
|
||||
raise ProxyException(
|
||||
message=e.message,
|
||||
type=ProxyErrorTypes.budget_exceeded,
|
||||
param=None,
|
||||
code=400,
|
||||
)
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
@@ -0,0 +1,835 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from functools import lru_cache
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
from litellm import Router, provider_list
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import STANDARD_CUSTOMER_ID_HEADERS
|
||||
from litellm.proxy._types import *
|
||||
from litellm.types.router import CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS
|
||||
|
||||
|
||||
def _get_request_ip_address(
|
||||
request: Request, use_x_forwarded_for: Optional[bool] = False
|
||||
) -> Optional[str]:
|
||||
client_ip = None
|
||||
if use_x_forwarded_for is True and "x-forwarded-for" in request.headers:
|
||||
client_ip = request.headers["x-forwarded-for"]
|
||||
elif request.client is not None:
|
||||
client_ip = request.client.host
|
||||
else:
|
||||
client_ip = ""
|
||||
|
||||
return client_ip
|
||||
|
||||
|
||||
def _check_valid_ip(
|
||||
allowed_ips: Optional[List[str]],
|
||||
request: Request,
|
||||
use_x_forwarded_for: Optional[bool] = False,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Returns if ip is allowed or not
|
||||
"""
|
||||
if allowed_ips is None: # if not set, assume true
|
||||
return True, None
|
||||
|
||||
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
|
||||
client_ip = _get_request_ip_address(
|
||||
request=request, use_x_forwarded_for=use_x_forwarded_for
|
||||
)
|
||||
|
||||
# Check if IP address is allowed
|
||||
if client_ip not in allowed_ips:
|
||||
return False, client_ip
|
||||
|
||||
return True, client_ip
|
||||
|
||||
|
||||
def check_complete_credentials(request_body: dict) -> bool:
|
||||
"""
|
||||
if 'api_base' in request body. Check if complete credentials given. Prevent malicious attacks.
|
||||
"""
|
||||
given_model: Optional[str] = None
|
||||
|
||||
given_model = request_body.get("model")
|
||||
if given_model is None:
|
||||
return False
|
||||
|
||||
if (
|
||||
"sagemaker" in given_model
|
||||
or "bedrock" in given_model
|
||||
or "vertex_ai" in given_model
|
||||
or "vertex_ai_beta" in given_model
|
||||
):
|
||||
# complex credentials - easier to make a malicious request
|
||||
return False
|
||||
|
||||
if "api_key" in request_body:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_regex_or_str_match(request_body_value: Any, regex_str: str) -> bool:
|
||||
"""
|
||||
Check if request_body_value matches the regex_str or is equal to param
|
||||
"""
|
||||
if re.match(regex_str, request_body_value) or regex_str == request_body_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_param_allowed(
|
||||
param: str,
|
||||
request_body_value: Any,
|
||||
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if param is a str or dict and if request_body_value is in the list of allowed values
|
||||
"""
|
||||
if configurable_clientside_auth_params is None:
|
||||
return False
|
||||
|
||||
for item in configurable_clientside_auth_params:
|
||||
if isinstance(item, str) and param == item:
|
||||
return True
|
||||
elif isinstance(item, Dict):
|
||||
if param == "api_base" and check_regex_or_str_match(
|
||||
request_body_value=request_body_value,
|
||||
regex_str=item["api_base"],
|
||||
): # assume param is a regex
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _allow_model_level_clientside_configurable_parameters(
|
||||
model: str, param: str, request_body_value: Any, llm_router: Optional[Router]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if model is allowed to use configurable client-side params
|
||||
- get matching model
|
||||
- check if 'clientside_configurable_parameters' is set for model
|
||||
-
|
||||
"""
|
||||
if llm_router is None:
|
||||
return False
|
||||
# check if model is set
|
||||
model_info = llm_router.get_model_group_info(model_group=model)
|
||||
if model_info is None:
|
||||
# check if wildcard model is set
|
||||
if model.split("/", 1)[0] in provider_list:
|
||||
model_info = llm_router.get_model_group_info(
|
||||
model_group=model.split("/", 1)[0]
|
||||
)
|
||||
|
||||
if model_info is None:
|
||||
return False
|
||||
|
||||
if model_info is None or model_info.configurable_clientside_auth_params is None:
|
||||
return False
|
||||
|
||||
return _is_param_allowed(
|
||||
param=param,
|
||||
request_body_value=request_body_value,
|
||||
configurable_clientside_auth_params=model_info.configurable_clientside_auth_params,
|
||||
)
|
||||
|
||||
|
||||
def is_request_body_safe(
|
||||
request_body: dict, general_settings: dict, llm_router: Optional[Router], model: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the request body is safe.
|
||||
|
||||
A malicious user can set the api_base to their own domain and invoke POST /chat/completions to intercept and steal the OpenAI API key.
|
||||
Relevant issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997
|
||||
"""
|
||||
banned_params = ["api_base", "base_url"]
|
||||
|
||||
for param in banned_params:
|
||||
if (
|
||||
param in request_body
|
||||
and not check_complete_credentials( # allow client-credentials to be passed to proxy
|
||||
request_body=request_body
|
||||
)
|
||||
):
|
||||
if general_settings.get("allow_client_side_credentials") is True:
|
||||
return True
|
||||
elif (
|
||||
_allow_model_level_clientside_configurable_parameters(
|
||||
model=model,
|
||||
param=param,
|
||||
request_body_value=request_body[param],
|
||||
llm_router=llm_router,
|
||||
)
|
||||
is True
|
||||
):
|
||||
return True
|
||||
raise ValueError(
|
||||
f"Rejected Request: {param} is not allowed in request body. "
|
||||
"Enable with `general_settings::allow_client_side_credentials` on proxy config.yaml. "
|
||||
"Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def pre_db_read_auth_checks(
|
||||
request: Request,
|
||||
request_data: dict,
|
||||
route: str,
|
||||
):
|
||||
"""
|
||||
1. Checks if request size is under max_request_size_mb (if set)
|
||||
2. Check if request body is safe (example user has not set api_base in request body)
|
||||
3. Check if IP address is allowed (if set)
|
||||
4. Check if request route is an allowed route on the proxy (if set)
|
||||
|
||||
Returns:
|
||||
- True
|
||||
|
||||
Raises:
|
||||
- HTTPException if request fails initial auth checks
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings, llm_router, premium_user
|
||||
|
||||
# Check 1. request size
|
||||
await check_if_request_size_is_safe(request=request)
|
||||
|
||||
# Check 2. Request body is safe
|
||||
is_request_body_safe(
|
||||
request_body=request_data,
|
||||
general_settings=general_settings,
|
||||
llm_router=llm_router,
|
||||
model=request_data.get(
|
||||
"model", ""
|
||||
), # [TODO] use model passed in url as well (azure openai routes)
|
||||
)
|
||||
|
||||
# Check 3. Check if IP address is allowed
|
||||
is_valid_ip, passed_in_ip = _check_valid_ip(
|
||||
allowed_ips=general_settings.get("allowed_ips", None),
|
||||
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
|
||||
request=request,
|
||||
)
|
||||
|
||||
if not is_valid_ip:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Access forbidden: IP address {passed_in_ip} not allowed.",
|
||||
)
|
||||
|
||||
# Check 4. Check if request route is an allowed route on the proxy
|
||||
if "allowed_routes" in general_settings:
|
||||
_allowed_routes = general_settings["allowed_routes"]
|
||||
if premium_user is not True:
|
||||
verbose_proxy_logger.error(
|
||||
f"Trying to set allowed_routes. This is an Enterprise feature. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
if route not in _allowed_routes:
|
||||
verbose_proxy_logger.error(
|
||||
f"Route {route} not in allowed_routes={_allowed_routes}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Access forbidden: Route {route} not allowed",
|
||||
)
|
||||
|
||||
|
||||
def route_in_additonal_public_routes(current_route: str):
|
||||
"""
|
||||
Helper to check if the user defined public_routes on config.yaml
|
||||
|
||||
Parameters:
|
||||
- current_route: str - the route the user is trying to call
|
||||
|
||||
Returns:
|
||||
- bool - True if the route is defined in public_routes
|
||||
- bool - False if the route is not defined in public_routes
|
||||
|
||||
Supports wildcard patterns (e.g., "/api/*" matches "/api/users", "/api/users/123")
|
||||
|
||||
In order to use this the litellm config.yaml should have the following in general_settings:
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate", "/api/*"]
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.proxy_server import general_settings, premium_user
|
||||
|
||||
try:
|
||||
if premium_user is not True:
|
||||
return False
|
||||
if general_settings is None:
|
||||
return False
|
||||
|
||||
routes_defined = general_settings.get("public_routes", [])
|
||||
|
||||
# Check exact match first
|
||||
if current_route in routes_defined:
|
||||
return True
|
||||
|
||||
# Check wildcard patterns
|
||||
for route_pattern in routes_defined:
|
||||
if RouteChecks._route_matches_wildcard_pattern(
|
||||
route=current_route, pattern=route_pattern
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"route_in_additonal_public_routes: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def get_request_route(request: Request) -> str:
|
||||
"""
|
||||
Helper to get the route from the request
|
||||
|
||||
remove base url from path if set e.g. `/genai/chat/completions` -> `/chat/completions
|
||||
"""
|
||||
try:
|
||||
if hasattr(request, "base_url") and request.url.path.startswith(
|
||||
request.base_url.path
|
||||
):
|
||||
# remove base_url from path
|
||||
return request.url.path[len(request.base_url.path) - 1 :]
|
||||
else:
|
||||
return request.url.path
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"error on get_request_route: {str(e)}, defaulting to request.url.path={request.url.path}"
|
||||
)
|
||||
return request.url.path
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def normalize_request_route(route: str) -> str:
|
||||
"""
|
||||
Normalize request routes by replacing dynamic path parameters with placeholders.
|
||||
|
||||
This prevents high cardinality in Prometheus metrics by collapsing routes like:
|
||||
- /v1/responses/1234567890 -> /v1/responses/{response_id}
|
||||
- /v1/threads/thread_123 -> /v1/threads/{thread_id}
|
||||
|
||||
Args:
|
||||
route: The request route path
|
||||
|
||||
Returns:
|
||||
Normalized route with dynamic parameters replaced by placeholders
|
||||
|
||||
Examples:
|
||||
>>> normalize_request_route("/v1/responses/abc123")
|
||||
'/v1/responses/{response_id}'
|
||||
>>> normalize_request_route("/v1/responses/abc123/cancel")
|
||||
'/v1/responses/{response_id}/cancel'
|
||||
>>> normalize_request_route("/chat/completions")
|
||||
'/chat/completions'
|
||||
"""
|
||||
# Define patterns for routes with dynamic IDs
|
||||
# Format: (regex_pattern, replacement_template)
|
||||
patterns = [
|
||||
# Responses API - must come before generic patterns
|
||||
(r"^(/(?:openai/)?v1/responses)/([^/]+)(/input_items)$", r"\1/{response_id}\3"),
|
||||
(r"^(/(?:openai/)?v1/responses)/([^/]+)(/cancel)$", r"\1/{response_id}\3"),
|
||||
(r"^(/(?:openai/)?v1/responses)/([^/]+)$", r"\1/{response_id}"),
|
||||
(r"^(/responses)/([^/]+)(/input_items)$", r"\1/{response_id}\3"),
|
||||
(r"^(/responses)/([^/]+)(/cancel)$", r"\1/{response_id}\3"),
|
||||
(r"^(/responses)/([^/]+)$", r"\1/{response_id}"),
|
||||
# Threads API
|
||||
(
|
||||
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/steps)/([^/]+)$",
|
||||
r"\1/{thread_id}\3/{run_id}\5/{step_id}",
|
||||
),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/steps)$",
|
||||
r"\1/{thread_id}\3/{run_id}\5",
|
||||
),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/cancel)$",
|
||||
r"\1/{thread_id}\3/{run_id}\5",
|
||||
),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/submit_tool_outputs)$",
|
||||
r"\1/{thread_id}\3/{run_id}\5",
|
||||
),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)$",
|
||||
r"\1/{thread_id}\3/{run_id}",
|
||||
),
|
||||
(r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)$", r"\1/{thread_id}\3"),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/threads)/([^/]+)(/messages)/([^/]+)$",
|
||||
r"\1/{thread_id}\3/{message_id}",
|
||||
),
|
||||
(r"^(/(?:openai/)?v1/threads)/([^/]+)(/messages)$", r"\1/{thread_id}\3"),
|
||||
(r"^(/(?:openai/)?v1/threads)/([^/]+)$", r"\1/{thread_id}"),
|
||||
# Vector Stores API
|
||||
(
|
||||
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/files)/([^/]+)$",
|
||||
r"\1/{vector_store_id}\3/{file_id}",
|
||||
),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/files)$",
|
||||
r"\1/{vector_store_id}\3",
|
||||
),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/file_batches)/([^/]+)$",
|
||||
r"\1/{vector_store_id}\3/{batch_id}",
|
||||
),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/file_batches)$",
|
||||
r"\1/{vector_store_id}\3",
|
||||
),
|
||||
(r"^(/(?:openai/)?v1/vector_stores)/([^/]+)$", r"\1/{vector_store_id}"),
|
||||
# Assistants API
|
||||
(r"^(/(?:openai/)?v1/assistants)/([^/]+)$", r"\1/{assistant_id}"),
|
||||
# Files API
|
||||
(r"^(/(?:openai/)?v1/files)/([^/]+)(/content)$", r"\1/{file_id}\3"),
|
||||
(r"^(/(?:openai/)?v1/files)/([^/]+)$", r"\1/{file_id}"),
|
||||
# Batches API
|
||||
(r"^(/(?:openai/)?v1/batches)/([^/]+)(/cancel)$", r"\1/{batch_id}\3"),
|
||||
(r"^(/(?:openai/)?v1/batches)/([^/]+)$", r"\1/{batch_id}"),
|
||||
# Fine-tuning API
|
||||
(
|
||||
r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)(/events)$",
|
||||
r"\1/{fine_tuning_job_id}\3",
|
||||
),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)(/cancel)$",
|
||||
r"\1/{fine_tuning_job_id}\3",
|
||||
),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)(/checkpoints)$",
|
||||
r"\1/{fine_tuning_job_id}\3",
|
||||
),
|
||||
(r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)$", r"\1/{fine_tuning_job_id}"),
|
||||
# Models API
|
||||
(r"^(/(?:openai/)?v1/models)/([^/]+)$", r"\1/{model}"),
|
||||
]
|
||||
|
||||
# Apply patterns in order
|
||||
for pattern, replacement in patterns:
|
||||
normalized = re.sub(pattern, replacement, route)
|
||||
if normalized != route:
|
||||
return normalized
|
||||
|
||||
# Return original route if no pattern matched
|
||||
return route
|
||||
|
||||
|
||||
async def check_if_request_size_is_safe(request: Request) -> bool:
|
||||
"""
|
||||
Enterprise Only:
|
||||
- Checks if the request size is within the limit
|
||||
|
||||
Args:
|
||||
request (Request): The incoming request.
|
||||
|
||||
Returns:
|
||||
bool: True if the request size is within the limit
|
||||
|
||||
Raises:
|
||||
ProxyException: If the request size is too large
|
||||
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings, premium_user
|
||||
|
||||
max_request_size_mb = general_settings.get("max_request_size_mb", None)
|
||||
|
||||
if max_request_size_mb is not None:
|
||||
# Check if premium user
|
||||
if premium_user is not True:
|
||||
verbose_proxy_logger.warning(
|
||||
f"using max_request_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
return True
|
||||
|
||||
# Get the request body
|
||||
content_length = request.headers.get("content-length")
|
||||
|
||||
if content_length:
|
||||
header_size = int(content_length)
|
||||
header_size_mb = bytes_to_mb(bytes_value=header_size)
|
||||
verbose_proxy_logger.debug(
|
||||
f"content_length request size in MB={header_size_mb}"
|
||||
)
|
||||
|
||||
if header_size_mb > max_request_size_mb:
|
||||
raise ProxyException(
|
||||
message=f"Request size is too large. Request size is {header_size_mb} MB. Max size is {max_request_size_mb} MB",
|
||||
type=ProxyErrorTypes.bad_request_error.value,
|
||||
code=400,
|
||||
param="content-length",
|
||||
)
|
||||
else:
|
||||
# If Content-Length is not available, read the body
|
||||
body = await request.body()
|
||||
body_size = len(body)
|
||||
request_size_mb = bytes_to_mb(bytes_value=body_size)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"request body request size in MB={request_size_mb}"
|
||||
)
|
||||
if request_size_mb > max_request_size_mb:
|
||||
raise ProxyException(
|
||||
message=f"Request size is too large. Request size is {request_size_mb} MB. Max size is {max_request_size_mb} MB",
|
||||
type=ProxyErrorTypes.bad_request_error.value,
|
||||
code=400,
|
||||
param="content-length",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def check_response_size_is_safe(response: Any) -> bool:
|
||||
"""
|
||||
Enterprise Only:
|
||||
- Checks if the response size is within the limit
|
||||
|
||||
Args:
|
||||
response (Any): The response to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the response size is within the limit
|
||||
|
||||
Raises:
|
||||
ProxyException: If the response size is too large
|
||||
|
||||
"""
|
||||
|
||||
from litellm.proxy.proxy_server import general_settings, premium_user
|
||||
|
||||
max_response_size_mb = general_settings.get("max_response_size_mb", None)
|
||||
if max_response_size_mb is not None:
|
||||
# Check if premium user
|
||||
if premium_user is not True:
|
||||
verbose_proxy_logger.warning(
|
||||
f"using max_response_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
return True
|
||||
|
||||
response_size_mb = bytes_to_mb(bytes_value=sys.getsizeof(response))
|
||||
verbose_proxy_logger.debug(f"response size in MB={response_size_mb}")
|
||||
if response_size_mb > max_response_size_mb:
|
||||
raise ProxyException(
|
||||
message=f"Response size is too large. Response size is {response_size_mb} MB. Max size is {max_response_size_mb} MB",
|
||||
type=ProxyErrorTypes.bad_request_error.value,
|
||||
code=400,
|
||||
param="content-length",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def bytes_to_mb(bytes_value: int):
|
||||
"""
|
||||
Helper to convert bytes to MB
|
||||
"""
|
||||
return bytes_value / (1024 * 1024)
|
||||
|
||||
|
||||
# helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key
|
||||
def get_key_model_rpm_limit(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> Optional[Dict[str, int]]:
|
||||
"""
|
||||
Get the model rpm limit for a given api key.
|
||||
|
||||
Priority order (returns first found):
|
||||
1. Key metadata (model_rpm_limit)
|
||||
2. Key model_max_budget (rpm_limit per model)
|
||||
3. Team metadata (model_rpm_limit)
|
||||
"""
|
||||
# 1. Check key metadata first (takes priority)
|
||||
if user_api_key_dict.metadata:
|
||||
result = user_api_key_dict.metadata.get("model_rpm_limit")
|
||||
if result:
|
||||
return result
|
||||
|
||||
# 2. Check model_max_budget
|
||||
if user_api_key_dict.model_max_budget:
|
||||
model_rpm_limit: Dict[str, Any] = {}
|
||||
for model, budget in user_api_key_dict.model_max_budget.items():
|
||||
if isinstance(budget, dict) and budget.get("rpm_limit") is not None:
|
||||
model_rpm_limit[model] = budget["rpm_limit"]
|
||||
if model_rpm_limit:
|
||||
return model_rpm_limit
|
||||
|
||||
# 3. Fallback to team metadata
|
||||
if user_api_key_dict.team_metadata:
|
||||
return user_api_key_dict.team_metadata.get("model_rpm_limit")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_key_model_tpm_limit(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> Optional[Dict[str, int]]:
|
||||
"""
|
||||
Get the model tpm limit for a given api key.
|
||||
|
||||
Priority order (returns first found):
|
||||
1. Key metadata (model_tpm_limit)
|
||||
2. Key model_max_budget (tpm_limit per model)
|
||||
3. Team metadata (model_tpm_limit)
|
||||
"""
|
||||
# 1. Check key metadata first (takes priority)
|
||||
if user_api_key_dict.metadata:
|
||||
result = user_api_key_dict.metadata.get("model_tpm_limit")
|
||||
if result:
|
||||
return result
|
||||
|
||||
# 2. Check model_max_budget (iterate per-model like RPM does)
|
||||
if user_api_key_dict.model_max_budget:
|
||||
model_tpm_limit: Dict[str, Any] = {}
|
||||
for model, budget in user_api_key_dict.model_max_budget.items():
|
||||
if isinstance(budget, dict) and budget.get("tpm_limit") is not None:
|
||||
model_tpm_limit[model] = budget["tpm_limit"]
|
||||
if model_tpm_limit:
|
||||
return model_tpm_limit
|
||||
|
||||
# 3. Fallback to team metadata
|
||||
if user_api_key_dict.team_metadata:
|
||||
return user_api_key_dict.team_metadata.get("model_tpm_limit")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_model_rate_limit_from_metadata(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
metadata_accessor_key: Literal["team_metadata", "organization_metadata"],
|
||||
rate_limit_key: Literal["model_rpm_limit", "model_tpm_limit"],
|
||||
) -> Optional[Dict[str, int]]:
|
||||
if getattr(user_api_key_dict, metadata_accessor_key):
|
||||
return getattr(user_api_key_dict, metadata_accessor_key).get(rate_limit_key)
|
||||
return None
|
||||
|
||||
|
||||
def get_team_model_rpm_limit(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> Optional[Dict[str, int]]:
|
||||
if user_api_key_dict.team_metadata:
|
||||
return user_api_key_dict.team_metadata.get("model_rpm_limit")
|
||||
return None
|
||||
|
||||
|
||||
def get_team_model_tpm_limit(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> Optional[Dict[str, int]]:
|
||||
if user_api_key_dict.team_metadata:
|
||||
return user_api_key_dict.team_metadata.get("model_tpm_limit")
|
||||
return None
|
||||
|
||||
|
||||
def is_pass_through_provider_route(route: str) -> bool:
|
||||
PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES = [
|
||||
"vertex-ai",
|
||||
]
|
||||
|
||||
# check if any of the prefixes are in the route
|
||||
for prefix in PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES:
|
||||
if prefix in route:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _has_user_setup_sso():
|
||||
"""
|
||||
Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID or generic client ID and UI username environment variables.
|
||||
Returns a boolean indicating whether SSO has been set up.
|
||||
"""
|
||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||
|
||||
sso_setup = (
|
||||
(microsoft_client_id is not None)
|
||||
or (google_client_id is not None)
|
||||
or (generic_client_id is not None)
|
||||
)
|
||||
|
||||
return sso_setup
|
||||
|
||||
|
||||
def get_customer_user_header_from_mapping(user_id_mapping) -> Optional[str]:
|
||||
"""Return the header_name mapped to CUSTOMER role, if any (dict-based)."""
|
||||
if not user_id_mapping:
|
||||
return None
|
||||
items = user_id_mapping if isinstance(user_id_mapping, list) else [user_id_mapping]
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
role = item.get("litellm_user_role")
|
||||
header_name = item.get("header_name")
|
||||
if role is None or not header_name:
|
||||
continue
|
||||
if str(role).lower() == str(LitellmUserRoles.CUSTOMER).lower():
|
||||
return header_name
|
||||
return None
|
||||
|
||||
|
||||
def _get_customer_id_from_standard_headers(
|
||||
request_headers: Optional[dict],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Check standard customer ID headers for a customer/end-user ID.
|
||||
|
||||
This enables tools like Claude Code to pass customer IDs via ANTHROPIC_CUSTOM_HEADERS.
|
||||
No configuration required - these headers are always checked.
|
||||
|
||||
Args:
|
||||
request_headers: The request headers dict
|
||||
|
||||
Returns:
|
||||
The customer ID if found in standard headers, None otherwise
|
||||
"""
|
||||
if request_headers is None:
|
||||
return None
|
||||
|
||||
for standard_header in STANDARD_CUSTOMER_ID_HEADERS:
|
||||
for header_name, header_value in request_headers.items():
|
||||
if header_name.lower() == standard_header.lower():
|
||||
user_id_str = str(header_value) if header_value is not None else ""
|
||||
if user_id_str.strip():
|
||||
return user_id_str
|
||||
return None
|
||||
|
||||
|
||||
def get_end_user_id_from_request_body(
|
||||
request_body: dict, request_headers: Optional[dict] = None
|
||||
) -> Optional[str]:
|
||||
# Import general_settings here to avoid potential circular import issues at module level
|
||||
# and to ensure it's fetched at runtime.
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
# Check 1: Standard customer ID headers (always checked, no configuration required)
|
||||
customer_id = _get_customer_id_from_standard_headers(
|
||||
request_headers=request_headers
|
||||
)
|
||||
if customer_id is not None:
|
||||
return customer_id
|
||||
|
||||
# Check 2: Follow the user header mappings feature, if not found, then check for deprecated user_header_name (only if request_headers is provided)
|
||||
# User query: "system not respecting user_header_name property"
|
||||
# This implies the key in general_settings is 'user_header_name'.
|
||||
if request_headers is not None:
|
||||
custom_header_name_to_check: Optional[str] = None
|
||||
|
||||
# Prefer user mappings (new behavior)
|
||||
user_id_mapping = general_settings.get("user_header_mappings", None)
|
||||
if user_id_mapping:
|
||||
custom_header_name_to_check = get_customer_user_header_from_mapping(
|
||||
user_id_mapping
|
||||
)
|
||||
|
||||
# Fallback to deprecated user_header_name if mapping did not specify
|
||||
if not custom_header_name_to_check:
|
||||
user_id_header_config_key = "user_header_name"
|
||||
value = general_settings.get(user_id_header_config_key)
|
||||
if isinstance(value, str) and value.strip() != "":
|
||||
custom_header_name_to_check = value
|
||||
|
||||
# If we have a header name to check, try to read it from request headers
|
||||
if isinstance(custom_header_name_to_check, str):
|
||||
for header_name, header_value in request_headers.items():
|
||||
if header_name.lower() == custom_header_name_to_check.lower():
|
||||
user_id_from_header = header_value
|
||||
user_id_str = (
|
||||
str(user_id_from_header)
|
||||
if user_id_from_header is not None
|
||||
else ""
|
||||
)
|
||||
if user_id_str.strip():
|
||||
return user_id_str
|
||||
|
||||
# Check 3: 'user' field in request_body (commonly OpenAI)
|
||||
if "user" in request_body and request_body["user"] is not None:
|
||||
user_from_body_user_field = request_body["user"]
|
||||
return str(user_from_body_user_field)
|
||||
|
||||
# Check 4: 'litellm_metadata.user' in request_body (commonly Anthropic)
|
||||
litellm_metadata = request_body.get("litellm_metadata")
|
||||
if isinstance(litellm_metadata, dict):
|
||||
user_from_litellm_metadata = litellm_metadata.get("user")
|
||||
if user_from_litellm_metadata is not None:
|
||||
return str(user_from_litellm_metadata)
|
||||
|
||||
# Check 5: 'metadata.user_id' in request_body (another common pattern)
|
||||
metadata_dict = request_body.get("metadata")
|
||||
if isinstance(metadata_dict, dict):
|
||||
user_id_from_metadata_field = metadata_dict.get("user_id")
|
||||
if user_id_from_metadata_field is not None:
|
||||
return str(user_id_from_metadata_field)
|
||||
|
||||
# Check 6: 'safety_identifier' in request body (OpenAI Responses API parameter)
|
||||
# SECURITY NOTE: safety_identifier can be set by any caller in the request body.
|
||||
# Only use this for end-user identification in trusted environments where you control
|
||||
# the calling application. For untrusted callers, prefer using headers or server-side
|
||||
# middleware to set the end_user_id to prevent impersonation.
|
||||
if request_body.get("safety_identifier") is not None:
|
||||
user_from_body_user_field = request_body["safety_identifier"]
|
||||
return str(user_from_body_user_field)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_model_from_request(
|
||||
request_data: dict, route: str
|
||||
) -> Optional[Union[str, List[str]]]:
|
||||
# First try to get model from request_data
|
||||
model = request_data.get("model") or request_data.get("target_model_names")
|
||||
|
||||
if model is not None:
|
||||
model_names = model.split(",")
|
||||
if len(model_names) == 1:
|
||||
model = model_names[0].strip()
|
||||
else:
|
||||
model = [m.strip() for m in model_names]
|
||||
|
||||
# If model not in request_data, try to extract from route
|
||||
if model is None:
|
||||
# Parse model from route that follows the pattern /openai/deployments/{model}/*
|
||||
match = re.match(r"/openai/deployments/([^/]+)", route)
|
||||
if match:
|
||||
model = match.group(1)
|
||||
|
||||
# If still not found, extract model from Google generateContent-style routes.
|
||||
# These routes put the model in the path and allow "/" inside the model id.
|
||||
# Examples:
|
||||
# - /v1beta/models/gemini-2.0-flash:generateContent
|
||||
# - /v1beta/models/bedrock/claude-sonnet-3.7:generateContent
|
||||
# - /models/custom/ns/model:streamGenerateContent
|
||||
if model is None and not route.lower().startswith("/vertex"):
|
||||
google_match = re.search(r"/(?:v1beta|beta)/models/([^:]+):", route)
|
||||
if google_match:
|
||||
model = google_match.group(1)
|
||||
|
||||
if model is None and not route.lower().startswith("/vertex"):
|
||||
google_match = re.search(r"^/models/([^:]+):", route)
|
||||
if google_match:
|
||||
model = google_match.group(1)
|
||||
|
||||
# If still not found, extract from Vertex AI passthrough route
|
||||
# Pattern: /vertex_ai/.../models/{model_id}:*
|
||||
# Example: /vertex_ai/v1/.../models/gemini-1.5-pro:generateContent
|
||||
if model is None and route.lower().startswith("/vertex"):
|
||||
vertex_match = re.search(r"/models/([^:]+)", route)
|
||||
if vertex_match:
|
||||
model = vertex_match.group(1)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def abbreviate_api_key(api_key: str) -> str:
|
||||
return f"sk-...{api_key[-4:]}"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
IP address utilities for MCP public/private access control.
|
||||
|
||||
Internal callers (private IPs) see all MCP servers.
|
||||
External callers (public IPs) only see servers with available_on_public_internet=True.
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.auth.auth_utils import _get_request_ip_address
|
||||
|
||||
|
||||
class IPAddressUtils:
|
||||
"""Static utilities for IP-based MCP access control."""
|
||||
|
||||
_DEFAULT_INTERNAL_NETWORKS = [
|
||||
ipaddress.ip_network("10.0.0.0/8"),
|
||||
ipaddress.ip_network("172.16.0.0/12"),
|
||||
ipaddress.ip_network("192.168.0.0/16"),
|
||||
ipaddress.ip_network("127.0.0.0/8"),
|
||||
ipaddress.ip_network("::1/128"),
|
||||
ipaddress.ip_network("fc00::/7"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def parse_internal_networks(
|
||||
configured_ranges: Optional[List[str]],
|
||||
) -> List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]:
|
||||
"""Parse configured CIDR ranges into network objects, falling back to defaults."""
|
||||
if not configured_ranges:
|
||||
return IPAddressUtils._DEFAULT_INTERNAL_NETWORKS
|
||||
networks: List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]] = []
|
||||
for cidr in configured_ranges:
|
||||
try:
|
||||
networks.append(ipaddress.ip_network(cidr, strict=False))
|
||||
except ValueError:
|
||||
verbose_proxy_logger.warning(
|
||||
"Invalid CIDR in mcp_internal_ip_ranges: %s, skipping", cidr
|
||||
)
|
||||
return networks if networks else IPAddressUtils._DEFAULT_INTERNAL_NETWORKS
|
||||
|
||||
@staticmethod
|
||||
def parse_trusted_proxy_networks(
|
||||
configured_ranges: Optional[List[str]],
|
||||
) -> List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]:
|
||||
"""
|
||||
Parse trusted proxy CIDR ranges for XFF validation.
|
||||
Returns empty list if not configured (XFF will not be trusted).
|
||||
"""
|
||||
if not configured_ranges:
|
||||
return []
|
||||
networks: List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]] = []
|
||||
for cidr in configured_ranges:
|
||||
try:
|
||||
networks.append(ipaddress.ip_network(cidr, strict=False))
|
||||
except ValueError:
|
||||
verbose_proxy_logger.warning(
|
||||
"Invalid CIDR in mcp_trusted_proxy_ranges: %s, skipping", cidr
|
||||
)
|
||||
return networks
|
||||
|
||||
@staticmethod
|
||||
def is_trusted_proxy(
|
||||
proxy_ip: Optional[str],
|
||||
trusted_networks: List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]],
|
||||
) -> bool:
|
||||
"""Check if the direct connection IP is from a trusted proxy."""
|
||||
if not proxy_ip or not trusted_networks:
|
||||
return False
|
||||
try:
|
||||
addr = ipaddress.ip_address(proxy_ip.strip())
|
||||
return any(addr in network for network in trusted_networks)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_internal_ip(
|
||||
client_ip: Optional[str],
|
||||
internal_networks: Optional[
|
||||
List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]
|
||||
] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a client IP is from an internal/private network.
|
||||
|
||||
Handles X-Forwarded-For comma chains (takes leftmost = original client).
|
||||
Fails closed: empty/invalid IPs are treated as external.
|
||||
"""
|
||||
if not client_ip:
|
||||
return False
|
||||
|
||||
# X-Forwarded-For may contain comma-separated chain; leftmost is original client
|
||||
if "," in client_ip:
|
||||
client_ip = client_ip.split(",")[0].strip()
|
||||
|
||||
networks = internal_networks or IPAddressUtils._DEFAULT_INTERNAL_NETWORKS
|
||||
|
||||
try:
|
||||
addr = ipaddress.ip_address(client_ip.strip())
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
return any(addr in network for network in networks)
|
||||
|
||||
@staticmethod
|
||||
def get_mcp_client_ip(
|
||||
request: Request,
|
||||
general_settings: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Extract client IP from a FastAPI request for MCP access control.
|
||||
|
||||
Security: Only trusts X-Forwarded-For if:
|
||||
1. use_x_forwarded_for is enabled in settings
|
||||
2. The direct connection is from a trusted proxy (if mcp_trusted_proxy_ranges configured)
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
general_settings: Optional settings dict. If not provided, imports from proxy_server.
|
||||
"""
|
||||
if general_settings is None:
|
||||
try:
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings as proxy_general_settings,
|
||||
)
|
||||
|
||||
general_settings = proxy_general_settings
|
||||
except ImportError:
|
||||
general_settings = {}
|
||||
|
||||
# Handle case where general_settings is still None after import
|
||||
if general_settings is None:
|
||||
general_settings = {}
|
||||
|
||||
use_xff = general_settings.get("use_x_forwarded_for", False)
|
||||
|
||||
# If XFF is enabled, validate the request comes from a trusted proxy
|
||||
if use_xff and "x-forwarded-for" in request.headers:
|
||||
trusted_ranges = general_settings.get("mcp_trusted_proxy_ranges")
|
||||
if trusted_ranges:
|
||||
# Validate direct connection is from trusted proxy
|
||||
direct_ip = request.client.host if request.client else None
|
||||
trusted_networks = IPAddressUtils.parse_trusted_proxy_networks(
|
||||
trusted_ranges
|
||||
)
|
||||
if not IPAddressUtils.is_trusted_proxy(direct_ip, trusted_networks):
|
||||
# Untrusted source trying to set XFF - ignore XFF, use direct IP
|
||||
verbose_proxy_logger.warning(
|
||||
"XFF header from untrusted IP %s, ignoring", direct_ip
|
||||
)
|
||||
return direct_ip
|
||||
return _get_request_ip_address(request, use_x_forwarded_for=use_xff)
|
||||
@@ -0,0 +1,214 @@
|
||||
# What is this?
|
||||
## If litellm license in env, checks if it's valid
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import NON_LLM_CONNECTION_TIMEOUT
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy._types import EnterpriseLicenseData
|
||||
|
||||
|
||||
class LicenseCheck:
|
||||
"""
|
||||
- Check if license in env
|
||||
- Returns if license is valid
|
||||
"""
|
||||
|
||||
base_url = "https://license.litellm.ai"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.license_str = os.getenv("LITELLM_LICENSE", None)
|
||||
verbose_proxy_logger.debug("License Str value - {}".format(self.license_str))
|
||||
self.http_handler = HTTPHandler(timeout=NON_LLM_CONNECTION_TIMEOUT)
|
||||
self._premium_check_logged = False
|
||||
self.public_key = None
|
||||
self.read_public_key()
|
||||
self.airgapped_license_data: Optional["EnterpriseLicenseData"] = None
|
||||
|
||||
def read_public_key(self):
|
||||
try:
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
# current dir
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
# check if public_key.pem exists
|
||||
_path_to_public_key = os.path.join(current_dir, "public_key.pem")
|
||||
if os.path.exists(_path_to_public_key):
|
||||
with open(_path_to_public_key, "rb") as key_file:
|
||||
self.public_key = serialization.load_pem_public_key(key_file.read())
|
||||
else:
|
||||
self.public_key = None
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error reading public key: {str(e)}")
|
||||
|
||||
def _verify(self, license_str: str) -> bool:
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.auth.litellm_license.py::_verify - Checking license against {}/verify_license - {}".format(
|
||||
self.base_url, license_str
|
||||
)
|
||||
)
|
||||
url = "{}/verify_license/{}".format(self.base_url, license_str)
|
||||
|
||||
response: Optional[httpx.Response] = None
|
||||
try: # don't impact user, if call fails
|
||||
num_retries = 3
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
response = self.http_handler.get(url=url)
|
||||
if response is None:
|
||||
raise Exception("No response from license server")
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError:
|
||||
if i == num_retries - 1:
|
||||
raise
|
||||
|
||||
if response is None:
|
||||
raise Exception("No response from license server")
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
premium = response_json["verify"]
|
||||
|
||||
assert isinstance(premium, bool)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.auth.litellm_license.py::_verify - License={} is premium={}".format(
|
||||
license_str, premium
|
||||
)
|
||||
)
|
||||
return premium
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.auth.litellm_license.py::_verify - Unable to verify License={} via api. - {}".format(
|
||||
license_str, str(e)
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
||||
def is_premium(self) -> bool:
|
||||
"""
|
||||
1. verify_license_without_api_request: checks if license was generate using private / public key pair
|
||||
2. _verify: checks if license is valid calling litellm API. This is the old way we were generating/validating license
|
||||
"""
|
||||
try:
|
||||
if not self._premium_check_logged:
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.auth.litellm_license.py::is_premium() - ENTERING 'IS_PREMIUM' - LiteLLM License={}".format(
|
||||
self.license_str
|
||||
)
|
||||
)
|
||||
|
||||
if self.license_str is None:
|
||||
self.license_str = os.getenv("LITELLM_LICENSE", None)
|
||||
|
||||
if not self._premium_check_logged:
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.auth.litellm_license.py::is_premium() - Updated 'self.license_str' - {}".format(
|
||||
self.license_str
|
||||
)
|
||||
)
|
||||
self._premium_check_logged = True
|
||||
|
||||
if self.license_str is None:
|
||||
return False
|
||||
elif (
|
||||
self.verify_license_without_api_request(
|
||||
public_key=self.public_key, license_key=self.license_str
|
||||
)
|
||||
is True
|
||||
):
|
||||
return True
|
||||
elif self._verify(license_str=self.license_str) is True:
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def is_over_limit(self, total_users: int) -> bool:
|
||||
"""
|
||||
Check if the license is over the limit
|
||||
"""
|
||||
if self.airgapped_license_data is None:
|
||||
return False
|
||||
if "max_users" not in self.airgapped_license_data or not isinstance(
|
||||
self.airgapped_license_data["max_users"], int
|
||||
):
|
||||
return False
|
||||
return total_users > self.airgapped_license_data["max_users"]
|
||||
|
||||
def is_team_count_over_limit(self, team_count: int) -> bool:
|
||||
"""
|
||||
Check if the license is over the limit
|
||||
"""
|
||||
if self.airgapped_license_data is None:
|
||||
return False
|
||||
|
||||
_max_teams_in_license: Optional[int] = self.airgapped_license_data.get(
|
||||
"max_teams"
|
||||
)
|
||||
if "max_teams" not in self.airgapped_license_data or not isinstance(
|
||||
_max_teams_in_license, int
|
||||
):
|
||||
return False
|
||||
return team_count > _max_teams_in_license
|
||||
|
||||
def verify_license_without_api_request(self, public_key, license_key):
|
||||
try:
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
|
||||
from litellm.proxy._types import EnterpriseLicenseData
|
||||
|
||||
# Decode the license key - add padding if needed for base64
|
||||
# Base64 strings need to be a multiple of 4 characters
|
||||
padding_needed = len(license_key) % 4
|
||||
if padding_needed:
|
||||
license_key += "=" * (4 - padding_needed)
|
||||
|
||||
decoded = base64.b64decode(license_key)
|
||||
message, signature = decoded.split(b".", 1)
|
||||
|
||||
# Verify the signature
|
||||
public_key.verify(
|
||||
signature,
|
||||
message,
|
||||
padding.PSS(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
salt_length=padding.PSS.MAX_LENGTH,
|
||||
),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
|
||||
# Decode and parse the data
|
||||
license_data = json.loads(message.decode())
|
||||
|
||||
self.airgapped_license_data = EnterpriseLicenseData(**license_data)
|
||||
|
||||
# debug information provided in license data
|
||||
verbose_proxy_logger.debug("License data: %s", license_data)
|
||||
|
||||
# Check expiration date
|
||||
expiration_date = datetime.strptime(
|
||||
license_data["expiration_date"], "%Y-%m-%d"
|
||||
)
|
||||
if expiration_date < datetime.now():
|
||||
return False, "License has expired"
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.auth.litellm_license.py::verify_license_without_api_request - Unable to verify License locally. - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
return False
|
||||
@@ -0,0 +1,344 @@
|
||||
"""
|
||||
Login utilities for handling user authentication in the proxy server.
|
||||
|
||||
This module contains the core login logic that can be reused across different
|
||||
login endpoints (e.g., /login and /v2/login).
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from typing import Literal, Optional, cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm.constants import LITELLM_PROXY_ADMIN_NAME, LITELLM_UI_SESSION_DURATION
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_UserTable,
|
||||
LitellmUserRoles,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
UpdateUserRequest,
|
||||
UserAPIKeyAuth,
|
||||
hash_token,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import user_update
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
generate_key_helper_fn,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.ui_sso import (
|
||||
get_disabled_non_admin_personal_key_creation,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient, get_server_root_path
|
||||
from litellm.secret_managers.main import get_secret_bool
|
||||
from litellm.types.proxy.ui_sso import ReturnedUITokenObject
|
||||
|
||||
|
||||
def get_ui_credentials(master_key: Optional[str]) -> tuple[str, str]:
|
||||
"""
|
||||
Get UI username and password from environment variables or master key.
|
||||
|
||||
Args:
|
||||
master_key: Master key for the proxy (used as fallback for password)
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: A tuple containing (ui_username, ui_password)
|
||||
|
||||
Raises:
|
||||
ProxyException: If neither UI_PASSWORD nor master_key is available
|
||||
"""
|
||||
ui_username = os.getenv("UI_USERNAME", "admin")
|
||||
ui_password = os.getenv("UI_PASSWORD", None)
|
||||
if ui_password is None:
|
||||
ui_password = str(master_key) if master_key is not None else None
|
||||
if ui_password is None:
|
||||
raise ProxyException(
|
||||
message="set Proxy master key to use UI. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="UI_PASSWORD",
|
||||
code=500,
|
||||
)
|
||||
return ui_username, ui_password
|
||||
|
||||
|
||||
class LoginResult:
|
||||
"""Result object containing authentication data from login."""
|
||||
|
||||
user_id: str
|
||||
key: str
|
||||
user_email: Optional[str]
|
||||
user_role: str
|
||||
login_method: Literal["sso", "username_password"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
key: str,
|
||||
user_email: Optional[str],
|
||||
user_role: str,
|
||||
login_method: Literal["sso", "username_password"] = "username_password",
|
||||
):
|
||||
self.user_id = user_id
|
||||
self.key = key
|
||||
self.user_email = user_email
|
||||
self.user_role = user_role
|
||||
self.login_method = login_method
|
||||
|
||||
|
||||
async def authenticate_user( # noqa: PLR0915
|
||||
username: str,
|
||||
password: str,
|
||||
master_key: Optional[str],
|
||||
prisma_client: Optional[PrismaClient],
|
||||
) -> LoginResult:
|
||||
"""
|
||||
Authenticate a user and generate an API key for UI access.
|
||||
|
||||
This function handles two login scenarios:
|
||||
1. Admin login using UI_USERNAME and UI_PASSWORD
|
||||
2. User login using email and password from database
|
||||
|
||||
Args:
|
||||
username: Username or email from the login form
|
||||
password: Password from the login form
|
||||
master_key: Master key for the proxy (required)
|
||||
prisma_client: Prisma database client (optional)
|
||||
|
||||
Returns:
|
||||
LoginResult: Object containing authentication data
|
||||
|
||||
Raises:
|
||||
ProxyException: If authentication fails or required configuration is missing
|
||||
"""
|
||||
if master_key is None:
|
||||
raise ProxyException(
|
||||
message="Master Key not set for Proxy. Please set Master Key to use Admin UI. Set `LITELLM_MASTER_KEY` in .env or set general_settings:master_key in config.yaml. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="master_key",
|
||||
code=500,
|
||||
)
|
||||
|
||||
ui_username, ui_password = get_ui_credentials(master_key)
|
||||
|
||||
# Check if we can find the `username` in the db. On the UI, users can enter username=their email
|
||||
_user_row: Optional[LiteLLM_UserTable] = None
|
||||
user_role: Optional[
|
||||
Literal[
|
||||
LitellmUserRoles.PROXY_ADMIN,
|
||||
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
|
||||
LitellmUserRoles.INTERNAL_USER,
|
||||
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
|
||||
]
|
||||
] = None
|
||||
|
||||
if prisma_client is not None:
|
||||
_user_row = cast(
|
||||
Optional[LiteLLM_UserTable],
|
||||
await prisma_client.db.litellm_usertable.find_first(
|
||||
where={"user_email": {"equals": username, "mode": "insensitive"}}
|
||||
),
|
||||
)
|
||||
|
||||
"""
|
||||
To login to Admin UI, we support the following
|
||||
- Login with UI_USERNAME and UI_PASSWORD
|
||||
- Login with Invite Link `user_email` and `password` combination
|
||||
"""
|
||||
if secrets.compare_digest(
|
||||
username.encode("utf-8"), ui_username.encode("utf-8")
|
||||
) and secrets.compare_digest(password.encode("utf-8"), ui_password.encode("utf-8")):
|
||||
# Non SSO -> If user is using UI_USERNAME and UI_PASSWORD they are Proxy admin
|
||||
user_role = LitellmUserRoles.PROXY_ADMIN
|
||||
user_id = LITELLM_PROXY_ADMIN_NAME
|
||||
|
||||
# we want the key created to have PROXY_ADMIN_PERMISSIONS
|
||||
key_user_id = LITELLM_PROXY_ADMIN_NAME
|
||||
if (
|
||||
os.getenv("PROXY_ADMIN_ID", None) is not None
|
||||
and os.environ["PROXY_ADMIN_ID"] == user_id
|
||||
) or user_id == LITELLM_PROXY_ADMIN_NAME:
|
||||
# checks if user is admin
|
||||
key_user_id = os.getenv("PROXY_ADMIN_ID", LITELLM_PROXY_ADMIN_NAME)
|
||||
|
||||
# Admin is Authe'd in - generate key for the UI to access Proxy
|
||||
|
||||
# ensure this user is set as the proxy admin, in this route there is no sso, we can assume this user is only the admin
|
||||
await user_update(
|
||||
data=UpdateUserRequest(
|
||||
user_id=key_user_id,
|
||||
user_role=user_role,
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
),
|
||||
)
|
||||
|
||||
if os.getenv("DATABASE_URL") is not None:
|
||||
response = await generate_key_helper_fn(
|
||||
request_type="key",
|
||||
**{
|
||||
"user_role": LitellmUserRoles.PROXY_ADMIN,
|
||||
"duration": LITELLM_UI_SESSION_DURATION,
|
||||
"key_max_budget": litellm.max_ui_session_budget,
|
||||
"models": [],
|
||||
"aliases": {},
|
||||
"config": {},
|
||||
"spend": 0,
|
||||
"user_id": key_user_id,
|
||||
"team_id": "litellm-dashboard",
|
||||
}, # type: ignore
|
||||
)
|
||||
else:
|
||||
raise ProxyException(
|
||||
message="No Database connected. Set DATABASE_URL in .env. If set, use `--detailed_debug` to debug issue.",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="DATABASE_URL",
|
||||
code=500,
|
||||
)
|
||||
|
||||
key = response["token"] # type: ignore
|
||||
|
||||
if get_secret_bool("EXPERIMENTAL_UI_LOGIN"):
|
||||
from litellm.proxy.auth.auth_checks import ExperimentalUIJWTToken
|
||||
|
||||
user_info: Optional[LiteLLM_UserTable] = None
|
||||
if _user_row is not None:
|
||||
user_info = _user_row
|
||||
elif (
|
||||
user_id is not None
|
||||
): # if user_id is not None, we are using the UI_USERNAME and UI_PASSWORD
|
||||
user_info = LiteLLM_UserTable(
|
||||
user_id=user_id,
|
||||
user_role=user_role,
|
||||
models=[],
|
||||
max_budget=litellm.max_ui_session_budget,
|
||||
)
|
||||
if user_info is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": "User Information is required for experimental UI login"
|
||||
},
|
||||
)
|
||||
|
||||
key = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token(
|
||||
user_info
|
||||
)
|
||||
|
||||
return LoginResult(
|
||||
user_id=user_id,
|
||||
key=key,
|
||||
user_email=None,
|
||||
user_role=user_role,
|
||||
login_method="username_password",
|
||||
)
|
||||
|
||||
elif _user_row is not None:
|
||||
"""
|
||||
When sharing invite links
|
||||
|
||||
-> if the user has no role in the DB assume they are only a viewer
|
||||
"""
|
||||
user_id = getattr(_user_row, "user_id", "unknown")
|
||||
user_role = getattr(
|
||||
_user_row, "user_role", LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||
)
|
||||
user_email = getattr(_user_row, "user_email", "unknown")
|
||||
_password = getattr(_user_row, "password", "unknown")
|
||||
|
||||
if _password is None:
|
||||
raise ProxyException(
|
||||
message="User has no password set. Please set a password for the user via `/user/update`.",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="password",
|
||||
code=401,
|
||||
)
|
||||
|
||||
# check if password == _user_row.password
|
||||
hash_password = hash_token(token=password)
|
||||
if secrets.compare_digest(
|
||||
password.encode("utf-8"), _password.encode("utf-8")
|
||||
) or secrets.compare_digest(
|
||||
hash_password.encode("utf-8"), _password.encode("utf-8")
|
||||
):
|
||||
if os.getenv("DATABASE_URL") is not None:
|
||||
response = await generate_key_helper_fn(
|
||||
request_type="key",
|
||||
**{ # type: ignore
|
||||
"user_role": user_role,
|
||||
"duration": LITELLM_UI_SESSION_DURATION,
|
||||
"key_max_budget": litellm.max_ui_session_budget,
|
||||
"models": [],
|
||||
"aliases": {},
|
||||
"config": {},
|
||||
"spend": 0,
|
||||
"user_id": user_id,
|
||||
"team_id": "litellm-dashboard",
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise ProxyException(
|
||||
message="No Database connected. Set DATABASE_URL in .env. If set, use `--detailed_debug` to debug issue.",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="DATABASE_URL",
|
||||
code=500,
|
||||
)
|
||||
|
||||
key = response["token"] # type: ignore
|
||||
|
||||
return LoginResult(
|
||||
user_id=user_id,
|
||||
key=key,
|
||||
user_email=user_email,
|
||||
user_role=cast(str, user_role),
|
||||
login_method="username_password",
|
||||
)
|
||||
else:
|
||||
raise ProxyException(
|
||||
message=f"Invalid credentials used to access UI.\nNot valid credentials for {username}",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="invalid_credentials",
|
||||
code=401,
|
||||
)
|
||||
else:
|
||||
raise ProxyException(
|
||||
message="Invalid credentials used to access UI.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="invalid_credentials",
|
||||
code=401,
|
||||
)
|
||||
|
||||
|
||||
def create_ui_token_object(
|
||||
login_result: LoginResult,
|
||||
general_settings: dict,
|
||||
premium_user: bool,
|
||||
) -> ReturnedUITokenObject:
|
||||
"""
|
||||
Create a ReturnedUITokenObject from a LoginResult.
|
||||
|
||||
Args:
|
||||
login_result: The result from authenticate_user
|
||||
general_settings: General proxy settings dictionary
|
||||
premium_user: Whether premium features are enabled
|
||||
|
||||
Returns:
|
||||
ReturnedUITokenObject: Token object ready for JWT encoding
|
||||
"""
|
||||
disabled_non_admin_personal_key_creation = (
|
||||
get_disabled_non_admin_personal_key_creation()
|
||||
)
|
||||
|
||||
return ReturnedUITokenObject(
|
||||
user_id=login_result.user_id,
|
||||
key=login_result.key,
|
||||
user_email=login_result.user_email,
|
||||
user_role=login_result.user_role,
|
||||
login_method=login_result.login_method,
|
||||
premium_user=premium_user,
|
||||
auth_header_name=general_settings.get(
|
||||
"litellm_key_header_name", "Authorization"
|
||||
),
|
||||
disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation,
|
||||
server_root_path=get_server_root_path(),
|
||||
)
|
||||
@@ -0,0 +1,381 @@
|
||||
# What is this?
|
||||
## Common checks for /v1/models and `/model/info`
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth
|
||||
from litellm.router import Router
|
||||
from litellm.router_utils.fallback_event_handlers import get_fallback_model_group
|
||||
from litellm.types.router import LiteLLM_Params
|
||||
from litellm.utils import get_valid_models
|
||||
|
||||
|
||||
def _check_wildcard_routing(model: str) -> bool:
|
||||
"""
|
||||
Returns True if a model is a provider wildcard.
|
||||
|
||||
eg:
|
||||
- anthropic/*
|
||||
- openai/*
|
||||
- *
|
||||
"""
|
||||
if "*" in model:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_provider_models(
|
||||
provider: str, litellm_params: Optional[LiteLLM_Params] = None
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Returns the list of known models by provider
|
||||
"""
|
||||
if provider == "*":
|
||||
return get_valid_models(litellm_params=litellm_params)
|
||||
|
||||
if provider in litellm.models_by_provider:
|
||||
provider_models = get_valid_models(
|
||||
custom_llm_provider=provider, litellm_params=litellm_params
|
||||
)
|
||||
return provider_models
|
||||
return None
|
||||
|
||||
|
||||
def _get_models_from_access_groups(
|
||||
model_access_groups: Dict[str, List[str]],
|
||||
all_models: List[str],
|
||||
include_model_access_groups: Optional[bool] = False,
|
||||
) -> List[str]:
|
||||
idx_to_remove = []
|
||||
new_models = []
|
||||
for idx, model in enumerate(all_models):
|
||||
if model in model_access_groups:
|
||||
if (
|
||||
not include_model_access_groups
|
||||
): # remove access group, unless requested - e.g. when creating a key
|
||||
idx_to_remove.append(idx)
|
||||
new_models.extend(model_access_groups[model])
|
||||
|
||||
for idx in sorted(idx_to_remove, reverse=True):
|
||||
all_models.pop(idx)
|
||||
|
||||
all_models.extend(new_models)
|
||||
return all_models
|
||||
|
||||
|
||||
async def get_mcp_server_ids(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns the list of MCP server ids for a given key by querying the object_permission table
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
return []
|
||||
|
||||
if user_api_key_dict.object_permission_id is None:
|
||||
return []
|
||||
|
||||
# Make a direct SQL query to get just the mcp_servers
|
||||
try:
|
||||
result = await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
||||
where={"object_permission_id": user_api_key_dict.object_permission_id},
|
||||
)
|
||||
if result and result.mcp_servers:
|
||||
return result.mcp_servers
|
||||
return []
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def get_key_models(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
proxy_model_list: List[str],
|
||||
model_access_groups: Dict[str, List[str]],
|
||||
include_model_access_groups: Optional[bool] = False,
|
||||
only_model_access_groups: Optional[bool] = False,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns:
|
||||
- List of model name strings
|
||||
- Empty list if no models set
|
||||
- If model_access_groups is provided, only return models that are in the access groups
|
||||
- If include_model_access_groups is True, it includes the 'keys' of the model_access_groups
|
||||
in the response - {"beta-models": ["gpt-4", "claude-v1"]} -> returns 'beta-models'
|
||||
"""
|
||||
all_models: List[str] = []
|
||||
if len(user_api_key_dict.models) > 0:
|
||||
all_models = list(
|
||||
user_api_key_dict.models
|
||||
) # copy to avoid mutating cached objects
|
||||
if SpecialModelNames.all_team_models.value in all_models:
|
||||
all_models = list(
|
||||
user_api_key_dict.team_models
|
||||
) # copy to avoid mutating cached objects
|
||||
if SpecialModelNames.all_proxy_models.value in all_models:
|
||||
all_models = list(proxy_model_list) # copy to avoid mutating caller's list
|
||||
if include_model_access_groups:
|
||||
all_models.extend(model_access_groups.keys())
|
||||
|
||||
all_models = _get_models_from_access_groups(
|
||||
model_access_groups=model_access_groups,
|
||||
all_models=all_models,
|
||||
include_model_access_groups=include_model_access_groups,
|
||||
)
|
||||
|
||||
# deduplicate while preserving order
|
||||
all_models = list(dict.fromkeys(all_models))
|
||||
|
||||
verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
|
||||
return all_models
|
||||
|
||||
|
||||
def get_team_models(
|
||||
team_models: List[str],
|
||||
proxy_model_list: List[str],
|
||||
model_access_groups: Dict[str, List[str]],
|
||||
include_model_access_groups: Optional[bool] = False,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns:
|
||||
- List of model name strings
|
||||
- Empty list if no models set
|
||||
- If model_access_groups is provided, only return models that are in the access groups
|
||||
"""
|
||||
all_models_set: Set[str] = set()
|
||||
if len(team_models) > 0:
|
||||
all_models_set.update(team_models)
|
||||
if SpecialModelNames.all_team_models.value in all_models_set:
|
||||
all_models_set.update(team_models)
|
||||
if SpecialModelNames.all_proxy_models.value in all_models_set:
|
||||
all_models_set.update(proxy_model_list)
|
||||
if include_model_access_groups:
|
||||
all_models_set.update(model_access_groups.keys())
|
||||
|
||||
all_models = _get_models_from_access_groups(
|
||||
model_access_groups=model_access_groups,
|
||||
all_models=list(all_models_set),
|
||||
include_model_access_groups=include_model_access_groups,
|
||||
)
|
||||
|
||||
# deduplicate while preserving order
|
||||
all_models = list(dict.fromkeys(all_models))
|
||||
|
||||
verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models)))
|
||||
return all_models
|
||||
|
||||
|
||||
def get_complete_model_list(
|
||||
key_models: List[str],
|
||||
team_models: List[str],
|
||||
proxy_model_list: List[str],
|
||||
user_model: Optional[str],
|
||||
infer_model_from_keys: Optional[bool],
|
||||
return_wildcard_routes: Optional[bool] = False,
|
||||
llm_router: Optional[Router] = None,
|
||||
model_access_groups: Dict[str, List[str]] = {},
|
||||
include_model_access_groups: Optional[bool] = False,
|
||||
only_model_access_groups: Optional[bool] = False,
|
||||
) -> List[str]:
|
||||
"""Logic for returning complete model list for a given key + team pair"""
|
||||
|
||||
"""
|
||||
- If key list is empty -> defer to team list
|
||||
- If team list is empty -> defer to proxy model list
|
||||
|
||||
If list contains wildcard -> return known provider models
|
||||
"""
|
||||
|
||||
unique_models = []
|
||||
|
||||
def append_unique(models):
|
||||
for model in models:
|
||||
if model not in unique_models:
|
||||
unique_models.append(model)
|
||||
|
||||
if key_models:
|
||||
append_unique(key_models)
|
||||
elif team_models:
|
||||
append_unique(team_models)
|
||||
else:
|
||||
append_unique(proxy_model_list)
|
||||
if include_model_access_groups:
|
||||
append_unique(list(model_access_groups.keys())) # TODO: keys order
|
||||
|
||||
if user_model:
|
||||
append_unique([user_model])
|
||||
|
||||
if infer_model_from_keys:
|
||||
valid_models = get_valid_models()
|
||||
append_unique(valid_models)
|
||||
|
||||
if only_model_access_groups:
|
||||
model_access_groups_to_return: List[str] = []
|
||||
for model in unique_models:
|
||||
if model in model_access_groups:
|
||||
model_access_groups_to_return.append(model)
|
||||
return model_access_groups_to_return
|
||||
|
||||
all_wildcard_models = _get_wildcard_models(
|
||||
unique_models=unique_models,
|
||||
return_wildcard_routes=return_wildcard_routes,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
complete_model_list = unique_models + all_wildcard_models
|
||||
|
||||
return complete_model_list
|
||||
|
||||
|
||||
def get_known_models_from_wildcard(
|
||||
wildcard_model: str, litellm_params: Optional[LiteLLM_Params] = None
|
||||
) -> List[str]:
|
||||
try:
|
||||
wildcard_provider_prefix, wildcard_suffix = wildcard_model.split("/", 1)
|
||||
except ValueError: # safely fail
|
||||
return []
|
||||
|
||||
# Use provider from litellm_params when available, otherwise from wildcard prefix
|
||||
# (e.g., "openai" from "openai/*" - needed for BYOK where wildcard isn't in router)
|
||||
if litellm_params is not None:
|
||||
try:
|
||||
provider = litellm_params.model.split("/", 1)[0]
|
||||
except ValueError:
|
||||
provider = wildcard_provider_prefix
|
||||
else:
|
||||
provider = wildcard_provider_prefix
|
||||
|
||||
# get all known provider models
|
||||
|
||||
wildcard_models = get_provider_models(
|
||||
provider=provider, litellm_params=litellm_params
|
||||
)
|
||||
|
||||
if wildcard_models is None:
|
||||
return []
|
||||
if wildcard_suffix != "*":
|
||||
## CHECK IF PARTIAL FILTER e.g. `gemini-*`
|
||||
model_prefix = wildcard_suffix.replace("*", "")
|
||||
|
||||
is_partial_filter = any(
|
||||
wc_model.startswith(model_prefix) for wc_model in wildcard_models
|
||||
)
|
||||
if is_partial_filter:
|
||||
filtered_wildcard_models = [
|
||||
wc_model
|
||||
for wc_model in wildcard_models
|
||||
if wc_model.startswith(model_prefix)
|
||||
]
|
||||
wildcard_models = filtered_wildcard_models
|
||||
else:
|
||||
# add model prefix to wildcard models
|
||||
wildcard_models = [f"{model_prefix}{model}" for model in wildcard_models]
|
||||
|
||||
suffix_appended_wildcard_models = []
|
||||
for model in wildcard_models:
|
||||
if not model.startswith(wildcard_provider_prefix):
|
||||
model = f"{wildcard_provider_prefix}/{model}"
|
||||
suffix_appended_wildcard_models.append(model)
|
||||
return suffix_appended_wildcard_models or []
|
||||
|
||||
|
||||
def _get_wildcard_models(
|
||||
unique_models: List[str],
|
||||
return_wildcard_routes: Optional[bool] = False,
|
||||
llm_router: Optional[Router] = None,
|
||||
) -> List[str]:
|
||||
models_to_remove = set()
|
||||
all_wildcard_models = []
|
||||
for model in unique_models:
|
||||
if _check_wildcard_routing(model=model):
|
||||
if (
|
||||
return_wildcard_routes
|
||||
): # will add the wildcard route to the list eg: anthropic/*.
|
||||
all_wildcard_models.append(model)
|
||||
|
||||
## get litellm params from model
|
||||
if llm_router is not None:
|
||||
model_list = llm_router.get_model_list(model_name=model)
|
||||
if model_list:
|
||||
for router_model in model_list:
|
||||
wildcard_models = get_known_models_from_wildcard(
|
||||
wildcard_model=model,
|
||||
litellm_params=LiteLLM_Params(
|
||||
**router_model["litellm_params"] # type: ignore
|
||||
),
|
||||
)
|
||||
all_wildcard_models.extend(wildcard_models)
|
||||
else:
|
||||
# Router has no deployment for this wildcard (e.g., BYOK team models)
|
||||
# Fall back to expanding from known provider models
|
||||
wildcard_models = get_known_models_from_wildcard(
|
||||
wildcard_model=model, litellm_params=None
|
||||
)
|
||||
if wildcard_models:
|
||||
models_to_remove.add(model)
|
||||
all_wildcard_models.extend(wildcard_models)
|
||||
else:
|
||||
# get all known provider models
|
||||
wildcard_models = get_known_models_from_wildcard(
|
||||
wildcard_model=model, litellm_params=None
|
||||
)
|
||||
|
||||
if wildcard_models:
|
||||
models_to_remove.add(model)
|
||||
all_wildcard_models.extend(wildcard_models)
|
||||
|
||||
for model in models_to_remove:
|
||||
unique_models.remove(model)
|
||||
|
||||
return all_wildcard_models
|
||||
|
||||
|
||||
def get_all_fallbacks(
|
||||
model: str,
|
||||
llm_router: Optional[Router] = None,
|
||||
fallback_type: str = "general",
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get all fallbacks for a given model from the router's fallback configuration.
|
||||
|
||||
Args:
|
||||
model: The model name to get fallbacks for
|
||||
llm_router: The LiteLLM router instance
|
||||
fallback_type: Type of fallback ("general", "context_window", "content_policy")
|
||||
|
||||
Returns:
|
||||
List of fallback model names. Empty list if no fallbacks found.
|
||||
"""
|
||||
if llm_router is None:
|
||||
return []
|
||||
|
||||
# Get the appropriate fallback list based on type
|
||||
fallbacks_config: list = []
|
||||
if fallback_type == "general":
|
||||
fallbacks_config = getattr(llm_router, "fallbacks", [])
|
||||
elif fallback_type == "context_window":
|
||||
fallbacks_config = getattr(llm_router, "context_window_fallbacks", [])
|
||||
elif fallback_type == "content_policy":
|
||||
fallbacks_config = getattr(llm_router, "content_policy_fallbacks", [])
|
||||
else:
|
||||
verbose_proxy_logger.warning(f"Unknown fallback_type: {fallback_type}")
|
||||
return []
|
||||
|
||||
if not fallbacks_config:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Use existing function to get fallback model group
|
||||
fallback_model_group, _ = get_fallback_model_group(
|
||||
fallbacks=fallbacks_config, model_group=model
|
||||
)
|
||||
|
||||
if fallback_model_group is None:
|
||||
return []
|
||||
|
||||
return fallback_model_group
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error getting fallbacks for model {model}: {e}")
|
||||
return []
|
||||
@@ -0,0 +1,222 @@
|
||||
import base64
|
||||
import os
|
||||
from typing import Dict, Optional, Tuple, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
|
||||
|
||||
|
||||
class Oauth2Handler:
|
||||
"""
|
||||
Handles OAuth2 token validation.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _is_introspection_endpoint(
|
||||
token_info_endpoint: str,
|
||||
oauth_client_id: Optional[str],
|
||||
oauth_client_secret: Optional[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if this is an introspection endpoint (requires POST) or token info endpoint (uses GET).
|
||||
|
||||
Args:
|
||||
token_info_endpoint: The OAuth2 endpoint URL
|
||||
oauth_client_id: OAuth2 client ID
|
||||
oauth_client_secret: OAuth2 client secret
|
||||
|
||||
Returns:
|
||||
bool: True if this is an introspection endpoint
|
||||
"""
|
||||
return (
|
||||
"introspect" in token_info_endpoint.lower()
|
||||
and oauth_client_id is not None
|
||||
and oauth_client_secret is not None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_introspection_request(
|
||||
token: str,
|
||||
oauth_client_id: Optional[str],
|
||||
oauth_client_secret: Optional[str],
|
||||
) -> Tuple[Dict[str, str], Dict[str, str]]:
|
||||
"""
|
||||
Prepare headers and data for OAuth2 introspection endpoint (RFC 7662).
|
||||
|
||||
Args:
|
||||
token: The OAuth2 token to validate
|
||||
oauth_client_id: OAuth2 client ID
|
||||
oauth_client_secret: OAuth2 client secret
|
||||
|
||||
Returns:
|
||||
Tuple of (headers, data) for the introspection request
|
||||
"""
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
data = {"token": token}
|
||||
|
||||
# Add client authentication if credentials are provided
|
||||
if oauth_client_id and oauth_client_secret:
|
||||
# Use HTTP Basic authentication for client credentials
|
||||
credentials = base64.b64encode(
|
||||
f"{oauth_client_id}:{oauth_client_secret}".encode()
|
||||
).decode()
|
||||
headers["Authorization"] = f"Basic {credentials}"
|
||||
elif oauth_client_id:
|
||||
# For public clients, include client_id in the request body
|
||||
data["client_id"] = oauth_client_id
|
||||
|
||||
return headers, data
|
||||
|
||||
@staticmethod
|
||||
def _prepare_token_info_request(token: str) -> Dict[str, str]:
|
||||
"""
|
||||
Prepare headers for generic token info endpoint.
|
||||
|
||||
Args:
|
||||
token: The OAuth2 token to validate
|
||||
|
||||
Returns:
|
||||
Dict of headers for the token info request
|
||||
"""
|
||||
return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
|
||||
@staticmethod
|
||||
def _extract_user_info(
|
||||
response_data: Dict,
|
||||
user_id_field_name: str,
|
||||
user_role_field_name: str,
|
||||
user_team_id_field_name: str,
|
||||
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extract user information from OAuth2 response.
|
||||
|
||||
Args:
|
||||
response_data: The response data from OAuth2 endpoint
|
||||
user_id_field_name: Field name for user ID
|
||||
user_role_field_name: Field name for user role
|
||||
user_team_id_field_name: Field name for team ID
|
||||
|
||||
Returns:
|
||||
Tuple of (user_id, user_role, user_team_id)
|
||||
"""
|
||||
user_id = response_data.get(user_id_field_name)
|
||||
user_team_id = response_data.get(user_team_id_field_name)
|
||||
user_role = response_data.get(user_role_field_name)
|
||||
|
||||
return user_id, user_role, user_team_id
|
||||
|
||||
@staticmethod
|
||||
async def check_oauth2_token(token: str) -> UserAPIKeyAuth:
|
||||
"""
|
||||
Makes a request to the token introspection endpoint to validate the OAuth2 token.
|
||||
|
||||
This function implements OAuth2 token introspection according to RFC 7662.
|
||||
It supports both generic token info endpoints (GET) and OAuth2 introspection endpoints (POST).
|
||||
|
||||
Args:
|
||||
token (str): The OAuth2 token to validate.
|
||||
|
||||
Returns:
|
||||
UserAPIKeyAuth: If the token is valid, containing user information.
|
||||
|
||||
Raises:
|
||||
ValueError: If the token is invalid, the request fails, or the token info endpoint is not set.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
"Oauth2 token validation is only available for premium users"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("Oauth2 token validation for token=%s", token)
|
||||
|
||||
# Get the token info endpoint from environment variable
|
||||
token_info_endpoint = os.getenv("OAUTH_TOKEN_INFO_ENDPOINT")
|
||||
user_id_field_name = os.environ.get("OAUTH_USER_ID_FIELD_NAME", "sub")
|
||||
user_role_field_name = os.environ.get("OAUTH_USER_ROLE_FIELD_NAME", "role")
|
||||
user_team_id_field_name = os.environ.get(
|
||||
"OAUTH_USER_TEAM_ID_FIELD_NAME", "team_id"
|
||||
)
|
||||
|
||||
# OAuth2 client credentials for introspection endpoint authentication
|
||||
oauth_client_id = os.environ.get("OAUTH_CLIENT_ID")
|
||||
oauth_client_secret = os.environ.get("OAUTH_CLIENT_SECRET")
|
||||
|
||||
if not token_info_endpoint:
|
||||
raise ValueError(
|
||||
"OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set"
|
||||
)
|
||||
|
||||
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
|
||||
|
||||
# Determine if this is an introspection endpoint (requires POST) or token info endpoint (uses GET)
|
||||
is_introspection_endpoint = Oauth2Handler._is_introspection_endpoint(
|
||||
token_info_endpoint=token_info_endpoint,
|
||||
oauth_client_id=oauth_client_id,
|
||||
oauth_client_secret=oauth_client_secret,
|
||||
)
|
||||
|
||||
try:
|
||||
if is_introspection_endpoint:
|
||||
# OAuth2 Token Introspection (RFC 7662) - requires POST with form data
|
||||
verbose_proxy_logger.debug("Using OAuth2 introspection endpoint (POST)")
|
||||
|
||||
headers, data = Oauth2Handler._prepare_introspection_request(
|
||||
token=token,
|
||||
oauth_client_id=oauth_client_id,
|
||||
oauth_client_secret=oauth_client_secret,
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
token_info_endpoint, headers=headers, data=data
|
||||
)
|
||||
else:
|
||||
# Generic token info endpoint - uses GET with Bearer token
|
||||
verbose_proxy_logger.debug("Using generic token info endpoint (GET)")
|
||||
headers = Oauth2Handler._prepare_token_info_request(token=token)
|
||||
response = await client.get(token_info_endpoint, headers=headers)
|
||||
|
||||
# if it's a bad token we expect it to raise an HTTPStatusError
|
||||
response.raise_for_status()
|
||||
|
||||
# If we get here, the request was successful
|
||||
data = response.json()
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Oauth2 token validation for token=%s, response from endpoint=%s",
|
||||
token,
|
||||
data,
|
||||
)
|
||||
|
||||
# For introspection endpoints, check if token is active
|
||||
if is_introspection_endpoint and not data.get("active", True):
|
||||
raise ValueError("Token is not active")
|
||||
|
||||
# Extract user information from response
|
||||
user_id, user_role, user_team_id = Oauth2Handler._extract_user_info(
|
||||
response_data=data,
|
||||
user_id_field_name=user_id_field_name,
|
||||
user_role_field_name=user_role_field_name,
|
||||
user_team_id_field_name=user_team_id_field_name,
|
||||
)
|
||||
|
||||
return UserAPIKeyAuth(
|
||||
api_key=token,
|
||||
team_id=user_team_id,
|
||||
user_id=user_id,
|
||||
user_role=cast(LitellmUserRoles, user_role),
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
# This will catch any 4xx or 5xx errors
|
||||
raise ValueError(f"Oauth 2.0 Token validation failed: {e}")
|
||||
except Exception as e:
|
||||
# This will catch any other errors (like network issues)
|
||||
raise ValueError(f"An error occurred during token validation: {e}")
|
||||
@@ -0,0 +1,45 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
async def handle_oauth2_proxy_request(request: Request) -> UserAPIKeyAuth:
|
||||
"""
|
||||
Handle request from oauth2 proxy.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
verbose_proxy_logger.debug("Handling oauth2 proxy request")
|
||||
# Define the OAuth2 config mappings
|
||||
oauth2_config_mappings: Dict[str, str] = (
|
||||
general_settings.get("oauth2_config_mappings") or {}
|
||||
)
|
||||
verbose_proxy_logger.debug(f"Oauth2 config mappings: {oauth2_config_mappings}")
|
||||
|
||||
if not oauth2_config_mappings:
|
||||
raise ValueError("Oauth2 config mappings not found in general_settings")
|
||||
# Initialize a dictionary to store the mapped values
|
||||
auth_data: Dict[str, Any] = {}
|
||||
|
||||
# Extract values from headers based on the mappings
|
||||
for key, header in oauth2_config_mappings.items():
|
||||
value = request.headers.get(header)
|
||||
if value:
|
||||
# Convert max_budget to float if present
|
||||
if key == "max_budget":
|
||||
auth_data[key] = float(value)
|
||||
# Convert models to list if present
|
||||
elif key == "models":
|
||||
auth_data[key] = [model.strip() for model in value.split(",")]
|
||||
else:
|
||||
auth_data[key] = value
|
||||
verbose_proxy_logger.debug(
|
||||
f"Auth data before creating UserAPIKeyAuth object: {auth_data}"
|
||||
)
|
||||
user_api_key_auth = UserAPIKeyAuth(**auth_data)
|
||||
verbose_proxy_logger.debug(f"UserAPIKeyAuth object created: {user_api_key_auth}")
|
||||
# Create and return UserAPIKeyAuth object
|
||||
return user_api_key_auth
|
||||
@@ -0,0 +1,9 @@
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwcNBabWBZzrDhFAuA4Fh
|
||||
FhIcA3rF7vrLb8+1yhF2U62AghQp9nStyuJRjxMUuldWgJ1yRJ2s7UffVw5r8DeA
|
||||
dqXPD+w+3LCNwqJGaIKN08QGJXNArM3QtMaN0RTzAyQ4iibN1r6609W5muK9wGp0
|
||||
b1j5+iDUmf0ynItnhvaX6B8Xoaflc3WD/UBdrygLmsU5uR3XC86+/8ILoSZH3HtN
|
||||
6FJmWhlhjS2TR1cKZv8K5D0WuADTFf5MF8jYFR+uORPj5Pe/EJlLGN26Lfn2QnGu
|
||||
XgbPF6nCGwZ0hwH1Xkn3xzGaJ4xBEC761wqp5cHxWSDktHyFKnLbP3jVeegjVIHh
|
||||
pQIDAQAB
|
||||
-----END PUBLIC KEY-----
|
||||
@@ -0,0 +1,187 @@
|
||||
import os
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
def init_rds_client(
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
aws_web_identity_token: Optional[str] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
):
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
# Define the list of parameters to check
|
||||
params_to_check = [
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
for i, param in enumerate(params_to_check):
|
||||
if param and param.startswith("os.environ/"):
|
||||
params_to_check[i] = get_secret(param) # type: ignore
|
||||
# Assign updated values back to parameters
|
||||
(
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
) = params_to_check
|
||||
|
||||
### SET REGION NAME
|
||||
region_name = aws_region_name
|
||||
if aws_region_name:
|
||||
region_name = aws_region_name
|
||||
elif litellm_aws_region_name:
|
||||
region_name = litellm_aws_region_name
|
||||
elif standard_aws_region_name:
|
||||
region_name = standard_aws_region_name
|
||||
else:
|
||||
raise Exception(
|
||||
"AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
|
||||
)
|
||||
|
||||
import boto3
|
||||
|
||||
if isinstance(timeout, float):
|
||||
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) # type: ignore
|
||||
elif isinstance(timeout, httpx.Timeout):
|
||||
config = boto3.session.Config( # type: ignore
|
||||
connect_timeout=timeout.connect, read_timeout=timeout.read
|
||||
)
|
||||
else:
|
||||
config = boto3.session.Config() # type: ignore
|
||||
|
||||
### CHECK STS ###
|
||||
if (
|
||||
aws_web_identity_token is not None
|
||||
and aws_role_name is not None
|
||||
and aws_session_name is not None
|
||||
):
|
||||
try:
|
||||
oidc_token = open(aws_web_identity_token).read() # check if filepath
|
||||
except Exception:
|
||||
oidc_token = get_secret(aws_web_identity_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise Exception(
|
||||
"OIDC token could not be retrieved from secret manager.",
|
||||
)
|
||||
|
||||
sts_client = boto3.client("sts")
|
||||
|
||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||
sts_response = sts_client.assume_role_with_web_identity(
|
||||
RoleArn=aws_role_name,
|
||||
RoleSessionName=aws_session_name,
|
||||
WebIdentityToken=oidc_token,
|
||||
DurationSeconds=3600,
|
||||
)
|
||||
|
||||
client = boto3.client(
|
||||
service_name="rds",
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=region_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
elif aws_role_name is not None and aws_session_name is not None:
|
||||
# use sts if role name passed in
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
|
||||
sts_response = sts_client.assume_role(
|
||||
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||
)
|
||||
|
||||
client = boto3.client(
|
||||
service_name="rds",
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=region_name,
|
||||
config=config,
|
||||
)
|
||||
elif aws_access_key_id is not None:
|
||||
# uses auth params passed to completion
|
||||
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||
|
||||
client = boto3.client(
|
||||
service_name="rds",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=region_name,
|
||||
config=config,
|
||||
)
|
||||
elif aws_profile_name is not None:
|
||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||
|
||||
client = boto3.Session(profile_name=aws_profile_name).client(
|
||||
service_name="rds",
|
||||
region_name=region_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
else:
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# boto3 automatically reads env variables
|
||||
|
||||
client = boto3.client(
|
||||
service_name="rds",
|
||||
region_name=region_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
|
||||
def generate_iam_auth_token(
|
||||
db_host, db_port, db_user, client: Optional[Any] = None
|
||||
) -> str:
|
||||
from urllib.parse import quote
|
||||
|
||||
if client is None:
|
||||
boto_client = init_rds_client(
|
||||
aws_region_name=os.getenv("AWS_REGION_NAME"),
|
||||
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
aws_session_name=os.getenv("AWS_SESSION_NAME"),
|
||||
aws_profile_name=os.getenv("AWS_PROFILE_NAME"),
|
||||
aws_role_name=os.getenv("AWS_ROLE_NAME", os.getenv("AWS_ROLE_ARN")),
|
||||
aws_web_identity_token=os.getenv(
|
||||
"AWS_WEB_IDENTITY_TOKEN", os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
|
||||
),
|
||||
)
|
||||
else:
|
||||
boto_client = client
|
||||
|
||||
token = boto_client.generate_db_auth_token(
|
||||
DBHostname=db_host, Port=db_port, DBUsername=db_user
|
||||
)
|
||||
cleaned_token = quote(token, safe="")
|
||||
|
||||
return cleaned_token
|
||||
@@ -0,0 +1,669 @@
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
CommonProxyErrors,
|
||||
LiteLLM_UserTable,
|
||||
LiteLLMRoutes,
|
||||
LitellmUserRoles,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
|
||||
from .auth_checks_organization import _user_is_org_admin
|
||||
|
||||
|
||||
class RouteChecks:
|
||||
@staticmethod
|
||||
def should_call_route(route: str, valid_token: UserAPIKeyAuth):
|
||||
"""
|
||||
Check if management route is disabled and raise exception
|
||||
"""
|
||||
try:
|
||||
from litellm_enterprise.proxy.auth.route_checks import EnterpriseRouteChecks
|
||||
|
||||
EnterpriseRouteChecks.should_call_route(route=route)
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check if Virtual Key is allowed to call the route - Applies to all Roles
|
||||
RouteChecks.is_virtual_key_allowed_to_call_route(
|
||||
route=route, valid_token=valid_token
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def is_virtual_key_allowed_to_call_route(
|
||||
route: str, valid_token: UserAPIKeyAuth
|
||||
) -> bool:
|
||||
"""
|
||||
Raises Exception if Virtual Key is not allowed to call the route
|
||||
"""
|
||||
|
||||
# Only check if valid_token.allowed_routes is set and is a list with at least one item
|
||||
if valid_token.allowed_routes is None:
|
||||
return True
|
||||
if not isinstance(valid_token.allowed_routes, list):
|
||||
return True
|
||||
if len(valid_token.allowed_routes) == 0:
|
||||
return True
|
||||
|
||||
# explicit check for allowed routes (exact match or prefix match)
|
||||
for allowed_route in valid_token.allowed_routes:
|
||||
if RouteChecks._route_matches_allowed_route(
|
||||
route=route, allowed_route=allowed_route
|
||||
):
|
||||
return True
|
||||
|
||||
## check if 'allowed_route' is a field name in LiteLLMRoutes
|
||||
if any(
|
||||
allowed_route in LiteLLMRoutes._member_names_
|
||||
for allowed_route in valid_token.allowed_routes
|
||||
):
|
||||
for allowed_route in valid_token.allowed_routes:
|
||||
if allowed_route in LiteLLMRoutes._member_names_:
|
||||
if RouteChecks.check_route_access(
|
||||
route=route,
|
||||
allowed_routes=LiteLLMRoutes._member_map_[allowed_route].value,
|
||||
):
|
||||
return True
|
||||
|
||||
################################################
|
||||
# For llm_api_routes, also check registered pass-through endpoints
|
||||
################################################
|
||||
if allowed_route == "llm_api_routes":
|
||||
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||
InitPassThroughEndpointHelpers,
|
||||
)
|
||||
|
||||
if InitPassThroughEndpointHelpers.is_registered_pass_through_route(
|
||||
route=route
|
||||
):
|
||||
return True
|
||||
|
||||
# check if wildcard pattern is allowed
|
||||
for allowed_route in valid_token.allowed_routes:
|
||||
if RouteChecks._route_matches_wildcard_pattern(
|
||||
route=route, pattern=allowed_route
|
||||
):
|
||||
return True
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Virtual key is not allowed to call this route. Only allowed to call routes: {valid_token.allowed_routes}. Tried to call route: {route}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _mask_user_id(user_id: str) -> str:
|
||||
"""
|
||||
Mask user_id to prevent leaking sensitive information in error messages
|
||||
|
||||
Args:
|
||||
user_id (str): The user_id to mask
|
||||
|
||||
Returns:
|
||||
str: Masked user_id showing only first 2 and last 2 characters
|
||||
"""
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
|
||||
|
||||
if not user_id or len(user_id) <= 4:
|
||||
return "***"
|
||||
|
||||
# Use SensitiveDataMasker with custom configuration for user_id
|
||||
masker = SensitiveDataMasker(visible_prefix=6, visible_suffix=2, mask_char="*")
|
||||
|
||||
return masker._mask_value(user_id)
|
||||
|
||||
@staticmethod
|
||||
def _raise_admin_only_route_exception(
|
||||
user_obj: Optional[LiteLLM_UserTable],
|
||||
route: str,
|
||||
) -> None:
|
||||
"""
|
||||
Raise exception for routes that require proxy admin access
|
||||
|
||||
Args:
|
||||
user_obj (Optional[LiteLLM_UserTable]): The user object
|
||||
route (str): The route being accessed
|
||||
|
||||
Raises:
|
||||
Exception: With user role and masked user_id information
|
||||
"""
|
||||
user_role = "unknown"
|
||||
user_id = "unknown"
|
||||
if user_obj is not None:
|
||||
user_role = user_obj.user_role or "unknown"
|
||||
user_id = user_obj.user_id or "unknown"
|
||||
|
||||
masked_user_id = RouteChecks._mask_user_id(user_id)
|
||||
raise Exception(
|
||||
f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={masked_user_id}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def non_proxy_admin_allowed_routes_check(
|
||||
user_obj: Optional[LiteLLM_UserTable],
|
||||
_user_role: Optional[LitellmUserRoles],
|
||||
route: str,
|
||||
request: Request,
|
||||
valid_token: UserAPIKeyAuth,
|
||||
request_data: dict,
|
||||
):
|
||||
"""
|
||||
Checks if Non Proxy Admin User is allowed to access the route
|
||||
"""
|
||||
|
||||
# Check user has defined custom admin routes
|
||||
RouteChecks.custom_admin_only_route_check(
|
||||
route=route,
|
||||
)
|
||||
|
||||
if RouteChecks.is_llm_api_route(route=route):
|
||||
pass
|
||||
elif RouteChecks.is_info_route(route=route):
|
||||
# check if user allowed to call an info route
|
||||
if route == "/key/info":
|
||||
# handled by function itself
|
||||
pass
|
||||
elif route == "/user/info":
|
||||
# check if user can access this route
|
||||
query_params = request.query_params
|
||||
user_id = query_params.get("user_id")
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_id: {user_id} & valid_token.user_id: {valid_token.user_id}"
|
||||
)
|
||||
if user_id and user_id != valid_token.user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="key not allowed to access this user's info. user_id={}, key's user_id={}".format(
|
||||
user_id, valid_token.user_id
|
||||
),
|
||||
)
|
||||
elif route == "/model/info":
|
||||
# /model/info just shows models user has access to
|
||||
pass
|
||||
elif route == "/team/info":
|
||||
pass # handled by function itself
|
||||
elif (
|
||||
route in LiteLLMRoutes.global_spend_tracking_routes.value
|
||||
and getattr(valid_token, "permissions", None) is not None
|
||||
and "get_spend_routes" in getattr(valid_token, "permissions", [])
|
||||
):
|
||||
pass
|
||||
elif _user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value:
|
||||
RouteChecks._check_proxy_admin_viewer_access(
|
||||
route=route,
|
||||
_user_role=_user_role,
|
||||
request_data=request_data,
|
||||
)
|
||||
elif (
|
||||
_user_role == LitellmUserRoles.INTERNAL_USER.value
|
||||
and RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.internal_user_routes.value
|
||||
)
|
||||
):
|
||||
pass
|
||||
elif _user_is_org_admin(
|
||||
request_data=request_data, user_object=user_obj
|
||||
) and RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.org_admin_allowed_routes.value
|
||||
):
|
||||
pass
|
||||
elif (
|
||||
_user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
|
||||
and RouteChecks.check_route_access(
|
||||
route=route,
|
||||
allowed_routes=LiteLLMRoutes.internal_user_view_only_routes.value,
|
||||
)
|
||||
):
|
||||
pass
|
||||
elif RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.self_managed_routes.value
|
||||
): # routes that manage their own allowed/disallowed logic
|
||||
pass
|
||||
elif route.startswith("/v1/mcp/") or route.startswith("/mcp-rest/"):
|
||||
pass # authN/authZ handled by api itself
|
||||
elif RouteChecks.check_passthrough_route_access(
|
||||
route=route, user_api_key_dict=valid_token
|
||||
):
|
||||
pass
|
||||
elif valid_token.allowed_routes is not None:
|
||||
# check if route is in allowed_routes (exact match or prefix match)
|
||||
route_allowed = False
|
||||
for allowed_route in valid_token.allowed_routes:
|
||||
if RouteChecks._route_matches_allowed_route(
|
||||
route=route, allowed_route=allowed_route
|
||||
):
|
||||
route_allowed = True
|
||||
break
|
||||
|
||||
if RouteChecks._route_matches_wildcard_pattern(
|
||||
route=route, pattern=allowed_route
|
||||
):
|
||||
route_allowed = True
|
||||
break
|
||||
|
||||
if not route_allowed:
|
||||
RouteChecks._raise_admin_only_route_exception(
|
||||
user_obj=user_obj, route=route
|
||||
)
|
||||
else:
|
||||
RouteChecks._raise_admin_only_route_exception(
|
||||
user_obj=user_obj, route=route
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def custom_admin_only_route_check(route: str):
|
||||
from litellm.proxy.proxy_server import general_settings, premium_user
|
||||
|
||||
if "admin_only_routes" in general_settings:
|
||||
if premium_user is not True:
|
||||
verbose_proxy_logger.error(
|
||||
f"Trying to use 'admin_only_routes' this is an Enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
return
|
||||
if route in general_settings["admin_only_routes"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"user not allowed to access this route. Route={route} is an admin only route",
|
||||
)
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def is_llm_api_route(route: str) -> bool:
|
||||
"""
|
||||
Helper to checks if provided route is an OpenAI route
|
||||
|
||||
|
||||
Returns:
|
||||
- True: if route is an OpenAI route
|
||||
- False: if route is not an OpenAI route
|
||||
"""
|
||||
# Ensure route is a string before performing checks
|
||||
if not isinstance(route, str):
|
||||
return False
|
||||
|
||||
if route in LiteLLMRoutes.openai_routes.value:
|
||||
return True
|
||||
|
||||
if route in LiteLLMRoutes.anthropic_routes.value:
|
||||
return True
|
||||
|
||||
if route in LiteLLMRoutes.google_routes.value:
|
||||
return True
|
||||
|
||||
if RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.mcp_routes.value
|
||||
):
|
||||
return True
|
||||
|
||||
if RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.agent_routes.value
|
||||
):
|
||||
return True
|
||||
|
||||
if route in LiteLLMRoutes.litellm_native_routes.value:
|
||||
return True
|
||||
|
||||
# fuzzy match routes like "/v1/threads/thread_49EIN5QF32s4mH20M7GFKdlZ"
|
||||
# Check for routes with placeholders or wildcard patterns
|
||||
for openai_route in LiteLLMRoutes.openai_routes.value:
|
||||
# Replace placeholders with regex pattern
|
||||
# placeholders are written as "/threads/{thread_id}"
|
||||
if "{" in openai_route:
|
||||
if RouteChecks._route_matches_pattern(
|
||||
route=route, pattern=openai_route
|
||||
):
|
||||
return True
|
||||
# Check for wildcard patterns like "/containers/*"
|
||||
if RouteChecks._is_wildcard_pattern(pattern=openai_route):
|
||||
if RouteChecks._route_matches_wildcard_pattern(
|
||||
route=route, pattern=openai_route
|
||||
):
|
||||
return True
|
||||
|
||||
# Check for Google routes with placeholders like "/v1beta/models/{model_name}:generateContent"
|
||||
for google_route in LiteLLMRoutes.google_routes.value:
|
||||
if "{" in google_route:
|
||||
if RouteChecks._route_matches_pattern(
|
||||
route=route, pattern=google_route
|
||||
):
|
||||
return True
|
||||
|
||||
# Check for Anthropic routes with placeholders
|
||||
for anthropic_route in LiteLLMRoutes.anthropic_routes.value:
|
||||
if "{" in anthropic_route:
|
||||
if RouteChecks._route_matches_pattern(
|
||||
route=route, pattern=anthropic_route
|
||||
):
|
||||
return True
|
||||
|
||||
if RouteChecks._is_azure_openai_route(route=route):
|
||||
return True
|
||||
|
||||
for _llm_passthrough_route in LiteLLMRoutes.mapped_pass_through_routes.value:
|
||||
if _llm_passthrough_route in route:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_management_route(route: str) -> bool:
|
||||
"""
|
||||
Check if route is a management route
|
||||
"""
|
||||
return route in LiteLLMRoutes.management_routes.value
|
||||
|
||||
@staticmethod
|
||||
def is_info_route(route: str) -> bool:
|
||||
"""
|
||||
Check if route is an info route
|
||||
"""
|
||||
return route in LiteLLMRoutes.info_routes.value
|
||||
|
||||
@staticmethod
|
||||
def _is_azure_openai_route(route: str) -> bool:
|
||||
"""
|
||||
Check if route is a route from AzureOpenAI SDK client
|
||||
|
||||
eg.
|
||||
route='/openai/deployments/vertex_ai/gemini-1.5-flash/chat/completions'
|
||||
"""
|
||||
# Ensure route is a string before attempting regex matching
|
||||
if not isinstance(route, str):
|
||||
return False
|
||||
# Add support for deployment and engine model paths
|
||||
deployment_pattern = r"^/openai/deployments/[^/]+/[^/]+/chat/completions$"
|
||||
engine_pattern = r"^/engines/[^/]+/chat/completions$"
|
||||
|
||||
if re.match(deployment_pattern, route) or re.match(engine_pattern, route):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _route_matches_pattern(route: str, pattern: str) -> bool:
|
||||
"""
|
||||
Check if route matches the pattern placed in proxy/_types.py
|
||||
|
||||
Example:
|
||||
- pattern: "/threads/{thread_id}"
|
||||
- route: "/threads/thread_49EIN5QF32s4mH20M7GFKdlZ"
|
||||
- returns: True
|
||||
|
||||
|
||||
- pattern: "/key/{token_id}/regenerate"
|
||||
- route: "/key/regenerate/82akk800000000jjsk"
|
||||
- returns: False, pattern is "/key/{token_id}/regenerate"
|
||||
"""
|
||||
# Ensure route is a string before attempting regex matching
|
||||
if not isinstance(route, str):
|
||||
return False
|
||||
|
||||
def _placeholder_to_regex(match: re.Match) -> str:
|
||||
placeholder = match.group(0).strip("{}")
|
||||
if placeholder.endswith(":path"):
|
||||
# allow "/" in the placeholder value, but don't eat the route suffix after ":"
|
||||
return r"[^:]+"
|
||||
return r"[^/]+"
|
||||
|
||||
pattern = re.sub(r"\{[^}]+\}", _placeholder_to_regex, pattern)
|
||||
# Anchor the pattern to match the entire string
|
||||
pattern = f"^{pattern}$"
|
||||
if re.match(pattern, route):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_wildcard_pattern(pattern: str) -> bool:
|
||||
"""
|
||||
Check if pattern is a wildcard pattern
|
||||
"""
|
||||
return pattern.endswith("*")
|
||||
|
||||
@staticmethod
|
||||
def _route_matches_wildcard_pattern(route: str, pattern: str) -> bool:
|
||||
"""
|
||||
Check if route matches the wildcard pattern
|
||||
|
||||
eg.
|
||||
|
||||
pattern: "/scim/v2/*"
|
||||
route: "/scim/v2/Users"
|
||||
- returns: True
|
||||
|
||||
pattern: "/scim/v2/*"
|
||||
route: "/chat/completions"
|
||||
- returns: False
|
||||
|
||||
|
||||
pattern: "/scim/v2/*"
|
||||
route: "/scim/v2/Users/123"
|
||||
- returns: True
|
||||
|
||||
"""
|
||||
if pattern.endswith("*"):
|
||||
# Get the prefix (everything before the wildcard)
|
||||
prefix = pattern[:-1]
|
||||
return route.startswith(prefix)
|
||||
else:
|
||||
# If there's no wildcard, the pattern and route should match exactly
|
||||
return route == pattern
|
||||
|
||||
@staticmethod
|
||||
def _route_matches_allowed_route(route: str, allowed_route: str) -> bool:
|
||||
"""
|
||||
Check if route matches the allowed_route pattern.
|
||||
Supports both exact match and prefix match.
|
||||
|
||||
Examples:
|
||||
- allowed_route="/fake-openai-proxy-6", route="/fake-openai-proxy-6" -> True (exact match)
|
||||
- allowed_route="/fake-openai-proxy-6", route="/fake-openai-proxy-6/v1/chat/completions" -> True (prefix match)
|
||||
- allowed_route="/fake-openai-proxy-6", route="/fake-openai-proxy-600" -> False (not a valid prefix)
|
||||
|
||||
Args:
|
||||
route: The actual route being accessed
|
||||
allowed_route: The allowed route pattern
|
||||
|
||||
Returns:
|
||||
bool: True if route matches (exact or prefix), False otherwise
|
||||
"""
|
||||
# Exact match
|
||||
if route == allowed_route:
|
||||
return True
|
||||
# Prefix match - ensure we add "/" to prevent false matches like /fake-openai-proxy-600
|
||||
if route.startswith(allowed_route + "/"):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def check_route_access(route: str, allowed_routes: List[str]) -> bool:
|
||||
"""
|
||||
Check if a route has access by checking both exact matches and patterns
|
||||
|
||||
Args:
|
||||
route (str): The route to check
|
||||
allowed_routes (list): List of allowed routes/patterns
|
||||
|
||||
Returns:
|
||||
bool: True if route is allowed, False otherwise
|
||||
"""
|
||||
#########################################################
|
||||
# exact match route is in allowed_routes
|
||||
#########################################################
|
||||
if route in allowed_routes:
|
||||
return True
|
||||
|
||||
#########################################################
|
||||
# wildcard match route is in allowed_routes
|
||||
# e.g calling /anthropic/v1/messages is allowed if allowed_routes has /anthropic/*
|
||||
#########################################################
|
||||
wildcard_allowed_routes = [
|
||||
route
|
||||
for route in allowed_routes
|
||||
if RouteChecks._is_wildcard_pattern(pattern=route)
|
||||
]
|
||||
for allowed_route in wildcard_allowed_routes:
|
||||
if RouteChecks._route_matches_wildcard_pattern(
|
||||
route=route, pattern=allowed_route
|
||||
):
|
||||
return True
|
||||
|
||||
#########################################################
|
||||
# pattern match route is in allowed_routes
|
||||
# pattern: "/threads/{thread_id}"
|
||||
# route: "/threads/thread_49EIN5QF32s4mH20M7GFKdlZ"
|
||||
# returns: True
|
||||
#########################################################
|
||||
if any( # Check pattern match
|
||||
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
|
||||
for allowed_route in allowed_routes
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def check_passthrough_route_access(
|
||||
route: str, user_api_key_dict: UserAPIKeyAuth
|
||||
) -> bool:
|
||||
"""
|
||||
Check if route is a passthrough route.
|
||||
Supports both exact match and prefix match.
|
||||
"""
|
||||
metadata = user_api_key_dict.metadata
|
||||
team_metadata = user_api_key_dict.team_metadata or {}
|
||||
if metadata is None and team_metadata is None:
|
||||
return False
|
||||
if (
|
||||
"allowed_passthrough_routes" not in metadata
|
||||
and "allowed_passthrough_routes" not in team_metadata
|
||||
):
|
||||
return False
|
||||
if (
|
||||
metadata.get("allowed_passthrough_routes") is None
|
||||
and team_metadata.get("allowed_passthrough_routes") is None
|
||||
):
|
||||
return False
|
||||
|
||||
allowed_passthrough_routes = (
|
||||
metadata.get("allowed_passthrough_routes")
|
||||
or team_metadata.get("allowed_passthrough_routes")
|
||||
or []
|
||||
)
|
||||
|
||||
# Check if route matches any allowed passthrough route (exact or prefix match)
|
||||
for allowed_route in allowed_passthrough_routes:
|
||||
if RouteChecks._route_matches_allowed_route(
|
||||
route=route, allowed_route=allowed_route
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_assistants_api_request(request: Request) -> bool:
|
||||
"""
|
||||
Returns True if `thread` or `assistant` is in the request path
|
||||
|
||||
Args:
|
||||
request (Request): The request object
|
||||
|
||||
Returns:
|
||||
bool: True if `thread` or `assistant` is in the request path, False otherwise
|
||||
"""
|
||||
if "thread" in request.url.path or "assistant" in request.url.path:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_generate_content_route(route: str) -> bool:
|
||||
"""
|
||||
Returns True if this is a google generateContent or streamGenerateContent route
|
||||
|
||||
These routes from google allow passing key=api_key in the query params
|
||||
"""
|
||||
if "generateContent" in route:
|
||||
return True
|
||||
if "streamGenerateContent" in route:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _check_proxy_admin_viewer_access(
|
||||
route: str,
|
||||
_user_role: str,
|
||||
request_data: dict,
|
||||
) -> None:
|
||||
"""
|
||||
Check access for PROXY_ADMIN_VIEW_ONLY role
|
||||
"""
|
||||
if RouteChecks.is_llm_api_route(route=route):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"user not allowed to access this OpenAI routes, role= {_user_role}",
|
||||
)
|
||||
|
||||
# Check if this is a write operation on management routes
|
||||
if RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.management_routes.value
|
||||
):
|
||||
# For management routes, only allow read operations or specific allowed updates
|
||||
if route == "/user/update":
|
||||
# Check the Request params are valid for PROXY_ADMIN_VIEW_ONLY
|
||||
if request_data is not None and isinstance(request_data, dict):
|
||||
_params_updated = request_data.keys()
|
||||
for param in _params_updated:
|
||||
if param not in ["user_email", "password"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route} and updating invalid param: {param}. only user_email and password can be updated",
|
||||
)
|
||||
elif (
|
||||
route
|
||||
in [
|
||||
"/user/new",
|
||||
"/user/delete",
|
||||
"/team/new",
|
||||
"/team/update",
|
||||
"/team/delete",
|
||||
"/model/new",
|
||||
"/model/update",
|
||||
"/model/delete",
|
||||
"/key/generate",
|
||||
"/key/delete",
|
||||
"/key/update",
|
||||
"/key/regenerate",
|
||||
"/key/service-account/generate",
|
||||
"/key/block",
|
||||
"/key/unblock",
|
||||
]
|
||||
or route.startswith("/key/")
|
||||
and route.endswith("/regenerate")
|
||||
):
|
||||
# Block write operations for PROXY_ADMIN_VIEW_ONLY
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
|
||||
)
|
||||
# Allow read operations on management routes (like /user/info, /team/info, /model/info)
|
||||
return
|
||||
elif RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.admin_viewer_routes.value
|
||||
):
|
||||
# Allow access to admin viewer routes (read-only admin endpoints)
|
||||
return
|
||||
elif RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.global_spend_tracking_routes.value
|
||||
):
|
||||
# Allow access to global spend tracking routes (read-only spend endpoints)
|
||||
# proxy_admin_viewer role description: "view all keys, view all spend"
|
||||
return
|
||||
else:
|
||||
# For other routes, block access
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user