Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/proxy/response_polling/polling_handler.py

331 lines
12 KiB
Python
Raw Normal View History

"""
Response Polling Handler for Background Responses with Cache
"""
import json
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid4
from litellm.caching.redis_cache import RedisCache
from litellm.types.llms.openai import ResponsesAPIResponse, ResponsesAPIStatus
class ResponsePollingHandler:
"""Handles polling-based responses with Redis cache"""
CACHE_KEY_PREFIX = "litellm:polling:response:"
POLLING_ID_PREFIX = "litellm_poll_" # Clear prefix to identify polling IDs
def __init__(self, redis_cache: Optional[RedisCache] = None, ttl: int = 3600):
self.redis_cache = redis_cache
self.ttl = ttl # Time-to-live for cache entries (default: 1 hour)
@classmethod
def generate_polling_id(cls) -> str:
"""Generate a unique UUID for polling with clear prefix"""
return f"{cls.POLLING_ID_PREFIX}{uuid4()}"
@classmethod
def is_polling_id(cls, response_id: str) -> bool:
"""Check if a response_id is a polling ID"""
return response_id.startswith(cls.POLLING_ID_PREFIX)
@classmethod
def get_cache_key(cls, polling_id: str) -> str:
"""Get Redis cache key for a polling ID"""
return f"{cls.CACHE_KEY_PREFIX}{polling_id}"
async def create_initial_state(
self,
polling_id: str,
request_data: Dict[str, Any],
) -> ResponsesAPIResponse:
"""
Create initial state in Redis for a polling request
Uses OpenAI ResponsesAPIResponse object:
https://platform.openai.com/docs/api-reference/responses/object
Args:
polling_id: Unique identifier for this polling request
request_data: Original request data
Returns:
ResponsesAPIResponse object following OpenAI spec
"""
created_timestamp = int(datetime.now(timezone.utc).timestamp())
# Create OpenAI-compliant response object
response = ResponsesAPIResponse(
id=polling_id,
object="response",
status="queued", # OpenAI native status
created_at=created_timestamp,
output=[],
metadata=request_data.get("metadata", {}),
usage=None,
)
cache_key = self.get_cache_key(polling_id)
if self.redis_cache:
# Store ResponsesAPIResponse directly in Redis
await self.redis_cache.async_set_cache(
key=cache_key,
value=response.model_dump_json(), # Pydantic v2 method
ttl=self.ttl,
)
verbose_proxy_logger.debug(
f"Created initial polling state for {polling_id} with TTL={self.ttl}s"
)
return response
async def update_state(
self,
polling_id: str,
status: Optional[ResponsesAPIStatus] = None,
usage: Optional[Dict] = None,
error: Optional[Dict] = None,
incomplete_details: Optional[Dict] = None,
reasoning: Optional[Dict] = None,
tool_choice: Optional[Any] = None,
tools: Optional[list] = None,
output: Optional[list] = None,
# Additional ResponsesAPIResponse fields
model: Optional[str] = None,
instructions: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_output_tokens: Optional[int] = None,
previous_response_id: Optional[str] = None,
text: Optional[Dict] = None,
truncation: Optional[str] = None,
parallel_tool_calls: Optional[bool] = None,
user: Optional[str] = None,
store: Optional[bool] = None,
) -> None:
"""
Update the polling state in Redis
Uses OpenAI Response object format with native status types:
https://platform.openai.com/docs/api-reference/responses/object
Args:
polling_id: Unique identifier for this polling request
status: OpenAI ResponsesAPIStatus value
usage: Usage information
error: Error dict (automatically sets status to "failed")
incomplete_details: Details for incomplete responses
reasoning: Reasoning configuration from response.completed
tool_choice: Tool choice configuration from response.completed
tools: Tools list from response.completed
output: Full output list to replace current output
model: Model identifier
instructions: System instructions
temperature: Sampling temperature
top_p: Nucleus sampling parameter
max_output_tokens: Maximum output tokens
previous_response_id: ID of previous response in conversation
text: Text configuration
truncation: Truncation setting
parallel_tool_calls: Whether parallel tool calls are enabled
user: User identifier
store: Whether to store the response
"""
if not self.redis_cache:
return
cache_key = self.get_cache_key(polling_id)
# Get current state
cached_state = await self.redis_cache.async_get_cache(cache_key)
if not cached_state:
verbose_proxy_logger.warning(
f"No cached state found for polling_id: {polling_id}"
)
return
# Parse existing ResponsesAPIResponse from cache
state = json.loads(cached_state)
# Update status (using OpenAI native status values)
if status:
state["status"] = status
# Replace full output list if provided
if output is not None:
state["output"] = output
# Update usage
if usage:
state["usage"] = usage
# Handle error (sets status to OpenAI's "failed")
if error:
state["status"] = "failed"
state["error"] = error # Use OpenAI's 'error' field
# Handle incomplete details
if incomplete_details:
state["incomplete_details"] = incomplete_details
# Update reasoning, tool_choice, tools from response.completed
if reasoning is not None:
state["reasoning"] = reasoning
if tool_choice is not None:
state["tool_choice"] = tool_choice
if tools is not None:
state["tools"] = tools
# Update additional ResponsesAPIResponse fields
if model is not None:
state["model"] = model
if instructions is not None:
state["instructions"] = instructions
if temperature is not None:
state["temperature"] = temperature
if top_p is not None:
state["top_p"] = top_p
if max_output_tokens is not None:
state["max_output_tokens"] = max_output_tokens
if previous_response_id is not None:
state["previous_response_id"] = previous_response_id
if text is not None:
state["text"] = text
if truncation is not None:
state["truncation"] = truncation
if parallel_tool_calls is not None:
state["parallel_tool_calls"] = parallel_tool_calls
if user is not None:
state["user"] = user
if store is not None:
state["store"] = store
# Update cache with configured TTL
await self.redis_cache.async_set_cache(
key=cache_key,
value=json.dumps(state),
ttl=self.ttl,
)
output_count = len(state.get("output", []))
verbose_proxy_logger.debug(
f"Updated polling state for {polling_id}: status={state['status']}, output_items={output_count}"
)
async def get_state(self, polling_id: str) -> Optional[Dict[str, Any]]:
"""Get current polling state from Redis"""
if not self.redis_cache:
return None
cache_key = self.get_cache_key(polling_id)
cached_state = await self.redis_cache.async_get_cache(cache_key)
if cached_state:
return json.loads(cached_state)
return None
async def cancel_polling(self, polling_id: str) -> bool:
"""
Cancel a polling request
Following OpenAI Response object format for cancelled status
"""
await self.update_state(
polling_id=polling_id,
status="cancelled",
)
return True
async def delete_polling(self, polling_id: str) -> bool:
"""Delete a polling request from cache"""
if not self.redis_cache:
return False
cache_key = self.get_cache_key(polling_id)
# Use RedisCache's async_delete_cache method which handles Redis/RedisCluster
await self.redis_cache.async_delete_cache(cache_key)
return True
def should_use_polling_for_request(
background_mode: bool,
polling_via_cache_enabled, # Can be False, "all", or List[str]
redis_cache, # RedisCache or None
model: str,
llm_router, # Router instance or None
native_background_mode: Optional[
List[str]
] = None, # List of models that should use native background mode
) -> bool:
"""
Determine if polling via cache should be used for a request.
Args:
background_mode: Whether background=true was set in the request
polling_via_cache_enabled: Config value - False, "all", or list of providers
redis_cache: Redis cache instance (required for polling)
model: Model name from the request (e.g., "gpt-5" or "openai/gpt-4o")
llm_router: LiteLLM router instance for looking up model deployments
native_background_mode: List of model names that should use native provider
background mode instead of polling via cache
Returns:
True if polling should be used, False otherwise
"""
# All conditions must be met
if not (background_mode and polling_via_cache_enabled and redis_cache):
return False
# Check if model is in native_background_mode list - these use native provider background mode
if native_background_mode and model in native_background_mode:
verbose_proxy_logger.debug(
f"Model {model} is in native_background_mode list, skipping polling via cache"
)
return False
# "all" enables polling for all providers
if polling_via_cache_enabled == "all":
return True
# Check if provider is in the enabled list
if isinstance(polling_via_cache_enabled, list):
# First, try to get provider from model string format "provider/model"
if "/" in model:
provider = model.split("/")[0]
if provider in polling_via_cache_enabled:
return True
# Otherwise, check ALL deployments for this model_name in router
elif llm_router is not None:
try:
# Get all deployment indices for this model name
indices = llm_router.model_name_to_deployment_indices.get(model, [])
for idx in indices:
deployment_dict = llm_router.model_list[idx]
litellm_params = deployment_dict.get("litellm_params", {})
# Check custom_llm_provider first
dep_provider = litellm_params.get("custom_llm_provider")
# Then try to extract from model (e.g., "openai/gpt-5")
if not dep_provider:
dep_model = litellm_params.get("model", "")
if "/" in dep_model:
dep_provider = dep_model.split("/")[0]
# If ANY deployment's provider matches, enable polling
if dep_provider and dep_provider in polling_via_cache_enabled:
verbose_proxy_logger.debug(
f"Polling enabled for model={model}, provider={dep_provider}"
)
return True
except Exception as e:
verbose_proxy_logger.debug(
f"Could not resolve provider for model {model}: {e}"
)
return False