457 lines
17 KiB
Python
457 lines
17 KiB
Python
|
|
"""
|
||
|
|
Batch Rate Limiter Hook
|
||
|
|
|
||
|
|
This hook implements rate limiting for batch API requests by:
|
||
|
|
1. Reading batch input files to count requests and estimate tokens at submission
|
||
|
|
2. Validating actual usage from output files when batches complete
|
||
|
|
3. Integrating with the existing parallel request limiter infrastructure
|
||
|
|
|
||
|
|
## Integration & Calling
|
||
|
|
This hook is automatically registered and called by the proxy system.
|
||
|
|
See BATCH_RATE_LIMITER_INTEGRATION.md for complete integration details.
|
||
|
|
|
||
|
|
Quick summary:
|
||
|
|
- Add to PROXY_HOOKS in litellm/proxy/hooks/__init__.py
|
||
|
|
- Gets auto-instantiated on proxy startup via _add_proxy_hooks()
|
||
|
|
- async_pre_call_hook() fires on POST /v1/batches (batch submission)
|
||
|
|
- async_log_success_event() fires on GET /v1/batches/{id} (batch completion)
|
||
|
|
"""
|
||
|
|
|
||
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
||
|
|
|
||
|
|
from fastapi import HTTPException
|
||
|
|
from pydantic import BaseModel
|
||
|
|
|
||
|
|
import litellm
|
||
|
|
from litellm._logging import verbose_proxy_logger
|
||
|
|
from litellm.batches.batch_utils import (
|
||
|
|
_get_batch_job_input_file_usage,
|
||
|
|
_get_file_content_as_dictionary,
|
||
|
|
)
|
||
|
|
from litellm.integrations.custom_logger import CustomLogger
|
||
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from opentelemetry.trace import Span as _Span
|
||
|
|
|
||
|
|
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
|
||
|
|
RateLimitDescriptor as _RateLimitDescriptor,
|
||
|
|
)
|
||
|
|
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
|
||
|
|
RateLimitStatus as _RateLimitStatus,
|
||
|
|
)
|
||
|
|
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
|
||
|
|
_PROXY_MaxParallelRequestsHandler_v3 as _ParallelRequestLimiter,
|
||
|
|
)
|
||
|
|
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
||
|
|
from litellm.router import Router as _Router
|
||
|
|
|
||
|
|
Span = Union[_Span, Any]
|
||
|
|
InternalUsageCache = _InternalUsageCache
|
||
|
|
Router = _Router
|
||
|
|
ParallelRequestLimiter = _ParallelRequestLimiter
|
||
|
|
RateLimitStatus = _RateLimitStatus
|
||
|
|
RateLimitDescriptor = _RateLimitDescriptor
|
||
|
|
else:
|
||
|
|
Span = Any
|
||
|
|
InternalUsageCache = Any
|
||
|
|
Router = Any
|
||
|
|
ParallelRequestLimiter = Any
|
||
|
|
RateLimitStatus = Dict[str, Any]
|
||
|
|
RateLimitDescriptor = Dict[str, Any]
|
||
|
|
|
||
|
|
|
||
|
|
class BatchFileUsage(BaseModel):
|
||
|
|
"""
|
||
|
|
Internal model for batch file usage tracking, used for batch rate limiting
|
||
|
|
"""
|
||
|
|
|
||
|
|
total_tokens: int
|
||
|
|
request_count: int
|
||
|
|
|
||
|
|
|
||
|
|
class _PROXY_BatchRateLimiter(CustomLogger):
|
||
|
|
"""
|
||
|
|
Rate limiter for batch API requests.
|
||
|
|
|
||
|
|
Handles rate limiting at two points:
|
||
|
|
1. Batch submission - reads input file and reserves capacity
|
||
|
|
2. Batch completion - reads output file and adjusts for actual usage
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
internal_usage_cache: InternalUsageCache,
|
||
|
|
parallel_request_limiter: ParallelRequestLimiter,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Initialize the batch rate limiter.
|
||
|
|
|
||
|
|
Note: These dependencies are automatically injected by ProxyLogging._add_proxy_hooks()
|
||
|
|
when this hook is registered in PROXY_HOOKS. See BATCH_RATE_LIMITER_INTEGRATION.md.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
internal_usage_cache: Cache for storing rate limit data (auto-injected)
|
||
|
|
parallel_request_limiter: Existing rate limiter to integrate with (needs custom injection)
|
||
|
|
"""
|
||
|
|
self.internal_usage_cache = internal_usage_cache
|
||
|
|
self.parallel_request_limiter = parallel_request_limiter
|
||
|
|
|
||
|
|
def _raise_rate_limit_error(
|
||
|
|
self,
|
||
|
|
status: "RateLimitStatus",
|
||
|
|
descriptors: List["RateLimitDescriptor"],
|
||
|
|
batch_usage: BatchFileUsage,
|
||
|
|
limit_type: str,
|
||
|
|
) -> None:
|
||
|
|
"""Raise HTTPException for rate limit exceeded."""
|
||
|
|
from datetime import datetime
|
||
|
|
|
||
|
|
# Find the descriptor for this status
|
||
|
|
descriptor_index = next(
|
||
|
|
(
|
||
|
|
i
|
||
|
|
for i, d in enumerate(descriptors)
|
||
|
|
if d.get("key") == status.get("descriptor_key")
|
||
|
|
),
|
||
|
|
0,
|
||
|
|
)
|
||
|
|
descriptor: RateLimitDescriptor = (
|
||
|
|
descriptors[descriptor_index]
|
||
|
|
if descriptors
|
||
|
|
else {"key": "", "value": "", "rate_limit": None}
|
||
|
|
)
|
||
|
|
|
||
|
|
now = datetime.now().timestamp()
|
||
|
|
window_size = self.parallel_request_limiter.window_size
|
||
|
|
reset_time = now + window_size
|
||
|
|
reset_time_formatted = datetime.fromtimestamp(reset_time).strftime(
|
||
|
|
"%Y-%m-%d %H:%M:%S UTC"
|
||
|
|
)
|
||
|
|
|
||
|
|
remaining_display = max(0, status["limit_remaining"])
|
||
|
|
current_limit = status["current_limit"]
|
||
|
|
|
||
|
|
if limit_type == "requests":
|
||
|
|
detail = (
|
||
|
|
f"Batch rate limit exceeded for {descriptor.get('key', 'unknown')}: {descriptor.get('value', 'unknown')}. "
|
||
|
|
f"Batch contains {batch_usage.request_count} requests but only {remaining_display} requests remaining "
|
||
|
|
f"out of {current_limit} RPM limit. "
|
||
|
|
f"Limit resets at: {reset_time_formatted}"
|
||
|
|
)
|
||
|
|
else: # tokens
|
||
|
|
detail = (
|
||
|
|
f"Batch rate limit exceeded for {descriptor.get('key', 'unknown')}: {descriptor.get('value', 'unknown')}. "
|
||
|
|
f"Batch contains {batch_usage.total_tokens} tokens but only {remaining_display} tokens remaining "
|
||
|
|
f"out of {current_limit} TPM limit. "
|
||
|
|
f"Limit resets at: {reset_time_formatted}"
|
||
|
|
)
|
||
|
|
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=429,
|
||
|
|
detail=detail,
|
||
|
|
headers={
|
||
|
|
"retry-after": str(window_size),
|
||
|
|
"rate_limit_type": limit_type,
|
||
|
|
"reset_at": reset_time_formatted,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
async def _check_and_increment_batch_counters(
|
||
|
|
self,
|
||
|
|
user_api_key_dict: UserAPIKeyAuth,
|
||
|
|
data: Dict,
|
||
|
|
batch_usage: BatchFileUsage,
|
||
|
|
) -> None:
|
||
|
|
"""
|
||
|
|
Check rate limits and increment counters by the batch amounts.
|
||
|
|
|
||
|
|
Raises HTTPException if any limit would be exceeded.
|
||
|
|
"""
|
||
|
|
from litellm.types.caching import RedisPipelineIncrementOperation
|
||
|
|
|
||
|
|
# Create descriptors and check if batch would exceed limits
|
||
|
|
descriptors = self.parallel_request_limiter._create_rate_limit_descriptors(
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
data=data,
|
||
|
|
rpm_limit_type=None,
|
||
|
|
tpm_limit_type=None,
|
||
|
|
model_has_failures=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Check current usage without incrementing
|
||
|
|
rate_limit_response = await self.parallel_request_limiter.should_rate_limit(
|
||
|
|
descriptors=descriptors,
|
||
|
|
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||
|
|
read_only=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Verify batch won't exceed any limits
|
||
|
|
for status in rate_limit_response["statuses"]:
|
||
|
|
rate_limit_type = status["rate_limit_type"]
|
||
|
|
limit_remaining = status["limit_remaining"]
|
||
|
|
|
||
|
|
required_capacity = (
|
||
|
|
batch_usage.request_count
|
||
|
|
if rate_limit_type == "requests"
|
||
|
|
else batch_usage.total_tokens
|
||
|
|
if rate_limit_type == "tokens"
|
||
|
|
else 0
|
||
|
|
)
|
||
|
|
|
||
|
|
if required_capacity > limit_remaining:
|
||
|
|
self._raise_rate_limit_error(
|
||
|
|
status, descriptors, batch_usage, rate_limit_type
|
||
|
|
)
|
||
|
|
|
||
|
|
# Build pipeline operations for batch increments
|
||
|
|
# Reuse the same keys that descriptors check
|
||
|
|
pipeline_operations: List[RedisPipelineIncrementOperation] = []
|
||
|
|
|
||
|
|
for descriptor in descriptors:
|
||
|
|
key = descriptor["key"]
|
||
|
|
value = descriptor["value"]
|
||
|
|
rate_limit = descriptor.get("rate_limit")
|
||
|
|
|
||
|
|
if rate_limit is None:
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Add RPM increment if limit is set
|
||
|
|
if rate_limit.get("requests_per_unit") is not None:
|
||
|
|
rpm_key = self.parallel_request_limiter.create_rate_limit_keys(
|
||
|
|
key=key, value=value, rate_limit_type="requests"
|
||
|
|
)
|
||
|
|
pipeline_operations.append(
|
||
|
|
RedisPipelineIncrementOperation(
|
||
|
|
key=rpm_key,
|
||
|
|
increment_value=batch_usage.request_count,
|
||
|
|
ttl=self.parallel_request_limiter.window_size,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
# Add TPM increment if limit is set
|
||
|
|
if rate_limit.get("tokens_per_unit") is not None:
|
||
|
|
tpm_key = self.parallel_request_limiter.create_rate_limit_keys(
|
||
|
|
key=key, value=value, rate_limit_type="tokens"
|
||
|
|
)
|
||
|
|
pipeline_operations.append(
|
||
|
|
RedisPipelineIncrementOperation(
|
||
|
|
key=tpm_key,
|
||
|
|
increment_value=batch_usage.total_tokens,
|
||
|
|
ttl=self.parallel_request_limiter.window_size,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
# Execute increments
|
||
|
|
if pipeline_operations:
|
||
|
|
await self.parallel_request_limiter.async_increment_tokens_with_ttl_preservation(
|
||
|
|
pipeline_operations=pipeline_operations,
|
||
|
|
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def count_input_file_usage(
|
||
|
|
self,
|
||
|
|
file_id: str,
|
||
|
|
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||
|
|
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
|
||
|
|
) -> BatchFileUsage:
|
||
|
|
"""
|
||
|
|
Count number of requests and tokens in a batch input file.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
file_id: The file ID to read
|
||
|
|
custom_llm_provider: The custom LLM provider to use for token encoding
|
||
|
|
user_api_key_dict: User authentication information for file access (required for managed files)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
BatchFileUsage with total_tokens and request_count
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
# Check if this is a managed file (base64 encoded unified file ID)
|
||
|
|
from litellm.proxy.openai_files_endpoints.common_utils import (
|
||
|
|
_is_base64_encoded_unified_file_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Managed files require bypassing the HTTP endpoint (which runs access-check hooks)
|
||
|
|
# and calling the managed files hook directly with the user's credentials.
|
||
|
|
is_managed_file = _is_base64_encoded_unified_file_id(file_id)
|
||
|
|
if is_managed_file and user_api_key_dict is not None:
|
||
|
|
file_content = await self._fetch_managed_file_content(
|
||
|
|
file_id=file_id,
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
# For non-managed files, use the standard litellm.afile_content
|
||
|
|
file_content = await litellm.afile_content(
|
||
|
|
file_id=file_id,
|
||
|
|
custom_llm_provider=custom_llm_provider,
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
)
|
||
|
|
|
||
|
|
file_content_as_dict = _get_file_content_as_dictionary(file_content.content)
|
||
|
|
|
||
|
|
input_file_usage = _get_batch_job_input_file_usage(
|
||
|
|
file_content_dictionary=file_content_as_dict,
|
||
|
|
custom_llm_provider=custom_llm_provider,
|
||
|
|
)
|
||
|
|
request_count = len(file_content_as_dict)
|
||
|
|
return BatchFileUsage(
|
||
|
|
total_tokens=input_file_usage.total_tokens,
|
||
|
|
request_count=request_count,
|
||
|
|
)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
f"Error counting input file usage for {file_id}: {str(e)}"
|
||
|
|
)
|
||
|
|
raise
|
||
|
|
|
||
|
|
async def _fetch_managed_file_content(
|
||
|
|
self,
|
||
|
|
file_id: str,
|
||
|
|
user_api_key_dict: UserAPIKeyAuth,
|
||
|
|
) -> Any:
|
||
|
|
"""
|
||
|
|
Fetch file content from managed files hook.
|
||
|
|
|
||
|
|
This is needed for managed files because they require proper user context
|
||
|
|
to verify file ownership and access permissions.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
file_id: The managed file ID (base64 encoded)
|
||
|
|
user_api_key_dict: User authentication information
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
HttpxBinaryResponseContent with the file content
|
||
|
|
"""
|
||
|
|
from litellm.llms.base_llm.files.transformation import BaseFileEndpoints
|
||
|
|
|
||
|
|
# Import proxy_server dependencies at runtime to avoid circular imports
|
||
|
|
try:
|
||
|
|
from litellm.proxy.proxy_server import llm_router, proxy_logging_obj
|
||
|
|
except ImportError as e:
|
||
|
|
raise ValueError(
|
||
|
|
f"Cannot import proxy_server dependencies: {str(e)}. "
|
||
|
|
"Managed files require proxy_server to be initialized."
|
||
|
|
)
|
||
|
|
|
||
|
|
# Get the managed files hook
|
||
|
|
if proxy_logging_obj is None:
|
||
|
|
raise ValueError(
|
||
|
|
"proxy_logging_obj not available. Cannot access managed files hook."
|
||
|
|
)
|
||
|
|
|
||
|
|
managed_files_obj = proxy_logging_obj.get_proxy_hook("managed_files")
|
||
|
|
if managed_files_obj is None:
|
||
|
|
raise ValueError(
|
||
|
|
"Managed files hook not found. Cannot access managed file."
|
||
|
|
)
|
||
|
|
|
||
|
|
if not isinstance(managed_files_obj, BaseFileEndpoints):
|
||
|
|
raise ValueError("Managed files hook is not a BaseFileEndpoints instance.")
|
||
|
|
|
||
|
|
if llm_router is None:
|
||
|
|
raise ValueError("llm_router not available. Cannot access managed files.")
|
||
|
|
|
||
|
|
# Use the managed files hook to get file content
|
||
|
|
# This properly handles user permissions and file ownership
|
||
|
|
file_content = await managed_files_obj.afile_content(
|
||
|
|
file_id=file_id,
|
||
|
|
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||
|
|
llm_router=llm_router,
|
||
|
|
)
|
||
|
|
|
||
|
|
return file_content
|
||
|
|
|
||
|
|
async def async_pre_call_hook(
|
||
|
|
self,
|
||
|
|
user_api_key_dict: UserAPIKeyAuth,
|
||
|
|
cache: Any,
|
||
|
|
data: Dict,
|
||
|
|
call_type: str,
|
||
|
|
) -> Union[Exception, str, Dict, None]:
|
||
|
|
"""
|
||
|
|
Pre-call hook for batch operations.
|
||
|
|
|
||
|
|
Only handles batch creation (acreate_batch):
|
||
|
|
- Reads input file
|
||
|
|
- Counts tokens and requests
|
||
|
|
- Reserves rate limit capacity via parallel_request_limiter
|
||
|
|
|
||
|
|
Args:
|
||
|
|
user_api_key_dict: User authentication information
|
||
|
|
cache: Cache instance (not used directly)
|
||
|
|
data: Request data
|
||
|
|
call_type: Type of call being made
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Modified data dict or None
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
HTTPException: 429 if rate limit would be exceeded
|
||
|
|
"""
|
||
|
|
# Only handle batch creation
|
||
|
|
if call_type != "acreate_batch":
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Batch rate limiter: Not handling batch creation rate limiting for call type: {call_type}"
|
||
|
|
)
|
||
|
|
return data
|
||
|
|
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"Batch rate limiter: Handling batch creation rate limiting"
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Extract input_file_id from data
|
||
|
|
input_file_id = data.get("input_file_id")
|
||
|
|
if not input_file_id:
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"No input_file_id in batch request, skipping rate limiting"
|
||
|
|
)
|
||
|
|
return data
|
||
|
|
|
||
|
|
# Get custom_llm_provider for token counting
|
||
|
|
custom_llm_provider = data.get("custom_llm_provider", "openai")
|
||
|
|
|
||
|
|
# Count tokens and requests from input file
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Counting tokens from batch input file: {input_file_id}"
|
||
|
|
)
|
||
|
|
batch_usage = await self.count_input_file_usage(
|
||
|
|
file_id=input_file_id,
|
||
|
|
custom_llm_provider=custom_llm_provider,
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
)
|
||
|
|
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Batch input file usage - Tokens: {batch_usage.total_tokens}, "
|
||
|
|
f"Requests: {batch_usage.request_count}"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Store batch usage in data for later reference
|
||
|
|
data["_batch_token_count"] = batch_usage.total_tokens
|
||
|
|
data["_batch_request_count"] = batch_usage.request_count
|
||
|
|
|
||
|
|
# Directly increment counters by batch amounts (check happens atomically)
|
||
|
|
# This will raise HTTPException if limits are exceeded
|
||
|
|
await self._check_and_increment_batch_counters(
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
data=data,
|
||
|
|
batch_usage=batch_usage,
|
||
|
|
)
|
||
|
|
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"Batch rate limit check passed, counters incremented"
|
||
|
|
)
|
||
|
|
return data
|
||
|
|
|
||
|
|
except HTTPException:
|
||
|
|
# Re-raise HTTP exceptions (rate limit exceeded)
|
||
|
|
raise
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
f"Error in batch rate limiting: {str(e)}", exc_info=True
|
||
|
|
)
|
||
|
|
# Don't block the request if rate limiting fails
|
||
|
|
return data
|