chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,70 @@
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.types.utils import HiddenParams
|
||||
|
||||
|
||||
def _add_headers_to_response(response: Any, headers: dict) -> Any:
|
||||
"""
|
||||
Helper function to add headers to a response's hidden params
|
||||
"""
|
||||
if response is None or not isinstance(response, BaseModel):
|
||||
return response
|
||||
|
||||
hidden_params: Optional[Union[dict, HiddenParams]] = getattr(
|
||||
response, "_hidden_params", {}
|
||||
)
|
||||
|
||||
if hidden_params is None:
|
||||
hidden_params_dict = {}
|
||||
elif isinstance(hidden_params, HiddenParams):
|
||||
hidden_params_dict = hidden_params.model_dump()
|
||||
else:
|
||||
hidden_params_dict = hidden_params
|
||||
|
||||
hidden_params_dict.setdefault("additional_headers", {})
|
||||
hidden_params_dict["additional_headers"].update(headers)
|
||||
|
||||
setattr(response, "_hidden_params", hidden_params_dict)
|
||||
return response
|
||||
|
||||
|
||||
def add_retry_headers_to_response(
|
||||
response: Any,
|
||||
attempted_retries: int,
|
||||
max_retries: Optional[int] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Add retry headers to the request
|
||||
"""
|
||||
retry_headers = {
|
||||
"x-litellm-attempted-retries": attempted_retries,
|
||||
}
|
||||
if max_retries is not None:
|
||||
retry_headers["x-litellm-max-retries"] = max_retries
|
||||
|
||||
return _add_headers_to_response(response, retry_headers)
|
||||
|
||||
|
||||
def add_fallback_headers_to_response(
|
||||
response: Any,
|
||||
attempted_fallbacks: int,
|
||||
) -> Any:
|
||||
"""
|
||||
Add fallback headers to the response
|
||||
|
||||
Args:
|
||||
response: The response to add the headers to
|
||||
attempted_fallbacks: The number of fallbacks attempted
|
||||
|
||||
Returns:
|
||||
The response with the headers added
|
||||
|
||||
Note: It's intentional that we don't add max_fallbacks in response headers
|
||||
Want to avoid bloat in the response headers for performance.
|
||||
"""
|
||||
fallback_headers = {
|
||||
"x-litellm-attempted-fallbacks": attempted_fallbacks,
|
||||
}
|
||||
return _add_headers_to_response(response, fallback_headers)
|
||||
@@ -0,0 +1,150 @@
|
||||
import io
|
||||
import json
|
||||
from os import PathLike
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.llms.openai import FileTypes, OpenAIFilesPurpose
|
||||
|
||||
|
||||
class InMemoryFile(io.BytesIO):
|
||||
def __init__(
|
||||
self, content: bytes, name: str, content_type: str = "application/jsonl"
|
||||
):
|
||||
super().__init__(content)
|
||||
self.name = name
|
||||
self.content_type = content_type
|
||||
|
||||
|
||||
def parse_jsonl_with_embedded_newlines(content: str) -> List[dict]:
|
||||
"""
|
||||
Parse JSONL content that may contain JSON objects with embedded newlines in string values.
|
||||
|
||||
Unlike splitlines(), this function properly handles cases where JSON string values
|
||||
contain literal newline characters, which would otherwise break simple line-based parsing.
|
||||
|
||||
Args:
|
||||
content: The JSONL file content as a string
|
||||
|
||||
Returns:
|
||||
List of parsed JSON objects
|
||||
|
||||
Example:
|
||||
>>> content = '{"id":1,"msg":"Line 1\\nLine 2"}\\n{"id":2,"msg":"test"}'
|
||||
>>> parse_jsonl_with_embedded_newlines(content)
|
||||
[{"id":1,"msg":"Line 1\\nLine 2"}, {"id":2,"msg":"test"}]
|
||||
"""
|
||||
json_objects = []
|
||||
buffer = ""
|
||||
|
||||
for char in content:
|
||||
buffer += char
|
||||
if char == "\n":
|
||||
# Try to parse what we have so far
|
||||
try:
|
||||
json_object = json.loads(buffer.strip())
|
||||
json_objects.append(json_object)
|
||||
buffer = "" # Reset buffer for next JSON object
|
||||
except json.JSONDecodeError:
|
||||
# Not a complete JSON object yet, keep accumulating
|
||||
continue
|
||||
|
||||
# Handle any remaining content in buffer
|
||||
if buffer.strip():
|
||||
try:
|
||||
json_object = json.loads(buffer.strip())
|
||||
json_objects.append(json_object)
|
||||
except json.JSONDecodeError as e:
|
||||
verbose_logger.error(
|
||||
f"error parsing final buffer: {buffer[:100]}..., error: {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
return json_objects
|
||||
|
||||
|
||||
def should_replace_model_in_jsonl(
|
||||
purpose: OpenAIFilesPurpose,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the model name should be replaced in the JSONL file for the deployment model name.
|
||||
|
||||
Azure raises an error on create batch if the model name for deployment is not in the .jsonl.
|
||||
"""
|
||||
if purpose == "batch":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def replace_model_in_jsonl(file_content: FileTypes, new_model_name: str) -> FileTypes:
|
||||
try:
|
||||
## if pathlike, return the original file content
|
||||
if isinstance(file_content, PathLike):
|
||||
return file_content
|
||||
|
||||
# Decode the bytes to a string and split into lines
|
||||
# If file_content is a file-like object, read the bytes
|
||||
if hasattr(file_content, "read"):
|
||||
file_content_bytes = file_content.read() # type: ignore
|
||||
elif isinstance(file_content, tuple):
|
||||
file_content_bytes = file_content[1]
|
||||
else:
|
||||
file_content_bytes = file_content
|
||||
|
||||
# Decode the bytes to a string and split into lines
|
||||
if isinstance(file_content_bytes, bytes):
|
||||
file_content_str = file_content_bytes.decode("utf-8")
|
||||
elif isinstance(file_content_bytes, str):
|
||||
file_content_str = file_content_bytes
|
||||
else:
|
||||
return file_content
|
||||
|
||||
# Parse JSONL properly, handling potential multiline JSON objects
|
||||
json_objects = parse_jsonl_with_embedded_newlines(file_content_str)
|
||||
|
||||
# If no valid JSON objects were found, return the original content
|
||||
if len(json_objects) == 0:
|
||||
return file_content
|
||||
|
||||
modified_lines = []
|
||||
for json_object in json_objects:
|
||||
# Replace the model name if it exists
|
||||
if "body" in json_object:
|
||||
json_object["body"]["model"] = new_model_name
|
||||
|
||||
# Convert the modified JSON object back to a string
|
||||
modified_lines.append(json.dumps(json_object))
|
||||
|
||||
# Reassemble the modified lines and return as bytes
|
||||
modified_file_content = "\n".join(modified_lines).encode("utf-8")
|
||||
|
||||
return InMemoryFile(modified_file_content, name="modified_file.jsonl", content_type="application/jsonl") # type: ignore
|
||||
|
||||
except (json.JSONDecodeError, UnicodeDecodeError, TypeError):
|
||||
# return the original file content if there is an error replacing the model name
|
||||
return file_content
|
||||
|
||||
|
||||
def _get_router_metadata_variable_name(function_name: Optional[str]) -> str:
|
||||
"""
|
||||
Helper to return what the "metadata" field should be called in the request data
|
||||
|
||||
For all /thread or /assistant endpoints we need to call this "litellm_metadata"
|
||||
|
||||
For ALL other endpoints we call this "metadata
|
||||
"""
|
||||
ROUTER_METHODS_USING_LITELLM_METADATA = set(
|
||||
[
|
||||
"batch",
|
||||
"generic_api_call",
|
||||
"_acreate_batch",
|
||||
"file",
|
||||
"_ageneric_api_call_with_fallbacks",
|
||||
]
|
||||
)
|
||||
if function_name and any(
|
||||
method in function_name for method in ROUTER_METHODS_USING_LITELLM_METADATA
|
||||
):
|
||||
return "litellm_metadata"
|
||||
else:
|
||||
return "metadata"
|
||||
@@ -0,0 +1,37 @@
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from litellm.utils import calculate_max_parallel_requests
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
|
||||
|
||||
class InitalizeCachedClient:
|
||||
@staticmethod
|
||||
def set_max_parallel_requests_client(
|
||||
litellm_router_instance: LitellmRouter, model: dict
|
||||
):
|
||||
litellm_params = model.get("litellm_params", {})
|
||||
model_id = model["model_info"]["id"]
|
||||
rpm = litellm_params.get("rpm", None)
|
||||
tpm = litellm_params.get("tpm", None)
|
||||
max_parallel_requests = litellm_params.get("max_parallel_requests", None)
|
||||
calculated_max_parallel_requests = calculate_max_parallel_requests(
|
||||
rpm=rpm,
|
||||
max_parallel_requests=max_parallel_requests,
|
||||
tpm=tpm,
|
||||
default_max_parallel_requests=litellm_router_instance.default_max_parallel_requests,
|
||||
)
|
||||
if calculated_max_parallel_requests:
|
||||
semaphore = asyncio.Semaphore(calculated_max_parallel_requests)
|
||||
cache_key = f"{model_id}_max_parallel_requests_client"
|
||||
litellm_router_instance.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=semaphore,
|
||||
local_only=True,
|
||||
)
|
||||
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Utils for handling clientside credentials
|
||||
|
||||
Supported clientside credentials:
|
||||
- api_key
|
||||
- api_base
|
||||
- base_url
|
||||
|
||||
If given, generate a unique model_id for the deployment.
|
||||
|
||||
Ensures cooldowns are applied correctly.
|
||||
"""
|
||||
|
||||
clientside_credential_keys = ["api_key", "api_base", "base_url"]
|
||||
|
||||
|
||||
def is_clientside_credential(request_kwargs: dict) -> bool:
|
||||
"""
|
||||
Check if the credential is a clientside credential.
|
||||
"""
|
||||
return any(key in request_kwargs for key in clientside_credential_keys)
|
||||
|
||||
|
||||
def get_dynamic_litellm_params(litellm_params: dict, request_kwargs: dict) -> dict:
|
||||
"""
|
||||
Generate a unique model_id for the deployment.
|
||||
|
||||
Returns
|
||||
- litellm_params: dict
|
||||
|
||||
for generating a unique model_id.
|
||||
"""
|
||||
# update litellm_params with clientside credentials
|
||||
for key in clientside_credential_keys:
|
||||
if key in request_kwargs:
|
||||
litellm_params[key] = request_kwargs[key]
|
||||
return litellm_params
|
||||
@@ -0,0 +1,130 @@
|
||||
import hashlib
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.llms.openai import OpenAIFileObject
|
||||
|
||||
from litellm.types.router import CredentialLiteLLMParams
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
def get_litellm_params_sensitive_credential_hash(litellm_params: dict) -> str:
|
||||
"""
|
||||
Hash of the credential params, used for mapping the file id to the right model
|
||||
"""
|
||||
sensitive_params = CredentialLiteLLMParams(**litellm_params)
|
||||
return hashlib.sha256(
|
||||
json.dumps(sensitive_params.model_dump()).encode()
|
||||
).hexdigest()
|
||||
|
||||
|
||||
def add_model_file_id_mappings(
|
||||
healthy_deployments: Union[List[Dict], Dict], responses: List["OpenAIFileObject"]
|
||||
) -> dict:
|
||||
"""
|
||||
Create a mapping of model name to file id
|
||||
{
|
||||
"model_id": "file_id",
|
||||
"model_id": "file_id",
|
||||
}
|
||||
"""
|
||||
model_file_id_mapping = {}
|
||||
if isinstance(healthy_deployments, list):
|
||||
for deployment, response in zip(healthy_deployments, responses):
|
||||
model_file_id_mapping[
|
||||
deployment.get("model_info", {}).get("id")
|
||||
] = response.id
|
||||
elif isinstance(healthy_deployments, dict):
|
||||
for model_id, file_id in healthy_deployments.items():
|
||||
model_file_id_mapping[model_id] = file_id
|
||||
return model_file_id_mapping
|
||||
|
||||
|
||||
def filter_team_based_models(
|
||||
healthy_deployments: Union[List[Dict], Dict],
|
||||
request_kwargs: Optional[Dict] = None,
|
||||
) -> Union[List[Dict], Dict]:
|
||||
"""
|
||||
If a model has a team_id
|
||||
|
||||
Only use if request is from that team
|
||||
"""
|
||||
if request_kwargs is None:
|
||||
return healthy_deployments
|
||||
|
||||
metadata = request_kwargs.get("metadata") or {}
|
||||
litellm_metadata = request_kwargs.get("litellm_metadata") or {}
|
||||
request_team_id = metadata.get("user_api_key_team_id") or litellm_metadata.get(
|
||||
"user_api_key_team_id"
|
||||
)
|
||||
ids_to_remove = set()
|
||||
if isinstance(healthy_deployments, dict):
|
||||
return healthy_deployments
|
||||
for deployment in healthy_deployments:
|
||||
_model_info = deployment.get("model_info") or {}
|
||||
model_team_id = _model_info.get("team_id")
|
||||
if model_team_id is None:
|
||||
continue
|
||||
if model_team_id != request_team_id:
|
||||
ids_to_remove.add(_model_info.get("id"))
|
||||
|
||||
return [
|
||||
deployment
|
||||
for deployment in healthy_deployments
|
||||
if deployment.get("model_info", {}).get("id") not in ids_to_remove
|
||||
]
|
||||
|
||||
|
||||
def _deployment_supports_web_search(deployment: Dict) -> bool:
|
||||
"""
|
||||
Check if a deployment supports web search.
|
||||
|
||||
Priority:
|
||||
1. Check config-level override in model_info.supports_web_search
|
||||
2. Default to True (assume supported unless explicitly disabled)
|
||||
|
||||
Note: Ideally we'd fall back to litellm.supports_web_search() but
|
||||
model_prices_and_context_window.json doesn't have supports_web_search
|
||||
tags on all models yet. TODO: backfill and add fallback.
|
||||
"""
|
||||
model_info = deployment.get("model_info", {})
|
||||
|
||||
if "supports_web_search" in model_info:
|
||||
return model_info["supports_web_search"]
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def filter_web_search_deployments(
|
||||
healthy_deployments: Union[List[Dict], Dict],
|
||||
request_kwargs: Optional[Dict] = None,
|
||||
) -> Union[List[Dict], Dict]:
|
||||
"""
|
||||
If the request is websearch, filter out deployments that don't support web search
|
||||
"""
|
||||
if request_kwargs is None:
|
||||
return healthy_deployments
|
||||
# When a specific deployment was already chosen, it's returned as a dict
|
||||
# rather than a list - nothing to filter, just pass through
|
||||
if isinstance(healthy_deployments, dict):
|
||||
return healthy_deployments
|
||||
|
||||
is_web_search_request = False
|
||||
tools = request_kwargs.get("tools") or []
|
||||
for tool in tools:
|
||||
# These are the two websearch tools for OpenAI / Azure.
|
||||
if tool.get("type") == "web_search" or tool.get("type") == "web_search_preview":
|
||||
is_web_search_request = True
|
||||
break
|
||||
|
||||
if not is_web_search_request:
|
||||
return healthy_deployments
|
||||
|
||||
# Filter out deployments that don't support web search
|
||||
final_deployments = [
|
||||
d for d in healthy_deployments if _deployment_supports_web_search(d)
|
||||
]
|
||||
if len(healthy_deployments) > 0 and len(final_deployments) == 0:
|
||||
verbose_logger.warning("No deployments support web search for request")
|
||||
return final_deployments
|
||||
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
Wrapper around router cache. Meant to handle model cooldown logic
|
||||
"""
|
||||
|
||||
import functools
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.in_memory_cache import InMemoryCache
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class CooldownCacheValue(TypedDict):
|
||||
exception_received: str
|
||||
status_code: str
|
||||
timestamp: float
|
||||
cooldown_time: float
|
||||
|
||||
|
||||
class CooldownCache:
|
||||
def __init__(self, cache: DualCache, default_cooldown_time: float):
|
||||
self.cache = cache
|
||||
self.default_cooldown_time = default_cooldown_time
|
||||
self.in_memory_cache = InMemoryCache()
|
||||
# Initialize the masker with custom settings for exception strings
|
||||
self.exception_masker = SensitiveDataMasker(
|
||||
visible_prefix=50, # Show first 50 characters
|
||||
visible_suffix=0, # Show last 0 characters
|
||||
mask_char="*", # Use * for masking
|
||||
)
|
||||
|
||||
def _common_add_cooldown_logic(
|
||||
self, model_id: str, original_exception, exception_status, cooldown_time: float
|
||||
) -> Tuple[str, CooldownCacheValue]:
|
||||
try:
|
||||
current_time = time.time()
|
||||
cooldown_key = CooldownCache.get_cooldown_cache_key(model_id)
|
||||
|
||||
# Store the cooldown information for the deployment separately
|
||||
cooldown_data = CooldownCacheValue(
|
||||
exception_received=self.exception_masker._mask_value(
|
||||
str(original_exception)
|
||||
),
|
||||
status_code=str(exception_status),
|
||||
timestamp=current_time,
|
||||
cooldown_time=cooldown_time,
|
||||
)
|
||||
|
||||
return cooldown_key, cooldown_data
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
"CooldownCache::_common_add_cooldown_logic - Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
def add_deployment_to_cooldown(
|
||||
self,
|
||||
model_id: str,
|
||||
original_exception: Exception,
|
||||
exception_status: int,
|
||||
cooldown_time: Optional[float],
|
||||
):
|
||||
try:
|
||||
#########################################################
|
||||
# get cooldown time
|
||||
# 1. If dynamic cooldown time is set for the model/deployment, use that
|
||||
# 2. If no dynamic cooldown time is set, use the default cooldown time set on CooldownCache
|
||||
_cooldown_time = cooldown_time
|
||||
if _cooldown_time is None:
|
||||
_cooldown_time = self.default_cooldown_time
|
||||
#########################################################
|
||||
|
||||
cooldown_key, cooldown_data = self._common_add_cooldown_logic(
|
||||
model_id=model_id,
|
||||
original_exception=original_exception,
|
||||
exception_status=exception_status,
|
||||
cooldown_time=_cooldown_time,
|
||||
)
|
||||
|
||||
# Set the cache with a TTL equal to the cooldown time
|
||||
self.cache.set_cache(
|
||||
value=cooldown_data,
|
||||
key=cooldown_key,
|
||||
ttl=_cooldown_time,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
"CooldownCache::add_deployment_to_cooldown - Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=1024)
|
||||
def get_cooldown_cache_key(model_id: str) -> str:
|
||||
return "deployment:" + model_id + ":cooldown"
|
||||
|
||||
async def async_get_active_cooldowns(
|
||||
self, model_ids: List[str], parent_otel_span: Optional[Span]
|
||||
) -> List[Tuple[str, CooldownCacheValue]]:
|
||||
# Generate the keys for the deployments
|
||||
keys = [
|
||||
CooldownCache.get_cooldown_cache_key(model_id) for model_id in model_ids
|
||||
]
|
||||
|
||||
# Retrieve the values for the keys using mget
|
||||
## more likely to be none if no models ratelimited. So just check redis every 1s
|
||||
## each redis call adds ~100ms latency.
|
||||
|
||||
## check in memory cache first
|
||||
results = await self.cache.async_batch_get_cache(
|
||||
keys=keys, parent_otel_span=parent_otel_span
|
||||
)
|
||||
active_cooldowns: List[Tuple[str, CooldownCacheValue]] = []
|
||||
|
||||
if results is None or all(v is None for v in results):
|
||||
return active_cooldowns
|
||||
|
||||
# Process the results
|
||||
for model_id, result in zip(model_ids, results):
|
||||
if result and isinstance(result, dict):
|
||||
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore
|
||||
active_cooldowns.append((model_id, cooldown_cache_value))
|
||||
|
||||
return active_cooldowns
|
||||
|
||||
def get_active_cooldowns(
|
||||
self, model_ids: List[str], parent_otel_span: Optional[Span]
|
||||
) -> List[Tuple[str, CooldownCacheValue]]:
|
||||
# Generate the keys for the deployments
|
||||
keys = [
|
||||
CooldownCache.get_cooldown_cache_key(model_id) for model_id in model_ids
|
||||
]
|
||||
# Retrieve the values for the keys using mget
|
||||
results = (
|
||||
self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
|
||||
or []
|
||||
)
|
||||
|
||||
active_cooldowns = []
|
||||
# Process the results
|
||||
for model_id, result in zip(model_ids, results):
|
||||
if result and isinstance(result, dict):
|
||||
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore
|
||||
active_cooldowns.append((model_id, cooldown_cache_value))
|
||||
|
||||
return active_cooldowns
|
||||
|
||||
def get_min_cooldown(
|
||||
self, model_ids: List[str], parent_otel_span: Optional[Span]
|
||||
) -> float:
|
||||
"""Return min cooldown time required for a group of model id's."""
|
||||
|
||||
# Generate the keys for the deployments
|
||||
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
|
||||
|
||||
# Retrieve the values for the keys using mget
|
||||
results = (
|
||||
self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
|
||||
or []
|
||||
)
|
||||
|
||||
min_cooldown_time: Optional[float] = None
|
||||
# Process the results
|
||||
for model_id, result in zip(model_ids, results):
|
||||
if result and isinstance(result, dict):
|
||||
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore
|
||||
if min_cooldown_time is None:
|
||||
min_cooldown_time = cooldown_cache_value["cooldown_time"]
|
||||
elif cooldown_cache_value["cooldown_time"] < min_cooldown_time:
|
||||
min_cooldown_time = cooldown_cache_value["cooldown_time"]
|
||||
|
||||
return min_cooldown_time or self.default_cooldown_time
|
||||
|
||||
|
||||
# Usage example:
|
||||
# cooldown_cache = CooldownCache(cache=your_cache_instance, cooldown_time=your_cooldown_time)
|
||||
# cooldown_cache.add_deployment_to_cooldown(deployment, original_exception, exception_status)
|
||||
# active_cooldowns = cooldown_cache.get_active_cooldowns()
|
||||
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Callbacks triggered on cooling down deployments
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
from litellm.integrations.prometheus import PrometheusLogger
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
PrometheusLogger = Any
|
||||
|
||||
|
||||
async def router_cooldown_event_callback(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
deployment_id: str,
|
||||
exception_status: Union[str, int],
|
||||
cooldown_time: Optional[float],
|
||||
):
|
||||
"""
|
||||
Callback triggered when a deployment is put into cooldown by litellm
|
||||
|
||||
- Updates deployment state on Prometheus
|
||||
- Increments cooldown metric for deployment on Prometheus
|
||||
"""
|
||||
verbose_logger.debug("In router_cooldown_event_callback - updating prometheus")
|
||||
_deployment = litellm_router_instance.get_deployment(model_id=deployment_id)
|
||||
if _deployment is None:
|
||||
verbose_logger.warning(
|
||||
f"in router_cooldown_event_callback but _deployment is None for deployment_id={deployment_id}. Doing nothing"
|
||||
)
|
||||
return
|
||||
_litellm_params = _deployment["litellm_params"]
|
||||
temp_litellm_params = copy.deepcopy(_litellm_params)
|
||||
temp_litellm_params = dict(temp_litellm_params)
|
||||
_model_name = _deployment.get("model_name", None) or ""
|
||||
_api_base = (
|
||||
litellm.get_api_base(model=_model_name, optional_params=temp_litellm_params)
|
||||
or ""
|
||||
)
|
||||
model_info = _deployment["model_info"]
|
||||
model_id = model_info.id
|
||||
|
||||
litellm_model_name = temp_litellm_params.get("model") or ""
|
||||
llm_provider = ""
|
||||
try:
|
||||
_, llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=litellm_model_name,
|
||||
custom_llm_provider=temp_litellm_params.get("custom_llm_provider"),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# get the prometheus logger from in memory loggers
|
||||
prometheusLogger: Optional[
|
||||
PrometheusLogger
|
||||
] = _get_prometheus_logger_from_callbacks()
|
||||
|
||||
if prometheusLogger is not None:
|
||||
prometheusLogger.set_deployment_complete_outage(
|
||||
litellm_model_name=_model_name,
|
||||
model_id=model_id,
|
||||
api_base=_api_base,
|
||||
api_provider=llm_provider,
|
||||
)
|
||||
|
||||
prometheusLogger.increment_deployment_cooled_down(
|
||||
litellm_model_name=_model_name,
|
||||
model_id=model_id,
|
||||
api_base=_api_base,
|
||||
api_provider=llm_provider,
|
||||
exception_status=str(exception_status),
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _get_prometheus_logger_from_callbacks() -> Optional[PrometheusLogger]:
|
||||
"""
|
||||
Checks if prometheus is a initalized callback, if yes returns it
|
||||
"""
|
||||
from litellm.integrations.prometheus import PrometheusLogger
|
||||
|
||||
if PrometheusLogger is None:
|
||||
return None
|
||||
|
||||
for _callback in litellm._async_success_callback:
|
||||
if isinstance(_callback, PrometheusLogger):
|
||||
return _callback
|
||||
for global_callback in litellm.callbacks:
|
||||
if isinstance(global_callback, PrometheusLogger):
|
||||
return global_callback
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,459 @@
|
||||
"""
|
||||
Router cooldown handlers
|
||||
- _set_cooldown_deployments: puts a deployment in the cooldown list
|
||||
- get_cooldown_deployments: returns the list of deployments in the cooldown list
|
||||
- async_get_cooldown_deployments: ASYNC: returns the list of deployments in the cooldown list
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.constants import (
|
||||
DEFAULT_COOLDOWN_TIME_SECONDS,
|
||||
DEFAULT_FAILURE_THRESHOLD_MINIMUM_REQUESTS,
|
||||
DEFAULT_FAILURE_THRESHOLD_PERCENT,
|
||||
SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD,
|
||||
)
|
||||
from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback
|
||||
|
||||
from .router_callbacks.track_deployment_metrics import (
|
||||
get_deployment_failures_for_current_minute,
|
||||
get_deployment_successes_for_current_minute,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
Span = Any
|
||||
|
||||
|
||||
def _is_cooldown_required(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
model_id: str,
|
||||
exception_status: Union[str, int],
|
||||
exception_str: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
A function to determine if a cooldown is required based on the exception status.
|
||||
|
||||
Parameters:
|
||||
model_id (str) The id of the model in the model list
|
||||
exception_status (Union[str, int]): The status of the exception.
|
||||
|
||||
Returns:
|
||||
bool: True if a cooldown is required, False otherwise.
|
||||
"""
|
||||
try:
|
||||
ignored_strings = ["APIConnectionError"]
|
||||
if (
|
||||
exception_str is not None
|
||||
): # don't cooldown on litellm api connection errors errors
|
||||
for ignored_string in ignored_strings:
|
||||
if ignored_string in exception_str:
|
||||
return False
|
||||
|
||||
if isinstance(exception_status, str):
|
||||
if len(exception_status) == 0:
|
||||
return False
|
||||
exception_status = int(exception_status)
|
||||
|
||||
if exception_status >= 400 and exception_status < 500:
|
||||
if exception_status == 429:
|
||||
# Cool down 429 Rate Limit Errors
|
||||
return True
|
||||
|
||||
elif exception_status == 401:
|
||||
# Cool down 401 Auth Errors
|
||||
return True
|
||||
|
||||
elif exception_status == 408:
|
||||
return True
|
||||
|
||||
elif exception_status == 404:
|
||||
return True
|
||||
|
||||
else:
|
||||
# Do NOT cool down all other 4XX Errors
|
||||
return False
|
||||
|
||||
else:
|
||||
# should cool down for all other errors
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
# Catch all - if any exceptions default to cooling down
|
||||
return True
|
||||
|
||||
|
||||
def _should_run_cooldown_logic(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
deployment: Optional[str],
|
||||
exception_status: Union[str, int],
|
||||
original_exception: Any,
|
||||
time_to_cooldown: Optional[float] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Helper that decides if cooldown logic should be run
|
||||
Returns False if cooldown logic should not be run
|
||||
|
||||
Does not run cooldown logic when:
|
||||
- router.disable_cooldowns is True
|
||||
- deployment is None
|
||||
- _is_cooldown_required() returns False
|
||||
- deployment is in litellm_router_instance.provider_default_deployment_ids
|
||||
- exception_status is not one that should be immediately retried (e.g. 401)
|
||||
"""
|
||||
if (
|
||||
deployment is None
|
||||
or litellm_router_instance.get_model_group(id=deployment) is None
|
||||
):
|
||||
verbose_router_logger.debug(
|
||||
"Should Not Run Cooldown Logic: deployment id is none or model group can't be found."
|
||||
)
|
||||
return False
|
||||
|
||||
#########################################################
|
||||
# If time_to_cooldown is 0 or 0.0000000, don't run cooldown logic
|
||||
#########################################################
|
||||
if time_to_cooldown is not None and math.isclose(
|
||||
a=time_to_cooldown, b=0.0, abs_tol=1e-9
|
||||
):
|
||||
verbose_router_logger.debug(
|
||||
"Should Not Run Cooldown Logic: time_to_cooldown is effectively 0"
|
||||
)
|
||||
return False
|
||||
|
||||
if litellm_router_instance.disable_cooldowns:
|
||||
verbose_router_logger.debug(
|
||||
"Should Not Run Cooldown Logic: disable_cooldowns is True"
|
||||
)
|
||||
return False
|
||||
|
||||
if deployment is None:
|
||||
verbose_router_logger.debug("Should Not Run Cooldown Logic: deployment is None")
|
||||
return False
|
||||
|
||||
if not _is_cooldown_required(
|
||||
litellm_router_instance=litellm_router_instance,
|
||||
model_id=deployment,
|
||||
exception_status=exception_status,
|
||||
exception_str=str(original_exception),
|
||||
):
|
||||
verbose_router_logger.debug(
|
||||
"Should Not Run Cooldown Logic: _is_cooldown_required returned False"
|
||||
)
|
||||
return False
|
||||
|
||||
if deployment in litellm_router_instance.provider_default_deployment_ids:
|
||||
verbose_router_logger.debug(
|
||||
"Should Not Run Cooldown Logic: deployment is in provider_default_deployment_ids"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _should_cooldown_deployment(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
deployment: str,
|
||||
exception_status: Union[str, int],
|
||||
original_exception: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
Helper that decides if a deployment should be put in cooldown
|
||||
|
||||
Returns True if the deployment should be put in cooldown
|
||||
Returns False if the deployment should not be put in cooldown
|
||||
|
||||
|
||||
Deployment is put in cooldown when:
|
||||
- v2 logic (Current):
|
||||
cooldown if:
|
||||
- got a 429 error from LLM API
|
||||
- if %fails/%(successes + fails) > ALLOWED_FAILURE_RATE_PER_MINUTE
|
||||
- got 401 Auth error, 404 NotFounder - checked by litellm._should_retry()
|
||||
|
||||
|
||||
|
||||
- v1 logic (Legacy): if allowed fails or allowed fail policy set, coolsdown if num fails in this minute > allowed fails
|
||||
"""
|
||||
## BASE CASE - single deployment
|
||||
model_group = litellm_router_instance.get_model_group(id=deployment)
|
||||
is_single_deployment_model_group = False
|
||||
if model_group is not None and len(model_group) == 1:
|
||||
is_single_deployment_model_group = True
|
||||
if (
|
||||
litellm_router_instance.allowed_fails_policy is None
|
||||
and _is_allowed_fails_set_on_router(
|
||||
litellm_router_instance=litellm_router_instance
|
||||
)
|
||||
is False
|
||||
):
|
||||
num_successes_this_minute = get_deployment_successes_for_current_minute(
|
||||
litellm_router_instance=litellm_router_instance, deployment_id=deployment
|
||||
)
|
||||
num_fails_this_minute = get_deployment_failures_for_current_minute(
|
||||
litellm_router_instance=litellm_router_instance, deployment_id=deployment
|
||||
)
|
||||
|
||||
total_requests_this_minute = num_successes_this_minute + num_fails_this_minute
|
||||
percent_fails = 0.0
|
||||
if total_requests_this_minute > 0:
|
||||
percent_fails = num_fails_this_minute / (
|
||||
num_successes_this_minute + num_fails_this_minute
|
||||
)
|
||||
verbose_router_logger.debug(
|
||||
"percent fails for deployment = %s, percent fails = %s, num successes = %s, num fails = %s",
|
||||
deployment,
|
||||
percent_fails,
|
||||
num_successes_this_minute,
|
||||
num_fails_this_minute,
|
||||
)
|
||||
|
||||
exception_status_int = cast_exception_status_to_int(exception_status)
|
||||
if exception_status_int == 429 and not is_single_deployment_model_group:
|
||||
return True
|
||||
elif (
|
||||
percent_fails == 1.0
|
||||
and total_requests_this_minute
|
||||
>= SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD
|
||||
):
|
||||
# Cooldown if all requests failed and we have reasonable traffic
|
||||
return True
|
||||
elif (
|
||||
percent_fails > DEFAULT_FAILURE_THRESHOLD_PERCENT
|
||||
and total_requests_this_minute >= DEFAULT_FAILURE_THRESHOLD_MINIMUM_REQUESTS
|
||||
and not is_single_deployment_model_group # by default we should avoid cooldowns on single deployment model groups
|
||||
):
|
||||
# Only apply error rate cooldown when we have enough requests to make the percentage meaningful
|
||||
return True
|
||||
|
||||
elif (
|
||||
litellm._should_retry(
|
||||
status_code=cast_exception_status_to_int(exception_status)
|
||||
)
|
||||
is False
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
else:
|
||||
return should_cooldown_based_on_allowed_fails_policy(
|
||||
litellm_router_instance=litellm_router_instance,
|
||||
deployment=deployment,
|
||||
original_exception=original_exception,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _set_cooldown_deployments(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
original_exception: Any,
|
||||
exception_status: Union[str, int],
|
||||
deployment: Optional[str] = None,
|
||||
time_to_cooldown: Optional[float] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute
|
||||
|
||||
or
|
||||
|
||||
the exception is not one that should be immediately retried (e.g. 401)
|
||||
|
||||
Returns:
|
||||
- True if the deployment should be put in cooldown
|
||||
- False if the deployment should not be put in cooldown
|
||||
"""
|
||||
verbose_router_logger.debug("checks 'should_run_cooldown_logic'")
|
||||
|
||||
if (
|
||||
_should_run_cooldown_logic(
|
||||
litellm_router_instance=litellm_router_instance,
|
||||
deployment=deployment,
|
||||
exception_status=exception_status,
|
||||
original_exception=original_exception,
|
||||
time_to_cooldown=time_to_cooldown,
|
||||
)
|
||||
is False
|
||||
or deployment is None
|
||||
):
|
||||
verbose_router_logger.debug("should_run_cooldown_logic returned False")
|
||||
return False
|
||||
|
||||
exception_status_int = cast_exception_status_to_int(exception_status)
|
||||
verbose_router_logger.debug(f"Attempting to add {deployment} to cooldown list")
|
||||
|
||||
if _should_cooldown_deployment(
|
||||
litellm_router_instance=litellm_router_instance,
|
||||
deployment=deployment,
|
||||
exception_status=exception_status,
|
||||
original_exception=original_exception,
|
||||
):
|
||||
litellm_router_instance.cooldown_cache.add_deployment_to_cooldown(
|
||||
model_id=deployment,
|
||||
original_exception=original_exception,
|
||||
exception_status=exception_status_int,
|
||||
cooldown_time=time_to_cooldown,
|
||||
)
|
||||
|
||||
# Trigger cooldown callback handler
|
||||
asyncio.create_task(
|
||||
router_cooldown_event_callback(
|
||||
litellm_router_instance=litellm_router_instance,
|
||||
deployment_id=deployment,
|
||||
exception_status=exception_status,
|
||||
cooldown_time=time_to_cooldown,
|
||||
)
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _async_get_cooldown_deployments(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
parent_otel_span: Optional[Span],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Async implementation of '_get_cooldown_deployments'
|
||||
"""
|
||||
model_ids = litellm_router_instance.get_model_ids()
|
||||
cooldown_models = (
|
||||
await litellm_router_instance.cooldown_cache.async_get_active_cooldowns(
|
||||
model_ids=model_ids,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
)
|
||||
|
||||
cached_value_deployment_ids = []
|
||||
if (
|
||||
cooldown_models is not None
|
||||
and isinstance(cooldown_models, list)
|
||||
and len(cooldown_models) > 0
|
||||
and isinstance(cooldown_models[0], tuple)
|
||||
):
|
||||
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
|
||||
|
||||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||
return cached_value_deployment_ids
|
||||
|
||||
|
||||
async def _async_get_cooldown_deployments_with_debug_info(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
parent_otel_span: Optional[Span],
|
||||
) -> List[tuple]:
|
||||
"""
|
||||
Async implementation of '_get_cooldown_deployments'
|
||||
"""
|
||||
model_ids = litellm_router_instance.get_model_ids()
|
||||
cooldown_models = (
|
||||
await litellm_router_instance.cooldown_cache.async_get_active_cooldowns(
|
||||
model_ids=model_ids, parent_otel_span=parent_otel_span
|
||||
)
|
||||
)
|
||||
|
||||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||
return cooldown_models
|
||||
|
||||
|
||||
def _get_cooldown_deployments(
|
||||
litellm_router_instance: LitellmRouter, parent_otel_span: Optional[Span]
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get the list of models being cooled down for this minute
|
||||
"""
|
||||
# get the current cooldown list for that minute
|
||||
|
||||
# ----------------------
|
||||
# Return cooldown models
|
||||
# ----------------------
|
||||
model_ids = litellm_router_instance.get_model_ids()
|
||||
|
||||
cooldown_models = litellm_router_instance.cooldown_cache.get_active_cooldowns(
|
||||
model_ids=model_ids, parent_otel_span=parent_otel_span
|
||||
)
|
||||
|
||||
cached_value_deployment_ids = []
|
||||
if (
|
||||
cooldown_models is not None
|
||||
and isinstance(cooldown_models, list)
|
||||
and len(cooldown_models) > 0
|
||||
and isinstance(cooldown_models[0], tuple)
|
||||
):
|
||||
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
|
||||
|
||||
return cached_value_deployment_ids
|
||||
|
||||
|
||||
def should_cooldown_based_on_allowed_fails_policy(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
deployment: str,
|
||||
original_exception: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if fails are within the allowed limit and update the number of fails.
|
||||
|
||||
Returns:
|
||||
- True if fails exceed the allowed limit (should cooldown)
|
||||
- False if fails are within the allowed limit (should not cooldown)
|
||||
"""
|
||||
allowed_fails = (
|
||||
litellm_router_instance.get_allowed_fails_from_policy(
|
||||
exception=original_exception,
|
||||
)
|
||||
or litellm_router_instance.allowed_fails
|
||||
)
|
||||
cooldown_time = (
|
||||
litellm_router_instance.cooldown_time or DEFAULT_COOLDOWN_TIME_SECONDS
|
||||
)
|
||||
|
||||
current_fails = litellm_router_instance.failed_calls.get_cache(key=deployment) or 0
|
||||
updated_fails = current_fails + 1
|
||||
|
||||
if updated_fails > allowed_fails:
|
||||
return True
|
||||
else:
|
||||
litellm_router_instance.failed_calls.set_cache(
|
||||
key=deployment, value=updated_fails, ttl=cooldown_time
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_allowed_fails_set_on_router(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if Router.allowed_fails is set or is Non-default Value
|
||||
|
||||
Returns:
|
||||
- True if Router.allowed_fails is set or is Non-default Value
|
||||
- False if Router.allowed_fails is None or is Default Value
|
||||
"""
|
||||
if litellm_router_instance.allowed_fails is None:
|
||||
return False
|
||||
if litellm_router_instance.allowed_fails != litellm.allowed_fails:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def cast_exception_status_to_int(exception_status: Union[str, int]) -> int:
|
||||
if isinstance(exception_status, str):
|
||||
try:
|
||||
exception_status = int(exception_status)
|
||||
except Exception:
|
||||
verbose_router_logger.debug(
|
||||
f"Unable to cast exception status to int {exception_status}. Defaulting to status=500."
|
||||
)
|
||||
exception_status = 500
|
||||
return exception_status
|
||||
@@ -0,0 +1,257 @@
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.router_utils.add_retry_fallback_headers import (
|
||||
add_fallback_headers_to_response,
|
||||
)
|
||||
from litellm.types.router import LiteLLMParamsTypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
|
||||
|
||||
def _check_stripped_model_group(model_group: str, fallback_key: str) -> bool:
|
||||
"""
|
||||
Handles wildcard routing scenario
|
||||
|
||||
where fallbacks set like:
|
||||
[{"gpt-3.5-turbo": ["claude-3-haiku"]}]
|
||||
|
||||
but model_group is like:
|
||||
"openai/gpt-3.5-turbo"
|
||||
|
||||
Returns:
|
||||
- True if the stripped model group == fallback_key
|
||||
"""
|
||||
for provider in litellm.provider_list:
|
||||
if isinstance(provider, Enum):
|
||||
_provider = provider.value
|
||||
else:
|
||||
_provider = provider
|
||||
if model_group.startswith(f"{_provider}/"):
|
||||
stripped_model_group = model_group.replace(f"{_provider}/", "")
|
||||
if stripped_model_group == fallback_key:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_fallback_model_group(
|
||||
fallbacks: List[Any], model_group: str
|
||||
) -> Tuple[Optional[List[str]], Optional[int]]:
|
||||
"""
|
||||
Returns:
|
||||
- fallback_model_group: List[str] of fallback model groups. example: ["gpt-4", "gpt-3.5-turbo"]
|
||||
- generic_fallback_idx: int of the index of the generic fallback in the fallbacks list.
|
||||
|
||||
Checks:
|
||||
- exact match
|
||||
- stripped model group match
|
||||
- generic fallback
|
||||
"""
|
||||
generic_fallback_idx: Optional[int] = None
|
||||
stripped_model_fallback: Optional[List[str]] = None
|
||||
fallback_model_group: Optional[List[str]] = None
|
||||
## check for specific model group-specific fallbacks
|
||||
for idx, item in enumerate(fallbacks):
|
||||
if isinstance(item, dict):
|
||||
if list(item.keys())[0] == model_group: # check exact match
|
||||
fallback_model_group = item[model_group]
|
||||
break
|
||||
elif _check_stripped_model_group(
|
||||
model_group=model_group, fallback_key=list(item.keys())[0]
|
||||
): # check generic fallback
|
||||
stripped_model_fallback = item[list(item.keys())[0]]
|
||||
elif list(item.keys())[0] == "*": # check generic fallback
|
||||
generic_fallback_idx = idx
|
||||
elif isinstance(item, str):
|
||||
fallback_model_group = [fallbacks.pop(idx)] # returns single-item list
|
||||
## if none, check for generic fallback
|
||||
if fallback_model_group is None:
|
||||
if stripped_model_fallback is not None:
|
||||
fallback_model_group = stripped_model_fallback
|
||||
elif generic_fallback_idx is not None:
|
||||
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
|
||||
|
||||
return fallback_model_group, generic_fallback_idx
|
||||
|
||||
|
||||
async def run_async_fallback(
|
||||
*args: Tuple[Any],
|
||||
litellm_router: LitellmRouter,
|
||||
fallback_model_group: List[str],
|
||||
original_model_group: str,
|
||||
original_exception: Exception,
|
||||
max_fallbacks: int,
|
||||
fallback_depth: int,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Loops through all the fallback model groups and calls kwargs["original_function"] with the arguments and keyword arguments provided.
|
||||
|
||||
If the call is successful, it logs the success and returns the response.
|
||||
If the call fails, it logs the failure and continues to the next fallback model group.
|
||||
If all fallback model groups fail, it raises the most recent exception.
|
||||
|
||||
Args:
|
||||
litellm_router: The litellm router instance.
|
||||
*args: Positional arguments.
|
||||
fallback_model_group: List[str] of fallback model groups. example: ["gpt-4", "gpt-3.5-turbo"]
|
||||
original_model_group: The original model group. example: "gpt-3.5-turbo"
|
||||
original_exception: The original exception.
|
||||
**kwargs: Keyword arguments.
|
||||
|
||||
Returns:
|
||||
The response from the successful fallback model group.
|
||||
Raises:
|
||||
The most recent exception if all fallback model groups fail.
|
||||
"""
|
||||
|
||||
### BASE CASE ### MAX FALLBACK DEPTH REACHED
|
||||
if fallback_depth >= max_fallbacks:
|
||||
raise original_exception
|
||||
|
||||
error_from_fallbacks = original_exception
|
||||
|
||||
for mg in fallback_model_group:
|
||||
if mg == original_model_group:
|
||||
continue
|
||||
try:
|
||||
# LOGGING
|
||||
kwargs = litellm_router.log_retry(kwargs=kwargs, e=original_exception)
|
||||
verbose_router_logger.info(f"Falling back to model_group = {mg}")
|
||||
if isinstance(mg, str):
|
||||
kwargs["model"] = mg
|
||||
elif isinstance(mg, dict):
|
||||
kwargs.update(mg)
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"model_group": kwargs.get("model", None)}
|
||||
) # update model_group used, if fallbacks are done
|
||||
fallback_depth = fallback_depth + 1
|
||||
kwargs["fallback_depth"] = fallback_depth
|
||||
kwargs["max_fallbacks"] = max_fallbacks
|
||||
response = await litellm_router.async_function_with_fallbacks(
|
||||
*args, **kwargs
|
||||
)
|
||||
verbose_router_logger.info("Successful fallback b/w models.")
|
||||
response = add_fallback_headers_to_response(
|
||||
response=response,
|
||||
attempted_fallbacks=fallback_depth,
|
||||
)
|
||||
# callback for successfull_fallback_event():
|
||||
await log_success_fallback_event(
|
||||
original_model_group=original_model_group,
|
||||
kwargs=kwargs,
|
||||
original_exception=original_exception,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
error_from_fallbacks = e
|
||||
await log_failure_fallback_event(
|
||||
original_model_group=original_model_group,
|
||||
kwargs=kwargs,
|
||||
original_exception=original_exception,
|
||||
)
|
||||
raise error_from_fallbacks
|
||||
|
||||
|
||||
async def log_success_fallback_event(
|
||||
original_model_group: str, kwargs: dict, original_exception: Exception
|
||||
):
|
||||
"""
|
||||
Log a successful fallback event to all registered callbacks.
|
||||
|
||||
Uses LoggingCallbackManager.get_custom_loggers_for_type() to get deduplicated
|
||||
CustomLogger instances from all callback lists.
|
||||
|
||||
Args:
|
||||
original_model_group (str): The original model group before fallback.
|
||||
kwargs (dict): kwargs for the request
|
||||
|
||||
Note:
|
||||
Errors during logging are caught and reported but do not interrupt the process.
|
||||
"""
|
||||
# Get deduplicated CustomLogger instances from all callback lists
|
||||
custom_loggers = litellm.logging_callback_manager.get_custom_loggers_for_type(
|
||||
CustomLogger
|
||||
)
|
||||
|
||||
for _callback_custom_logger in custom_loggers:
|
||||
try:
|
||||
await _callback_custom_logger.log_success_fallback_event(
|
||||
original_model_group=original_model_group,
|
||||
kwargs=kwargs,
|
||||
original_exception=original_exception,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
f"Error in log_success_fallback_event: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
async def log_failure_fallback_event(
|
||||
original_model_group: str, kwargs: dict, original_exception: Exception
|
||||
):
|
||||
"""
|
||||
Log a failed fallback event to all registered callbacks.
|
||||
|
||||
Uses LoggingCallbackManager.get_custom_loggers_for_type() to get deduplicated
|
||||
CustomLogger instances from all callback lists.
|
||||
|
||||
Args:
|
||||
original_model_group (str): The original model group before fallback.
|
||||
kwargs (dict): kwargs for the request
|
||||
|
||||
Note:
|
||||
Errors during logging are caught and reported but do not interrupt the process.
|
||||
"""
|
||||
# Get deduplicated CustomLogger instances from all callback lists
|
||||
custom_loggers = litellm.logging_callback_manager.get_custom_loggers_for_type(
|
||||
CustomLogger
|
||||
)
|
||||
|
||||
for _callback_custom_logger in custom_loggers:
|
||||
try:
|
||||
await _callback_custom_logger.log_failure_fallback_event(
|
||||
original_model_group=original_model_group,
|
||||
kwargs=kwargs,
|
||||
original_exception=original_exception,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
f"Error in log_failure_fallback_event: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def _check_non_standard_fallback_format(fallbacks: Optional[List[Any]]) -> bool:
|
||||
"""
|
||||
Checks if the fallbacks list is a list of strings or a list of dictionaries.
|
||||
|
||||
If
|
||||
- List[str]: e.g. ["claude-3-haiku", "openai/o-1"]
|
||||
- List[Dict[<LiteLLMParamsTypedDict>, Any]]: e.g. [{"model": "claude-3-haiku", "messages": [{"role": "user", "content": "Hey, how's it going?"}]}]
|
||||
|
||||
If [{"gpt-3.5-turbo": ["claude-3-haiku"]}] then standard format.
|
||||
"""
|
||||
if fallbacks is None or not isinstance(fallbacks, list) or len(fallbacks) == 0:
|
||||
return False
|
||||
if all(isinstance(item, str) for item in fallbacks):
|
||||
return True
|
||||
elif all(isinstance(item, dict) for item in fallbacks):
|
||||
for key in LiteLLMParamsTypedDict.__annotations__.keys():
|
||||
if key in fallbacks[0].keys():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def run_non_standard_fallback_format(
|
||||
fallbacks: Union[List[str], List[Dict[str, Any]]], model_group: str
|
||||
):
|
||||
pass
|
||||
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Get num retries for an exception.
|
||||
|
||||
- Account for retry policy by exception type.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from litellm.exceptions import (
|
||||
AuthenticationError,
|
||||
BadRequestError,
|
||||
ContentPolicyViolationError,
|
||||
RateLimitError,
|
||||
Timeout,
|
||||
)
|
||||
from litellm.types.router import RetryPolicy
|
||||
|
||||
|
||||
def get_num_retries_from_retry_policy(
|
||||
exception: Exception,
|
||||
retry_policy: Optional[Union[RetryPolicy, dict]] = None,
|
||||
model_group: Optional[str] = None,
|
||||
model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = None,
|
||||
):
|
||||
"""
|
||||
BadRequestErrorRetries: Optional[int] = None
|
||||
AuthenticationErrorRetries: Optional[int] = None
|
||||
TimeoutErrorRetries: Optional[int] = None
|
||||
RateLimitErrorRetries: Optional[int] = None
|
||||
ContentPolicyViolationErrorRetries: Optional[int] = None
|
||||
"""
|
||||
# if we can find the exception then in the retry policy -> return the number of retries
|
||||
|
||||
if (
|
||||
model_group_retry_policy is not None
|
||||
and model_group is not None
|
||||
and model_group in model_group_retry_policy
|
||||
):
|
||||
retry_policy = model_group_retry_policy.get(model_group, None) # type: ignore
|
||||
|
||||
if retry_policy is None:
|
||||
return None
|
||||
if isinstance(retry_policy, dict):
|
||||
retry_policy = RetryPolicy(**retry_policy)
|
||||
|
||||
if (
|
||||
isinstance(exception, AuthenticationError)
|
||||
and retry_policy.AuthenticationErrorRetries is not None
|
||||
):
|
||||
return retry_policy.AuthenticationErrorRetries
|
||||
if isinstance(exception, Timeout) and retry_policy.TimeoutErrorRetries is not None:
|
||||
return retry_policy.TimeoutErrorRetries
|
||||
if (
|
||||
isinstance(exception, RateLimitError)
|
||||
and retry_policy.RateLimitErrorRetries is not None
|
||||
):
|
||||
return retry_policy.RateLimitErrorRetries
|
||||
if (
|
||||
isinstance(exception, ContentPolicyViolationError)
|
||||
and retry_policy.ContentPolicyViolationErrorRetries is not None
|
||||
):
|
||||
return retry_policy.ContentPolicyViolationErrorRetries
|
||||
if (
|
||||
isinstance(exception, BadRequestError)
|
||||
and retry_policy.BadRequestErrorRetries is not None
|
||||
):
|
||||
return retry_policy.BadRequestErrorRetries
|
||||
|
||||
|
||||
def reset_retry_policy() -> RetryPolicy:
|
||||
return RetryPolicy()
|
||||
@@ -0,0 +1,95 @@
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.constants import MAX_EXCEPTION_MESSAGE_LENGTH
|
||||
from litellm.router_utils.cooldown_handlers import (
|
||||
_async_get_cooldown_deployments_with_debug_info,
|
||||
)
|
||||
from litellm.types.integrations.slack_alerting import AlertType
|
||||
from litellm.types.router import RouterRateLimitError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
Span = Any
|
||||
|
||||
|
||||
async def send_llm_exception_alert(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
request_kwargs: dict,
|
||||
error_traceback_str: str,
|
||||
original_exception,
|
||||
):
|
||||
"""
|
||||
Only runs if router.slack_alerting_logger is set
|
||||
Sends a Slack / MS Teams alert for the LLM API call failure. Only if router.slack_alerting_logger is set.
|
||||
|
||||
Parameters:
|
||||
litellm_router_instance (_Router): The LitellmRouter instance.
|
||||
original_exception (Any): The original exception that occurred.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if litellm_router_instance is None:
|
||||
return
|
||||
|
||||
if not hasattr(litellm_router_instance, "slack_alerting_logger"):
|
||||
return
|
||||
|
||||
if litellm_router_instance.slack_alerting_logger is None:
|
||||
return
|
||||
|
||||
if "proxy_server_request" in request_kwargs:
|
||||
# Do not send any alert if it's a request from litellm proxy server request
|
||||
# the proxy is already instrumented to send LLM API call failures
|
||||
return
|
||||
|
||||
litellm_debug_info = getattr(original_exception, "litellm_debug_info", None)
|
||||
exception_str = str(original_exception)
|
||||
if litellm_debug_info is not None:
|
||||
exception_str += litellm_debug_info
|
||||
exception_str += f"\n\n{error_traceback_str[:MAX_EXCEPTION_MESSAGE_LENGTH]}"
|
||||
|
||||
await litellm_router_instance.slack_alerting_logger.send_alert(
|
||||
message=f"LLM API call failed: `{exception_str}`",
|
||||
level="High",
|
||||
alert_type=AlertType.llm_exceptions,
|
||||
alerting_metadata={},
|
||||
)
|
||||
|
||||
|
||||
async def async_raise_no_deployment_exception(
|
||||
litellm_router_instance: LitellmRouter, model: str, parent_otel_span: Optional[Span]
|
||||
):
|
||||
"""
|
||||
Raises a RouterRateLimitError if no deployment is found for the given model.
|
||||
"""
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, No deployment available"
|
||||
)
|
||||
model_ids = litellm_router_instance.get_model_ids(model_name=model)
|
||||
_cooldown_time = litellm_router_instance.cooldown_cache.get_min_cooldown(
|
||||
model_ids=model_ids, parent_otel_span=parent_otel_span
|
||||
)
|
||||
_cooldown_list = await _async_get_cooldown_deployments_with_debug_info(
|
||||
litellm_router_instance=litellm_router_instance,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
verbose_router_logger.info(
|
||||
f"No deployment found for model: {model}, cooldown_list with debug info: {_cooldown_list}"
|
||||
)
|
||||
|
||||
cooldown_list_ids = [cooldown_model[0] for cooldown_model in (_cooldown_list or [])]
|
||||
return RouterRateLimitError(
|
||||
model=model,
|
||||
cooldown_time=_cooldown_time,
|
||||
enable_pre_call_checks=litellm_router_instance.enable_pre_call_checks,
|
||||
cooldown_list=cooldown_list_ids,
|
||||
)
|
||||
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
Class to handle llm wildcard routing and regex pattern matching
|
||||
"""
|
||||
|
||||
import copy
|
||||
import re
|
||||
from re import Match
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm._logging import verbose_router_logger
|
||||
|
||||
|
||||
class PatternUtils:
|
||||
@staticmethod
|
||||
def calculate_pattern_specificity(pattern: str) -> Tuple[int, int]:
|
||||
"""
|
||||
Calculate pattern specificity based on length and complexity.
|
||||
|
||||
Args:
|
||||
pattern: Regex pattern to analyze
|
||||
|
||||
Returns:
|
||||
Tuple of (length, complexity) for sorting
|
||||
"""
|
||||
complexity_chars = ["*", "+", "?", "\\", "^", "$", "|", "(", ")"]
|
||||
ret_val = (
|
||||
len(pattern), # Longer patterns more specific
|
||||
sum(
|
||||
pattern.count(char) for char in complexity_chars
|
||||
), # More regex complexity
|
||||
)
|
||||
return ret_val
|
||||
|
||||
@staticmethod
|
||||
def sorted_patterns(
|
||||
patterns: Dict[str, List[Dict]]
|
||||
) -> List[Tuple[str, List[Dict]]]:
|
||||
"""
|
||||
Cached property for patterns sorted by specificity.
|
||||
|
||||
Returns:
|
||||
Sorted list of pattern-deployment tuples
|
||||
"""
|
||||
return sorted(
|
||||
patterns.items(),
|
||||
key=lambda x: PatternUtils.calculate_pattern_specificity(x[0]),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
|
||||
class PatternMatchRouter:
|
||||
"""
|
||||
Class to handle llm wildcard routing and regex pattern matching
|
||||
|
||||
doc: https://docs.litellm.ai/docs/proxy/configs#provider-specific-wildcard-routing
|
||||
|
||||
This class will store a mapping for regex pattern: List[Deployments]
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.patterns: Dict[str, List] = {}
|
||||
|
||||
def add_pattern(self, pattern: str, llm_deployment: Dict):
|
||||
"""
|
||||
Add a regex pattern and the corresponding llm deployments to the patterns
|
||||
|
||||
Args:
|
||||
pattern: str
|
||||
llm_deployment: str or List[str]
|
||||
"""
|
||||
# Convert the pattern to a regex
|
||||
regex = self._pattern_to_regex(pattern)
|
||||
if regex not in self.patterns:
|
||||
self.patterns[regex] = []
|
||||
self.patterns[regex].append(llm_deployment)
|
||||
|
||||
def _pattern_to_regex(self, pattern: str) -> str:
|
||||
"""
|
||||
Convert a wildcard pattern to a regex pattern
|
||||
|
||||
example:
|
||||
pattern: openai/*
|
||||
regex: openai/.*
|
||||
|
||||
pattern: openai/fo::*::static::*
|
||||
regex: openai/fo::.*::static::.*
|
||||
|
||||
Args:
|
||||
pattern: str
|
||||
|
||||
Returns:
|
||||
str: regex pattern
|
||||
"""
|
||||
# # Replace '*' with '.*' for regex matching
|
||||
# regex = pattern.replace("*", ".*")
|
||||
# # Escape other special characters
|
||||
# regex = re.escape(regex).replace(r"\.\*", ".*")
|
||||
# return f"^{regex}$"
|
||||
return re.escape(pattern).replace(r"\*", "(.*)")
|
||||
|
||||
def _return_pattern_matched_deployments(
|
||||
self, matched_pattern: Match, deployments: List[Dict]
|
||||
) -> List[Dict]:
|
||||
new_deployments = []
|
||||
for deployment in deployments:
|
||||
new_deployment = copy.deepcopy(deployment)
|
||||
new_deployment["litellm_params"][
|
||||
"model"
|
||||
] = PatternMatchRouter.set_deployment_model_name(
|
||||
matched_pattern=matched_pattern,
|
||||
litellm_deployment_litellm_model=deployment["litellm_params"]["model"],
|
||||
)
|
||||
new_deployments.append(new_deployment)
|
||||
|
||||
return new_deployments
|
||||
|
||||
def route(
|
||||
self, request: Optional[str], filtered_model_names: Optional[List[str]] = None
|
||||
) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Route a requested model to the corresponding llm deployments based on the regex pattern
|
||||
|
||||
loop through all the patterns and find the matching pattern
|
||||
if a pattern is found, return the corresponding llm deployments
|
||||
if no pattern is found, return None
|
||||
|
||||
Args:
|
||||
request: str - the received model name from the user (can be a wildcard route). If none, No deployments will be returned.
|
||||
filtered_model_names: Optional[List[str]] - if provided, only return deployments that match the filtered_model_names
|
||||
Returns:
|
||||
Optional[List[Deployment]]: llm deployments
|
||||
"""
|
||||
try:
|
||||
if request is None:
|
||||
return None
|
||||
|
||||
sorted_patterns = PatternUtils.sorted_patterns(self.patterns)
|
||||
regex_filtered_model_names = (
|
||||
[self._pattern_to_regex(m) for m in filtered_model_names]
|
||||
if filtered_model_names is not None
|
||||
else []
|
||||
)
|
||||
for pattern, llm_deployments in sorted_patterns:
|
||||
if (
|
||||
filtered_model_names is not None
|
||||
and pattern not in regex_filtered_model_names
|
||||
):
|
||||
continue
|
||||
pattern_match = re.match(pattern, request)
|
||||
if pattern_match:
|
||||
return self._return_pattern_matched_deployments(
|
||||
matched_pattern=pattern_match, deployments=llm_deployments
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_router_logger.debug(f"Error in PatternMatchRouter.route: {str(e)}")
|
||||
|
||||
return None # No matching pattern found
|
||||
|
||||
@staticmethod
|
||||
def set_deployment_model_name(
|
||||
matched_pattern: Match,
|
||||
litellm_deployment_litellm_model: str,
|
||||
) -> str:
|
||||
"""
|
||||
Set the model name for the matched pattern llm deployment
|
||||
|
||||
E.g.:
|
||||
|
||||
Case 1:
|
||||
model_name: llmengine/* (can be any regex pattern or wildcard pattern)
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
|
||||
if model_name = "llmengine/foo" -> model = "openai/foo"
|
||||
|
||||
Case 2:
|
||||
model_name: llmengine/fo::*::static::*
|
||||
litellm_params:
|
||||
model: openai/fo::*::static::*
|
||||
|
||||
if model_name = "llmengine/foo::bar::static::baz" -> model = "openai/foo::bar::static::baz"
|
||||
|
||||
Case 3:
|
||||
model_name: *meta.llama3*
|
||||
litellm_params:
|
||||
model: bedrock/meta.llama3*
|
||||
|
||||
if model_name = "hello-world-meta.llama3-70b" -> model = "bedrock/meta.llama3-70b"
|
||||
"""
|
||||
|
||||
## BASE CASE: if the deployment model name does not contain a wildcard, return the deployment model name
|
||||
if "*" not in litellm_deployment_litellm_model:
|
||||
return litellm_deployment_litellm_model
|
||||
|
||||
wildcard_count = litellm_deployment_litellm_model.count("*")
|
||||
|
||||
# Extract all dynamic segments from the request
|
||||
dynamic_segments = matched_pattern.groups()
|
||||
|
||||
if len(dynamic_segments) > wildcard_count:
|
||||
return (
|
||||
matched_pattern.string
|
||||
) # default to the user input, if unable to map based on wildcards.
|
||||
# Replace the corresponding wildcards in the litellm model pattern with extracted segments
|
||||
for segment in dynamic_segments:
|
||||
litellm_deployment_litellm_model = litellm_deployment_litellm_model.replace(
|
||||
"*", segment, 1
|
||||
)
|
||||
|
||||
return litellm_deployment_litellm_model
|
||||
|
||||
def get_pattern(
|
||||
self, model: str, custom_llm_provider: Optional[str] = None
|
||||
) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Check if a pattern exists for the given model and custom llm provider
|
||||
|
||||
Args:
|
||||
model: str
|
||||
custom_llm_provider: Optional[str]
|
||||
|
||||
Returns:
|
||||
bool: True if pattern exists, False otherwise
|
||||
"""
|
||||
if custom_llm_provider is None:
|
||||
try:
|
||||
(
|
||||
_,
|
||||
custom_llm_provider,
|
||||
_,
|
||||
_,
|
||||
) = get_llm_provider(model=model)
|
||||
except Exception:
|
||||
# get_llm_provider raises exception when provider is unknown
|
||||
pass
|
||||
return self.route(model) or self.route(f"{custom_llm_provider}/{model}")
|
||||
|
||||
def get_deployments_by_pattern(
|
||||
self, model: str, custom_llm_provider: Optional[str] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Get the deployments by pattern
|
||||
|
||||
Args:
|
||||
model: str
|
||||
custom_llm_provider: Optional[str]
|
||||
|
||||
Returns:
|
||||
List[Dict]: llm deployments matching the pattern
|
||||
"""
|
||||
pattern_match = self.get_pattern(model, custom_llm_provider)
|
||||
if pattern_match:
|
||||
return pattern_match
|
||||
return []
|
||||
|
||||
|
||||
# Example usage:
|
||||
# router = PatternRouter()
|
||||
# router.add_pattern('openai/*', [Deployment(), Deployment()])
|
||||
# router.add_pattern('openai/fo::*::static::*', Deployment())
|
||||
# print(router.route('openai/gpt-4')) # Output: [Deployment(), Deployment()]
|
||||
# print(router.route('openai/fo::hi::static::hi')) # Output: [Deployment()]
|
||||
# print(router.route('something/else')) # Output: None
|
||||
@@ -0,0 +1,506 @@
|
||||
"""
|
||||
Unified deployment affinity (session stickiness) for the Router.
|
||||
|
||||
Features (independently enable-able):
|
||||
1. Responses API continuity: when a `previous_response_id` is provided, route to the
|
||||
deployment that generated the original response (highest priority).
|
||||
2. API-key affinity: map an API key hash -> deployment id for a TTL and re-use that
|
||||
deployment for subsequent requests to the same router deployment model name
|
||||
(alias-safe, aligns to `model_map_information.model_map_key`).
|
||||
|
||||
This is designed to support "implicit prompt caching" scenarios (no explicit cache_control),
|
||||
where routing to a consistent deployment is still beneficial.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.caching.dual_cache import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger, Span
|
||||
from litellm.responses.utils import ResponsesAPIRequestUtils
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
|
||||
class DeploymentAffinityCacheValue(TypedDict):
|
||||
model_id: str
|
||||
|
||||
|
||||
class DeploymentAffinityCheck(CustomLogger):
|
||||
"""
|
||||
Router deployment affinity callback.
|
||||
|
||||
NOTE: This is a Router-only callback intended to be wired through
|
||||
`Router(optional_pre_call_checks=[...])`.
|
||||
"""
|
||||
|
||||
CACHE_KEY_PREFIX = "deployment_affinity:v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache: DualCache,
|
||||
ttl_seconds: int,
|
||||
enable_user_key_affinity: bool,
|
||||
enable_responses_api_affinity: bool,
|
||||
enable_session_id_affinity: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.cache = cache
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self.enable_user_key_affinity = enable_user_key_affinity
|
||||
self.enable_responses_api_affinity = enable_responses_api_affinity
|
||||
self.enable_session_id_affinity = enable_session_id_affinity
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_sha256_hex(value: str) -> bool:
|
||||
if len(value) != 64:
|
||||
return False
|
||||
try:
|
||||
int(value, 16)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _hash_user_key(user_key: str) -> str:
|
||||
"""
|
||||
Hash user identifiers before storing them in cache keys.
|
||||
|
||||
This avoids putting raw API keys / user identifiers into Redis keys (and therefore
|
||||
into logs/metrics), while keeping the cache key stable and a fixed length.
|
||||
"""
|
||||
# If the proxy already provides a stable SHA-256 (e.g. `metadata.user_api_key_hash`),
|
||||
# keep it as-is to avoid double-hashing and to make correlation/debugging possible.
|
||||
if DeploymentAffinityCheck._looks_like_sha256_hex(user_key):
|
||||
return user_key.lower()
|
||||
|
||||
return hashlib.sha256(user_key.encode("utf-8")).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _get_model_map_key_from_litellm_model_name(
|
||||
litellm_model_name: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Best-effort derivation of a stable "model map key" for affinity scoping.
|
||||
|
||||
The intent is to align with `standard_logging_payload.model_map_information.model_map_key`,
|
||||
which is typically the base model identifier (stable across deployments/endpoints).
|
||||
|
||||
Notes:
|
||||
- When the model name is in "provider/model" format, the provider prefix is stripped.
|
||||
- For Azure, the string after "azure/" is commonly an *Azure deployment name*, which may
|
||||
differ across instances. If `base_model` is not explicitly set, we skip deriving a
|
||||
model-map key from the model string to avoid generating unstable keys.
|
||||
"""
|
||||
if not litellm_model_name:
|
||||
return None
|
||||
|
||||
if "/" not in litellm_model_name:
|
||||
return litellm_model_name
|
||||
|
||||
provider_prefix, remainder = litellm_model_name.split("/", 1)
|
||||
if provider_prefix == "azure":
|
||||
return None
|
||||
|
||||
return remainder
|
||||
|
||||
@staticmethod
|
||||
def _get_model_map_key_from_deployment(deployment: dict) -> Optional[str]:
|
||||
"""
|
||||
Derive a stable model-map key from a router deployment dict.
|
||||
|
||||
Primary source: `deployment.model_name` (Router's canonical group name after
|
||||
alias resolution). This is stable across provider-specific deployments (e.g.,
|
||||
Azure/Vertex/Bedrock for the same logical model) and aligns with
|
||||
`model_map_information.model_map_key` in standard logging.
|
||||
|
||||
Prefer `base_model` when available (important for Azure), otherwise fall back to
|
||||
parsing `litellm_params.model`.
|
||||
"""
|
||||
model_name = deployment.get("model_name")
|
||||
if isinstance(model_name, str) and model_name:
|
||||
return model_name
|
||||
|
||||
model_info = deployment.get("model_info")
|
||||
if isinstance(model_info, dict):
|
||||
base_model = model_info.get("base_model")
|
||||
if isinstance(base_model, str) and base_model:
|
||||
return base_model
|
||||
|
||||
litellm_params = deployment.get("litellm_params")
|
||||
if isinstance(litellm_params, dict):
|
||||
base_model = litellm_params.get("base_model")
|
||||
if isinstance(base_model, str) and base_model:
|
||||
return base_model
|
||||
litellm_model_name = litellm_params.get("model")
|
||||
if isinstance(litellm_model_name, str) and litellm_model_name:
|
||||
return (
|
||||
DeploymentAffinityCheck._get_model_map_key_from_litellm_model_name(
|
||||
litellm_model_name
|
||||
)
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_stable_model_map_key_from_deployments(
|
||||
healthy_deployments: List[dict],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Only use model-map key scoping when it is stable across the deployment set.
|
||||
|
||||
This prevents accidentally keying on per-deployment identifiers like Azure deployment
|
||||
names (when `base_model` is not configured).
|
||||
"""
|
||||
if not healthy_deployments:
|
||||
return None
|
||||
|
||||
keys: List[str] = []
|
||||
for deployment in healthy_deployments:
|
||||
key = DeploymentAffinityCheck._get_model_map_key_from_deployment(deployment)
|
||||
if key is None:
|
||||
return None
|
||||
keys.append(key)
|
||||
|
||||
unique_keys = set(keys)
|
||||
if len(unique_keys) != 1:
|
||||
return None
|
||||
return keys[0]
|
||||
|
||||
@staticmethod
|
||||
def _shorten_for_logs(value: str, keep: int = 8) -> str:
|
||||
if len(value) <= keep:
|
||||
return value
|
||||
return f"{value[:keep]}..."
|
||||
|
||||
@classmethod
|
||||
def get_affinity_cache_key(cls, model_group: str, user_key: str) -> str:
|
||||
hashed_user_key = cls._hash_user_key(user_key=user_key)
|
||||
return f"{cls.CACHE_KEY_PREFIX}:{model_group}:{hashed_user_key}"
|
||||
|
||||
@classmethod
|
||||
def get_session_affinity_cache_key(cls, model_group: str, session_id: str) -> str:
|
||||
return f"{cls.CACHE_KEY_PREFIX}:session:{model_group}:{session_id}"
|
||||
|
||||
@staticmethod
|
||||
def _get_user_key_from_metadata_dict(metadata: dict) -> Optional[str]:
|
||||
# NOTE: affinity is keyed on the *API key hash* provided by the proxy (not the
|
||||
# OpenAI `user` parameter, which is an end-user identifier).
|
||||
user_key = metadata.get("user_api_key_hash")
|
||||
if user_key is None:
|
||||
return None
|
||||
return str(user_key)
|
||||
|
||||
@staticmethod
|
||||
def _get_session_id_from_metadata_dict(metadata: dict) -> Optional[str]:
|
||||
session_id = metadata.get("session_id")
|
||||
if session_id is None:
|
||||
return None
|
||||
return str(session_id)
|
||||
|
||||
@staticmethod
|
||||
def _iter_metadata_dicts(request_kwargs: dict) -> List[dict]:
|
||||
"""
|
||||
Return all metadata dicts available on the request.
|
||||
|
||||
Depending on the endpoint, Router may populate `metadata` or `litellm_metadata`.
|
||||
Users may also send one or both, so we check both (rather than using `or`).
|
||||
"""
|
||||
metadata_dicts: List[dict] = []
|
||||
for key in ("litellm_metadata", "metadata"):
|
||||
md = request_kwargs.get(key)
|
||||
if isinstance(md, dict):
|
||||
metadata_dicts.append(md)
|
||||
return metadata_dicts
|
||||
|
||||
@staticmethod
|
||||
def _get_user_key_from_request_kwargs(request_kwargs: dict) -> Optional[str]:
|
||||
"""
|
||||
Extract a stable affinity key from request kwargs.
|
||||
|
||||
Source (proxy): `metadata.user_api_key_hash`
|
||||
|
||||
Note: the OpenAI `user` parameter is an end-user identifier and is intentionally
|
||||
not used for deployment affinity.
|
||||
"""
|
||||
# Check metadata dicts (Proxy usage)
|
||||
for metadata in DeploymentAffinityCheck._iter_metadata_dicts(request_kwargs):
|
||||
user_key = DeploymentAffinityCheck._get_user_key_from_metadata_dict(
|
||||
metadata=metadata
|
||||
)
|
||||
if user_key is not None:
|
||||
return user_key
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_session_id_from_request_kwargs(request_kwargs: dict) -> Optional[str]:
|
||||
for metadata in DeploymentAffinityCheck._iter_metadata_dicts(request_kwargs):
|
||||
session_id = DeploymentAffinityCheck._get_session_id_from_metadata_dict(
|
||||
metadata=metadata
|
||||
)
|
||||
if session_id is not None:
|
||||
return session_id
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _find_deployment_by_model_id(
|
||||
healthy_deployments: List[dict], model_id: str
|
||||
) -> Optional[dict]:
|
||||
for deployment in healthy_deployments:
|
||||
model_info = deployment.get("model_info")
|
||||
if not isinstance(model_info, dict):
|
||||
continue
|
||||
deployment_model_id = model_info.get("id")
|
||||
if deployment_model_id is not None and str(deployment_model_id) == str(
|
||||
model_id
|
||||
):
|
||||
return deployment
|
||||
return None
|
||||
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
request_kwargs: Optional[dict] = None,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Optionally filter healthy deployments based on:
|
||||
1. `previous_response_id` (Responses API continuity) [highest priority]
|
||||
2. cached API-key deployment affinity
|
||||
"""
|
||||
request_kwargs = request_kwargs or {}
|
||||
typed_healthy_deployments = cast(List[dict], healthy_deployments)
|
||||
|
||||
# 1) Responses API continuity (high priority)
|
||||
if self.enable_responses_api_affinity:
|
||||
previous_response_id = request_kwargs.get("previous_response_id")
|
||||
if previous_response_id is not None:
|
||||
responses_model_id = (
|
||||
ResponsesAPIRequestUtils.get_model_id_from_response_id(
|
||||
str(previous_response_id)
|
||||
)
|
||||
)
|
||||
if responses_model_id is not None:
|
||||
deployment = self._find_deployment_by_model_id(
|
||||
healthy_deployments=typed_healthy_deployments,
|
||||
model_id=responses_model_id,
|
||||
)
|
||||
if deployment is not None:
|
||||
verbose_router_logger.debug(
|
||||
"DeploymentAffinityCheck: previous_response_id pinning -> deployment=%s",
|
||||
responses_model_id,
|
||||
)
|
||||
return [deployment]
|
||||
|
||||
stable_model_map_key = self._get_stable_model_map_key_from_deployments(
|
||||
healthy_deployments=typed_healthy_deployments
|
||||
)
|
||||
if stable_model_map_key is None:
|
||||
return typed_healthy_deployments
|
||||
|
||||
# 2) Session-id -> deployment affinity
|
||||
if self.enable_session_id_affinity:
|
||||
session_id = self._get_session_id_from_request_kwargs(
|
||||
request_kwargs=request_kwargs
|
||||
)
|
||||
if session_id is not None:
|
||||
session_cache_key = self.get_session_affinity_cache_key(
|
||||
model_group=stable_model_map_key, session_id=session_id
|
||||
)
|
||||
session_cache_result = await self.cache.async_get_cache(
|
||||
key=session_cache_key
|
||||
)
|
||||
|
||||
session_model_id: Optional[str] = None
|
||||
if isinstance(session_cache_result, dict):
|
||||
session_model_id = cast(
|
||||
Optional[str], session_cache_result.get("model_id")
|
||||
)
|
||||
elif isinstance(session_cache_result, str):
|
||||
session_model_id = session_cache_result
|
||||
|
||||
if session_model_id:
|
||||
session_deployment = self._find_deployment_by_model_id(
|
||||
healthy_deployments=typed_healthy_deployments,
|
||||
model_id=session_model_id,
|
||||
)
|
||||
if session_deployment is not None:
|
||||
verbose_router_logger.debug(
|
||||
"DeploymentAffinityCheck: session-id affinity hit -> deployment=%s session_id=%s",
|
||||
session_model_id,
|
||||
session_id,
|
||||
)
|
||||
return [session_deployment]
|
||||
else:
|
||||
verbose_router_logger.debug(
|
||||
"DeploymentAffinityCheck: session-id pinned deployment=%s not found in healthy_deployments",
|
||||
session_model_id,
|
||||
)
|
||||
|
||||
# 3) User key -> deployment affinity
|
||||
if not self.enable_user_key_affinity:
|
||||
return typed_healthy_deployments
|
||||
|
||||
user_key = self._get_user_key_from_request_kwargs(request_kwargs=request_kwargs)
|
||||
if user_key is None:
|
||||
return typed_healthy_deployments
|
||||
|
||||
cache_key = self.get_affinity_cache_key(
|
||||
model_group=stable_model_map_key, user_key=user_key
|
||||
)
|
||||
cache_result = await self.cache.async_get_cache(key=cache_key)
|
||||
|
||||
model_id: Optional[str] = None
|
||||
if isinstance(cache_result, dict):
|
||||
model_id = cast(Optional[str], cache_result.get("model_id"))
|
||||
elif isinstance(cache_result, str):
|
||||
# Backwards / safety: allow raw string values.
|
||||
model_id = cache_result
|
||||
|
||||
if not model_id:
|
||||
return typed_healthy_deployments
|
||||
|
||||
deployment = self._find_deployment_by_model_id(
|
||||
healthy_deployments=typed_healthy_deployments,
|
||||
model_id=model_id,
|
||||
)
|
||||
if deployment is None:
|
||||
verbose_router_logger.debug(
|
||||
"DeploymentAffinityCheck: pinned deployment=%s not found in healthy_deployments",
|
||||
model_id,
|
||||
)
|
||||
return typed_healthy_deployments
|
||||
|
||||
verbose_router_logger.debug(
|
||||
"DeploymentAffinityCheck: api-key affinity hit -> deployment=%s user_key=%s",
|
||||
model_id,
|
||||
self._shorten_for_logs(user_key),
|
||||
)
|
||||
return [deployment]
|
||||
|
||||
async def async_pre_call_deployment_hook(
|
||||
self, kwargs: Dict[str, Any], call_type: Optional[CallTypes]
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Persist/update the API-key -> deployment mapping for this request.
|
||||
|
||||
Why pre-call?
|
||||
- LiteLLM runs async success callbacks via a background logging worker for performance.
|
||||
- We want affinity to be immediately available for subsequent requests.
|
||||
"""
|
||||
if not self.enable_user_key_affinity and not self.enable_session_id_affinity:
|
||||
return None
|
||||
|
||||
user_key = None
|
||||
if self.enable_user_key_affinity:
|
||||
user_key = self._get_user_key_from_request_kwargs(request_kwargs=kwargs)
|
||||
|
||||
session_id = None
|
||||
if self.enable_session_id_affinity:
|
||||
session_id = self._get_session_id_from_request_kwargs(request_kwargs=kwargs)
|
||||
|
||||
if user_key is None and session_id is None:
|
||||
return None
|
||||
|
||||
metadata_dicts = self._iter_metadata_dicts(kwargs)
|
||||
|
||||
model_info = kwargs.get("model_info")
|
||||
if not isinstance(model_info, dict):
|
||||
model_info = None
|
||||
|
||||
if model_info is None:
|
||||
for metadata in metadata_dicts:
|
||||
maybe_model_info = metadata.get("model_info")
|
||||
if isinstance(maybe_model_info, dict):
|
||||
model_info = maybe_model_info
|
||||
break
|
||||
|
||||
if model_info is None:
|
||||
# Router sets `model_info` after selecting a deployment. If it's missing, this is
|
||||
# likely a non-router call or a call path that doesn't support affinity.
|
||||
return None
|
||||
|
||||
model_id = model_info.get("id")
|
||||
if not model_id:
|
||||
verbose_router_logger.warning(
|
||||
"DeploymentAffinityCheck: model_id missing; skipping affinity cache update."
|
||||
)
|
||||
return None
|
||||
|
||||
# Scope affinity by the Router deployment model name (alias-safe, consistent across
|
||||
# heterogeneous providers, and matches standard logging's `model_map_key`).
|
||||
deployment_model_name: Optional[str] = None
|
||||
for metadata in metadata_dicts:
|
||||
maybe_deployment_model_name = metadata.get("deployment_model_name")
|
||||
if (
|
||||
isinstance(maybe_deployment_model_name, str)
|
||||
and maybe_deployment_model_name
|
||||
):
|
||||
deployment_model_name = maybe_deployment_model_name
|
||||
break
|
||||
|
||||
if not deployment_model_name:
|
||||
verbose_router_logger.warning(
|
||||
"DeploymentAffinityCheck: deployment_model_name missing; skipping affinity cache update. model_id=%s",
|
||||
model_id,
|
||||
)
|
||||
return None
|
||||
|
||||
if user_key is not None:
|
||||
try:
|
||||
cache_key = self.get_affinity_cache_key(
|
||||
model_group=deployment_model_name, user_key=user_key
|
||||
)
|
||||
await self.cache.async_set_cache(
|
||||
cache_key,
|
||||
DeploymentAffinityCacheValue(model_id=str(model_id)),
|
||||
ttl=self.ttl_seconds,
|
||||
)
|
||||
|
||||
verbose_router_logger.debug(
|
||||
"DeploymentAffinityCheck: set affinity mapping model_map_key=%s deployment=%s ttl=%s user_key=%s",
|
||||
deployment_model_name,
|
||||
model_id,
|
||||
self.ttl_seconds,
|
||||
self._shorten_for_logs(user_key),
|
||||
)
|
||||
except Exception as e:
|
||||
# Non-blocking: affinity is a best-effort optimization.
|
||||
verbose_router_logger.debug(
|
||||
"DeploymentAffinityCheck: failed to set user key affinity cache. model_map_key=%s error=%s",
|
||||
deployment_model_name,
|
||||
e,
|
||||
)
|
||||
|
||||
# Also persist Session-ID affinity if enabled and session-id is provided
|
||||
if session_id is not None:
|
||||
try:
|
||||
session_cache_key = self.get_session_affinity_cache_key(
|
||||
model_group=deployment_model_name, session_id=session_id
|
||||
)
|
||||
await self.cache.async_set_cache(
|
||||
session_cache_key,
|
||||
DeploymentAffinityCacheValue(model_id=str(model_id)),
|
||||
ttl=self.ttl_seconds,
|
||||
)
|
||||
verbose_router_logger.debug(
|
||||
"DeploymentAffinityCheck: set session affinity mapping model_map_key=%s deployment=%s ttl=%s session_id=%s",
|
||||
deployment_model_name,
|
||||
model_id,
|
||||
self.ttl_seconds,
|
||||
session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_router_logger.debug(
|
||||
"DeploymentAffinityCheck: failed to set session affinity cache. model_map_key=%s error=%s",
|
||||
deployment_model_name,
|
||||
e,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Encrypted-content-aware deployment affinity for the Router.
|
||||
|
||||
When Codex or other models use `store: false` with `include: ["reasoning.encrypted_content"]`,
|
||||
the response output items contain encrypted reasoning tokens tied to the originating
|
||||
organization's API key. If a follow-up request containing those items is routed to a
|
||||
different deployment (different org), OpenAI rejects it with an `invalid_encrypted_content`
|
||||
error because the organization_id doesn't match.
|
||||
|
||||
This callback solves the problem by encoding the originating deployment's ``model_id``
|
||||
into the response output items that carry ``encrypted_content``. Two encoding strategies:
|
||||
|
||||
1. **Items with IDs**: Encode model_id into the item ID itself (e.g., ``encitem_...``)
|
||||
2. **Items without IDs** (Codex): Wrap the encrypted_content with model_id metadata
|
||||
(e.g., ``litellm_enc:{base64_metadata};{original_encrypted_content}``)
|
||||
|
||||
The encoded model_id is decoded on the next request so the router can pin to the correct
|
||||
deployment without any cache lookup.
|
||||
|
||||
Response post-processing (encoding) is handled by
|
||||
``ResponsesAPIRequestUtils._update_encrypted_content_item_ids_in_response`` which is
|
||||
called inside ``_update_responses_api_response_id_with_model_id`` in ``responses/utils.py``.
|
||||
|
||||
Request pre-processing (ID/content restoration before forwarding to upstream) is handled by
|
||||
``ResponsesAPIRequestUtils._restore_encrypted_content_item_ids_in_input`` which is called
|
||||
in ``get_optional_params_responses_api``.
|
||||
|
||||
This pre-call check is responsible only for the routing decision: it reads the encoded
|
||||
``model_id`` from either item IDs or wrapped encrypted_content and pins the request to
|
||||
the matching deployment.
|
||||
|
||||
Safe to enable globally:
|
||||
- Only activates when encoded markers appear in the request ``input``.
|
||||
- No effect on embedding models, chat completions, or first-time requests.
|
||||
- No quota reduction -- first requests are fully load balanced.
|
||||
- No cache required.
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger, Span
|
||||
from litellm.responses.utils import ResponsesAPIRequestUtils
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
class EncryptedContentAffinityCheck(CustomLogger):
|
||||
"""
|
||||
Routes follow-up Responses API requests to the deployment that produced
|
||||
the encrypted output items they reference.
|
||||
|
||||
The ``model_id`` is decoded directly from the litellm-encoded item IDs –
|
||||
no caching or TTL management needed.
|
||||
|
||||
Wired via ``Router(optional_pre_call_checks=["encrypted_content_affinity"])``.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_id_from_input(request_input: Any) -> Optional[str]:
|
||||
"""
|
||||
Scan ``input`` items for litellm-encoded encrypted-content markers and
|
||||
return the ``model_id`` embedded in the first one found.
|
||||
|
||||
Checks both:
|
||||
1. Encoded item IDs (encitem_...) - for clients that send IDs
|
||||
2. Wrapped encrypted_content (litellm_enc:...) - for clients like Codex that don't send IDs
|
||||
|
||||
``input`` can be:
|
||||
- a plain string -> no encoded markers
|
||||
- a list of items -> check each item's ``id`` and ``encrypted_content`` fields
|
||||
"""
|
||||
if not isinstance(request_input, list):
|
||||
return None
|
||||
|
||||
for item in request_input:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# First, try to decode from item ID (if present)
|
||||
item_id = item.get("id")
|
||||
if item_id and isinstance(item_id, str):
|
||||
decoded = ResponsesAPIRequestUtils._decode_encrypted_item_id(item_id)
|
||||
if decoded:
|
||||
return decoded.get("model_id")
|
||||
|
||||
# If no encoded ID, check if encrypted_content itself is wrapped
|
||||
encrypted_content = item.get("encrypted_content")
|
||||
if encrypted_content and isinstance(encrypted_content, str):
|
||||
(
|
||||
model_id,
|
||||
_,
|
||||
) = ResponsesAPIRequestUtils._unwrap_encrypted_content_with_model_id(
|
||||
encrypted_content
|
||||
)
|
||||
if model_id:
|
||||
return model_id
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _find_deployment_by_model_id(
|
||||
healthy_deployments: List[dict], model_id: str
|
||||
) -> Optional[dict]:
|
||||
for deployment in healthy_deployments:
|
||||
model_info = deployment.get("model_info")
|
||||
if not isinstance(model_info, dict):
|
||||
continue
|
||||
deployment_model_id = model_info.get("id")
|
||||
if deployment_model_id is not None and str(deployment_model_id) == str(
|
||||
model_id
|
||||
):
|
||||
return deployment
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Request routing (pre-call filter)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
request_kwargs: Optional[dict] = None,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
If the request ``input`` contains litellm-encoded item IDs, decode the
|
||||
embedded ``model_id`` and pin the request to that deployment.
|
||||
"""
|
||||
request_kwargs = request_kwargs or {}
|
||||
typed_healthy_deployments = cast(List[dict], healthy_deployments)
|
||||
|
||||
# Signal to the response post-processor that encrypted item IDs should be
|
||||
# encoded in the output of this request.
|
||||
litellm_metadata = request_kwargs.setdefault("litellm_metadata", {})
|
||||
litellm_metadata["encrypted_content_affinity_enabled"] = True
|
||||
|
||||
request_input = request_kwargs.get("input")
|
||||
model_id = self._extract_model_id_from_input(request_input)
|
||||
if not model_id:
|
||||
return typed_healthy_deployments
|
||||
|
||||
verbose_router_logger.debug(
|
||||
"EncryptedContentAffinityCheck: decoded model_id=%s from input item IDs",
|
||||
model_id,
|
||||
)
|
||||
|
||||
deployment = self._find_deployment_by_model_id(
|
||||
healthy_deployments=typed_healthy_deployments,
|
||||
model_id=model_id,
|
||||
)
|
||||
if deployment is not None:
|
||||
verbose_router_logger.debug(
|
||||
"EncryptedContentAffinityCheck: pinning -> deployment=%s",
|
||||
model_id,
|
||||
)
|
||||
request_kwargs["_encrypted_content_affinity_pinned"] = True
|
||||
return [deployment]
|
||||
|
||||
verbose_router_logger.error(
|
||||
"EncryptedContentAffinityCheck: decoded deployment=%s not found in healthy_deployments",
|
||||
model_id,
|
||||
)
|
||||
return typed_healthy_deployments
|
||||
@@ -0,0 +1,332 @@
|
||||
"""
|
||||
Enforce TPM/RPM rate limits set on model deployments.
|
||||
|
||||
This pre-call check ensures that model-level TPM/RPM limits are enforced
|
||||
across all requests, regardless of routing strategy.
|
||||
|
||||
When enabled via `enforce_model_rate_limits: true` in litellm_settings,
|
||||
requests that exceed the configured TPM/RPM limits will receive a 429 error.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.caching.dual_cache import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.router import RouterErrors
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
from litellm.utils import get_utc_datetime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class RoutingArgs:
|
||||
ttl: int = 60 # 1min (RPM/TPM expire key)
|
||||
|
||||
|
||||
class ModelRateLimitingCheck(CustomLogger):
|
||||
"""
|
||||
Pre-call check that enforces TPM/RPM limits on model deployments.
|
||||
|
||||
This check runs before each request and raises a RateLimitError
|
||||
if the deployment has exceeded its configured TPM or RPM limits.
|
||||
|
||||
Unlike the usage-based-routing strategy which uses limits for routing decisions,
|
||||
this check actively enforces those limits across ALL routing strategies.
|
||||
"""
|
||||
|
||||
def __init__(self, dual_cache: DualCache):
|
||||
self.dual_cache = dual_cache
|
||||
|
||||
def _get_deployment_limits(
|
||||
self, deployment: Dict
|
||||
) -> tuple[Optional[int], Optional[int]]:
|
||||
"""
|
||||
Extract TPM and RPM limits from a deployment configuration.
|
||||
|
||||
Checks in order:
|
||||
1. Top-level 'tpm'/'rpm' fields
|
||||
2. litellm_params.tpm/rpm
|
||||
3. model_info.tpm/rpm
|
||||
|
||||
Returns:
|
||||
Tuple of (tpm_limit, rpm_limit)
|
||||
"""
|
||||
# Check top-level
|
||||
tpm = deployment.get("tpm")
|
||||
rpm = deployment.get("rpm")
|
||||
|
||||
# Check litellm_params
|
||||
if tpm is None:
|
||||
tpm = deployment.get("litellm_params", {}).get("tpm")
|
||||
if rpm is None:
|
||||
rpm = deployment.get("litellm_params", {}).get("rpm")
|
||||
|
||||
# Check model_info
|
||||
if tpm is None:
|
||||
tpm = deployment.get("model_info", {}).get("tpm")
|
||||
if rpm is None:
|
||||
rpm = deployment.get("model_info", {}).get("rpm")
|
||||
|
||||
return tpm, rpm
|
||||
|
||||
def _get_cache_keys(self, deployment: Dict, current_minute: str) -> tuple[str, str]:
|
||||
"""Get the cache keys for TPM and RPM tracking."""
|
||||
model_id = deployment.get("model_info", {}).get("id")
|
||||
deployment_name = deployment.get("litellm_params", {}).get("model")
|
||||
|
||||
tpm_key = f"{model_id}:{deployment_name}:tpm:{current_minute}"
|
||||
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
|
||||
|
||||
return tpm_key, rpm_key
|
||||
|
||||
def pre_call_check(self, deployment: Dict) -> Optional[Dict]:
|
||||
"""
|
||||
Synchronous pre-call check for model rate limits.
|
||||
|
||||
Raises RateLimitError if deployment exceeds TPM/RPM limits.
|
||||
"""
|
||||
try:
|
||||
tpm_limit, rpm_limit = self._get_deployment_limits(deployment)
|
||||
|
||||
# If no limits are set, allow the request
|
||||
if tpm_limit is None and rpm_limit is None:
|
||||
return deployment
|
||||
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
tpm_key, rpm_key = self._get_cache_keys(deployment, current_minute)
|
||||
|
||||
model_id = deployment.get("model_info", {}).get("id")
|
||||
model_name = deployment.get("litellm_params", {}).get("model")
|
||||
model_group = deployment.get("model_name", "")
|
||||
|
||||
# Check TPM limit
|
||||
if tpm_limit is not None:
|
||||
# First check local cache
|
||||
current_tpm = self.dual_cache.get_cache(key=tpm_key, local_only=True)
|
||||
if current_tpm is not None and current_tpm >= tpm_limit:
|
||||
raise litellm.RateLimitError(
|
||||
message=f"Model rate limit exceeded. TPM limit={tpm_limit}, current usage={current_tpm}",
|
||||
llm_provider="",
|
||||
model=model_name,
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content=f"{RouterErrors.user_defined_ratelimit_error.value} tpm limit={tpm_limit}. current usage={current_tpm}. id={model_id}, model_group={model_group}",
|
||||
headers={"retry-after": str(60)},
|
||||
request=httpx.Request(
|
||||
method="model_rate_limit_check",
|
||||
url="https://github.com/BerriAI/litellm",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Check RPM limit (atomic increment-first to avoid race conditions)
|
||||
if rpm_limit is not None:
|
||||
current_rpm = self.dual_cache.increment_cache(
|
||||
key=rpm_key, value=1, ttl=RoutingArgs.ttl
|
||||
)
|
||||
if current_rpm is not None and current_rpm > rpm_limit:
|
||||
raise litellm.RateLimitError(
|
||||
message=f"Model rate limit exceeded. RPM limit={rpm_limit}, current usage={current_rpm}",
|
||||
llm_provider="",
|
||||
model=model_name,
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content=f"{RouterErrors.user_defined_ratelimit_error.value} rpm limit={rpm_limit}. current usage={current_rpm}. id={model_id}, model_group={model_group}",
|
||||
headers={"retry-after": str(60)},
|
||||
request=httpx.Request(
|
||||
method="model_rate_limit_check",
|
||||
url="https://github.com/BerriAI/litellm",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
return deployment
|
||||
|
||||
except litellm.RateLimitError:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_router_logger.debug(
|
||||
f"Error in ModelRateLimitingCheck.pre_call_check: {str(e)}"
|
||||
)
|
||||
# Don't fail the request if rate limit check fails
|
||||
return deployment
|
||||
|
||||
async def async_pre_call_check(
|
||||
self, deployment: Dict, parent_otel_span: Optional[Span] = None
|
||||
) -> Optional[Dict]:
|
||||
"""
|
||||
Async pre-call check for model rate limits.
|
||||
|
||||
Raises RateLimitError if deployment exceeds TPM/RPM limits.
|
||||
"""
|
||||
try:
|
||||
tpm_limit, rpm_limit = self._get_deployment_limits(deployment)
|
||||
|
||||
# If no limits are set, allow the request
|
||||
if tpm_limit is None and rpm_limit is None:
|
||||
return deployment
|
||||
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
tpm_key, rpm_key = self._get_cache_keys(deployment, current_minute)
|
||||
|
||||
model_id = deployment.get("model_info", {}).get("id")
|
||||
model_name = deployment.get("litellm_params", {}).get("model")
|
||||
model_group = deployment.get("model_name", "")
|
||||
|
||||
# Check TPM limit
|
||||
if tpm_limit is not None:
|
||||
# First check local cache
|
||||
current_tpm = await self.dual_cache.async_get_cache(
|
||||
key=tpm_key, local_only=True
|
||||
)
|
||||
if current_tpm is not None and current_tpm >= tpm_limit:
|
||||
raise litellm.RateLimitError(
|
||||
message=f"Model rate limit exceeded. TPM limit={tpm_limit}, current usage={current_tpm}",
|
||||
llm_provider="",
|
||||
model=model_name,
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content=f"{RouterErrors.user_defined_ratelimit_error.value} tpm limit={tpm_limit}. current usage={current_tpm}. id={model_id}, model_group={model_group}",
|
||||
headers={"retry-after": str(60)},
|
||||
request=httpx.Request(
|
||||
method="model_rate_limit_check",
|
||||
url="https://github.com/BerriAI/litellm",
|
||||
),
|
||||
),
|
||||
num_retries=0, # Don't retry - return 429 immediately
|
||||
)
|
||||
|
||||
# Check RPM limit (atomic increment-first to avoid race conditions)
|
||||
if rpm_limit is not None:
|
||||
current_rpm = await self.dual_cache.async_increment_cache(
|
||||
key=rpm_key,
|
||||
value=1,
|
||||
ttl=RoutingArgs.ttl,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
if current_rpm is not None and current_rpm > rpm_limit:
|
||||
raise litellm.RateLimitError(
|
||||
message=f"Model rate limit exceeded. RPM limit={rpm_limit}, current usage={current_rpm}",
|
||||
llm_provider="",
|
||||
model=model_name,
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content=f"{RouterErrors.user_defined_ratelimit_error.value} rpm limit={rpm_limit}. current usage={current_rpm}. id={model_id}, model_group={model_group}",
|
||||
headers={"retry-after": str(60)},
|
||||
request=httpx.Request(
|
||||
method="model_rate_limit_check",
|
||||
url="https://github.com/BerriAI/litellm",
|
||||
),
|
||||
),
|
||||
num_retries=0, # Don't retry - return 429 immediately
|
||||
)
|
||||
|
||||
return deployment
|
||||
|
||||
except litellm.RateLimitError:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_router_logger.debug(
|
||||
f"Error in ModelRateLimitingCheck.async_pre_call_check: {str(e)}"
|
||||
)
|
||||
# Don't fail the request if rate limit check fails
|
||||
return deployment
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Track TPM usage after successful request.
|
||||
|
||||
This updates the TPM counter with the actual tokens used.
|
||||
Always tracks tokens - the pre-call check handles enforcement.
|
||||
"""
|
||||
try:
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
if standard_logging_object is None:
|
||||
return
|
||||
|
||||
model_id = standard_logging_object.get("model_id")
|
||||
if model_id is None:
|
||||
return
|
||||
|
||||
total_tokens = standard_logging_object.get("total_tokens", 0)
|
||||
model = standard_logging_object.get("hidden_params", {}).get(
|
||||
"litellm_model_name"
|
||||
)
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"[TPM TRACKING] model_id={model_id}, total_tokens={total_tokens}, model={model}"
|
||||
)
|
||||
|
||||
if not model or not total_tokens:
|
||||
return
|
||||
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
tpm_key = f"{model_id}:{model}:tpm:{current_minute}"
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"[TPM TRACKING] Incrementing {tpm_key} by {total_tokens}"
|
||||
)
|
||||
|
||||
await self.dual_cache.async_increment_cache(
|
||||
key=tpm_key,
|
||||
value=total_tokens,
|
||||
ttl=RoutingArgs.ttl,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_router_logger.debug(
|
||||
f"Error in ModelRateLimitingCheck.async_log_success_event: {str(e)}"
|
||||
)
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Sync version of tracking TPM usage after successful request.
|
||||
Always tracks tokens - the pre-call check handles enforcement.
|
||||
"""
|
||||
try:
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
if standard_logging_object is None:
|
||||
return
|
||||
|
||||
model_id = standard_logging_object.get("model_id")
|
||||
if model_id is None:
|
||||
return
|
||||
|
||||
total_tokens = standard_logging_object.get("total_tokens", 0)
|
||||
model = standard_logging_object.get("hidden_params", {}).get(
|
||||
"litellm_model_name"
|
||||
)
|
||||
|
||||
if not model or not total_tokens:
|
||||
return
|
||||
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
tpm_key = f"{model_id}:{model}:tpm:{current_minute}"
|
||||
|
||||
self.dual_cache.increment_cache(
|
||||
key=tpm_key,
|
||||
value=total_tokens,
|
||||
ttl=RoutingArgs.ttl,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_router_logger.debug(
|
||||
f"Error in ModelRateLimitingCheck.log_success_event: {str(e)}"
|
||||
)
|
||||
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
Check if prompt caching is valid for a given deployment
|
||||
|
||||
Route to previously cached model id, if valid
|
||||
"""
|
||||
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching.dual_cache import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger, Span
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import CallTypes, StandardLoggingPayload
|
||||
from litellm.utils import is_prompt_caching_valid_prompt
|
||||
|
||||
from ..prompt_caching_cache import PromptCachingCache
|
||||
|
||||
|
||||
class PromptCachingDeploymentCheck(CustomLogger):
|
||||
def __init__(self, cache: DualCache):
|
||||
self.cache = cache
|
||||
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
request_kwargs: Optional[dict] = None,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
) -> List[dict]:
|
||||
if messages is not None and is_prompt_caching_valid_prompt(
|
||||
messages=messages,
|
||||
model=model,
|
||||
): # prompt > 1024 tokens
|
||||
prompt_cache = PromptCachingCache(
|
||||
cache=self.cache,
|
||||
)
|
||||
|
||||
model_id_dict = await prompt_cache.async_get_model_id(
|
||||
messages=cast(List[AllMessageValues], messages),
|
||||
tools=None,
|
||||
)
|
||||
if model_id_dict is not None:
|
||||
model_id = model_id_dict["model_id"]
|
||||
for deployment in healthy_deployments:
|
||||
if deployment["model_info"]["id"] == model_id:
|
||||
return [deployment]
|
||||
|
||||
return healthy_deployments
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
if standard_logging_object is None:
|
||||
return
|
||||
|
||||
call_type = standard_logging_object["call_type"]
|
||||
|
||||
if (
|
||||
call_type != CallTypes.completion.value
|
||||
and call_type != CallTypes.acompletion.value
|
||||
and call_type != CallTypes.anthropic_messages.value
|
||||
): # only use prompt caching for completion calls
|
||||
verbose_logger.debug(
|
||||
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, CALL TYPE IS NOT COMPLETION or ANTHROPIC MESSAGE"
|
||||
)
|
||||
return
|
||||
|
||||
model = standard_logging_object["model"]
|
||||
messages = standard_logging_object["messages"]
|
||||
model_id = standard_logging_object["model_id"]
|
||||
|
||||
if messages is None or not isinstance(messages, list):
|
||||
verbose_logger.debug(
|
||||
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MESSAGES IS NOT A LIST"
|
||||
)
|
||||
return
|
||||
if model_id is None:
|
||||
verbose_logger.debug(
|
||||
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MODEL ID IS NONE"
|
||||
)
|
||||
return
|
||||
|
||||
## PROMPT CACHING - cache model id, if prompt caching valid prompt + provider
|
||||
if is_prompt_caching_valid_prompt(
|
||||
model=model,
|
||||
messages=cast(List[AllMessageValues], messages),
|
||||
):
|
||||
cache = PromptCachingCache(
|
||||
cache=self.cache,
|
||||
)
|
||||
await cache.async_add_model_id(
|
||||
model_id=model_id,
|
||||
messages=messages,
|
||||
tools=None, # [TODO]: add tools once standard_logging_object supports it
|
||||
)
|
||||
|
||||
return
|
||||
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
For Responses API, we need routing affinity when a user sends a previous_response_id.
|
||||
|
||||
eg. If proxy admins are load balancing between N gpt-4.1-turbo deployments, and a user sends a previous_response_id,
|
||||
we want to route to the same gpt-4.1-turbo deployment.
|
||||
|
||||
This is different from the normal behavior of the router, which does not have routing affinity for previous_response_id.
|
||||
|
||||
|
||||
If previous_response_id is provided, route to the deployment that returned the previous response
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger, Span
|
||||
from litellm.responses.utils import ResponsesAPIRequestUtils
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
class ResponsesApiDeploymentCheck(CustomLogger):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
warnings.warn(
|
||||
(
|
||||
"ResponsesApiDeploymentCheck is deprecated. "
|
||||
"Use DeploymentAffinityCheck(enable_responses_api_affinity=True) instead."
|
||||
),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
request_kwargs: Optional[dict] = None,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
) -> List[dict]:
|
||||
request_kwargs = request_kwargs or {}
|
||||
previous_response_id = request_kwargs.get("previous_response_id", None)
|
||||
if previous_response_id is None:
|
||||
return healthy_deployments
|
||||
|
||||
decoded_response = ResponsesAPIRequestUtils._decode_responses_api_response_id(
|
||||
response_id=previous_response_id,
|
||||
)
|
||||
model_id = decoded_response.get("model_id")
|
||||
if model_id is None:
|
||||
return healthy_deployments
|
||||
|
||||
for deployment in healthy_deployments:
|
||||
if deployment["model_info"]["id"] == model_id:
|
||||
return [deployment]
|
||||
|
||||
return healthy_deployments
|
||||
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
Wrapper around router cache. Meant to store model id when prompt caching supported prompt is called.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union, cast
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.in_memory_cache import InMemoryCache
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.router import Router
|
||||
|
||||
litellm_router = Router
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
litellm_router = Any
|
||||
|
||||
|
||||
class PromptCachingCacheValue(TypedDict):
|
||||
model_id: str
|
||||
|
||||
|
||||
class PromptCachingCache:
|
||||
def __init__(self, cache: DualCache):
|
||||
self.cache = cache
|
||||
self.in_memory_cache = InMemoryCache()
|
||||
|
||||
@staticmethod
|
||||
def serialize_object(obj: Any) -> Any:
|
||||
"""Helper function to serialize Pydantic objects, dictionaries, or fallback to string."""
|
||||
if hasattr(obj, "dict"):
|
||||
# If the object is a Pydantic model, use its `dict()` method
|
||||
return obj.dict()
|
||||
elif isinstance(obj, dict):
|
||||
# If the object is a dictionary, serialize it with sorted keys
|
||||
return json.dumps(
|
||||
obj, sort_keys=True, separators=(",", ":")
|
||||
) # Standardize serialization
|
||||
|
||||
elif isinstance(obj, list):
|
||||
# Serialize lists by ensuring each element is handled properly
|
||||
return [PromptCachingCache.serialize_object(item) for item in obj]
|
||||
elif isinstance(obj, (int, float, bool)):
|
||||
return obj # Keep primitive types as-is
|
||||
return str(obj)
|
||||
|
||||
@staticmethod
|
||||
def extract_cacheable_prefix(
|
||||
messages: List[AllMessageValues],
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Extract the cacheable prefix from messages.
|
||||
|
||||
The cacheable prefix is everything UP TO AND INCLUDING the LAST content block
|
||||
(across all messages) that has cache_control. This includes ALL blocks before
|
||||
the last cacheable block (even if they don't have cache_control).
|
||||
|
||||
Args:
|
||||
messages: List of messages to extract cacheable prefix from
|
||||
|
||||
Returns:
|
||||
List of messages containing only the cacheable prefix
|
||||
"""
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
# Find the last content block (across all messages) that has cache_control
|
||||
last_cacheable_message_idx = None
|
||||
last_cacheable_content_idx = None
|
||||
|
||||
for msg_idx, message in enumerate(messages):
|
||||
content = message.get("content")
|
||||
|
||||
# Check for cache_control at message level (when content is a string)
|
||||
# This handles the case where cache_control is a sibling of string content:
|
||||
# {"role": "user", "content": "...", "cache_control": {"type": "ephemeral"}}
|
||||
message_level_cache_control = message.get("cache_control")
|
||||
if (
|
||||
message_level_cache_control is not None
|
||||
and isinstance(message_level_cache_control, dict)
|
||||
and message_level_cache_control.get("type") == "ephemeral"
|
||||
):
|
||||
last_cacheable_message_idx = msg_idx
|
||||
# Set to None to indicate the entire message content is cacheable
|
||||
# (not a specific content block index within a list)
|
||||
last_cacheable_content_idx = None
|
||||
|
||||
# Also check for cache_control within content blocks (when content is a list)
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
|
||||
for content_idx, content_block in enumerate(content):
|
||||
if isinstance(content_block, dict):
|
||||
cache_control = content_block.get("cache_control")
|
||||
if (
|
||||
cache_control is not None
|
||||
and isinstance(cache_control, dict)
|
||||
and cache_control.get("type") == "ephemeral"
|
||||
):
|
||||
last_cacheable_message_idx = msg_idx
|
||||
last_cacheable_content_idx = content_idx
|
||||
|
||||
# If no cacheable block found, return empty list (no cacheable prefix)
|
||||
if last_cacheable_message_idx is None:
|
||||
return []
|
||||
|
||||
# Build the cacheable prefix: all messages up to and including the last cacheable message
|
||||
cacheable_prefix = []
|
||||
|
||||
for msg_idx, message in enumerate(messages):
|
||||
if msg_idx < last_cacheable_message_idx:
|
||||
# Include entire message (comes before last cacheable block)
|
||||
cacheable_prefix.append(message)
|
||||
elif msg_idx == last_cacheable_message_idx:
|
||||
# Include message but only up to and including the last cacheable content block
|
||||
content = message.get("content")
|
||||
if isinstance(content, list) and last_cacheable_content_idx is not None:
|
||||
# Create a copy of the message with only cacheable content blocks
|
||||
message_copy = cast(
|
||||
AllMessageValues,
|
||||
{
|
||||
**message,
|
||||
"content": content[: last_cacheable_content_idx + 1],
|
||||
},
|
||||
)
|
||||
cacheable_prefix.append(message_copy)
|
||||
else:
|
||||
# Content is not a list or cacheable content idx is None, include full message
|
||||
cacheable_prefix.append(message)
|
||||
else:
|
||||
# Message comes after last cacheable block, don't include
|
||||
break
|
||||
|
||||
return cacheable_prefix
|
||||
|
||||
@staticmethod
|
||||
def get_prompt_caching_cache_key(
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> Optional[str]:
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
|
||||
# Extract cacheable prefix from messages (only include up to last cache_control block)
|
||||
cacheable_messages = None
|
||||
if messages is not None:
|
||||
cacheable_messages = PromptCachingCache.extract_cacheable_prefix(messages)
|
||||
# If no cacheable prefix found, return None (can't cache)
|
||||
if not cacheable_messages:
|
||||
return None
|
||||
|
||||
# Use serialize_object for consistent and stable serialization
|
||||
data_to_hash = {}
|
||||
if cacheable_messages is not None:
|
||||
serialized_messages = PromptCachingCache.serialize_object(
|
||||
cacheable_messages
|
||||
)
|
||||
data_to_hash["messages"] = serialized_messages
|
||||
if tools is not None:
|
||||
serialized_tools = PromptCachingCache.serialize_object(tools)
|
||||
data_to_hash["tools"] = serialized_tools
|
||||
|
||||
# Combine serialized data into a single string
|
||||
data_to_hash_str = json.dumps(
|
||||
data_to_hash,
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
|
||||
# Create a hash of the serialized data for a stable cache key
|
||||
hashed_data = hashlib.sha256(data_to_hash_str.encode()).hexdigest()
|
||||
return f"deployment:{hashed_data}:prompt_caching"
|
||||
|
||||
def add_model_id(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> None:
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
|
||||
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||
# If no cacheable prefix found, don't cache (can't generate cache key)
|
||||
if cache_key is None:
|
||||
return None
|
||||
|
||||
self.cache.set_cache(
|
||||
cache_key, PromptCachingCacheValue(model_id=model_id), ttl=300
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_add_model_id(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> None:
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
|
||||
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||
# If no cacheable prefix found, don't cache (can't generate cache key)
|
||||
if cache_key is None:
|
||||
return None
|
||||
|
||||
await self.cache.async_set_cache(
|
||||
cache_key,
|
||||
PromptCachingCacheValue(model_id=model_id),
|
||||
ttl=300, # store for 5 minutes
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_get_model_id(
|
||||
self,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> Optional[PromptCachingCacheValue]:
|
||||
"""
|
||||
Get model ID from cache using the cacheable prefix.
|
||||
|
||||
The cache key is based on the cacheable prefix (everything up to and including
|
||||
the last cache_control block), so requests with the same cacheable prefix but
|
||||
different user messages will have the same cache key.
|
||||
"""
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
|
||||
# Generate cache key using cacheable prefix
|
||||
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||
if cache_key is None:
|
||||
return None
|
||||
|
||||
# Perform cache lookup
|
||||
cache_result = await self.cache.async_get_cache(key=cache_key)
|
||||
return cache_result
|
||||
|
||||
def get_model_id(
|
||||
self,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> Optional[PromptCachingCacheValue]:
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
|
||||
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||
# If no cacheable prefix found, return None (can't cache)
|
||||
if cache_key is None:
|
||||
return None
|
||||
|
||||
return self.cache.get_cache(cache_key)
|
||||
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Helper functions to get/set num success and num failures per deployment
|
||||
|
||||
|
||||
set_deployment_failures_for_current_minute
|
||||
set_deployment_successes_for_current_minute
|
||||
|
||||
get_deployment_failures_for_current_minute
|
||||
get_deployment_successes_for_current_minute
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
|
||||
|
||||
def increment_deployment_successes_for_current_minute(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
deployment_id: str,
|
||||
) -> str:
|
||||
"""
|
||||
In-Memory: Increments the number of successes for the current minute for a deployment_id
|
||||
"""
|
||||
key = f"{deployment_id}:successes"
|
||||
litellm_router_instance.cache.increment_cache(
|
||||
local_only=True,
|
||||
key=key,
|
||||
value=1,
|
||||
ttl=60,
|
||||
)
|
||||
return key
|
||||
|
||||
|
||||
def increment_deployment_failures_for_current_minute(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
deployment_id: str,
|
||||
):
|
||||
"""
|
||||
In-Memory: Increments the number of failures for the current minute for a deployment_id
|
||||
"""
|
||||
key = f"{deployment_id}:fails"
|
||||
litellm_router_instance.cache.increment_cache(
|
||||
local_only=True,
|
||||
key=key,
|
||||
value=1,
|
||||
ttl=60,
|
||||
)
|
||||
|
||||
|
||||
def get_deployment_successes_for_current_minute(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
deployment_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
Returns the number of successes for the current minute for a deployment_id
|
||||
|
||||
Returns 0 if no value found
|
||||
"""
|
||||
key = f"{deployment_id}:successes"
|
||||
return (
|
||||
litellm_router_instance.cache.get_cache(
|
||||
local_only=True,
|
||||
key=key,
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
|
||||
def get_deployment_failures_for_current_minute(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
deployment_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
Returns the number of fails for the current minute for a deployment_id
|
||||
|
||||
Returns 0 if no value found
|
||||
"""
|
||||
key = f"{deployment_id}:fails"
|
||||
return (
|
||||
litellm_router_instance.cache.get_cache(
|
||||
local_only=True,
|
||||
key=key,
|
||||
)
|
||||
or 0
|
||||
)
|
||||
@@ -0,0 +1,227 @@
|
||||
"""
|
||||
Router utilities for Search API integration.
|
||||
|
||||
Handles search tool selection, load balancing, and fallback logic for search requests.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import traceback
|
||||
from functools import partial
|
||||
from typing import Any, Callable
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
|
||||
|
||||
class SearchAPIRouter:
|
||||
"""
|
||||
Static utility class for routing search API calls through the LiteLLM router.
|
||||
|
||||
Provides methods for search tool selection, load balancing, and fallback handling.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def update_router_search_tools(router_instance: Any, search_tools: list):
|
||||
"""
|
||||
Update the router with search tools from the database.
|
||||
|
||||
This method is called by a cron job to sync search tools from DB to router.
|
||||
|
||||
Args:
|
||||
router_instance: The Router instance to update
|
||||
search_tools: List of search tool configurations from the database
|
||||
"""
|
||||
try:
|
||||
from litellm.types.router import SearchToolTypedDict
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"Adding {len(search_tools)} search tools to router"
|
||||
)
|
||||
|
||||
# Convert search tools to the format expected by the router
|
||||
router_search_tools: list = []
|
||||
for tool in search_tools:
|
||||
# Create dict that matches SearchToolTypedDict structure
|
||||
router_search_tool: SearchToolTypedDict = { # type: ignore
|
||||
"search_tool_id": tool.get("search_tool_id"),
|
||||
"search_tool_name": tool.get("search_tool_name"),
|
||||
"litellm_params": tool.get("litellm_params", {}),
|
||||
"search_tool_info": tool.get("search_tool_info"),
|
||||
}
|
||||
router_search_tools.append(router_search_tool)
|
||||
|
||||
# Update the router's search_tools list
|
||||
router_instance.search_tools = router_search_tools
|
||||
|
||||
verbose_router_logger.info(
|
||||
f"Successfully updated router with {len(router_search_tools)} search tool(s)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_router_logger.exception(
|
||||
f"Error updating router with search tools: {str(e)}"
|
||||
)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def get_matching_search_tools(
|
||||
router_instance: Any,
|
||||
search_tool_name: str,
|
||||
) -> list:
|
||||
"""
|
||||
Get all search tools matching the given name.
|
||||
|
||||
Args:
|
||||
router_instance: The Router instance
|
||||
search_tool_name: Name of the search tool to find
|
||||
|
||||
Returns:
|
||||
List of matching search tool configurations
|
||||
|
||||
Raises:
|
||||
ValueError: If no matching search tools are found
|
||||
"""
|
||||
matching_tools = [
|
||||
tool
|
||||
for tool in router_instance.search_tools
|
||||
if tool.get("search_tool_name") == search_tool_name
|
||||
]
|
||||
|
||||
if not matching_tools:
|
||||
raise ValueError(
|
||||
f"Search tool '{search_tool_name}' not found in router.search_tools"
|
||||
)
|
||||
|
||||
return matching_tools
|
||||
|
||||
@staticmethod
|
||||
async def async_search_with_fallbacks(
|
||||
router_instance: Any,
|
||||
original_function: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Helper function to make a search API call through the router with load balancing and fallbacks.
|
||||
Reuses the router's retry/fallback infrastructure.
|
||||
|
||||
Args:
|
||||
router_instance: The Router instance
|
||||
original_function: The original litellm.asearch function
|
||||
**kwargs: Search parameters including search_tool_name, query, etc.
|
||||
|
||||
Returns:
|
||||
SearchResponse from the search API
|
||||
"""
|
||||
try:
|
||||
search_tool_name = kwargs.get("search_tool_name", kwargs.get("model"))
|
||||
|
||||
if not search_tool_name:
|
||||
raise ValueError(
|
||||
"search_tool_name or model parameter is required for search"
|
||||
)
|
||||
|
||||
# Set up kwargs for the fallback system
|
||||
kwargs[
|
||||
"model"
|
||||
] = search_tool_name # Use model field for compatibility with fallback system
|
||||
kwargs["original_generic_function"] = original_function
|
||||
# Bind router_instance to the helper method using partial
|
||||
kwargs["original_function"] = partial(
|
||||
SearchAPIRouter.async_search_with_fallbacks_helper,
|
||||
router_instance=router_instance,
|
||||
)
|
||||
|
||||
# Update kwargs before fallbacks (for logging, metadata, etc)
|
||||
router_instance._update_kwargs_before_fallbacks(
|
||||
model=search_tool_name,
|
||||
kwargs=kwargs,
|
||||
metadata_variable_name="litellm_metadata",
|
||||
)
|
||||
|
||||
available_search_tool_names = [
|
||||
tool.get("search_tool_name") for tool in router_instance.search_tools
|
||||
]
|
||||
verbose_router_logger.debug(
|
||||
f"Inside SearchAPIRouter.async_search_with_fallbacks() - search_tool_name: {search_tool_name}, Available Search Tools: {available_search_tool_names}, kwargs: {kwargs}"
|
||||
)
|
||||
|
||||
# Use the existing retry/fallback infrastructure
|
||||
response = await router_instance.async_function_with_fallbacks(**kwargs)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
from litellm.router_utils.handle_error import send_llm_exception_alert
|
||||
|
||||
asyncio.create_task(
|
||||
send_llm_exception_alert(
|
||||
litellm_router_instance=router_instance,
|
||||
request_kwargs=kwargs,
|
||||
error_traceback_str=traceback.format_exc(),
|
||||
original_exception=e,
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
async def async_search_with_fallbacks_helper(
|
||||
router_instance: Any,
|
||||
model: str,
|
||||
original_generic_function: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Helper function for search API calls - selects a search tool and calls the original function.
|
||||
Called by async_function_with_fallbacks for each retry attempt.
|
||||
|
||||
Args:
|
||||
router_instance: The Router instance
|
||||
model: The search tool name (passed as model for compatibility)
|
||||
original_generic_function: The original litellm.asearch function
|
||||
**kwargs: Search parameters
|
||||
|
||||
Returns:
|
||||
SearchResponse from the selected search provider
|
||||
"""
|
||||
search_tool_name = model # model field contains the search_tool_name
|
||||
|
||||
try:
|
||||
# Find matching search tools
|
||||
matching_tools = SearchAPIRouter.get_matching_search_tools(
|
||||
router_instance=router_instance,
|
||||
search_tool_name=search_tool_name,
|
||||
)
|
||||
|
||||
# Simple random selection for load balancing across multiple providers with same name
|
||||
# For search tools, we use simple random choice since they don't have TPM/RPM constraints
|
||||
selected_tool = random.choice(matching_tools)
|
||||
|
||||
# Extract search provider and other params from litellm_params
|
||||
litellm_params = selected_tool.get("litellm_params", {})
|
||||
search_provider = litellm_params.get("search_provider")
|
||||
api_key = litellm_params.get("api_key")
|
||||
api_base = litellm_params.get("api_base")
|
||||
|
||||
if not search_provider:
|
||||
raise ValueError(
|
||||
f"search_provider not found in litellm_params for search tool '{search_tool_name}'"
|
||||
)
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"Selected search tool with provider: {search_provider}"
|
||||
)
|
||||
|
||||
# Call the original search function with the provider config
|
||||
response = await original_generic_function(
|
||||
search_provider=search_provider,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
f"Error in SearchAPIRouter.async_search_with_fallbacks_helper for {search_tool_name}: {str(e)}"
|
||||
)
|
||||
raise e
|
||||
Reference in New Issue
Block a user