chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1,70 @@
from typing import Any, Optional, Union
from pydantic import BaseModel
from litellm.types.utils import HiddenParams
def _add_headers_to_response(response: Any, headers: dict) -> Any:
"""
Helper function to add headers to a response's hidden params
"""
if response is None or not isinstance(response, BaseModel):
return response
hidden_params: Optional[Union[dict, HiddenParams]] = getattr(
response, "_hidden_params", {}
)
if hidden_params is None:
hidden_params_dict = {}
elif isinstance(hidden_params, HiddenParams):
hidden_params_dict = hidden_params.model_dump()
else:
hidden_params_dict = hidden_params
hidden_params_dict.setdefault("additional_headers", {})
hidden_params_dict["additional_headers"].update(headers)
setattr(response, "_hidden_params", hidden_params_dict)
return response
def add_retry_headers_to_response(
response: Any,
attempted_retries: int,
max_retries: Optional[int] = None,
) -> Any:
"""
Add retry headers to the request
"""
retry_headers = {
"x-litellm-attempted-retries": attempted_retries,
}
if max_retries is not None:
retry_headers["x-litellm-max-retries"] = max_retries
return _add_headers_to_response(response, retry_headers)
def add_fallback_headers_to_response(
response: Any,
attempted_fallbacks: int,
) -> Any:
"""
Add fallback headers to the response
Args:
response: The response to add the headers to
attempted_fallbacks: The number of fallbacks attempted
Returns:
The response with the headers added
Note: It's intentional that we don't add max_fallbacks in response headers
Want to avoid bloat in the response headers for performance.
"""
fallback_headers = {
"x-litellm-attempted-fallbacks": attempted_fallbacks,
}
return _add_headers_to_response(response, fallback_headers)

View File

@@ -0,0 +1,150 @@
import io
import json
from os import PathLike
from typing import List, Optional
from litellm._logging import verbose_logger
from litellm.types.llms.openai import FileTypes, OpenAIFilesPurpose
class InMemoryFile(io.BytesIO):
def __init__(
self, content: bytes, name: str, content_type: str = "application/jsonl"
):
super().__init__(content)
self.name = name
self.content_type = content_type
def parse_jsonl_with_embedded_newlines(content: str) -> List[dict]:
"""
Parse JSONL content that may contain JSON objects with embedded newlines in string values.
Unlike splitlines(), this function properly handles cases where JSON string values
contain literal newline characters, which would otherwise break simple line-based parsing.
Args:
content: The JSONL file content as a string
Returns:
List of parsed JSON objects
Example:
>>> content = '{"id":1,"msg":"Line 1\\nLine 2"}\\n{"id":2,"msg":"test"}'
>>> parse_jsonl_with_embedded_newlines(content)
[{"id":1,"msg":"Line 1\\nLine 2"}, {"id":2,"msg":"test"}]
"""
json_objects = []
buffer = ""
for char in content:
buffer += char
if char == "\n":
# Try to parse what we have so far
try:
json_object = json.loads(buffer.strip())
json_objects.append(json_object)
buffer = "" # Reset buffer for next JSON object
except json.JSONDecodeError:
# Not a complete JSON object yet, keep accumulating
continue
# Handle any remaining content in buffer
if buffer.strip():
try:
json_object = json.loads(buffer.strip())
json_objects.append(json_object)
except json.JSONDecodeError as e:
verbose_logger.error(
f"error parsing final buffer: {buffer[:100]}..., error: {e}"
)
raise e
return json_objects
def should_replace_model_in_jsonl(
purpose: OpenAIFilesPurpose,
) -> bool:
"""
Check if the model name should be replaced in the JSONL file for the deployment model name.
Azure raises an error on create batch if the model name for deployment is not in the .jsonl.
"""
if purpose == "batch":
return True
return False
def replace_model_in_jsonl(file_content: FileTypes, new_model_name: str) -> FileTypes:
try:
## if pathlike, return the original file content
if isinstance(file_content, PathLike):
return file_content
# Decode the bytes to a string and split into lines
# If file_content is a file-like object, read the bytes
if hasattr(file_content, "read"):
file_content_bytes = file_content.read() # type: ignore
elif isinstance(file_content, tuple):
file_content_bytes = file_content[1]
else:
file_content_bytes = file_content
# Decode the bytes to a string and split into lines
if isinstance(file_content_bytes, bytes):
file_content_str = file_content_bytes.decode("utf-8")
elif isinstance(file_content_bytes, str):
file_content_str = file_content_bytes
else:
return file_content
# Parse JSONL properly, handling potential multiline JSON objects
json_objects = parse_jsonl_with_embedded_newlines(file_content_str)
# If no valid JSON objects were found, return the original content
if len(json_objects) == 0:
return file_content
modified_lines = []
for json_object in json_objects:
# Replace the model name if it exists
if "body" in json_object:
json_object["body"]["model"] = new_model_name
# Convert the modified JSON object back to a string
modified_lines.append(json.dumps(json_object))
# Reassemble the modified lines and return as bytes
modified_file_content = "\n".join(modified_lines).encode("utf-8")
return InMemoryFile(modified_file_content, name="modified_file.jsonl", content_type="application/jsonl") # type: ignore
except (json.JSONDecodeError, UnicodeDecodeError, TypeError):
# return the original file content if there is an error replacing the model name
return file_content
def _get_router_metadata_variable_name(function_name: Optional[str]) -> str:
"""
Helper to return what the "metadata" field should be called in the request data
For all /thread or /assistant endpoints we need to call this "litellm_metadata"
For ALL other endpoints we call this "metadata
"""
ROUTER_METHODS_USING_LITELLM_METADATA = set(
[
"batch",
"generic_api_call",
"_acreate_batch",
"file",
"_ageneric_api_call_with_fallbacks",
]
)
if function_name and any(
method in function_name for method in ROUTER_METHODS_USING_LITELLM_METADATA
):
return "litellm_metadata"
else:
return "metadata"

View File

@@ -0,0 +1,37 @@
import asyncio
from typing import TYPE_CHECKING, Any
from litellm.utils import calculate_max_parallel_requests
if TYPE_CHECKING:
from litellm.router import Router as _Router
LitellmRouter = _Router
else:
LitellmRouter = Any
class InitalizeCachedClient:
@staticmethod
def set_max_parallel_requests_client(
litellm_router_instance: LitellmRouter, model: dict
):
litellm_params = model.get("litellm_params", {})
model_id = model["model_info"]["id"]
rpm = litellm_params.get("rpm", None)
tpm = litellm_params.get("tpm", None)
max_parallel_requests = litellm_params.get("max_parallel_requests", None)
calculated_max_parallel_requests = calculate_max_parallel_requests(
rpm=rpm,
max_parallel_requests=max_parallel_requests,
tpm=tpm,
default_max_parallel_requests=litellm_router_instance.default_max_parallel_requests,
)
if calculated_max_parallel_requests:
semaphore = asyncio.Semaphore(calculated_max_parallel_requests)
cache_key = f"{model_id}_max_parallel_requests_client"
litellm_router_instance.cache.set_cache(
key=cache_key,
value=semaphore,
local_only=True,
)

View File

@@ -0,0 +1,37 @@
"""
Utils for handling clientside credentials
Supported clientside credentials:
- api_key
- api_base
- base_url
If given, generate a unique model_id for the deployment.
Ensures cooldowns are applied correctly.
"""
clientside_credential_keys = ["api_key", "api_base", "base_url"]
def is_clientside_credential(request_kwargs: dict) -> bool:
"""
Check if the credential is a clientside credential.
"""
return any(key in request_kwargs for key in clientside_credential_keys)
def get_dynamic_litellm_params(litellm_params: dict, request_kwargs: dict) -> dict:
"""
Generate a unique model_id for the deployment.
Returns
- litellm_params: dict
for generating a unique model_id.
"""
# update litellm_params with clientside credentials
for key in clientside_credential_keys:
if key in request_kwargs:
litellm_params[key] = request_kwargs[key]
return litellm_params

View File

@@ -0,0 +1,130 @@
import hashlib
import json
from typing import TYPE_CHECKING, Dict, List, Optional, Union
if TYPE_CHECKING:
from litellm.types.llms.openai import OpenAIFileObject
from litellm.types.router import CredentialLiteLLMParams
from litellm._logging import verbose_logger
def get_litellm_params_sensitive_credential_hash(litellm_params: dict) -> str:
"""
Hash of the credential params, used for mapping the file id to the right model
"""
sensitive_params = CredentialLiteLLMParams(**litellm_params)
return hashlib.sha256(
json.dumps(sensitive_params.model_dump()).encode()
).hexdigest()
def add_model_file_id_mappings(
healthy_deployments: Union[List[Dict], Dict], responses: List["OpenAIFileObject"]
) -> dict:
"""
Create a mapping of model name to file id
{
"model_id": "file_id",
"model_id": "file_id",
}
"""
model_file_id_mapping = {}
if isinstance(healthy_deployments, list):
for deployment, response in zip(healthy_deployments, responses):
model_file_id_mapping[
deployment.get("model_info", {}).get("id")
] = response.id
elif isinstance(healthy_deployments, dict):
for model_id, file_id in healthy_deployments.items():
model_file_id_mapping[model_id] = file_id
return model_file_id_mapping
def filter_team_based_models(
healthy_deployments: Union[List[Dict], Dict],
request_kwargs: Optional[Dict] = None,
) -> Union[List[Dict], Dict]:
"""
If a model has a team_id
Only use if request is from that team
"""
if request_kwargs is None:
return healthy_deployments
metadata = request_kwargs.get("metadata") or {}
litellm_metadata = request_kwargs.get("litellm_metadata") or {}
request_team_id = metadata.get("user_api_key_team_id") or litellm_metadata.get(
"user_api_key_team_id"
)
ids_to_remove = set()
if isinstance(healthy_deployments, dict):
return healthy_deployments
for deployment in healthy_deployments:
_model_info = deployment.get("model_info") or {}
model_team_id = _model_info.get("team_id")
if model_team_id is None:
continue
if model_team_id != request_team_id:
ids_to_remove.add(_model_info.get("id"))
return [
deployment
for deployment in healthy_deployments
if deployment.get("model_info", {}).get("id") not in ids_to_remove
]
def _deployment_supports_web_search(deployment: Dict) -> bool:
"""
Check if a deployment supports web search.
Priority:
1. Check config-level override in model_info.supports_web_search
2. Default to True (assume supported unless explicitly disabled)
Note: Ideally we'd fall back to litellm.supports_web_search() but
model_prices_and_context_window.json doesn't have supports_web_search
tags on all models yet. TODO: backfill and add fallback.
"""
model_info = deployment.get("model_info", {})
if "supports_web_search" in model_info:
return model_info["supports_web_search"]
return True
def filter_web_search_deployments(
healthy_deployments: Union[List[Dict], Dict],
request_kwargs: Optional[Dict] = None,
) -> Union[List[Dict], Dict]:
"""
If the request is websearch, filter out deployments that don't support web search
"""
if request_kwargs is None:
return healthy_deployments
# When a specific deployment was already chosen, it's returned as a dict
# rather than a list - nothing to filter, just pass through
if isinstance(healthy_deployments, dict):
return healthy_deployments
is_web_search_request = False
tools = request_kwargs.get("tools") or []
for tool in tools:
# These are the two websearch tools for OpenAI / Azure.
if tool.get("type") == "web_search" or tool.get("type") == "web_search_preview":
is_web_search_request = True
break
if not is_web_search_request:
return healthy_deployments
# Filter out deployments that don't support web search
final_deployments = [
d for d in healthy_deployments if _deployment_supports_web_search(d)
]
if len(healthy_deployments) > 0 and len(final_deployments) == 0:
verbose_logger.warning("No deployments support web search for request")
return final_deployments

View File

