chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1,221 @@
"""
Max Iterations Limiter for LiteLLM Proxy.
Enforces a per-session cap on the number of LLM calls an agentic loop can make.
Callers send a `session_id` with each request (via `x-litellm-session-id` header
or `metadata.session_id`), and this hook counts calls per session. When the count
exceeds `max_iterations` (configured in agent litellm_params or key metadata), returns 429.
Works across multiple proxy instances via DualCache (in-memory + Redis).
Follows the same pattern as parallel_request_limiter_v3.py.
"""
import os
from typing import TYPE_CHECKING, Any, Optional, Union
from fastapi import HTTPException
from litellm import DualCache
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
if TYPE_CHECKING:
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
InternalUsageCache = _InternalUsageCache
else:
InternalUsageCache = Any
# Redis Lua script for atomic increment with TTL.
# Returns the new count after increment.
# Only sets EXPIRE on first increment (when count becomes 1).
MAX_ITERATIONS_INCREMENT_SCRIPT = """
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local current = redis.call('INCR', key)
if current == 1 then
redis.call('EXPIRE', key, ttl)
end
return current
"""
# Default TTL for session iteration counters (1 hour)
DEFAULT_MAX_ITERATIONS_TTL = 3600
class _PROXY_MaxIterationsHandler(CustomLogger):
"""
Pre-call hook that enforces max_iterations per session.
Configuration:
- max_iterations: set in agent litellm_params (preferred)
e.g. litellm_params={"max_iterations": 25}
Falls back to key metadata max_iterations for backwards compatibility.
- session_id: sent by caller via x-litellm-session-id header or
metadata.session_id in request body
Cache key pattern:
{session_iterations:<session_id>}:count
Multi-instance support:
Uses Redis Lua script for atomic increment (same pattern as
parallel_request_limiter_v3). Falls back to in-memory cache
when Redis is unavailable.
"""
def __init__(self, internal_usage_cache: InternalUsageCache):
self.internal_usage_cache = internal_usage_cache
self.ttl = int(
os.getenv("LITELLM_MAX_ITERATIONS_TTL", DEFAULT_MAX_ITERATIONS_TTL)
)
# Register Lua script with Redis if available (same pattern as v3 limiter)
if self.internal_usage_cache.dual_cache.redis_cache is not None:
self.increment_script = (
self.internal_usage_cache.dual_cache.redis_cache.async_register_script(
MAX_ITERATIONS_INCREMENT_SCRIPT
)
)
else:
self.increment_script = None
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
) -> Optional[Union[Exception, str, dict]]:
"""
Check session iteration count before making the API call.
Extracts session_id from request metadata and max_iterations from
agent litellm_params. If the session has exceeded max_iterations, raises 429.
"""
# Extract session_id from request data
session_id = self._get_session_id(data)
if session_id is None:
return None
max_iterations = self._get_max_iterations(user_api_key_dict)
if max_iterations is None:
return None
verbose_proxy_logger.debug(
"MaxIterationsHandler: session_id=%s, max_iterations=%s",
session_id,
max_iterations,
)
# Increment and check
cache_key = self._make_cache_key(session_id)
current_count = await self._increment_and_get(cache_key)
if current_count > max_iterations:
raise HTTPException(
status_code=429,
detail=(
f"Max iterations exceeded for session {session_id}. "
f"Current count: {current_count}, max_iterations: {max_iterations}."
),
)
verbose_proxy_logger.debug(
"MaxIterationsHandler: session_id=%s, count=%s/%s",
session_id,
current_count,
max_iterations,
)
return None
def _get_session_id(self, data: dict) -> Optional[str]:
"""Extract session_id from request metadata."""
metadata = data.get("metadata") or {}
session_id = metadata.get("session_id")
if session_id is not None:
return str(session_id)
# Also check litellm_metadata (used for /thread and /assistant endpoints)
litellm_metadata = data.get("litellm_metadata") or {}
session_id = litellm_metadata.get("session_id")
if session_id is not None:
return str(session_id)
return None
def _get_max_iterations(self, user_api_key_dict: UserAPIKeyAuth) -> Optional[int]:
"""Extract max_iterations from agent litellm_params, with fallback to key metadata."""
# Try agent litellm_params first
agent_id = user_api_key_dict.agent_id
if agent_id is not None:
from litellm.proxy.agent_endpoints.agent_registry import (
global_agent_registry,
)
agent = global_agent_registry.get_agent_by_id(agent_id=agent_id)
if agent is not None:
litellm_params = agent.litellm_params or {}
max_iterations = litellm_params.get("max_iterations")
if max_iterations is not None:
return int(max_iterations)
# Fallback to key metadata for backwards compatibility
metadata = user_api_key_dict.metadata or {}
max_iterations = metadata.get("max_iterations")
if max_iterations is not None:
return int(max_iterations)
return None
def _make_cache_key(self, session_id: str) -> str:
"""
Create cache key for session iteration counter.
Uses Redis hash-tag pattern {session_iterations:<session_id>} so all
keys for a session land on the same Redis Cluster slot.
"""
return f"{{session_iterations:{session_id}}}:count"
async def _increment_and_get(self, cache_key: str) -> int:
"""
Atomically increment the session counter and return the new value.
Tries Redis first (via registered Lua script for atomicity across
instances), falls back to in-memory cache.
"""
if self.increment_script is not None:
try:
result = await self.increment_script(
keys=[cache_key],
args=[self.ttl],
)
return int(result)
except Exception as e:
verbose_proxy_logger.warning(
"MaxIterationsHandler: Redis failed, falling back to in-memory: %s",
str(e),
)
# Fallback: in-memory cache
return await self._in_memory_increment(cache_key)
async def _in_memory_increment(self, cache_key: str) -> int:
"""Increment counter in in-memory cache with TTL."""
current = await self.internal_usage_cache.async_get_cache(
key=cache_key,
litellm_parent_otel_span=None,
local_only=True,
)
new_value = (int(current) if current is not None else 0) + 1
await self.internal_usage_cache.async_set_cache(
key=cache_key,
value=new_value,
ttl=self.ttl,
litellm_parent_otel_span=None,
local_only=True,
)
return new_value