1310 lines
51 KiB
Python
1310 lines
51 KiB
Python
|
|
import asyncio
|
|||
|
|
import json
|
|||
|
|
import time
|
|||
|
|
import traceback
|
|||
|
|
from datetime import datetime
|
|||
|
|
from typing import Any, Dict, List, Optional
|
|||
|
|
|
|||
|
|
import httpx
|
|||
|
|
|
|||
|
|
import litellm
|
|||
|
|
from litellm.constants import (
|
|||
|
|
LITELLM_MAX_STREAMING_DURATION_SECONDS,
|
|||
|
|
STREAM_SSE_DONE_STRING,
|
|||
|
|
)
|
|||
|
|
from litellm.litellm_core_utils.asyncify import run_async_function
|
|||
|
|
from litellm.litellm_core_utils.core_helpers import process_response_headers
|
|||
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
|||
|
|
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
|
|||
|
|
from litellm.litellm_core_utils.llm_response_utils.response_metadata import (
|
|||
|
|
update_response_metadata,
|
|||
|
|
)
|
|||
|
|
from litellm.litellm_core_utils.thread_pool_executor import executor
|
|||
|
|
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
|
|||
|
|
from litellm.responses.utils import ResponsesAPIRequestUtils
|
|||
|
|
from litellm.types.llms.openai import (
|
|||
|
|
OutputTextDeltaEvent,
|
|||
|
|
ResponseAPIUsage,
|
|||
|
|
ResponseCompletedEvent,
|
|||
|
|
ResponsesAPIRequestParams,
|
|||
|
|
ResponsesAPIResponse,
|
|||
|
|
ResponsesAPIStreamEvents,
|
|||
|
|
ResponsesAPIStreamingResponse,
|
|||
|
|
)
|
|||
|
|
from litellm.types.utils import CallTypes
|
|||
|
|
from litellm.utils import CustomStreamWrapper, async_post_call_success_deployment_hook
|
|||
|
|
|
|||
|
|
|
|||
|
|
class BaseResponsesAPIStreamingIterator:
|
|||
|
|
"""
|
|||
|
|
Base class for streaming iterators that process responses from the Responses API.
|
|||
|
|
|
|||
|
|
This class contains shared logic for both synchronous and asynchronous iterators.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
response: httpx.Response,
|
|||
|
|
model: str,
|
|||
|
|
responses_api_provider_config: BaseResponsesAPIConfig,
|
|||
|
|
logging_obj: LiteLLMLoggingObj,
|
|||
|
|
litellm_metadata: Optional[Dict[str, Any]] = None,
|
|||
|
|
custom_llm_provider: Optional[str] = None,
|
|||
|
|
request_data: Optional[Dict[str, Any]] = None,
|
|||
|
|
call_type: Optional[str] = None,
|
|||
|
|
):
|
|||
|
|
self.response = response
|
|||
|
|
self.model = model
|
|||
|
|
self.logging_obj = logging_obj
|
|||
|
|
self.finished = False
|
|||
|
|
self.responses_api_provider_config = responses_api_provider_config
|
|||
|
|
self.completed_response: Optional[ResponsesAPIStreamingResponse] = None
|
|||
|
|
self.start_time = getattr(logging_obj, "start_time", datetime.now())
|
|||
|
|
self._failure_handled = False # Track if failure handler has been called
|
|||
|
|
self._stream_created_time: float = time.time()
|
|||
|
|
|
|||
|
|
# track request context for hooks
|
|||
|
|
self.litellm_metadata = litellm_metadata
|
|||
|
|
self.custom_llm_provider = custom_llm_provider
|
|||
|
|
self.request_data: Dict[str, Any] = request_data or {}
|
|||
|
|
self.call_type: Optional[str] = call_type
|
|||
|
|
|
|||
|
|
# set hidden params for response headers (e.g., x-litellm-model-id)
|
|||
|
|
# This matches the stream wrapper in litellm/litellm_core_utils/streaming_handler.py
|
|||
|
|
_api_base = get_api_base(
|
|||
|
|
model=model or "",
|
|||
|
|
optional_params=self.logging_obj.model_call_details.get(
|
|||
|
|
"litellm_params", {}
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
_model_info: Dict = (
|
|||
|
|
litellm_metadata.get("model_info", {}) if litellm_metadata else {}
|
|||
|
|
)
|
|||
|
|
self._hidden_params = {
|
|||
|
|
"model_id": _model_info.get("id", None),
|
|||
|
|
"api_base": _api_base,
|
|||
|
|
"custom_llm_provider": custom_llm_provider,
|
|||
|
|
}
|
|||
|
|
self._hidden_params["additional_headers"] = process_response_headers(
|
|||
|
|
self.response.headers or {}
|
|||
|
|
) # GUARANTEE OPENAI HEADERS IN RESPONSE
|
|||
|
|
|
|||
|
|
def _check_max_streaming_duration(self) -> None:
|
|||
|
|
"""Raise litellm.Timeout if the stream has exceeded LITELLM_MAX_STREAMING_DURATION_SECONDS."""
|
|||
|
|
if LITELLM_MAX_STREAMING_DURATION_SECONDS is None:
|
|||
|
|
return
|
|||
|
|
elapsed = time.time() - self._stream_created_time
|
|||
|
|
if elapsed > LITELLM_MAX_STREAMING_DURATION_SECONDS:
|
|||
|
|
raise litellm.Timeout(
|
|||
|
|
message=f"Stream exceeded max streaming duration of {LITELLM_MAX_STREAMING_DURATION_SECONDS}s (elapsed {elapsed:.1f}s)",
|
|||
|
|
model=self.model or "",
|
|||
|
|
llm_provider=self.custom_llm_provider or "",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _process_chunk(self, chunk) -> Optional[ResponsesAPIStreamingResponse]:
|
|||
|
|
"""Process a single chunk of data from the stream"""
|
|||
|
|
if not chunk:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# Handle SSE format (data: {...})
|
|||
|
|
chunk = CustomStreamWrapper._strip_sse_data_from_chunk(chunk)
|
|||
|
|
if chunk is None:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# Handle "[DONE]" marker
|
|||
|
|
if chunk == STREAM_SSE_DONE_STRING:
|
|||
|
|
self.finished = True
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# Parse the JSON chunk
|
|||
|
|
parsed_chunk = json.loads(chunk)
|
|||
|
|
|
|||
|
|
# Format as ResponsesAPIStreamingResponse
|
|||
|
|
if isinstance(parsed_chunk, dict):
|
|||
|
|
openai_responses_api_chunk = (
|
|||
|
|
self.responses_api_provider_config.transform_streaming_response(
|
|||
|
|
model=self.model,
|
|||
|
|
parsed_chunk=parsed_chunk,
|
|||
|
|
logging_obj=self.logging_obj,
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# if "response" in parsed_chunk, then encode litellm specific information like custom_llm_provider
|
|||
|
|
response_object = getattr(openai_responses_api_chunk, "response", None)
|
|||
|
|
if response_object:
|
|||
|
|
response = ResponsesAPIRequestUtils._update_responses_api_response_id_with_model_id(
|
|||
|
|
responses_api_response=response_object,
|
|||
|
|
litellm_metadata=self.litellm_metadata,
|
|||
|
|
custom_llm_provider=self.custom_llm_provider,
|
|||
|
|
)
|
|||
|
|
setattr(openai_responses_api_chunk, "response", response)
|
|||
|
|
|
|||
|
|
# Wrap encrypted_content in streaming events (output_item.added, output_item.done)
|
|||
|
|
if self.litellm_metadata and self.litellm_metadata.get(
|
|||
|
|
"encrypted_content_affinity_enabled"
|
|||
|
|
):
|
|||
|
|
event_type = getattr(openai_responses_api_chunk, "type", None)
|
|||
|
|
if event_type in (
|
|||
|
|
ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED,
|
|||
|
|
ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE,
|
|||
|
|
):
|
|||
|
|
item = getattr(openai_responses_api_chunk, "item", None)
|
|||
|
|
if item:
|
|||
|
|
encrypted_content = getattr(item, "encrypted_content", None)
|
|||
|
|
if encrypted_content and isinstance(encrypted_content, str):
|
|||
|
|
model_id = (
|
|||
|
|
self.litellm_metadata.get("model_info", {}).get(
|
|||
|
|
"id"
|
|||
|
|
)
|
|||
|
|
if self.litellm_metadata
|
|||
|
|
else None
|
|||
|
|
)
|
|||
|
|
if model_id:
|
|||
|
|
wrapped_content = ResponsesAPIRequestUtils._wrap_encrypted_content_with_model_id(
|
|||
|
|
encrypted_content, model_id
|
|||
|
|
)
|
|||
|
|
setattr(item, "encrypted_content", wrapped_content)
|
|||
|
|
|
|||
|
|
# Store the completed response
|
|||
|
|
if (
|
|||
|
|
openai_responses_api_chunk
|
|||
|
|
and getattr(openai_responses_api_chunk, "type", None)
|
|||
|
|
== ResponsesAPIStreamEvents.RESPONSE_COMPLETED
|
|||
|
|
):
|
|||
|
|
self.completed_response = openai_responses_api_chunk
|
|||
|
|
# Add cost to usage object if include_cost_in_streaming_usage is True
|
|||
|
|
if (
|
|||
|
|
litellm.include_cost_in_streaming_usage
|
|||
|
|
and self.logging_obj is not None
|
|||
|
|
):
|
|||
|
|
response_obj: Optional[ResponsesAPIResponse] = getattr(
|
|||
|
|
openai_responses_api_chunk, "response", None
|
|||
|
|
)
|
|||
|
|
if response_obj:
|
|||
|
|
usage_obj: Optional[ResponseAPIUsage] = getattr(
|
|||
|
|
response_obj, "usage", None
|
|||
|
|
)
|
|||
|
|
if usage_obj is not None:
|
|||
|
|
try:
|
|||
|
|
cost: Optional[
|
|||
|
|
float
|
|||
|
|
] = self.logging_obj._response_cost_calculator(
|
|||
|
|
result=response_obj
|
|||
|
|
)
|
|||
|
|
if cost is not None:
|
|||
|
|
setattr(usage_obj, "cost", cost)
|
|||
|
|
except Exception:
|
|||
|
|
# If cost calculation fails, continue without cost
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
self._handle_logging_completed_response()
|
|||
|
|
|
|||
|
|
return openai_responses_api_chunk
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
# If we can't parse the chunk, continue
|
|||
|
|
return None
|
|||
|
|
except Exception as e:
|
|||
|
|
# Trigger failure hooks before re-raising
|
|||
|
|
# This ensures failures are logged even when _process_chunk is called directly
|
|||
|
|
self._handle_failure(e)
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def _handle_logging_completed_response(self):
|
|||
|
|
"""Base implementation - should be overridden by subclasses"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
async def _call_post_streaming_deployment_hook(self, chunk):
|
|||
|
|
"""
|
|||
|
|
Allow callbacks to modify streaming chunks before returning (parity with chat).
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# Align with chat pipeline: use logging_obj model_call_details + call_type
|
|||
|
|
typed_call_type: Optional[CallTypes] = None
|
|||
|
|
if self.call_type is not None:
|
|||
|
|
try:
|
|||
|
|
typed_call_type = CallTypes(self.call_type)
|
|||
|
|
except ValueError:
|
|||
|
|
typed_call_type = None
|
|||
|
|
if typed_call_type is None:
|
|||
|
|
try:
|
|||
|
|
typed_call_type = CallTypes(
|
|||
|
|
getattr(self.logging_obj, "call_type", None)
|
|||
|
|
)
|
|||
|
|
except Exception:
|
|||
|
|
typed_call_type = None
|
|||
|
|
|
|||
|
|
request_data = self.request_data or getattr(
|
|||
|
|
self.logging_obj, "model_call_details", {}
|
|||
|
|
)
|
|||
|
|
callbacks = getattr(litellm, "callbacks", None) or []
|
|||
|
|
hooks_ran = False
|
|||
|
|
for callback in callbacks:
|
|||
|
|
if hasattr(callback, "async_post_call_streaming_deployment_hook"):
|
|||
|
|
hooks_ran = True
|
|||
|
|
result = await callback.async_post_call_streaming_deployment_hook(
|
|||
|
|
request_data=request_data,
|
|||
|
|
response_chunk=chunk,
|
|||
|
|
call_type=typed_call_type,
|
|||
|
|
)
|
|||
|
|
if result is not None:
|
|||
|
|
chunk = result
|
|||
|
|
if hooks_ran:
|
|||
|
|
setattr(chunk, "_post_streaming_hooks_ran", True)
|
|||
|
|
return chunk
|
|||
|
|
except Exception:
|
|||
|
|
return chunk
|
|||
|
|
|
|||
|
|
async def call_post_streaming_hooks_for_testing(self, chunk):
|
|||
|
|
"""
|
|||
|
|
Helper to invoke streaming deployment hooks explicitly (used in tests).
|
|||
|
|
"""
|
|||
|
|
return await self._call_post_streaming_deployment_hook(chunk)
|
|||
|
|
|
|||
|
|
def _run_post_success_hooks(self, end_time: datetime):
|
|||
|
|
"""
|
|||
|
|
Run post-call deployment hooks and update metadata similar to chat pipeline.
|
|||
|
|
"""
|
|||
|
|
if self.completed_response is None:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
request_payload: Dict[str, Any] = {}
|
|||
|
|
if isinstance(self.request_data, dict):
|
|||
|
|
request_payload.update(self.request_data)
|
|||
|
|
try:
|
|||
|
|
if hasattr(self.logging_obj, "model_call_details"):
|
|||
|
|
request_payload.update(self.logging_obj.model_call_details)
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
if "litellm_params" not in request_payload:
|
|||
|
|
try:
|
|||
|
|
request_payload["litellm_params"] = getattr(
|
|||
|
|
self.logging_obj, "model_call_details", {}
|
|||
|
|
).get("litellm_params", {})
|
|||
|
|
except Exception:
|
|||
|
|
request_payload["litellm_params"] = {}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
update_response_metadata(
|
|||
|
|
result=self.completed_response,
|
|||
|
|
logging_obj=self.logging_obj,
|
|||
|
|
model=self.model,
|
|||
|
|
kwargs=request_payload,
|
|||
|
|
start_time=self.start_time,
|
|||
|
|
end_time=end_time,
|
|||
|
|
)
|
|||
|
|
except Exception:
|
|||
|
|
# Non-blocking
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
typed_call_type: Optional[CallTypes] = None
|
|||
|
|
if self.call_type is not None:
|
|||
|
|
try:
|
|||
|
|
typed_call_type = CallTypes(self.call_type)
|
|||
|
|
except ValueError:
|
|||
|
|
typed_call_type = None
|
|||
|
|
except Exception:
|
|||
|
|
typed_call_type = None
|
|||
|
|
if typed_call_type is None:
|
|||
|
|
try:
|
|||
|
|
typed_call_type = CallTypes.responses
|
|||
|
|
except Exception:
|
|||
|
|
typed_call_type = None
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# Call synchronously; async hook will be executed via asyncio.run in a new loop
|
|||
|
|
run_async_function(
|
|||
|
|
async_function=async_post_call_success_deployment_hook,
|
|||
|
|
request_data=request_payload,
|
|||
|
|
response=self.completed_response,
|
|||
|
|
call_type=typed_call_type,
|
|||
|
|
)
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
def _handle_failure(self, exception: Exception):
|
|||
|
|
"""
|
|||
|
|
Trigger failure handlers before bubbling the exception.
|
|||
|
|
Only calls handlers once even if called multiple times.
|
|||
|
|
"""
|
|||
|
|
# Prevent double-calling failure handlers
|
|||
|
|
if self._failure_handled:
|
|||
|
|
return
|
|||
|
|
self._failure_handled = True
|
|||
|
|
|
|||
|
|
traceback_exception = traceback.format_exc()
|
|||
|
|
try:
|
|||
|
|
run_async_function(
|
|||
|
|
async_function=self.logging_obj.async_failure_handler,
|
|||
|
|
exception=exception,
|
|||
|
|
traceback_exception=traceback_exception,
|
|||
|
|
start_time=self.start_time,
|
|||
|
|
end_time=datetime.now(),
|
|||
|
|
)
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
executor.submit(
|
|||
|
|
self.logging_obj.failure_handler,
|
|||
|
|
exception,
|
|||
|
|
traceback_exception,
|
|||
|
|
self.start_time,
|
|||
|
|
datetime.now(),
|
|||
|
|
)
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def call_post_streaming_hooks_for_testing(iterator, chunk):
|
|||
|
|
"""
|
|||
|
|
Module-level helper for tests to ensure hooks can be invoked even if the iterator is wrapped.
|
|||
|
|
"""
|
|||
|
|
hook_fn = getattr(iterator, "_call_post_streaming_deployment_hook", None)
|
|||
|
|
if hook_fn is None:
|
|||
|
|
return chunk
|
|||
|
|
return await hook_fn(chunk)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
|
|||
|
|
"""
|
|||
|
|
Async iterator for processing streaming responses from the Responses API.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
response: httpx.Response,
|
|||
|
|
model: str,
|
|||
|
|
responses_api_provider_config: BaseResponsesAPIConfig,
|
|||
|
|
logging_obj: LiteLLMLoggingObj,
|
|||
|
|
litellm_metadata: Optional[Dict[str, Any]] = None,
|
|||
|
|
custom_llm_provider: Optional[str] = None,
|
|||
|
|
request_data: Optional[Dict[str, Any]] = None,
|
|||
|
|
call_type: Optional[str] = None,
|
|||
|
|
):
|
|||
|
|
super().__init__(
|
|||
|
|
response,
|
|||
|
|
model,
|
|||
|
|
responses_api_provider_config,
|
|||
|
|
logging_obj,
|
|||
|
|
litellm_metadata,
|
|||
|
|
custom_llm_provider,
|
|||
|
|
request_data,
|
|||
|
|
call_type,
|
|||
|
|
)
|
|||
|
|
self.stream_iterator = response.aiter_lines()
|
|||
|
|
|
|||
|
|
def __aiter__(self):
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
async def __anext__(self) -> ResponsesAPIStreamingResponse:
|
|||
|
|
try:
|
|||
|
|
self._check_max_streaming_duration()
|
|||
|
|
while True:
|
|||
|
|
# Get the next chunk from the stream
|
|||
|
|
try:
|
|||
|
|
chunk = await self.stream_iterator.__anext__()
|
|||
|
|
except StopAsyncIteration:
|
|||
|
|
self.finished = True
|
|||
|
|
raise StopAsyncIteration
|
|||
|
|
|
|||
|
|
self._check_max_streaming_duration()
|
|||
|
|
result = self._process_chunk(chunk)
|
|||
|
|
|
|||
|
|
if self.finished:
|
|||
|
|
raise StopAsyncIteration
|
|||
|
|
elif result is not None:
|
|||
|
|
# Await hook directly instead of run_async_function
|
|||
|
|
# (which spawns a thread + event loop per call)
|
|||
|
|
result = await self._call_post_streaming_deployment_hook(
|
|||
|
|
chunk=result,
|
|||
|
|
)
|
|||
|
|
return result
|
|||
|
|
# If result is None, continue the loop to get the next chunk
|
|||
|
|
|
|||
|
|
except StopAsyncIteration:
|
|||
|
|
# Normal end of stream - don't log as failure
|
|||
|
|
raise
|
|||
|
|
except httpx.HTTPError as e:
|
|||
|
|
# Handle HTTP errors
|
|||
|
|
self.finished = True
|
|||
|
|
self._handle_failure(e)
|
|||
|
|
raise e
|
|||
|
|
except Exception as e:
|
|||
|
|
self.finished = True
|
|||
|
|
self._handle_failure(e)
|
|||
|
|
raise e
|
|||
|
|
|
|||
|
|
def _handle_logging_completed_response(self):
|
|||
|
|
"""Handle logging for completed responses in async context"""
|
|||
|
|
# Create a copy for logging to avoid modifying the response object that will be returned to the user
|
|||
|
|
# The logging handlers may transform usage from Responses API format (input_tokens/output_tokens)
|
|||
|
|
# to chat completion format (prompt_tokens/completion_tokens) for internal logging
|
|||
|
|
# Use model_dump + model_validate instead of deepcopy to avoid pickle errors with
|
|||
|
|
# Pydantic ValidatorIterator when response contains tool_choice with allowed_tools (fixes #17192)
|
|||
|
|
logging_response = self.completed_response
|
|||
|
|
if self.completed_response is not None and hasattr(
|
|||
|
|
self.completed_response, "model_dump"
|
|||
|
|
):
|
|||
|
|
try:
|
|||
|
|
logging_response = type(self.completed_response).model_validate(
|
|||
|
|
self.completed_response.model_dump()
|
|||
|
|
)
|
|||
|
|
except Exception:
|
|||
|
|
# Fallback to original if serialization fails
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
asyncio.create_task(
|
|||
|
|
self.logging_obj.async_success_handler(
|
|||
|
|
result=logging_response,
|
|||
|
|
start_time=self.start_time,
|
|||
|
|
end_time=datetime.now(),
|
|||
|
|
cache_hit=None,
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
executor.submit(
|
|||
|
|
self.logging_obj.success_handler,
|
|||
|
|
result=logging_response,
|
|||
|
|
cache_hit=None,
|
|||
|
|
start_time=self.start_time,
|
|||
|
|
end_time=datetime.now(),
|
|||
|
|
)
|
|||
|
|
self._run_post_success_hooks(end_time=datetime.now())
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SyncResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
|
|||
|
|
"""
|
|||
|
|
Synchronous iterator for processing streaming responses from the Responses API.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
response: httpx.Response,
|
|||
|
|
model: str,
|
|||
|
|
responses_api_provider_config: BaseResponsesAPIConfig,
|
|||
|
|
logging_obj: LiteLLMLoggingObj,
|
|||
|
|
litellm_metadata: Optional[Dict[str, Any]] = None,
|
|||
|
|
custom_llm_provider: Optional[str] = None,
|
|||
|
|
request_data: Optional[Dict[str, Any]] = None,
|
|||
|
|
call_type: Optional[str] = None,
|
|||
|
|
):
|
|||
|
|
super().__init__(
|
|||
|
|
response,
|
|||
|
|
model,
|
|||
|
|
responses_api_provider_config,
|
|||
|
|
logging_obj,
|
|||
|
|
litellm_metadata,
|
|||
|
|
custom_llm_provider,
|
|||
|
|
request_data,
|
|||
|
|
call_type,
|
|||
|
|
)
|
|||
|
|
self.stream_iterator = response.iter_lines()
|
|||
|
|
|
|||
|
|
def __iter__(self):
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
def __next__(self):
|
|||
|
|
try:
|
|||
|
|
self._check_max_streaming_duration()
|
|||
|
|
while True:
|
|||
|
|
# Get the next chunk from the stream
|
|||
|
|
try:
|
|||
|
|
chunk = next(self.stream_iterator)
|
|||
|
|
except StopIteration:
|
|||
|
|
self.finished = True
|
|||
|
|
raise StopIteration
|
|||
|
|
|
|||
|
|
self._check_max_streaming_duration()
|
|||
|
|
result = self._process_chunk(chunk)
|
|||
|
|
|
|||
|
|
if self.finished:
|
|||
|
|
raise StopIteration
|
|||
|
|
elif result is not None:
|
|||
|
|
# Sync path: use run_async_function for the hook
|
|||
|
|
result = run_async_function(
|
|||
|
|
async_function=self._call_post_streaming_deployment_hook,
|
|||
|
|
chunk=result,
|
|||
|
|
)
|
|||
|
|
return result
|
|||
|
|
# If result is None, continue the loop to get the next chunk
|
|||
|
|
|
|||
|
|
except StopIteration:
|
|||
|
|
# Normal end of stream - don't log as failure
|
|||
|
|
raise
|
|||
|
|
except httpx.HTTPError as e:
|
|||
|
|
# Handle HTTP errors
|
|||
|
|
self.finished = True
|
|||
|
|
self._handle_failure(e)
|
|||
|
|
raise e
|
|||
|
|
except Exception as e:
|
|||
|
|
self.finished = True
|
|||
|
|
self._handle_failure(e)
|
|||
|
|
raise e
|
|||
|
|
|
|||
|
|
def _handle_logging_completed_response(self):
|
|||
|
|
"""Handle logging for completed responses in sync context"""
|
|||
|
|
# Create a copy for logging to avoid modifying the response object that will be returned to the user
|
|||
|
|
# The logging handlers may transform usage from Responses API format (input_tokens/output_tokens)
|
|||
|
|
# to chat completion format (prompt_tokens/completion_tokens) for internal logging
|
|||
|
|
# Use model_dump + model_validate instead of deepcopy to avoid pickle errors with
|
|||
|
|
# Pydantic ValidatorIterator when response contains tool_choice with allowed_tools (fixes #17192)
|
|||
|
|
logging_response = self.completed_response
|
|||
|
|
if self.completed_response is not None and hasattr(
|
|||
|
|
self.completed_response, "model_dump"
|
|||
|
|
):
|
|||
|
|
try:
|
|||
|
|
logging_response = type(self.completed_response).model_validate(
|
|||
|
|
self.completed_response.model_dump()
|
|||
|
|
)
|
|||
|
|
except Exception:
|
|||
|
|
# Fallback to original if serialization fails
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
run_async_function(
|
|||
|
|
async_function=self.logging_obj.async_success_handler,
|
|||
|
|
result=logging_response,
|
|||
|
|
start_time=self.start_time,
|
|||
|
|
end_time=datetime.now(),
|
|||
|
|
cache_hit=None,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
executor.submit(
|
|||
|
|
self.logging_obj.success_handler,
|
|||
|
|
result=logging_response,
|
|||
|
|
cache_hit=None,
|
|||
|
|
start_time=self.start_time,
|
|||
|
|
end_time=datetime.now(),
|
|||
|
|
)
|
|||
|
|
self._run_post_success_hooks(end_time=datetime.now())
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MockResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
|
|||
|
|
"""
|
|||
|
|
Mock iterator—fake a stream by slicing the full response text into
|
|||
|
|
5 char deltas, then emit a completed event.
|
|||
|
|
|
|||
|
|
Models like o1-pro don't support streaming, so we fake it.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
CHUNK_SIZE = 5
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
response: httpx.Response,
|
|||
|
|
model: str,
|
|||
|
|
responses_api_provider_config: BaseResponsesAPIConfig,
|
|||
|
|
logging_obj: LiteLLMLoggingObj,
|
|||
|
|
litellm_metadata: Optional[Dict[str, Any]] = None,
|
|||
|
|
custom_llm_provider: Optional[str] = None,
|
|||
|
|
request_data: Optional[Dict[str, Any]] = None,
|
|||
|
|
call_type: Optional[str] = None,
|
|||
|
|
):
|
|||
|
|
super().__init__(
|
|||
|
|
response=response,
|
|||
|
|
model=model,
|
|||
|
|
responses_api_provider_config=responses_api_provider_config,
|
|||
|
|
logging_obj=logging_obj,
|
|||
|
|
litellm_metadata=litellm_metadata,
|
|||
|
|
custom_llm_provider=custom_llm_provider,
|
|||
|
|
request_data=request_data,
|
|||
|
|
call_type=call_type,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# one-time transform
|
|||
|
|
transformed = (
|
|||
|
|
self.responses_api_provider_config.transform_response_api_response(
|
|||
|
|
model=self.model,
|
|||
|
|
raw_response=response,
|
|||
|
|
logging_obj=logging_obj,
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
full_text = self._collect_text(transformed)
|
|||
|
|
|
|||
|
|
# build a list of 5‑char delta events
|
|||
|
|
deltas = [
|
|||
|
|
OutputTextDeltaEvent(
|
|||
|
|
type=ResponsesAPIStreamEvents.OUTPUT_TEXT_DELTA,
|
|||
|
|
delta=full_text[i : i + self.CHUNK_SIZE],
|
|||
|
|
item_id=transformed.id,
|
|||
|
|
output_index=0,
|
|||
|
|
content_index=0,
|
|||
|
|
)
|
|||
|
|
for i in range(0, len(full_text), self.CHUNK_SIZE)
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# Add cost to usage object if include_cost_in_streaming_usage is True
|
|||
|
|
if litellm.include_cost_in_streaming_usage and logging_obj is not None:
|
|||
|
|
usage_obj: Optional[ResponseAPIUsage] = getattr(transformed, "usage", None)
|
|||
|
|
if usage_obj is not None:
|
|||
|
|
try:
|
|||
|
|
cost: Optional[float] = logging_obj._response_cost_calculator(
|
|||
|
|
result=transformed
|
|||
|
|
)
|
|||
|
|
if cost is not None:
|
|||
|
|
setattr(usage_obj, "cost", cost)
|
|||
|
|
except Exception:
|
|||
|
|
# If cost calculation fails, continue without cost
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
# append the completed event
|
|||
|
|
self._events = deltas + [
|
|||
|
|
ResponseCompletedEvent(
|
|||
|
|
type=ResponsesAPIStreamEvents.RESPONSE_COMPLETED,
|
|||
|
|
response=transformed,
|
|||
|
|
)
|
|||
|
|
]
|
|||
|
|
self._idx = 0
|
|||
|
|
|
|||
|
|
def __aiter__(self):
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
async def __anext__(self) -> ResponsesAPIStreamingResponse:
|
|||
|
|
if self._idx >= len(self._events):
|
|||
|
|
raise StopAsyncIteration
|
|||
|
|
evt = self._events[self._idx]
|
|||
|
|
self._idx += 1
|
|||
|
|
return evt
|
|||
|
|
|
|||
|
|
def __iter__(self):
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
def __next__(self) -> ResponsesAPIStreamingResponse:
|
|||
|
|
if self._idx >= len(self._events):
|
|||
|
|
raise StopIteration
|
|||
|
|
evt = self._events[self._idx]
|
|||
|
|
self._idx += 1
|
|||
|
|
return evt
|
|||
|
|
|
|||
|
|
def _collect_text(self, resp: ResponsesAPIResponse) -> str:
|
|||
|
|
out = ""
|
|||
|
|
for out_item in resp.output:
|
|||
|
|
item_type = getattr(out_item, "type", None)
|
|||
|
|
if item_type == "message":
|
|||
|
|
for c in getattr(out_item, "content", []):
|
|||
|
|
out += c.text
|
|||
|
|
return out
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# WebSocket mode streaming (bidirectional forwarding)
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
from litellm._logging import verbose_logger
|
|||
|
|
from litellm.litellm_core_utils.thread_pool_executor import executor as _ws_executor
|
|||
|
|
|
|||
|
|
RESPONSES_WS_LOGGED_EVENT_TYPES = [
|
|||
|
|
"response.created",
|
|||
|
|
"response.completed",
|
|||
|
|
"response.failed",
|
|||
|
|
"response.incomplete",
|
|||
|
|
"error",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ResponsesWebSocketStreaming:
|
|||
|
|
"""
|
|||
|
|
Manages bidirectional WebSocket forwarding for the Responses API
|
|||
|
|
WebSocket mode (wss://.../v1/responses).
|
|||
|
|
|
|||
|
|
Unlike the Realtime API, the Responses API WebSocket mode:
|
|||
|
|
- Uses response.create as the client-to-server event
|
|||
|
|
- Streams back the same events as the HTTP streaming Responses API
|
|||
|
|
- Supports previous_response_id for incremental continuation
|
|||
|
|
- Supports generate: false for warmup
|
|||
|
|
- One response at a time per connection (sequential, no multiplexing)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
websocket: Any,
|
|||
|
|
backend_ws: Any,
|
|||
|
|
logging_obj: LiteLLMLoggingObj,
|
|||
|
|
user_api_key_dict: Optional[Any] = None,
|
|||
|
|
request_data: Optional[Dict] = None,
|
|||
|
|
):
|
|||
|
|
self.websocket = websocket
|
|||
|
|
self.backend_ws = backend_ws
|
|||
|
|
self.logging_obj = logging_obj
|
|||
|
|
self.user_api_key_dict = user_api_key_dict
|
|||
|
|
self.request_data: Dict = request_data or {}
|
|||
|
|
self.messages: list[Dict] = []
|
|||
|
|
self.input_messages: list[Dict[str, str]] = []
|
|||
|
|
|
|||
|
|
def _should_store_event(self, event_obj: dict) -> bool:
|
|||
|
|
return event_obj.get("type") in RESPONSES_WS_LOGGED_EVENT_TYPES
|
|||
|
|
|
|||
|
|
def _store_event(self, event: Any) -> None:
|
|||
|
|
if isinstance(event, bytes):
|
|||
|
|
event = event.decode("utf-8")
|
|||
|
|
if isinstance(event, str):
|
|||
|
|
try:
|
|||
|
|
event_obj = json.loads(event)
|
|||
|
|
except (json.JSONDecodeError, TypeError):
|
|||
|
|
return
|
|||
|
|
else:
|
|||
|
|
event_obj = event
|
|||
|
|
|
|||
|
|
if self._should_store_event(event_obj):
|
|||
|
|
self.messages.append(event_obj)
|
|||
|
|
|
|||
|
|
def _collect_input_from_client_event(self, message: Any) -> None:
|
|||
|
|
"""Extract user input content from response.create for logging."""
|
|||
|
|
try:
|
|||
|
|
if isinstance(message, str):
|
|||
|
|
msg_obj = json.loads(message)
|
|||
|
|
elif isinstance(message, dict):
|
|||
|
|
msg_obj = message
|
|||
|
|
else:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
if msg_obj.get("type") != "response.create":
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
input_items = msg_obj.get("input", [])
|
|||
|
|
if isinstance(input_items, str):
|
|||
|
|
self.input_messages.append({"role": "user", "content": input_items})
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
if isinstance(input_items, list):
|
|||
|
|
for item in input_items:
|
|||
|
|
if not isinstance(item, dict):
|
|||
|
|
continue
|
|||
|
|
if item.get("type") == "message" and item.get("role") == "user":
|
|||
|
|
content = item.get("content", [])
|
|||
|
|
if isinstance(content, str):
|
|||
|
|
self.input_messages.append(
|
|||
|
|
{"role": "user", "content": content}
|
|||
|
|
)
|
|||
|
|
elif isinstance(content, list):
|
|||
|
|
for c in content:
|
|||
|
|
if (
|
|||
|
|
isinstance(c, dict)
|
|||
|
|
and c.get("type") == "input_text"
|
|||
|
|
):
|
|||
|
|
text = c.get("text", "")
|
|||
|
|
if text:
|
|||
|
|
self.input_messages.append(
|
|||
|
|
{"role": "user", "content": text}
|
|||
|
|
)
|
|||
|
|
except (json.JSONDecodeError, AttributeError, TypeError):
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
def _store_input(self, message: Any) -> None:
|
|||
|
|
self._collect_input_from_client_event(message)
|
|||
|
|
if self.logging_obj:
|
|||
|
|
self.logging_obj.pre_call(input=message, api_key="")
|
|||
|
|
|
|||
|
|
async def _log_messages(self) -> None:
|
|||
|
|
if not self.logging_obj:
|
|||
|
|
return
|
|||
|
|
if self.input_messages:
|
|||
|
|
self.logging_obj.model_call_details["messages"] = self.input_messages
|
|||
|
|
if self.messages:
|
|||
|
|
asyncio.create_task(self.logging_obj.async_success_handler(self.messages))
|
|||
|
|
_ws_executor.submit(self.logging_obj.success_handler, self.messages)
|
|||
|
|
|
|||
|
|
async def backend_to_client(self) -> None:
|
|||
|
|
"""Forward events from backend WebSocket to the client."""
|
|||
|
|
import websockets
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
while True:
|
|||
|
|
try:
|
|||
|
|
raw_response = await self.backend_ws.recv(decode=False) # type: ignore[union-attr]
|
|||
|
|
except TypeError:
|
|||
|
|
raw_response = await self.backend_ws.recv() # type: ignore[union-attr, assignment]
|
|||
|
|
|
|||
|
|
if isinstance(raw_response, bytes):
|
|||
|
|
response_str = raw_response.decode("utf-8")
|
|||
|
|
else:
|
|||
|
|
response_str = raw_response
|
|||
|
|
|
|||
|
|
self._store_event(response_str)
|
|||
|
|
await self.websocket.send_text(response_str)
|
|||
|
|
|
|||
|
|
except websockets.exceptions.ConnectionClosed as e: # type: ignore
|
|||
|
|
verbose_logger.debug("Responses WS backend connection closed: %s", e)
|
|||
|
|
except Exception as e:
|
|||
|
|
verbose_logger.exception("Error in responses WS backend_to_client: %s", e)
|
|||
|
|
finally:
|
|||
|
|
await self._log_messages()
|
|||
|
|
|
|||
|
|
async def client_to_backend(self) -> None:
|
|||
|
|
"""Forward response.create events from client to backend."""
|
|||
|
|
try:
|
|||
|
|
while True:
|
|||
|
|
message = await self.websocket.receive_text()
|
|||
|
|
|
|||
|
|
self._store_input(message)
|
|||
|
|
self._store_event(message)
|
|||
|
|
await self.backend_ws.send(message) # type: ignore[union-attr]
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
verbose_logger.debug("Responses WS client_to_backend ended: %s", e)
|
|||
|
|
|
|||
|
|
async def bidirectional_forward(self) -> None:
|
|||
|
|
"""Run both forwarding directions concurrently."""
|
|||
|
|
forward_task = asyncio.create_task(self.backend_to_client())
|
|||
|
|
try:
|
|||
|
|
await self.client_to_backend()
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
finally:
|
|||
|
|
if not forward_task.done():
|
|||
|
|
forward_task.cancel()
|
|||
|
|
try:
|
|||
|
|
await forward_task
|
|||
|
|
except asyncio.CancelledError:
|
|||
|
|
pass
|
|||
|
|
try:
|
|||
|
|
await self.backend_ws.close()
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# Managed WebSocket mode (HTTP-backed, provider-agnostic)
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
_RESPONSE_CREATE_PARAMS: frozenset = (
|
|||
|
|
ResponsesAPIRequestParams.__required_keys__
|
|||
|
|
| ResponsesAPIRequestParams.__optional_keys__
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
_MANAGED_WS_SKIP_KWARGS: frozenset = frozenset(
|
|||
|
|
{
|
|||
|
|
"litellm_logging_obj",
|
|||
|
|
"litellm_call_id",
|
|||
|
|
"aresponses",
|
|||
|
|
"_aresponses_websocket",
|
|||
|
|
"user_api_key_dict",
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ManagedResponsesWebSocketHandler:
|
|||
|
|
"""
|
|||
|
|
Handles Responses API WebSocket mode for providers that do not expose a
|
|||
|
|
native ``wss://`` responses endpoint.
|
|||
|
|
|
|||
|
|
Instead of proxying to a provider WebSocket, this handler:
|
|||
|
|
- Listens for ``response.create`` events from the client
|
|||
|
|
- Makes HTTP streaming calls via ``litellm.aresponses(stream=True)``
|
|||
|
|
- Serialises and forwards every streaming event back over the WebSocket
|
|||
|
|
- Supports ``previous_response_id`` for multi-turn conversations via
|
|||
|
|
in-memory session tracking (avoids async DB-write timing issues)
|
|||
|
|
- Supports sequential requests over a single persistent connection
|
|||
|
|
|
|||
|
|
This makes every provider that LiteLLM can reach over HTTP available on
|
|||
|
|
the WebSocket transport without any provider-specific changes.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
websocket: Any,
|
|||
|
|
model: str,
|
|||
|
|
logging_obj: "LiteLLMLoggingObj",
|
|||
|
|
user_api_key_dict: Optional[Any] = None,
|
|||
|
|
litellm_metadata: Optional[Dict[str, Any]] = None,
|
|||
|
|
api_key: Optional[str] = None,
|
|||
|
|
api_base: Optional[str] = None,
|
|||
|
|
timeout: Optional[float] = None,
|
|||
|
|
custom_llm_provider: Optional[str] = None,
|
|||
|
|
**kwargs: Any,
|
|||
|
|
) -> None:
|
|||
|
|
self.websocket = websocket
|
|||
|
|
self.model = model
|
|||
|
|
self.logging_obj = logging_obj
|
|||
|
|
self.user_api_key_dict = user_api_key_dict
|
|||
|
|
self.litellm_metadata: Dict[str, Any] = litellm_metadata or {}
|
|||
|
|
self.api_key = api_key
|
|||
|
|
self.api_base = api_base
|
|||
|
|
self.timeout = timeout
|
|||
|
|
self.custom_llm_provider = custom_llm_provider
|
|||
|
|
# Carry through safe pass-through kwargs (e.g. extra_headers)
|
|||
|
|
self.extra_kwargs: Dict[str, Any] = {
|
|||
|
|
k: v for k, v in kwargs.items() if k not in _MANAGED_WS_SKIP_KWARGS
|
|||
|
|
}
|
|||
|
|
# In-memory session history: response_id → full accumulated message list.
|
|||
|
|
# Keyed by the DECODED (pre-encoding) response ID from response.completed.
|
|||
|
|
# This avoids the async DB-write race condition where spend logs haven't
|
|||
|
|
# been committed yet when the next response.create arrives.
|
|||
|
|
self._session_history: Dict[str, List[Dict[str, Any]]] = {}
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# Internal helpers
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _serialize_chunk(chunk: Any) -> Optional[str]:
|
|||
|
|
"""Serialize a streaming chunk to a JSON string for WebSocket transmission."""
|
|||
|
|
try:
|
|||
|
|
if hasattr(chunk, "model_dump_json"):
|
|||
|
|
return chunk.model_dump_json(exclude_none=True)
|
|||
|
|
if hasattr(chunk, "model_dump"):
|
|||
|
|
return json.dumps(chunk.model_dump(exclude_none=True), default=str)
|
|||
|
|
if isinstance(chunk, dict):
|
|||
|
|
return json.dumps(chunk, default=str)
|
|||
|
|
return json.dumps(str(chunk))
|
|||
|
|
except Exception as exc:
|
|||
|
|
verbose_logger.debug(
|
|||
|
|
"ManagedResponsesWS: failed to serialize chunk: %s", exc
|
|||
|
|
)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
async def _send_error(self, message: str, error_type: str = "server_error") -> None:
|
|||
|
|
try:
|
|||
|
|
await self.websocket.send_text(
|
|||
|
|
json.dumps(
|
|||
|
|
{"type": "error", "error": {"type": error_type, "message": message}}
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
def _get_history_messages(self, previous_response_id: str) -> List[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
Return accumulated message history for *previous_response_id*.
|
|||
|
|
|
|||
|
|
The key is the *decoded* response ID (the raw provider response ID before
|
|||
|
|
LiteLLM base64-encodes it into the ``resp_...`` format).
|
|||
|
|
"""
|
|||
|
|
decoded = ResponsesAPIRequestUtils._decode_responses_api_response_id(
|
|||
|
|
previous_response_id
|
|||
|
|
)
|
|||
|
|
raw_id = decoded.get("response_id", previous_response_id)
|
|||
|
|
return list(self._session_history.get(raw_id, []))
|
|||
|
|
|
|||
|
|
def _store_history(self, response_id: str, messages: List[Dict[str, Any]]) -> None:
|
|||
|
|
"""
|
|||
|
|
Store the complete accumulated message history for *response_id*.
|
|||
|
|
|
|||
|
|
Replaces any prior value — callers are responsible for passing the full
|
|||
|
|
history (prior turns + current input + new output).
|
|||
|
|
"""
|
|||
|
|
self._session_history[response_id] = messages
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _extract_response_id(completed_event: Dict[str, Any]) -> Optional[str]:
|
|||
|
|
"""
|
|||
|
|
Pull the raw (decoded) response ID out of a ``response.completed`` event.
|
|||
|
|
Returns *None* if the event doesn't contain a usable ID.
|
|||
|
|
"""
|
|||
|
|
resp_obj = completed_event.get("response", {})
|
|||
|
|
encoded_id: Optional[str] = (
|
|||
|
|
resp_obj.get("id") if isinstance(resp_obj, dict) else None
|
|||
|
|
)
|
|||
|
|
if not encoded_id:
|
|||
|
|
return None
|
|||
|
|
decoded = ResponsesAPIRequestUtils._decode_responses_api_response_id(encoded_id)
|
|||
|
|
return decoded.get("response_id", encoded_id)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _extract_output_messages(
|
|||
|
|
completed_event: Dict[str, Any]
|
|||
|
|
) -> List[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
Convert the output items in a ``response.completed`` event into
|
|||
|
|
Responses API message dicts suitable for the next turn's ``input``.
|
|||
|
|
"""
|
|||
|
|
resp_obj = completed_event.get("response", {})
|
|||
|
|
if not isinstance(resp_obj, dict):
|
|||
|
|
return []
|
|||
|
|
messages: List[Dict[str, Any]] = []
|
|||
|
|
for item in resp_obj.get("output", []) or []:
|
|||
|
|
if not isinstance(item, dict):
|
|||
|
|
continue
|
|||
|
|
item_type = item.get("type")
|
|||
|
|
role = item.get("role", "assistant")
|
|||
|
|
if item_type == "message":
|
|||
|
|
content_parts = item.get("content") or []
|
|||
|
|
text_parts = [
|
|||
|
|
p.get("text", "")
|
|||
|
|
for p in content_parts
|
|||
|
|
if isinstance(p, dict) and p.get("type") in ("output_text", "text")
|
|||
|
|
]
|
|||
|
|
text = "".join(text_parts)
|
|||
|
|
if text:
|
|||
|
|
messages.append(
|
|||
|
|
{
|
|||
|
|
"type": "message",
|
|||
|
|
"role": role,
|
|||
|
|
"content": [{"type": "output_text", "text": text}],
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
elif item_type == "function_call":
|
|||
|
|
messages.append(item)
|
|||
|
|
return messages
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _input_to_messages(input_val: Any) -> List[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
Normalise the ``input`` field of a ``response.create`` event to a list
|
|||
|
|
of Responses API message dicts.
|
|||
|
|
"""
|
|||
|
|
if isinstance(input_val, str):
|
|||
|
|
return [
|
|||
|
|
{
|
|||
|
|
"type": "message",
|
|||
|
|
"role": "user",
|
|||
|
|
"content": [{"type": "input_text", "text": input_val}],
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
if isinstance(input_val, list):
|
|||
|
|
return [item for item in input_val if isinstance(item, dict)]
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# _process_response_create sub-methods
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
async def _parse_message(self, raw_message: str) -> Optional[Dict[str, Any]]:
|
|||
|
|
"""Parse raw WS text; return the message dict or None (JSON error / ignored type)."""
|
|||
|
|
try:
|
|||
|
|
msg_obj = json.loads(raw_message)
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
await self._send_error(
|
|||
|
|
"Invalid JSON in response.create event", "invalid_request_error"
|
|||
|
|
)
|
|||
|
|
return None
|
|||
|
|
if msg_obj.get("type") != "response.create":
|
|||
|
|
# Silently ignore non-response.create messages (e.g. warmup pings)
|
|||
|
|
return None
|
|||
|
|
return msg_obj
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _build_base_call_kwargs(msg_obj: Dict[str, Any]) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
Extract Responses API params from the event, handling both wire formats:
|
|||
|
|
Nested: {"type": "response.create", "response": {"input": [...], ...}}
|
|||
|
|
Flat: {"type": "response.create", "input": [...], "model": "...", ...}
|
|||
|
|
"""
|
|||
|
|
nested = msg_obj.get("response")
|
|||
|
|
response_params: Dict[str, Any] = (
|
|||
|
|
nested
|
|||
|
|
if isinstance(nested, dict) and nested
|
|||
|
|
else {k: v for k, v in msg_obj.items() if k != "type"}
|
|||
|
|
)
|
|||
|
|
return {
|
|||
|
|
param: response_params[param]
|
|||
|
|
for param in _RESPONSE_CREATE_PARAMS
|
|||
|
|
if param in response_params and response_params[param] is not None
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def _apply_history(
|
|||
|
|
self,
|
|||
|
|
call_kwargs: Dict[str, Any],
|
|||
|
|
previous_response_id: Optional[str],
|
|||
|
|
current_messages: List[Dict[str, Any]],
|
|||
|
|
prior_history: List[Dict[str, Any]],
|
|||
|
|
) -> None:
|
|||
|
|
"""Prepend in-memory turn history, or fall back to DB-based reconstruction."""
|
|||
|
|
if not previous_response_id:
|
|||
|
|
return
|
|||
|
|
if prior_history:
|
|||
|
|
call_kwargs["input"] = prior_history + current_messages
|
|||
|
|
verbose_logger.debug(
|
|||
|
|
"ManagedResponsesWS: prepended %d history messages for previous_response_id=%s",
|
|||
|
|
len(prior_history),
|
|||
|
|
previous_response_id,
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
verbose_logger.debug(
|
|||
|
|
"ManagedResponsesWS: no in-memory history for previous_response_id=%s; "
|
|||
|
|
"falling back to DB-based session reconstruction",
|
|||
|
|
previous_response_id,
|
|||
|
|
)
|
|||
|
|
# Fall back to DB-based session reconstruction (may work for
|
|||
|
|
# cross-connection multi-turn when spend logs are committed)
|
|||
|
|
call_kwargs["previous_response_id"] = previous_response_id
|
|||
|
|
|
|||
|
|
def _inject_credentials(
|
|||
|
|
self, call_kwargs: Dict[str, Any], event_model: Optional[str]
|
|||
|
|
) -> None:
|
|||
|
|
"""Inject connection-level credentials and metadata into call_kwargs."""
|
|||
|
|
if self.api_key is not None:
|
|||
|
|
call_kwargs["api_key"] = self.api_key
|
|||
|
|
if self.api_base is not None:
|
|||
|
|
call_kwargs["api_base"] = self.api_base
|
|||
|
|
if self.timeout is not None:
|
|||
|
|
call_kwargs["timeout"] = self.timeout
|
|||
|
|
# Only propagate custom_llm_provider when no per-request model override exists.
|
|||
|
|
# If the payload specifies a different model, let litellm re-resolve the
|
|||
|
|
# provider so we don't accidentally force the wrong backend.
|
|||
|
|
if self.custom_llm_provider is not None and not event_model:
|
|||
|
|
call_kwargs["custom_llm_provider"] = self.custom_llm_provider
|
|||
|
|
if self.litellm_metadata:
|
|||
|
|
call_kwargs["litellm_metadata"] = dict(self.litellm_metadata)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _update_proxy_request(call_kwargs: Dict[str, Any], model: str) -> None:
|
|||
|
|
"""Update proxy_server_request body so spend logs record the full request."""
|
|||
|
|
proxy_server_request = (call_kwargs.get("litellm_metadata") or {}).get(
|
|||
|
|
"proxy_server_request"
|
|||
|
|
) or {}
|
|||
|
|
if not isinstance(proxy_server_request, dict):
|
|||
|
|
return
|
|||
|
|
body = dict(proxy_server_request.get("body") or {})
|
|||
|
|
body["input"] = call_kwargs.get("input")
|
|||
|
|
body["store"] = call_kwargs.get("store")
|
|||
|
|
body["model"] = model
|
|||
|
|
for k in ("tools", "tool_choice", "instructions", "metadata"):
|
|||
|
|
if k in call_kwargs and call_kwargs[k] is not None:
|
|||
|
|
body[k] = call_kwargs[k]
|
|||
|
|
proxy_server_request = {**proxy_server_request, "body": body}
|
|||
|
|
if "litellm_metadata" not in call_kwargs:
|
|||
|
|
call_kwargs["litellm_metadata"] = {}
|
|||
|
|
call_kwargs["litellm_metadata"]["proxy_server_request"] = proxy_server_request
|
|||
|
|
call_kwargs.setdefault("litellm_params", {})
|
|||
|
|
call_kwargs["litellm_params"]["proxy_server_request"] = proxy_server_request
|
|||
|
|
|
|||
|
|
async def _stream_and_forward(
|
|||
|
|
self, model: str, call_kwargs: Dict[str, Any]
|
|||
|
|
) -> Optional[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
Stream ``litellm.aresponses`` and forward every chunk over the WebSocket.
|
|||
|
|
|
|||
|
|
Captures the ``response.completed`` event type from the chunk object
|
|||
|
|
directly (before serialization) to avoid a redundant JSON round-trip on
|
|||
|
|
every chunk. Returns the completed event dict, or ``None``.
|
|||
|
|
"""
|
|||
|
|
completed_event: Optional[Dict[str, Any]] = None
|
|||
|
|
stream_response = await litellm.aresponses(model=model, **call_kwargs)
|
|||
|
|
async for chunk in stream_response: # type: ignore[union-attr]
|
|||
|
|
if chunk is None:
|
|||
|
|
continue
|
|||
|
|
# Read type from the object before serializing to avoid double JSON parse
|
|||
|
|
chunk_type = getattr(chunk, "type", None) or (
|
|||
|
|
chunk.get("type") if isinstance(chunk, dict) else None
|
|||
|
|
)
|
|||
|
|
serialized = self._serialize_chunk(chunk)
|
|||
|
|
if serialized is None:
|
|||
|
|
continue
|
|||
|
|
if chunk_type == "response.completed" and completed_event is None:
|
|||
|
|
try:
|
|||
|
|
completed_event = json.loads(serialized)
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
try:
|
|||
|
|
await self.websocket.send_text(serialized)
|
|||
|
|
except Exception as send_exc:
|
|||
|
|
verbose_logger.debug(
|
|||
|
|
"ManagedResponsesWS: error sending chunk to client: %s", send_exc
|
|||
|
|
)
|
|||
|
|
return completed_event # Client disconnected
|
|||
|
|
return completed_event
|
|||
|
|
|
|||
|
|
def _save_turn_history(
|
|||
|
|
self,
|
|||
|
|
completed_event: Optional[Dict[str, Any]],
|
|||
|
|
prior_history: List[Dict[str, Any]],
|
|||
|
|
current_messages: List[Dict[str, Any]],
|
|||
|
|
) -> None:
|
|||
|
|
"""Store this turn in in-memory history for future previous_response_id lookups."""
|
|||
|
|
if completed_event is None:
|
|||
|
|
return
|
|||
|
|
new_response_id = self._extract_response_id(completed_event)
|
|||
|
|
if not new_response_id:
|
|||
|
|
return
|
|||
|
|
output_msgs = self._extract_output_messages(completed_event)
|
|||
|
|
all_messages = prior_history + current_messages + output_msgs
|
|||
|
|
self._store_history(new_response_id, all_messages)
|
|||
|
|
verbose_logger.debug(
|
|||
|
|
"ManagedResponsesWS: stored %d messages for response_id=%s",
|
|||
|
|
len(all_messages),
|
|||
|
|
new_response_id,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# Core request handler
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
async def _process_response_create(self, raw_message: str) -> None:
|
|||
|
|
"""
|
|||
|
|
Parse one ``response.create`` event, call ``litellm.aresponses(stream=True)``,
|
|||
|
|
and forward every streaming event to the client.
|
|||
|
|
|
|||
|
|
Multi-turn support via in-memory session history
|
|||
|
|
------------------------------------------------
|
|||
|
|
When ``previous_response_id`` is present in the event:
|
|||
|
|
1. Look up the accumulated message history in ``self._session_history``
|
|||
|
|
(keyed by the decoded provider response ID).
|
|||
|
|
2. Prepend those messages to the current ``input`` so the model has full
|
|||
|
|
conversation context.
|
|||
|
|
3. After the stream completes, extract the new response ID and output
|
|||
|
|
messages from ``response.completed`` and store them in
|
|||
|
|
``self._session_history`` for the next turn.
|
|||
|
|
|
|||
|
|
This in-memory approach avoids the async DB-write race condition that
|
|||
|
|
occurs when spend logs haven't been committed by the time the second
|
|||
|
|
``response.create`` arrives over the same WebSocket connection.
|
|||
|
|
"""
|
|||
|
|
msg_obj = await self._parse_message(raw_message)
|
|||
|
|
if msg_obj is None:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
call_kwargs = self._build_base_call_kwargs(msg_obj)
|
|||
|
|
call_kwargs["stream"] = True
|
|||
|
|
|
|||
|
|
event_model: Optional[str] = call_kwargs.pop("model", None)
|
|||
|
|
model = event_model or self.model
|
|||
|
|
|
|||
|
|
previous_response_id: Optional[str] = call_kwargs.pop(
|
|||
|
|
"previous_response_id", None
|
|||
|
|
)
|
|||
|
|
current_messages = self._input_to_messages(call_kwargs.get("input"))
|
|||
|
|
|
|||
|
|
# Fetch history once; reused in both _apply_history and _save_turn_history
|
|||
|
|
prior_history = (
|
|||
|
|
self._get_history_messages(previous_response_id)
|
|||
|
|
if previous_response_id
|
|||
|
|
else []
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
self._apply_history(
|
|||
|
|
call_kwargs, previous_response_id, current_messages, prior_history
|
|||
|
|
)
|
|||
|
|
self._inject_credentials(call_kwargs, event_model)
|
|||
|
|
self._update_proxy_request(call_kwargs, model)
|
|||
|
|
call_kwargs.update(self.extra_kwargs)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
completed_event = await self._stream_and_forward(model, call_kwargs)
|
|||
|
|
except Exception as exc:
|
|||
|
|
verbose_logger.exception(
|
|||
|
|
"ManagedResponsesWS: error processing response.create: %s", exc
|
|||
|
|
)
|
|||
|
|
await self._send_error(str(exc))
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
self._save_turn_history(completed_event, prior_history, current_messages)
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# Main entry point
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
async def run(self) -> None:
|
|||
|
|
"""
|
|||
|
|
Main loop: accept ``response.create`` events sequentially and handle
|
|||
|
|
each one before waiting for the next message.
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
while True:
|
|||
|
|
try:
|
|||
|
|
message = await self.websocket.receive_text()
|
|||
|
|
except Exception as exc:
|
|||
|
|
verbose_logger.debug(
|
|||
|
|
"ManagedResponsesWS: client disconnected: %s", exc
|
|||
|
|
)
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
await self._process_response_create(message)
|
|||
|
|
|
|||
|
|
except Exception as exc:
|
|||
|
|
verbose_logger.exception("ManagedResponsesWS: unexpected error: %s", exc)
|
|||
|
|
await self._send_error(f"Internal server error: {exc}")
|