@@ -0,0 +1,193 @@
"""
Wrapper around router cache. Meant to handle model cooldown logic
"""
import functools
import time
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
from typing_extensions import TypedDict
from litellm import verbose_logger
from litellm.caching.caching import DualCache
from litellm.caching.in_memory_cache import InMemoryCache
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
class CooldownCacheValue(TypedDict):
exception_received: str
status_code: str
timestamp: float
cooldown_time: float
class CooldownCache:
def __init__(self, cache: DualCache, default_cooldown_time: float):
self.cache = cache
self.default_cooldown_time = default_cooldown_time
self.in_memory_cache = InMemoryCache()
# Initialize the masker with custom settings for exception strings
self.exception_masker = SensitiveDataMasker(
visible_prefix=50, # Show first 50 characters
visible_suffix=0, # Show last 0 characters
mask_char="*", # Use * for masking
)
def _common_add_cooldown_logic(
self, model_id: str, original_exception, exception_status, cooldown_time: float
) -> Tuple[str, CooldownCacheValue]:
try:
current_time = time.time()
cooldown_key = CooldownCache.get_cooldown_cache_key(model_id)
# Store the cooldown information for the deployment separately
cooldown_data = CooldownCacheValue(
exception_received=self.exception_masker._mask_value(
str(original_exception)
),
status_code=str(exception_status),
timestamp=current_time,
cooldown_time=cooldown_time,
)
return cooldown_key, cooldown_data
except Exception as e:
verbose_logger.error(
"CooldownCache::_common_add_cooldown_logic - Exception occurred - {}".format(
str(e)
)
)
raise e
def add_deployment_to_cooldown(
self,
model_id: str,
original_exception: Exception,
exception_status: int,
cooldown_time: Optional[float],
):
try:
#########################################################
# get cooldown time
# 1. If dynamic cooldown time is set for the model/deployment, use that
# 2. If no dynamic cooldown time is set, use the default cooldown time set on CooldownCache
_cooldown_time = cooldown_time
if _cooldown_time is None:
_cooldown_time = self.default_cooldown_time
#########################################################
cooldown_key, cooldown_data = self._common_add_cooldown_logic(
model_id=model_id,
original_exception=original_exception,
exception_status=exception_status,
cooldown_time=_cooldown_time,
)
# Set the cache with a TTL equal to the cooldown time
self.cache.set_cache(
value=cooldown_data,
key=cooldown_key,
ttl=_cooldown_time,
)
except Exception as e:
verbose_logger.error(
"CooldownCache::add_deployment_to_cooldown - Exception occurred - {}".format(
str(e)
)
)
raise e
@staticmethod
@functools.lru_cache(maxsize=1024)
def get_cooldown_cache_key(model_id: str) -> str:
return "deployment:" + model_id + ":cooldown"
async def async_get_active_cooldowns(
self, model_ids: List[str], parent_otel_span: Optional[Span]
) -> List[Tuple[str, CooldownCacheValue]]:
# Generate the keys for the deployments
keys = [
CooldownCache.get_cooldown_cache_key(model_id) for model_id in model_ids
]
# Retrieve the values for the keys using mget
## more likely to be none if no models ratelimited. So just check redis every 1s
## each redis call adds ~100ms latency.
## check in memory cache first
results = await self.cache.async_batch_get_cache(
keys=keys, parent_otel_span=parent_otel_span
)
active_cooldowns: List[Tuple[str, CooldownCacheValue]] = []
if results is None or all(v is None for v in results):
return active_cooldowns
# Process the results
for model_id, result in zip(model_ids, results):
if result and isinstance(result, dict):
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore
active_cooldowns.append((model_id, cooldown_cache_value))
return active_cooldowns
def get_active_cooldowns(
self, model_ids: List[str], parent_otel_span: Optional[Span]
) -> List[Tuple[str, CooldownCacheValue]]:
# Generate the keys for the deployments
keys = [
CooldownCache.get_cooldown_cache_key(model_id) for model_id in model_ids
]
# Retrieve the values for the keys using mget
results = (
self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
or []
)
active_cooldowns = []
# Process the results
for model_id, result in zip(model_ids, results):
if result and isinstance(result, dict):
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore
active_cooldowns.append((model_id, cooldown_cache_value))
return active_cooldowns
def get_min_cooldown(
self, model_ids: List[str], parent_otel_span: Optional[Span]
) -> float:
"""Return min cooldown time required for a group of model id's."""
# Generate the keys for the deployments
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
# Retrieve the values for the keys using mget
results = (
self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
or []
)
min_cooldown_time: Optional[float] = None
# Process the results
for model_id, result in zip(model_ids, results):
if result and isinstance(result, dict):
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore
if min_cooldown_time is None:
min_cooldown_time = cooldown_cache_value["cooldown_time"]
elif cooldown_cache_value["cooldown_time"] < min_cooldown_time:
min_cooldown_time = cooldown_cache_value["cooldown_time"]
return min_cooldown_time or self.default_cooldown_time
# Usage example:
# cooldown_cache = CooldownCache(cache=your_cache_instance, cooldown_time=your_cooldown_time)
# cooldown_cache.add_deployment_to_cooldown(deployment, original_exception, exception_status)
# active_cooldowns = cooldown_cache.get_active_cooldowns()

View File

@@ -0,0 +1,101 @@
"""
Callbacks triggered on cooling down deployments
"""
import copy
from typing import TYPE_CHECKING, Any, Optional, Union
import litellm
from litellm._logging import verbose_logger
if TYPE_CHECKING:
from litellm.router import Router as _Router
LitellmRouter = _Router
from litellm.integrations.prometheus import PrometheusLogger
else:
LitellmRouter = Any
PrometheusLogger = Any
async def router_cooldown_event_callback(
litellm_router_instance: LitellmRouter,
deployment_id: str,
exception_status: Union[str, int],
cooldown_time: Optional[float],
):
"""
Callback triggered when a deployment is put into cooldown by litellm
- Updates deployment state on Prometheus
- Increments cooldown metric for deployment on Prometheus
"""
verbose_logger.debug("In router_cooldown_event_callback - updating prometheus")
_deployment = litellm_router_instance.get_deployment(model_id=deployment_id)
if _deployment is None:
verbose_logger.warning(
f"in router_cooldown_event_callback but _deployment is None for deployment_id={deployment_id}. Doing nothing"
)
return
_litellm_params = _deployment["litellm_params"]
temp_litellm_params = copy.deepcopy(_litellm_params)
temp_litellm_params = dict(temp_litellm_params)
_model_name = _deployment.get("model_name", None) or ""
_api_base = (
litellm.get_api_base(model=_model_name, optional_params=temp_litellm_params)
or ""
)
model_info = _deployment["model_info"]
model_id = model_info.id
litellm_model_name = temp_litellm_params.get("model") or ""
llm_provider = ""
try:
_, llm_provider, _, _ = litellm.get_llm_provider(
model=litellm_model_name,
custom_llm_provider=temp_litellm_params.get("custom_llm_provider"),
)
except Exception:
pass
# get the prometheus logger from in memory loggers
prometheusLogger: Optional[
PrometheusLogger
] = _get_prometheus_logger_from_callbacks()
if prometheusLogger is not None:
prometheusLogger.set_deployment_complete_outage(
litellm_model_name=_model_name,
model_id=model_id,
api_base=_api_base,
api_provider=llm_provider,
)
prometheusLogger.increment_deployment_cooled_down(
litellm_model_name=_model_name,
model_id=model_id,
api_base=_api_base,
api_provider=llm_provider,
exception_status=str(exception_status),
)
return
def _get_prometheus_logger_from_callbacks() -> Optional[PrometheusLogger]:
"""
Checks if prometheus is a initalized callback, if yes returns it
"""
from litellm.integrations.prometheus import PrometheusLogger
if PrometheusLogger is None:
return None
for _callback in litellm._async_success_callback:
if isinstance(_callback, PrometheusLogger):
return _callback
for global_callback in litellm.callbacks:
if isinstance(global_callback, PrometheusLogger):
return global_callback
return None

View File

