chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
Logic specific for `litellm.completion`.
|
||||
|
||||
Includes:
|
||||
- Bridge for transforming completion requests to responses api requests
|
||||
@@ -0,0 +1,3 @@
|
||||
from .litellm_responses_transformation import responses_api_bridge
|
||||
|
||||
__all__ = ["responses_api_bridge"]
|
||||
@@ -0,0 +1,3 @@
|
||||
from .handler import responses_api_bridge
|
||||
|
||||
__all__ = ["responses_api_bridge"]
|
||||
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
Handler for transforming /chat/completions api requests to litellm.responses requests
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from litellm.types.llms.openai import ResponsesAPIResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import CustomStreamWrapper, LiteLLMLoggingObj, ModelResponse
|
||||
|
||||
|
||||
class ResponsesToCompletionBridgeHandlerInputKwargs(TypedDict):
|
||||
model: str
|
||||
messages: list
|
||||
optional_params: dict
|
||||
litellm_params: dict
|
||||
headers: dict
|
||||
model_response: "ModelResponse"
|
||||
logging_obj: "LiteLLMLoggingObj"
|
||||
custom_llm_provider: str
|
||||
|
||||
|
||||
class ResponsesToCompletionBridgeHandler:
|
||||
def __init__(self):
|
||||
from .transformation import LiteLLMResponsesTransformationHandler
|
||||
|
||||
super().__init__()
|
||||
self.transformation_handler = LiteLLMResponsesTransformationHandler()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_stream_flag(optional_params: dict, litellm_params: dict) -> bool:
|
||||
stream = optional_params.get("stream")
|
||||
if stream is None:
|
||||
stream = litellm_params.get("stream", False)
|
||||
return bool(stream)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_response_object(
|
||||
response_obj: Any,
|
||||
hidden_params: Optional[dict],
|
||||
) -> "ResponsesAPIResponse":
|
||||
if isinstance(response_obj, ResponsesAPIResponse):
|
||||
response = response_obj
|
||||
elif isinstance(response_obj, dict):
|
||||
try:
|
||||
response = ResponsesAPIResponse(**response_obj)
|
||||
except Exception:
|
||||
response = ResponsesAPIResponse.model_construct(**response_obj)
|
||||
else:
|
||||
raise ValueError("Unexpected responses stream payload")
|
||||
|
||||
if hidden_params:
|
||||
existing = getattr(response, "_hidden_params", None)
|
||||
if not isinstance(existing, dict) or not existing:
|
||||
setattr(response, "_hidden_params", dict(hidden_params))
|
||||
else:
|
||||
for key, value in hidden_params.items():
|
||||
existing.setdefault(key, value)
|
||||
return response
|
||||
|
||||
def _collect_response_from_stream(self, stream_iter: Any) -> "ResponsesAPIResponse":
|
||||
for _ in stream_iter:
|
||||
pass
|
||||
|
||||
completed = getattr(stream_iter, "completed_response", None)
|
||||
response_obj = getattr(completed, "response", None) if completed else None
|
||||
if response_obj is None:
|
||||
raise ValueError("Stream ended without a completed response")
|
||||
|
||||
hidden_params = getattr(stream_iter, "_hidden_params", None)
|
||||
response = self._coerce_response_object(response_obj, hidden_params)
|
||||
if not isinstance(response, ResponsesAPIResponse):
|
||||
raise ValueError("Stream completed response is invalid")
|
||||
return response
|
||||
|
||||
async def _collect_response_from_stream_async(
|
||||
self, stream_iter: Any
|
||||
) -> "ResponsesAPIResponse":
|
||||
async for _ in stream_iter:
|
||||
pass
|
||||
|
||||
completed = getattr(stream_iter, "completed_response", None)
|
||||
response_obj = getattr(completed, "response", None) if completed else None
|
||||
if response_obj is None:
|
||||
raise ValueError("Stream ended without a completed response")
|
||||
|
||||
hidden_params = getattr(stream_iter, "_hidden_params", None)
|
||||
response = self._coerce_response_object(response_obj, hidden_params)
|
||||
if not isinstance(response, ResponsesAPIResponse):
|
||||
raise ValueError("Stream completed response is invalid")
|
||||
return response
|
||||
|
||||
def validate_input_kwargs(
|
||||
self, kwargs: dict
|
||||
) -> ResponsesToCompletionBridgeHandlerInputKwargs:
|
||||
from litellm import LiteLLMLoggingObj
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
model = kwargs.get("model")
|
||||
if model is None or not isinstance(model, str):
|
||||
raise ValueError("model is required")
|
||||
|
||||
custom_llm_provider = kwargs.get("custom_llm_provider")
|
||||
if custom_llm_provider is None or not isinstance(custom_llm_provider, str):
|
||||
raise ValueError("custom_llm_provider is required")
|
||||
|
||||
messages = kwargs.get("messages")
|
||||
if messages is None or not isinstance(messages, list):
|
||||
raise ValueError("messages is required")
|
||||
|
||||
optional_params = kwargs.get("optional_params")
|
||||
if optional_params is None or not isinstance(optional_params, dict):
|
||||
raise ValueError("optional_params is required")
|
||||
|
||||
litellm_params = kwargs.get("litellm_params")
|
||||
if litellm_params is None or not isinstance(litellm_params, dict):
|
||||
raise ValueError("litellm_params is required")
|
||||
|
||||
headers = kwargs.get("headers")
|
||||
if headers is None or not isinstance(headers, dict):
|
||||
raise ValueError("headers is required")
|
||||
|
||||
model_response = kwargs.get("model_response")
|
||||
if model_response is None or not isinstance(model_response, ModelResponse):
|
||||
raise ValueError("model_response is required")
|
||||
|
||||
logging_obj = kwargs.get("logging_obj")
|
||||
if logging_obj is None or not isinstance(logging_obj, LiteLLMLoggingObj):
|
||||
raise ValueError("logging_obj is required")
|
||||
|
||||
return ResponsesToCompletionBridgeHandlerInputKwargs(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
def completion(
|
||||
self, *args, **kwargs
|
||||
) -> Union[
|
||||
Coroutine[Any, Any, Union["ModelResponse", "CustomStreamWrapper"]],
|
||||
"ModelResponse",
|
||||
"CustomStreamWrapper",
|
||||
]:
|
||||
if kwargs.get("acompletion") is True:
|
||||
return self.acompletion(**kwargs)
|
||||
|
||||
from litellm import responses
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
|
||||
validated_kwargs = self.validate_input_kwargs(kwargs)
|
||||
model = validated_kwargs["model"]
|
||||
messages = validated_kwargs["messages"]
|
||||
optional_params = validated_kwargs["optional_params"]
|
||||
litellm_params = validated_kwargs["litellm_params"]
|
||||
headers = validated_kwargs["headers"]
|
||||
model_response = validated_kwargs["model_response"]
|
||||
logging_obj = validated_kwargs["logging_obj"]
|
||||
custom_llm_provider = validated_kwargs["custom_llm_provider"]
|
||||
|
||||
request_data = self.transformation_handler.transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
litellm_logging_obj=logging_obj,
|
||||
client=kwargs.get("client"),
|
||||
)
|
||||
|
||||
result = responses(
|
||||
**request_data,
|
||||
)
|
||||
|
||||
stream = self._resolve_stream_flag(optional_params, litellm_params)
|
||||
if isinstance(result, ResponsesAPIResponse):
|
||||
return self.transformation_handler.transform_response(
|
||||
model=model,
|
||||
raw_response=result,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=kwargs.get("encoding"),
|
||||
api_key=kwargs.get("api_key"),
|
||||
json_mode=kwargs.get("json_mode"),
|
||||
)
|
||||
elif not stream:
|
||||
responses_api_response = self._collect_response_from_stream(result)
|
||||
return self.transformation_handler.transform_response(
|
||||
model=model,
|
||||
raw_response=responses_api_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=kwargs.get("encoding"),
|
||||
api_key=kwargs.get("api_key"),
|
||||
json_mode=kwargs.get("json_mode"),
|
||||
)
|
||||
else:
|
||||
completion_stream = self.transformation_handler.get_model_response_iterator(
|
||||
streaming_response=result, # type: ignore
|
||||
sync_stream=True,
|
||||
json_mode=kwargs.get("json_mode"),
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return self._apply_post_stream_processing(
|
||||
streamwrapper, model, custom_llm_provider
|
||||
)
|
||||
|
||||
async def acompletion(
|
||||
self, *args, **kwargs
|
||||
) -> Union["ModelResponse", "CustomStreamWrapper"]:
|
||||
from litellm import aresponses
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
|
||||
validated_kwargs = self.validate_input_kwargs(kwargs)
|
||||
model = validated_kwargs["model"]
|
||||
messages = validated_kwargs["messages"]
|
||||
optional_params = validated_kwargs["optional_params"]
|
||||
litellm_params = validated_kwargs["litellm_params"]
|
||||
headers = validated_kwargs["headers"]
|
||||
model_response = validated_kwargs["model_response"]
|
||||
logging_obj = validated_kwargs["logging_obj"]
|
||||
custom_llm_provider = validated_kwargs["custom_llm_provider"]
|
||||
|
||||
try:
|
||||
request_data = self.transformation_handler.transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
litellm_logging_obj=logging_obj,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
result = await aresponses(
|
||||
**request_data,
|
||||
aresponses=True,
|
||||
)
|
||||
|
||||
stream = self._resolve_stream_flag(optional_params, litellm_params)
|
||||
if isinstance(result, ResponsesAPIResponse):
|
||||
return self.transformation_handler.transform_response(
|
||||
model=model,
|
||||
raw_response=result,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=kwargs.get("encoding"),
|
||||
api_key=kwargs.get("api_key"),
|
||||
json_mode=kwargs.get("json_mode"),
|
||||
)
|
||||
elif not stream:
|
||||
responses_api_response = await self._collect_response_from_stream_async(
|
||||
result
|
||||
)
|
||||
return self.transformation_handler.transform_response(
|
||||
model=model,
|
||||
raw_response=responses_api_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=kwargs.get("encoding"),
|
||||
api_key=kwargs.get("api_key"),
|
||||
json_mode=kwargs.get("json_mode"),
|
||||
)
|
||||
else:
|
||||
completion_stream = self.transformation_handler.get_model_response_iterator(
|
||||
streaming_response=result, # type: ignore
|
||||
sync_stream=False,
|
||||
json_mode=kwargs.get("json_mode"),
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return self._apply_post_stream_processing(
|
||||
streamwrapper, model, custom_llm_provider
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _apply_post_stream_processing(
|
||||
stream: "CustomStreamWrapper",
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
) -> Any:
|
||||
"""Apply provider-specific post-stream processing if available."""
|
||||
from litellm.types.utils import LlmProviders
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
try:
|
||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||
model=model, provider=LlmProviders(custom_llm_provider)
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
return stream
|
||||
|
||||
if provider_config is not None:
|
||||
return provider_config.post_stream_processing(stream)
|
||||
return stream
|
||||
|
||||
|
||||
responses_api_bridge = ResponsesToCompletionBridgeHandler()
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user