chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1,136 @@
"""
Auto-Routing Strategy that works with a Semantic Router Config
"""
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from litellm._logging import verbose_router_logger
from litellm.integrations.custom_logger import CustomLogger
if TYPE_CHECKING:
from semantic_router.routers.base import Route
from litellm.router import Router
from litellm.types.router import PreRoutingHookResponse
else:
Router = Any
PreRoutingHookResponse = Any
Route = Any
class AutoRouter(CustomLogger):
DEFAULT_AUTO_SYNC_VALUE = "local"
def __init__(
self,
model_name: str,
default_model: str,
embedding_model: str,
litellm_router_instance: "Router",
auto_router_config_path: Optional[str] = None,
auto_router_config: Optional[str] = None,
):
"""
Auto-Router class that uses a semantic router to route requests to the appropriate model.
Args:
model_name: The name of the model to use for the auto-router. eg. if model = "auto-router1" then us this router.
auto_router_config_path: The path to the router config file.
auto_router_config: The config to use for the auto-router. You can either use this or auto_router_config_path, not both.
default_model: The default model to use if no route is found.
embedding_model: The embedding model to use for the auto-router.
litellm_router_instance: The instance of the LiteLLM Router.
"""
from semantic_router.routers import SemanticRouter
self.auto_router_config_path: Optional[str] = auto_router_config_path
self.auto_router_config: Optional[str] = auto_router_config
self.auto_sync_value = self.DEFAULT_AUTO_SYNC_VALUE
self.loaded_routes: List[Route] = self._load_semantic_routing_routes()
self.routelayer: Optional[SemanticRouter] = None
self.default_model = default_model
self.embedding_model: str = embedding_model
self.litellm_router_instance: "Router" = litellm_router_instance
def _load_semantic_routing_routes(self) -> List[Route]:
from semantic_router.routers import SemanticRouter
if self.auto_router_config_path:
return SemanticRouter.from_json(self.auto_router_config_path).routes
elif self.auto_router_config:
return self._load_auto_router_routes_from_config_json()
else:
raise ValueError("No router config provided")
def _load_auto_router_routes_from_config_json(self) -> List[Route]:
import json
from semantic_router.routers.base import Route
if self.auto_router_config is None:
raise ValueError("No auto router config provided")
auto_router_routes: List[Route] = []
loaded_config = json.loads(self.auto_router_config)
for route in loaded_config.get("routes", []):
auto_router_routes.append(
Route(
name=route.get("name"),
description=route.get("description"),
utterances=route.get("utterances", []),
score_threshold=route.get("score_threshold"),
)
)
return auto_router_routes
async def async_pre_routing_hook(
self,
model: str,
request_kwargs: Dict,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
specific_deployment: Optional[bool] = False,
) -> Optional["PreRoutingHookResponse"]:
"""
This hook is called before the routing decision is made.
Used for the litellm auto-router to modify the request before the routing decision is made.
"""
from semantic_router.routers import SemanticRouter
from semantic_router.schema import RouteChoice
from litellm.router_strategy.auto_router.litellm_encoder import (
LiteLLMRouterEncoder,
)
from litellm.types.router import PreRoutingHookResponse
if messages is None:
# do nothing, return same inputs
return None
if self.routelayer is None:
#######################
# Create the route layer
#######################
self.routelayer = SemanticRouter(
routes=self.loaded_routes,
encoder=LiteLLMRouterEncoder(
litellm_router_instance=self.litellm_router_instance,
model_name=self.embedding_model,
),
auto_sync=self.auto_sync_value,
)
user_message: Dict[str, str] = messages[-1]
message_content: str = user_message.get("content", "")
route_choice: Optional[Union[RouteChoice, List[RouteChoice]]] = self.routelayer(
text=message_content
)
verbose_router_logger.debug(f"route_choice: {route_choice}")
if isinstance(route_choice, RouteChoice):
model = route_choice.name or self.default_model
elif isinstance(route_choice, list):
model = route_choice[0].name or self.default_model
return PreRoutingHookResponse(
model=model,
messages=messages,
)

View File

@@ -0,0 +1,139 @@
from typing import TYPE_CHECKING, Any, Optional, Union
from pydantic import ConfigDict
from semantic_router.encoders import DenseEncoder
from semantic_router.encoders.base import AsymmetricDenseMixin
import litellm
if TYPE_CHECKING:
from litellm.router import Router
else:
Router = Any
def litellm_to_list(embeds: litellm.EmbeddingResponse) -> list[list[float]]:
"""Convert a LiteLLM embedding response to a list of embeddings.
:param embeds: The LiteLLM embedding response.
:return: A list of embeddings.
"""
if (
not embeds
or not isinstance(embeds, litellm.EmbeddingResponse)
or not embeds.data
):
raise ValueError("No embeddings found in LiteLLM embedding response.")
return [x["embedding"] for x in embeds.data]
class CustomDenseEncoder(DenseEncoder):
model_config = ConfigDict(extra="allow")
def __init__(self, litellm_router_instance: Optional["Router"] = None, **kwargs):
# Extract litellm_router_instance from kwargs if passed there
if "litellm_router_instance" in kwargs:
litellm_router_instance = kwargs.pop("litellm_router_instance")
super().__init__(**kwargs)
self.litellm_router_instance = litellm_router_instance
class LiteLLMRouterEncoder(CustomDenseEncoder, AsymmetricDenseMixin):
"""LiteLLM encoder class for generating embeddings using LiteLLM.
The LiteLLMRouterEncoder class is a subclass of DenseEncoder and utilizes the LiteLLM Router SDK
to generate embeddings for given documents. It supports all encoders supported by LiteLLM
and supports customization of the score threshold for filtering or processing the embeddings.
"""
type: str = "internal_litellm_router"
def __init__(
self,
litellm_router_instance: "Router",
model_name: str,
score_threshold: Union[float, None] = None,
):
"""Initialize the LiteLLMEncoder.
:param litellm_router_instance: The instance of the LiteLLM Router.
:type litellm_router_instance: Router
:param model_name: The name of the embedding model to use. Must use LiteLLM naming
convention (e.g. "openai/text-embedding-3-small" or "mistral/mistral-embed").
:type model_name: str
:param score_threshold: The score threshold for the embeddings.
:type score_threshold: float
"""
super().__init__(
name=model_name,
score_threshold=score_threshold if score_threshold is not None else 0.3,
)
self.model_name = model_name
self.litellm_router_instance = litellm_router_instance
def __call__(self, docs: list[Any], **kwargs) -> list[list[float]]:
"""Encode a list of text documents into embeddings using LiteLLM.
:param docs: List of text documents to encode.
:return: List of embeddings for each document."""
return self.encode_queries(docs, **kwargs)
async def acall(self, docs: list[Any], **kwargs) -> list[list[float]]:
"""Encode a list of documents into embeddings using LiteLLM asynchronously.
:param docs: List of documents to encode.
:return: List of embeddings for each document."""
return await self.aencode_queries(docs, **kwargs)
def encode_queries(self, docs: list[str], **kwargs) -> list[list[float]]:
if self.litellm_router_instance is None:
raise ValueError("litellm_router_instance is not set")
try:
embeds = self.litellm_router_instance.embedding(
input=docs, model=self.model_name, **kwargs
)
return litellm_to_list(embeds)
except Exception as e:
raise ValueError(
f"{self.type.capitalize()} API call failed. Error: {e}"
) from e
def encode_documents(self, docs: list[str], **kwargs) -> list[list[float]]:
if self.litellm_router_instance is None:
raise ValueError("litellm_router_instance is not set")
try:
embeds = self.litellm_router_instance.embedding(
input=docs, model=self.model_name, **kwargs
)
return litellm_to_list(embeds)
except Exception as e:
raise ValueError(
f"{self.type.capitalize()} API call failed. Error: {e}"
) from e
async def aencode_queries(self, docs: list[str], **kwargs) -> list[list[float]]:
if self.litellm_router_instance is None:
raise ValueError("litellm_router_instance is not set")
try:
embeds = await self.litellm_router_instance.aembedding(
input=docs, model=self.model_name, **kwargs
)
return litellm_to_list(embeds)
except Exception as e:
raise ValueError(
f"{self.type.capitalize()} API call failed. Error: {e}"
) from e
async def aencode_documents(self, docs: list[str], **kwargs) -> list[list[float]]:
if self.litellm_router_instance is None:
raise ValueError("litellm_router_instance is not set")
try:
embeds = await self.litellm_router_instance.aembedding(
input=docs, model=self.model_name, **kwargs
)
return litellm_to_list(embeds)
except Exception as e:
raise ValueError(
f"{self.type.capitalize()} API call failed. Error: {e}"
) from e

View File

@@ -0,0 +1,261 @@
"""
Base class across routing strategies to abstract commmon functions like batch incrementing redis
"""
import asyncio
from abc import ABC
from typing import Dict, List, Optional, Set, Tuple, Union
from litellm._logging import verbose_router_logger
from litellm.caching.caching import DualCache
from litellm.caching.redis_cache import RedisPipelineIncrementOperation
from litellm.constants import DEFAULT_REDIS_SYNC_INTERVAL
class BaseRoutingStrategy(ABC):
def __init__(
self,
dual_cache: DualCache,
should_batch_redis_writes: bool,
default_sync_interval: Optional[Union[int, float]],
):
self.dual_cache = dual_cache
self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
self._sync_task: Optional[asyncio.Task[None]] = None
if should_batch_redis_writes:
self.setup_sync_task(default_sync_interval)
self.in_memory_keys_to_update: set[
str
] = set() # Set with max size of 1000 keys
def setup_sync_task(self, default_sync_interval: Optional[Union[int, float]]):
"""Setup the sync task in a way that's compatible with FastAPI"""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self._sync_task = loop.create_task(
self.periodic_sync_in_memory_spend_with_redis(
default_sync_interval=default_sync_interval
)
)
async def cleanup(self):
"""Cleanup method to be called when shutting down"""
if self._sync_task is not None:
self._sync_task.cancel()
try:
await self._sync_task
except asyncio.CancelledError:
pass
async def _increment_value_list_in_current_window(
self, increment_list: List[Tuple[str, int]], ttl: int
) -> List[float]:
"""
Increment a list of values in the current window
"""
results = []
for key, value in increment_list:
result = await self._increment_value_in_current_window(
key=key, value=value, ttl=ttl
)
results.append(result)
return results
async def _increment_value_in_current_window(
self, key: str, value: Union[int, float], ttl: int
):
"""
Increment spend within existing budget window
Runs once the budget start time exists in Redis Cache (on the 2nd and subsequent requests to the same provider)
- Increments the spend in memory cache (so spend instantly updated in memory)
- Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM)
"""
result = await self.dual_cache.in_memory_cache.async_increment(
key=key,
value=value,
ttl=ttl,
)
increment_op = RedisPipelineIncrementOperation(
key=key,
increment_value=value,
ttl=ttl,
)
self.redis_increment_operation_queue.append(increment_op)
self.add_to_in_memory_keys_to_update(key=key)
return result
async def periodic_sync_in_memory_spend_with_redis(
self, default_sync_interval: Optional[Union[int, float]]
):
"""
Handler that triggers sync_in_memory_spend_with_redis every DEFAULT_REDIS_SYNC_INTERVAL seconds
Required for multi-instance environment usage of provider budgets
"""
default_sync_interval = default_sync_interval or DEFAULT_REDIS_SYNC_INTERVAL
while True:
try:
await self._sync_in_memory_spend_with_redis()
await asyncio.sleep(
default_sync_interval
) # Wait for DEFAULT_REDIS_SYNC_INTERVAL seconds before next sync
except Exception as e:
verbose_router_logger.error(f"Error in periodic sync task: {str(e)}")
await asyncio.sleep(
default_sync_interval
) # Still wait DEFAULT_REDIS_SYNC_INTERVAL seconds on error before retrying
async def _push_in_memory_increments_to_redis(self):
"""
How this works:
- async_log_success_event collects all provider spend increments in `redis_increment_operation_queue`
- This function compresses multiple increments for the same key into a single operation
- Then pushes all increments to Redis in a batched pipeline to optimize performance
Only runs if Redis is initialized
"""
try:
if not self.dual_cache.redis_cache:
return # Redis is not initialized
if len(self.redis_increment_operation_queue) > 0:
# Compress operations for the same key
compressed_ops: Dict[str, RedisPipelineIncrementOperation] = {}
ops_to_remove = []
for idx, op in enumerate(self.redis_increment_operation_queue):
if op["key"] in compressed_ops:
# Add to existing increment
compressed_ops[op["key"]]["increment_value"] += op[
"increment_value"
]
else:
compressed_ops[op["key"]] = op
ops_to_remove.append(idx)
# Convert back to list
compressed_queue = list(compressed_ops.values())
increment_result = (
await self.dual_cache.redis_cache.async_increment_pipeline(
increment_list=compressed_queue,
)
)
self.redis_increment_operation_queue = [
op
for idx, op in enumerate(self.redis_increment_operation_queue)
if idx not in ops_to_remove
]
if increment_result is not None:
return_result = {
key["key"]: op
for key, op in zip(compressed_queue, increment_result)
}
else:
return_result = {}
return return_result
except Exception as e:
verbose_router_logger.error(
f"Error syncing in-memory cache with Redis: {str(e)}"
)
self.redis_increment_operation_queue = []
def add_to_in_memory_keys_to_update(self, key: str):
self.in_memory_keys_to_update.add(key)
def get_key_pattern_to_sync(self) -> Optional[str]:
"""
Get the key pattern to sync
"""
return None
def get_in_memory_keys_to_update(self) -> Set[str]:
return self.in_memory_keys_to_update
def get_and_reset_in_memory_keys_to_update(self) -> Set[str]:
"""Atomic get and reset in-memory keys to update"""
keys = self.in_memory_keys_to_update
self.in_memory_keys_to_update = set()
return keys
def reset_in_memory_keys_to_update(self):
self.in_memory_keys_to_update = set()
async def _sync_in_memory_spend_with_redis(self):
"""
Ensures in-memory cache is updated with latest Redis values for all provider spends.
Why Do we need this?
- Optimization to hit sub 100ms latency. Performance was impacted when redis was used for read/write per request
- Use provider budgets in multi-instance environment, we use Redis to sync spend across all instances
What this does:
1. Push all provider spend increments to Redis
2. Fetch all current provider spend from Redis to update in-memory cache
"""
try:
# No need to sync if Redis cache is not initialized
if self.dual_cache.redis_cache is None:
return
# 2. Fetch all current provider spend from Redis to update in-memory cache
cache_keys = (
self.get_in_memory_keys_to_update()
) # if no pattern OR redis cache does not support scan_iter, use in-memory keys
cache_keys_list = list(cache_keys)
# 1. Snapshot in-memory before
in_memory_before_dict = {}
in_memory_before = (
await self.dual_cache.in_memory_cache.async_batch_get_cache(
keys=cache_keys_list
)
)
for k, v in zip(cache_keys_list, in_memory_before):
in_memory_before_dict[k] = float(v or 0)
# 1. Push all provider spend increments to Redis
redis_values = await self._push_in_memory_increments_to_redis()
if redis_values is None:
return
# 4. Merge
for key in cache_keys_list:
redis_val = float(redis_values.get(key, 0) or 0)
before = float(in_memory_before_dict.get(key, 0) or 0)
after = float(
await self.dual_cache.in_memory_cache.async_get_cache(key=key) or 0
)
delta = after - before
if after <= redis_val:
merged = redis_val + delta
else:
continue
# elif "rpm" in key: # redis is behind in-memory cache
# # shut down the proxy
# print(f"self.redis_increment_operation_queue: {self.redis_increment_operation_queue}")
# print(f"Redis_val={redis_val} is behind in-memory cache_val={after} for key: {key}. This should not happen, since we should be updating redis with in-memory cache.")
# import os
# os._exit(1)
# raise Exception(f"Redis is behind in-memory cache for key: {key}. This should not happen, since we should be updating redis with in-memory cache.")
await self.dual_cache.in_memory_cache.async_set_cache(
key=key, value=merged
)
except Exception as e:
verbose_router_logger.exception(
f"Error syncing in-memory cache with Redis: {str(e)}"
)