@@ -0,0 +1,459 @@
"""
Router cooldown handlers
- _set_cooldown_deployments: puts a deployment in the cooldown list
- get_cooldown_deployments: returns the list of deployments in the cooldown list
- async_get_cooldown_deployments: ASYNC: returns the list of deployments in the cooldown list
"""
import asyncio
import math
from typing import TYPE_CHECKING, Any, List, Optional, Union
import litellm
from litellm._logging import verbose_router_logger
from litellm.constants import (
DEFAULT_COOLDOWN_TIME_SECONDS,
DEFAULT_FAILURE_THRESHOLD_MINIMUM_REQUESTS,
DEFAULT_FAILURE_THRESHOLD_PERCENT,
SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD,
)
from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback
from .router_callbacks.track_deployment_metrics import (
get_deployment_failures_for_current_minute,
get_deployment_successes_for_current_minute,
)
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.router import Router as _Router
LitellmRouter = _Router
Span = Union[_Span, Any]
else:
LitellmRouter = Any
Span = Any
def _is_cooldown_required(
litellm_router_instance: LitellmRouter,
model_id: str,
exception_status: Union[str, int],
exception_str: Optional[str] = None,
) -> bool:
"""
A function to determine if a cooldown is required based on the exception status.
Parameters:
model_id (str) The id of the model in the model list
exception_status (Union[str, int]): The status of the exception.
Returns:
bool: True if a cooldown is required, False otherwise.
"""
try:
ignored_strings = ["APIConnectionError"]
if (
exception_str is not None
): # don't cooldown on litellm api connection errors errors
for ignored_string in ignored_strings:
if ignored_string in exception_str:
return False
if isinstance(exception_status, str):
if len(exception_status) == 0:
return False
exception_status = int(exception_status)
if exception_status >= 400 and exception_status < 500:
if exception_status == 429:
# Cool down 429 Rate Limit Errors
return True
elif exception_status == 401:
# Cool down 401 Auth Errors
return True
elif exception_status == 408:
return True
elif exception_status == 404:
return True
else:
# Do NOT cool down all other 4XX Errors
return False
else:
# should cool down for all other errors
return True
except Exception:
# Catch all - if any exceptions default to cooling down
return True
def _should_run_cooldown_logic(
litellm_router_instance: LitellmRouter,
deployment: Optional[str],
exception_status: Union[str, int],
original_exception: Any,
time_to_cooldown: Optional[float] = None,
) -> bool:
"""
Helper that decides if cooldown logic should be run
Returns False if cooldown logic should not be run
Does not run cooldown logic when:
- router.disable_cooldowns is True
- deployment is None
- _is_cooldown_required() returns False
- deployment is in litellm_router_instance.provider_default_deployment_ids
- exception_status is not one that should be immediately retried (e.g. 401)
"""
if (
deployment is None
or litellm_router_instance.get_model_group(id=deployment) is None
):
verbose_router_logger.debug(
"Should Not Run Cooldown Logic: deployment id is none or model group can't be found."
)
return False
#########################################################
# If time_to_cooldown is 0 or 0.0000000, don't run cooldown logic
#########################################################
if time_to_cooldown is not None and math.isclose(
a=time_to_cooldown, b=0.0, abs_tol=1e-9
):
verbose_router_logger.debug(
"Should Not Run Cooldown Logic: time_to_cooldown is effectively 0"
)
return False
if litellm_router_instance.disable_cooldowns:
verbose_router_logger.debug(
"Should Not Run Cooldown Logic: disable_cooldowns is True"
)
return False
if deployment is None:
verbose_router_logger.debug("Should Not Run Cooldown Logic: deployment is None")
return False
if not _is_cooldown_required(
litellm_router_instance=litellm_router_instance,
model_id=deployment,
exception_status=exception_status,
exception_str=str(original_exception),
):
verbose_router_logger.debug(
"Should Not Run Cooldown Logic: _is_cooldown_required returned False"
)
return False
if deployment in litellm_router_instance.provider_default_deployment_ids:
verbose_router_logger.debug(
"Should Not Run Cooldown Logic: deployment is in provider_default_deployment_ids"
)
return False
return True
def _should_cooldown_deployment(
litellm_router_instance: LitellmRouter,
deployment: str,
exception_status: Union[str, int],
original_exception: Any,
) -> bool:
"""
Helper that decides if a deployment should be put in cooldown
Returns True if the deployment should be put in cooldown
Returns False if the deployment should not be put in cooldown
Deployment is put in cooldown when:
- v2 logic (Current):
cooldown if:
- got a 429 error from LLM API
- if %fails/%(successes + fails) > ALLOWED_FAILURE_RATE_PER_MINUTE
- got 401 Auth error, 404 NotFounder - checked by litellm._should_retry()
- v1 logic (Legacy): if allowed fails or allowed fail policy set, coolsdown if num fails in this minute > allowed fails
"""
## BASE CASE - single deployment
model_group = litellm_router_instance.get_model_group(id=deployment)
is_single_deployment_model_group = False
if model_group is not None and len(model_group) == 1:
is_single_deployment_model_group = True
if (
litellm_router_instance.allowed_fails_policy is None
and _is_allowed_fails_set_on_router(
litellm_router_instance=litellm_router_instance
)
is False
):
num_successes_this_minute = get_deployment_successes_for_current_minute(
litellm_router_instance=litellm_router_instance, deployment_id=deployment
)
num_fails_this_minute = get_deployment_failures_for_current_minute(
litellm_router_instance=litellm_router_instance, deployment_id=deployment
)
total_requests_this_minute = num_successes_this_minute + num_fails_this_minute
percent_fails = 0.0
if total_requests_this_minute > 0:
percent_fails = num_fails_this_minute / (
num_successes_this_minute + num_fails_this_minute
)
verbose_router_logger.debug(
"percent fails for deployment = %s, percent fails = %s, num successes = %s, num fails = %s",
deployment,
percent_fails,
num_successes_this_minute,
num_fails_this_minute,
)
exception_status_int = cast_exception_status_to_int(exception_status)
if exception_status_int == 429 and not is_single_deployment_model_group:
return True
elif (
percent_fails == 1.0
and total_requests_this_minute
>= SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD
):
# Cooldown if all requests failed and we have reasonable traffic
return True
elif (
percent_fails > DEFAULT_FAILURE_THRESHOLD_PERCENT
and total_requests_this_minute >= DEFAULT_FAILURE_THRESHOLD_MINIMUM_REQUESTS
and not is_single_deployment_model_group # by default we should avoid cooldowns on single deployment model groups
):
# Only apply error rate cooldown when we have enough requests to make the percentage meaningful
return True
elif (
litellm._should_retry(
status_code=cast_exception_status_to_int(exception_status)
)
is False
):
return True
return False
else:
return should_cooldown_based_on_allowed_fails_policy(
litellm_router_instance=litellm_router_instance,
deployment=deployment,
original_exception=original_exception,
)
return False
def _set_cooldown_deployments(
litellm_router_instance: LitellmRouter,
original_exception: Any,
exception_status: Union[str, int],
deployment: Optional[str] = None,
time_to_cooldown: Optional[float] = None,
) -> bool:
"""
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute
or
the exception is not one that should be immediately retried (e.g. 401)
Returns:
- True if the deployment should be put in cooldown
- False if the deployment should not be put in cooldown
"""
verbose_router_logger.debug("checks 'should_run_cooldown_logic'")
if (
_should_run_cooldown_logic(
litellm_router_instance=litellm_router_instance,
deployment=deployment,
exception_status=exception_status,
original_exception=original_exception,
time_to_cooldown=time_to_cooldown,
)
is False
or deployment is None
):
verbose_router_logger.debug("should_run_cooldown_logic returned False")
return False
exception_status_int = cast_exception_status_to_int(exception_status)
verbose_router_logger.debug(f"Attempting to add {deployment} to cooldown list")
if _should_cooldown_deployment(
litellm_router_instance=litellm_router_instance,
deployment=deployment,
exception_status=exception_status,
original_exception=original_exception,
):
litellm_router_instance.cooldown_cache.add_deployment_to_cooldown(
model_id=deployment,
original_exception=original_exception,
exception_status=exception_status_int,
cooldown_time=time_to_cooldown,
)
# Trigger cooldown callback handler
asyncio.create_task(
router_cooldown_event_callback(
litellm_router_instance=litellm_router_instance,
deployment_id=deployment,
exception_status=exception_status,
cooldown_time=time_to_cooldown,
)
)
return True
return False
async def _async_get_cooldown_deployments(
litellm_router_instance: LitellmRouter,
parent_otel_span: Optional[Span],
) -> List[str]:
"""
Async implementation of '_get_cooldown_deployments'
"""
model_ids = litellm_router_instance.get_model_ids()
cooldown_models = (
await litellm_router_instance.cooldown_cache.async_get_active_cooldowns(
model_ids=model_ids,
parent_otel_span=parent_otel_span,
)
)
cached_value_deployment_ids = []
if (
cooldown_models is not None
and isinstance(cooldown_models, list)
and len(cooldown_models) > 0
and isinstance(cooldown_models[0], tuple)
):
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cached_value_deployment_ids
async def _async_get_cooldown_deployments_with_debug_info(
litellm_router_instance: LitellmRouter,
parent_otel_span: Optional[Span],
) -> List[tuple]:
"""
Async implementation of '_get_cooldown_deployments'
"""
model_ids = litellm_router_instance.get_model_ids()
cooldown_models = (
await litellm_router_instance.cooldown_cache.async_get_active_cooldowns(
model_ids=model_ids, parent_otel_span=parent_otel_span
)
)
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models
def _get_cooldown_deployments(
litellm_router_instance: LitellmRouter, parent_otel_span: Optional[Span]
) -> List[str]:
"""
Get the list of models being cooled down for this minute
"""
# get the current cooldown list for that minute
# ----------------------
# Return cooldown models
# ----------------------
model_ids = litellm_router_instance.get_model_ids()
cooldown_models = litellm_router_instance.cooldown_cache.get_active_cooldowns(
model_ids=model_ids, parent_otel_span=parent_otel_span
)
cached_value_deployment_ids = []
if (
cooldown_models is not None
and isinstance(cooldown_models, list)
and len(cooldown_models) > 0
and isinstance(cooldown_models[0], tuple)
):
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
return cached_value_deployment_ids
def should_cooldown_based_on_allowed_fails_policy(
litellm_router_instance: LitellmRouter,
deployment: str,
original_exception: Any,
) -> bool:
"""
Check if fails are within the allowed limit and update the number of fails.
Returns:
- True if fails exceed the allowed limit (should cooldown)
- False if fails are within the allowed limit (should not cooldown)
"""
allowed_fails = (
litellm_router_instance.get_allowed_fails_from_policy(
exception=original_exception,
)
or litellm_router_instance.allowed_fails
)
cooldown_time = (
litellm_router_instance.cooldown_time or DEFAULT_COOLDOWN_TIME_SECONDS
)
current_fails = litellm_router_instance.failed_calls.get_cache(key=deployment) or 0
updated_fails = current_fails + 1
if updated_fails > allowed_fails:
return True
else:
litellm_router_instance.failed_calls.set_cache(
key=deployment, value=updated_fails, ttl=cooldown_time
)
return False
def _is_allowed_fails_set_on_router(
litellm_router_instance: LitellmRouter,
) -> bool:
"""
Check if Router.allowed_fails is set or is Non-default Value
Returns:
- True if Router.allowed_fails is set or is Non-default Value
- False if Router.allowed_fails is None or is Default Value
"""
if litellm_router_instance.allowed_fails is None:
return False
if litellm_router_instance.allowed_fails != litellm.allowed_fails:
return True
return False
def cast_exception_status_to_int(exception_status: Union[str, int]) -> int:
if isinstance(exception_status, str):
try:
exception_status = int(exception_status)
except Exception:
verbose_router_logger.debug(
f"Unable to cast exception status to int {exception_status}. Defaulting to status=500."
)
exception_status = 500
return exception_status

View File

@@ -0,0 +1,257 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import litellm
from litellm._logging import verbose_router_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.router_utils.add_retry_fallback_headers import (
add_fallback_headers_to_response,
)
from litellm.types.router import LiteLLMParamsTypedDict
if TYPE_CHECKING:
from litellm.router import Router as _Router
LitellmRouter = _Router
else:
LitellmRouter = Any
def _check_stripped_model_group(model_group: str, fallback_key: str) -> bool:
"""
Handles wildcard routing scenario
where fallbacks set like:
[{"gpt-3.5-turbo": ["claude-3-haiku"]}]
but model_group is like:
"openai/gpt-3.5-turbo"
Returns:
- True if the stripped model group == fallback_key
"""
for provider in litellm.provider_list:
if isinstance(provider, Enum):
_provider = provider.value
else:
_provider = provider
if model_group.startswith(f"{_provider}/"):
stripped_model_group = model_group.replace(f"{_provider}/", "")
if stripped_model_group == fallback_key:
return True
return False
def get_fallback_model_group(
fallbacks: List[Any], model_group: str
) -> Tuple[Optional[List[str]], Optional[int]]:
"""
Returns:
- fallback_model_group: List[str] of fallback model groups. example: ["gpt-4", "gpt-3.5-turbo"]
- generic_fallback_idx: int of the index of the generic fallback in the fallbacks list.
Checks:
- exact match
- stripped model group match
- generic fallback
"""
generic_fallback_idx: Optional[int] = None
stripped_model_fallback: Optional[List[str]] = None
fallback_model_group: Optional[List[str]] = None
## check for specific model group-specific fallbacks
for idx, item in enumerate(fallbacks):
if isinstance(item, dict):
if list(item.keys())[0] == model_group: # check exact match
fallback_model_group = item[model_group]
break
elif _check_stripped_model_group(
model_group=model_group, fallback_key=list(item.keys())[0]
): # check generic fallback
stripped_model_fallback = item[list(item.keys())[0]]
elif list(item.keys())[0] == "*": # check generic fallback
generic_fallback_idx = idx
elif isinstance(item, str):
fallback_model_group = [fallbacks.pop(idx)] # returns single-item list
## if none, check for generic fallback
if fallback_model_group is None:
if stripped_model_fallback is not None:
fallback_model_group = stripped_model_fallback
elif generic_fallback_idx is not None:
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
return fallback_model_group, generic_fallback_idx
async def run_async_fallback(
*args: Tuple[Any],
litellm_router: LitellmRouter,
fallback_model_group: List[str],
original_model_group: str,
original_exception: Exception,
max_fallbacks: int,
fallback_depth: int,
**kwargs,
) -> Any:
"""
Loops through all the fallback model groups and calls kwargs["original_function"] with the arguments and keyword arguments provided.
If the call is successful, it logs the success and returns the response.
If the call fails, it logs the failure and continues to the next fallback model group.
If all fallback model groups fail, it raises the most recent exception.
Args:
litellm_router: The litellm router instance.
*args: Positional arguments.
fallback_model_group: List[str] of fallback model groups. example: ["gpt-4", "gpt-3.5-turbo"]
original_model_group: The original model group. example: "gpt-3.5-turbo"
original_exception: The original exception.
**kwargs: Keyword arguments.
Returns:
The response from the successful fallback model group.
Raises:
The most recent exception if all fallback model groups fail.
"""
### BASE CASE ### MAX FALLBACK DEPTH REACHED
if fallback_depth >= max_fallbacks:
raise original_exception
error_from_fallbacks = original_exception
for mg in fallback_model_group:
if mg == original_model_group:
continue
try:
# LOGGING
kwargs = litellm_router.log_retry(kwargs=kwargs, e=original_exception)
verbose_router_logger.info(f"Falling back to model_group = {mg}")
if isinstance(mg, str):
kwargs["model"] = mg
elif isinstance(mg, dict):
kwargs.update(mg)
kwargs.setdefault("metadata", {}).update(
{"model_group": kwargs.get("model", None)}
) # update model_group used, if fallbacks are done
fallback_depth = fallback_depth + 1
kwargs["fallback_depth"] = fallback_depth
kwargs["max_fallbacks"] = max_fallbacks
response = await litellm_router.async_function_with_fallbacks(
*args, **kwargs
)
verbose_router_logger.info("Successful fallback b/w models.")
response = add_fallback_headers_to_response(
response=response,
attempted_fallbacks=fallback_depth,
)
# callback for successfull_fallback_event():
await log_success_fallback_event(
original_model_group=original_model_group,
kwargs=kwargs,
original_exception=original_exception,
)
return response
except Exception as e:
error_from_fallbacks = e
await log_failure_fallback_event(
original_model_group=original_model_group,
kwargs=kwargs,
original_exception=original_exception,
)
raise error_from_fallbacks
async def log_success_fallback_event(
original_model_group: str, kwargs: dict, original_exception: Exception
):
"""
Log a successful fallback event to all registered callbacks.
Uses LoggingCallbackManager.get_custom_loggers_for_type() to get deduplicated
CustomLogger instances from all callback lists.
Args:
original_model_group (str): The original model group before fallback.
kwargs (dict): kwargs for the request
Note:
Errors during logging are caught and reported but do not interrupt the process.
"""
# Get deduplicated CustomLogger instances from all callback lists
custom_loggers = litellm.logging_callback_manager.get_custom_loggers_for_type(
CustomLogger
)
for _callback_custom_logger in custom_loggers:
try:
await _callback_custom_logger.log_success_fallback_event(
original_model_group=original_model_group,
kwargs=kwargs,
original_exception=original_exception,
)
except Exception as e:
verbose_router_logger.error(
f"Error in log_success_fallback_event: {str(e)}"
)
async def log_failure_fallback_event(
original_model_group: str, kwargs: dict, original_exception: Exception
):
"""
Log a failed fallback event to all registered callbacks.
Uses LoggingCallbackManager.get_custom_loggers_for_type() to get deduplicated
CustomLogger instances from all callback lists.
Args:
original_model_group (str): The original model group before fallback.
kwargs (dict): kwargs for the request
Note:
Errors during logging are caught and reported but do not interrupt the process.
"""
# Get deduplicated CustomLogger instances from all callback lists
custom_loggers = litellm.logging_callback_manager.get_custom_loggers_for_type(
CustomLogger
)
for _callback_custom_logger in custom_loggers:
try:
await _callback_custom_logger.log_failure_fallback_event(
original_model_group=original_model_group,
kwargs=kwargs,
original_exception=original_exception,
)
except Exception as e:
verbose_router_logger.error(
f"Error in log_failure_fallback_event: {str(e)}"
)
def _check_non_standard_fallback_format(fallbacks: Optional[List[Any]]) -> bool:
"""
Checks if the fallbacks list is a list of strings or a list of dictionaries.
If
- List[str]: e.g. ["claude-3-haiku", "openai/o-1"]
- List[Dict[<LiteLLMParamsTypedDict>, Any]]: e.g. [{"model": "claude-3-haiku", "messages": [{"role": "user", "content": "Hey, how's it going?"}]}]
If [{"gpt-3.5-turbo": ["claude-3-haiku"]}] then standard format.
"""
if fallbacks is None or not isinstance(fallbacks, list) or len(fallbacks) == 0:
return False
if all(isinstance(item, str) for item in fallbacks):
return True
elif all(isinstance(item, dict) for item in fallbacks):
for key in LiteLLMParamsTypedDict.__annotations__.keys():
if key in fallbacks[0].keys():
return True
return False
def run_non_standard_fallback_format(
fallbacks: Union[List[str], List[Dict[str, Any]]], model_group: str
):
pass

