Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/proxy/litellm_pre_call_utils.py
2026-03-26 20:06:14 +08:00

1980 lines
74 KiB
Python

import asyncio
import copy
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from fastapi import Request
from starlette.datastructures import Headers
import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm._service_logger import ServiceLogging
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
from litellm.proxy._types import (
AddTeamCallback,
CommonProxyErrors,
LitellmDataForBackendLLMCall,
LitellmUserRoles,
SpecialHeaders,
TeamCallbackMetadata,
UserAPIKeyAuth,
)
from litellm.proxy.common_utils.http_parsing_utils import _safe_get_request_headers
# Cache special headers as a frozenset for O(1) lookup performance
_SPECIAL_HEADERS_CACHE = frozenset(
v.value.lower() for v in SpecialHeaders._member_map_.values()
)
from litellm.router import Router
from litellm.types.llms.anthropic import ANTHROPIC_API_HEADERS
from litellm.types.services import ServiceTypes
from litellm.types.utils import (
LlmProviders,
ProviderSpecificHeader,
StandardLoggingUserAPIKeyMetadata,
SupportedCacheControls,
)
service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
if TYPE_CHECKING:
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
from litellm.types.proxy.policy_engine import PolicyMatchContext
ProxyConfig = _ProxyConfig
else:
ProxyConfig = Any
PolicyMatchContext = Any
def parse_cache_control(cache_control):
cache_dict = {}
directives = cache_control.split(", ")
for directive in directives:
if "=" in directive:
key, value = directive.split("=")
cache_dict[key] = value
else:
cache_dict[directive] = True
return cache_dict
LITELLM_METADATA_ROUTES = (
"batches",
"/v1/messages",
"responses",
"files",
)
def _get_metadata_variable_name(request: Request) -> str:
"""
Helper to return what the "metadata" field should be called in the request data
For all /thread or /assistant endpoints we need to call this "litellm_metadata"
For ALL other endpoints we call this "metadata"
"""
path = request.url.path
if "thread" in path or "assistant" in path:
return "litellm_metadata"
if any(route in path for route in LITELLM_METADATA_ROUTES):
return "litellm_metadata"
return "metadata"
def get_chain_id_from_headers(headers: Optional[Dict[str, str]]) -> Optional[str]:
"""
Extract chain id for call chaining from request headers.
x-litellm-trace-id and x-litellm-session-id are interchangeable; when both
are present, x-litellm-trace-id takes precedence. Header keys are matched
case-insensitively so this works with raw header dicts from any transport.
Used by MCP (and other paths that have raw_headers but no Request) to set
litellm_trace_id/litellm_session_id for spend logs and logging consistency.
"""
if not headers:
return None
normalized = {k.lower(): v for k, v in headers.items() if isinstance(k, str)}
return normalized.get("x-litellm-trace-id") or normalized.get(
"x-litellm-session-id"
)
def safe_add_api_version_from_query_params(data: dict, request: Request):
try:
if hasattr(request, "query_params"):
query_params = dict(request.query_params)
if "api-version" in query_params:
data["api_version"] = query_params["api-version"]
except KeyError:
pass
except Exception as e:
verbose_logger.exception(
"error checking api version in query params: %s", str(e)
)
def convert_key_logging_metadata_to_callback(
data: AddTeamCallback, team_callback_settings_obj: Optional[TeamCallbackMetadata]
) -> TeamCallbackMetadata:
if team_callback_settings_obj is None:
team_callback_settings_obj = TeamCallbackMetadata()
if data.callback_type == "success":
if team_callback_settings_obj.success_callback is None:
team_callback_settings_obj.success_callback = []
if data.callback_name not in team_callback_settings_obj.success_callback:
team_callback_settings_obj.success_callback.append(data.callback_name)
elif data.callback_type == "failure":
if team_callback_settings_obj.failure_callback is None:
team_callback_settings_obj.failure_callback = []
if data.callback_name not in team_callback_settings_obj.failure_callback:
team_callback_settings_obj.failure_callback.append(data.callback_name)
elif (
not data.callback_type or data.callback_type == "success_and_failure"
): # assume 'success_and_failure' = litellm.callbacks
if team_callback_settings_obj.success_callback is None:
team_callback_settings_obj.success_callback = []
if team_callback_settings_obj.failure_callback is None:
team_callback_settings_obj.failure_callback = []
if team_callback_settings_obj.callbacks is None:
team_callback_settings_obj.callbacks = []
if data.callback_name not in team_callback_settings_obj.success_callback:
team_callback_settings_obj.success_callback.append(data.callback_name)
if data.callback_name not in team_callback_settings_obj.failure_callback:
team_callback_settings_obj.failure_callback.append(data.callback_name)
if data.callback_name not in team_callback_settings_obj.callbacks:
team_callback_settings_obj.callbacks.append(data.callback_name)
for var, value in data.callback_vars.items():
if team_callback_settings_obj.callback_vars is None:
team_callback_settings_obj.callback_vars = {}
team_callback_settings_obj.callback_vars[var] = str(
litellm.utils.get_secret(value, default_value=value) or value
)
return team_callback_settings_obj
class KeyAndTeamLoggingSettings:
"""
Helper class to get the dynamic logging settings for the key and team
"""
@staticmethod
def get_key_dynamic_logging_settings(user_api_key_dict: UserAPIKeyAuth):
if (
user_api_key_dict.metadata is not None
and "logging" in user_api_key_dict.metadata
):
return user_api_key_dict.metadata["logging"]
return None
@staticmethod
def get_team_dynamic_logging_settings(user_api_key_dict: UserAPIKeyAuth):
if (
user_api_key_dict.team_metadata is not None
and "logging" in user_api_key_dict.team_metadata
):
return user_api_key_dict.team_metadata["logging"]
return None
def _get_dynamic_logging_metadata(
user_api_key_dict: UserAPIKeyAuth, proxy_config: ProxyConfig
) -> Optional[TeamCallbackMetadata]:
callback_settings_obj: Optional[TeamCallbackMetadata] = None
key_dynamic_logging_settings: Optional[
dict
] = KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(user_api_key_dict)
team_dynamic_logging_settings: Optional[
dict
] = KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(user_api_key_dict)
#########################################################################################
# Key-based callbacks
#########################################################################################
if key_dynamic_logging_settings is not None:
for item in key_dynamic_logging_settings:
callback_settings_obj = convert_key_logging_metadata_to_callback(
data=AddTeamCallback(**item),
team_callback_settings_obj=callback_settings_obj,
)
#########################################################################################
# Team-based callbacks
#########################################################################################
elif team_dynamic_logging_settings is not None:
for item in team_dynamic_logging_settings:
callback_settings_obj = convert_key_logging_metadata_to_callback(
data=AddTeamCallback(**item),
team_callback_settings_obj=callback_settings_obj,
)
#########################################################################################
# Deprecated format - maintained for backwards compatibility
#########################################################################################
elif (
user_api_key_dict.team_metadata is not None
and "callback_settings" in user_api_key_dict.team_metadata
):
"""
callback_settings = {
{
'callback_vars': {'langfuse_public_key': 'pk', 'langfuse_secret_key': 'sk_'},
'failure_callback': [],
'success_callback': ['langfuse', 'langfuse']
}
}
"""
team_metadata = user_api_key_dict.team_metadata
callback_settings = team_metadata.get("callback_settings", None) or {}
callback_settings_obj = TeamCallbackMetadata(**callback_settings)
verbose_proxy_logger.debug(
"Team callback settings activated: %s", callback_settings_obj
)
#########################################################################################
# Enter here when configured on the config.yaml file.
#########################################################################################
elif user_api_key_dict.team_id is not None:
callback_settings_obj = (
LiteLLMProxyRequestSetup.add_team_based_callbacks_from_config(
team_id=user_api_key_dict.team_id, proxy_config=proxy_config
)
)
return callback_settings_obj
def clean_headers(
headers: Headers,
litellm_key_header_name: Optional[str] = None,
forward_llm_provider_auth_headers: bool = False,
authenticated_with_header: Optional[str] = None,
) -> dict:
"""
Removes litellm api key from headers
Args:
headers: Request headers
litellm_key_header_name: Custom header name for LiteLLM API key
forward_llm_provider_auth_headers: Whether to forward provider auth headers
authenticated_with_header: Which header was used for LiteLLM authentication
(e.g., "x-litellm-api-key", "authorization", "x-api-key")
Returns:
Cleaned headers dict
"""
from litellm.llms.anthropic.common_utils import is_anthropic_oauth_key
clean_headers = {}
litellm_key_lower = (
litellm_key_header_name.lower() if litellm_key_header_name is not None else None
)
for header, value in headers.items():
header_lower = header.lower()
if header_lower == "authorization" and is_anthropic_oauth_key(value):
if (
authenticated_with_header is None
or authenticated_with_header.lower() != "authorization"
):
clean_headers[header] = value
continue
# Special handling for x-api-key: forward it based on authenticated_with_header
elif header_lower == "x-api-key":
if forward_llm_provider_auth_headers and (
authenticated_with_header is None
or authenticated_with_header.lower() != "x-api-key"
):
clean_headers[header] = value
elif (
forward_llm_provider_auth_headers and header_lower in _SPECIAL_HEADERS_CACHE
):
if litellm_key_lower and header_lower == litellm_key_lower:
continue
if header_lower == "authorization":
continue
# Never forward x-litellm-api-key (it's for proxy auth only)
if header_lower == "x-litellm-api-key":
continue
clean_headers[header] = value
# Check if header should be excluded: either in special headers cache or matches custom litellm key
elif header_lower not in _SPECIAL_HEADERS_CACHE and (
litellm_key_lower is None or header_lower != litellm_key_lower
):
clean_headers[header] = value
return clean_headers
class LiteLLMProxyRequestSetup:
@staticmethod
def _get_timeout_from_request(headers: dict) -> Optional[float]:
"""
Workaround for client request from Vercel's AI SDK.
Allow's user to set a timeout in the request headers.
Example:
```js
const openaiProvider = createOpenAI({
baseURL: liteLLM.baseURL,
apiKey: liteLLM.apiKey,
compatibility: "compatible",
headers: {
"x-litellm-timeout": "90"
},
});
```
"""
timeout_header = headers.get("x-litellm-timeout", None)
if timeout_header is not None:
return float(timeout_header)
return None
@staticmethod
def _get_stream_timeout_from_request(headers: dict) -> Optional[float]:
"""
Get the `stream_timeout` from the request headers.
"""
stream_timeout_header = headers.get("x-litellm-stream-timeout", None)
if stream_timeout_header is not None:
return float(stream_timeout_header)
return None
@staticmethod
def _get_num_retries_from_request(headers: dict) -> Optional[int]:
"""
Workaround for client request from Vercel's AI SDK.
"""
num_retries_header = headers.get("x-litellm-num-retries", None)
if num_retries_header is not None:
return int(num_retries_header)
return None
@staticmethod
def _get_spend_logs_metadata_from_request_headers(headers: dict) -> Optional[dict]:
"""
Get the `spend_logs_metadata` from the request headers.
"""
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
spend_logs_metadata_header = headers.get("x-litellm-spend-logs-metadata", None)
if spend_logs_metadata_header is not None:
return safe_json_loads(spend_logs_metadata_header)
return None
@staticmethod
def _get_forwardable_headers(
headers: Union[Headers, dict],
):
"""
Get the headers that should be forwarded to the LLM Provider.
Looks for any `x-` headers and sends them to the LLM Provider.
[07/09/2025] - Support 'anthropic-beta' header as well.
"""
forwarded_headers = {}
for header, value in headers.items():
if header.lower().startswith("x-") and not header.lower().startswith(
"x-stainless"
): # causes openai sdk to fail
forwarded_headers[header] = value
elif header.lower().startswith("anthropic-beta"):
forwarded_headers[header] = value
return forwarded_headers
@staticmethod
def _get_case_insensitive_header(headers: dict, key: str) -> Optional[str]:
"""
Get a case-insensitive header from the headers dictionary.
"""
for header, value in headers.items():
if header.lower() == key.lower():
return value
return None
@staticmethod
def add_internal_user_from_user_mapping(
general_settings: Optional[Dict],
user_api_key_dict: UserAPIKeyAuth,
headers: dict,
) -> UserAPIKeyAuth:
if general_settings is None:
return user_api_key_dict
user_header_mapping = general_settings.get("user_header_mappings")
if not user_header_mapping:
return user_api_key_dict
header_name = LiteLLMProxyRequestSetup.get_internal_user_header_from_mapping(
user_header_mapping
)
if not header_name:
return user_api_key_dict
header_value = LiteLLMProxyRequestSetup._get_case_insensitive_header(
headers, header_name
)
if header_value:
user_api_key_dict.user_id = header_value
return user_api_key_dict
return user_api_key_dict
@staticmethod
def get_user_from_headers(
headers: dict, general_settings: Optional[Dict] = None
) -> Optional[str]:
"""
Get the user from the specified header if `general_settings.user_header_name` is set.
"""
if general_settings is None:
return None
header_name = general_settings.get("user_header_name")
if header_name is None or header_name == "":
return None
if not isinstance(header_name, str):
raise TypeError(
f"Expected user_header_name to be a str but got {type(header_name)}"
)
user = LiteLLMProxyRequestSetup._get_case_insensitive_header(
headers, header_name
)
if user is not None:
verbose_logger.info(f'found user "{user}" in header "{header_name}"')
return user
@staticmethod
def get_openai_org_id_from_headers(
headers: dict, general_settings: Optional[Dict] = None
) -> Optional[str]:
"""
Get the OpenAI Org ID from the headers.
"""
if (
general_settings is not None
and general_settings.get("forward_openai_org_id") is not True
):
return None
for header, value in headers.items():
if header.lower() == "openai-organization":
verbose_logger.info(f"found openai org id: {value}, sending to llm")
return value
return None
@staticmethod
def add_headers_to_llm_call(
headers: dict, user_api_key_dict: UserAPIKeyAuth
) -> dict:
"""
Add headers to the LLM call
- Checks request headers for forwardable headers
- Checks if user information should be added to the headers
"""
returned_headers = LiteLLMProxyRequestSetup._get_forwardable_headers(headers)
if litellm.add_user_information_to_llm_headers is True:
litellm_logging_metadata_headers = (
LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
user_api_key_dict=user_api_key_dict
)
)
for k, v in litellm_logging_metadata_headers.items():
if v is not None:
returned_headers["x-litellm-{}".format(k)] = v
return returned_headers
@staticmethod
def add_headers_to_llm_call_by_model_group(
data: dict, headers: dict, user_api_key_dict: UserAPIKeyAuth
) -> dict:
"""
Add headers to the LLM call by model group
"""
from litellm.proxy.auth.auth_checks import _check_model_access_helper
from litellm.proxy.proxy_server import llm_router
data_model = data.get("model")
if (
data_model is not None
and litellm.model_group_settings is not None
and litellm.model_group_settings.forward_client_headers_to_llm_api
is not None
and _check_model_access_helper(
model=data_model,
llm_router=llm_router,
models=litellm.model_group_settings.forward_client_headers_to_llm_api,
team_model_aliases=user_api_key_dict.team_model_aliases,
team_id=user_api_key_dict.team_id,
) # handles aliases, wildcards, etc.
):
_headers = LiteLLMProxyRequestSetup.add_headers_to_llm_call(
headers, user_api_key_dict
)
if _headers != {}:
data["headers"] = _headers
return data
@staticmethod
def get_internal_user_header_from_mapping(user_header_mapping) -> Optional[str]:
if not user_header_mapping:
return None
items = (
user_header_mapping
if isinstance(user_header_mapping, list)
else [user_header_mapping]
)
for item in items:
if not isinstance(item, dict):
continue
role = item.get("litellm_user_role")
header_name = item.get("header_name")
if role is None or not header_name:
continue
if str(role).lower() == str(LitellmUserRoles.INTERNAL_USER).lower():
return header_name
return None
@staticmethod
def add_litellm_data_for_backend_llm_call(
*,
headers: dict,
user_api_key_dict: UserAPIKeyAuth,
general_settings: Optional[Dict[str, Any]] = None,
) -> LitellmDataForBackendLLMCall:
"""
- Adds user from headers
- Adds forwardable headers
- Adds org id
"""
data = LitellmDataForBackendLLMCall()
if (
general_settings
and general_settings.get("forward_client_headers_to_llm_api") is True
):
_headers = LiteLLMProxyRequestSetup.add_headers_to_llm_call(
headers, user_api_key_dict
)
if _headers != {}:
data["headers"] = _headers
_organization = LiteLLMProxyRequestSetup.get_openai_org_id_from_headers(
headers, general_settings
)
if _organization is not None:
data["organization"] = _organization
timeout = LiteLLMProxyRequestSetup._get_timeout_from_request(headers)
if timeout is not None:
data["timeout"] = timeout
stream_timeout = LiteLLMProxyRequestSetup._get_stream_timeout_from_request(
headers
)
if stream_timeout is not None:
data["stream_timeout"] = stream_timeout
num_retries = LiteLLMProxyRequestSetup._get_num_retries_from_request(headers)
if num_retries is not None:
data["num_retries"] = num_retries
return data
@staticmethod
def add_litellm_metadata_from_request_headers(
headers: dict,
data: dict,
_metadata_variable_name: str,
) -> dict:
"""
Add litellm metadata from request headers
Relevant issue: https://github.com/BerriAI/litellm/issues/14008
"""
from litellm.proxy._types import LitellmMetadataFromRequestHeaders
metadata_from_headers = LitellmMetadataFromRequestHeaders()
spend_logs_metadata = (
LiteLLMProxyRequestSetup._get_spend_logs_metadata_from_request_headers(
headers
)
)
if spend_logs_metadata is not None:
metadata_from_headers["spend_logs_metadata"] = spend_logs_metadata
#########################################################################################
# Finally update the requests metadata with the `metadata_from_headers`
#########################################################################################
agent_id_from_header = headers.get("x-litellm-agent-id")
# x-litellm-trace-id and x-litellm-session-id are interchangeable for call chaining
chain_id = headers.get("x-litellm-trace-id") or headers.get(
"x-litellm-session-id"
)
if agent_id_from_header:
metadata_from_headers["agent_id"] = agent_id_from_header
verbose_proxy_logger.debug(
f"Extracted agent_id from header: {agent_id_from_header}"
)
if chain_id:
metadata_from_headers["trace_id"] = chain_id
metadata_from_headers["session_id"] = chain_id
data["litellm_session_id"] = chain_id
data["litellm_trace_id"] = chain_id
verbose_proxy_logger.debug(
f"Extracted chain_id from header (trace-id/session-id): {chain_id}"
)
if isinstance(data[_metadata_variable_name], dict):
data[_metadata_variable_name].update(metadata_from_headers)
return data
@staticmethod
def get_sanitized_user_information_from_key(
user_api_key_dict: UserAPIKeyAuth,
) -> StandardLoggingUserAPIKeyMetadata:
user_api_key_logged_metadata = StandardLoggingUserAPIKeyMetadata(
user_api_key_hash=user_api_key_dict.api_key, # just the hashed token
user_api_key_alias=user_api_key_dict.key_alias,
user_api_key_spend=user_api_key_dict.spend,
user_api_key_max_budget=user_api_key_dict.max_budget,
user_api_key_team_id=user_api_key_dict.team_id,
user_api_key_project_id=user_api_key_dict.project_id,
user_api_key_user_id=user_api_key_dict.user_id,
user_api_key_org_id=user_api_key_dict.org_id,
user_api_key_team_alias=user_api_key_dict.team_alias,
user_api_key_end_user_id=user_api_key_dict.end_user_id,
user_api_key_user_email=user_api_key_dict.user_email,
user_api_key_request_route=user_api_key_dict.request_route,
user_api_key_budget_reset_at=(
user_api_key_dict.budget_reset_at.isoformat()
if user_api_key_dict.budget_reset_at
else None
),
user_api_key_auth_metadata=user_api_key_dict.metadata,
)
return user_api_key_logged_metadata
@staticmethod
def add_user_api_key_auth_to_request_metadata(
data: dict,
user_api_key_dict: UserAPIKeyAuth,
_metadata_variable_name: str,
) -> dict:
"""
Adds the `UserAPIKeyAuth` object to the request metadata.
"""
user_api_key_logged_metadata = (
LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
user_api_key_dict=user_api_key_dict
)
)
data[_metadata_variable_name].update(user_api_key_logged_metadata)
data[_metadata_variable_name][
"user_api_key"
] = user_api_key_dict.api_key # this is just the hashed token
# Key-owned agent_id for spend attribution; keep existing (e.g. from header) if key has none
_key_agent_id = getattr(user_api_key_dict, "agent_id", None)
_existing_agent_id = data[_metadata_variable_name].get("agent_id")
_resolved_agent_id = _key_agent_id or _existing_agent_id
data[_metadata_variable_name]["agent_id"] = _resolved_agent_id
data[_metadata_variable_name]["user_api_end_user_max_budget"] = getattr(
user_api_key_dict, "end_user_max_budget", None
)
# Add the full UserAPIKeyAuth object for MCP server access control
data[_metadata_variable_name]["user_api_key_auth"] = user_api_key_dict
return data
@staticmethod
def add_management_endpoint_metadata_to_request_metadata(
data: dict,
management_endpoint_metadata: dict,
_metadata_variable_name: str,
) -> dict:
"""
Adds the `UserAPIKeyAuth` metadata to the request metadata.
ignore any sensitive fields like logging, api_key, etc.
"""
if _metadata_variable_name not in data:
return data
from litellm.proxy._types import (
LiteLLM_ManagementEndpoint_MetadataFields,
LiteLLM_ManagementEndpoint_MetadataFields_Premium,
)
# ignore any special fields
added_metadata = {}
for k, v in management_endpoint_metadata.items():
if k not in (
LiteLLM_ManagementEndpoint_MetadataFields_Premium
+ LiteLLM_ManagementEndpoint_MetadataFields
):
added_metadata[k] = v
if data[_metadata_variable_name].get("user_api_key_auth_metadata") is None:
data[_metadata_variable_name]["user_api_key_auth_metadata"] = {}
data[_metadata_variable_name]["user_api_key_auth_metadata"].update(
added_metadata
)
return data
@staticmethod
def add_key_level_controls(
key_metadata: Optional[dict], data: dict, _metadata_variable_name: str
):
if key_metadata is None:
return data
if "cache" in key_metadata:
data["cache"] = {}
if isinstance(key_metadata["cache"], dict):
for k, v in key_metadata["cache"].items():
if k in SupportedCacheControls:
data["cache"][k] = v
## KEY-LEVEL SPEND LOGS / TAGS
if "tags" in key_metadata and key_metadata["tags"] is not None:
data[_metadata_variable_name][
"tags"
] = LiteLLMProxyRequestSetup._merge_tags(
request_tags=data[_metadata_variable_name].get("tags"),
tags_to_add=key_metadata["tags"],
)
if "disable_global_guardrails" in key_metadata and isinstance(
key_metadata["disable_global_guardrails"], bool
):
data[_metadata_variable_name]["disable_global_guardrails"] = key_metadata[
"disable_global_guardrails"
]
if "spend_logs_metadata" in key_metadata and isinstance(
key_metadata["spend_logs_metadata"], dict
):
if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["spend_logs_metadata"], dict
):
for key, value in key_metadata["spend_logs_metadata"].items():
if (
key not in data[_metadata_variable_name]["spend_logs_metadata"]
): # don't override k-v pair sent by request (user request)
data[_metadata_variable_name]["spend_logs_metadata"][
key
] = value
else:
data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[
"spend_logs_metadata"
]
## KEY-LEVEL DISABLE FALLBACKS
if "disable_fallbacks" in key_metadata and isinstance(
key_metadata["disable_fallbacks"], bool
):
data["disable_fallbacks"] = key_metadata["disable_fallbacks"]
## KEY-LEVEL METADATA
data = LiteLLMProxyRequestSetup.add_management_endpoint_metadata_to_request_metadata(
data=data,
management_endpoint_metadata=key_metadata,
_metadata_variable_name=_metadata_variable_name,
)
return data
@staticmethod
def _merge_tags(request_tags: Optional[list], tags_to_add: Optional[list]) -> list:
"""
Helper function to merge two lists of tags, ensuring no duplicates.
Args:
request_tags (Optional[list]): List of tags from the original request
tags_to_add (Optional[list]): List of tags to add
Returns:
list: Combined list of unique tags
"""
final_tags = []
if request_tags and isinstance(request_tags, list):
final_tags.extend(request_tags)
if tags_to_add and isinstance(tags_to_add, list):
for tag in tags_to_add:
if tag not in final_tags:
final_tags.append(tag)
return final_tags
@staticmethod
def add_team_based_callbacks_from_config(
team_id: str,
proxy_config: ProxyConfig,
) -> Optional[TeamCallbackMetadata]:
"""
Add team-based callbacks from the config
"""
team_config = proxy_config.load_team_config(team_id=team_id)
if not isinstance(team_config, dict) or len(team_config) == 0:
return None
callback_vars_dict = {**team_config.get("callback_vars", team_config)}
callback_vars_dict.pop("team_id", None)
callback_vars_dict.pop("success_callback", None)
callback_vars_dict.pop("failure_callback", None)
return TeamCallbackMetadata(
success_callback=team_config.get("success_callback", None),
failure_callback=team_config.get("failure_callback", None),
callback_vars=callback_vars_dict,
)
@staticmethod
def add_request_tag_to_metadata(
llm_router: Optional[Router],
headers: dict,
data: dict,
) -> Optional[List[str]]:
tags = None
# Check request headers for tags
if "x-litellm-tags" in headers:
if isinstance(headers["x-litellm-tags"], str):
_tags = headers["x-litellm-tags"].split(",")
tags = [tag.strip() for tag in _tags]
elif isinstance(headers["x-litellm-tags"], list):
tags = headers["x-litellm-tags"]
# Check request body for tags
if "tags" in data and isinstance(data["tags"], list):
tags = data["tags"]
return tags
async def add_litellm_data_to_request( # noqa: PLR0915
data: dict,
request: Request,
user_api_key_dict: UserAPIKeyAuth,
proxy_config: ProxyConfig,
general_settings: Optional[Dict[str, Any]] = None,
version: Optional[str] = None,
):
"""
Adds LiteLLM-specific data to the request.
Args:
data (dict): The data dictionary to be modified.
request (Request): The incoming request.
user_api_key_dict (UserAPIKeyAuth): The user API key dictionary.
general_settings (Optional[Dict[str, Any]], optional): General settings. Defaults to None.
version (Optional[str], optional): Version. Defaults to None.
Returns:
dict: The modified data dictionary.
"""
from litellm.proxy.proxy_server import llm_router, premium_user
from litellm.types.proxy.litellm_pre_call_utils import RedactedDict, SecretFields
_raw_headers: Dict[str, str] = RedactedDict(_safe_get_request_headers(request))
forward_llm_auth = False
if general_settings:
forward_llm_auth = general_settings.get(
"forward_llm_provider_auth_headers", False
)
if not forward_llm_auth:
forward_llm_auth = getattr(litellm, "forward_llm_provider_auth_headers", False)
# Determine which header was used for authentication
# This enables forwarding provider keys (e.g., x-api-key) when they weren't used for LiteLLM auth
authenticated_with_header = None
if "x-litellm-api-key" in request.headers:
# If x-litellm-api-key is present, it was used for auth
authenticated_with_header = "x-litellm-api-key"
elif "authorization" in request.headers:
# Authorization header was used for auth
authenticated_with_header = "authorization"
else:
# x-api-key or another header was used for auth
authenticated_with_header = "x-api-key"
_headers: Dict[str, str] = clean_headers(
request.headers,
litellm_key_header_name=(
general_settings.get("litellm_key_header_name")
if general_settings is not None
else None
),
forward_llm_provider_auth_headers=forward_llm_auth,
authenticated_with_header=authenticated_with_header,
)
verbose_proxy_logger.debug(f"Request Headers: {_headers}")
verbose_proxy_logger.debug(f"Raw Headers: {_raw_headers}")
if forward_llm_auth and "x-api-key" in _headers:
data["api_key"] = _headers["x-api-key"]
verbose_proxy_logger.debug(
"Setting client-provided x-api-key as api_key parameter (will override deployment key)"
)
##########################################################
# Init - Proxy Server Request
# we do this as soon as entering so we track the original request
##########################################################
# Track arrival time for queue time metric
arrival_time = time.time()
data["proxy_server_request"] = {
"url": str(request.url),
"method": request.method,
"headers": _headers,
"body": copy.copy(data), # use copy instead of deepcopy
"arrival_time": arrival_time, # Track when request arrived at proxy
}
safe_add_api_version_from_query_params(data, request)
_metadata_variable_name = _get_metadata_variable_name(request)
if data.get(_metadata_variable_name, None) is None:
data[_metadata_variable_name] = {}
data.update(
LiteLLMProxyRequestSetup.add_litellm_data_for_backend_llm_call(
headers=_headers,
user_api_key_dict=user_api_key_dict,
general_settings=general_settings,
)
)
LiteLLMProxyRequestSetup.add_litellm_metadata_from_request_headers(
headers=_headers,
data=data,
_metadata_variable_name=_metadata_variable_name,
)
# Add headers to metadata for guardrails to access (fixes #17477)
# Guardrails use metadata["headers"] to access request headers (e.g., User-Agent)
if _metadata_variable_name in data and isinstance(
data[_metadata_variable_name], dict
):
data[_metadata_variable_name]["headers"] = _headers
# check for forwardable headers
data = LiteLLMProxyRequestSetup.add_headers_to_llm_call_by_model_group(
data=data, headers=_headers, user_api_key_dict=user_api_key_dict
)
user_api_key_dict = LiteLLMProxyRequestSetup.add_internal_user_from_user_mapping(
general_settings, user_api_key_dict, _headers
)
# Parse user info from headers (fallback to general_settings.user_header_name)
user = LiteLLMProxyRequestSetup.get_user_from_headers(_headers, general_settings)
if user is not None:
if user_api_key_dict.end_user_id is None:
user_api_key_dict.end_user_id = user
if "user" not in data:
data["user"] = user
data["secret_fields"] = SecretFields(raw_headers=_raw_headers)
## Dynamic api version (Azure OpenAI endpoints) ##
try:
query_params = request.query_params
# Convert query parameters to a dictionary (optional)
query_dict = dict(query_params)
except KeyError:
query_dict = {}
## check for api version in query params
dynamic_api_version: Optional[str] = query_dict.get("api-version")
if dynamic_api_version is not None: # only pass, if set
data["api_version"] = dynamic_api_version
## Forward any LLM API Provider specific headers in extra_headers
add_provider_specific_headers_to_request(data=data, headers=_headers)
## Cache Controls
cache_control_header = _headers.get("Cache-Control", None)
if cache_control_header:
cache_dict = parse_cache_control(cache_control_header)
data["ttl"] = cache_dict.get("s-maxage")
verbose_proxy_logger.debug("receiving data: %s", data)
# Parse metadata if it's a string (e.g., from multipart/form-data)
if "metadata" in data and data["metadata"] is not None:
if isinstance(data["metadata"], str):
data["metadata"] = safe_json_loads(data["metadata"])
if not isinstance(data["metadata"], dict):
verbose_proxy_logger.warning(
f"Failed to parse 'metadata' as JSON dict. Received value: {data['metadata']}"
)
data[_metadata_variable_name]["requester_metadata"] = copy.deepcopy(
data["metadata"]
)
# Parse litellm_metadata if it's a string (e.g., from multipart/form-data or extra_body)
if "litellm_metadata" in data and data["litellm_metadata"] is not None:
if isinstance(data["litellm_metadata"], str):
parsed_litellm_metadata = safe_json_loads(data["litellm_metadata"])
if not isinstance(parsed_litellm_metadata, dict):
verbose_proxy_logger.warning(
f"Failed to parse 'litellm_metadata' as JSON dict. Received value: {data['litellm_metadata']}"
)
else:
data["litellm_metadata"] = parsed_litellm_metadata
# Merge litellm_metadata into the metadata variable (preserving existing values)
if isinstance(data["litellm_metadata"], dict):
for key, value in data["litellm_metadata"].items():
if key not in data[_metadata_variable_name]:
data[_metadata_variable_name][key] = value
data = LiteLLMProxyRequestSetup.add_user_api_key_auth_to_request_metadata(
data=data,
user_api_key_dict=user_api_key_dict,
_metadata_variable_name=_metadata_variable_name,
)
data[_metadata_variable_name]["litellm_api_version"] = version
if general_settings is not None:
data[_metadata_variable_name][
"global_max_parallel_requests"
] = general_settings.get("global_max_parallel_requests", None)
### KEY-LEVEL Controls
key_metadata = user_api_key_dict.metadata
data = LiteLLMProxyRequestSetup.add_key_level_controls(
key_metadata=key_metadata,
data=data,
_metadata_variable_name=_metadata_variable_name,
)
## TEAM-LEVEL SPEND LOGS/TAGS
team_metadata = user_api_key_dict.team_metadata or {}
if "tags" in team_metadata and team_metadata["tags"] is not None:
data[_metadata_variable_name]["tags"] = LiteLLMProxyRequestSetup._merge_tags(
request_tags=data[_metadata_variable_name].get("tags"),
tags_to_add=team_metadata["tags"],
)
if "disable_global_guardrails" in team_metadata and isinstance(
team_metadata["disable_global_guardrails"], bool
):
data[_metadata_variable_name]["disable_global_guardrails"] = team_metadata[
"disable_global_guardrails"
]
if "spend_logs_metadata" in team_metadata and isinstance(
team_metadata["spend_logs_metadata"], dict
):
if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["spend_logs_metadata"], dict
):
for key, value in team_metadata["spend_logs_metadata"].items():
if (
key not in data[_metadata_variable_name]["spend_logs_metadata"]
): # don't override k-v pair sent by request (user request)
data[_metadata_variable_name]["spend_logs_metadata"][key] = value
else:
data[_metadata_variable_name]["spend_logs_metadata"] = team_metadata[
"spend_logs_metadata"
]
## PROJECT-LEVEL TAGS
project_metadata = user_api_key_dict.project_metadata or {}
if "tags" in project_metadata and project_metadata["tags"] is not None:
data[_metadata_variable_name]["tags"] = LiteLLMProxyRequestSetup._merge_tags(
request_tags=data[_metadata_variable_name].get("tags"),
tags_to_add=project_metadata["tags"],
)
## TEAM-LEVEL METADATA
data = (
LiteLLMProxyRequestSetup.add_management_endpoint_metadata_to_request_metadata(
data=data,
management_endpoint_metadata=team_metadata,
_metadata_variable_name=_metadata_variable_name,
)
)
# Team spend, budget - used by prometheus.py
data[_metadata_variable_name][
"user_api_key_team_max_budget"
] = user_api_key_dict.team_max_budget
data[_metadata_variable_name][
"user_api_key_team_spend"
] = user_api_key_dict.team_spend
data[_metadata_variable_name][
"user_api_key_request_route"
] = user_api_key_dict.request_route
# API Key spend, budget - used by prometheus.py
data[_metadata_variable_name]["user_api_key_spend"] = user_api_key_dict.spend
data[_metadata_variable_name][
"user_api_key_max_budget"
] = user_api_key_dict.max_budget
data[_metadata_variable_name][
"user_api_key_model_max_budget"
] = user_api_key_dict.model_max_budget
data[_metadata_variable_name][
"user_api_key_end_user_model_max_budget"
] = user_api_key_dict.end_user_model_max_budget
# User spend, budget - used by prometheus.py
# Follow same pattern as team and API key budgets
data[_metadata_variable_name][
"user_api_key_user_spend"
] = user_api_key_dict.user_spend
data[_metadata_variable_name][
"user_api_key_user_max_budget"
] = user_api_key_dict.user_max_budget
data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata
data[_metadata_variable_name][
"user_api_key_team_metadata"
] = user_api_key_dict.team_metadata
data[_metadata_variable_name]["user_api_key_object_permission_id"] = getattr(
user_api_key_dict, "object_permission_id", None
)
data[_metadata_variable_name]["user_api_key_team_object_permission_id"] = getattr(
user_api_key_dict, "team_object_permission_id", None
)
data[_metadata_variable_name]["headers"] = _headers
data[_metadata_variable_name]["endpoint"] = str(request.url)
# OTEL Controls / Tracing
# Add the OTEL Parent Trace before sending it LiteLLM
data[_metadata_variable_name][
"litellm_parent_otel_span"
] = user_api_key_dict.parent_otel_span
_add_otel_traceparent_to_data(data, request=request)
### END-USER SPECIFIC PARAMS ###
if user_api_key_dict.allowed_model_region is not None:
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
start_time = time.time()
## [Enterprise Only]
# Add User-IP Address
requester_ip_address = ""
if True: # Always set the IP Address if available
# logic for tracking IP Address
# logic for tracking IP Address
if (
general_settings is not None
and general_settings.get("use_x_forwarded_for") is True
and request is not None
and hasattr(request, "headers")
and "x-forwarded-for" in request.headers
):
requester_ip_address = request.headers["x-forwarded-for"]
elif (
request is not None
and hasattr(request, "client")
and hasattr(request.client, "host")
and request.client is not None
):
requester_ip_address = request.client.host
data[_metadata_variable_name]["requester_ip_address"] = requester_ip_address
# Add User-Agent
user_agent = ""
if (
request is not None
and hasattr(request, "headers")
and "user-agent" in request.headers
):
user_agent = request.headers["user-agent"]
data[_metadata_variable_name]["user_agent"] = user_agent
# Check if using tag based routing
tags = LiteLLMProxyRequestSetup.add_request_tag_to_metadata(
llm_router=llm_router,
headers=_headers,
data=data,
)
if tags is not None:
data[_metadata_variable_name]["tags"] = tags
# Team Callbacks controls
callback_settings_obj = _get_dynamic_logging_metadata(
user_api_key_dict=user_api_key_dict, proxy_config=proxy_config
)
if callback_settings_obj is not None:
data["success_callback"] = callback_settings_obj.success_callback
data["failure_callback"] = callback_settings_obj.failure_callback
if callback_settings_obj.callback_vars is not None:
# unpack callback_vars in data
for k, v in callback_settings_obj.callback_vars.items():
data[k] = v
# Add disabled callbacks from key metadata
if (
user_api_key_dict.metadata
and "litellm_disabled_callbacks" in user_api_key_dict.metadata
):
disabled_callbacks = user_api_key_dict.metadata["litellm_disabled_callbacks"]
if disabled_callbacks and isinstance(disabled_callbacks, list):
data["litellm_disabled_callbacks"] = disabled_callbacks
# Guardrails from key/team metadata and policy engine
await move_guardrails_to_metadata(
data=data,
_metadata_variable_name=_metadata_variable_name,
user_api_key_dict=user_api_key_dict,
)
# Team Model Aliases
_update_model_if_team_alias_exists(
data=data,
user_api_key_dict=user_api_key_dict,
)
# Key Model Aliases
_update_model_if_key_alias_exists(
data=data,
user_api_key_dict=user_api_key_dict,
)
verbose_proxy_logger.debug(
"[PROXY] returned data from litellm_pre_call_utils: %s", data
)
## ENFORCED PARAMS CHECK
# loop through each enforced param
# example enforced_params ['user', 'metadata', 'metadata.generation_name']
_enforced_params_check(
request_body=data,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
premium_user=premium_user,
)
end_time = time.time()
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=ServiceTypes.PROXY_PRE_CALL,
duration=end_time - start_time,
call_type="add_litellm_data_to_request",
start_time=start_time,
end_time=end_time,
parent_otel_span=user_api_key_dict.parent_otel_span,
)
)
return data
def _update_model_if_team_alias_exists(
data: dict,
user_api_key_dict: UserAPIKeyAuth,
) -> None:
"""
Update the model if the team alias exists
If a alias map has been set on a team, then we want to make the request with the model the team alias is pointing to
eg.
- user calls `gpt-4o`
- team.model_alias_map = {
"gpt-4o": "gpt-4o-team-1"
}
- requested_model = "gpt-4o-team-1"
"""
_model = data.get("model")
if (
_model
and user_api_key_dict.team_model_aliases
and _model in user_api_key_dict.team_model_aliases
):
data["model"] = user_api_key_dict.team_model_aliases[_model]
return
def _update_model_if_key_alias_exists(
data: dict,
user_api_key_dict: UserAPIKeyAuth,
) -> None:
"""
Update the model if the key alias exists
If an alias map has been set on a key, then we want to make the request with the model the key alias is pointing to
eg.
- user calls `modelAlias`
- key.aliases = {
"modelAlias": "xai/grok-4-fast-non-reasoning"
}
- requested_model = "xai/grok-4-fast-non-reasoning"
"""
_model = data.get("model")
if (
_model
and user_api_key_dict.aliases
and isinstance(user_api_key_dict.aliases, dict)
and _model in user_api_key_dict.aliases
):
data["model"] = user_api_key_dict.aliases[_model]
return
def _get_enforced_params(
general_settings: Optional[dict], user_api_key_dict: UserAPIKeyAuth
) -> Optional[list]:
enforced_params: Optional[list] = None
if general_settings is not None:
enforced_params = general_settings.get("enforced_params")
if (
"service_account_settings" in general_settings
and check_if_token_is_service_account(user_api_key_dict) is True
):
service_account_settings = general_settings["service_account_settings"]
if "enforced_params" in service_account_settings:
if enforced_params is None:
enforced_params = []
enforced_params.extend(service_account_settings["enforced_params"])
if user_api_key_dict.metadata.get("enforced_params", None) is not None:
if enforced_params is None:
enforced_params = []
enforced_params.extend(user_api_key_dict.metadata["enforced_params"])
return enforced_params
def check_if_token_is_service_account(valid_token: UserAPIKeyAuth) -> bool:
"""
Checks if the token is a service account
Returns:
bool: True if token is a service account
"""
if valid_token.metadata:
if "service_account_id" in valid_token.metadata:
return True
return False
def _enforced_params_check(
request_body: dict,
general_settings: Optional[dict],
user_api_key_dict: UserAPIKeyAuth,
premium_user: bool,
) -> bool:
"""
If enforced params are set, check if the request body contains the enforced params.
"""
enforced_params: Optional[list] = _get_enforced_params(
general_settings=general_settings, user_api_key_dict=user_api_key_dict
)
if enforced_params is None:
return True
if enforced_params and premium_user is not True:
raise ValueError(
f"Enforced Params is an Enterprise feature. Enforced Params: {enforced_params}. {CommonProxyErrors.not_premium_user.value}"
)
for enforced_param in enforced_params:
_enforced_params = enforced_param.split(".")
if len(_enforced_params) == 1:
if _enforced_params[0] not in request_body:
raise ValueError(
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
)
elif len(_enforced_params) == 2:
# this is a scenario where user requires request['metadata']['generation_name'] to exist
if _enforced_params[0] not in request_body:
raise ValueError(
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
)
if _enforced_params[1] not in request_body[_enforced_params[0]]:
raise ValueError(
f"BadRequest please pass param=[{_enforced_params[0]}][{_enforced_params[1]}] in request body. This is a required param"
)
return True
def _add_guardrails_from_key_or_team_metadata(
key_metadata: Optional[dict],
team_metadata: Optional[dict],
data: dict,
metadata_variable_name: str,
) -> None:
"""
Helper add guardrails from key or team metadata to request data
Key guardrails are set first, then team guardrails are appended (without duplicates).
Args:
key_metadata: The key metadata dictionary to check for guardrails
team_metadata: The team metadata dictionary to check for guardrails
data: The request data to update
metadata_variable_name: The name of the metadata field in data
"""
from litellm.proxy.utils import _premium_user_check
# Initialize guardrails set (avoiding duplicates)
combined_guardrails = set()
# Add key-level guardrails first
if key_metadata and "guardrails" in key_metadata:
if (
isinstance(key_metadata["guardrails"], list)
and len(key_metadata["guardrails"]) > 0
):
_premium_user_check()
combined_guardrails.update(key_metadata["guardrails"])
# Add team-level guardrails (set automatically handles duplicates)
if team_metadata and "guardrails" in team_metadata:
if (
isinstance(team_metadata["guardrails"], list)
and len(team_metadata["guardrails"]) > 0
):
_premium_user_check()
combined_guardrails.update(team_metadata["guardrails"])
# Set combined guardrails in metadata as list
if combined_guardrails:
data[metadata_variable_name]["guardrails"] = list(combined_guardrails)
def _add_guardrails_from_policies_in_metadata(
key_metadata: Optional[dict],
team_metadata: Optional[dict],
data: dict,
metadata_variable_name: str,
) -> None:
"""
Helper to resolve guardrails from policies attached to key/team metadata.
This function:
1. Gets policy names from key and team metadata
2. Resolves guardrails from those policies (including inheritance)
3. Adds resolved guardrails to request metadata
Args:
key_metadata: The key metadata dictionary to check for policies
team_metadata: The team metadata dictionary to check for policies
data: The request data to update
metadata_variable_name: The name of the metadata field in data
"""
from litellm._logging import verbose_proxy_logger
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
from litellm.proxy.policy_engine.policy_resolver import PolicyResolver
from litellm.proxy.utils import _premium_user_check
from litellm.types.proxy.policy_engine import PolicyMatchContext
# Collect policy names from key and team metadata
policy_names: set = set()
# Add key-level policies first
if key_metadata and "policies" in key_metadata:
if (
isinstance(key_metadata["policies"], list)
and len(key_metadata["policies"]) > 0
):
_premium_user_check()
policy_names.update(key_metadata["policies"])
# Add team-level policies
if team_metadata and "policies" in team_metadata:
if (
isinstance(team_metadata["policies"], list)
and len(team_metadata["policies"]) > 0
):
_premium_user_check()
policy_names.update(team_metadata["policies"])
if not policy_names:
return
verbose_proxy_logger.debug(
f"Policy engine: resolving guardrails from key/team policies: {policy_names}"
)
# Check if policy registry is initialized
registry = get_policy_registry()
if not registry.is_initialized():
verbose_proxy_logger.debug(
"Policy engine not initialized, skipping policy resolution from metadata"
)
return
# Build context for policy resolution (model from request data)
context = PolicyMatchContext(model=data.get("model"))
# Get all policies from registry
all_policies = registry.get_all_policies()
# Resolve guardrails from the specified policies
resolved_guardrails: set = set()
for policy_name in policy_names:
if registry.has_policy(policy_name):
resolved_policy = PolicyResolver.resolve_policy_guardrails(
policy_name=policy_name,
policies=all_policies,
context=context,
)
resolved_guardrails.update(resolved_policy.guardrails)
verbose_proxy_logger.debug(
f"Policy engine: resolved guardrails from policy '{policy_name}': {resolved_policy.guardrails}"
)
else:
verbose_proxy_logger.warning(
f"Policy engine: policy '{policy_name}' not found in registry"
)
if not resolved_guardrails:
return
# Add resolved guardrails to request metadata
if metadata_variable_name not in data:
data[metadata_variable_name] = {}
existing_guardrails = data[metadata_variable_name].get("guardrails", [])
if not isinstance(existing_guardrails, list):
existing_guardrails = []
# Combine existing guardrails with policy-resolved guardrails (no duplicates)
combined = set(existing_guardrails)
combined.update(resolved_guardrails)
data[metadata_variable_name]["guardrails"] = list(combined)
# Store applied policies in metadata for tracking
if "applied_policies" not in data[metadata_variable_name]:
data[metadata_variable_name]["applied_policies"] = []
data[metadata_variable_name]["applied_policies"].extend(list(policy_names))
verbose_proxy_logger.debug(
f"Policy engine: added guardrails from key/team policies to request metadata: {list(resolved_guardrails)}"
)
async def move_guardrails_to_metadata(
data: dict,
_metadata_variable_name: str,
user_api_key_dict: UserAPIKeyAuth,
):
"""
Helper to add guardrails from request to metadata
- If guardrails set on API Key metadata then sets guardrails on request metadata
- If guardrails not set on API key, then checks request metadata
- Adds guardrails from policies attached to key/team metadata
- Adds guardrails from policy engine based on team/key/model context
"""
# Early-out: skip all guardrails processing when nothing is configured
key_metadata = user_api_key_dict.metadata
team_metadata = user_api_key_dict.team_metadata
has_key_config = key_metadata and (
"guardrails" in key_metadata or "policies" in key_metadata
)
has_team_config = team_metadata and (
"guardrails" in team_metadata or "policies" in team_metadata
)
has_request_config = (
"guardrails" in data or "guardrail_config" in data or "policies" in data
)
# Only check policy engine if no local config (avoid import + registry lookup)
if not (has_key_config or has_team_config or has_request_config):
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
if not get_policy_registry().is_initialized():
# Nothing configured anywhere - clean up request body fields and return
data.pop("policies", None)
return
# Check key-level guardrails
_add_guardrails_from_key_or_team_metadata(
key_metadata=user_api_key_dict.metadata,
team_metadata=user_api_key_dict.team_metadata,
data=data,
metadata_variable_name=_metadata_variable_name,
)
#########################################################################################
# Add guardrails from policies attached to key/team metadata
#########################################################################################
_add_guardrails_from_policies_in_metadata(
key_metadata=user_api_key_dict.metadata,
team_metadata=user_api_key_dict.team_metadata,
data=data,
metadata_variable_name=_metadata_variable_name,
)
#########################################################################################
# Add guardrails from policy engine based on team/key/model context
#########################################################################################
await add_guardrails_from_policy_engine(
data=data,
metadata_variable_name=_metadata_variable_name,
user_api_key_dict=user_api_key_dict,
)
#########################################################################################
# User's might send "guardrails" in the request body, we need to add them to the request metadata.
# Since downstream logic requires "guardrails" to be in the request metadata
#########################################################################################
if "guardrails" in data:
request_body_guardrails = data.pop("guardrails")
if "guardrails" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["guardrails"], list
):
data[_metadata_variable_name]["guardrails"].extend(request_body_guardrails)
else:
data[_metadata_variable_name]["guardrails"] = request_body_guardrails
#########################################################################################
if "guardrail_config" in data:
request_body_guardrail_config = data.pop("guardrail_config")
if "guardrail_config" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["guardrail_config"], dict
):
data[_metadata_variable_name]["guardrail_config"].update(
request_body_guardrail_config
)
else:
data[_metadata_variable_name][
"guardrail_config"
] = request_body_guardrail_config
def _is_policy_version_id(s: str) -> bool:
"""Return True if string is a policy version ID (starts with policy_<uuid> prefix)."""
from litellm.proxy.policy_engine.policy_registry import POLICY_VERSION_ID_PREFIX
return isinstance(s, str) and s.startswith(POLICY_VERSION_ID_PREFIX)
def _extract_policy_id(s: str) -> Optional[str]:
"""Extract raw UUID from policy_<uuid> string, or None if not a valid version ID."""
from litellm.proxy.policy_engine.policy_registry import POLICY_VERSION_ID_PREFIX
if not _is_policy_version_id(s):
return None
return s[len(POLICY_VERSION_ID_PREFIX) :].strip() or None
def _match_and_track_policies(
data: dict,
context: "PolicyMatchContext",
request_body_policies: Any,
policies_override: Optional[Dict[str, Any]] = None,
) -> tuple[list[str], dict[str, str]]:
"""
Match policies via attachments and request body, track them in metadata.
Returns:
Tuple of (applied_policy_names, policy_reasons)
"""
from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.callback_utils import (
add_policy_sources_to_metadata,
add_policy_to_applied_policies_header,
)
from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
# Get matching policies via attachments (with match reasons for attribution)
attachment_registry = get_attachment_registry()
matches_with_reasons = attachment_registry.get_attached_policies_with_reasons(
context
)
matching_policy_names = [m["policy_name"] for m in matches_with_reasons]
policy_reasons = {m["policy_name"]: m["matched_via"] for m in matches_with_reasons}
verbose_proxy_logger.debug(
f"Policy engine: matched policies via attachments: {matching_policy_names}"
)
# Combine attachment-based policies with dynamic request body policies
all_policy_names = set(matching_policy_names)
if request_body_policies and isinstance(request_body_policies, list):
all_policy_names.update(request_body_policies)
verbose_proxy_logger.debug(
f"Policy engine: added dynamic policies from request body: {request_body_policies}"
)
if not all_policy_names:
return [], {}
# Filter to only policies whose conditions match the context
applied_policy_names = PolicyMatcher.get_policies_with_matching_conditions(
policy_names=list(all_policy_names),
context=context,
policies=policies_override,
)
verbose_proxy_logger.debug(
f"Policy engine: applied policies (conditions matched): {applied_policy_names}"
)
# Track applied policies in metadata for response headers
for policy_name in applied_policy_names:
add_policy_to_applied_policies_header(
request_data=data, policy_name=policy_name
)
# Track policy attribution sources for x-litellm-policy-sources header
applied_reasons = {
name: policy_reasons[name]
for name in applied_policy_names
if name in policy_reasons
}
add_policy_sources_to_metadata(request_data=data, policy_sources=applied_reasons)
return applied_policy_names, policy_reasons
def _apply_resolved_guardrails_to_metadata(
data: dict,
metadata_variable_name: str,
context: "PolicyMatchContext",
policy_names: Optional[List[str]] = None,
policies: Optional[Dict[str, Any]] = None,
) -> None:
"""Apply resolved guardrails and pipelines to request metadata."""
from litellm._logging import verbose_proxy_logger
from litellm.proxy.policy_engine.policy_resolver import PolicyResolver
# Resolve guardrails from matching policies
resolved_guardrails = PolicyResolver.resolve_guardrails_for_context(
context=context,
policies=policies,
policy_names=policy_names,
)
verbose_proxy_logger.debug(
f"Policy engine: resolved guardrails: {resolved_guardrails}"
)
# Resolve pipelines from matching policies
pipelines = PolicyResolver.resolve_pipelines_for_context(
context=context,
policies=policies,
policy_names=policy_names,
)
# Add resolved guardrails to request metadata
if metadata_variable_name not in data:
data[metadata_variable_name] = {}
# Track pipeline-managed guardrails to exclude from independent execution
pipeline_managed_guardrails: set = set()
if pipelines:
pipeline_managed_guardrails = PolicyResolver.get_pipeline_managed_guardrails(
pipelines
)
data[metadata_variable_name]["_guardrail_pipelines"] = pipelines
data[metadata_variable_name][
"_pipeline_managed_guardrails"
] = pipeline_managed_guardrails
verbose_proxy_logger.debug(
f"Policy engine: resolved {len(pipelines)} pipeline(s), "
f"managed guardrails: {pipeline_managed_guardrails}"
)
if not resolved_guardrails and not pipelines:
return
existing_guardrails = data[metadata_variable_name].get("guardrails", [])
if not isinstance(existing_guardrails, list):
existing_guardrails = []
# Combine existing guardrails with policy-resolved guardrails (no duplicates)
# Exclude pipeline-managed guardrails from the flat list
combined = set(existing_guardrails)
combined.update(resolved_guardrails)
combined -= pipeline_managed_guardrails
data[metadata_variable_name]["guardrails"] = list(combined)
verbose_proxy_logger.debug(
f"Policy engine: added guardrails to request metadata: {list(combined)}"
)
async def add_guardrails_from_policy_engine(
data: dict,
metadata_variable_name: str,
user_api_key_dict: UserAPIKeyAuth,
) -> None:
"""
Add guardrails from the policy engine based on request context.
This function:
1. Extracts "policies" from request body (if present) for dynamic policy application
2. Supports policy_<uuid> in policies to execute a specific version (e.g. published)
3. Gets matching policies based on team_alias, key_alias, and model (via attachments)
4. Combines dynamic policies with attachment-based policies
5. Resolves guardrails from all policies (including inheritance)
6. Adds guardrails to request metadata
7. Tracks applied policies in metadata for response headers
8. Removes "policies" from request body so it's not forwarded to LLM provider
Args:
data: The request data to update
metadata_variable_name: The name of the metadata field in data
user_api_key_dict: The user's API key authentication info
"""
from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.http_parsing_utils import get_tags_from_request_body
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
from litellm.types.proxy.policy_engine import PolicyMatchContext
# Extract dynamic policies from request body (if present)
request_body_policies_raw = data.pop("policies", None)
registry = get_policy_registry()
verbose_proxy_logger.debug(
f"Policy engine: registry initialized={registry.is_initialized()}, "
f"policy_count={len(registry.get_all_policies())}"
)
if not registry.is_initialized():
verbose_proxy_logger.debug(
"Policy engine not initialized, skipping policy matching"
)
return
# Extract tags and build context
all_tags = get_tags_from_request_body(data) or None
_team_alias = user_api_key_dict.team_alias
_key_alias = user_api_key_dict.key_alias
context = PolicyMatchContext(
team_alias=_team_alias if isinstance(_team_alias, str) else None,
key_alias=_key_alias if isinstance(_key_alias, str) else None,
model=data.get("model"),
tags=all_tags,
)
verbose_proxy_logger.debug(
f"Policy engine: matching policies for context team_alias={context.team_alias}, "
f"key_alias={context.key_alias}, model={context.model}, tags={context.tags}"
)
# Separate policy names from policy version IDs (policy_<uuid>)
request_body_names: List[str] = []
request_body_version_ids: List[str] = []
if request_body_policies_raw and isinstance(request_body_policies_raw, list):
for item in request_body_policies_raw:
if not isinstance(item, str):
continue
if _is_policy_version_id(item):
policy_id = _extract_policy_id(item)
if policy_id:
request_body_version_ids.append(policy_id)
else:
request_body_names.append(item)
# Resolve policy versions by ID from in-memory cache (populated by sync job; no DB in hot path)
merged_policies: Dict[str, Any] = dict(registry.get_all_policies())
fetched_policy_names: List[str] = []
for policy_id in request_body_version_ids:
result = registry.get_policy_by_id_for_request(policy_id=policy_id)
if result is not None:
pname, policy = result
merged_policies[pname] = policy
fetched_policy_names.append(pname)
verbose_proxy_logger.debug(
f"Policy engine: loaded version by ID policy_{policy_id} -> {pname}"
)
else:
verbose_proxy_logger.debug(
f"Policy engine: policy version {policy_id} not found in cache, skipping"
)
# Build request body list: names + policy names from fetched versions
request_body_policies = request_body_names + fetched_policy_names
# Match and track policies (with merged_policies when we have version overrides)
applied_policy_names, _ = _match_and_track_policies(
data,
context,
request_body_policies,
policies_override=merged_policies if request_body_version_ids else None,
)
# Resolve and apply guardrails. Use applied_policy_names so request-body policies
# (names + version IDs) are included. Use merged_policies when we have version overrides.
_apply_resolved_guardrails_to_metadata(
data,
metadata_variable_name,
context,
policy_names=applied_policy_names if applied_policy_names else None,
policies=merged_policies if request_body_version_ids else None,
)
def add_provider_specific_headers_to_request(
data: dict,
headers: dict,
):
from litellm.llms.anthropic.common_utils import is_anthropic_oauth_key
anthropic_headers = {}
# boolean to indicate if a header was added
added_header = False
for header in ANTHROPIC_API_HEADERS:
if header in headers:
header_value = headers[header]
anthropic_headers[header] = header_value
added_header = True
# Check for Authorization header with Anthropic OAuth token (sk-ant-oat*)
# This needs to be handled via provider-specific headers to ensure it only
# goes to Anthropic-compatible providers, not all providers in the router
for header, value in headers.items():
if header.lower() == "authorization" and is_anthropic_oauth_key(value):
anthropic_headers[header] = value
added_header = True
break
if added_header is True:
# Anthropic headers work across multiple providers
# Store as comma-separated list so retrieval can match any of them
data["provider_specific_header"] = ProviderSpecificHeader(
custom_llm_provider=f"{LlmProviders.ANTHROPIC.value},{LlmProviders.BEDROCK.value},{LlmProviders.VERTEX_AI.value}",
extra_headers=anthropic_headers,
)
return
def _add_otel_traceparent_to_data(data: dict, request: Request):
from litellm.proxy.proxy_server import open_telemetry_logger
if data is None:
return
if open_telemetry_logger is None:
# if user is not use OTEL don't send extra_headers
# relevant issue: https://github.com/BerriAI/litellm/issues/4448
return
if litellm.forward_traceparent_to_llm_provider is True:
if request.headers:
if "traceparent" in request.headers:
# we want to forward this to the LLM Provider
# Relevant issue: https://github.com/BerriAI/litellm/issues/4419
# pass this in extra_headers
if "extra_headers" not in data:
data["extra_headers"] = {}
_exra_headers = data["extra_headers"]
if "traceparent" not in _exra_headers:
_exra_headers["traceparent"] = request.headers["traceparent"]