View File

@@ -0,0 +1,898 @@
"""
Provider budget limiting
Use this if you want to set $ budget limits for each provider.
Note: This is a filter, like tag-routing. Meaning it will accept healthy deployments and then filter out deployments that have exceeded their budget limit.
This means you can use this with weighted-pick, lowest-latency, simple-shuffle, routing etc
Example:
```
openai:
budget_limit: 0.000000000001
time_period: 1d
anthropic:
budget_limit: 100
time_period: 7d
```
"""
import asyncio
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple, Union
import litellm
from litellm._logging import verbose_router_logger
from litellm.caching.caching import DualCache
from litellm.caching.redis_cache import RedisPipelineIncrementOperation
from litellm.integrations.custom_logger import CustomLogger, Span
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.router_strategy.tag_based_routing import _get_tags_from_request_kwargs
from litellm.router_utils.cooldown_callbacks import (
_get_prometheus_logger_from_callbacks,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.router import DeploymentTypedDict, LiteLLM_Params, RouterErrors
from litellm.types.utils import BudgetConfig
from litellm.types.utils import BudgetConfig as GenericBudgetInfo
from litellm.types.utils import GenericBudgetConfigType, StandardLoggingPayload
DEFAULT_REDIS_SYNC_INTERVAL = 1
class _LiteLLMParamsDictView:
"""
Lightweight attribute view over `litellm_params` dict.
This avoids pydantic construction in request hot-path while preserving
attribute-style access used by `litellm.get_llm_provider(...)`.
"""
__slots__ = ("_params",)
def __init__(self, params: Dict[str, Any]):
self._params = params
def __getattr__(self, key: str) -> Any:
return self._params.get(key)
def __getitem__(self, key: str) -> Any:
return self._params.get(key)
def __contains__(self, key: str) -> bool:
return key in self._params
def get(self, key: str, default: Any = None) -> Any:
return self._params.get(key, default)
def keys(self):
return self._params.keys()
def values(self):
return self._params.values()
def items(self):
return self._params.items()
def __iter__(self):
return iter(self._params)
def __len__(self) -> int:
return len(self._params)
def dict(self) -> Dict[str, Any]:
return dict(self._params)
def model_dump(self) -> Dict[str, Any]:
return dict(self._params)
class RouterBudgetLimiting(CustomLogger):
def __init__(
self,
dual_cache: DualCache,
provider_budget_config: Optional[dict],
model_list: Optional[
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
] = None,
):
self.dual_cache = dual_cache
self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis())
self.provider_budget_config: Optional[
GenericBudgetConfigType
] = provider_budget_config
self.deployment_budget_config: Optional[GenericBudgetConfigType] = None
self.tag_budget_config: Optional[GenericBudgetConfigType] = None
self._init_provider_budgets()
self._init_deployment_budgets(model_list=model_list)
self._init_tag_budgets()
# Add self to litellm callbacks if it's a list
if isinstance(litellm.callbacks, list):
litellm.logging_callback_manager.add_litellm_callback(self) # type: ignore
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None, # type: ignore
) -> List[dict]:
"""
Filter out deployments that have exceeded their provider budget limit.
Example:
if deployment = openai/gpt-3.5-turbo
and openai spend > openai budget limit
then skip this deployment
"""
# If a single deployment is passed, convert it to a list
if isinstance(healthy_deployments, dict):
healthy_deployments = [healthy_deployments]
# Don't do any filtering if there are no healthy deployments
if len(healthy_deployments) == 0:
return healthy_deployments
potential_deployments: List[Dict] = []
(
cache_keys,
provider_configs,
deployment_configs,
deployment_providers,
) = await self._async_get_cache_keys_for_router_budget_limiting(
healthy_deployments=healthy_deployments,
request_kwargs=request_kwargs,
)
# Single cache read for all spend values
if len(cache_keys) > 0:
_current_spends = await self.dual_cache.async_batch_get_cache(
keys=cache_keys,
parent_otel_span=parent_otel_span,
)
current_spends: List = _current_spends or [0.0] * len(cache_keys)
# Map spends to their respective keys
spend_map: Dict[str, float] = {}
for idx, key in enumerate(cache_keys):
spend_map[key] = float(current_spends[idx] or 0.0)
(
potential_deployments,
deployment_above_budget_info,
) = self._filter_out_deployments_above_budget(
healthy_deployments=healthy_deployments,
provider_configs=provider_configs,
deployment_configs=deployment_configs,
deployment_providers=deployment_providers,
spend_map=spend_map,
potential_deployments=potential_deployments,
request_tags=_get_tags_from_request_kwargs(
request_kwargs=request_kwargs
),
)
if len(potential_deployments) == 0:
raise ValueError(
f"{RouterErrors.no_deployments_with_provider_budget_routing.value}: {deployment_above_budget_info}"
)
return potential_deployments
else:
return healthy_deployments
def _filter_out_deployments_above_budget(
self,
potential_deployments: List[Dict[str, Any]],
healthy_deployments: List[Dict[str, Any]],
provider_configs: Dict[str, GenericBudgetInfo],
deployment_configs: Dict[str, GenericBudgetInfo],
deployment_providers: List[Optional[str]],
spend_map: Dict[str, float],
request_tags: List[str],
) -> Tuple[List[Dict[str, Any]], str]:
"""
Filter out deployments that have exceeded their budget limit.
Follow budget checks are run here:
- Provider budget
- Deployment budget
- Request tags budget
Returns:
Tuple[List[Dict[str, Any]], str]:
- A tuple containing the filtered deployments
- A string containing debug information about deployments that exceeded their budget limit.
"""
# Filter deployments based on both provider and deployment budgets
deployment_above_budget_info: str = ""
for idx, deployment in enumerate(healthy_deployments):
is_within_budget = True
# Check provider budget
if self.provider_budget_config:
if idx < len(deployment_providers):
provider = deployment_providers[idx]
else:
provider = self._get_llm_provider_for_deployment(deployment)
if provider in provider_configs:
config = provider_configs[provider]
if config.max_budget is None:
continue
current_spend = spend_map.get(
f"provider_spend:{provider}:{config.budget_duration}", 0.0
)
self._track_provider_remaining_budget_prometheus(
provider=provider,
spend=current_spend,
budget_limit=config.max_budget,
)
if config.max_budget and current_spend >= config.max_budget:
debug_msg = f"Exceeded budget for provider {provider}: {current_spend} >= {config.max_budget}"
deployment_above_budget_info += f"{debug_msg}\n"
is_within_budget = False
continue
# Check deployment budget
if self.deployment_budget_config and is_within_budget:
_model_name = deployment.get("model_name")
_litellm_params = deployment.get("litellm_params") or {}
_litellm_model_name = _litellm_params.get("model")
model_id = deployment.get("model_info", {}).get("id")
if model_id in deployment_configs:
config = deployment_configs[model_id]
current_spend = spend_map.get(
f"deployment_spend:{model_id}:{config.budget_duration}", 0.0
)
if config.max_budget and current_spend >= config.max_budget:
debug_msg = f"Exceeded budget for deployment model_name: {_model_name}, litellm_params.model: {_litellm_model_name}, model_id: {model_id}: {current_spend} >= {config.budget_duration}"
verbose_router_logger.debug(debug_msg)
deployment_above_budget_info += f"{debug_msg}\n"
is_within_budget = False
continue
# Check tag budget
if self.tag_budget_config and is_within_budget:
for _tag in request_tags:
_tag_budget_config = self._get_budget_config_for_tag(_tag)
if _tag_budget_config:
_tag_spend = spend_map.get(
f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}",
0.0,
)
if (
_tag_budget_config.max_budget
and _tag_spend >= _tag_budget_config.max_budget
):
debug_msg = f"Exceeded budget for tag='{_tag}', tag_spend={_tag_spend}, tag_budget_limit={_tag_budget_config.max_budget}"
verbose_router_logger.debug(debug_msg)
deployment_above_budget_info += f"{debug_msg}\n"
is_within_budget = False
continue
if is_within_budget:
potential_deployments.append(deployment)
return potential_deployments, deployment_above_budget_info
async def _async_get_cache_keys_for_router_budget_limiting(
self,
healthy_deployments: List[Dict[str, Any]],
request_kwargs: Optional[Dict] = None,
) -> Tuple[
List[str],
Dict[str, GenericBudgetInfo],
Dict[str, GenericBudgetInfo],
List[Optional[str]],
]:
"""
Returns list of cache keys to fetch from router cache for budget limiting and provider and deployment configs
Returns:
Tuple[List[str], Dict[str, GenericBudgetInfo], Dict[str, GenericBudgetInfo], List[Optional[str]]]:
- List of cache keys to fetch from router cache for budget limiting
- Dict of provider budget configs `provider_configs`
- Dict of deployment budget configs `deployment_configs`
- List of resolved providers aligned by deployment index `deployment_providers`
"""
cache_keys: List[str] = []
provider_configs: Dict[str, GenericBudgetInfo] = {}
deployment_configs: Dict[str, GenericBudgetInfo] = {}
deployment_providers: List[Optional[str]] = []
for deployment in healthy_deployments:
# Check provider budgets
if self.provider_budget_config:
provider = self._get_llm_provider_for_deployment(deployment)
deployment_providers.append(provider)
if provider is not None:
budget_config = self._get_budget_config_for_provider(provider)
if (
budget_config is not None
and budget_config.budget_duration is not None
):
provider_configs[provider] = budget_config
cache_keys.append(
f"provider_spend:{provider}:{budget_config.budget_duration}"
)
# Check deployment budgets
if self.deployment_budget_config:
model_id = deployment.get("model_info", {}).get("id")
if model_id is not None:
budget_config = self._get_budget_config_for_deployment(model_id)
if budget_config is not None:
deployment_configs[model_id] = budget_config
cache_keys.append(
f"deployment_spend:{model_id}:{budget_config.budget_duration}"
)
# Check tag budgets
if self.tag_budget_config:
request_tags = _get_tags_from_request_kwargs(
request_kwargs=request_kwargs
)
for _tag in request_tags:
_tag_budget_config = self._get_budget_config_for_tag(_tag)
if _tag_budget_config:
cache_keys.append(
f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}"
)
return (
cache_keys,
provider_configs,
deployment_configs,
deployment_providers,
)
async def _get_or_set_budget_start_time(
self, start_time_key: str, current_time: float, ttl_seconds: int
) -> float:
"""
Checks if the key = `provider_budget_start_time:{provider}` exists in cache.
If it does, return the value.
If it does not, set the key to `current_time` and return the value.
"""
budget_start = await self.dual_cache.async_get_cache(start_time_key)
if budget_start is None:
await self.dual_cache.async_set_cache(
key=start_time_key, value=current_time, ttl=ttl_seconds
)
return current_time
return float(budget_start)
async def _handle_new_budget_window(
self,
spend_key: str,
start_time_key: str,
current_time: float,
response_cost: float,
ttl_seconds: int,
) -> float:
"""
Handle start of new budget window by resetting spend and start time
Enters this when:
- The budget does not exist in cache, so we need to set it
- The budget window has expired, so we need to reset everything
Does 2 things:
- stores key: `provider_spend:{provider}:1d`, value: response_cost
- stores key: `provider_budget_start_time:{provider}`, value: current_time.
This stores the start time of the new budget window
"""
await self.dual_cache.async_set_cache(
key=spend_key, value=response_cost, ttl=ttl_seconds
)
await self.dual_cache.async_set_cache(
key=start_time_key, value=current_time, ttl=ttl_seconds
)
return current_time
async def _increment_spend_in_current_window(
self, spend_key: str, response_cost: float, ttl: int
):
"""
Increment spend within existing budget window
Runs once the budget start time exists in Redis Cache (on the 2nd and subsequent requests to the same provider)
- Increments the spend in memory cache (so spend instantly updated in memory)
- Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM)
"""
await self.dual_cache.in_memory_cache.async_increment(
key=spend_key,
value=response_cost,
ttl=ttl,
)
increment_op = RedisPipelineIncrementOperation(
key=spend_key,
increment_value=response_cost,
ttl=ttl,
)
self.redis_increment_operation_queue.append(increment_op)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""Original method now uses helper functions"""
verbose_router_logger.debug("in RouterBudgetLimiting.async_log_success_event")
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if standard_logging_payload is None:
raise ValueError("standard_logging_payload is required")
response_cost: float = standard_logging_payload.get("response_cost", 0)
model_id: str = str(standard_logging_payload.get("model_id", ""))
custom_llm_provider: str = kwargs.get("litellm_params", {}).get(
"custom_llm_provider", None
)
if custom_llm_provider is None:
raise ValueError("custom_llm_provider is required")
budget_config = self._get_budget_config_for_provider(custom_llm_provider)
if budget_config:
# increment spend for provider
spend_key = (
f"provider_spend:{custom_llm_provider}:{budget_config.budget_duration}"
)
start_time_key = f"provider_budget_start_time:{custom_llm_provider}"
await self._increment_spend_for_key(
budget_config=budget_config,
spend_key=spend_key,
start_time_key=start_time_key,
response_cost=response_cost,
)
deployment_budget_config = self._get_budget_config_for_deployment(model_id)
if deployment_budget_config:
# increment spend for specific deployment id
deployment_spend_key = f"deployment_spend:{model_id}:{deployment_budget_config.budget_duration}"
deployment_start_time_key = f"deployment_budget_start_time:{model_id}"
await self._increment_spend_for_key(
budget_config=deployment_budget_config,
spend_key=deployment_spend_key,
start_time_key=deployment_start_time_key,
response_cost=response_cost,
)
request_tags = _get_tags_from_request_kwargs(kwargs)
if len(request_tags) > 0:
for _tag in request_tags:
_tag_budget_config = self._get_budget_config_for_tag(_tag)
if _tag_budget_config:
_tag_spend_key = (
f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}"
)
_tag_start_time_key = f"tag_budget_start_time:{_tag}"
await self._increment_spend_for_key(
budget_config=_tag_budget_config,
spend_key=_tag_spend_key,
start_time_key=_tag_start_time_key,
response_cost=response_cost,
)
async def _increment_spend_for_key(
self,
budget_config: GenericBudgetInfo,
spend_key: str,
start_time_key: str,
response_cost: float,
):
if budget_config.budget_duration is None:
return
current_time = datetime.now(timezone.utc).timestamp()
ttl_seconds = duration_in_seconds(budget_config.budget_duration)
budget_start = await self._get_or_set_budget_start_time(
start_time_key=start_time_key,
current_time=current_time,
ttl_seconds=ttl_seconds,
)
if budget_start is None:
# First spend for this provider
budget_start = await self._handle_new_budget_window(
spend_key=spend_key,
start_time_key=start_time_key,
current_time=current_time,
response_cost=response_cost,
ttl_seconds=ttl_seconds,
)
elif (current_time - budget_start) > ttl_seconds:
# Budget window expired - reset everything
verbose_router_logger.debug("Budget window expired - resetting everything")
budget_start = await self._handle_new_budget_window(
spend_key=spend_key,
start_time_key=start_time_key,
current_time=current_time,
response_cost=response_cost,
ttl_seconds=ttl_seconds,
)
else:
# Within existing window - increment spend
remaining_time = ttl_seconds - (current_time - budget_start)
ttl_for_increment = int(remaining_time)
await self._increment_spend_in_current_window(
spend_key=spend_key, response_cost=response_cost, ttl=ttl_for_increment
)
verbose_router_logger.debug(
f"Incremented spend for {spend_key} by {response_cost}"
)
async def periodic_sync_in_memory_spend_with_redis(self):
"""
Handler that triggers sync_in_memory_spend_with_redis every DEFAULT_REDIS_SYNC_INTERVAL seconds
Required for multi-instance environment usage of provider budgets
"""
while True:
try:
await self._sync_in_memory_spend_with_redis()
await asyncio.sleep(
DEFAULT_REDIS_SYNC_INTERVAL
) # Wait for DEFAULT_REDIS_SYNC_INTERVAL seconds before next sync
except Exception as e:
verbose_router_logger.error(f"Error in periodic sync task: {str(e)}")
await asyncio.sleep(
DEFAULT_REDIS_SYNC_INTERVAL
) # Still wait DEFAULT_REDIS_SYNC_INTERVAL seconds on error before retrying
async def _push_in_memory_increments_to_redis(self):
"""
How this works:
- async_log_success_event collects all provider spend increments in `redis_increment_operation_queue`
- This function pushes all increments to Redis in a batched pipeline to optimize performance
Only runs if Redis is initialized
"""
try:
if not self.dual_cache.redis_cache:
return # Redis is not initialized
verbose_router_logger.debug(
"Pushing Redis Increment Pipeline for queue: %s",
self.redis_increment_operation_queue,
)
if len(self.redis_increment_operation_queue) > 0:
asyncio.create_task(
self.dual_cache.redis_cache.async_increment_pipeline(
increment_list=self.redis_increment_operation_queue,
)
)
self.redis_increment_operation_queue = []
except Exception as e:
verbose_router_logger.error(
f"Error syncing in-memory cache with Redis: {str(e)}"
)
async def _sync_in_memory_spend_with_redis(self):
"""
Ensures in-memory cache is updated with latest Redis values for all provider spends.
Why Do we need this?
- Optimization to hit sub 100ms latency. Performance was impacted when redis was used for read/write per request
- Use provider budgets in multi-instance environment, we use Redis to sync spend across all instances
What this does:
1. Push all provider spend increments to Redis
2. Fetch all current provider spend from Redis to update in-memory cache
"""
try:
# No need to sync if Redis cache is not initialized
if self.dual_cache.redis_cache is None:
return
# 1. Push all provider spend increments to Redis
await self._push_in_memory_increments_to_redis()
# 2. Fetch all current provider spend from Redis to update in-memory cache
cache_keys = []
if self.provider_budget_config is not None:
for provider, config in self.provider_budget_config.items():
if config is None:
continue
cache_keys.append(
f"provider_spend:{provider}:{config.budget_duration}"
)
if self.deployment_budget_config is not None:
for model_id, config in self.deployment_budget_config.items():
if config is None:
continue
cache_keys.append(
f"deployment_spend:{model_id}:{config.budget_duration}"
)
if self.tag_budget_config is not None:
for tag, config in self.tag_budget_config.items():
if config is None:
continue
cache_keys.append(f"tag_spend:{tag}:{config.budget_duration}")
# Batch fetch current spend values from Redis
redis_values = await self.dual_cache.redis_cache.async_batch_get_cache(
key_list=cache_keys
)
# Update in-memory cache with Redis values
if isinstance(redis_values, dict): # Check if redis_values is a dictionary
for key, value in redis_values.items():
if value is not None:
await self.dual_cache.in_memory_cache.async_set_cache(
key=key, value=float(value)
)
verbose_router_logger.debug(
f"Updated in-memory cache for {key}: {value}"
)
except Exception as e:
verbose_router_logger.error(
f"Error syncing in-memory cache with Redis: {str(e)}"
)
def _get_budget_config_for_deployment(
self,
model_id: str,
) -> Optional[GenericBudgetInfo]:
if self.deployment_budget_config is None:
return None
return self.deployment_budget_config.get(model_id, None)
def _get_budget_config_for_provider(
self, provider: str
) -> Optional[GenericBudgetInfo]:
if self.provider_budget_config is None:
return None
return self.provider_budget_config.get(provider, None)
def _get_budget_config_for_tag(self, tag: str) -> Optional[GenericBudgetInfo]:
if self.tag_budget_config is None:
return None
return self.tag_budget_config.get(tag, None)
def _get_llm_provider_for_deployment(self, deployment: Dict) -> Optional[str]:
try:
deployment_litellm_params = deployment.get("litellm_params") or {}
if isinstance(deployment_litellm_params, LiteLLM_Params):
model = deployment_litellm_params.model or ""
provider_resolution_params: Any = deployment_litellm_params
elif isinstance(deployment_litellm_params, dict):
model = deployment_litellm_params.get("model") or ""
provider_resolution_params = _LiteLLMParamsDictView(
deployment_litellm_params
)
else:
model = ""
provider_resolution_params = _LiteLLMParamsDictView({})
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=str(model),
litellm_params=provider_resolution_params,
)
except Exception:
verbose_router_logger.error(
f"Error getting LLM provider for deployment: {deployment}"
)
return None
return custom_llm_provider
def _track_provider_remaining_budget_prometheus(
self, provider: str, spend: float, budget_limit: float
):
"""
Optional helper - emit provider remaining budget metric to Prometheus
This is helpful for debugging and monitoring provider budget limits.
"""
prometheus_logger = _get_prometheus_logger_from_callbacks()
if prometheus_logger:
prometheus_logger.track_provider_remaining_budget(
provider=provider,
spend=spend,
budget_limit=budget_limit,
)
async def _get_current_provider_spend(self, provider: str) -> Optional[float]:
"""
GET the current spend for a provider from cache
used for GET /provider/budgets endpoint in spend_management_endpoints.py
Args:
provider (str): The provider to get spend for (e.g., "openai", "anthropic")
Returns:
Optional[float]: The current spend for the provider, or None if not found
"""
budget_config = self._get_budget_config_for_provider(provider)
if budget_config is None:
return None
spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}"
if self.dual_cache.redis_cache:
# use Redis as source of truth since that has spend across all instances
current_spend = await self.dual_cache.redis_cache.async_get_cache(spend_key)
else:
# use in-memory cache if Redis is not initialized
current_spend = await self.dual_cache.async_get_cache(spend_key)
return float(current_spend) if current_spend is not None else 0.0
async def _get_current_provider_budget_reset_at(
self, provider: str
) -> Optional[str]:
budget_config = self._get_budget_config_for_provider(provider)
if budget_config is None:
return None
spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}"
if self.dual_cache.redis_cache:
ttl_seconds = await self.dual_cache.redis_cache.async_get_ttl(spend_key)
else:
ttl_seconds = await self.dual_cache.async_get_ttl(spend_key)
if ttl_seconds is None:
return None
return (datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)).isoformat()
async def _init_provider_budget_in_cache(
self, provider: str, budget_config: GenericBudgetInfo
):
"""
Initialize provider budget in cache by storing the following keys if they don't exist:
- provider_spend:{provider}:{budget_config.time_period} - stores the current spend
- provider_budget_start_time:{provider} - stores the start time of the budget window
"""
spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}"
start_time_key = f"provider_budget_start_time:{provider}"
ttl_seconds: Optional[int] = None
if budget_config.budget_duration is not None:
ttl_seconds = duration_in_seconds(budget_config.budget_duration)
budget_start = await self.dual_cache.async_get_cache(start_time_key)
if budget_start is None:
budget_start = datetime.now(timezone.utc).timestamp()
await self.dual_cache.async_set_cache(
key=start_time_key, value=budget_start, ttl=ttl_seconds
)
_spend_key = await self.dual_cache.async_get_cache(spend_key)
if _spend_key is None:
await self.dual_cache.async_set_cache(
key=spend_key, value=0.0, ttl=ttl_seconds
)
@staticmethod
def should_init_router_budget_limiter(
provider_budget_config: Optional[dict],
model_list: Optional[
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
] = None,
):
"""
Returns `True` if the router budget routing settings are set and RouterBudgetLimiting should be initialized
Either:
- provider_budget_config is set
- budgets are set for deployments in the model_list
- tag_budget_config is set
"""
if provider_budget_config is not None:
return True
if litellm.tag_budget_config is not None:
return True
if model_list is None:
return False
for _model in model_list:
_litellm_params = _model.get("litellm_params", {})
if (
_litellm_params.get("max_budget")
or _litellm_params.get("budget_duration") is not None
):
return True
return False
def _init_provider_budgets(self):
if self.provider_budget_config is not None:
# cast elements of provider_budget_config to GenericBudgetInfo
for provider, config in self.provider_budget_config.items():
if config is None:
raise ValueError(
f"No budget config found for provider {provider}, provider_budget_config: {self.provider_budget_config}"
)
if not isinstance(config, GenericBudgetInfo):
self.provider_budget_config[provider] = GenericBudgetInfo(
budget_limit=config.get("budget_limit"),
time_period=config.get("time_period"),
)
asyncio.create_task(
self._init_provider_budget_in_cache(
provider=provider,
budget_config=self.provider_budget_config[provider],
)
)
verbose_router_logger.debug(
f"Initalized Provider budget config: {self.provider_budget_config}"
)
def _init_deployment_budgets(
self,
model_list: Optional[
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
] = None,
):
if model_list is None:
return
for _model in model_list:
_litellm_params = _model.get("litellm_params", {})
_model_info: Dict = _model.get("model_info") or {}
_model_id = _model_info.get("id")
_max_budget = _litellm_params.get("max_budget")
_budget_duration = _litellm_params.get("budget_duration")
verbose_router_logger.debug(
f"Init Deployment Budget: max_budget: {_max_budget}, budget_duration: {_budget_duration}, model_id: {_model_id}"
)
if (
_max_budget is not None
and _budget_duration is not None
and _model_id is not None
):
_budget_config = GenericBudgetInfo(
time_period=_budget_duration,
budget_limit=_max_budget,
)
if self.deployment_budget_config is None:
self.deployment_budget_config = {}
self.deployment_budget_config[_model_id] = _budget_config
verbose_router_logger.debug(
f"Initialized Deployment Budget Config: {self.deployment_budget_config}"
)
def _init_tag_budgets(self):
if litellm.tag_budget_config is None:
return
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
if premium_user is not True:
raise ValueError(
f"Tag budgets are an Enterprise only feature, {CommonProxyErrors.not_premium_user}"
)
if self.tag_budget_config is None:
self.tag_budget_config = {}
for _tag, _tag_budget_config in litellm.tag_budget_config.items():
if isinstance(_tag_budget_config, dict):
_tag_budget_config = BudgetConfig(**_tag_budget_config)
_generic_budget_config = GenericBudgetInfo(
time_period=_tag_budget_config.budget_duration,
budget_limit=_tag_budget_config.max_budget,
)
self.tag_budget_config[_tag] = _generic_budget_config
verbose_router_logger.debug(
f"Initialized Tag Budget Config: {self.tag_budget_config}"
)

View File

@@ -0,0 +1,162 @@
# Complexity Router
A rule-based routing strategy that classifies requests by complexity and routes them to appropriate models - with zero API calls and sub-millisecond latency.
## Overview
Unlike the semantic `auto_router` which uses embedding-based matching, the `complexity_router` uses weighted rule-based scoring across multiple dimensions to classify request complexity. This approach:
- **Zero external API calls** - all scoring is local
- **Sub-millisecond latency** - typically <1ms per classification
- **Predictable behavior** - rule-based scoring is deterministic
- **Fully configurable** - weights, thresholds, and keyword lists can be customized
## How It Works
The router scores each request across 7 dimensions:
| Dimension | Description | Weight |
|-----------|-------------|--------|
| `tokenCount` | Short prompts = simple, long = complex | 0.10 |
| `codePresence` | Code keywords (function, class, etc.) | 0.30 |
| `reasoningMarkers` | "step by step", "think through", etc. | 0.25 |
| `technicalTerms` | Domain complexity indicators | 0.25 |
| `simpleIndicators` | "what is", "define" (negative weight) | 0.05 |
| `multiStepPatterns` | "first...then", numbered steps | 0.03 |
| `questionComplexity` | Multiple question marks | 0.02 |
The weighted sum is mapped to tiers using configurable boundaries:
| Tier | Score Range | Typical Use |
|------|-------------|-------------|
| SIMPLE | < 0.15 | Basic questions, greetings |
| MEDIUM | 0.15 - 0.35 | Standard queries |
| COMPLEX | 0.35 - 0.60 | Technical, multi-part requests |
| REASONING | > 0.60 | Chain-of-thought, analysis |
## Configuration
### Basic Configuration
```yaml
model_list:
- model_name: smart-router
litellm_params:
model: auto_router/complexity_router
complexity_router_config:
tiers:
SIMPLE: gpt-4o-mini
MEDIUM: gpt-4o
COMPLEX: claude-sonnet-4
REASONING: o1-preview
```
### Full Configuration
```yaml
model_list:
- model_name: smart-router
litellm_params:
model: auto_router/complexity_router
complexity_router_config:
# Tier to model mapping
tiers:
SIMPLE: gpt-4o-mini
MEDIUM: gpt-4o
COMPLEX: claude-sonnet-4
REASONING: o1-preview
# Tier boundaries (normalized scores)
tier_boundaries:
simple_medium: 0.15
medium_complex: 0.35
complex_reasoning: 0.60
# Token count thresholds
token_thresholds:
simple: 15 # Below this = "short" (default: 15)
complex: 400 # Above this = "long" (default: 400)
# Dimension weights (must sum to ~1.0)
dimension_weights:
tokenCount: 0.10
codePresence: 0.30
reasoningMarkers: 0.25
technicalTerms: 0.25
simpleIndicators: 0.05
multiStepPatterns: 0.03
questionComplexity: 0.02
# Override default keyword lists
code_keywords:
- function
- class
- def
- async
- database
reasoning_keywords:
- step by step
- think through
- analyze
# Fallback model if tier cannot be determined
default_model: gpt-4o
```
## Usage
Once configured, use the model name like any other:
```python
import litellm
response = litellm.completion(
model="smart-router", # Your complexity_router model name
messages=[{"role": "user", "content": "What is 2+2?"}]
)
# Routes to SIMPLE tier (gpt-4o-mini)
response = litellm.completion(
model="smart-router",
messages=[{"role": "user", "content": "Think step by step: analyze the performance implications of implementing a distributed consensus algorithm for our microservices architecture."}]
)
# Routes to REASONING tier (o1-preview)
```
## Special Behaviors
### Reasoning Override
If 2+ reasoning markers are detected in the user message, the request is automatically routed to the REASONING tier regardless of the weighted score. This ensures complex reasoning tasks get the appropriate model.
### System Prompt Handling
Reasoning markers in the system prompt do **not** trigger the reasoning override. This prevents system prompts like "Think step by step before answering" from forcing all requests to the reasoning tier.
### Code Detection
Technical code keywords are detected case-insensitively and include:
- Language keywords: `function`, `class`, `def`, `const`, `let`, `var`
- Operations: `import`, `export`, `return`, `async`, `await`
- Infrastructure: `database`, `api`, `endpoint`, `docker`, `kubernetes`
- Actions: `debug`, `implement`, `refactor`, `optimize`
## Performance
- **Classification time**: <1ms typical
- **Memory usage**: Minimal (compiled regex patterns + keyword sets)
- **No external dependencies**: Works offline with no API calls
## Comparison with auto_router
| Feature | complexity_router | auto_router |
|---------|-------------------|-------------|
| Classification | Rule-based scoring | Semantic embedding |
| Latency | <1ms | ~100-500ms (embedding API) |
| API Calls | None | Requires embedding model |
| Training | None | Requires utterance examples |
| Customization | Weights, keywords, thresholds | Utterance examples |
| Best For | Cost optimization | Intent routing |
Use `complexity_router` when you want to optimize costs by routing simple queries to cheaper models. Use `auto_router` when you need semantic intent matching (e.g., routing "customer support" queries to a specialized model).

View File

@@ -0,0 +1,22 @@
"""
Complexity-based Auto Router
A rule-based routing strategy that uses weighted scoring across multiple dimensions
to classify requests by complexity and route them to appropriate models.
No external API calls - all scoring is local and <1ms.
"""
from litellm.router_strategy.complexity_router.complexity_router import ComplexityRouter
from litellm.router_strategy.complexity_router.config import (
ComplexityTier,
DEFAULT_COMPLEXITY_CONFIG,
ComplexityRouterConfig,
)
__all__ = [
"ComplexityRouter",
"ComplexityTier",
"DEFAULT_COMPLEXITY_CONFIG",
"ComplexityRouterConfig",
]

View File

@@ -0,0 +1,410 @@
"""
Complexity-based Auto Router
A rule-based routing strategy that uses weighted scoring across multiple dimensions
to classify requests by complexity and route them to appropriate models.
No external API calls - all scoring is local and <1ms.
Inspired by ClawRouter: https://github.com/BlockRunAI/ClawRouter
"""
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from litellm._logging import verbose_router_logger
from litellm.integrations.custom_logger import CustomLogger
from .config import (
DEFAULT_CODE_KEYWORDS,
DEFAULT_REASONING_KEYWORDS,
DEFAULT_SIMPLE_KEYWORDS,
DEFAULT_TECHNICAL_KEYWORDS,
ComplexityRouterConfig,
ComplexityTier,
)
if TYPE_CHECKING:
from litellm.router import Router
from litellm.types.router import PreRoutingHookResponse
else:
Router = Any
PreRoutingHookResponse = Any
class DimensionScore:
"""Represents a score for a single dimension with optional signal."""
__slots__ = ("name", "score", "signal")
def __init__(self, name: str, score: float, signal: Optional[str] = None):
self.name = name
self.score = score
self.signal = signal
class ComplexityRouter(CustomLogger):
"""
Rule-based complexity router that classifies requests and routes to appropriate models.
Handles requests in <1ms with zero external API calls by using weighted scoring
across multiple dimensions:
- Token count (short=simple, long=complex)
- Code presence (code keywords → complex)
- Reasoning markers ("step by step", "think through" → reasoning tier)
- Technical terms (domain complexity)
- Simple indicators ("what is", "define" → simple, negative weight)
- Multi-step patterns ("first...then", numbered steps)
- Question complexity (multiple questions)
"""
def __init__(
self,
model_name: str,
litellm_router_instance: "Router",
complexity_router_config: Optional[Dict[str, Any]] = None,
default_model: Optional[str] = None,
):
"""
Initialize ComplexityRouter.
Args:
model_name: The name of the model/deployment using this router.
litellm_router_instance: The LiteLLM Router instance.
complexity_router_config: Optional configuration dict from proxy config.
default_model: Optional default model to use if tier cannot be determined.
"""
self.model_name = model_name
self.litellm_router_instance = litellm_router_instance
# Parse config - always create a new instance to avoid singleton mutation
if complexity_router_config:
self.config = ComplexityRouterConfig(**complexity_router_config)
else:
self.config = ComplexityRouterConfig()
# Override default_model if provided
if default_model:
self.config.default_model = default_model
# Build effective keyword lists (use config overrides or defaults)
self.code_keywords = self.config.code_keywords or DEFAULT_CODE_KEYWORDS
self.reasoning_keywords = (
self.config.reasoning_keywords or DEFAULT_REASONING_KEYWORDS
)
self.technical_keywords = (
self.config.technical_keywords or DEFAULT_TECHNICAL_KEYWORDS
)
self.simple_keywords = self.config.simple_keywords or DEFAULT_SIMPLE_KEYWORDS
# Pre-compile regex patterns for efficiency
# Use non-greedy .*? to prevent ReDoS on pathological inputs
self._multi_step_patterns = [
re.compile(r"first.*?then", re.IGNORECASE),
re.compile(r"step\s*\d", re.IGNORECASE),
re.compile(r"\d+\.\s"),
re.compile(r"[a-z]\)\s", re.IGNORECASE),
]
verbose_router_logger.debug(
f"ComplexityRouter initialized for {model_name} with tiers: {self.config.tiers}"
)
def _estimate_tokens(self, text: str) -> int:
"""
Estimate token count from text.
Uses a simple heuristic: ~4 characters per token on average.
"""
return len(text) // 4
def _score_token_count(self, estimated_tokens: int) -> DimensionScore:
"""Score based on token count."""
thresholds = self.config.token_thresholds
simple_threshold = thresholds.get("simple", 15)
complex_threshold = thresholds.get("complex", 400)
if estimated_tokens < simple_threshold:
return DimensionScore(
"tokenCount", -1.0, f"short ({estimated_tokens} tokens)"
)
if estimated_tokens > complex_threshold:
return DimensionScore(
"tokenCount", 1.0, f"long ({estimated_tokens} tokens)"
)
return DimensionScore("tokenCount", 0, None)
def _keyword_matches(self, text: str, keyword: str) -> bool:
"""
Check if a keyword matches in text using word boundary matching.
For single-word keywords, uses regex word boundaries to avoid
false positives (e.g., "error" matching "terrorism", "class" matching "classical").
For multi-word phrases, uses substring matching.
"""
kw_lower = keyword.lower()
# For single-word keywords, use word boundary matching to avoid false positives
# e.g., "api" should not match "capital", "error" should not match "terrorism"
if " " not in kw_lower:
pattern = r"\b" + re.escape(kw_lower) + r"\b"
return bool(re.search(pattern, text))
# For multi-word phrases, substring matching is fine
return kw_lower in text
def _score_keyword_match(
self,
text: str,
keywords: List[str],
name: str,
signal_label: str,
thresholds: Tuple[int, int], # (low, high)
scores: Tuple[float, float, float], # (none, low, high)
) -> Tuple[DimensionScore, int]:
"""Score based on keyword matches using word boundary matching.
Returns:
Tuple of (DimensionScore, match_count) so callers can reuse the count.
"""
low_threshold, high_threshold = thresholds
score_none, score_low, score_high = scores
matches = [kw for kw in keywords if self._keyword_matches(text, kw)]
match_count = len(matches)
if match_count >= high_threshold:
return (
DimensionScore(
name, score_high, f"{signal_label} ({', '.join(matches[:3])})"
),
match_count,
)
if match_count >= low_threshold:
return (
DimensionScore(
name, score_low, f"{signal_label} ({', '.join(matches[:3])})"
),
match_count,
)
return DimensionScore(name, score_none, None), match_count
def _score_multi_step(self, text: str) -> DimensionScore:
"""Score based on multi-step patterns."""
hits = sum(1 for p in self._multi_step_patterns if p.search(text))
if hits > 0:
return DimensionScore("multiStepPatterns", 0.5, "multi-step")
return DimensionScore("multiStepPatterns", 0, None)
def _score_question_complexity(self, text: str) -> DimensionScore:
"""Score based on number of question marks."""
count = text.count("?")
if count > 3:
return DimensionScore("questionComplexity", 0.5, f"{count} questions")
return DimensionScore("questionComplexity", 0, None)
def classify(
self, prompt: str, system_prompt: Optional[str] = None
) -> Tuple[ComplexityTier, float, List[str]]:
"""
Classify a prompt by complexity.
Args:
prompt: The user's prompt/message.
system_prompt: Optional system prompt for context.
Returns:
Tuple of (tier, score, signals) where:
- tier: The ComplexityTier (SIMPLE, MEDIUM, COMPLEX, REASONING)
- score: The raw weighted score
- signals: List of triggered signals for debugging
"""
# Combine text for analysis.
# System prompt is intentionally included in code/technical/simple scoring
# because it provides deployment-level context (e.g., "You are a Python assistant"
# signals that code-capable models are appropriate). Reasoning markers use
# user_text only to prevent system prompts from forcing REASONING tier.
full_text = f"{system_prompt or ''} {prompt}".lower()
user_text = prompt.lower()
# Estimate tokens
estimated_tokens = self._estimate_tokens(prompt)
# Score all dimensions, capturing match counts where needed
code_score, _ = self._score_keyword_match(
full_text,
self.code_keywords,
"codePresence",
"code",
(1, 2),
(0, 0.5, 1.0),
)
reasoning_score, reasoning_match_count = self._score_keyword_match(
user_text,
self.reasoning_keywords,
"reasoningMarkers",
"reasoning",
(1, 2),
(0, 0.7, 1.0),
)
technical_score, _ = self._score_keyword_match(
full_text,
self.technical_keywords,
"technicalTerms",
"technical",
(2, 4),
(0, 0.5, 1.0),
)
simple_score, _ = self._score_keyword_match(
full_text,
self.simple_keywords,
"simpleIndicators",
"simple",
(1, 2),
(0, -1.0, -1.0),
)
dimensions: List[DimensionScore] = [
self._score_token_count(estimated_tokens),
code_score,
reasoning_score,
technical_score,
simple_score,
self._score_multi_step(full_text),
self._score_question_complexity(prompt),
]
# Collect signals
signals = [d.signal for d in dimensions if d.signal is not None]
# Compute weighted score
weights = self.config.dimension_weights
weighted_score = sum(d.score * weights.get(d.name, 0) for d in dimensions)
# Check for reasoning override (2+ reasoning markers)
# Reuse match count from _score_keyword_match to avoid scanning twice
if reasoning_match_count >= 2:
return ComplexityTier.REASONING, weighted_score, signals
# Map score to tier
boundaries = self.config.tier_boundaries
simple_medium = boundaries.get("simple_medium", 0.15)
medium_complex = boundaries.get("medium_complex", 0.35)
complex_reasoning = boundaries.get("complex_reasoning", 0.60)
if weighted_score < simple_medium:
tier = ComplexityTier.SIMPLE
elif weighted_score < medium_complex:
tier = ComplexityTier.MEDIUM
elif weighted_score < complex_reasoning:
tier = ComplexityTier.COMPLEX
else:
tier = ComplexityTier.REASONING
return tier, weighted_score, signals
def get_model_for_tier(self, tier: ComplexityTier) -> str:
"""
Get the model name for a given complexity tier.
Args:
tier: The complexity tier.
Returns:
The model name configured for that tier.
"""
tier_key = tier.value if isinstance(tier, ComplexityTier) else tier
# Check config tiers mapping
model = self.config.tiers.get(tier_key)
if model:
return model
# Fallback to default model if configured
if self.config.default_model:
return self.config.default_model
# Last resort: return MEDIUM tier model or error
medium_model = self.config.tiers.get(ComplexityTier.MEDIUM.value)
if medium_model:
return medium_model
raise ValueError(
f"No model configured for tier {tier_key} and no default_model set"
)
async def async_pre_routing_hook(
self,
model: str,
request_kwargs: Dict,
messages: Optional[List[Dict[str, Any]]] = None,
input: Optional[Union[str, List]] = None,
specific_deployment: Optional[bool] = False,
) -> Optional["PreRoutingHookResponse"]:
"""
Pre-routing hook called before the routing decision.
Classifies the request by complexity and returns the appropriate model.
Args:
model: The original model name requested.
request_kwargs: The request kwargs.
messages: The messages in the request.
input: Optional input for embeddings.
specific_deployment: Whether a specific deployment was requested.
Returns:
PreRoutingHookResponse with the routed model, or None if no routing needed.
"""
from litellm.types.router import PreRoutingHookResponse
if messages is None or len(messages) == 0:
verbose_router_logger.debug(
"ComplexityRouter: No messages provided, skipping routing"
)
return None
# Extract the last user message and the last system prompt
user_message: Optional[str] = None
system_prompt: Optional[str] = None
for msg in reversed(messages):
role = msg.get("role", "")
content = msg.get("content") or ""
# content may be a list of content parts (e.g. [{"type": "text", "text": "..."}])
if isinstance(content, list):
text_parts = [
part.get("text", "")
for part in content
if isinstance(part, dict) and part.get("type") == "text"
]
content = " ".join(text_parts).strip()
if isinstance(content, str) and content:
if role == "user" and user_message is None:
user_message = content
elif role == "system" and system_prompt is None:
system_prompt = content
if user_message is None:
verbose_router_logger.debug(
"ComplexityRouter: No user message found, routing to default model"
)
return PreRoutingHookResponse(
model=self.config.default_model
or self.get_model_for_tier(ComplexityTier.MEDIUM),
messages=messages,
)
# Classify the request
tier, score, signals = self.classify(user_message, system_prompt)
# Get the model for this tier
routed_model = self.get_model_for_tier(tier)
verbose_router_logger.info(
f"ComplexityRouter: tier={tier.value}, score={score:.3f}, "
f"signals={signals}, routed_model={routed_model}"
)
return PreRoutingHookResponse(
model=routed_model,
messages=messages,
)

View File

@@ -0,0 +1,255 @@
"""
Configuration for the Complexity Router.
Contains default keyword lists, weights, tier boundaries, and configuration classes.
All values are configurable via proxy config.yaml.
"""
from enum import Enum
from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field
class ComplexityTier(str, Enum):
"""Complexity tiers for routing decisions."""
SIMPLE = "SIMPLE"
MEDIUM = "MEDIUM"
COMPLEX = "COMPLEX"
REASONING = "REASONING"
# ─── Default Keyword Lists ───
# Note: Keywords should be full words/phrases to avoid substring false positives.
# The matching logic uses word boundary detection for single-word keywords.
DEFAULT_CODE_KEYWORDS: List[str] = [
"function",
"class",
"def",
"const",
"let",
"var",
"import",
"export",
"return",
"async",
"await",
"try",
"catch",
"exception",
"error",
"debug",
"api",
"endpoint",
"request",
"response",
"database",
"sql",
"query",
"schema",
"algorithm",
"implement",
"refactor",
"optimize",
"python",
"javascript",
"typescript",
"java",
"rust",
"golang",
"react",
"vue",
"angular",
"node",
"docker",
"kubernetes",
"git",
"commit",
"merge",
"branch",
"pull request",
]
DEFAULT_REASONING_KEYWORDS: List[str] = [
"step by step",
"think through",
"let's think",
"reason through",
"analyze this",
"break down",
"explain your reasoning",
"show your work",
"chain of thought",
"think carefully",
"consider all",
"evaluate",
"pros and cons",
"compare and contrast",
"weigh the options",
"logical",
"deduce",
"infer",
"conclude",
]
DEFAULT_TECHNICAL_KEYWORDS: List[str] = [
"architecture",
"distributed",
"scalable",
"microservice",
"machine learning",
"neural network",
"deep learning",
"encryption",
"authentication",
"authorization",
"performance",
"latency",
"throughput",
"benchmark",
"concurrency",
"parallel",
"threading",
"memory",
"cpu",
"gpu",
"optimization",
"protocol",
"tcp",
"http",
"grpc",
"websocket",
"container",
"orchestration",
# Note: "async", "kubernetes", "docker" are in DEFAULT_CODE_KEYWORDS
]
DEFAULT_SIMPLE_KEYWORDS: List[str] = [
"what is",
"what's",
"define",
"definition of",
"who is",
"who was",
"when did",
"when was",
"where is",
"where was",
"how many",
"how much",
"yes or no",
"true or false",
"simple",
"brief",
"short",
"quick",
"hello",
"hi",
"hey",
"thanks",
"thank you",
"goodbye",
"bye",
"okay",
# Note: "ok" removed due to false positives (matches "token", "book", etc.)
]
# ─── Default Dimension Weights ───
DEFAULT_DIMENSION_WEIGHTS: Dict[str, float] = {
"tokenCount": 0.10, # Reduced - length is less important than content
"codePresence": 0.30, # High - code requests need capable models
"reasoningMarkers": 0.25, # High - explicit reasoning requests
"technicalTerms": 0.25, # High - technical content matters
"simpleIndicators": 0.05, # Low - don't over-penalize simple patterns
"multiStepPatterns": 0.03,
"questionComplexity": 0.02,
}
# ─── Default Tier Boundaries ───
DEFAULT_TIER_BOUNDARIES: Dict[str, float] = {
"simple_medium": 0.15, # Lower threshold to catch more MEDIUM cases
"medium_complex": 0.35, # Lower threshold to catch technical COMPLEX cases
"complex_reasoning": 0.60, # Reasoning tier reserved for explicit reasoning markers
}
# ─── Default Token Thresholds ───
DEFAULT_TOKEN_THRESHOLDS: Dict[str, int] = {
"simple": 15, # Only very short prompts (<15 tokens) are penalized
"complex": 400, # Long prompts (>400 tokens) get complexity boost
}
# ─── Default Tier to Model Mapping ───
DEFAULT_TIER_MODELS: Dict[str, str] = {
"SIMPLE": "gpt-4o-mini",
"MEDIUM": "gpt-4o",
"COMPLEX": "claude-sonnet-4-20250514",
"REASONING": "claude-sonnet-4-20250514",
}
class ComplexityRouterConfig(BaseModel):
"""Configuration for the ComplexityRouter."""
# Tier to model mapping
tiers: Dict[str, str] = Field(
default_factory=lambda: DEFAULT_TIER_MODELS.copy(),
description="Mapping of complexity tiers to model names",
)
# Tier boundaries (normalized scores)
tier_boundaries: Dict[str, float] = Field(
default_factory=lambda: DEFAULT_TIER_BOUNDARIES.copy(),
description="Score boundaries between tiers",
)
# Token count thresholds
token_thresholds: Dict[str, int] = Field(
default_factory=lambda: DEFAULT_TOKEN_THRESHOLDS.copy(),
description="Token count thresholds for simple/complex classification",
)
# Dimension weights
dimension_weights: Dict[str, float] = Field(
default_factory=lambda: DEFAULT_DIMENSION_WEIGHTS.copy(),
description="Weights for each scoring dimension",
)
# Keyword lists (overridable)
code_keywords: Optional[List[str]] = Field(
default=None,
description="Keywords indicating code-related content",
)
reasoning_keywords: Optional[List[str]] = Field(
default=None,
description="Keywords indicating reasoning-required content",
)
technical_keywords: Optional[List[str]] = Field(
default=None,
description="Keywords indicating technical content",
)
simple_keywords: Optional[List[str]] = Field(
default=None,
description="Keywords indicating simple/basic queries",
)
# Default model if scoring fails
default_model: Optional[str] = Field(
default=None,
description="Default model to use if tier cannot be determined",
)
model_config = ConfigDict(extra="allow") # Allow additional fields
# Combined default config
DEFAULT_COMPLEXITY_CONFIG = ComplexityRouterConfig()

View File

@@ -0,0 +1 @@
# Evaluation suite for ComplexityRouter

View File

@@ -0,0 +1,343 @@
"""
Evaluation suite for the ComplexityRouter.
Tests the router's ability to correctly classify prompts into complexity tiers.
Run with: python -m litellm.router_strategy.complexity_router.evals.eval_complexity_router
"""
import os
# Add parent to path for imports
import sys
# ruff: noqa: T201
from dataclasses import dataclass
from typing import List, Optional, Tuple
from unittest.mock import MagicMock
sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../.."))
)
from litellm.router_strategy.complexity_router.complexity_router import ComplexityRouter
from litellm.router_strategy.complexity_router.config import ComplexityTier
@dataclass
class EvalCase:
"""A single evaluation case."""
prompt: str
expected_tier: ComplexityTier
description: str
system_prompt: Optional[str] = None
# Allow some flexibility - if actual tier is in acceptable_tiers, still passes
acceptable_tiers: Optional[List[ComplexityTier]] = None
# ─── Evaluation Dataset ───
EVAL_CASES: List[EvalCase] = [
# === SIMPLE tier cases ===
EvalCase(
prompt="Hello!",
expected_tier=ComplexityTier.SIMPLE,
description="Basic greeting",
),
EvalCase(
prompt="What is Python?",
expected_tier=ComplexityTier.SIMPLE,
description="Simple definition question",
),
EvalCase(
prompt="Who is Elon Musk?",
expected_tier=ComplexityTier.SIMPLE,
description="Simple factual question",
),
EvalCase(
prompt="What's the capital of France?",
expected_tier=ComplexityTier.SIMPLE,
description="Simple geography question",
),
EvalCase(
prompt="Thanks for your help!",
expected_tier=ComplexityTier.SIMPLE,
description="Simple thank you",
),
EvalCase(
prompt="Define machine learning",
expected_tier=ComplexityTier.SIMPLE,
description="Definition request",
),
EvalCase(
prompt="When was the iPhone released?",
expected_tier=ComplexityTier.SIMPLE,
description="Simple date question",
),
EvalCase(
prompt="How many planets are in our solar system?",
expected_tier=ComplexityTier.SIMPLE,
description="Simple count question",
),
EvalCase(
prompt="Yes",
expected_tier=ComplexityTier.SIMPLE,
description="Single word response",
),
EvalCase(
prompt="What time is it in Tokyo?",
expected_tier=ComplexityTier.SIMPLE,
description="Simple time zone question",
),
# === MEDIUM tier cases ===
EvalCase(
prompt="Explain how REST APIs work and when to use them",
expected_tier=ComplexityTier.MEDIUM,
description="Technical explanation",
acceptable_tiers=[ComplexityTier.SIMPLE, ComplexityTier.MEDIUM],
),
EvalCase(
prompt="Write a short poem about the ocean",
expected_tier=ComplexityTier.MEDIUM,
description="Creative writing - short",
acceptable_tiers=[ComplexityTier.SIMPLE, ComplexityTier.MEDIUM],
),
EvalCase(
prompt="Summarize the main differences between SQL and NoSQL databases",
expected_tier=ComplexityTier.MEDIUM,
description="Technical comparison",
acceptable_tiers=[ComplexityTier.MEDIUM, ComplexityTier.COMPLEX],
),
EvalCase(
prompt="What are the benefits of using TypeScript over JavaScript?",
expected_tier=ComplexityTier.MEDIUM,
description="Technical comparison question",
acceptable_tiers=[ComplexityTier.SIMPLE, ComplexityTier.MEDIUM],
),
EvalCase(
prompt="Help me debug this error: TypeError: Cannot read property 'map' of undefined",
expected_tier=ComplexityTier.MEDIUM,
description="Debugging help",
acceptable_tiers=[ComplexityTier.MEDIUM, ComplexityTier.COMPLEX],
),
# === COMPLEX tier cases ===
EvalCase(
prompt="Design a distributed microservice architecture for a high-throughput "
"real-time data processing pipeline with Kubernetes orchestration, "
"implementing proper authentication and encryption protocols",
expected_tier=ComplexityTier.COMPLEX,
description="Complex architecture design",
acceptable_tiers=[ComplexityTier.COMPLEX, ComplexityTier.REASONING],
),
EvalCase(
prompt="Write a Python function that implements a binary search tree with "
"insert, delete, and search operations. Include proper error handling "
"and optimize for memory efficiency.",
expected_tier=ComplexityTier.COMPLEX,
description="Complex coding task",
acceptable_tiers=[ComplexityTier.MEDIUM, ComplexityTier.COMPLEX],
),
EvalCase(
prompt="Explain the differences between TCP and UDP protocols, including "
"use cases for each, performance implications, and how they handle "
"packet loss in distributed systems",
expected_tier=ComplexityTier.COMPLEX,
description="Deep technical explanation",
acceptable_tiers=[ComplexityTier.MEDIUM, ComplexityTier.COMPLEX],
),
EvalCase(
prompt="Create a comprehensive database schema for an e-commerce platform "
"that handles users, products, orders, payments, shipping, reviews, "
"and inventory management with proper indexing strategies",
expected_tier=ComplexityTier.COMPLEX,
description="Complex database design",
acceptable_tiers=[
ComplexityTier.MEDIUM,
ComplexityTier.COMPLEX,
ComplexityTier.REASONING,
],
),
EvalCase(
prompt="Implement a rate limiter using the token bucket algorithm in Python "
"that supports multiple rate limit tiers and can be used across "
"distributed systems with Redis as the backend",
expected_tier=ComplexityTier.COMPLEX,
description="Complex distributed systems coding",
acceptable_tiers=[
ComplexityTier.MEDIUM,
ComplexityTier.COMPLEX,
ComplexityTier.REASONING,
],
),
# === REASONING tier cases ===
EvalCase(
prompt="Think step by step about how to solve this: A farmer has 17 sheep. "
"All but 9 die. How many are left? Explain your reasoning.",
expected_tier=ComplexityTier.REASONING,
description="Explicit reasoning request",
),
EvalCase(
prompt="Let's think through this carefully. Analyze the pros and cons of "
"microservices vs monolithic architecture for a startup with 5 engineers. "
"Consider scalability, development speed, and operational complexity.",
expected_tier=ComplexityTier.REASONING,
description="Multiple reasoning markers + analysis",
),
EvalCase(
prompt="Reason through this problem: If I have a function that's O(n^2) and "
"I need to process 1 million items, what are my options to optimize it? "
"Walk me through each approach step by step.",
expected_tier=ComplexityTier.REASONING,
description="Algorithm reasoning",
),
EvalCase(
prompt="I need you to think carefully and analyze this code for potential "
"security vulnerabilities. Consider injection attacks, authentication "
"bypasses, and data exposure risks. Show your reasoning process.",
expected_tier=ComplexityTier.REASONING,
description="Security analysis with reasoning",
acceptable_tiers=[ComplexityTier.COMPLEX, ComplexityTier.REASONING],
),
EvalCase(
prompt="Step by step, explain your reasoning as you evaluate whether we should "
"use PostgreSQL or MongoDB for our new project. Consider our requirements: "
"complex queries, high write volume, and eventual consistency is acceptable.",
expected_tier=ComplexityTier.REASONING,
description="Database decision with explicit reasoning",
),
# === Edge cases / regression tests ===
EvalCase(
prompt="What is the capital of France?",
expected_tier=ComplexityTier.SIMPLE,
description="Regression: 'capital' should not trigger 'api' keyword",
),
EvalCase(
prompt="I tried to book a flight but the entry form wasn't working",
expected_tier=ComplexityTier.SIMPLE,
description="Regression: 'tried' and 'entry' should not trigger code keywords",
acceptable_tiers=[ComplexityTier.SIMPLE, ComplexityTier.MEDIUM],
),
EvalCase(
prompt="The poetry of digital art is fascinating",
expected_tier=ComplexityTier.SIMPLE,
description="Regression: 'poetry' should not trigger 'try' keyword",
acceptable_tiers=[ComplexityTier.SIMPLE, ComplexityTier.MEDIUM],
),
EvalCase(
prompt="Can you recommend a good book about country music history?",
expected_tier=ComplexityTier.SIMPLE,
description="Regression: 'country' should not trigger 'try' keyword",
acceptable_tiers=[ComplexityTier.SIMPLE, ComplexityTier.MEDIUM],
),
]
def run_eval() -> Tuple[int, int, List[dict]]:
"""
Run the evaluation suite.
Returns:
Tuple of (passed, total, failures)
"""
# Create router with default config
mock_router = MagicMock()
router = ComplexityRouter(
model_name="eval-router",
litellm_router_instance=mock_router,
)
passed = 0
total = len(EVAL_CASES)
failures = []
print("=" * 70) # noqa: T201
print("COMPLEXITY ROUTER EVALUATION") # noqa: T201
print("=" * 70) # noqa: T201
print() # noqa: T201
for i, case in enumerate(EVAL_CASES, 1):
tier, score, signals = router.classify(case.prompt, case.system_prompt)
# Check if pass
is_exact_match = tier == case.expected_tier
is_acceptable = (
case.acceptable_tiers is not None and tier in case.acceptable_tiers
)
is_pass = is_exact_match or is_acceptable
if is_pass:
passed += 1
status = "✓ PASS"
else:
status = "✗ FAIL"
failures.append(
{
"case": i,
"description": case.description,
"prompt": case.prompt[:80] + "..."
if len(case.prompt) > 80
else case.prompt,
"expected": case.expected_tier.value,
"actual": tier.value,
"score": round(score, 3),
"signals": signals,
"acceptable": [t.value for t in case.acceptable_tiers]
if case.acceptable_tiers
else None,
}
)
# Print result
print(f"[{i:2d}] {status} | {case.description}") # noqa: T201
print(
f" Expected: {case.expected_tier.value:10s} | Got: {tier.value:10s} | Score: {score:+.3f}"
) # noqa: T201
if signals:
print(f" Signals: {', '.join(signals)}") # noqa: T201
if not is_pass:
print(f" Prompt: {case.prompt[:60]}...") # noqa: T201
print() # noqa: T201
# Summary
print("=" * 70) # noqa: T201
print(f"RESULTS: {passed}/{total} passed ({100*passed/total:.1f}%)") # noqa: T201
print("=" * 70) # noqa: T201
if failures:
print("\nFAILURES:") # noqa: T201
print("-" * 70) # noqa: T201
for f in failures:
print(f"Case {f['case']}: {f['description']}") # noqa: T201
print(
f" Expected: {f['expected']}, Got: {f['actual']} (score: {f['score']})"
) # noqa: T201
print(f" Signals: {f['signals']}") # noqa: T201
if f["acceptable"]:
print(f" Acceptable: {f['acceptable']}") # noqa: T201
print() # noqa: T201
return passed, total, failures
def main():
"""Main entry point."""
passed, total, failures = run_eval()
# Exit with error code if too many failures
pass_rate = passed / total
if pass_rate < 0.80:
print(
f"\n❌ EVAL FAILED: Pass rate {pass_rate:.1%} is below 80% threshold"
) # noqa: T201
sys.exit(1)
elif pass_rate < 0.90:
print(
f"\n⚠️ EVAL WARNING: Pass rate {pass_rate:.1%} is below 90%"
) # noqa: T201
sys.exit(0)
else:
print(f"\n✅ EVAL PASSED: Pass rate {pass_rate:.1%}") # noqa: T201
sys.exit(0)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,250 @@
#### What this does ####
# identifies least busy deployment
# How is this achieved?
# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
# - use litellm.success + failure callbacks to log when a request completed
# - in get_available_deployment, for a given model group name -> pick based on traffic
import random
from typing import Optional
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
class LeastBusyLoggingHandler(CustomLogger):
test_flag: bool = False
logged_success: int = 0
logged_failure: int = 0
def __init__(self, router_cache: DualCache):
self.router_cache = router_cache
def log_pre_api_call(self, model, messages, kwargs):
"""
Log when a model is being used.
Caching based on model group.
"""
try:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
request_count_api_key = f"{model_group}_request_count"
# update cache
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
except Exception:
pass
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_value: Optional[int] = request_count_dict.get(id, 0)
if request_count_value is None:
return
request_count_dict[id] = request_count_value - 1
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception:
pass
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_value: Optional[int] = request_count_dict.get(id, 0)
if request_count_value is None:
return
request_count_dict[id] = request_count_value - 1
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
### TESTING ###
if self.test_flag:
self.logged_failure += 1
except Exception:
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
request_count_dict = (
await self.router_cache.async_get_cache(key=request_count_api_key)
or {}
)
request_count_value: Optional[int] = request_count_dict.get(id, 0)
if request_count_value is None:
return
request_count_dict[id] = request_count_value - 1
await self.router_cache.async_set_cache(
key=request_count_api_key, value=request_count_dict
)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception:
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
request_count_dict = (
await self.router_cache.async_get_cache(key=request_count_api_key)
or {}
)
request_count_value: Optional[int] = request_count_dict.get(id, 0)
if request_count_value is None:
return
request_count_dict[id] = request_count_value - 1
await self.router_cache.async_set_cache(
key=request_count_api_key, value=request_count_dict
)
### TESTING ###
if self.test_flag:
self.logged_failure += 1
except Exception:
pass
def _get_available_deployments(
self,
healthy_deployments: list,
all_deployments: dict,
):
"""
Helper to get deployments using least busy strategy
"""
for d in healthy_deployments:
## if healthy deployment not yet used
if d["model_info"]["id"] not in all_deployments:
all_deployments[d["model_info"]["id"]] = 0
# map deployment to id
# pick least busy deployment
min_traffic = float("inf")
min_deployment = None
for k, v in all_deployments.items():
if v < min_traffic:
min_traffic = v
min_deployment = k
if min_deployment is not None:
## check if min deployment is a string, if so, cast it to int
for m in healthy_deployments:
if m["model_info"]["id"] == min_deployment:
return m
min_deployment = random.choice(healthy_deployments)
else:
min_deployment = random.choice(healthy_deployments)
return min_deployment
def get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
):
"""
Sync helper to get deployments using least busy strategy
"""
request_count_api_key = f"{model_group}_request_count"
all_deployments = self.router_cache.get_cache(key=request_count_api_key) or {}
return self._get_available_deployments(
healthy_deployments=healthy_deployments,
all_deployments=all_deployments,
)
async def async_get_available_deployments(
self, model_group: str, healthy_deployments: list
):
"""
Async helper to get deployments using least busy strategy
"""
request_count_api_key = f"{model_group}_request_count"
all_deployments = (
await self.router_cache.async_get_cache(key=request_count_api_key) or {}
)
return self._get_available_deployments(
healthy_deployments=healthy_deployments,
all_deployments=all_deployments,
)

View File

@@ -0,0 +1,330 @@
#### What this does ####
# picks based on response time (for streaming, this is time to first token)
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Union
import litellm
from litellm import ModelResponse, token_counter, verbose_logger
from litellm._logging import verbose_router_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
class LowestCostLoggingHandler(CustomLogger):
test_flag: bool = False
logged_success: int = 0
logged_failure: int = 0
def __init__(self, router_cache: DualCache, routing_args: dict = {}):
self.router_cache = router_cache
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
Update usage on success
"""
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
# ------------
# Setup values
# ------------
"""
{
{model_group}_map: {
id: {
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
}
}
}
"""
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
cost_key = f"{model_group}_map"
response_ms: timedelta = end_time - start_time
total_tokens = 0
if isinstance(response_obj, ModelResponse):
_usage = getattr(response_obj, "usage", None)
if _usage is not None and isinstance(_usage, litellm.Usage):
completion_tokens = _usage.completion_tokens
total_tokens = _usage.total_tokens
float(response_ms.total_seconds() / completion_tokens)
# ------------
# Update usage
# ------------
request_count_dict = self.router_cache.get_cache(key=cost_key) or {}
# check local result first
if id not in request_count_dict:
request_count_dict[id] = {}
if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}
## TPM
request_count_dict[id][precise_minute]["tpm"] = (
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
)
## RPM
request_count_dict[id][precise_minute]["rpm"] = (
request_count_dict[id][precise_minute].get("rpm", 0) + 1
)
self.router_cache.set_cache(key=cost_key, value=request_count_dict)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
verbose_logger.exception(
"litellm.router_strategy.lowest_cost.py::log_success_event(): Exception occured - {}".format(
str(e)
)
)
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
Update cost usage on success
"""
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
# ------------
# Setup values
# ------------
"""
{
{model_group}_map: {
id: {
"cost": [..]
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
}
}
}
"""
cost_key = f"{model_group}_map"
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
response_ms: timedelta = end_time - start_time
total_tokens = 0
if isinstance(response_obj, ModelResponse):
_usage = getattr(response_obj, "usage", None)
if _usage is not None and isinstance(_usage, litellm.Usage):
completion_tokens = _usage.completion_tokens
total_tokens = _usage.total_tokens
float(response_ms.total_seconds() / completion_tokens)
# ------------
# Update usage
# ------------
request_count_dict = (
await self.router_cache.async_get_cache(key=cost_key) or {}
)
if id not in request_count_dict:
request_count_dict[id] = {}
if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}
## TPM
request_count_dict[id][precise_minute]["tpm"] = (
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
)
## RPM
request_count_dict[id][precise_minute]["rpm"] = (
request_count_dict[id][precise_minute].get("rpm", 0) + 1
)
await self.router_cache.async_set_cache(
key=cost_key, value=request_count_dict
) # reset map within window
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
verbose_logger.exception(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
pass
async def async_get_available_deployments( # noqa: PLR0915
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
request_kwargs: Optional[Dict] = None,
):
"""
Returns a deployment with the lowest cost
"""
cost_key = f"{model_group}_map"
request_count_dict = await self.router_cache.async_get_cache(key=cost_key) or {}
# -----------------------
# Find lowest used model
# ----------------------
float("inf")
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
if request_count_dict is None: # base case
return
all_deployments = request_count_dict
for d in healthy_deployments:
## if healthy deployment not yet used
if d["model_info"]["id"] not in all_deployments:
all_deployments[d["model_info"]["id"]] = {
precise_minute: {"tpm": 0, "rpm": 0},
}
try:
input_tokens = token_counter(messages=messages, text=input)
except Exception:
input_tokens = 0
# randomly sample from all_deployments, incase all deployments have latency=0.0
_items = all_deployments.items()
### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits
potential_deployments = []
_cost_per_deployment = {}
for item, item_map in all_deployments.items():
## get the item from model list
_deployment = None
for m in healthy_deployments:
if item == m["model_info"]["id"]:
_deployment = m
if _deployment is None:
continue # skip to next one
_deployment_tpm = (
_deployment.get("tpm", None)
or _deployment.get("litellm_params", {}).get("tpm", None)
or _deployment.get("model_info", {}).get("tpm", None)
or float("inf")
)
_deployment_rpm = (
_deployment.get("rpm", None)
or _deployment.get("litellm_params", {}).get("rpm", None)
or _deployment.get("model_info", {}).get("rpm", None)
or float("inf")
)
item_litellm_model_name = _deployment.get("litellm_params", {}).get("model")
item_litellm_model_cost_map = litellm.model_cost.get(
item_litellm_model_name, {}
)
# check if user provided input_cost_per_token and output_cost_per_token in litellm_params
item_input_cost = None
item_output_cost = None
if _deployment.get("litellm_params", {}).get("input_cost_per_token", None):
item_input_cost = _deployment.get("litellm_params", {}).get(
"input_cost_per_token"
)
if _deployment.get("litellm_params", {}).get("output_cost_per_token", None):
item_output_cost = _deployment.get("litellm_params", {}).get(
"output_cost_per_token"
)
if item_input_cost is None:
item_input_cost = item_litellm_model_cost_map.get(
"input_cost_per_token", 5.0
)
if item_output_cost is None:
item_output_cost = item_litellm_model_cost_map.get(
"output_cost_per_token", 5.0
)
# if litellm["model"] is not in model_cost map -> use item_cost = $10
item_cost = item_input_cost + item_output_cost
item_rpm = item_map.get(precise_minute, {}).get("rpm", 0)
item_tpm = item_map.get(precise_minute, {}).get("tpm", 0)
verbose_router_logger.debug(
f"item_cost: {item_cost}, item_tpm: {item_tpm}, item_rpm: {item_rpm}, model_id: {_deployment.get('model_info', {}).get('id')}"
)
# -------------- #
# Debugging Logic
# -------------- #
# We use _cost_per_deployment to log to langfuse, slack - this is not used to make a decision on routing
# this helps a user to debug why the router picked a specfic deployment #
_deployment_api_base = _deployment.get("litellm_params", {}).get(
"api_base", ""
)
if _deployment_api_base is not None:
_cost_per_deployment[_deployment_api_base] = item_cost
# -------------- #
# End of Debugging Logic
# -------------- #
if (
item_tpm + input_tokens > _deployment_tpm
or item_rpm + 1 > _deployment_rpm
): # if user passed in tpm / rpm in the model_list
continue
else:
potential_deployments.append((_deployment, item_cost))
if len(potential_deployments) == 0:
return None
potential_deployments = sorted(potential_deployments, key=lambda x: x[1])
selected_deployment = potential_deployments[0][0]
return selected_deployment

View File

@@ -0,0 +1,627 @@
#### What this does ####
# picks based on response time (for streaming, this is time to first token)
import random
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import litellm
from litellm import ModelResponse, token_counter, verbose_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import safe_divide_seconds
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.types.utils import LiteLLMPydanticObjectBase
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
class RoutingArgs(LiteLLMPydanticObjectBase):
ttl: float = 1 * 60 * 60 # 1 hour
lowest_latency_buffer: float = 0
max_latency_list_size: int = 10
class LowestLatencyLoggingHandler(CustomLogger):
test_flag: bool = False
logged_success: int = 0
logged_failure: int = 0
def __init__(self, router_cache: DualCache, routing_args: dict = {}):
self.router_cache = router_cache
self.routing_args = RoutingArgs(**routing_args)
def log_success_event( # noqa: PLR0915
self, kwargs, response_obj, start_time, end_time
):
try:
"""
Update latency usage on success
"""
metadata_field = self._select_metadata_field(kwargs)
if kwargs["litellm_params"].get(metadata_field) is None:
pass
else:
model_group = kwargs["litellm_params"][metadata_field].get(
"model_group", None
)
id = (kwargs["litellm_params"].get("model_info") or {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
# ------------
# Setup values
# ------------
"""
{
{model_group}_map: {
id: {
"latency": [..]
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
}
}
}
"""
latency_key = f"{model_group}_map"
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
response_ms = end_time - start_time
time_to_first_token_response_time = None
if kwargs.get("stream", None) is not None and kwargs["stream"] is True:
# only log ttft for streaming request
time_to_first_token_response_time = (
kwargs.get("completion_start_time", end_time) - start_time
)
final_value: Union[float, timedelta] = response_ms
time_to_first_token: Optional[float] = None
total_tokens = 0
if isinstance(response_obj, ModelResponse):
_usage = getattr(response_obj, "usage", None)
if _usage is not None:
completion_tokens = _usage.completion_tokens
total_tokens = _usage.total_tokens
# Handle both timedelta and float response times
if isinstance(response_ms, timedelta):
response_seconds = response_ms.total_seconds()
else:
response_seconds = response_ms
final_value = safe_divide_seconds(
response_seconds, completion_tokens
)
if final_value is not None:
final_value = float(final_value)
else:
final_value = response_seconds
if time_to_first_token_response_time is not None:
if isinstance(time_to_first_token_response_time, timedelta):
ttft_seconds = (
time_to_first_token_response_time.total_seconds()
)
else:
ttft_seconds = time_to_first_token_response_time
time_to_first_token = safe_divide_seconds(
ttft_seconds, completion_tokens
)
# ------------
# Update usage
# ------------
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
request_count_dict = (
self.router_cache.get_cache(
key=latency_key, parent_otel_span=parent_otel_span
)
or {}
)
if id not in request_count_dict:
request_count_dict[id] = {}
## Latency
if (
len(request_count_dict[id].get("latency", []))
< self.routing_args.max_latency_list_size
):
request_count_dict[id].setdefault("latency", []).append(final_value)
else:
request_count_dict[id]["latency"] = request_count_dict[id][
"latency"
][: self.routing_args.max_latency_list_size - 1] + [final_value]
## Time to first token
if time_to_first_token is not None:
if (
len(request_count_dict[id].get("time_to_first_token", []))
< self.routing_args.max_latency_list_size
):
request_count_dict[id].setdefault(
"time_to_first_token", []
).append(time_to_first_token)
else:
request_count_dict[id][
"time_to_first_token"
] = request_count_dict[id]["time_to_first_token"][
: self.routing_args.max_latency_list_size - 1
] + [
time_to_first_token
]
if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}
## TPM
request_count_dict[id][precise_minute]["tpm"] = (
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
)
## RPM
request_count_dict[id][precise_minute]["rpm"] = (
request_count_dict[id][precise_minute].get("rpm", 0) + 1
)
self.router_cache.set_cache(
key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl
) # reset map within window
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
verbose_logger.exception(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
"""
Check if Timeout Error, if timeout set deployment latency -> 100
"""
try:
metadata_field = self._select_metadata_field(kwargs)
_exception = kwargs.get("exception", None)
if isinstance(_exception, litellm.Timeout):
if kwargs["litellm_params"].get(metadata_field) is None:
pass
else:
model_group = kwargs["litellm_params"][metadata_field].get(
"model_group", None
)
id = (kwargs["litellm_params"].get("model_info") or {}).get(
"id", None
)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
# ------------
# Setup values
# ------------
"""
{
{model_group}_map: {
id: {
"latency": [..]
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
}
}
}
"""
latency_key = f"{model_group}_map"
request_count_dict = (
await self.router_cache.async_get_cache(key=latency_key) or {}
)
if id not in request_count_dict:
request_count_dict[id] = {}
## Latency - give 1000s penalty for failing
if (
len(request_count_dict[id].get("latency", []))
< self.routing_args.max_latency_list_size
):
request_count_dict[id].setdefault("latency", []).append(1000.0)
else:
request_count_dict[id]["latency"] = request_count_dict[id][
"latency"
][: self.routing_args.max_latency_list_size - 1] + [1000.0]
await self.router_cache.async_set_cache(
key=latency_key,
value=request_count_dict,
ttl=self.routing_args.ttl,
) # reset map within window
else:
# do nothing if it's not a timeout error
return
except Exception as e:
verbose_logger.exception(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
pass
async def async_log_success_event( # noqa: PLR0915
self, kwargs, response_obj, start_time, end_time
):
try:
"""
Update latency usage on success
"""
metadata_field = self._select_metadata_field(kwargs)
if kwargs["litellm_params"].get(metadata_field) is None:
pass
else:
model_group = kwargs["litellm_params"][metadata_field].get(
"model_group", None
)
id = (kwargs["litellm_params"].get("model_info") or {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
# ------------
# Setup values
# ------------
"""
{
{model_group}_map: {
id: {
"latency": [..]
"time_to_first_token": [..]
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
}
}
}
"""
latency_key = f"{model_group}_map"
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
response_ms = end_time - start_time
time_to_first_token_response_time = None
if kwargs.get("stream", None) is not None and kwargs["stream"] is True:
# only log ttft for streaming request
time_to_first_token_response_time = (
kwargs.get("completion_start_time", end_time) - start_time
)
final_value: Union[float, timedelta] = response_ms
total_tokens = 0
time_to_first_token: Optional[float] = None
if isinstance(response_obj, ModelResponse):
_usage = getattr(response_obj, "usage", None)
if _usage is not None:
completion_tokens = _usage.completion_tokens
total_tokens = _usage.total_tokens
# Handle both timedelta and float response times
if isinstance(response_ms, timedelta):
response_seconds = response_ms.total_seconds()
else:
response_seconds = response_ms
final_value = safe_divide_seconds(
response_seconds, completion_tokens
)
if final_value is not None:
final_value = float(final_value)
else:
final_value = response_ms
if time_to_first_token_response_time is not None:
if isinstance(time_to_first_token_response_time, timedelta):
ttft_seconds = (
time_to_first_token_response_time.total_seconds()
)
else:
ttft_seconds = time_to_first_token_response_time
time_to_first_token = safe_divide_seconds(
ttft_seconds, completion_tokens
)
# ------------
# Update usage
# ------------
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
request_count_dict = (
await self.router_cache.async_get_cache(
key=latency_key,
parent_otel_span=parent_otel_span,
local_only=True,
)
or {}
)
if id not in request_count_dict:
request_count_dict[id] = {}
## Latency
if (
len(request_count_dict[id].get("latency", []))
< self.routing_args.max_latency_list_size
):
request_count_dict[id].setdefault("latency", []).append(final_value)
else:
request_count_dict[id]["latency"] = request_count_dict[id][
"latency"
][: self.routing_args.max_latency_list_size - 1] + [final_value]
## Time to first token
if time_to_first_token is not None:
if (
len(request_count_dict[id].get("time_to_first_token", []))
< self.routing_args.max_latency_list_size
):
request_count_dict[id].setdefault(
"time_to_first_token", []
).append(time_to_first_token)
else:
request_count_dict[id][
"time_to_first_token"
] = request_count_dict[id]["time_to_first_token"][
: self.routing_args.max_latency_list_size - 1
] + [
time_to_first_token
]
if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}
## TPM
request_count_dict[id][precise_minute]["tpm"] = (
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
)
## RPM
request_count_dict[id][precise_minute]["rpm"] = (
request_count_dict[id][precise_minute].get("rpm", 0) + 1
)
await self.router_cache.async_set_cache(
key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl
) # reset map within window
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
verbose_logger.exception(
"litellm.router_strategy.lowest_latency.py::async_log_success_event(): Exception occured - {}".format(
str(e)
)
)
pass
def _get_available_deployments( # noqa: PLR0915
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
request_kwargs: Optional[Dict] = None,
request_count_dict: Optional[Dict] = None,
):
"""Common logic for both sync and async get_available_deployments"""
# -----------------------
# Find lowest used model
# ----------------------
_latency_per_deployment = {}
lowest_latency = float("inf")
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
deployment = None
if request_count_dict is None: # base case
return
all_deployments = request_count_dict
for d in healthy_deployments:
## if healthy deployment not yet used
if d["model_info"]["id"] not in all_deployments:
all_deployments[d["model_info"]["id"]] = {
"latency": [0],
precise_minute: {"tpm": 0, "rpm": 0},
}
try:
input_tokens = token_counter(messages=messages, text=input)
except Exception:
input_tokens = 0
# randomly sample from all_deployments, incase all deployments have latency=0.0
_items = all_deployments.items()
_all_deployments = random.sample(list(_items), len(_items))
all_deployments = dict(_all_deployments)
### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits
potential_deployments = []
for item, item_map in all_deployments.items():
## get the item from model list
_deployment = None
for m in healthy_deployments:
if item == m["model_info"]["id"]:
_deployment = m
if _deployment is None:
continue # skip to next one
_deployment_tpm = (
_deployment.get("tpm", None)
or _deployment.get("litellm_params", {}).get("tpm", None)
or _deployment.get("model_info", {}).get("tpm", None)
or float("inf")
)
_deployment_rpm = (
_deployment.get("rpm", None)
or _deployment.get("litellm_params", {}).get("rpm", None)
or _deployment.get("model_info", {}).get("rpm", None)
or float("inf")
)
item_latency = item_map.get("latency", [])
item_ttft_latency = item_map.get("time_to_first_token", [])
item_rpm = item_map.get(precise_minute, {}).get("rpm", 0)
item_tpm = item_map.get(precise_minute, {}).get("tpm", 0)
# get average latency or average ttft (depending on streaming/non-streaming)
total: float = 0.0
use_ttft = (
request_kwargs is not None
and request_kwargs.get("stream", None) is not None
and request_kwargs["stream"] is True
and len(item_ttft_latency) > 0
)
if use_ttft:
for _call_latency in item_ttft_latency:
if isinstance(_call_latency, float):
total += _call_latency
item_latency = total / len(item_ttft_latency)
else:
for _call_latency in item_latency:
if isinstance(_call_latency, float):
total += _call_latency
item_latency = total / len(item_latency)
# -------------- #
# Debugging Logic
# -------------- #
# We use _latency_per_deployment to log to langfuse, slack - this is not used to make a decision on routing
# this helps a user to debug why the router picked a specfic deployment #
_deployment_api_base = _deployment.get("litellm_params", {}).get(
"api_base", ""
)
if _deployment_api_base is not None:
_latency_per_deployment[_deployment_api_base] = item_latency
# -------------- #
# End of Debugging Logic
# -------------- #
if (
item_tpm + input_tokens > _deployment_tpm
or item_rpm + 1 > _deployment_rpm
): # if user passed in tpm / rpm in the model_list
continue
else:
potential_deployments.append((_deployment, item_latency))
if len(potential_deployments) == 0:
return None
# Sort potential deployments by latency
sorted_deployments = sorted(potential_deployments, key=lambda x: x[1])
# Find lowest latency deployment
lowest_latency = sorted_deployments[0][1]
# Find deployments within buffer of lowest latency
buffer = self.routing_args.lowest_latency_buffer * lowest_latency
valid_deployments = [
x for x in sorted_deployments if x[1] <= lowest_latency + buffer
]
# Pick a random deployment from valid deployments
random_valid_deployment = random.choice(valid_deployments)
deployment = random_valid_deployment[0]
metadata_field = self._select_metadata_field(request_kwargs)
if request_kwargs is not None and metadata_field in request_kwargs:
request_kwargs[metadata_field][
"_latency_per_deployment"
] = _latency_per_deployment
return deployment
async def async_get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
request_kwargs: Optional[Dict] = None,
):
# get list of potential deployments
latency_key = f"{model_group}_map"
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
request_kwargs
)
request_count_dict = (
await self.router_cache.async_get_cache(
key=latency_key, parent_otel_span=parent_otel_span
)
or {}
)
return self._get_available_deployments(
model_group,
healthy_deployments,
messages,
input,
request_kwargs,
request_count_dict,
)
def get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
request_kwargs: Optional[Dict] = None,
):
"""
Returns a deployment with the lowest latency
"""
# get list of potential deployments
latency_key = f"{model_group}_map"
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
request_kwargs
)
request_count_dict = (
self.router_cache.get_cache(
key=latency_key, parent_otel_span=parent_otel_span
)
or {}
)
return self._get_available_deployments(
model_group,
healthy_deployments,
messages,
input,
request_kwargs,
request_count_dict,
)

View File

@@ -0,0 +1,249 @@
#### What this does ####
# identifies lowest tpm deployment
import traceback
from datetime import datetime
from typing import Dict, List, Optional, Union
from litellm import token_counter
from litellm._logging import verbose_router_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import LiteLLMPydanticObjectBase
from litellm.utils import print_verbose
class RoutingArgs(LiteLLMPydanticObjectBase):
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
class LowestTPMLoggingHandler(CustomLogger):
test_flag: bool = False
logged_success: int = 0
logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(self, router_cache: DualCache, routing_args: dict = {}):
self.router_cache = router_cache
self.routing_args = RoutingArgs(**routing_args)
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
Update TPM/RPM usage on success
"""
if "litellm_params" not in kwargs or kwargs["litellm_params"] is None:
return
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
total_tokens = response_obj["usage"]["total_tokens"]
# ------------
# Setup values
# ------------
current_minute = datetime.now().strftime("%H-%M")
tpm_key = f"{model_group}:tpm:{current_minute}"
rpm_key = f"{model_group}:rpm:{current_minute}"
# ------------
# Update usage
# ------------
## TPM
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
self.router_cache.set_cache(
key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
)
## RPM
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(
key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
verbose_router_logger.error(
"litellm.router_strategy.lowest_tpm_rpm.py::async_log_success_event(): Exception occured - {}".format(
str(e)
)
)
verbose_router_logger.debug(traceback.format_exc())
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
Update TPM/RPM usage on success
"""
if "litellm_params" not in kwargs or kwargs["litellm_params"] is None:
return
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
model_info = kwargs["litellm_params"].get("model_info")
id = None
if model_info is not None and isinstance(model_info, dict):
id = model_info.get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
if "usage" not in response_obj:
return
total_tokens = response_obj["usage"]["total_tokens"]
# ------------
# Setup values
# ------------
current_minute = datetime.now().strftime("%H-%M")
tpm_key = f"{model_group}:tpm:{current_minute}"
rpm_key = f"{model_group}:rpm:{current_minute}"
# ------------
# Update usage
# ------------
# update cache
## TPM
request_count_dict = (
await self.router_cache.async_get_cache(key=tpm_key) or {}
)
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
await self.router_cache.async_set_cache(
key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
)
## RPM
request_count_dict = (
await self.router_cache.async_get_cache(key=rpm_key) or {}
)
request_count_dict[id] = request_count_dict.get(id, 0) + 1
await self.router_cache.async_set_cache(
key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
verbose_router_logger.exception(
"litellm.router_strategy.lowest_tpm_rpm.py::async_log_success_event(): Exception occured - {}".format(
str(e)
)
)
verbose_router_logger.debug(traceback.format_exc())
pass
def get_available_deployments( # noqa: PLR0915
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
):
"""
Returns a deployment with the lowest TPM/RPM usage.
"""
# get list of potential deployments
verbose_router_logger.debug(
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
)
current_minute = datetime.now().strftime("%H-%M")
tpm_key = f"{model_group}:tpm:{current_minute}"
rpm_key = f"{model_group}:rpm:{current_minute}"
tpm_dict = self.router_cache.get_cache(key=tpm_key)
rpm_dict = self.router_cache.get_cache(key=rpm_key)
verbose_router_logger.debug(
f"tpm_key={tpm_key}, tpm_dict: {tpm_dict}, rpm_dict: {rpm_dict}"
)
try:
input_tokens = token_counter(messages=messages, text=input)
except Exception:
input_tokens = 0
verbose_router_logger.debug(f"input_tokens={input_tokens}")
# -----------------------
# Find lowest used model
# ----------------------
lowest_tpm = float("inf")
if tpm_dict is None: # base case - none of the deployments have been used
# initialize a tpm dict with {model_id: 0}
tpm_dict = {}
for deployment in healthy_deployments:
tpm_dict[deployment["model_info"]["id"]] = 0
else:
for d in healthy_deployments:
## if healthy deployment not yet used
if d["model_info"]["id"] not in tpm_dict:
tpm_dict[d["model_info"]["id"]] = 0
all_deployments = tpm_dict
deployment = None
for item, item_tpm in all_deployments.items():
## get the item from model list
_deployment = None
for m in healthy_deployments:
if item == m["model_info"]["id"]:
_deployment = m
if _deployment is None:
continue # skip to next one
_deployment_tpm = None
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("tpm")
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("litellm_params", {}).get("tpm")
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("model_info", {}).get("tpm")
if _deployment_tpm is None:
_deployment_tpm = float("inf")
_deployment_rpm = None
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("rpm")
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("litellm_params", {}).get("rpm")
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("model_info", {}).get("rpm")
if _deployment_rpm is None:
_deployment_rpm = float("inf")
if item_tpm + input_tokens > _deployment_tpm:
continue
elif (rpm_dict is not None and item in rpm_dict) and (
rpm_dict[item] + 1 >= _deployment_rpm
):
continue
elif item_tpm < lowest_tpm:
lowest_tpm = item_tpm
deployment = _deployment
print_verbose("returning picked lowest tpm/rpm deployment.")
return deployment

View File

@@ -0,0 +1,668 @@
#### What this does ####
# identifies lowest tpm deployment
import random
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm import token_counter
from litellm._logging import verbose_logger, verbose_router_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.types.router import RouterErrors
from litellm.types.utils import LiteLLMPydanticObjectBase, StandardLoggingPayload
from litellm.utils import get_utc_datetime, print_verbose
from .base_routing_strategy import BaseRoutingStrategy
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
class RoutingArgs(LiteLLMPydanticObjectBase):
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
class LowestTPMLoggingHandler_v2(BaseRoutingStrategy, CustomLogger):
"""
Updated version of TPM/RPM Logging.
Meant to work across instances.
Caches individual models, not model_groups
Uses batch get (redis.mget)
Increments tpm/rpm limit using redis.incr
"""
test_flag: bool = False
logged_success: int = 0
logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(self, router_cache: DualCache, routing_args: dict = {}):
self.router_cache = router_cache
self.routing_args = RoutingArgs(**routing_args)
BaseRoutingStrategy.__init__(
self,
dual_cache=router_cache,
should_batch_redis_writes=True,
default_sync_interval=0.1,
)
def pre_call_check(self, deployment: Dict) -> Optional[Dict]:
"""
Pre-call check + update model rpm
Returns - deployment
Raises - RateLimitError if deployment over defined RPM limit
"""
try:
# ------------
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
model_id = deployment.get("model_info", {}).get("id")
deployment_name = deployment.get("litellm_params", {}).get("model")
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
local_result = self.router_cache.get_cache(
key=rpm_key, local_only=True
) # check local result first
deployment_rpm = None
if deployment_rpm is None:
deployment_rpm = deployment.get("rpm")
if deployment_rpm is None:
deployment_rpm = deployment.get("litellm_params", {}).get("rpm")
if deployment_rpm is None:
deployment_rpm = deployment.get("model_info", {}).get("rpm")
if deployment_rpm is None:
deployment_rpm = float("inf")
if local_result is not None and local_result >= deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, local_result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="{} rpm limit={}. current usage={}. id={}, model_group={}. Get the model info by calling 'router.get_model_info(id)".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
local_result,
model_id,
deployment.get("model_name", ""),
),
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
else:
# if local result below limit, check redis ## prevent unnecessary redis checks
result = self.router_cache.increment_cache(
key=rpm_key, value=1, ttl=self.routing_args.ttl
)
if result is not None and result > deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="{} rpm limit={}. current usage={}".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
result,
),
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return deployment
except Exception as e:
if isinstance(e, litellm.RateLimitError):
raise e
return deployment # don't fail calls if eg. redis fails to connect
async def async_pre_call_check(
self, deployment: Dict, parent_otel_span: Optional[Span]
) -> Optional[Dict]:
"""
Pre-call check + update model rpm
- Used inside semaphore
- raise rate limit error if deployment over limit
Why? solves concurrency issue - https://github.com/BerriAI/litellm/issues/2994
Returns - deployment
Raises - RateLimitError if deployment over defined RPM limit
"""
try:
# ------------
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
model_id = deployment.get("model_info", {}).get("id")
deployment_name = deployment.get("litellm_params", {}).get("model")
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
local_result = await self.router_cache.async_get_cache(
key=rpm_key, local_only=True
) # check local result first
deployment_rpm = None
if deployment_rpm is None:
deployment_rpm = deployment.get("rpm")
if deployment_rpm is None:
deployment_rpm = deployment.get("litellm_params", {}).get("rpm")
if deployment_rpm is None:
deployment_rpm = deployment.get("model_info", {}).get("rpm")
if deployment_rpm is None:
deployment_rpm = float("inf")
if local_result is not None and local_result >= deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, local_result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="{} rpm limit={}. current usage={}".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
local_result,
),
headers={"retry-after": str(60)}, # type: ignore
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
num_retries=deployment.get("num_retries"),
)
else:
# if local result below limit, check redis ## prevent unnecessary redis checks
result = await self._increment_value_in_current_window(
key=rpm_key, value=1, ttl=self.routing_args.ttl
)
if result is not None and result > deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="{} rpm limit={}. current usage={}".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
result,
),
headers={"retry-after": str(60)}, # type: ignore
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
num_retries=deployment.get("num_retries"),
)
return deployment
except Exception as e:
if isinstance(e, litellm.RateLimitError):
raise e
return deployment # don't fail calls if eg. redis fails to connect
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
Update TPM/RPM usage on success
"""
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if standard_logging_object is None:
raise ValueError("standard_logging_object not passed in.")
model_group = standard_logging_object.get("model_group")
model = standard_logging_object["hidden_params"].get("litellm_model_name")
id = standard_logging_object.get("model_id")
if model_group is None or id is None or model is None:
return
elif isinstance(id, int):
id = str(id)
total_tokens = standard_logging_object.get("total_tokens")
# ------------
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock
tpm_key = f"{id}:{model}:tpm:{current_minute}"
# ------------
# Update usage
# ------------
# update cache
## TPM
self.router_cache.increment_cache(
key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl
)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
verbose_logger.exception(
"litellm.proxy.hooks.lowest_tpm_rpm_v2.py::log_success_event(): Exception occured - {}".format(
str(e)
)
)
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
Update TPM usage on success
"""
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if standard_logging_object is None:
raise ValueError("standard_logging_object not passed in.")
model_group = standard_logging_object.get("model_group")
model = standard_logging_object["hidden_params"]["litellm_model_name"]
id = standard_logging_object.get("model_id")
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
total_tokens = standard_logging_object.get("total_tokens")
# ------------
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock
tpm_key = f"{id}:{model}:tpm:{current_minute}"
# ------------
# Update usage
# ------------
# update cache
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
## TPM
await self.router_cache.async_increment_cache(
key=tpm_key,
value=total_tokens,
ttl=self.routing_args.ttl,
parent_otel_span=parent_otel_span,
)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
verbose_logger.exception(
"litellm.proxy.hooks.lowest_tpm_rpm_v2.py::async_log_success_event(): Exception occured - {}".format(
str(e)
)
)
pass
def _return_potential_deployments(
self,
healthy_deployments: List[Dict],
all_deployments: Dict,
input_tokens: int,
rpm_dict: Dict,
):
lowest_tpm = float("inf")
potential_deployments = [] # if multiple deployments have the same low value
deployment_lookup = {
deployment.get("model_info", {}).get("id"): deployment
for deployment in healthy_deployments
}
for item, item_tpm in all_deployments.items():
## get the item from model list
item = item.split(":")[0]
_deployment = deployment_lookup.get(item)
if _deployment is None:
continue # skip to next one
elif item_tpm is None:
continue # skip if unhealthy deployment
_deployment_tpm = None
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("tpm")
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("litellm_params", {}).get("tpm")
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("model_info", {}).get("tpm")
if _deployment_tpm is None:
_deployment_tpm = float("inf")
_deployment_rpm = None
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("rpm")
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("litellm_params", {}).get("rpm")
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("model_info", {}).get("rpm")
if _deployment_rpm is None:
_deployment_rpm = float("inf")
if item_tpm + input_tokens > _deployment_tpm:
continue
elif (
(rpm_dict is not None and item in rpm_dict)
and rpm_dict[item] is not None
and (rpm_dict[item] + 1 >= _deployment_rpm)
):
continue
elif item_tpm == lowest_tpm:
potential_deployments.append(_deployment)
elif item_tpm < lowest_tpm:
lowest_tpm = item_tpm
potential_deployments = [_deployment]
return potential_deployments
def _common_checks_available_deployment( # noqa: PLR0915
self,
model_group: str,
healthy_deployments: list,
tpm_keys: list,
tpm_values: Optional[list],
rpm_keys: list,
rpm_values: Optional[list],
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
) -> Optional[dict]:
"""
Common checks for get available deployment, across sync + async implementations
"""
if tpm_values is None or rpm_values is None:
return None
tpm_dict = {} # {model_id: 1, ..}
for idx, key in enumerate(tpm_keys):
tpm_dict[tpm_keys[idx].split(":")[0]] = tpm_values[idx]
rpm_dict = {} # {model_id: 1, ..}
for idx, key in enumerate(rpm_keys):
rpm_dict[rpm_keys[idx].split(":")[0]] = rpm_values[idx]
try:
input_tokens = token_counter(messages=messages, text=input)
except Exception:
input_tokens = 0
verbose_router_logger.debug(f"input_tokens={input_tokens}")
# -----------------------
# Find lowest used model
# ----------------------
if tpm_dict is None: # base case - none of the deployments have been used
# initialize a tpm dict with {model_id: 0}
tpm_dict = {}
for deployment in healthy_deployments:
tpm_dict[deployment["model_info"]["id"]] = 0
else:
for d in healthy_deployments:
## if healthy deployment not yet used
tpm_key = d["model_info"]["id"]
if tpm_key not in tpm_dict or tpm_dict[tpm_key] is None:
tpm_dict[tpm_key] = 0
all_deployments = tpm_dict
potential_deployments = self._return_potential_deployments(
healthy_deployments=healthy_deployments,
all_deployments=all_deployments,
input_tokens=input_tokens,
rpm_dict=rpm_dict,
)
print_verbose("returning picked lowest tpm/rpm deployment.")
if len(potential_deployments) > 0:
return random.choice(potential_deployments)
else:
return None
async def async_get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
):
"""
Async implementation of get deployments.
Reduces time to retrieve the tpm/rpm values from cache
"""
# get list of potential deployments
verbose_router_logger.debug(
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
)
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_keys = []
rpm_keys = []
for m in healthy_deployments:
if isinstance(m, dict):
id = m.get("model_info", {}).get(
"id"
) # a deployment should always have an 'id'. this is set in router.py
deployment_name = m.get("litellm_params", {}).get("model")
tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute)
rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute)
tpm_keys.append(tpm_key)
rpm_keys.append(rpm_key)
combined_tpm_rpm_keys = tpm_keys + rpm_keys
combined_tpm_rpm_values = await self.router_cache.async_batch_get_cache(
keys=combined_tpm_rpm_keys
) # [1, 2, None, ..]
if combined_tpm_rpm_values is not None:
tpm_values = combined_tpm_rpm_values[: len(tpm_keys)]
rpm_values = combined_tpm_rpm_values[len(tpm_keys) :]
else:
tpm_values = None
rpm_values = None
deployment = self._common_checks_available_deployment(
model_group=model_group,
healthy_deployments=healthy_deployments,
tpm_keys=tpm_keys,
tpm_values=tpm_values,
rpm_keys=rpm_keys,
rpm_values=rpm_values,
messages=messages,
input=input,
)
try:
assert deployment is not None
return deployment
except Exception:
### GET THE DICT OF TPM / RPM + LIMITS PER DEPLOYMENT ###
deployment_dict = {}
for index, _deployment in enumerate(healthy_deployments):
if isinstance(_deployment, dict):
id = _deployment.get("model_info", {}).get("id")
### GET DEPLOYMENT TPM LIMIT ###
_deployment_tpm = None
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("tpm", None)
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("litellm_params", {}).get(
"tpm", None
)
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("model_info", {}).get(
"tpm", None
)
if _deployment_tpm is None:
_deployment_tpm = float("inf")
### GET CURRENT TPM ###
current_tpm = tpm_values[index] if tpm_values else 0
### GET DEPLOYMENT TPM LIMIT ###
_deployment_rpm = None
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("rpm", None)
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("litellm_params", {}).get(
"rpm", None
)
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("model_info", {}).get(
"rpm", None
)
if _deployment_rpm is None:
_deployment_rpm = float("inf")
### GET CURRENT RPM ###
current_rpm = rpm_values[index] if rpm_values else 0
deployment_dict[id] = {
"current_tpm": current_tpm,
"tpm_limit": _deployment_tpm,
"current_rpm": current_rpm,
"rpm_limit": _deployment_rpm,
}
raise litellm.RateLimitError(
message=f"{RouterErrors.no_deployments_available.value}. Passed model={model_group}. Deployments={deployment_dict}",
llm_provider="",
model=model_group,
response=httpx.Response(
status_code=429,
content="",
headers={"retry-after": str(60)}, # type: ignore
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
def get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
parent_otel_span: Optional[Span] = None,
):
"""
Returns a deployment with the lowest TPM/RPM usage.
"""
# get list of potential deployments
verbose_router_logger.debug(
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
)
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_keys = []
rpm_keys = []
for m in healthy_deployments:
if isinstance(m, dict):
id = m.get("model_info", {}).get(
"id"
) # a deployment should always have an 'id'. this is set in router.py
deployment_name = m.get("litellm_params", {}).get("model")
tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute)
rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute)
tpm_keys.append(tpm_key)
rpm_keys.append(rpm_key)
tpm_values = self.router_cache.batch_get_cache(
keys=tpm_keys, parent_otel_span=parent_otel_span
) # [1, 2, None, ..]
rpm_values = self.router_cache.batch_get_cache(
keys=rpm_keys, parent_otel_span=parent_otel_span
) # [1, 2, None, ..]
deployment = self._common_checks_available_deployment(
model_group=model_group,
healthy_deployments=healthy_deployments,
tpm_keys=tpm_keys,
tpm_values=tpm_values,
rpm_keys=rpm_keys,
rpm_values=rpm_values,
messages=messages,
input=input,
)
try:
assert deployment is not None
return deployment
except Exception:
### GET THE DICT OF TPM / RPM + LIMITS PER DEPLOYMENT ###
deployment_dict = {}
for index, _deployment in enumerate(healthy_deployments):
if isinstance(_deployment, dict):
id = _deployment.get("model_info", {}).get("id")
### GET DEPLOYMENT TPM LIMIT ###
_deployment_tpm = None
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("tpm", None)
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("litellm_params", {}).get(
"tpm", None
)
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("model_info", {}).get(
"tpm", None
)
if _deployment_tpm is None:
_deployment_tpm = float("inf")
### GET CURRENT TPM ###
current_tpm = tpm_values[index] if tpm_values else 0
### GET DEPLOYMENT TPM LIMIT ###
_deployment_rpm = None
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("rpm", None)
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("litellm_params", {}).get(
"rpm", None
)
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("model_info", {}).get(
"rpm", None
)
if _deployment_rpm is None:
_deployment_rpm = float("inf")
### GET CURRENT RPM ###
current_rpm = rpm_values[index] if rpm_values else 0
deployment_dict[id] = {
"current_tpm": current_tpm,
"tpm_limit": _deployment_tpm,
"current_rpm": current_rpm,
"rpm_limit": _deployment_rpm,
}
raise ValueError(
f"{RouterErrors.no_deployments_available.value}. Passed model={model_group}. Deployments={deployment_dict}"
)

View File

@@ -0,0 +1,64 @@
"""
Returns a random deployment from the list of healthy deployments.
If weights are provided, it will return a deployment based on the weights.
"""
import random
from typing import TYPE_CHECKING, Any, Dict, List, Union
from litellm._logging import verbose_router_logger
if TYPE_CHECKING:
from litellm.router import Router as _Router
LitellmRouter = _Router
else:
LitellmRouter = Any
def simple_shuffle(
llm_router_instance: LitellmRouter,
healthy_deployments: Union[List[Any], Dict[Any, Any]],
model: str,
) -> Dict:
"""
Returns a random deployment from the list of healthy deployments.
If weights are provided, it will return a deployment based on the weights.
If users pass `rpm` or `tpm`, we do a random weighted pick - based on `rpm`/`tpm`.
Args:
llm_router_instance: LitellmRouter instance
healthy_deployments: List of healthy deployments
model: Model name
Returns:
Dict: A single healthy deployment
"""
############## Check if 'weight' or 'rpm' or 'tpm' param set for a weighted pick #################
for weight_by in ["weight", "rpm", "tpm"]:
weight = healthy_deployments[0].get("litellm_params").get(weight_by, None)
if weight is not None:
weights = [
m["litellm_params"].get(weight_by, 0) for m in healthy_deployments
]
verbose_router_logger.debug(f"\nweight {weights}")
total_weight = sum(weights)
weights = [weight / total_weight for weight in weights]
verbose_router_logger.debug(f"\n weights {weights} by {weight_by}")
# Perform weighted random pick
selected_index = random.choices(range(len(weights)), weights=weights)[0]
verbose_router_logger.debug(f"\n selected index, {selected_index}")
deployment = healthy_deployments[selected_index]
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}"
)
return deployment or deployment[0]
############## No RPM/TPM passed, we do a random pick #################
item = random.choice(healthy_deployments)
return item or item[0]

View File

@@ -0,0 +1,164 @@
"""
Use this to route requests between Teams
- If tags in request is a subset of tags in deployment, return deployment
- if deployments are set with default tags, return all default deployment
- If no default_deployments are set, return all deployments
"""
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
from litellm._logging import verbose_logger
from litellm.types.router import RouterErrors
if TYPE_CHECKING:
from litellm.router import Router as _Router
LitellmRouter = _Router
else:
LitellmRouter = Any
def is_valid_deployment_tag(
deployment_tags: List[str], request_tags: List[str], match_any: bool = True
) -> bool:
"""
Check if a tag is valid, the matching can be either any or all based on `match_any` flag
"""
if not request_tags:
return False
dep_set = set(deployment_tags)
req_set = set(request_tags)
if match_any:
is_valid_deployment = bool(dep_set & req_set)
else:
is_valid_deployment = req_set.issubset(dep_set)
if is_valid_deployment:
verbose_logger.debug(
"adding deployment with tags: %s, request tags: %s for match_any=%s",
deployment_tags,
request_tags,
match_any,
)
return True
return False
async def get_deployments_for_tag(
llm_router_instance: LitellmRouter,
model: str, # used to raise the correct error
healthy_deployments: Union[List[Any], Dict[Any, Any]],
request_kwargs: Optional[Dict[Any, Any]] = None,
metadata_variable_name: Literal["metadata", "litellm_metadata"] = "metadata",
):
"""
Returns a list of deployments that match the requested model and tags in the request.
Executes tag based filtering based on the tags in request metadata and the tags on the deployments
"""
if llm_router_instance.enable_tag_filtering is not True:
return healthy_deployments
if request_kwargs is None:
verbose_logger.debug(
"get_deployments_for_tag: request_kwargs is None returning healthy_deployments: %s",
healthy_deployments,
)
return healthy_deployments
if healthy_deployments is None:
verbose_logger.debug(
"get_deployments_for_tag: healthy_deployments is None returning healthy_deployments"
)
return healthy_deployments
verbose_logger.debug(
"request metadata: %s", request_kwargs.get(metadata_variable_name)
)
if metadata_variable_name in request_kwargs:
metadata = request_kwargs[metadata_variable_name]
request_tags = metadata.get("tags")
match_any = llm_router_instance.tag_filtering_match_any
new_healthy_deployments = []
default_deployments = []
if request_tags:
verbose_logger.debug(
"get_deployments_for_tag routing: router_keys: %s", request_tags
)
# example this can be router_keys=["free", "custom"]
for deployment in healthy_deployments:
deployment_litellm_params = deployment.get("litellm_params")
deployment_tags = deployment_litellm_params.get("tags")
verbose_logger.debug(
"deployment: %s, deployment_router_keys: %s",
deployment,
deployment_tags,
)
if deployment_tags is None:
continue
if is_valid_deployment_tag(deployment_tags, request_tags, match_any):
new_healthy_deployments.append(deployment)
if "default" in deployment_tags:
default_deployments.append(deployment)
if len(new_healthy_deployments) == 0 and len(default_deployments) == 0:
raise ValueError(
f"{RouterErrors.no_deployments_with_tag_routing.value}. Passed model={model} and tags={request_tags}"
)
return (
new_healthy_deployments
if len(new_healthy_deployments) > 0
else default_deployments
)
# for Untagged requests use default deployments if set
_default_deployments_with_tags = []
for deployment in healthy_deployments:
if "default" in deployment.get("litellm_params", {}).get("tags", []):
_default_deployments_with_tags.append(deployment)
if len(_default_deployments_with_tags) > 0:
return _default_deployments_with_tags
# if no default deployment is found, return healthy_deployments
verbose_logger.debug(
"no tier found in metadata, returning healthy_deployments: %s",
healthy_deployments,
)
return healthy_deployments
def _get_tags_from_request_kwargs(
request_kwargs: Optional[Dict[Any, Any]] = None,
metadata_variable_name: Literal["metadata", "litellm_metadata"] = "metadata",
) -> List[str]:
"""
Helper to get tags from request kwargs
Args:
request_kwargs: The request kwargs to get tags from
Returns:
List[str]: The tags from the request kwargs
"""
if request_kwargs is None:
return []
if metadata_variable_name in request_kwargs:
metadata = request_kwargs[metadata_variable_name] or {}
tags = metadata.get("tags", [])
return tags if tags is not None else []
elif "litellm_params" in request_kwargs:
litellm_params = request_kwargs["litellm_params"] or {}
_metadata = litellm_params.get(metadata_variable_name, {}) or {}
tags = _metadata.get("tags", [])
return tags if tags is not None else []
return []