View File

@@ -0,0 +1,71 @@
"""
Get num retries for an exception.
- Account for retry policy by exception type.
"""
from typing import Dict, Optional, Union
from litellm.exceptions import (
AuthenticationError,
BadRequestError,
ContentPolicyViolationError,
RateLimitError,
Timeout,
)
from litellm.types.router import RetryPolicy
def get_num_retries_from_retry_policy(
exception: Exception,
retry_policy: Optional[Union[RetryPolicy, dict]] = None,
model_group: Optional[str] = None,
model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = None,
):
"""
BadRequestErrorRetries: Optional[int] = None
AuthenticationErrorRetries: Optional[int] = None
TimeoutErrorRetries: Optional[int] = None
RateLimitErrorRetries: Optional[int] = None
ContentPolicyViolationErrorRetries: Optional[int] = None
"""
# if we can find the exception then in the retry policy -> return the number of retries
if (
model_group_retry_policy is not None
and model_group is not None
and model_group in model_group_retry_policy
):
retry_policy = model_group_retry_policy.get(model_group, None) # type: ignore
if retry_policy is None:
return None
if isinstance(retry_policy, dict):
retry_policy = RetryPolicy(**retry_policy)
if (
isinstance(exception, AuthenticationError)
and retry_policy.AuthenticationErrorRetries is not None
):
return retry_policy.AuthenticationErrorRetries
if isinstance(exception, Timeout) and retry_policy.TimeoutErrorRetries is not None:
return retry_policy.TimeoutErrorRetries
if (
isinstance(exception, RateLimitError)
and retry_policy.RateLimitErrorRetries is not None
):
return retry_policy.RateLimitErrorRetries
if (
isinstance(exception, ContentPolicyViolationError)
and retry_policy.ContentPolicyViolationErrorRetries is not None
):
return retry_policy.ContentPolicyViolationErrorRetries
if (
isinstance(exception, BadRequestError)
and retry_policy.BadRequestErrorRetries is not None
):
return retry_policy.BadRequestErrorRetries
def reset_retry_policy() -> RetryPolicy:
return RetryPolicy()

View File

@@ -0,0 +1,95 @@
from typing import TYPE_CHECKING, Any, Optional, Union
from litellm._logging import verbose_router_logger
from litellm.constants import MAX_EXCEPTION_MESSAGE_LENGTH
from litellm.router_utils.cooldown_handlers import (
_async_get_cooldown_deployments_with_debug_info,
)
from litellm.types.integrations.slack_alerting import AlertType
from litellm.types.router import RouterRateLimitError
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.router import Router as _Router
LitellmRouter = _Router
Span = Union[_Span, Any]
else:
LitellmRouter = Any
Span = Any
async def send_llm_exception_alert(
litellm_router_instance: LitellmRouter,
request_kwargs: dict,
error_traceback_str: str,
original_exception,
):
"""
Only runs if router.slack_alerting_logger is set
Sends a Slack / MS Teams alert for the LLM API call failure. Only if router.slack_alerting_logger is set.
Parameters:
litellm_router_instance (_Router): The LitellmRouter instance.
original_exception (Any): The original exception that occurred.
Returns:
None
"""
if litellm_router_instance is None:
return
if not hasattr(litellm_router_instance, "slack_alerting_logger"):
return
if litellm_router_instance.slack_alerting_logger is None:
return
if "proxy_server_request" in request_kwargs:
# Do not send any alert if it's a request from litellm proxy server request
# the proxy is already instrumented to send LLM API call failures
return
litellm_debug_info = getattr(original_exception, "litellm_debug_info", None)
exception_str = str(original_exception)
if litellm_debug_info is not None:
exception_str += litellm_debug_info
exception_str += f"\n\n{error_traceback_str[:MAX_EXCEPTION_MESSAGE_LENGTH]}"
await litellm_router_instance.slack_alerting_logger.send_alert(
message=f"LLM API call failed: `{exception_str}`",
level="High",
alert_type=AlertType.llm_exceptions,
alerting_metadata={},
)
async def async_raise_no_deployment_exception(
litellm_router_instance: LitellmRouter, model: str, parent_otel_span: Optional[Span]
):
"""
Raises a RouterRateLimitError if no deployment is found for the given model.
"""
verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available"
)
model_ids = litellm_router_instance.get_model_ids(model_name=model)
_cooldown_time = litellm_router_instance.cooldown_cache.get_min_cooldown(
model_ids=model_ids, parent_otel_span=parent_otel_span
)
_cooldown_list = await _async_get_cooldown_deployments_with_debug_info(
litellm_router_instance=litellm_router_instance,
parent_otel_span=parent_otel_span,
)
verbose_router_logger.info(
f"No deployment found for model: {model}, cooldown_list with debug info: {_cooldown_list}"
)
cooldown_list_ids = [cooldown_model[0] for cooldown_model in (_cooldown_list or [])]
return RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,
enable_pre_call_checks=litellm_router_instance.enable_pre_call_checks,
cooldown_list=cooldown_list_ids,
)

View File

@@ -0,0 +1,264 @@
"""
Class to handle llm wildcard routing and regex pattern matching
"""
import copy
import re
from re import Match
from typing import Dict, List, Optional, Tuple
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm._logging import verbose_router_logger
class PatternUtils:
@staticmethod
def calculate_pattern_specificity(pattern: str) -> Tuple[int, int]:
"""
Calculate pattern specificity based on length and complexity.
Args:
pattern: Regex pattern to analyze
Returns:
Tuple of (length, complexity) for sorting
"""
complexity_chars = ["*", "+", "?", "\\", "^", "$", "|", "(", ")"]
ret_val = (
len(pattern), # Longer patterns more specific
sum(
pattern.count(char) for char in complexity_chars
), # More regex complexity
)
return ret_val
@staticmethod
def sorted_patterns(
patterns: Dict[str, List[Dict]]
) -> List[Tuple[str, List[Dict]]]:
"""
Cached property for patterns sorted by specificity.
Returns:
Sorted list of pattern-deployment tuples
"""
return sorted(
patterns.items(),
key=lambda x: PatternUtils.calculate_pattern_specificity(x[0]),
reverse=True,
)
class PatternMatchRouter:
"""
Class to handle llm wildcard routing and regex pattern matching
doc: https://docs.litellm.ai/docs/proxy/configs#provider-specific-wildcard-routing
This class will store a mapping for regex pattern: List[Deployments]
"""
def __init__(self):
self.patterns: Dict[str, List] = {}
def add_pattern(self, pattern: str, llm_deployment: Dict):
"""
Add a regex pattern and the corresponding llm deployments to the patterns
Args:
pattern: str
llm_deployment: str or List[str]
"""
# Convert the pattern to a regex
regex = self._pattern_to_regex(pattern)
if regex not in self.patterns:
self.patterns[regex] = []
self.patterns[regex].append(llm_deployment)
def _pattern_to_regex(self, pattern: str) -> str:
"""
Convert a wildcard pattern to a regex pattern
example:
pattern: openai/*
regex: openai/.*
pattern: openai/fo::*::static::*
regex: openai/fo::.*::static::.*
Args:
pattern: str
Returns:
str: regex pattern
"""
# # Replace '*' with '.*' for regex matching
# regex = pattern.replace("*", ".*")
# # Escape other special characters
# regex = re.escape(regex).replace(r"\.\*", ".*")
# return f"^{regex}$"
return re.escape(pattern).replace(r"\*", "(.*)")
def _return_pattern_matched_deployments(
self, matched_pattern: Match, deployments: List[Dict]
) -> List[Dict]:
new_deployments = []
for deployment in deployments:
new_deployment = copy.deepcopy(deployment)
new_deployment["litellm_params"][
"model"
] = PatternMatchRouter.set_deployment_model_name(
matched_pattern=matched_pattern,
litellm_deployment_litellm_model=deployment["litellm_params"]["model"],
)
new_deployments.append(new_deployment)
return new_deployments
def route(
self, request: Optional[str], filtered_model_names: Optional[List[str]] = None
) -> Optional[List[Dict]]:
"""
Route a requested model to the corresponding llm deployments based on the regex pattern
loop through all the patterns and find the matching pattern
if a pattern is found, return the corresponding llm deployments
if no pattern is found, return None
Args:
request: str - the received model name from the user (can be a wildcard route). If none, No deployments will be returned.
filtered_model_names: Optional[List[str]] - if provided, only return deployments that match the filtered_model_names
Returns:
Optional[List[Deployment]]: llm deployments
"""
try:
if request is None:
return None
sorted_patterns = PatternUtils.sorted_patterns(self.patterns)
regex_filtered_model_names = (
[self._pattern_to_regex(m) for m in filtered_model_names]
if filtered_model_names is not None
else []
)
for pattern, llm_deployments in sorted_patterns:
if (
filtered_model_names is not None
and pattern not in regex_filtered_model_names
):
continue
pattern_match = re.match(pattern, request)
if pattern_match:
return self._return_pattern_matched_deployments(
matched_pattern=pattern_match, deployments=llm_deployments
)
except Exception as e:
verbose_router_logger.debug(f"Error in PatternMatchRouter.route: {str(e)}")
return None # No matching pattern found
@staticmethod
def set_deployment_model_name(
matched_pattern: Match,
litellm_deployment_litellm_model: str,
) -> str:
"""
Set the model name for the matched pattern llm deployment
E.g.:
Case 1:
model_name: llmengine/* (can be any regex pattern or wildcard pattern)
litellm_params:
model: openai/*
if model_name = "llmengine/foo" -> model = "openai/foo"
Case 2:
model_name: llmengine/fo::*::static::*
litellm_params:
model: openai/fo::*::static::*
if model_name = "llmengine/foo::bar::static::baz" -> model = "openai/foo::bar::static::baz"
Case 3:
model_name: *meta.llama3*
litellm_params:
model: bedrock/meta.llama3*
if model_name = "hello-world-meta.llama3-70b" -> model = "bedrock/meta.llama3-70b"
"""
## BASE CASE: if the deployment model name does not contain a wildcard, return the deployment model name
if "*" not in litellm_deployment_litellm_model:
return litellm_deployment_litellm_model
wildcard_count = litellm_deployment_litellm_model.count("*")
# Extract all dynamic segments from the request
dynamic_segments = matched_pattern.groups()
if len(dynamic_segments) > wildcard_count:
return (
matched_pattern.string
) # default to the user input, if unable to map based on wildcards.
# Replace the corresponding wildcards in the litellm model pattern with extracted segments
for segment in dynamic_segments:
litellm_deployment_litellm_model = litellm_deployment_litellm_model.replace(
"*", segment, 1
)
return litellm_deployment_litellm_model
def get_pattern(
self, model: str, custom_llm_provider: Optional[str] = None
) -> Optional[List[Dict]]:
"""
Check if a pattern exists for the given model and custom llm provider
Args:
model: str
custom_llm_provider: Optional[str]
Returns:
bool: True if pattern exists, False otherwise
"""
if custom_llm_provider is None:
try:
(
_,
custom_llm_provider,
_,
_,
) = get_llm_provider(model=model)
except Exception:
# get_llm_provider raises exception when provider is unknown
pass
return self.route(model) or self.route(f"{custom_llm_provider}/{model}")
def get_deployments_by_pattern(
self, model: str, custom_llm_provider: Optional[str] = None
) -> List[Dict]:
"""
Get the deployments by pattern
Args:
model: str
custom_llm_provider: Optional[str]
Returns:
List[Dict]: llm deployments matching the pattern
"""
pattern_match = self.get_pattern(model, custom_llm_provider)
if pattern_match:
return pattern_match
return []
# Example usage:
# router = PatternRouter()
# router.add_pattern('openai/*', [Deployment(), Deployment()])
# router.add_pattern('openai/fo::*::static::*', Deployment())
# print(router.route('openai/gpt-4')) # Output: [Deployment(), Deployment()]
# print(router.route('openai/fo::hi::static::hi')) # Output: [Deployment()]
# print(router.route('something/else')) # Output: None

View File

@@ -0,0 +1,506 @@
"""
Unified deployment affinity (session stickiness) for the Router.
Features (independently enable-able):
1. Responses API continuity: when a `previous_response_id` is provided, route to the
deployment that generated the original response (highest priority).
2. API-key affinity: map an API key hash -> deployment id for a TTL and re-use that
deployment for subsequent requests to the same router deployment model name
(alias-safe, aligns to `model_map_information.model_map_key`).
This is designed to support "implicit prompt caching" scenarios (no explicit cache_control),
where routing to a consistent deployment is still beneficial.
"""
import hashlib
from typing import Any, Dict, List, Optional, cast
from typing_extensions import TypedDict
from litellm._logging import verbose_router_logger
from litellm.caching.dual_cache import DualCache
from litellm.integrations.custom_logger import CustomLogger, Span
from litellm.responses.utils import ResponsesAPIRequestUtils
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import CallTypes
class DeploymentAffinityCacheValue(TypedDict):
model_id: str
class DeploymentAffinityCheck(CustomLogger):
"""
Router deployment affinity callback.
NOTE: This is a Router-only callback intended to be wired through
`Router(optional_pre_call_checks=[...])`.
"""
CACHE_KEY_PREFIX = "deployment_affinity:v1"
def __init__(
self,
cache: DualCache,
ttl_seconds: int,
enable_user_key_affinity: bool,
enable_responses_api_affinity: bool,
enable_session_id_affinity: bool = False,
):
super().__init__()
self.cache = cache
self.ttl_seconds = ttl_seconds
self.enable_user_key_affinity = enable_user_key_affinity
self.enable_responses_api_affinity = enable_responses_api_affinity
self.enable_session_id_affinity = enable_session_id_affinity
@staticmethod
def _looks_like_sha256_hex(value: str) -> bool:
if len(value) != 64:
return False
try:
int(value, 16)
except ValueError:
return False
return True
@staticmethod
def _hash_user_key(user_key: str) -> str:
"""
Hash user identifiers before storing them in cache keys.
This avoids putting raw API keys / user identifiers into Redis keys (and therefore
into logs/metrics), while keeping the cache key stable and a fixed length.
"""
# If the proxy already provides a stable SHA-256 (e.g. `metadata.user_api_key_hash`),
# keep it as-is to avoid double-hashing and to make correlation/debugging possible.
if DeploymentAffinityCheck._looks_like_sha256_hex(user_key):
return user_key.lower()
return hashlib.sha256(user_key.encode("utf-8")).hexdigest()
@staticmethod
def _get_model_map_key_from_litellm_model_name(
litellm_model_name: str,
) -> Optional[str]:
"""
Best-effort derivation of a stable "model map key" for affinity scoping.
The intent is to align with `standard_logging_payload.model_map_information.model_map_key`,
which is typically the base model identifier (stable across deployments/endpoints).
Notes:
- When the model name is in "provider/model" format, the provider prefix is stripped.
- For Azure, the string after "azure/" is commonly an *Azure deployment name*, which may
differ across instances. If `base_model` is not explicitly set, we skip deriving a
model-map key from the model string to avoid generating unstable keys.
"""
if not litellm_model_name:
return None
if "/" not in litellm_model_name:
return litellm_model_name
provider_prefix, remainder = litellm_model_name.split("/", 1)
if provider_prefix == "azure":
return None
return remainder
@staticmethod
def _get_model_map_key_from_deployment(deployment: dict) -> Optional[str]:
"""
Derive a stable model-map key from a router deployment dict.
Primary source: `deployment.model_name` (Router's canonical group name after
alias resolution). This is stable across provider-specific deployments (e.g.,
Azure/Vertex/Bedrock for the same logical model) and aligns with
`model_map_information.model_map_key` in standard logging.
Prefer `base_model` when available (important for Azure), otherwise fall back to
parsing `litellm_params.model`.
"""
model_name = deployment.get("model_name")
if isinstance(model_name, str) and model_name:
return model_name
model_info = deployment.get("model_info")
if isinstance(model_info, dict):
base_model = model_info.get("base_model")
if isinstance(base_model, str) and base_model:
return base_model
litellm_params = deployment.get("litellm_params")
if isinstance(litellm_params, dict):
base_model = litellm_params.get("base_model")
if isinstance(base_model, str) and base_model:
return base_model
litellm_model_name = litellm_params.get("model")
if isinstance(litellm_model_name, str) and litellm_model_name:
return (
DeploymentAffinityCheck._get_model_map_key_from_litellm_model_name(
litellm_model_name
)
)
return None
@staticmethod
def _get_stable_model_map_key_from_deployments(
healthy_deployments: List[dict],
) -> Optional[str]:
"""
Only use model-map key scoping when it is stable across the deployment set.
This prevents accidentally keying on per-deployment identifiers like Azure deployment
names (when `base_model` is not configured).
"""
if not healthy_deployments:
return None
keys: List[str] = []
for deployment in healthy_deployments:
key = DeploymentAffinityCheck._get_model_map_key_from_deployment(deployment)
if key is None:
return None
keys.append(key)
unique_keys = set(keys)
if len(unique_keys) != 1:
return None
return keys[0]
@staticmethod
def _shorten_for_logs(value: str, keep: int = 8) -> str:
if len(value) <= keep:
return value
return f"{value[:keep]}..."
@classmethod
def get_affinity_cache_key(cls, model_group: str, user_key: str) -> str:
hashed_user_key = cls._hash_user_key(user_key=user_key)
return f"{cls.CACHE_KEY_PREFIX}:{model_group}:{hashed_user_key}"
@classmethod
def get_session_affinity_cache_key(cls, model_group: str, session_id: str) -> str:
return f"{cls.CACHE_KEY_PREFIX}:session:{model_group}:{session_id}"
@staticmethod
def _get_user_key_from_metadata_dict(metadata: dict) -> Optional[str]:
# NOTE: affinity is keyed on the *API key hash* provided by the proxy (not the
# OpenAI `user` parameter, which is an end-user identifier).
user_key = metadata.get("user_api_key_hash")
if user_key is None:
return None
return str(user_key)
@staticmethod
def _get_session_id_from_metadata_dict(metadata: dict) -> Optional[str]:
session_id = metadata.get("session_id")
if session_id is None:
return None
return str(session_id)
@staticmethod
def _iter_metadata_dicts(request_kwargs: dict) -> List[dict]:
"""
Return all metadata dicts available on the request.
Depending on the endpoint, Router may populate `metadata` or `litellm_metadata`.
Users may also send one or both, so we check both (rather than using `or`).
"""
metadata_dicts: List[dict] = []
for key in ("litellm_metadata", "metadata"):
md = request_kwargs.get(key)
if isinstance(md, dict):
metadata_dicts.append(md)
return metadata_dicts
@staticmethod
def _get_user_key_from_request_kwargs(request_kwargs: dict) -> Optional[str]:
"""
Extract a stable affinity key from request kwargs.
Source (proxy): `metadata.user_api_key_hash`
Note: the OpenAI `user` parameter is an end-user identifier and is intentionally
not used for deployment affinity.
"""
# Check metadata dicts (Proxy usage)
for metadata in DeploymentAffinityCheck._iter_metadata_dicts(request_kwargs):
user_key = DeploymentAffinityCheck._get_user_key_from_metadata_dict(
metadata=metadata
)
if user_key is not None:
return user_key
return None
@staticmethod
def _get_session_id_from_request_kwargs(request_kwargs: dict) -> Optional[str]:
for metadata in DeploymentAffinityCheck._iter_metadata_dicts(request_kwargs):
session_id = DeploymentAffinityCheck._get_session_id_from_metadata_dict(
metadata=metadata
)
if session_id is not None:
return session_id
return None
@staticmethod
def _find_deployment_by_model_id(
healthy_deployments: List[dict], model_id: str
) -> Optional[dict]:
for deployment in healthy_deployments:
model_info = deployment.get("model_info")
if not isinstance(model_info, dict):
continue
deployment_model_id = model_info.get("id")
if deployment_model_id is not None and str(deployment_model_id) == str(
model_id
):
return deployment
return None
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None,
) -> List[dict]:
"""
Optionally filter healthy deployments based on:
1. `previous_response_id` (Responses API continuity) [highest priority]
2. cached API-key deployment affinity
"""
request_kwargs = request_kwargs or {}
typed_healthy_deployments = cast(List[dict], healthy_deployments)
# 1) Responses API continuity (high priority)
if self.enable_responses_api_affinity:
previous_response_id = request_kwargs.get("previous_response_id")
if previous_response_id is not None:
responses_model_id = (
ResponsesAPIRequestUtils.get_model_id_from_response_id(
str(previous_response_id)
)
)
if responses_model_id is not None:
deployment = self._find_deployment_by_model_id(
healthy_deployments=typed_healthy_deployments,
model_id=responses_model_id,
)
if deployment is not None:
verbose_router_logger.debug(
"DeploymentAffinityCheck: previous_response_id pinning -> deployment=%s",
responses_model_id,
)
return [deployment]
stable_model_map_key = self._get_stable_model_map_key_from_deployments(
healthy_deployments=typed_healthy_deployments
)
if stable_model_map_key is None:
return typed_healthy_deployments
# 2) Session-id -> deployment affinity
if self.enable_session_id_affinity:
session_id = self._get_session_id_from_request_kwargs(
request_kwargs=request_kwargs
)
if session_id is not None:
session_cache_key = self.get_session_affinity_cache_key(
model_group=stable_model_map_key, session_id=session_id
)
session_cache_result = await self.cache.async_get_cache(
key=session_cache_key
)
session_model_id: Optional[str] = None
if isinstance(session_cache_result, dict):
session_model_id = cast(
Optional[str], session_cache_result.get("model_id")
)
elif isinstance(session_cache_result, str):
session_model_id = session_cache_result
if session_model_id:
session_deployment = self._find_deployment_by_model_id(
healthy_deployments=typed_healthy_deployments,
model_id=session_model_id,
)
if session_deployment is not None:
verbose_router_logger.debug(
"DeploymentAffinityCheck: session-id affinity hit -> deployment=%s session_id=%s",
session_model_id,
session_id,
)
return [session_deployment]
else:
verbose_router_logger.debug(
"DeploymentAffinityCheck: session-id pinned deployment=%s not found in healthy_deployments",
session_model_id,
)
# 3) User key -> deployment affinity
if not self.enable_user_key_affinity:
return typed_healthy_deployments
user_key = self._get_user_key_from_request_kwargs(request_kwargs=request_kwargs)
if user_key is None:
return typed_healthy_deployments
cache_key = self.get_affinity_cache_key(
model_group=stable_model_map_key, user_key=user_key
)
cache_result = await self.cache.async_get_cache(key=cache_key)
model_id: Optional[str] = None
if isinstance(cache_result, dict):
model_id = cast(Optional[str], cache_result.get("model_id"))
elif isinstance(cache_result, str):
# Backwards / safety: allow raw string values.
model_id = cache_result
if not model_id:
return typed_healthy_deployments
deployment = self._find_deployment_by_model_id(
healthy_deployments=typed_healthy_deployments,
model_id=model_id,
)
if deployment is None:
verbose_router_logger.debug(
"DeploymentAffinityCheck: pinned deployment=%s not found in healthy_deployments",
model_id,
)
return typed_healthy_deployments
verbose_router_logger.debug(
"DeploymentAffinityCheck: api-key affinity hit -> deployment=%s user_key=%s",
model_id,
self._shorten_for_logs(user_key),
)
return [deployment]
async def async_pre_call_deployment_hook(
self, kwargs: Dict[str, Any], call_type: Optional[CallTypes]
) -> Optional[dict]:
"""
Persist/update the API-key -> deployment mapping for this request.
Why pre-call?
- LiteLLM runs async success callbacks via a background logging worker for performance.
- We want affinity to be immediately available for subsequent requests.
"""
if not self.enable_user_key_affinity and not self.enable_session_id_affinity:
return None
user_key = None
if self.enable_user_key_affinity:
user_key = self._get_user_key_from_request_kwargs(request_kwargs=kwargs)
session_id = None
if self.enable_session_id_affinity:
session_id = self._get_session_id_from_request_kwargs(request_kwargs=kwargs)
if user_key is None and session_id is None:
return None
metadata_dicts = self._iter_metadata_dicts(kwargs)
model_info = kwargs.get("model_info")
if not isinstance(model_info, dict):
model_info = None
if model_info is None:
for metadata in metadata_dicts:
maybe_model_info = metadata.get("model_info")
if isinstance(maybe_model_info, dict):
model_info = maybe_model_info
break
if model_info is None:
# Router sets `model_info` after selecting a deployment. If it's missing, this is
# likely a non-router call or a call path that doesn't support affinity.
return None
model_id = model_info.get("id")
if not model_id:
verbose_router_logger.warning(
"DeploymentAffinityCheck: model_id missing; skipping affinity cache update."
)
return None
# Scope affinity by the Router deployment model name (alias-safe, consistent across
# heterogeneous providers, and matches standard logging's `model_map_key`).
deployment_model_name: Optional[str] = None
for metadata in metadata_dicts:
maybe_deployment_model_name = metadata.get("deployment_model_name")
if (
isinstance(maybe_deployment_model_name, str)
and maybe_deployment_model_name
):
deployment_model_name = maybe_deployment_model_name
break
if not deployment_model_name:
verbose_router_logger.warning(
"DeploymentAffinityCheck: deployment_model_name missing; skipping affinity cache update. model_id=%s",
model_id,
)
return None
if user_key is not None:
try:
cache_key = self.get_affinity_cache_key(
model_group=deployment_model_name, user_key=user_key
)
await self.cache.async_set_cache(
cache_key,
DeploymentAffinityCacheValue(model_id=str(model_id)),
ttl=self.ttl_seconds,
)
verbose_router_logger.debug(
"DeploymentAffinityCheck: set affinity mapping model_map_key=%s deployment=%s ttl=%s user_key=%s",
deployment_model_name,
model_id,
self.ttl_seconds,
self._shorten_for_logs(user_key),
)
except Exception as e:
# Non-blocking: affinity is a best-effort optimization.
verbose_router_logger.debug(
"DeploymentAffinityCheck: failed to set user key affinity cache. model_map_key=%s error=%s",
deployment_model_name,
e,
)
# Also persist Session-ID affinity if enabled and session-id is provided
if session_id is not None:
try:
session_cache_key = self.get_session_affinity_cache_key(
model_group=deployment_model_name, session_id=session_id
)
await self.cache.async_set_cache(
session_cache_key,
DeploymentAffinityCacheValue(model_id=str(model_id)),
ttl=self.ttl_seconds,
)
verbose_router_logger.debug(
"DeploymentAffinityCheck: set session affinity mapping model_map_key=%s deployment=%s ttl=%s session_id=%s",
deployment_model_name,
model_id,
self.ttl_seconds,
session_id,
)
except Exception as e:
verbose_router_logger.debug(
"DeploymentAffinityCheck: failed to set session affinity cache. model_map_key=%s error=%s",
deployment_model_name,
e,
)
return None

View File

@@ -0,0 +1,172 @@
"""
Encrypted-content-aware deployment affinity for the Router.
When Codex or other models use `store: false` with `include: ["reasoning.encrypted_content"]`,
the response output items contain encrypted reasoning tokens tied to the originating
organization's API key. If a follow-up request containing those items is routed to a
different deployment (different org), OpenAI rejects it with an `invalid_encrypted_content`
error because the organization_id doesn't match.
This callback solves the problem by encoding the originating deployment's ``model_id``
into the response output items that carry ``encrypted_content``. Two encoding strategies:
1. **Items with IDs**: Encode model_id into the item ID itself (e.g., ``encitem_...``)
2. **Items without IDs** (Codex): Wrap the encrypted_content with model_id metadata
(e.g., ``litellm_enc:{base64_metadata};{original_encrypted_content}``)
The encoded model_id is decoded on the next request so the router can pin to the correct
deployment without any cache lookup.
Response post-processing (encoding) is handled by
``ResponsesAPIRequestUtils._update_encrypted_content_item_ids_in_response`` which is
called inside ``_update_responses_api_response_id_with_model_id`` in ``responses/utils.py``.
Request pre-processing (ID/content restoration before forwarding to upstream) is handled by
``ResponsesAPIRequestUtils._restore_encrypted_content_item_ids_in_input`` which is called
in ``get_optional_params_responses_api``.
This pre-call check is responsible only for the routing decision: it reads the encoded
``model_id`` from either item IDs or wrapped encrypted_content and pins the request to
the matching deployment.
Safe to enable globally:
- Only activates when encoded markers appear in the request ``input``.
- No effect on embedding models, chat completions, or first-time requests.
- No quota reduction -- first requests are fully load balanced.
- No cache required.
"""
from typing import Any, List, Optional, cast
from litellm._logging import verbose_router_logger
from litellm.integrations.custom_logger import CustomLogger, Span
from litellm.responses.utils import ResponsesAPIRequestUtils
from litellm.types.llms.openai import AllMessageValues
class EncryptedContentAffinityCheck(CustomLogger):
"""
Routes follow-up Responses API requests to the deployment that produced
the encrypted output items they reference.
The ``model_id`` is decoded directly from the litellm-encoded item IDs
no caching or TTL management needed.
Wired via ``Router(optional_pre_call_checks=["encrypted_content_affinity"])``.
"""
def __init__(self) -> None:
super().__init__()
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def _extract_model_id_from_input(request_input: Any) -> Optional[str]:
"""
Scan ``input`` items for litellm-encoded encrypted-content markers and
return the ``model_id`` embedded in the first one found.
Checks both:
1. Encoded item IDs (encitem_...) - for clients that send IDs
2. Wrapped encrypted_content (litellm_enc:...) - for clients like Codex that don't send IDs
``input`` can be:
- a plain string -> no encoded markers
- a list of items -> check each item's ``id`` and ``encrypted_content`` fields
"""
if not isinstance(request_input, list):
return None
for item in request_input:
if not isinstance(item, dict):
continue
# First, try to decode from item ID (if present)
item_id = item.get("id")
if item_id and isinstance(item_id, str):
decoded = ResponsesAPIRequestUtils._decode_encrypted_item_id(item_id)
if decoded:
return decoded.get("model_id")
# If no encoded ID, check if encrypted_content itself is wrapped
encrypted_content = item.get("encrypted_content")
if encrypted_content and isinstance(encrypted_content, str):
(
model_id,
_,
) = ResponsesAPIRequestUtils._unwrap_encrypted_content_with_model_id(
encrypted_content
)
if model_id:
return model_id
return None
@staticmethod
def _find_deployment_by_model_id(
healthy_deployments: List[dict], model_id: str
) -> Optional[dict]:
for deployment in healthy_deployments:
model_info = deployment.get("model_info")
if not isinstance(model_info, dict):
continue
deployment_model_id = model_info.get("id")
if deployment_model_id is not None and str(deployment_model_id) == str(
model_id
):
return deployment
return None
# ------------------------------------------------------------------
# Request routing (pre-call filter)
# ------------------------------------------------------------------
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None,
) -> List[dict]:
"""
If the request ``input`` contains litellm-encoded item IDs, decode the
embedded ``model_id`` and pin the request to that deployment.
"""
request_kwargs = request_kwargs or {}
typed_healthy_deployments = cast(List[dict], healthy_deployments)
# Signal to the response post-processor that encrypted item IDs should be
# encoded in the output of this request.
litellm_metadata = request_kwargs.setdefault("litellm_metadata", {})
litellm_metadata["encrypted_content_affinity_enabled"] = True
request_input = request_kwargs.get("input")
model_id = self._extract_model_id_from_input(request_input)
if not model_id:
return typed_healthy_deployments
verbose_router_logger.debug(
"EncryptedContentAffinityCheck: decoded model_id=%s from input item IDs",
model_id,
)
deployment = self._find_deployment_by_model_id(
healthy_deployments=typed_healthy_deployments,
model_id=model_id,
)
if deployment is not None:
verbose_router_logger.debug(
"EncryptedContentAffinityCheck: pinning -> deployment=%s",
model_id,
)
request_kwargs["_encrypted_content_affinity_pinned"] = True
return [deployment]
verbose_router_logger.error(
"EncryptedContentAffinityCheck: decoded deployment=%s not found in healthy_deployments",
model_id,
)
return typed_healthy_deployments

View File

@@ -0,0 +1,332 @@
"""
Enforce TPM/RPM rate limits set on model deployments.
This pre-call check ensures that model-level TPM/RPM limits are enforced
across all requests, regardless of routing strategy.
When enabled via `enforce_model_rate_limits: true` in litellm_settings,
requests that exceed the configured TPM/RPM limits will receive a 429 error.
"""
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_router_logger
from litellm.caching.dual_cache import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.router import RouterErrors
from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_utc_datetime
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
class RoutingArgs:
ttl: int = 60 # 1min (RPM/TPM expire key)
class ModelRateLimitingCheck(CustomLogger):
"""
Pre-call check that enforces TPM/RPM limits on model deployments.
This check runs before each request and raises a RateLimitError
if the deployment has exceeded its configured TPM or RPM limits.
Unlike the usage-based-routing strategy which uses limits for routing decisions,
this check actively enforces those limits across ALL routing strategies.
"""
def __init__(self, dual_cache: DualCache):
self.dual_cache = dual_cache
def _get_deployment_limits(
self, deployment: Dict
) -> tuple[Optional[int], Optional[int]]:
"""
Extract TPM and RPM limits from a deployment configuration.
Checks in order:
1. Top-level 'tpm'/'rpm' fields
2. litellm_params.tpm/rpm
3. model_info.tpm/rpm
Returns:
Tuple of (tpm_limit, rpm_limit)
"""
# Check top-level
tpm = deployment.get("tpm")
rpm = deployment.get("rpm")
# Check litellm_params
if tpm is None:
tpm = deployment.get("litellm_params", {}).get("tpm")
if rpm is None:
rpm = deployment.get("litellm_params", {}).get("rpm")
# Check model_info
if tpm is None:
tpm = deployment.get("model_info", {}).get("tpm")
if rpm is None:
rpm = deployment.get("model_info", {}).get("rpm")
return tpm, rpm
def _get_cache_keys(self, deployment: Dict, current_minute: str) -> tuple[str, str]:
"""Get the cache keys for TPM and RPM tracking."""
model_id = deployment.get("model_info", {}).get("id")
deployment_name = deployment.get("litellm_params", {}).get("model")
tpm_key = f"{model_id}:{deployment_name}:tpm:{current_minute}"
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
return tpm_key, rpm_key
def pre_call_check(self, deployment: Dict) -> Optional[Dict]:
"""
Synchronous pre-call check for model rate limits.
Raises RateLimitError if deployment exceeds TPM/RPM limits.
"""
try:
tpm_limit, rpm_limit = self._get_deployment_limits(deployment)
# If no limits are set, allow the request
if tpm_limit is None and rpm_limit is None:
return deployment
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_key, rpm_key = self._get_cache_keys(deployment, current_minute)
model_id = deployment.get("model_info", {}).get("id")
model_name = deployment.get("litellm_params", {}).get("model")
model_group = deployment.get("model_name", "")
# Check TPM limit
if tpm_limit is not None:
# First check local cache
current_tpm = self.dual_cache.get_cache(key=tpm_key, local_only=True)
if current_tpm is not None and current_tpm >= tpm_limit:
raise litellm.RateLimitError(
message=f"Model rate limit exceeded. TPM limit={tpm_limit}, current usage={current_tpm}",
llm_provider="",
model=model_name,
response=httpx.Response(
status_code=429,
content=f"{RouterErrors.user_defined_ratelimit_error.value} tpm limit={tpm_limit}. current usage={current_tpm}. id={model_id}, model_group={model_group}",
headers={"retry-after": str(60)},
request=httpx.Request(
method="model_rate_limit_check",
url="https://github.com/BerriAI/litellm",
),
),
)
# Check RPM limit (atomic increment-first to avoid race conditions)
if rpm_limit is not None:
current_rpm = self.dual_cache.increment_cache(
key=rpm_key, value=1, ttl=RoutingArgs.ttl
)
if current_rpm is not None and current_rpm > rpm_limit:
raise litellm.RateLimitError(
message=f"Model rate limit exceeded. RPM limit={rpm_limit}, current usage={current_rpm}",
llm_provider="",
model=model_name,
response=httpx.Response(
status_code=429,
content=f"{RouterErrors.user_defined_ratelimit_error.value} rpm limit={rpm_limit}. current usage={current_rpm}. id={model_id}, model_group={model_group}",
headers={"retry-after": str(60)},
request=httpx.Request(
method="model_rate_limit_check",
url="https://github.com/BerriAI/litellm",
),
),
)
return deployment
except litellm.RateLimitError:
raise
except Exception as e:
verbose_router_logger.debug(
f"Error in ModelRateLimitingCheck.pre_call_check: {str(e)}"
)
# Don't fail the request if rate limit check fails
return deployment
async def async_pre_call_check(
self, deployment: Dict, parent_otel_span: Optional[Span] = None
) -> Optional[Dict]:
"""
Async pre-call check for model rate limits.
Raises RateLimitError if deployment exceeds TPM/RPM limits.
"""
try:
tpm_limit, rpm_limit = self._get_deployment_limits(deployment)
# If no limits are set, allow the request
if tpm_limit is None and rpm_limit is None:
return deployment
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_key, rpm_key = self._get_cache_keys(deployment, current_minute)
model_id = deployment.get("model_info", {}).get("id")
model_name = deployment.get("litellm_params", {}).get("model")
model_group = deployment.get("model_name", "")
# Check TPM limit
if tpm_limit is not None:
# First check local cache
current_tpm = await self.dual_cache.async_get_cache(
key=tpm_key, local_only=True
)
if current_tpm is not None and current_tpm >= tpm_limit:
raise litellm.RateLimitError(
message=f"Model rate limit exceeded. TPM limit={tpm_limit}, current usage={current_tpm}",
llm_provider="",
model=model_name,
response=httpx.Response(
status_code=429,
content=f"{RouterErrors.user_defined_ratelimit_error.value} tpm limit={tpm_limit}. current usage={current_tpm}. id={model_id}, model_group={model_group}",
headers={"retry-after": str(60)},
request=httpx.Request(
method="model_rate_limit_check",
url="https://github.com/BerriAI/litellm",
),
),
num_retries=0, # Don't retry - return 429 immediately
)
# Check RPM limit (atomic increment-first to avoid race conditions)
if rpm_limit is not None:
current_rpm = await self.dual_cache.async_increment_cache(
key=rpm_key,
value=1,
ttl=RoutingArgs.ttl,
parent_otel_span=parent_otel_span,
)
if current_rpm is not None and current_rpm > rpm_limit:
raise litellm.RateLimitError(
message=f"Model rate limit exceeded. RPM limit={rpm_limit}, current usage={current_rpm}",
llm_provider="",
model=model_name,
response=httpx.Response(
status_code=429,
content=f"{RouterErrors.user_defined_ratelimit_error.value} rpm limit={rpm_limit}. current usage={current_rpm}. id={model_id}, model_group={model_group}",
headers={"retry-after": str(60)},
request=httpx.Request(
method="model_rate_limit_check",
url="https://github.com/BerriAI/litellm",
),
),
num_retries=0, # Don't retry - return 429 immediately
)
return deployment
except litellm.RateLimitError:
raise
except Exception as e:
verbose_router_logger.debug(
f"Error in ModelRateLimitingCheck.async_pre_call_check: {str(e)}"
)
# Don't fail the request if rate limit check fails
return deployment
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Track TPM usage after successful request.
This updates the TPM counter with the actual tokens used.
Always tracks tokens - the pre-call check handles enforcement.
"""
try:
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if standard_logging_object is None:
return
model_id = standard_logging_object.get("model_id")
if model_id is None:
return
total_tokens = standard_logging_object.get("total_tokens", 0)
model = standard_logging_object.get("hidden_params", {}).get(
"litellm_model_name"
)
verbose_router_logger.debug(
f"[TPM TRACKING] model_id={model_id}, total_tokens={total_tokens}, model={model}"
)
if not model or not total_tokens:
return
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_key = f"{model_id}:{model}:tpm:{current_minute}"
verbose_router_logger.debug(
f"[TPM TRACKING] Incrementing {tpm_key} by {total_tokens}"
)
await self.dual_cache.async_increment_cache(
key=tpm_key,
value=total_tokens,
ttl=RoutingArgs.ttl,
)
except Exception as e:
verbose_router_logger.debug(
f"Error in ModelRateLimitingCheck.async_log_success_event: {str(e)}"
)
def log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Sync version of tracking TPM usage after successful request.
Always tracks tokens - the pre-call check handles enforcement.
"""
try:
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if standard_logging_object is None:
return
model_id = standard_logging_object.get("model_id")
if model_id is None:
return
total_tokens = standard_logging_object.get("total_tokens", 0)
model = standard_logging_object.get("hidden_params", {}).get(
"litellm_model_name"
)
if not model or not total_tokens:
return
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_key = f"{model_id}:{model}:tpm:{current_minute}"
self.dual_cache.increment_cache(
key=tpm_key,
value=total_tokens,
ttl=RoutingArgs.ttl,
)
except Exception as e:
verbose_router_logger.debug(
f"Error in ModelRateLimitingCheck.log_success_event: {str(e)}"
)

View File

@@ -0,0 +1,100 @@
"""
Check if prompt caching is valid for a given deployment
Route to previously cached model id, if valid
"""
from typing import List, Optional, cast
from litellm import verbose_logger
from litellm.caching.dual_cache import DualCache
from litellm.integrations.custom_logger import CustomLogger, Span
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import CallTypes, StandardLoggingPayload
from litellm.utils import is_prompt_caching_valid_prompt
from ..prompt_caching_cache import PromptCachingCache
class PromptCachingDeploymentCheck(CustomLogger):
def __init__(self, cache: DualCache):
self.cache = cache
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None,
) -> List[dict]:
if messages is not None and is_prompt_caching_valid_prompt(
messages=messages,
model=model,
): # prompt > 1024 tokens
prompt_cache = PromptCachingCache(
cache=self.cache,
)
model_id_dict = await prompt_cache.async_get_model_id(
messages=cast(List[AllMessageValues], messages),
tools=None,
)
if model_id_dict is not None:
model_id = model_id_dict["model_id"]
for deployment in healthy_deployments:
if deployment["model_info"]["id"] == model_id:
return [deployment]
return healthy_deployments
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if standard_logging_object is None:
return
call_type = standard_logging_object["call_type"]
if (
call_type != CallTypes.completion.value
and call_type != CallTypes.acompletion.value
and call_type != CallTypes.anthropic_messages.value
): # only use prompt caching for completion calls
verbose_logger.debug(
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, CALL TYPE IS NOT COMPLETION or ANTHROPIC MESSAGE"
)
return
model = standard_logging_object["model"]
messages = standard_logging_object["messages"]
model_id = standard_logging_object["model_id"]
if messages is None or not isinstance(messages, list):
verbose_logger.debug(
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MESSAGES IS NOT A LIST"
)
return
if model_id is None:
verbose_logger.debug(
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MODEL ID IS NONE"
)
return
## PROMPT CACHING - cache model id, if prompt caching valid prompt + provider
if is_prompt_caching_valid_prompt(
model=model,
messages=cast(List[AllMessageValues], messages),
):
cache = PromptCachingCache(
cache=self.cache,
)
await cache.async_add_model_id(
model_id=model_id,
messages=messages,
tools=None, # [TODO]: add tools once standard_logging_object supports it
)
return

View File

@@ -0,0 +1,57 @@
"""
For Responses API, we need routing affinity when a user sends a previous_response_id.
eg. If proxy admins are load balancing between N gpt-4.1-turbo deployments, and a user sends a previous_response_id,
we want to route to the same gpt-4.1-turbo deployment.
This is different from the normal behavior of the router, which does not have routing affinity for previous_response_id.
If previous_response_id is provided, route to the deployment that returned the previous response
"""
import warnings
from typing import List, Optional
from litellm.integrations.custom_logger import CustomLogger, Span
from litellm.responses.utils import ResponsesAPIRequestUtils
from litellm.types.llms.openai import AllMessageValues
class ResponsesApiDeploymentCheck(CustomLogger):
def __init__(self) -> None:
super().__init__()
warnings.warn(
(
"ResponsesApiDeploymentCheck is deprecated. "
"Use DeploymentAffinityCheck(enable_responses_api_affinity=True) instead."
),
DeprecationWarning,
stacklevel=2,
)
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None,
) -> List[dict]:
request_kwargs = request_kwargs or {}
previous_response_id = request_kwargs.get("previous_response_id", None)
if previous_response_id is None:
return healthy_deployments
decoded_response = ResponsesAPIRequestUtils._decode_responses_api_response_id(
response_id=previous_response_id,
)
model_id = decoded_response.get("model_id")
if model_id is None:
return healthy_deployments
for deployment in healthy_deployments:
if deployment["model_info"]["id"] == model_id:
return [deployment]
return healthy_deployments

View File

@@ -0,0 +1,259 @@
"""
Wrapper around router cache. Meant to store model id when prompt caching supported prompt is called.
"""
import hashlib
import json
from typing import TYPE_CHECKING, Any, List, Optional, Union, cast
from typing_extensions import TypedDict
from litellm.caching.caching import DualCache
from litellm.caching.in_memory_cache import InMemoryCache
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.router import Router
litellm_router = Router
Span = Union[_Span, Any]
else:
Span = Any
litellm_router = Any
class PromptCachingCacheValue(TypedDict):
model_id: str
class PromptCachingCache:
def __init__(self, cache: DualCache):
self.cache = cache
self.in_memory_cache = InMemoryCache()
@staticmethod
def serialize_object(obj: Any) -> Any:
"""Helper function to serialize Pydantic objects, dictionaries, or fallback to string."""
if hasattr(obj, "dict"):
# If the object is a Pydantic model, use its `dict()` method
return obj.dict()
elif isinstance(obj, dict):
# If the object is a dictionary, serialize it with sorted keys
return json.dumps(
obj, sort_keys=True, separators=(",", ":")
) # Standardize serialization
elif isinstance(obj, list):
# Serialize lists by ensuring each element is handled properly
return [PromptCachingCache.serialize_object(item) for item in obj]
elif isinstance(obj, (int, float, bool)):
return obj # Keep primitive types as-is
return str(obj)
@staticmethod
def extract_cacheable_prefix(
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
"""
Extract the cacheable prefix from messages.
The cacheable prefix is everything UP TO AND INCLUDING the LAST content block
(across all messages) that has cache_control. This includes ALL blocks before
the last cacheable block (even if they don't have cache_control).
Args:
messages: List of messages to extract cacheable prefix from
Returns:
List of messages containing only the cacheable prefix
"""
if not messages:
return messages
# Find the last content block (across all messages) that has cache_control
last_cacheable_message_idx = None
last_cacheable_content_idx = None
for msg_idx, message in enumerate(messages):
content = message.get("content")
# Check for cache_control at message level (when content is a string)
# This handles the case where cache_control is a sibling of string content:
# {"role": "user", "content": "...", "cache_control": {"type": "ephemeral"}}
message_level_cache_control = message.get("cache_control")
if (
message_level_cache_control is not None
and isinstance(message_level_cache_control, dict)
and message_level_cache_control.get("type") == "ephemeral"
):
last_cacheable_message_idx = msg_idx
# Set to None to indicate the entire message content is cacheable
# (not a specific content block index within a list)
last_cacheable_content_idx = None
# Also check for cache_control within content blocks (when content is a list)
if not isinstance(content, list):
continue
for content_idx, content_block in enumerate(content):
if isinstance(content_block, dict):
cache_control = content_block.get("cache_control")
if (
cache_control is not None
and isinstance(cache_control, dict)
and cache_control.get("type") == "ephemeral"
):
last_cacheable_message_idx = msg_idx
last_cacheable_content_idx = content_idx
# If no cacheable block found, return empty list (no cacheable prefix)
if last_cacheable_message_idx is None:
return []
# Build the cacheable prefix: all messages up to and including the last cacheable message
cacheable_prefix = []
for msg_idx, message in enumerate(messages):
if msg_idx < last_cacheable_message_idx:
# Include entire message (comes before last cacheable block)
cacheable_prefix.append(message)
elif msg_idx == last_cacheable_message_idx:
# Include message but only up to and including the last cacheable content block
content = message.get("content")
if isinstance(content, list) and last_cacheable_content_idx is not None:
# Create a copy of the message with only cacheable content blocks
message_copy = cast(
AllMessageValues,
{
**message,
"content": content[: last_cacheable_content_idx + 1],
},
)
cacheable_prefix.append(message_copy)
else:
# Content is not a list or cacheable content idx is None, include full message
cacheable_prefix.append(message)
else:
# Message comes after last cacheable block, don't include
break
return cacheable_prefix
@staticmethod
def get_prompt_caching_cache_key(
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[str]:
if messages is None and tools is None:
return None
# Extract cacheable prefix from messages (only include up to last cache_control block)
cacheable_messages = None
if messages is not None:
cacheable_messages = PromptCachingCache.extract_cacheable_prefix(messages)
# If no cacheable prefix found, return None (can't cache)
if not cacheable_messages:
return None
# Use serialize_object for consistent and stable serialization
data_to_hash = {}
if cacheable_messages is not None:
serialized_messages = PromptCachingCache.serialize_object(
cacheable_messages
)
data_to_hash["messages"] = serialized_messages
if tools is not None:
serialized_tools = PromptCachingCache.serialize_object(tools)
data_to_hash["tools"] = serialized_tools
# Combine serialized data into a single string
data_to_hash_str = json.dumps(
data_to_hash,
sort_keys=True,
separators=(",", ":"),
)
# Create a hash of the serialized data for a stable cache key
hashed_data = hashlib.sha256(data_to_hash_str.encode()).hexdigest()
return f"deployment:{hashed_data}:prompt_caching"
def add_model_id(
self,
model_id: str,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> None:
if messages is None and tools is None:
return None
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
# If no cacheable prefix found, don't cache (can't generate cache key)
if cache_key is None:
return None
self.cache.set_cache(
cache_key, PromptCachingCacheValue(model_id=model_id), ttl=300
)
return None
async def async_add_model_id(
self,
model_id: str,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> None:
if messages is None and tools is None:
return None
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
# If no cacheable prefix found, don't cache (can't generate cache key)
if cache_key is None:
return None
await self.cache.async_set_cache(
cache_key,
PromptCachingCacheValue(model_id=model_id),
ttl=300, # store for 5 minutes
)
return None
async def async_get_model_id(
self,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[PromptCachingCacheValue]:
"""
Get model ID from cache using the cacheable prefix.
The cache key is based on the cacheable prefix (everything up to and including
the last cache_control block), so requests with the same cacheable prefix but
different user messages will have the same cache key.
"""
if messages is None and tools is None:
return None
# Generate cache key using cacheable prefix
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
if cache_key is None:
return None
# Perform cache lookup
cache_result = await self.cache.async_get_cache(key=cache_key)
return cache_result
def get_model_id(
self,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[PromptCachingCacheValue]:
if messages is None and tools is None:
return None
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
# If no cacheable prefix found, return None (can't cache)
if cache_key is None:
return None
return self.cache.get_cache(cache_key)

View File

@@ -0,0 +1,90 @@
"""
Helper functions to get/set num success and num failures per deployment
set_deployment_failures_for_current_minute
set_deployment_successes_for_current_minute
get_deployment_failures_for_current_minute
get_deployment_successes_for_current_minute
"""
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from litellm.router import Router as _Router
LitellmRouter = _Router
else:
LitellmRouter = Any
def increment_deployment_successes_for_current_minute(
litellm_router_instance: LitellmRouter,
deployment_id: str,
) -> str:
"""
In-Memory: Increments the number of successes for the current minute for a deployment_id
"""
key = f"{deployment_id}:successes"
litellm_router_instance.cache.increment_cache(
local_only=True,
key=key,
value=1,
ttl=60,
)
return key
def increment_deployment_failures_for_current_minute(
litellm_router_instance: LitellmRouter,
deployment_id: str,
):
"""
In-Memory: Increments the number of failures for the current minute for a deployment_id
"""
key = f"{deployment_id}:fails"
litellm_router_instance.cache.increment_cache(
local_only=True,
key=key,
value=1,
ttl=60,
)
def get_deployment_successes_for_current_minute(
litellm_router_instance: LitellmRouter,
deployment_id: str,
) -> int:
"""
Returns the number of successes for the current minute for a deployment_id
Returns 0 if no value found
"""
key = f"{deployment_id}:successes"
return (
litellm_router_instance.cache.get_cache(
local_only=True,
key=key,
)
or 0
)
def get_deployment_failures_for_current_minute(
litellm_router_instance: LitellmRouter,
deployment_id: str,
) -> int:
"""
Returns the number of fails for the current minute for a deployment_id
Returns 0 if no value found
"""
key = f"{deployment_id}:fails"
return (
litellm_router_instance.cache.get_cache(
local_only=True,
key=key,
)
or 0
)

View File

@@ -0,0 +1,227 @@
"""
Router utilities for Search API integration.
Handles search tool selection, load balancing, and fallback logic for search requests.
"""
import asyncio
import random
import traceback
from functools import partial
from typing import Any, Callable
from litellm._logging import verbose_router_logger
class SearchAPIRouter:
"""
Static utility class for routing search API calls through the LiteLLM router.
Provides methods for search tool selection, load balancing, and fallback handling.
"""
@staticmethod
async def update_router_search_tools(router_instance: Any, search_tools: list):
"""
Update the router with search tools from the database.
This method is called by a cron job to sync search tools from DB to router.
Args:
router_instance: The Router instance to update
search_tools: List of search tool configurations from the database
"""
try:
from litellm.types.router import SearchToolTypedDict
verbose_router_logger.debug(
f"Adding {len(search_tools)} search tools to router"
)
# Convert search tools to the format expected by the router
router_search_tools: list = []
for tool in search_tools:
# Create dict that matches SearchToolTypedDict structure
router_search_tool: SearchToolTypedDict = { # type: ignore
"search_tool_id": tool.get("search_tool_id"),
"search_tool_name": tool.get("search_tool_name"),
"litellm_params": tool.get("litellm_params", {}),
"search_tool_info": tool.get("search_tool_info"),
}
router_search_tools.append(router_search_tool)
# Update the router's search_tools list
router_instance.search_tools = router_search_tools
verbose_router_logger.info(
f"Successfully updated router with {len(router_search_tools)} search tool(s)"
)
except Exception as e:
verbose_router_logger.exception(
f"Error updating router with search tools: {str(e)}"
)
raise e
@staticmethod
def get_matching_search_tools(
router_instance: Any,
search_tool_name: str,
) -> list:
"""
Get all search tools matching the given name.
Args:
router_instance: The Router instance
search_tool_name: Name of the search tool to find
Returns:
List of matching search tool configurations
Raises:
ValueError: If no matching search tools are found
"""
matching_tools = [
tool
for tool in router_instance.search_tools
if tool.get("search_tool_name") == search_tool_name
]
if not matching_tools:
raise ValueError(
f"Search tool '{search_tool_name}' not found in router.search_tools"
)
return matching_tools
@staticmethod
async def async_search_with_fallbacks(
router_instance: Any,
original_function: Callable,
**kwargs,
):
"""
Helper function to make a search API call through the router with load balancing and fallbacks.
Reuses the router's retry/fallback infrastructure.
Args:
router_instance: The Router instance
original_function: The original litellm.asearch function
**kwargs: Search parameters including search_tool_name, query, etc.
Returns:
SearchResponse from the search API
"""
try:
search_tool_name = kwargs.get("search_tool_name", kwargs.get("model"))
if not search_tool_name:
raise ValueError(
"search_tool_name or model parameter is required for search"
)
# Set up kwargs for the fallback system
kwargs[
"model"
] = search_tool_name # Use model field for compatibility with fallback system
kwargs["original_generic_function"] = original_function
# Bind router_instance to the helper method using partial
kwargs["original_function"] = partial(
SearchAPIRouter.async_search_with_fallbacks_helper,
router_instance=router_instance,
)
# Update kwargs before fallbacks (for logging, metadata, etc)
router_instance._update_kwargs_before_fallbacks(
model=search_tool_name,
kwargs=kwargs,
metadata_variable_name="litellm_metadata",
)
available_search_tool_names = [
tool.get("search_tool_name") for tool in router_instance.search_tools
]
verbose_router_logger.debug(
f"Inside SearchAPIRouter.async_search_with_fallbacks() - search_tool_name: {search_tool_name}, Available Search Tools: {available_search_tool_names}, kwargs: {kwargs}"
)
# Use the existing retry/fallback infrastructure
response = await router_instance.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
from litellm.router_utils.handle_error import send_llm_exception_alert
asyncio.create_task(
send_llm_exception_alert(
litellm_router_instance=router_instance,
request_kwargs=kwargs,
error_traceback_str=traceback.format_exc(),
original_exception=e,
)
)
raise e
@staticmethod
async def async_search_with_fallbacks_helper(
router_instance: Any,
model: str,
original_generic_function: Callable,
**kwargs,
):
"""
Helper function for search API calls - selects a search tool and calls the original function.
Called by async_function_with_fallbacks for each retry attempt.
Args:
router_instance: The Router instance
model: The search tool name (passed as model for compatibility)
original_generic_function: The original litellm.asearch function
**kwargs: Search parameters
Returns:
SearchResponse from the selected search provider
"""
search_tool_name = model # model field contains the search_tool_name
try:
# Find matching search tools
matching_tools = SearchAPIRouter.get_matching_search_tools(
router_instance=router_instance,
search_tool_name=search_tool_name,
)
# Simple random selection for load balancing across multiple providers with same name
# For search tools, we use simple random choice since they don't have TPM/RPM constraints
selected_tool = random.choice(matching_tools)
# Extract search provider and other params from litellm_params
litellm_params = selected_tool.get("litellm_params", {})
search_provider = litellm_params.get("search_provider")
api_key = litellm_params.get("api_key")
api_base = litellm_params.get("api_base")
if not search_provider:
raise ValueError(
f"search_provider not found in litellm_params for search tool '{search_tool_name}'"
)
verbose_router_logger.debug(
f"Selected search tool with provider: {search_provider}"
)
# Call the original search function with the provider config
response = await original_generic_function(
search_provider=search_provider,
api_key=api_key,
api_base=api_base,
**kwargs,
)
return response
except Exception as e:
verbose_router_logger.error(
f"Error in SearchAPIRouter.async_search_with_fallbacks_helper for {search_tool_name}: {str(e)}"
)
raise e