chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,270 @@
"""Support for OpenAI gpt-5 model family."""
from typing import Optional, Union
import litellm
from litellm.utils import _supports_factory
from .gpt_transformation import OpenAIGPTConfig
def _normalize_reasoning_effort_for_chat_completion(
value: Union[str, dict, None],
) -> Optional[str]:
"""Convert reasoning_effort to the string format expected by OpenAI chat completion API.
The chat completion API expects a simple string: 'none', 'low', 'medium', 'high', or 'xhigh'.
Config/deployments may pass the Responses API format: {'effort': 'high', 'summary': 'detailed'}.
"""
if value is None:
return None
if isinstance(value, str):
return value
if isinstance(value, dict) and "effort" in value:
return value["effort"]
return None
def _get_effort_level(value: Union[str, dict, None]) -> Optional[str]:
"""Extract the effective effort level from reasoning_effort (string or dict).
Use this for guards that compare effort level (e.g. xhigh validation, "none" checks).
Ensures dict inputs like {"effort": "none", "summary": "detailed"} are correctly
treated as effort="none" for validation purposes.
"""
if value is None:
return None
if isinstance(value, str):
return value
if isinstance(value, dict) and "effort" in value:
return value["effort"]
return None
class OpenAIGPT5Config(OpenAIGPTConfig):
"""Configuration for gpt-5 models including GPT-5-Codex variants.
Handles OpenAI API quirks for the gpt-5 series like:
- Mapping ``max_tokens`` -> ``max_completion_tokens``.
- Dropping unsupported ``temperature`` values when requested.
- Support for GPT-5-Codex models optimized for code generation.
"""
@classmethod
def is_model_gpt_5_model(cls, model: str) -> bool:
# gpt-5-chat* behaves like a regular chat model (supports temperature, etc.)
# Don't route it through GPT-5 reasoning-specific parameter restrictions.
return "gpt-5" in model and "gpt-5-chat" not in model
@classmethod
def is_model_gpt_5_search_model(cls, model: str) -> bool:
"""Check if the model is a GPT-5 search variant (e.g. gpt-5-search-api).
Search-only models have a severely restricted parameter set compared to
regular GPT-5 models. They are identified by name convention (contain
both ``gpt-5`` and ``search``). Note: ``supports_web_search`` in model
info is a *different* concept — it indicates a model can *use* web
search as a tool, which many non-search-only models also support.
"""
return "gpt-5" in model and "search" in model
@classmethod
def is_model_gpt_5_codex_model(cls, model: str) -> bool:
"""Check if the model is specifically a GPT-5 Codex variant."""
return "gpt-5-codex" in model
@classmethod
def is_model_gpt_5_2_model(cls, model: str) -> bool:
"""Check if the model is a gpt-5.2 variant (including pro)."""
model_name = model.split("/")[-1]
return model_name.startswith("gpt-5.2") or model_name.startswith("gpt-5.4")
@classmethod
def is_model_gpt_5_4_model(cls, model: str) -> bool:
"""Check if the model is a gpt-5.4 variant (including pro)."""
model_name = model.split("/")[-1]
return model_name.startswith("gpt-5.4")
@classmethod
def is_model_gpt_5_4_plus_model(cls, model: str) -> bool:
"""Check if the model is gpt-5.4 or newer (5.4, 5.5, 5.6, etc., including pro)."""
model_name = model.split("/")[-1]
if not model_name.startswith("gpt-5."):
return False
try:
version_str = model_name.replace("gpt-5.", "").split("-")[0]
major = version_str.split(".")[0]
return int(major) >= 4
except (ValueError, IndexError):
return False
@classmethod
def _supports_reasoning_effort_level(cls, model: str, level: str) -> bool:
"""Check if the model supports a specific reasoning_effort level.
Looks up ``supports_{level}_reasoning_effort`` in the model map via
the shared ``_supports_factory`` helper.
Returns False for unknown models (safe fallback).
"""
return _supports_factory(
model=model,
custom_llm_provider=None,
key=f"supports_{level}_reasoning_effort",
)
def get_supported_openai_params(self, model: str) -> list:
if self.is_model_gpt_5_search_model(model):
return [
"max_tokens",
"max_completion_tokens",
"stream",
"stream_options",
"web_search_options",
"service_tier",
"safety_identifier",
"response_format",
"user",
"store",
"verbosity",
"max_retries",
"extra_headers",
]
from litellm.utils import supports_tool_choice
base_gpt_series_params = super().get_supported_openai_params(model=model)
gpt_5_only_params = ["reasoning_effort", "verbosity"]
base_gpt_series_params.extend(gpt_5_only_params)
if not supports_tool_choice(model=model):
base_gpt_series_params.remove("tool_choice")
non_supported_params = [
"presence_penalty",
"frequency_penalty",
"stop",
"logit_bias",
"modalities",
"prediction",
"audio",
"web_search_options",
]
# gpt-5.1/5.2 support logprobs, top_p, top_logprobs when reasoning_effort="none"
if not self._supports_reasoning_effort_level(model, "none"):
non_supported_params.extend(["logprobs", "top_p", "top_logprobs"])
return [
param
for param in base_gpt_series_params
if param not in non_supported_params
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
if self.is_model_gpt_5_search_model(model):
if "max_tokens" in non_default_params:
optional_params["max_completion_tokens"] = non_default_params.pop(
"max_tokens"
)
return super()._map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
)
# Get raw reasoning_effort and effective effort level for all guards.
# Use effective_effort (extracted string) for xhigh validation, "none" checks, and
# tool/sampling guards — dict inputs like {"effort": "none", "summary": "detailed"}
# must be treated as effort="none" to avoid incorrect tool-drop or sampling errors.
raw_reasoning_effort = non_default_params.get(
"reasoning_effort"
) or optional_params.get("reasoning_effort")
effective_effort = _get_effort_level(raw_reasoning_effort)
# Normalize dict reasoning_effort to string for Chat Completions API.
# Example: {"effort": "high", "summary": "detailed"} -> "high"
if isinstance(raw_reasoning_effort, dict) and "effort" in raw_reasoning_effort:
normalized = _normalize_reasoning_effort_for_chat_completion(
raw_reasoning_effort
)
if normalized is not None:
if "reasoning_effort" in non_default_params:
non_default_params["reasoning_effort"] = normalized
if "reasoning_effort" in optional_params:
optional_params["reasoning_effort"] = normalized
if effective_effort is not None and effective_effort == "xhigh":
if not self._supports_reasoning_effort_level(model, "xhigh"):
if litellm.drop_params or drop_params:
non_default_params.pop("reasoning_effort", None)
else:
raise litellm.utils.UnsupportedParamsError(
message=(
"reasoning_effort='xhigh' is only supported for gpt-5.1-codex-max, gpt-5.2, and gpt-5.4+ models."
),
status_code=400,
)
################################################################
# max_tokens is not supported for gpt-5 models on OpenAI API
# Relevant issue: https://github.com/BerriAI/litellm/issues/13381
################################################################
if "max_tokens" in non_default_params:
optional_params["max_completion_tokens"] = non_default_params.pop(
"max_tokens"
)
# gpt-5.1/5.2 support logprobs, top_p, top_logprobs only when reasoning_effort="none"
supports_none = self._supports_reasoning_effort_level(model, "none")
if supports_none:
sampling_params = ["logprobs", "top_logprobs", "top_p"]
has_sampling = any(p in non_default_params for p in sampling_params)
if has_sampling and effective_effort not in (None, "none"):
if litellm.drop_params or drop_params:
for p in sampling_params:
non_default_params.pop(p, None)
else:
raise litellm.utils.UnsupportedParamsError(
message=(
"gpt-5.1/5.2/5.4 only support logprobs, top_p, top_logprobs when "
"reasoning_effort='none'. Current reasoning_effort='{}'. "
"To drop unsupported params set `litellm.drop_params = True`"
).format(effective_effort),
status_code=400,
)
if "temperature" in non_default_params:
temperature_value: Optional[float] = non_default_params.pop("temperature")
if temperature_value is not None:
# models supporting reasoning_effort="none" also support flexible temperature
if supports_none and (
effective_effort == "none" or effective_effort is None
):
optional_params["temperature"] = temperature_value
elif temperature_value == 1:
optional_params["temperature"] = temperature_value
elif litellm.drop_params or drop_params:
pass
else:
raise litellm.utils.UnsupportedParamsError(
message=(
"gpt-5 models (including gpt-5-codex) don't support temperature={}. "
"Only temperature=1 is supported. "
"For gpt-5.1, temperature is supported when reasoning_effort='none' (or not specified, as it defaults to 'none'). "
"To drop unsupported params set `litellm.drop_params = True`"
).format(temperature_value),
status_code=400,
)
return super()._map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
)

View File

@@ -0,0 +1,48 @@
"""
Support for GPT-4o audio Family
OpenAI Doc: https://platform.openai.com/docs/guides/audio/quickstart?audio-generation-quickstart-example=audio-in&lang=python
"""
import litellm
from .gpt_transformation import OpenAIGPTConfig
class OpenAIGPTAudioConfig(OpenAIGPTConfig):
"""
Reference: https://platform.openai.com/docs/guides/audio
"""
@classmethod
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model: str) -> list:
"""
Get the supported OpenAI params for the `gpt-audio` models
"""
all_openai_params = super().get_supported_openai_params(model=model)
audio_specific_params = ["audio"]
return all_openai_params + audio_specific_params
def is_model_gpt_audio_model(self, model: str) -> bool:
if model in litellm.open_ai_chat_completion_models and "audio" in model:
return True
return False
def _map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return super()._map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
)

View File

@@ -0,0 +1,819 @@
"""
Support for gpt model family
"""
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Coroutine,
Iterator,
List,
Literal,
Optional,
Tuple,
Union,
cast,
overload,
)
import httpx
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
_extract_reasoning_content,
_handle_invalid_parallel_tool_calls,
_should_convert_tool_call_to_json_mode,
)
from litellm.litellm_core_utils.prompt_templates.common_utils import get_tool_call_names
from litellm.litellm_core_utils.prompt_templates.image_handling import (
async_convert_url_to_base64,
convert_url_to_base64,
)
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionFileObject,
ChatCompletionFileObjectFile,
ChatCompletionImageObject,
ChatCompletionImageUrlObject,
OpenAIChatCompletionChoices,
OpenAIMessageContentListBlock,
)
from litellm.types.utils import (
ChatCompletionMessageToolCall,
Choices,
Function,
Message,
ModelResponse,
ModelResponseStream,
)
from litellm.utils import convert_to_model_response_object
from ..common_utils import OpenAIError
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from litellm.llms.base_llm.base_utils import BaseTokenCounter
from litellm.types.llms.openai import ChatCompletionToolParam
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters:
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
- `function_call` (string or object): This optional parameter controls how the model calls functions.
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
# Add a class variable to track if this is the base class
_is_base_class = True
frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None
logit_bias: Optional[dict] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
response_format: Optional[dict] = None
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
response_format: Optional[dict] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
self.__class__._is_base_class = False
@classmethod
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model: str) -> list:
base_params = [
"frequency_penalty",
"logit_bias",
"logprobs",
"top_logprobs",
"max_tokens",
"max_completion_tokens",
"modalities",
"prediction",
"n",
"presence_penalty",
"seed",
"stop",
"stream",
"stream_options",
"temperature",
"top_p",
"tools",
"tool_choice",
"function_call",
"functions",
"max_retries",
"extra_headers",
"parallel_tool_calls",
"audio",
"web_search_options",
"service_tier",
"safety_identifier",
"prompt_cache_key",
"prompt_cache_retention",
"store",
] # works across all models
model_specific_params = []
if (
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
): # gpt-4 does not support 'response_format'
model_specific_params.append("response_format")
# Normalize model name for responses API (e.g., "responses/gpt-4.1" -> "gpt-4.1")
model_for_check = (
model.split("responses/", 1)[1] if "responses/" in model else model
)
if (
model_for_check in litellm.open_ai_chat_completion_models
) or model_for_check in litellm.open_ai_text_completion_models:
model_specific_params.append(
"user"
) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
return base_params + model_specific_params
def _map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
If any supported_openai_params are in non_default_params, add them to optional_params, so they are use in API call
Args:
non_default_params (dict): Non-default parameters to filter.
optional_params (dict): Optional parameters to update.
model (str): Model name for parameter support check.
Returns:
dict: Updated optional_params with supported non-default parameters.
"""
supported_openai_params = self.get_supported_openai_params(model)
for param, value in non_default_params.items():
if param in supported_openai_params:
optional_params[param] = value
return optional_params
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return self._map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
)
def contains_pdf_url(self, content_item: ChatCompletionFileObjectFile) -> bool:
potential_pdf_url_starts = ["https://", "http://", "www."]
file_id = content_item.get("file_id")
if file_id and any(
file_id.startswith(start) for start in potential_pdf_url_starts
):
return True
return False
def _handle_pdf_url(
self, content_item: ChatCompletionFileObjectFile
) -> ChatCompletionFileObjectFile:
content_copy = content_item.copy()
file_id = content_copy.get("file_id")
if file_id is not None:
base64_data = convert_url_to_base64(file_id)
content_copy["file_data"] = base64_data
content_copy["filename"] = "my_file.pdf"
content_copy.pop("file_id")
return content_copy
async def _async_handle_pdf_url(
self, content_item: ChatCompletionFileObjectFile
) -> ChatCompletionFileObjectFile:
file_id = content_item.get("file_id")
if file_id is not None: # check for file id being url done in _handle_pdf_url
base64_data = await async_convert_url_to_base64(file_id)
content_item["file_data"] = base64_data
content_item["filename"] = "my_file.pdf"
content_item.pop("file_id")
return content_item
def _common_file_data_check(
self, content_item: ChatCompletionFileObjectFile
) -> ChatCompletionFileObjectFile:
file_data = content_item.get("file_data")
filename = content_item.get("filename")
if file_data is not None and filename is None:
content_item["filename"] = "my_file.pdf"
return content_item
def _apply_common_transform_content_item(
self,
content_item: OpenAIMessageContentListBlock,
) -> OpenAIMessageContentListBlock:
litellm_specific_params = {"format"}
if content_item.get("type") == "image_url":
content_item = cast(ChatCompletionImageObject, content_item)
if isinstance(content_item["image_url"], str):
content_item["image_url"] = {
"url": content_item["image_url"],
}
elif isinstance(content_item["image_url"], dict):
new_image_url_obj = ChatCompletionImageUrlObject(
**{ # type: ignore
k: v
for k, v in content_item["image_url"].items()
if k not in litellm_specific_params
}
)
content_item["image_url"] = new_image_url_obj
elif content_item.get("type") == "file":
content_item = cast(ChatCompletionFileObject, content_item)
file_obj = content_item["file"]
new_file_obj = ChatCompletionFileObjectFile(
**{ # type: ignore
k: v
for k, v in file_obj.items()
if k not in litellm_specific_params
}
)
content_item["file"] = new_file_obj
return content_item
def _transform_content_item(
self,
content_item: OpenAIMessageContentListBlock,
) -> OpenAIMessageContentListBlock:
content_item = self._apply_common_transform_content_item(content_item)
content_item_type = content_item.get("type")
potential_file_obj = content_item.get("file")
if content_item_type == "file" and potential_file_obj:
file_obj = cast(ChatCompletionFileObjectFile, potential_file_obj)
content_item_typed = cast(ChatCompletionFileObject, content_item)
if self.contains_pdf_url(file_obj):
file_obj = self._handle_pdf_url(file_obj)
file_obj = self._common_file_data_check(file_obj)
content_item_typed["file"] = file_obj
content_item = content_item_typed
return content_item
async def _async_transform_content_item(
self, content_item: OpenAIMessageContentListBlock, is_async: bool = False
) -> OpenAIMessageContentListBlock:
content_item = self._apply_common_transform_content_item(content_item)
content_item_type = content_item.get("type")
potential_file_obj = content_item.get("file")
if content_item_type == "file" and potential_file_obj:
file_obj = cast(ChatCompletionFileObjectFile, potential_file_obj)
content_item_typed = cast(ChatCompletionFileObject, content_item)
if self.contains_pdf_url(file_obj):
file_obj = await self._async_handle_pdf_url(file_obj)
file_obj = self._common_file_data_check(file_obj)
content_item_typed["file"] = file_obj
content_item = content_item_typed
return content_item
# fmt: off
@overload
def _transform_messages(
self, messages: List[AllMessageValues], model: str, is_async: Literal[True]
) -> Coroutine[Any, Any, List[AllMessageValues]]:
...
@overload
def _transform_messages(
self,
messages: List[AllMessageValues],
model: str,
is_async: Literal[False] = False,
) -> List[AllMessageValues]:
...
# fmt: on
def _transform_messages(
self, messages: List[AllMessageValues], model: str, is_async: bool = False
) -> Union[List[AllMessageValues], Coroutine[Any, Any, List[AllMessageValues]]]:
"""OpenAI no longer supports image_url as a string, so we need to convert it to a dict"""
async def _async_transform():
for message in messages:
message_content = message.get("content")
message_role = message.get("role")
if (
message_role == "user"
and message_content
and isinstance(message_content, list)
):
message_content_types = cast(
List[OpenAIMessageContentListBlock], message_content
)
for i, content_item in enumerate(message_content_types):
message_content_types[
i
] = await self._async_transform_content_item(
cast(OpenAIMessageContentListBlock, content_item),
)
return messages
if is_async:
return _async_transform()
else:
for message in messages:
message_content = message.get("content")
message_role = message.get("role")
if (
message_role == "user"
and message_content
and isinstance(message_content, list)
):
message_content_types = cast(
List[OpenAIMessageContentListBlock], message_content
)
for i, content_item in enumerate(message_content):
message_content_types[i] = self._transform_content_item(
cast(OpenAIMessageContentListBlock, content_item)
)
return messages
def remove_cache_control_flag_from_messages_and_tools(
self,
model: str, # allows overrides to selectively run this
messages: List[AllMessageValues],
tools: Optional[List["ChatCompletionToolParam"]] = None,
) -> Tuple[List[AllMessageValues], Optional[List["ChatCompletionToolParam"]]]:
from litellm.litellm_core_utils.prompt_templates.common_utils import (
filter_value_from_dict,
)
from litellm.types.llms.openai import ChatCompletionToolParam
for i, message in enumerate(messages):
messages[i] = cast(
AllMessageValues, filter_value_from_dict(message, "cache_control") # type: ignore
)
if tools is not None:
for i, tool in enumerate(tools):
tools[i] = cast(
ChatCompletionToolParam,
filter_value_from_dict(tool, "cache_control"), # type: ignore
)
return messages, tools
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform the overall request to be sent to the API.
Returns:
dict: The transformed request. Sent as the body of the API call.
"""
messages = self._transform_messages(messages=messages, model=model)
messages, tools = self.remove_cache_control_flag_from_messages_and_tools(
model=model, messages=messages, tools=optional_params.get("tools", [])
)
if tools is not None and len(tools) > 0:
optional_params["tools"] = tools
optional_params.pop("max_retries", None)
return {
"model": model,
"messages": messages,
**optional_params,
}
async def async_transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
transformed_messages = await self._transform_messages(
messages=messages, model=model, is_async=True
)
(
transformed_messages,
tools,
) = self.remove_cache_control_flag_from_messages_and_tools(
model=model,
messages=transformed_messages,
tools=optional_params.get("tools", []),
)
if tools is not None and len(tools) > 0:
optional_params["tools"] = tools
if self.__class__._is_base_class:
return {
"model": model,
"messages": transformed_messages,
**optional_params,
}
else:
## allow for any object specific behaviour to be handled
return self.transform_request(
model, messages, optional_params, litellm_params, headers
)
def _passed_in_tools(self, optional_params: dict) -> bool:
return optional_params.get("tools", None) is not None
def _check_and_fix_if_content_is_tool_call(
self, content: str, optional_params: dict
) -> Optional[ChatCompletionMessageToolCall]:
"""
Check if the content is a tool call
"""
import json
if not self._passed_in_tools(optional_params):
return None
tool_call_names = get_tool_call_names(optional_params.get("tools", []))
try:
json_content = json.loads(content)
if (
json_content.get("type") == "function"
and json_content.get("name") in tool_call_names
):
return ChatCompletionMessageToolCall(
function=Function(
name=json_content.get("name"),
arguments=json_content.get("arguments"),
)
)
except Exception:
return None
return None
def _get_finish_reason(self, message: Message, received_finish_reason: str) -> str:
if message.tool_calls is not None:
return "tool_calls"
else:
return received_finish_reason
def _transform_choices(
self,
choices: List[OpenAIChatCompletionChoices],
json_mode: Optional[bool] = None,
optional_params: Optional[dict] = None,
) -> List[Choices]:
transformed_choices = []
for choice in choices:
## HANDLE JSON MODE - anthropic returns single function call]
tool_calls = choice["message"].get("tool_calls", None)
new_tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
message_content = choice["message"].get("content", None)
if tool_calls is not None:
_openai_tool_calls = []
for _tc in tool_calls:
_openai_tc = ChatCompletionMessageToolCall(**_tc) # type: ignore
_openai_tool_calls.append(_openai_tc)
fixed_tool_calls = _handle_invalid_parallel_tool_calls(
_openai_tool_calls
)
if fixed_tool_calls is not None:
new_tool_calls = fixed_tool_calls
elif (
optional_params is not None
and message_content
and isinstance(message_content, str)
):
new_tool_call = self._check_and_fix_if_content_is_tool_call(
message_content, optional_params
)
if new_tool_call is not None:
choice["message"]["content"] = None # remove the content
new_tool_calls = [new_tool_call]
translated_message: Optional[Message] = None
finish_reason: Optional[str] = None
if new_tool_calls and _should_convert_tool_call_to_json_mode(
tool_calls=new_tool_calls,
convert_tool_call_to_json_mode=json_mode,
):
# to support response_format on claude models
json_mode_content_str: Optional[str] = (
str(new_tool_calls[0]["function"].get("arguments", "")) or None
)
if json_mode_content_str is not None:
translated_message = Message(content=json_mode_content_str)
finish_reason = "stop"
if translated_message is None:
## get the reasoning content
(
reasoning_content,
content_str,
) = _extract_reasoning_content(cast(dict, choice["message"]))
translated_message = Message(
role="assistant",
content=content_str,
reasoning_content=reasoning_content,
thinking_blocks=None,
tool_calls=new_tool_calls,
)
if finish_reason is None:
finish_reason = choice["finish_reason"]
translated_choice = Choices(
finish_reason=finish_reason,
index=choice["index"],
message=translated_message,
logprobs=None,
enhancements=None,
)
translated_choice.finish_reason = map_finish_reason(
self._get_finish_reason(translated_message, choice["finish_reason"])
)
transformed_choices.append(translated_choice)
return transformed_choices
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
Transform the response from the API.
Returns:
dict: The transformed response.
"""
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=raw_response.text,
additional_args={"complete_input_dict": request_data},
)
## RESPONSE OBJECT
try:
completion_response = raw_response.json()
except Exception as e:
response_headers = getattr(raw_response, "headers", None)
raise OpenAIError(
message="Unable to get json response - {}, Original Response: {}".format(
str(e), raw_response.text
),
status_code=raw_response.status_code,
headers=response_headers,
)
raw_response_headers = dict(raw_response.headers)
final_response_obj = convert_to_model_response_object(
response_object=completion_response,
model_response_object=model_response,
hidden_params={"headers": raw_response_headers},
_response_headers=raw_response_headers,
)
return cast(ModelResponse, final_response_obj)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return OpenAIError(
status_code=status_code,
message=error_message,
headers=cast(httpx.Headers, headers),
)
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for the API call.
Returns:
str: The complete URL for the API call.
"""
if api_base is None:
api_base = "https://api.openai.com"
endpoint = "chat/completions"
# Remove trailing slash from api_base if present
api_base = api_base.rstrip("/")
# Check if endpoint is already in the api_base
if endpoint in api_base:
return api_base
return f"{api_base}/{endpoint}"
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
# Ensure Content-Type is set to application/json
if "content-type" not in headers and "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return headers
def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
"""
Calls OpenAI's `/v1/models` endpoint and returns the list of models.
"""
if api_base is None:
api_base = "https://api.openai.com"
if api_key is None:
api_key = get_secret_str("OPENAI_API_KEY")
# Strip api_base to just the base URL (scheme + host + port)
parsed_url = httpx.URL(api_base)
base_url = f"{parsed_url.scheme}://{parsed_url.host}"
if parsed_url.port:
base_url += f":{parsed_url.port}"
response = litellm.module_level_client.get(
url=f"{base_url}/v1/models",
headers={"Authorization": f"Bearer {api_key}"},
)
if response.status_code != 200:
raise Exception(f"Failed to get models: {response.text}")
models = response.json()["data"]
return [model["id"] for model in models]
@staticmethod
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
return (
api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
)
@staticmethod
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
return (
api_base
or litellm.api_base
or get_secret_str("OPENAI_BASE_URL")
or get_secret_str("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
@staticmethod
def get_base_model(model: Optional[str] = None) -> Optional[str]:
return model
def get_token_counter(self) -> Optional["BaseTokenCounter"]:
from litellm.llms.openai.responses.count_tokens.token_counter import (
OpenAITokenCounter,
)
return OpenAITokenCounter()
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
) -> Any:
return OpenAIChatCompletionStreamingHandler(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
class OpenAIChatCompletionStreamingHandler(BaseModelResponseIterator):
def _map_reasoning_to_reasoning_content(self, choices: list) -> list:
"""
Map 'reasoning' field to 'reasoning_content' field in delta.
Some OpenAI-compatible providers (e.g., GLM-5, hosted_vllm) return
delta.reasoning, but LiteLLM expects delta.reasoning_content.
Args:
choices: List of choice objects from the streaming chunk
Returns:
List of choices with reasoning field mapped to reasoning_content
"""
for choice in choices:
delta = choice.get("delta", {})
if "reasoning" in delta:
delta["reasoning_content"] = delta.pop("reasoning")
return choices
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
try:
choices = chunk.get("choices", [])
choices = self._map_reasoning_to_reasoning_content(choices)
kwargs = {
"id": chunk["id"],
"object": "chat.completion.chunk",
"created": chunk.get("created"),
"model": chunk.get("model"),
"choices": choices,
}
if "usage" in chunk and chunk["usage"] is not None:
kwargs["usage"] = chunk["usage"]
return ModelResponseStream(**kwargs)
except Exception as e:
raise e

View File

@@ -0,0 +1,3 @@
Translation of OpenAI `/chat/completions` input and output to a custom guardrail.
This enables guardrails to be applied to OpenAI `/chat/completions` requests and responses.

View File

@@ -0,0 +1,12 @@
"""OpenAI Chat Completions message handler for Unified Guardrails."""
from litellm.llms.openai.chat.guardrail_translation.handler import (
OpenAIChatCompletionsHandler,
)
from litellm.types.utils import CallTypes
guardrail_translation_mappings = {
CallTypes.completion: OpenAIChatCompletionsHandler,
CallTypes.acompletion: OpenAIChatCompletionsHandler,
}
__all__ = ["guardrail_translation_mappings"]

View File

@@ -0,0 +1,808 @@
"""
OpenAI Chat Completions Message Handler for Unified Guardrails
This module provides a class-based handler for OpenAI-format chat completions.
The class methods can be overridden for custom behavior.
Pattern Overview:
-----------------
1. Extract text content from messages/responses (both string and list formats)
2. Create async tasks to apply guardrails to each text segment
3. Track mappings to know where each response belongs
4. Apply guardrail responses back to the original structure
This pattern can be replicated for other message formats (e.g., Anthropic).
"""
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
from litellm.main import stream_chunk_builder
from litellm.types.llms.openai import ChatCompletionToolParam
from litellm.types.utils import (
Choices,
GenericGuardrailAPIInputs,
ModelResponse,
ModelResponseStream,
StreamingChoices,
)
if TYPE_CHECKING:
from litellm.integrations.custom_guardrail import CustomGuardrail
class OpenAIChatCompletionsHandler(BaseTranslation):
"""
Handler for processing OpenAI chat completions messages with guardrails.
This class provides methods to:
1. Process input messages (pre-call hook)
2. Process output responses (post-call hook)
Methods can be overridden to customize behavior for different message formats.
"""
async def process_input_messages(
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process input messages by applying guardrails to text content.
"""
messages = data.get("messages")
if messages is None:
return data
texts_to_check: List[str] = []
images_to_check: List[str] = []
tool_calls_to_check: List[ChatCompletionToolParam] = []
text_task_mappings: List[Tuple[int, Optional[int]]] = []
tool_call_task_mappings: List[Tuple[int, int]] = []
# text_task_mappings: Track (message_index, content_index) for each text
# content_index is None for string content, int for list content
# tool_call_task_mappings: Track (message_index, tool_call_index) for each tool call
# Step 1: Extract all text content, images, and tool calls
for msg_idx, message in enumerate(messages):
self._extract_inputs(
message=message,
msg_idx=msg_idx,
texts_to_check=texts_to_check,
images_to_check=images_to_check,
tool_calls_to_check=tool_calls_to_check,
text_task_mappings=text_task_mappings,
tool_call_task_mappings=tool_call_task_mappings,
)
# Step 2: Apply guardrail to all texts and tool calls in batch
if texts_to_check or tool_calls_to_check:
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
if images_to_check:
inputs["images"] = images_to_check
if tool_calls_to_check:
inputs["tool_calls"] = tool_calls_to_check # type: ignore
if messages:
inputs[
"structured_messages"
] = messages # pass the openai /chat/completions messages to the guardrail, as-is
# Pass tools (function definitions) to the guardrail
tools = data.get("tools")
if tools:
inputs["tools"] = tools
# Include model information if available
model = data.get("model")
if model:
inputs["model"] = model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=data,
input_type="request",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
guardrailed_tool_calls = guardrailed_inputs.get("tool_calls", [])
guardrailed_tools = guardrailed_inputs.get("tools")
if guardrailed_tools is not None:
data["tools"] = guardrailed_tools
# Step 3: Map guardrail responses back to original message structure
if guardrailed_texts and texts_to_check:
await self._apply_guardrail_responses_to_input_texts(
messages=messages,
responses=guardrailed_texts,
task_mappings=text_task_mappings,
)
# Step 4: Apply guardrailed tool calls back to messages
if guardrailed_tool_calls:
# Note: The guardrail may modify tool_calls_to_check in place
# or we may need to handle returned tool calls differently
await self._apply_guardrail_responses_to_input_tool_calls(
messages=messages,
tool_calls=guardrailed_tool_calls, # type: ignore
task_mappings=tool_call_task_mappings,
)
verbose_proxy_logger.debug(
"OpenAI Chat Completions: Processed input messages: %s", messages
)
return data
def extract_request_tool_names(self, data: dict) -> List[str]:
"""Extract tool names from OpenAI chat completions request (tools[].function.name, functions[].name)."""
names: List[str] = []
for tool in data.get("tools") or []:
if isinstance(tool, dict) and tool.get("type") == "function":
fn = tool.get("function")
if isinstance(fn, dict) and fn.get("name"):
names.append(str(fn["name"]))
for fn in data.get("functions") or []:
if isinstance(fn, dict) and fn.get("name"):
names.append(str(fn["name"]))
return names
def _extract_inputs(
self,
message: Dict[str, Any],
msg_idx: int,
texts_to_check: List[str],
images_to_check: List[str],
tool_calls_to_check: List[ChatCompletionToolParam],
text_task_mappings: List[Tuple[int, Optional[int]]],
tool_call_task_mappings: List[Tuple[int, int]],
) -> None:
"""
Extract text content, images, and tool calls from a message.
Override this method to customize text/image/tool call extraction logic.
"""
content = message.get("content", None)
if content is not None:
if isinstance(content, str):
# Simple string content
texts_to_check.append(content)
text_task_mappings.append((msg_idx, None))
elif isinstance(content, list):
# List content (e.g., multimodal with text and images)
for content_idx, content_item in enumerate(content):
# Extract text
text_str = content_item.get("text", None)
if text_str is not None:
texts_to_check.append(text_str)
text_task_mappings.append((msg_idx, int(content_idx)))
# Extract images (image_url)
if content_item.get("type") == "image_url":
image_url = content_item.get("image_url", {})
if isinstance(image_url, dict):
url = image_url.get("url")
if url:
images_to_check.append(url)
elif isinstance(image_url, str):
images_to_check.append(image_url)
# Extract tool calls (typically in assistant messages)
tool_calls = message.get("tool_calls", None)
if tool_calls is not None and isinstance(tool_calls, list):
for tool_call_idx, tool_call in enumerate(tool_calls):
if isinstance(tool_call, dict):
# Add the full tool call object to the list
tool_calls_to_check.append(cast(ChatCompletionToolParam, tool_call))
tool_call_task_mappings.append((msg_idx, int(tool_call_idx)))
async def _apply_guardrail_responses_to_input_texts(
self,
messages: List[Dict[str, Any]],
responses: List[str],
task_mappings: List[Tuple[int, Optional[int]]],
) -> None:
"""
Apply guardrail responses back to input message text content.
Override this method to customize how text responses are applied.
"""
for task_idx, guardrail_response in enumerate(responses):
mapping = task_mappings[task_idx]
msg_idx = cast(int, mapping[0])
content_idx_optional = cast(Optional[int], mapping[1])
# Handle content
content = messages[msg_idx].get("content", None)
if content is None:
continue
if isinstance(content, str) and content_idx_optional is None:
# Replace string content with guardrail response
messages[msg_idx]["content"] = guardrail_response
elif isinstance(content, list) and content_idx_optional is not None:
# Replace specific text item in list content
messages[msg_idx]["content"][content_idx_optional][
"text"
] = guardrail_response
async def _apply_guardrail_responses_to_input_tool_calls(
self,
messages: List[Dict[str, Any]],
tool_calls: List[Dict[str, Any]],
task_mappings: List[Tuple[int, int]],
) -> None:
"""
Apply guardrailed tool calls back to input messages.
The guardrail may have modified the tool_calls list in place,
so we apply the modified tool calls back to the original messages.
Override this method to customize how tool call responses are applied.
"""
for task_idx, (msg_idx, tool_call_idx) in enumerate(task_mappings):
if task_idx < len(tool_calls):
guardrailed_tool_call = tool_calls[task_idx]
message_tool_calls = messages[msg_idx].get("tool_calls", None)
if message_tool_calls is not None and isinstance(
message_tool_calls, list
):
if tool_call_idx < len(message_tool_calls):
# Replace the tool call with the guardrailed version
message_tool_calls[tool_call_idx] = guardrailed_tool_call
async def process_output_response(
self,
response: "ModelResponse",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
) -> Any:
"""
Process output response by applying guardrails to text content.
Args:
response: LiteLLM ModelResponse object
guardrail_to_apply: The guardrail instance to apply
litellm_logging_obj: Optional logging object
user_api_key_dict: User API key metadata to pass to guardrails
Returns:
Modified response with guardrail applied to content
Response Format Support:
- String content: choice.message.content = "text here"
- List content: choice.message.content = [{"type": "text", "text": "text here"}, ...]
"""
# Step 0: Check if response has any text content to process
if not self._has_text_content(response):
verbose_proxy_logger.warning(
"OpenAI Chat Completions: No text content in response, skipping guardrail"
)
return response
texts_to_check: List[str] = []
images_to_check: List[str] = []
tool_calls_to_check: List[Dict[str, Any]] = []
text_task_mappings: List[Tuple[int, Optional[int]]] = []
tool_call_task_mappings: List[Tuple[int, int]] = []
# text_task_mappings: Track (choice_index, content_index) for each text
# content_index is None for string content, int for list content
# tool_call_task_mappings: Track (choice_index, tool_call_index) for each tool call
# Step 1: Extract all text content, images, and tool calls from response choices
for choice_idx, choice in enumerate(response.choices):
self._extract_output_text_images_and_tool_calls(
choice=choice,
choice_idx=choice_idx,
texts_to_check=texts_to_check,
images_to_check=images_to_check,
tool_calls_to_check=tool_calls_to_check,
text_task_mappings=text_task_mappings,
tool_call_task_mappings=tool_call_task_mappings,
)
# Step 2: Apply guardrail to all texts and tool calls in batch
if texts_to_check or tool_calls_to_check:
# Create a request_data dict with response info and user API key metadata
request_data: dict = {"response": response}
# Add user API key metadata with prefixed keys
user_metadata = self.transform_user_api_key_dict_to_metadata(
user_api_key_dict
)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
if images_to_check:
inputs["images"] = images_to_check
if tool_calls_to_check:
inputs["tool_calls"] = tool_calls_to_check # type: ignore
# Include model information from the response if available
if hasattr(response, "model") and response.model:
inputs["model"] = response.model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="response",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
# Step 3: Map guardrail responses back to original response structure
if guardrailed_texts and texts_to_check:
await self._apply_guardrail_responses_to_output_texts(
response=response,
responses=guardrailed_texts,
task_mappings=text_task_mappings,
)
# Step 4: Apply guardrailed tool calls back to response
if tool_calls_to_check:
await self._apply_guardrail_responses_to_output_tool_calls(
response=response,
tool_calls=tool_calls_to_check,
task_mappings=tool_call_task_mappings,
)
verbose_proxy_logger.debug(
"OpenAI Chat Completions: Processed output response: %s", response
)
return response
async def process_output_streaming_response(
self,
responses_so_far: List["ModelResponseStream"],
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
) -> List["ModelResponseStream"]:
"""
Process output streaming responses by applying guardrails to text content.
Args:
responses_so_far: List of LiteLLM ModelResponseStream objects
guardrail_to_apply: The guardrail instance to apply
litellm_logging_obj: Optional logging object
user_api_key_dict: User API key metadata to pass to guardrails
Returns:
Modified list of responses with guardrail applied to content
Response Format Support:
- String content: choice.message.content = "text here"
- List content: choice.message.content = [{"type": "text", "text": "text here"}, ...]
"""
# check if the stream has ended
has_stream_ended = False
for chunk in responses_so_far:
if chunk.choices and chunk.choices[0].finish_reason is not None:
has_stream_ended = True
break
if has_stream_ended:
# convert to model response
model_response = cast(
ModelResponse,
stream_chunk_builder(
chunks=responses_so_far, logging_obj=litellm_logging_obj
),
)
# run process_output_response
await self.process_output_response(
response=model_response,
guardrail_to_apply=guardrail_to_apply,
litellm_logging_obj=litellm_logging_obj,
user_api_key_dict=user_api_key_dict,
)
return responses_so_far
# Step 0: Check if any response has text content to process
has_any_text_content = False
for response in responses_so_far:
if self._has_text_content(response):
has_any_text_content = True
break
if not has_any_text_content:
verbose_proxy_logger.warning(
"OpenAI Chat Completions: No text content in streaming responses, skipping guardrail"
)
return responses_so_far
# Step 1: Combine all streaming chunks into complete text per choice
# For streaming, we need to concatenate all delta.content across all chunks
# Key: (choice_idx, content_idx), Value: combined text
combined_texts = self._combine_streaming_texts(responses_so_far)
# Step 2: Create lists for guardrail processing
texts_to_check: List[str] = []
images_to_check: List[str] = []
task_mappings: List[Tuple[int, Optional[int]]] = []
# Track (choice_index, content_index) for each combined text
for (map_choice_idx, map_content_idx), combined_text in combined_texts.items():
texts_to_check.append(combined_text)
task_mappings.append((map_choice_idx, map_content_idx))
# Step 3: Apply guardrail to all combined texts in batch
if texts_to_check:
# Create a request_data dict with response info and user API key metadata
request_data: dict = {"responses": responses_so_far}
# Add user API key metadata with prefixed keys
user_metadata = self.transform_user_api_key_dict_to_metadata(
user_api_key_dict
)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
if images_to_check:
inputs["images"] = images_to_check
# Include model information from the first response if available
if (
responses_so_far
and hasattr(responses_so_far[0], "model")
and responses_so_far[0].model
):
inputs["model"] = responses_so_far[0].model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="response",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
# Step 4: Apply guardrailed text back to all streaming chunks
# For each choice, replace the combined text across all chunks
await self._apply_guardrail_responses_to_output_streaming(
responses=responses_so_far,
guardrailed_texts=guardrailed_texts,
task_mappings=task_mappings,
)
verbose_proxy_logger.debug(
"OpenAI Chat Completions: Processed output streaming responses: %s",
responses_so_far,
)
return responses_so_far
def _combine_streaming_texts(
self, responses_so_far: List["ModelResponseStream"]
) -> Dict[Tuple[int, Optional[int]], str]:
"""
Combine all streaming chunks into complete text per choice.
For streaming, we need to concatenate all delta.content across all chunks.
Args:
responses_so_far: List of LiteLLM ModelResponseStream objects
Returns:
Dict mapping (choice_idx, content_idx) to combined text string
"""
combined_texts: Dict[Tuple[int, Optional[int]], str] = {}
for response_idx, response in enumerate(responses_so_far):
for choice_idx, choice in enumerate(response.choices):
if isinstance(choice, litellm.StreamingChoices):
content = choice.delta.content
elif isinstance(choice, litellm.Choices):
content = choice.message.content
else:
continue
if content is None:
continue
if isinstance(content, str):
# String content - accumulate for this choice
str_key: Tuple[int, Optional[int]] = (choice_idx, None)
if str_key not in combined_texts:
combined_texts[str_key] = ""
combined_texts[str_key] += content
elif isinstance(content, list):
# List content - accumulate for each content item
for content_idx, content_item in enumerate(content):
text_str = content_item.get("text")
if text_str:
list_key: Tuple[int, Optional[int]] = (
choice_idx,
content_idx,
)
if list_key not in combined_texts:
combined_texts[list_key] = ""
combined_texts[list_key] += text_str
return combined_texts
def _has_text_content(
self, response: Union["ModelResponse", "ModelResponseStream"]
) -> bool:
"""
Check if response has any text content or tool calls to process.
Override this method to customize text content detection.
"""
from litellm.types.utils import ModelResponse, ModelResponseStream
if isinstance(response, ModelResponse):
for choice in response.choices:
if isinstance(choice, litellm.Choices):
# Check for text content
if choice.message.content and isinstance(
choice.message.content, str
):
return True
# Check for tool calls
if choice.message.tool_calls and isinstance(
choice.message.tool_calls, list
):
if len(choice.message.tool_calls) > 0:
return True
elif isinstance(response, ModelResponseStream):
for streaming_choice in response.choices:
if isinstance(streaming_choice, litellm.StreamingChoices):
# Check for text content
if streaming_choice.delta.content and isinstance(
streaming_choice.delta.content, str
):
return True
# Check for tool calls
if streaming_choice.delta.tool_calls and isinstance(
streaming_choice.delta.tool_calls, list
):
if len(streaming_choice.delta.tool_calls) > 0:
return True
return False
def _extract_output_text_images_and_tool_calls(
self,
choice: Union[Choices, StreamingChoices],
choice_idx: int,
texts_to_check: List[str],
images_to_check: List[str],
tool_calls_to_check: List[Dict[str, Any]],
text_task_mappings: List[Tuple[int, Optional[int]]],
tool_call_task_mappings: List[Tuple[int, int]],
) -> None:
"""
Extract text content, images, and tool calls from a response choice.
Override this method to customize text/image/tool call extraction logic.
"""
verbose_proxy_logger.debug(
"OpenAI Chat Completions: Processing choice: %s", choice
)
# Determine content source and tool calls based on choice type
content = None
tool_calls: Optional[List[Any]] = None
if isinstance(choice, litellm.Choices):
content = choice.message.content
tool_calls = choice.message.tool_calls
elif isinstance(choice, litellm.StreamingChoices):
content = choice.delta.content
tool_calls = choice.delta.tool_calls
else:
# Unknown choice type, skip processing
return
# Process content if it exists
if content and isinstance(content, str):
# Simple string content
texts_to_check.append(content)
text_task_mappings.append((choice_idx, None))
elif content and isinstance(content, list):
# List content (e.g., multimodal response)
for content_idx, content_item in enumerate(content):
# Extract text
content_text = content_item.get("text")
if content_text:
texts_to_check.append(content_text)
text_task_mappings.append((choice_idx, int(content_idx)))
# Extract images
if content_item.get("type") == "image_url":
image_url = content_item.get("image_url", {})
if isinstance(image_url, dict):
url = image_url.get("url")
if url:
images_to_check.append(url)
# Process tool calls if they exist
if tool_calls is not None and isinstance(tool_calls, list):
for tool_call_idx, tool_call in enumerate(tool_calls):
# Convert tool call to dict format for guardrail processing
tool_call_dict = self._convert_tool_call_to_dict(tool_call)
if tool_call_dict:
tool_calls_to_check.append(tool_call_dict)
tool_call_task_mappings.append((choice_idx, int(tool_call_idx)))
def _convert_tool_call_to_dict(
self, tool_call: Union[Dict[str, Any], Any]
) -> Optional[Dict[str, Any]]:
"""
Convert a tool call object to dictionary format.
Tool calls can be either dict or object depending on the type.
"""
if isinstance(tool_call, dict):
return tool_call
elif hasattr(tool_call, "id") and hasattr(tool_call, "function"):
# Convert object to dict
function = tool_call.function
function_dict = {}
if hasattr(function, "name"):
function_dict["name"] = function.name
if hasattr(function, "arguments"):
function_dict["arguments"] = function.arguments
tool_call_dict = {
"id": tool_call.id if hasattr(tool_call, "id") else None,
"type": tool_call.type if hasattr(tool_call, "type") else "function",
"function": function_dict,
}
return tool_call_dict
return None
async def _apply_guardrail_responses_to_output_texts(
self,
response: "ModelResponse",
responses: List[str],
task_mappings: List[Tuple[int, Optional[int]]],
) -> None:
"""
Apply guardrail text responses back to output response.
Override this method to customize how text responses are applied.
"""
for task_idx, guardrail_response in enumerate(responses):
mapping = task_mappings[task_idx]
choice_idx = cast(int, mapping[0])
content_idx_optional = cast(Optional[int], mapping[1])
choice = cast(Choices, response.choices[choice_idx])
# Handle content
content = choice.message.content
if content is None:
continue
if isinstance(content, str) and content_idx_optional is None:
# Replace string content with guardrail response
choice.message.content = guardrail_response
elif isinstance(content, list) and content_idx_optional is not None:
# Replace specific text item in list content
choice.message.content[content_idx_optional]["text"] = guardrail_response # type: ignore
async def _apply_guardrail_responses_to_output_tool_calls(
self,
response: "ModelResponse",
tool_calls: List[Dict[str, Any]],
task_mappings: List[Tuple[int, int]],
) -> None:
"""
Apply guardrailed tool calls back to output response.
The guardrail may have modified the tool_calls list in place,
so we apply the modified tool calls back to the original response.
Override this method to customize how tool call responses are applied.
"""
for task_idx, (choice_idx, tool_call_idx) in enumerate(task_mappings):
if task_idx < len(tool_calls):
guardrailed_tool_call = tool_calls[task_idx]
choice = cast(Choices, response.choices[choice_idx])
choice_tool_calls = choice.message.tool_calls
if choice_tool_calls is not None and isinstance(
choice_tool_calls, list
):
if tool_call_idx < len(choice_tool_calls):
# Update the tool call with guardrailed version
existing_tool_call = choice_tool_calls[tool_call_idx]
# Update object attributes (output responses always have typed objects)
if "function" in guardrailed_tool_call:
func_dict = guardrailed_tool_call["function"]
if "arguments" in func_dict:
existing_tool_call.function.arguments = func_dict[
"arguments"
]
if "name" in func_dict:
existing_tool_call.function.name = func_dict["name"]
async def _apply_guardrail_responses_to_output_streaming(
self,
responses: List["ModelResponseStream"],
guardrailed_texts: List[str],
task_mappings: List[Tuple[int, Optional[int]]],
) -> None:
"""
Apply guardrail responses back to output streaming responses.
For streaming responses, the guardrailed text (which is the combined text from all chunks)
is placed in the first chunk, and subsequent chunks are cleared.
Args:
responses: List of ModelResponseStream objects to modify
guardrailed_texts: List of guardrailed text responses (combined from all chunks)
task_mappings: List of tuples (choice_idx, content_idx)
Override this method to customize how responses are applied to streaming responses.
"""
# Build a mapping of what guardrailed text to use for each (choice_idx, content_idx)
guardrail_map: Dict[Tuple[int, Optional[int]], str] = {}
for task_idx, guardrail_response in enumerate(guardrailed_texts):
mapping = task_mappings[task_idx]
choice_idx = cast(int, mapping[0])
content_idx_optional = cast(Optional[int], mapping[1])
guardrail_map[(choice_idx, content_idx_optional)] = guardrail_response
# Track which choices we've already set the guardrailed text for
# Key: (choice_idx, content_idx), Value: boolean (True if already set)
already_set: Dict[Tuple[int, Optional[int]], bool] = {}
# Iterate through all responses and update content
for response_idx, response in enumerate(responses):
for choice_idx_in_response, choice in enumerate(response.choices):
if isinstance(choice, litellm.StreamingChoices):
content = choice.delta.content
elif isinstance(choice, litellm.Choices):
content = choice.message.content
else:
continue
if content is None:
continue
if isinstance(content, str):
# String content
str_key: Tuple[int, Optional[int]] = (choice_idx_in_response, None)
if str_key in guardrail_map:
if str_key not in already_set:
# First chunk - set the complete guardrailed text
if isinstance(choice, litellm.StreamingChoices):
choice.delta.content = guardrail_map[str_key]
elif isinstance(choice, litellm.Choices):
choice.message.content = guardrail_map[str_key]
already_set[str_key] = True
else:
# Subsequent chunks - clear the content
if isinstance(choice, litellm.StreamingChoices):
choice.delta.content = ""
elif isinstance(choice, litellm.Choices):
choice.message.content = ""
elif isinstance(content, list):
# List content - handle each content item
for content_idx, content_item in enumerate(content):
if "text" in content_item:
list_key: Tuple[int, Optional[int]] = (
choice_idx_in_response,
content_idx,
)
if list_key in guardrail_map:
if list_key not in already_set:
# First chunk - set the complete guardrailed text
content_item["text"] = guardrail_map[list_key]
already_set[list_key] = True
else:
# Subsequent chunks - clear the text
content_item["text"] = ""

View File

@@ -0,0 +1,3 @@
"""
LLM Calling done in `openai/openai.py`
"""

View File

@@ -0,0 +1,179 @@
"""
Support for o1/o3 model family
https://platform.openai.com/docs/guides/reasoning
Translations handled by LiteLLM:
- modalities: image => drop param (if user opts in to dropping param)
- role: system ==> translate to role 'user'
- streaming => faked by LiteLLM
- Tools, response_format => drop param (if user opts in to dropping param)
- Logprobs => drop param (if user opts in to dropping param)
"""
from typing import Any, Coroutine, List, Literal, Optional, Union, cast, overload
import litellm
from litellm import verbose_logger
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
from litellm.utils import (
supports_function_calling,
supports_parallel_function_calling,
supports_response_schema,
supports_system_messages,
)
from .gpt_transformation import OpenAIGPTConfig
class OpenAIOSeriesConfig(OpenAIGPTConfig):
"""
Reference: https://platform.openai.com/docs/guides/reasoning
"""
@classmethod
def get_config(cls):
return super().get_config()
def translate_developer_role_to_system_role(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
"""
O-series models support `developer` role.
"""
return messages
def get_supported_openai_params(self, model: str) -> list:
"""
Get the supported OpenAI params for the given model
"""
all_openai_params = super().get_supported_openai_params(model=model)
non_supported_params = [
"logprobs",
"top_p",
"presence_penalty",
"frequency_penalty",
"top_logprobs",
]
o_series_only_param = ["reasoning_effort"]
all_openai_params.extend(o_series_only_param)
try:
model, custom_llm_provider, api_base, api_key = get_llm_provider(
model=model
)
except Exception:
verbose_logger.debug(
f"Unable to infer model provider for model={model}, defaulting to openai for o1 supported param check"
)
custom_llm_provider = "openai"
_supports_function_calling = supports_function_calling(
model, custom_llm_provider
)
_supports_response_schema = supports_response_schema(model, custom_llm_provider)
_supports_parallel_tool_calls = supports_parallel_function_calling(
model, custom_llm_provider
)
if not _supports_function_calling:
non_supported_params.append("tools")
non_supported_params.append("tool_choice")
non_supported_params.append("function_call")
non_supported_params.append("functions")
if not _supports_parallel_tool_calls:
non_supported_params.append("parallel_tool_calls")
if not _supports_response_schema:
non_supported_params.append("response_format")
return [
param for param in all_openai_params if param not in non_supported_params
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
):
if "max_tokens" in non_default_params:
optional_params["max_completion_tokens"] = non_default_params.pop(
"max_tokens"
)
if "temperature" in non_default_params:
temperature_value: Optional[float] = non_default_params.pop("temperature")
if temperature_value is not None:
if temperature_value == 1:
optional_params["temperature"] = temperature_value
else:
## UNSUPPORTED TOOL CHOICE VALUE
if litellm.drop_params is True or drop_params is True:
pass
else:
raise litellm.utils.UnsupportedParamsError(
message="O-series models don't support temperature={}. Only temperature=1 is supported. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
temperature_value
),
status_code=400,
)
return super()._map_openai_params(
non_default_params, optional_params, model, drop_params
)
def is_model_o_series_model(self, model: str) -> bool:
model = model.split("/")[-1] # could be "openai/o3" or "o3"
return (
len(model) > 1
and model[0] == "o"
and model[1].isdigit()
and model in litellm.open_ai_chat_completion_models
)
@overload
def _transform_messages(
self, messages: List[AllMessageValues], model: str, is_async: Literal[True]
) -> Coroutine[Any, Any, List[AllMessageValues]]:
...
@overload
def _transform_messages(
self,
messages: List[AllMessageValues],
model: str,
is_async: Literal[False] = False,
) -> List[AllMessageValues]:
...
def _transform_messages(
self, messages: List[AllMessageValues], model: str, is_async: bool = False
) -> Union[List[AllMessageValues], Coroutine[Any, Any, List[AllMessageValues]]]:
"""
Handles limitations of O-1 model family.
- modalities: image => drop param (if user opts in to dropping param)
- role: system ==> translate to role 'user'
"""
_supports_system_messages = supports_system_messages(model, "openai")
for i, message in enumerate(messages):
if message["role"] == "system" and not _supports_system_messages:
new_message = ChatCompletionUserMessage(
content=message["content"], role="user"
)
messages[i] = new_message # Replace the old message with the new one
if is_async:
return super()._transform_messages(
messages, model, is_async=cast(Literal[True], True)
)
else:
return super()._transform_messages(
messages, model, is_async=cast(Literal[False], False)
)

View File

@@ -0,0 +1,290 @@
"""
Common helpers / utils across al OpenAI endpoints
"""
import hashlib
import inspect
import json
import os
import ssl
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
NamedTuple,
Optional,
Tuple,
Union,
)
import httpx
import openai
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
if TYPE_CHECKING:
from aiohttp import ClientSession
import litellm
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.custom_httpx.http_handler import (
_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
AsyncHTTPHandler,
get_ssl_configuration,
)
def _get_client_init_params(cls: type) -> Tuple[str, ...]:
"""Extract __init__ parameter names (excluding 'self') from a class."""
return tuple(p for p in inspect.signature(cls.__init__).parameters if p != "self") # type: ignore[misc]
_OPENAI_INIT_PARAMS: Tuple[str, ...] = _get_client_init_params(OpenAI)
_AZURE_OPENAI_INIT_PARAMS: Tuple[str, ...] = _get_client_init_params(AzureOpenAI)
class OpenAIError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[Union[dict, httpx.Headers]] = None,
body: Optional[dict] = None,
):
self.status_code = status_code
self.message = message
self.headers = headers
if request:
self.request = request
else:
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
if response:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
status_code=status_code,
message=self.message,
headers=self.headers,
request=self.request,
response=self.response,
body=body,
)
####### Error Handling Utils for OpenAI API #######################
###################################################################
def drop_params_from_unprocessable_entity_error(
e: Union[openai.UnprocessableEntityError, httpx.HTTPStatusError],
data: Dict[str, Any],
) -> Dict[str, Any]:
"""
Helper function to read OpenAI UnprocessableEntityError and drop the params that raised an error from the error message.
Args:
e (UnprocessableEntityError): The UnprocessableEntityError exception
data (Dict[str, Any]): The original data dictionary containing all parameters
Returns:
Dict[str, Any]: A new dictionary with invalid parameters removed
"""
invalid_params: List[str] = []
if isinstance(e, httpx.HTTPStatusError):
error_json = e.response.json()
error_message = error_json.get("error", {})
error_body = error_message
else:
error_body = e.body
if (
error_body is not None
and isinstance(error_body, dict)
and error_body.get("message")
):
message = error_body.get("message", {})
if isinstance(message, str):
try:
message = json.loads(message)
except json.JSONDecodeError:
message = {"detail": message}
detail = message.get("detail")
if isinstance(detail, List) and len(detail) > 0 and isinstance(detail[0], dict):
for error_dict in detail:
if (
error_dict.get("loc")
and isinstance(error_dict.get("loc"), list)
and len(error_dict.get("loc")) == 2
):
invalid_params.append(error_dict["loc"][1])
new_data = {k: v for k, v in data.items() if k not in invalid_params}
return new_data
class BaseOpenAILLM:
"""
Base class for OpenAI LLMs for getting their httpx clients and SSL verification settings
"""
@staticmethod
def get_cached_openai_client(
client_initialization_params: dict, client_type: Literal["openai", "azure"]
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]]:
"""Retrieves the OpenAI client from the in-memory cache based on the client initialization parameters"""
_cache_key = BaseOpenAILLM.get_openai_client_cache_key(
client_initialization_params=client_initialization_params,
client_type=client_type,
)
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key)
return _cached_client
@staticmethod
def set_cached_openai_client(
openai_client: Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI],
client_type: Literal["openai", "azure"],
client_initialization_params: dict,
):
"""Stores the OpenAI client in the in-memory cache for _DEFAULT_TTL_FOR_HTTPX_CLIENTS SECONDS"""
_cache_key = BaseOpenAILLM.get_openai_client_cache_key(
client_initialization_params=client_initialization_params,
client_type=client_type,
)
litellm.in_memory_llm_clients_cache.set_cache(
key=_cache_key,
value=openai_client,
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
)
@staticmethod
def get_openai_client_cache_key(
client_initialization_params: dict, client_type: Literal["openai", "azure"]
) -> str:
"""Creates a cache key for the OpenAI client based on the client initialization parameters"""
hashed_api_key = None
if client_initialization_params.get("api_key") is not None:
hash_object = hashlib.sha256(
client_initialization_params.get("api_key", "").encode()
)
# Hexadecimal representation of the hash
hashed_api_key = hash_object.hexdigest()
# Create a more readable cache key using a list of key-value pairs
key_parts = [
f"hashed_api_key={hashed_api_key}",
f"is_async={client_initialization_params.get('is_async')}",
]
LITELLM_CLIENT_SPECIFIC_PARAMS = (
"timeout",
"max_retries",
"organization",
"api_base",
)
openai_client_fields = (
BaseOpenAILLM.get_openai_client_initialization_param_fields(
client_type=client_type
)
+ LITELLM_CLIENT_SPECIFIC_PARAMS
)
for param in openai_client_fields:
key_parts.append(f"{param}={client_initialization_params.get(param)}")
_cache_key = ",".join(key_parts)
return _cache_key
@staticmethod
def get_openai_client_initialization_param_fields(
client_type: Literal["openai", "azure"]
) -> Tuple[str, ...]:
"""Returns a tuple of fields that are used to initialize the OpenAI client"""
if client_type == "openai":
return _OPENAI_INIT_PARAMS
else:
return _AZURE_OPENAI_INIT_PARAMS
@staticmethod
def _get_async_http_client(
shared_session: Optional["ClientSession"] = None,
) -> Optional[httpx.AsyncClient]:
if litellm.aclient_session is not None:
return litellm.aclient_session
if getattr(litellm, "network_mock", False):
from litellm.llms.custom_httpx.mock_transport import MockOpenAITransport
return httpx.AsyncClient(transport=MockOpenAITransport())
# Get unified SSL configuration
ssl_config = get_ssl_configuration()
return httpx.AsyncClient(
verify=ssl_config,
transport=AsyncHTTPHandler._create_async_transport(
ssl_context=ssl_config
if isinstance(ssl_config, ssl.SSLContext)
else None,
ssl_verify=ssl_config if isinstance(ssl_config, bool) else None,
shared_session=shared_session,
),
follow_redirects=True,
)
@staticmethod
def _get_sync_http_client() -> Optional[httpx.Client]:
if litellm.client_session is not None:
return litellm.client_session
if getattr(litellm, "network_mock", False):
from litellm.llms.custom_httpx.mock_transport import MockOpenAITransport
return httpx.Client(transport=MockOpenAITransport())
# Get unified SSL configuration
ssl_config = get_ssl_configuration()
return httpx.Client(
verify=ssl_config,
follow_redirects=True,
)
class OpenAICredentials(NamedTuple):
api_base: str
api_key: Optional[str]
organization: Optional[str]
def get_openai_credentials(
api_base: Optional[str] = None,
api_key: Optional[str] = None,
organization: Optional[str] = None,
) -> OpenAICredentials:
"""Resolve OpenAI credentials from params, litellm globals, and env vars."""
resolved_api_base = (
api_base
or litellm.api_base
or os.getenv("OPENAI_BASE_URL")
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
resolved_organization = (
organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None
)
resolved_api_key = (
api_key or litellm.api_key or litellm.openai_key or os.getenv("OPENAI_API_KEY")
)
return OpenAICredentials(
api_base=resolved_api_base,
api_key=resolved_api_key,
organization=resolved_organization,
)

View File

@@ -0,0 +1,158 @@
# OpenAI Text Completion Guardrail Translation Handler
Handler for processing OpenAI's text completion endpoint (`/v1/completions`) with guardrails.
## Overview
This handler processes text completion requests by:
1. Extracting the text prompt(s) from the request
2. Applying guardrails to the prompt text(s)
3. Updating the request with the guardrailed prompt(s)
4. Applying guardrails to the completion output text
## Data Format
### Input Format
**Single Prompt:**
```json
{
"model": "gpt-3.5-turbo-instruct",
"prompt": "Say this is a test",
"max_tokens": 7,
"temperature": 0
}
```
**Multiple Prompts (Batch):**
```json
{
"model": "gpt-3.5-turbo-instruct",
"prompt": [
"Tell me a joke",
"Write a poem"
],
"max_tokens": 50
}
```
### Output Format
```json
{
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
"object": "text_completion",
"created": 1589478378,
"model": "gpt-3.5-turbo-instruct",
"choices": [
{
"text": "\n\nThis is indeed a test",
"index": 0,
"logprobs": null,
"finish_reason": "length"
}
],
"usage": {
"prompt_tokens": 5,
"completion_tokens": 7,
"total_tokens": 12
}
}
```
## Usage
The handler is automatically discovered and applied when guardrails are used with the text completion endpoint.
### Example: Using Guardrails with Text Completion
```bash
curl -X POST 'http://localhost:4000/v1/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer your-api-key' \
-d '{
"model": "gpt-3.5-turbo-instruct",
"prompt": "Say this is a test",
"guardrails": ["content_moderation"],
"max_tokens": 7
}'
```
The guardrail will be applied to both:
- **Input**: The prompt text before sending to the LLM
- **Output**: The completion text in the response
### Example: PII Masking in Prompts and Completions
```bash
curl -X POST 'http://localhost:4000/v1/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer your-api-key' \
-d '{
"model": "gpt-3.5-turbo-instruct",
"prompt": "My name is John Doe and my email is john@example.com",
"guardrails": ["mask_pii"],
"metadata": {
"guardrails": ["mask_pii"]
}
}'
```
### Example: Batch Prompts with Guardrails
```bash
curl -X POST 'http://localhost:4000/v1/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer your-api-key' \
-d '{
"model": "gpt-3.5-turbo-instruct",
"prompt": [
"Tell me about AI",
"What is machine learning?"
],
"guardrails": ["content_filter"],
"max_tokens": 100
}'
```
## Implementation Details
### Input Processing
- **Field**: `prompt` (string or list of strings)
- **Processing**:
- String prompts: Apply guardrail directly
- List prompts: Apply guardrail to each string in the list
- **Result**: Updated prompt(s) in request
### Output Processing
- **Field**: `choices[*].text` (string)
- **Processing**: Applies guardrail to each completion text
- **Result**: Updated completion texts in response
### Supported Prompt Types
1. **String**: Single prompt as a string
2. **List of Strings**: Multiple prompts for batch completion
3. **List of Lists**: Token-based prompts (passed through unchanged)
## Extension
Override these methods to customize behavior:
- `process_input_messages()`: Customize how prompts are processed
- `process_output_response()`: Customize how completion texts are processed
## Supported Call Types
- `CallTypes.text_completion` - Synchronous text completion
- `CallTypes.atext_completion` - Asynchronous text completion
## Notes
- The handler processes both input prompts and output completion texts
- List prompts are processed individually (each string in the list)
- Non-string prompt items (e.g., token lists) are passed through unchanged
- Both sync and async call types use the same handler

View File

@@ -0,0 +1,13 @@
"""OpenAI Text Completion handler for Unified Guardrails."""
from litellm.llms.openai.completion.guardrail_translation.handler import (
OpenAITextCompletionHandler,
)
from litellm.types.utils import CallTypes
guardrail_translation_mappings = {
CallTypes.text_completion: OpenAITextCompletionHandler,
CallTypes.atext_completion: OpenAITextCompletionHandler,
}
__all__ = ["guardrail_translation_mappings", "OpenAITextCompletionHandler"]

View File

@@ -0,0 +1,194 @@
"""
OpenAI Text Completion Handler for Unified Guardrails
This module provides guardrail translation support for OpenAI's text completion endpoint.
The handler processes the 'prompt' parameter for guardrails.
"""
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
from litellm.types.utils import GenericGuardrailAPIInputs
if TYPE_CHECKING:
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.types.utils import TextCompletionResponse
class OpenAITextCompletionHandler(BaseTranslation):
"""
Handler for processing OpenAI text completion requests with guardrails.
This class provides methods to:
1. Process input prompt (pre-call hook)
2. Process output response (post-call hook)
The handler specifically processes the 'prompt' parameter which can be:
- A single string
- A list of strings (for batch completions)
"""
async def process_input_messages(
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process input prompt by applying guardrails to text content.
Args:
data: Request data dictionary containing 'prompt' parameter
guardrail_to_apply: The guardrail instance to apply
Returns:
Modified data with guardrails applied to prompt
"""
prompt = data.get("prompt")
if prompt is None:
verbose_proxy_logger.debug(
"OpenAI Text Completion: No prompt found in request data"
)
return data
if isinstance(prompt, str):
# Single string prompt
inputs = GenericGuardrailAPIInputs(texts=[prompt])
# Include model information if available
model = data.get("model")
if model:
inputs["model"] = model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=data,
input_type="request",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
data["prompt"] = guardrailed_texts[0] if guardrailed_texts else prompt
verbose_proxy_logger.debug(
"OpenAI Text Completion: Applied guardrail to string prompt. "
"Original length: %d, New length: %d",
len(prompt),
len(data["prompt"]),
)
elif isinstance(prompt, list):
# List of string prompts (batch completion)
texts_to_check = []
text_indices = [] # Track which prompts are strings
for idx, p in enumerate(prompt):
if isinstance(p, str):
texts_to_check.append(p)
text_indices.append(idx)
if texts_to_check:
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
# Include model information if available
model = data.get("model")
if model:
inputs["model"] = model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=data,
input_type="request",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
# Replace guardrailed texts back
for guardrail_idx, prompt_idx in enumerate(text_indices):
if guardrail_idx < len(guardrailed_texts):
data["prompt"][prompt_idx] = guardrailed_texts[guardrail_idx]
verbose_proxy_logger.debug(
"OpenAI Text Completion: Applied guardrail to prompt[%d]. "
"Original length: %d, New length: %d",
prompt_idx,
len(texts_to_check[guardrail_idx]),
len(guardrailed_texts[guardrail_idx]),
)
else:
verbose_proxy_logger.warning(
"OpenAI Text Completion: Unexpected prompt type: %s. Expected string or list.",
type(prompt),
)
return data
async def process_output_response(
self,
response: "TextCompletionResponse",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
) -> Any:
"""
Process output response by applying guardrails to completion text.
Args:
response: Text completion response object
guardrail_to_apply: The guardrail instance to apply
litellm_logging_obj: Optional logging object
user_api_key_dict: User API key metadata to pass to guardrails
Returns:
Modified response with guardrails applied to completion text
"""
if not hasattr(response, "choices") or not response.choices:
verbose_proxy_logger.debug(
"OpenAI Text Completion: No choices in response to process"
)
return response
# Collect all texts to check
texts_to_check = []
choice_indices = []
for idx, choice in enumerate(response.choices):
if hasattr(choice, "text") and isinstance(choice.text, str):
texts_to_check.append(choice.text)
choice_indices.append(idx)
# Apply guardrails in batch
if texts_to_check:
# Create a request_data dict with response info and user API key metadata
request_data: dict = {"response": response}
# Add user API key metadata with prefixed keys
user_metadata = self.transform_user_api_key_dict_to_metadata(
user_api_key_dict
)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
# Include model information from the response if available
if hasattr(response, "model") and response.model:
inputs["model"] = response.model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="response",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
# Apply guardrailed texts back to choices
for guardrail_idx, choice_idx in enumerate(choice_indices):
if guardrail_idx < len(guardrailed_texts):
original_text = response.choices[choice_idx].text
response.choices[choice_idx].text = guardrailed_texts[guardrail_idx]
verbose_proxy_logger.debug(
"OpenAI Text Completion: Applied guardrail to choice[%d] text. "
"Original length: %d, New length: %d",
choice_idx,
len(original_text),
len(guardrailed_texts[guardrail_idx]),
)
return response

View File

@@ -0,0 +1,318 @@
import json
from typing import Callable, List, Optional, Union
from openai import AsyncOpenAI, OpenAI
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.base import BaseLLM
from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage
from litellm.types.utils import LlmProviders, ModelResponse, TextCompletionResponse
from litellm.utils import ProviderConfigManager
from ..common_utils import BaseOpenAILLM, OpenAIError
from .transformation import OpenAITextCompletionConfig
class OpenAITextCompletion(BaseLLM):
openai_text_completion_global_config = OpenAITextCompletionConfig()
def __init__(self) -> None:
super().__init__()
def validate_environment(self, api_key):
headers = {
"content-type": "application/json",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def completion(
self,
model_response: ModelResponse,
api_key: str,
model: str,
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
timeout: float,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
print_verbose: Optional[Callable] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
litellm_params=None,
logger_fn=None,
client=None,
organization: Optional[str] = None,
headers: Optional[dict] = None,
):
try:
if headers is None:
headers = self.validate_environment(api_key=api_key)
if model is None or messages is None:
raise OpenAIError(status_code=422, message="Missing model or messages")
# don't send max retries to the api, if set
provider_config = ProviderConfigManager.get_provider_text_completion_config(
model=model,
provider=LlmProviders(custom_llm_provider),
)
data = provider_config.transform_text_completion_request(
model=model,
messages=messages,
optional_params=optional_params,
headers=headers,
)
max_retries = data.pop("max_retries", 2)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"headers": headers,
"api_base": api_base,
"complete_input_dict": data,
},
)
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
data=data,
headers=headers,
model_response=model_response,
model=model,
timeout=timeout,
max_retries=max_retries,
client=client,
organization=organization,
)
else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
elif optional_params.get("stream", False):
return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
data=data,
headers=headers,
model_response=model_response,
model=model,
timeout=timeout,
max_retries=max_retries, # type: ignore
client=client,
organization=organization,
)
else:
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries, # type: ignore
organization=organization,
)
else:
openai_client = client
raw_response = openai_client.completions.with_raw_response.create(**data) # type: ignore
response = raw_response.parse()
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
api_key=api_key,
original_response=response_json,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return TextCompletionResponse(**response_json)
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
async def acompletion(
self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
api_key: str,
model: str,
timeout: float,
max_retries: int,
organization: Optional[str] = None,
client=None,
):
try:
if client is None:
openai_aclient = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=BaseOpenAILLM._get_async_http_client(),
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_aclient = client
raw_response = await openai_aclient.completions.with_raw_response.create(
**data
)
response = raw_response.parse()
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
response_obj = TextCompletionResponse(**response_json)
response_obj._hidden_params.original_response = json.dumps(response_json)
return response_obj
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
def streaming(
self,
logging_obj,
api_key: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float,
api_base: Optional[str] = None,
max_retries=None,
client=None,
organization=None,
):
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries, # type: ignore
organization=organization,
)
else:
openai_client = client
try:
raw_response = openai_client.completions.with_raw_response.create(**data)
response = raw_response.parse()
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
)
try:
for chunk in streamwrapper:
yield chunk
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
async def async_streaming(
self,
logging_obj,
api_key: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float,
max_retries: int,
api_base: Optional[str] = None,
client=None,
organization=None,
):
if client is None:
openai_client = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_client = client
raw_response = await openai_client.completions.with_raw_response.create(**data)
response = raw_response.parse()
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
)
try:
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)

View File

@@ -0,0 +1,158 @@
"""
Support for gpt model family
"""
from typing import List, Optional, Union
from litellm.llms.base_llm.completion.transformation import BaseTextCompletionConfig
from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage
from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse
from ..chat.gpt_transformation import OpenAIGPTConfig
from .utils import _transform_prompt
class OpenAITextCompletionConfig(BaseTextCompletionConfig, OpenAIGPTConfig):
"""
Reference: https://platform.openai.com/docs/api-reference/completions/create
The class `OpenAITextCompletionConfig` provides configuration for the OpenAI's text completion API interface. Below are the parameters:
- `best_of` (integer or null): This optional parameter generates server-side completions and returns the one with the highest log probability per token.
- `echo` (boolean or null): This optional parameter will echo back the prompt in addition to the completion.
- `frequency_penalty` (number or null): Defaults to 0. It is a numbers from -2.0 to 2.0, where positive values decrease the model's likelihood to repeat the same line.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `logprobs` (integer or null): This optional parameter includes the log probabilities on the most likely tokens as well as the chosen tokens.
- `max_tokens` (integer or null): This optional parameter sets the maximum number of tokens to generate in the completion.
- `n` (integer or null): This optional parameter sets how many completions to generate for each prompt.
- `presence_penalty` (number or null): Defaults to 0 and can be between -2.0 and 2.0. Positive values increase the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `suffix` (string or null): Defines the suffix that comes after a completion of inserted text.
- `temperature` (number or null): This optional parameter defines the sampling temperature to use.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
best_of: Optional[int] = None
echo: Optional[bool] = None
frequency_penalty: Optional[int] = None
logit_bias: Optional[dict] = None
logprobs: Optional[int] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
suffix: Optional[str] = None
def __init__(
self,
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[int] = None,
logit_bias: Optional[dict] = None,
logprobs: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
suffix: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def convert_to_chat_model_response_object(
self,
response_object: Optional[TextCompletionResponse] = None,
model_response_object: Optional[ModelResponse] = None,
):
try:
## RESPONSE OBJECT
if response_object is None or model_response_object is None:
raise ValueError("Error in response object format")
choice_list: List[Choices] = []
for idx, choice in enumerate(response_object["choices"]):
message = Message(
content=choice["text"],
role="assistant",
)
choice = Choices(
finish_reason=choice["finish_reason"],
index=idx,
message=message,
logprobs=choice.get("logprobs", None),
)
choice_list.append(choice)
model_response_object.choices = choice_list # type: ignore
if "usage" in response_object:
setattr(model_response_object, "usage", response_object["usage"])
if "id" in response_object:
model_response_object.id = response_object["id"]
if "model" in response_object:
model_response_object.model = response_object["model"]
model_response_object._hidden_params[
"original_response"
] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
return model_response_object
except Exception as e:
raise e
def get_supported_openai_params(self, model: str) -> List:
return [
"functions",
"function_call",
"temperature",
"top_p",
"n",
"stream",
"stream_options",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
"logprobs",
"top_logprobs",
"extra_headers",
]
def transform_text_completion_request(
self,
model: str,
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
optional_params: dict,
headers: dict,
) -> dict:
prompt = _transform_prompt(messages)
return {
"model": model,
"prompt": prompt,
**optional_params,
}

View File

@@ -0,0 +1,50 @@
from typing import List, Union, cast
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.types.llms.openai import (
AllMessageValues,
AllPromptValues,
OpenAITextCompletionUserMessage,
)
def is_tokens_or_list_of_tokens(value: List):
# Check if it's a list of integers (tokens)
if isinstance(value, list) and all(isinstance(item, int) for item in value):
return True
# Check if it's a list of lists of integers (list of tokens)
if isinstance(value, list) and all(
isinstance(item, list) and all(isinstance(i, int) for i in item)
for item in value
):
return True
return False
def _transform_prompt(
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
) -> AllPromptValues:
if len(messages) == 1: # base case
message_content = messages[0].get("content")
if (
message_content
and isinstance(message_content, list)
and is_tokens_or_list_of_tokens(message_content)
):
openai_prompt: AllPromptValues = cast(AllPromptValues, message_content)
else:
openai_prompt = ""
content = convert_content_list_to_str(cast(AllMessageValues, messages[0]))
openai_prompt += content
else:
prompt_str_list: List[str] = []
for m in messages:
try: # expect list of int/list of list of int to be a 1 message array only.
content = convert_content_list_to_str(cast(AllMessageValues, m))
prompt_str_list.append(content)
except Exception as e:
raise e
openai_prompt = prompt_str_list
return openai_prompt

View File

@@ -0,0 +1,343 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import httpx
import litellm
from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
StandardBuiltInToolCostTracking,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.containers.main import (
ContainerCreateOptionalRequestParams,
ContainerFileListResponse,
ContainerListResponse,
ContainerObject,
DeleteContainerResult,
)
from litellm.types.router import GenericLiteLLMParams
from ...base_llm.containers.transformation import BaseContainerConfig
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from ...base_llm.chat.transformation import BaseLLMException as _BaseLLMException
LiteLLMLoggingObj = _LiteLLMLoggingObj
BaseLLMException = _BaseLLMException
else:
LiteLLMLoggingObj = Any
BaseLLMException = Any
class OpenAIContainerConfig(BaseContainerConfig):
"""Configuration class for OpenAI container API."""
def __init__(self):
super().__init__()
def get_supported_openai_params(self) -> list:
"""Get the list of supported OpenAI parameters for container API."""
return [
"name",
"expires_after",
"file_ids",
"extra_headers",
]
def map_openai_params(
self,
container_create_optional_params: ContainerCreateOptionalRequestParams,
drop_params: bool,
) -> Dict:
"""No mapping applied since inputs are in OpenAI spec already"""
return dict(container_create_optional_params)
def validate_environment(
self,
headers: dict,
api_key: Optional[str] = None,
) -> dict:
api_key = (
api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
)
headers.update(
{
"Authorization": f"Bearer {api_key}",
},
)
return headers
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""Get the complete URL for OpenAI container API."""
api_base = (
api_base
or litellm.api_base
or get_secret_str("OPENAI_BASE_URL")
or get_secret_str("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
return f"{api_base.rstrip('/')}/containers"
def transform_container_create_request(
self,
name: str,
container_create_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
"""Transform the container creation request for OpenAI API."""
# Remove extra_headers from optional params as they're handled separately
container_create_optional_request_params = {
k: v
for k, v in container_create_optional_request_params.items()
if k not in ["extra_headers"]
}
# Create the request data
request_dict = {
"name": name,
**container_create_optional_request_params,
}
return request_dict
def transform_container_create_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ContainerObject:
"""Transform the OpenAI container creation response."""
response_data = raw_response.json()
# Transform the response data
container_obj = ContainerObject(**response_data) # type: ignore[arg-type]
# Add cost for container creation (OpenAI containers are code interpreter sessions)
# https://platform.openai.com/docs/pricing
# Each container creation is 1 code interpreter session
container_cost = StandardBuiltInToolCostTracking.get_cost_for_code_interpreter(
sessions=1,
provider="openai",
)
if (
not hasattr(container_obj, "_hidden_params")
or container_obj._hidden_params is None
):
container_obj._hidden_params = {}
if "additional_headers" not in container_obj._hidden_params:
container_obj._hidden_params["additional_headers"] = {}
container_obj._hidden_params["additional_headers"][
"llm_provider-x-litellm-response-cost"
] = container_cost
return container_obj
def transform_container_list_request(
self,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
after: Optional[str] = None,
limit: Optional[int] = None,
order: Optional[str] = None,
extra_query: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict]:
"""Transform the container list request for OpenAI API.
OpenAI API expects the following request:
- GET /v1/containers
"""
# Use the api_base directly for container list
url = api_base
# Prepare query parameters
params = {}
if after is not None:
params["after"] = after
if limit is not None:
params["limit"] = str(limit)
if order is not None:
params["order"] = order
# Add any extra query parameters
if extra_query:
params.update(extra_query)
return url, params
def transform_container_list_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ContainerListResponse:
"""Transform the OpenAI container list response."""
response_data = raw_response.json()
# Transform the response data
container_list = ContainerListResponse(**response_data) # type: ignore[arg-type]
return container_list
def transform_container_retrieve_request(
self,
container_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""Transform the OpenAI container retrieve request."""
# For container retrieve, we just need to construct the URL
url = f"{api_base.rstrip('/')}/{container_id}"
# No additional data needed for GET request
data: Dict[str, Any] = {}
return url, data
def transform_container_retrieve_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ContainerObject:
"""Transform the OpenAI container retrieve response."""
response_data = raw_response.json()
# Transform the response data
container_obj = ContainerObject(**response_data) # type: ignore[arg-type]
return container_obj
def transform_container_delete_request(
self,
container_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""Transform the container delete request for OpenAI API.
OpenAI API expects the following request:
- DELETE /v1/containers/{container_id}
"""
# Construct the URL for container delete
url = f"{api_base.rstrip('/')}/{container_id}"
# No data needed for DELETE request
data: Dict[str, Any] = {}
return url, data
def transform_container_delete_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> DeleteContainerResult:
"""Transform the OpenAI container delete response."""
response_data = raw_response.json()
# Transform the response data
delete_result = DeleteContainerResult(**response_data) # type: ignore[arg-type]
return delete_result
def transform_container_file_list_request(
self,
container_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
after: Optional[str] = None,
limit: Optional[int] = None,
order: Optional[str] = None,
extra_query: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict]:
"""Transform the container file list request for OpenAI API.
OpenAI API expects the following request:
- GET /v1/containers/{container_id}/files
"""
# Construct the URL for container files
url = f"{api_base.rstrip('/')}/{container_id}/files"
# Prepare query parameters
params: Dict[str, Any] = {}
if after is not None:
params["after"] = after
if limit is not None:
params["limit"] = str(limit)
if order is not None:
params["order"] = order
# Add any extra query parameters
if extra_query:
params.update(extra_query)
return url, params
def transform_container_file_list_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ContainerFileListResponse:
"""Transform the OpenAI container file list response."""
response_data = raw_response.json()
# Transform the response data
file_list = ContainerFileListResponse(**response_data) # type: ignore[arg-type]
return file_list
def transform_container_file_content_request(
self,
container_id: str,
file_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""Transform the container file content request for OpenAI API.
OpenAI API expects the following request:
- GET /v1/containers/{container_id}/files/{file_id}/content
"""
# Construct the URL for container file content
url = f"{api_base.rstrip('/')}/{container_id}/files/{file_id}/content"
# No query parameters needed
params: Dict[str, Any] = {}
return url, params
def transform_container_file_content_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> bytes:
"""Transform the OpenAI container file content response.
Returns the raw binary content of the file.
"""
return raw_response.content
def get_error_class(
self,
error_message: str,
status_code: int,
headers: Union[dict, httpx.Headers],
) -> BaseLLMException:
from ...base_llm.chat.transformation import BaseLLMException
raise BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)

View File

@@ -0,0 +1,177 @@
"""
Helper util for handling openai-specific cost calculation
- e.g.: prompt caching
"""
from typing import Literal, Optional, Tuple
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
from litellm.types.utils import CallTypes, ModelInfo, Usage
from litellm.utils import get_model_info
def cost_router(call_type: CallTypes) -> Literal["cost_per_token", "cost_per_second"]:
if call_type == CallTypes.atranscription or call_type == CallTypes.transcription:
return "cost_per_second"
else:
return "cost_per_token"
def cost_per_token(
model: str, usage: Usage, service_tier: Optional[str] = None
) -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Input:
- model: str, the model name without provider prefix
- usage: LiteLLM Usage block, containing anthropic caching information
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
## CALCULATE INPUT COST
return generic_cost_per_token(
model=model,
usage=usage,
custom_llm_provider="openai",
service_tier=service_tier,
)
# ### Non-cached text tokens
# non_cached_text_tokens = usage.prompt_tokens
# cached_tokens: Optional[int] = None
# if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens:
# cached_tokens = usage.prompt_tokens_details.cached_tokens
# non_cached_text_tokens = non_cached_text_tokens - cached_tokens
# prompt_cost: float = non_cached_text_tokens * model_info["input_cost_per_token"]
# ## Prompt Caching cost calculation
# if model_info.get("cache_read_input_token_cost") is not None and cached_tokens:
# # Note: We read ._cache_read_input_tokens from the Usage - since cost_calculator.py standardizes the cache read tokens on usage._cache_read_input_tokens
# prompt_cost += cached_tokens * (
# model_info.get("cache_read_input_token_cost", 0) or 0
# )
# _audio_tokens: Optional[int] = (
# usage.prompt_tokens_details.audio_tokens
# if usage.prompt_tokens_details is not None
# else None
# )
# _audio_cost_per_token: Optional[float] = model_info.get(
# "input_cost_per_audio_token"
# )
# if _audio_tokens is not None and _audio_cost_per_token is not None:
# audio_cost: float = _audio_tokens * _audio_cost_per_token
# prompt_cost += audio_cost
# ## CALCULATE OUTPUT COST
# completion_cost: float = (
# usage["completion_tokens"] * model_info["output_cost_per_token"]
# )
# _output_cost_per_audio_token: Optional[float] = model_info.get(
# "output_cost_per_audio_token"
# )
# _output_audio_tokens: Optional[int] = (
# usage.completion_tokens_details.audio_tokens
# if usage.completion_tokens_details is not None
# else None
# )
# if _output_cost_per_audio_token is not None and _output_audio_tokens is not None:
# audio_cost = _output_audio_tokens * _output_cost_per_audio_token
# completion_cost += audio_cost
# return prompt_cost, completion_cost
def cost_per_second(
model: str, custom_llm_provider: Optional[str], duration: float = 0.0
) -> Tuple[float, float]:
"""
Calculates the cost per second for a given model, prompt tokens, and completion tokens.
Input:
- model: str, the model name without provider prefix
- custom_llm_provider: str, the custom llm provider
- duration: float, the duration of the response in seconds
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
## GET MODEL INFO
model_info = get_model_info(
model=model, custom_llm_provider=custom_llm_provider or "openai"
)
prompt_cost = 0.0
completion_cost = 0.0
## Speech / Audio cost calculation
if (
"output_cost_per_second" in model_info
and model_info["output_cost_per_second"] is not None
):
verbose_logger.debug(
f"For model={model} - output_cost_per_second: {model_info.get('output_cost_per_second')}; duration: {duration}"
)
## COST PER SECOND ##
completion_cost = model_info["output_cost_per_second"] * duration
elif (
"input_cost_per_second" in model_info
and model_info["input_cost_per_second"] is not None
):
verbose_logger.debug(
f"For model={model} - input_cost_per_second: {model_info.get('input_cost_per_second')}; duration: {duration}"
)
## COST PER SECOND ##
prompt_cost = model_info["input_cost_per_second"] * duration
completion_cost = 0.0
return prompt_cost, completion_cost
def video_generation_cost(
model: str,
duration_seconds: float,
custom_llm_provider: Optional[str] = None,
model_info: Optional[ModelInfo] = None,
) -> float:
"""
Calculates the cost for video generation based on duration in seconds.
Input:
- model: str, the model name without provider prefix
- duration_seconds: float, the duration of the generated video in seconds
- custom_llm_provider: str, the custom llm provider
- model_info: Optional[dict], deployment-level model info containing
custom video pricing. When provided, skips the global
get_model_info() lookup so that deployment-specific pricing is used.
Returns:
float - total_cost_in_usd
"""
## GET MODEL INFO
if model_info is None:
model_info = get_model_info(
model=model, custom_llm_provider=custom_llm_provider or "openai"
)
# Check for video-specific cost per second
video_cost_per_second = model_info.get("output_cost_per_video_per_second")
if video_cost_per_second is not None:
verbose_logger.debug(
f"For model={model} - output_cost_per_video_per_second: {video_cost_per_second}; duration: {duration_seconds}"
)
return video_cost_per_second * duration_seconds
# Fallback to general output cost per second
output_cost_per_second = model_info.get("output_cost_per_second")
if output_cost_per_second is not None:
verbose_logger.debug(
f"For model={model} - output_cost_per_second: {output_cost_per_second}; duration: {duration_seconds}"
)
return output_cost_per_second * duration_seconds
# If no cost information found, return 0
verbose_logger.warning(
f"No cost information found for video model {model}. Please add pricing to model_prices_and_context_window.json"
)
return 0.0

View File

@@ -0,0 +1,13 @@
"""OpenAI Embeddings handler for Unified Guardrails."""
from litellm.llms.openai.embeddings.guardrail_translation.handler import (
OpenAIEmbeddingsHandler,
)
from litellm.types.utils import CallTypes
guardrail_translation_mappings = {
CallTypes.embedding: OpenAIEmbeddingsHandler,
CallTypes.aembedding: OpenAIEmbeddingsHandler,
}
__all__ = ["guardrail_translation_mappings", "OpenAIEmbeddingsHandler"]

View File

@@ -0,0 +1,179 @@
"""
OpenAI Embeddings Handler for Unified Guardrails
This module provides guardrail translation support for OpenAI's embeddings endpoint.
The handler processes the 'input' parameter for guardrails.
"""
from typing import TYPE_CHECKING, Any, List, Optional, Union
from litellm._logging import verbose_proxy_logger
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
from litellm.types.utils import GenericGuardrailAPIInputs
if TYPE_CHECKING:
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.types.utils import EmbeddingResponse
class OpenAIEmbeddingsHandler(BaseTranslation):
"""
Handler for processing OpenAI embeddings requests with guardrails.
This class provides methods to:
1. Process input text (pre-call hook)
2. Process output response (post-call hook) - embeddings don't typically need output guardrails
The handler specifically processes the 'input' parameter which can be:
- A single string
- A list of strings (for batch embeddings)
- A list of integers (token IDs - not processed by guardrails)
- A list of lists of integers (batch token IDs - not processed by guardrails)
"""
async def process_input_messages(
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process input text by applying guardrails to text content.
Args:
data: Request data dictionary containing 'input' parameter
guardrail_to_apply: The guardrail instance to apply
litellm_logging_obj: Optional logging object
Returns:
Modified data with guardrails applied to input
"""
input_data = data.get("input")
if input_data is None:
verbose_proxy_logger.debug(
"OpenAI Embeddings: No input found in request data"
)
return data
if isinstance(input_data, str):
data = await self._process_string_input(
data, input_data, guardrail_to_apply, litellm_logging_obj
)
elif isinstance(input_data, list):
data = await self._process_list_input(
data, input_data, guardrail_to_apply, litellm_logging_obj
)
else:
verbose_proxy_logger.warning(
"OpenAI Embeddings: Unexpected input type: %s. Expected string or list.",
type(input_data),
)
return data
async def _process_string_input(
self,
data: dict,
input_data: str,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any],
) -> dict:
"""Process a single string input through the guardrail."""
inputs = GenericGuardrailAPIInputs(texts=[input_data])
if model := data.get("model"):
inputs["model"] = model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=data,
input_type="request",
logging_obj=litellm_logging_obj,
)
if guardrailed_texts := guardrailed_inputs.get("texts"):
data["input"] = guardrailed_texts[0]
verbose_proxy_logger.debug(
"OpenAI Embeddings: Applied guardrail to string input. "
"Original length: %d, New length: %d",
len(input_data),
len(data["input"]),
)
return data
async def _process_list_input(
self,
data: dict,
input_data: List[Union[str, int, List[int]]],
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any],
) -> dict:
"""Process a list input through the guardrail (if it contains strings)."""
if len(input_data) == 0:
return data
first_item = input_data[0]
# Skip non-text inputs (token IDs)
if isinstance(first_item, (int, list)):
verbose_proxy_logger.debug(
"OpenAI Embeddings: Input is token IDs, skipping guardrail processing"
)
return data
if not isinstance(first_item, str):
verbose_proxy_logger.warning(
"OpenAI Embeddings: Unexpected input list item type: %s",
type(first_item),
)
return data
# List of strings - apply guardrail
inputs = GenericGuardrailAPIInputs(texts=input_data) # type: ignore
if model := data.get("model"):
inputs["model"] = model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=data,
input_type="request",
logging_obj=litellm_logging_obj,
)
if guardrailed_texts := guardrailed_inputs.get("texts"):
data["input"] = guardrailed_texts
verbose_proxy_logger.debug(
"OpenAI Embeddings: Applied guardrail to %d inputs",
len(guardrailed_texts),
)
return data
async def process_output_response(
self,
response: "EmbeddingResponse",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
) -> Any:
"""
Process output response - embeddings responses contain vectors, not text.
For embeddings, the output is numerical vectors, so there's typically
no text content to apply guardrails to. This method is a no-op but
is included for interface consistency.
Args:
response: Embedding response object
guardrail_to_apply: The guardrail instance to apply
litellm_logging_obj: Optional logging object
user_api_key_dict: User API key metadata
Returns:
Unmodified response (embeddings don't have text output to guard)
"""
verbose_proxy_logger.debug(
"OpenAI Embeddings: Output response processing skipped - "
"embeddings contain vectors, not text"
)
return response

View File

@@ -0,0 +1,7 @@
"""
OpenAI Evals API configuration
"""
from .transformation import OpenAIEvalsConfig
__all__ = ["OpenAIEvalsConfig"]

View File

@@ -0,0 +1,426 @@
"""
OpenAI Evals API configuration and transformations
"""
from typing import Any, Dict, Optional, Tuple
import httpx
from litellm._logging import verbose_logger
from litellm.llms.base_llm.evals.transformation import (
BaseEvalsAPIConfig,
LiteLLMLoggingObj,
)
from litellm.types.llms.openai_evals import (
CancelEvalResponse,
CancelRunResponse,
CreateEvalRequest,
CreateRunRequest,
DeleteEvalResponse,
Eval,
ListEvalsParams,
ListEvalsResponse,
ListRunsParams,
ListRunsResponse,
Run,
RunDeleteResponse,
UpdateEvalRequest,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LlmProviders
class OpenAIEvalsConfig(BaseEvalsAPIConfig):
"""OpenAI-specific Evals API configuration"""
@property
def custom_llm_provider(self) -> LlmProviders:
return LlmProviders.OPENAI
def validate_environment(
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
"""Add OpenAI-specific headers"""
import litellm
from litellm.secret_managers.main import get_secret_str
# Get API key following OpenAI pattern
api_key = None
if litellm_params:
api_key = litellm_params.api_key
api_key = (
api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
)
if not api_key:
raise ValueError("OPENAI_API_KEY is required for Evals API")
# Add required headers
headers["Authorization"] = f"Bearer {api_key}"
headers["Content-Type"] = "application/json"
return headers
def get_complete_url(
self,
api_base: Optional[str],
endpoint: str,
eval_id: Optional[str] = None,
) -> str:
"""Get complete URL for OpenAI Evals API"""
if api_base is None:
api_base = "https://api.openai.com"
if eval_id:
return f"{api_base}/v1/evals/{eval_id}"
return f"{api_base}/v1/{endpoint}"
def transform_create_eval_request(
self,
create_request: CreateEvalRequest,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
"""Transform create eval request for OpenAI"""
verbose_logger.debug("Transforming create eval request: %s", create_request)
# OpenAI expects the request body directly
request_body = {k: v for k, v in create_request.items() if v is not None}
return request_body
def transform_create_eval_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> Eval:
"""Transform OpenAI response to Eval object"""
response_json = raw_response.json()
verbose_logger.debug("Transforming create eval response: %s", response_json)
return Eval(**response_json)
def transform_list_evals_request(
self,
list_params: ListEvalsParams,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""Transform list evals request for OpenAI"""
api_base = "https://api.openai.com"
if litellm_params and litellm_params.api_base:
api_base = litellm_params.api_base
url = self.get_complete_url(api_base=api_base, endpoint="evals")
# Build query parameters
query_params: Dict[str, Any] = {}
if "limit" in list_params and list_params["limit"]:
query_params["limit"] = list_params["limit"]
if "after" in list_params and list_params["after"]:
query_params["after"] = list_params["after"]
if "before" in list_params and list_params["before"]:
query_params["before"] = list_params["before"]
if "order" in list_params and list_params["order"]:
query_params["order"] = list_params["order"]
if "order_by" in list_params and list_params["order_by"]:
query_params["order_by"] = list_params["order_by"]
verbose_logger.debug(
"List evals request made to OpenAI Evals endpoint with params: %s",
query_params,
)
return url, query_params
def transform_list_evals_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ListEvalsResponse:
"""Transform OpenAI response to ListEvalsResponse"""
response_json = raw_response.json()
verbose_logger.debug("Transforming list evals response: %s", response_json)
return ListEvalsResponse(**response_json)
def transform_get_eval_request(
self,
eval_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""Transform get eval request for OpenAI"""
url = self.get_complete_url(
api_base=api_base, endpoint="evals", eval_id=eval_id
)
verbose_logger.debug("Get eval request - URL: %s", url)
return url, headers
def transform_get_eval_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> Eval:
"""Transform OpenAI response to Eval object"""
response_json = raw_response.json()
verbose_logger.debug("Transforming get eval response: %s", response_json)
return Eval(**response_json)
def transform_update_eval_request(
self,
eval_id: str,
update_request: UpdateEvalRequest,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict, Dict]:
"""Transform update eval request for OpenAI"""
url = self.get_complete_url(
api_base=api_base, endpoint="evals", eval_id=eval_id
)
# Build request body
request_body = {k: v for k, v in update_request.items() if v is not None}
verbose_logger.debug(
"Update eval request - URL: %s, body: %s", url, request_body
)
return url, headers, request_body
def transform_update_eval_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> Eval:
"""Transform OpenAI response to Eval object"""
response_json = raw_response.json()
verbose_logger.debug("Transforming update eval response: %s", response_json)
return Eval(**response_json)
def transform_delete_eval_request(
self,
eval_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""Transform delete eval request for OpenAI"""
url = self.get_complete_url(
api_base=api_base, endpoint="evals", eval_id=eval_id
)
verbose_logger.debug("Delete eval request - URL: %s", url)
return url, headers
def transform_delete_eval_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> DeleteEvalResponse:
"""Transform OpenAI response to DeleteEvalResponse"""
response_json = raw_response.json()
verbose_logger.debug("Transforming delete eval response: %s", response_json)
return DeleteEvalResponse(**response_json)
def transform_cancel_eval_request(
self,
eval_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict, Dict]:
"""Transform cancel eval request for OpenAI"""
url = f"{self.get_complete_url(api_base=api_base, endpoint='evals', eval_id=eval_id)}/cancel"
# Empty body for cancel request
request_body: Dict[str, Any] = {}
verbose_logger.debug("Cancel eval request - URL: %s", url)
return url, headers, request_body
def transform_cancel_eval_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> CancelEvalResponse:
"""Transform OpenAI response to CancelEvalResponse"""
response_json = raw_response.json()
verbose_logger.debug("Transforming cancel eval response: %s", response_json)
return CancelEvalResponse(**response_json)
# Run API Transformations
def transform_create_run_request(
self,
eval_id: str,
create_request: CreateRunRequest,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""Transform create run request for OpenAI"""
api_base = "https://api.openai.com"
if litellm_params and litellm_params.api_base:
api_base = litellm_params.api_base
url = f"{api_base}/v1/evals/{eval_id}/runs"
# Build request body
request_body = {k: v for k, v in create_request.items() if v is not None}
verbose_logger.debug(
"Create run request - URL: %s, body: %s", url, request_body
)
return url, request_body
def transform_create_run_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> Run:
"""Transform OpenAI response to Run object"""
response_json = raw_response.json()
verbose_logger.debug("Transforming create run response: %s", response_json)
return Run(**response_json)
def transform_list_runs_request(
self,
eval_id: str,
list_params: ListRunsParams,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""Transform list runs request for OpenAI"""
api_base = "https://api.openai.com"
if litellm_params and litellm_params.api_base:
api_base = litellm_params.api_base
url = f"{api_base}/v1/evals/{eval_id}/runs"
# Build query parameters
query_params: Dict[str, Any] = {}
if "limit" in list_params and list_params["limit"]:
query_params["limit"] = list_params["limit"]
if "after" in list_params and list_params["after"]:
query_params["after"] = list_params["after"]
if "before" in list_params and list_params["before"]:
query_params["before"] = list_params["before"]
if "order" in list_params and list_params["order"]:
query_params["order"] = list_params["order"]
verbose_logger.debug(
"List runs request made to OpenAI Evals endpoint with params: %s",
query_params,
)
return url, query_params
def transform_list_runs_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ListRunsResponse:
"""Transform OpenAI response to ListRunsResponse"""
response_json = raw_response.json()
verbose_logger.debug("Transforming list runs response: %s", response_json)
return ListRunsResponse(**response_json)
def transform_get_run_request(
self,
eval_id: str,
run_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""Transform get run request for OpenAI"""
url = f"{api_base}/v1/evals/{eval_id}/runs/{run_id}"
verbose_logger.debug("Get run request - URL: %s", url)
return url, headers
def transform_get_run_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> Run:
"""Transform OpenAI response to Run object"""
response_json = raw_response.json()
verbose_logger.debug("Transforming get run response: %s", response_json)
return Run(**response_json)
def transform_cancel_run_request(
self,
eval_id: str,
run_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict, Dict]:
"""Transform cancel run request for OpenAI"""
url = f"{api_base}/v1/evals/{eval_id}/runs/{run_id}/cancel"
# Empty body for cancel request
request_body: Dict[str, Any] = {}
verbose_logger.debug("Cancel run request - URL: %s", url)
return url, headers, request_body
def transform_cancel_run_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> CancelRunResponse:
"""Transform OpenAI response to CancelRunResponse"""
response_json = raw_response.json()
verbose_logger.debug("Transforming cancel run response: %s", response_json)
return CancelRunResponse(**response_json)
def transform_delete_run_request(
self,
eval_id: str,
run_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict, Dict]:
"""Transform delete run request for OpenAI"""
url = f"{api_base}/v1/evals/{eval_id}/runs/{run_id}"
# Empty body for delete request
request_body: Dict[str, Any] = {}
verbose_logger.debug("Delete run request - URL: %s", url)
return url, headers, request_body
def transform_delete_run_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> RunDeleteResponse:
"""Transform OpenAI response to RunDeleteResponse"""
response_json = raw_response.json()
verbose_logger.debug("Transforming delete run response: %s", response_json)
return RunDeleteResponse(**response_json)

View File

@@ -0,0 +1,278 @@
from typing import Any, Coroutine, Optional, Union, cast
import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from litellm._logging import verbose_logger
from litellm.types.utils import LiteLLMFineTuningJob
class OpenAIFineTuningAPI:
"""
OpenAI methods to support for batches
"""
def __init__(self) -> None:
super().__init__()
def get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
_is_async: bool = False,
api_version: Optional[str] = None,
litellm_params: Optional[dict] = None,
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]:
received_args = locals()
openai_client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
if _is_async is True:
openai_client = AsyncOpenAI(**data)
else:
openai_client = OpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
async def acreate_fine_tuning_job(
self,
create_fine_tuning_job_data: dict,
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
) -> LiteLLMFineTuningJob:
response = await openai_client.fine_tuning.jobs.create(
**create_fine_tuning_job_data
)
return LiteLLMFineTuningJob(**response.model_dump())
def create_fine_tuning_job(
self,
_is_async: bool,
create_fine_tuning_job_data: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
) -> Union[LiteLLMFineTuningJob, Coroutine[Any, Any, LiteLLMFineTuningJob]]:
openai_client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
api_version=api_version,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acreate_fine_tuning_job( # type: ignore
create_fine_tuning_job_data=create_fine_tuning_job_data,
openai_client=openai_client,
)
verbose_logger.debug(
"creating fine tuning job, args= %s", create_fine_tuning_job_data
)
response = cast(OpenAI, openai_client).fine_tuning.jobs.create(
**create_fine_tuning_job_data
)
return LiteLLMFineTuningJob(**response.model_dump())
async def acancel_fine_tuning_job(
self,
fine_tuning_job_id: str,
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
) -> LiteLLMFineTuningJob:
response = await openai_client.fine_tuning.jobs.cancel(
fine_tuning_job_id=fine_tuning_job_id
)
return LiteLLMFineTuningJob(**response.model_dump())
def cancel_fine_tuning_job(
self,
_is_async: bool,
fine_tuning_job_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
) -> Union[LiteLLMFineTuningJob, Coroutine[Any, Any, LiteLLMFineTuningJob]]:
openai_client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
api_version=api_version,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acancel_fine_tuning_job( # type: ignore
fine_tuning_job_id=fine_tuning_job_id,
openai_client=openai_client,
)
verbose_logger.debug("canceling fine tuning job, args= %s", fine_tuning_job_id)
response = cast(OpenAI, openai_client).fine_tuning.jobs.cancel(
fine_tuning_job_id=fine_tuning_job_id
)
return LiteLLMFineTuningJob(**response.model_dump())
async def alist_fine_tuning_jobs(
self,
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
after: Optional[str] = None,
limit: Optional[int] = None,
):
response = await openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore
return response
def list_fine_tuning_jobs(
self,
_is_async: bool,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
after: Optional[str] = None,
limit: Optional[int] = None,
):
openai_client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
api_version=api_version,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.alist_fine_tuning_jobs( # type: ignore
after=after,
limit=limit,
openai_client=openai_client,
)
verbose_logger.debug("list fine tuning job, after= %s, limit= %s", after, limit)
response = openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore
return response
async def aretrieve_fine_tuning_job(
self,
fine_tuning_job_id: str,
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
) -> LiteLLMFineTuningJob:
response = await openai_client.fine_tuning.jobs.retrieve(
fine_tuning_job_id=fine_tuning_job_id
)
return LiteLLMFineTuningJob(**response.model_dump())
def retrieve_fine_tuning_job(
self,
_is_async: bool,
fine_tuning_job_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
) -> Union[LiteLLMFineTuningJob, Coroutine[Any, Any, LiteLLMFineTuningJob]]:
openai_client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
api_version=api_version,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.aretrieve_fine_tuning_job( # type: ignore
fine_tuning_job_id=fine_tuning_job_id,
openai_client=openai_client,
)
verbose_logger.debug("retrieving fine tuning job, id= %s", fine_tuning_job_id)
response = cast(OpenAI, openai_client).fine_tuning.jobs.retrieve(
fine_tuning_job_id=fine_tuning_job_id
)
return LiteLLMFineTuningJob(**response.model_dump())

View File

@@ -0,0 +1,29 @@
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
from .dalle2_transformation import DallE2ImageEditConfig
from .transformation import OpenAIImageEditConfig
__all__ = [
"OpenAIImageEditConfig",
"DallE2ImageEditConfig",
"get_openai_image_edit_config",
]
def get_openai_image_edit_config(model: str) -> BaseImageEditConfig:
"""
Get the appropriate OpenAI image edit config based on the model.
Args:
model: The model name (e.g., "dall-e-2", "gpt-image-1")
Returns:
The appropriate config instance for the model
"""
model_normalized = model.lower().replace("-", "").replace("_", "")
if model_normalized == "dalle2":
return DallE2ImageEditConfig()
else:
# Default to standard OpenAI config for gpt-image-1 and other models
return OpenAIImageEditConfig()

View File

@@ -0,0 +1,104 @@
from io import BufferedReader
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
from httpx._types import RequestFiles
import litellm
from litellm.images.utils import ImageEditRequestUtils
from litellm.types.images.main import ImageEditRequestParams
from litellm.types.llms.openai import FileTypes
from litellm.types.router import GenericLiteLLMParams
from .transformation import OpenAIImageEditConfig
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class DallE2ImageEditConfig(OpenAIImageEditConfig):
"""
DALL-E-2 specific configuration for image edit API.
DALL-E-2 only supports editing a single image (not an array).
Uses "image" field name instead of "image[]".
"""
def transform_image_edit_request(
self,
model: str,
prompt: Optional[str],
image: Optional[FileTypes],
image_edit_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[Dict, RequestFiles]:
"""
Transform image edit request for DALL-E-2.
DALL-E-2 only accepts a single image with field name "image" (not "image[]").
"""
request_params = {
"model": model,
**image_edit_optional_request_params,
}
if image is not None:
request_params["image"] = image
if prompt is not None:
request_params["prompt"] = prompt
request = ImageEditRequestParams(**request_params)
request_dict = cast(Dict, request)
#########################################################
# Separate images and masks as `files` and send other parameters as `data`
#########################################################
_image_list = request_dict.get("image")
_mask = request_dict.get("mask")
data_without_files = {
k: v for k, v in request_dict.items() if k not in ["image", "mask"]
}
files_list: List[Tuple[str, Any]] = []
# Handle image parameter - DALL-E-2 only supports single image
if _image_list is not None:
image_list = (
[_image_list] if not isinstance(_image_list, list) else _image_list
)
# Validate only one image is provided
if len(image_list) > 1:
raise litellm.BadRequestError(
message="DALL-E-2 only supports editing a single image. Please provide one image.",
model=model,
llm_provider="openai",
)
# Use "image" field name (singular) for DALL-E-2
for _image in image_list:
if _image is not None:
self._add_image_to_files(
files_list=files_list,
image=_image,
field_name="image",
)
# Handle mask parameter if provided
if _mask is not None:
# Handle case where mask can be a list (extract first mask)
if isinstance(_mask, list):
_mask = _mask[0] if _mask else None
if _mask is not None:
mask_content_type: str = ImageEditRequestUtils.get_image_content_type(
_mask
)
if isinstance(_mask, BufferedReader):
files_list.append(("mask", (_mask.name, _mask, mask_content_type)))
else:
files_list.append(("mask", ("mask.png", _mask, mask_content_type)))
return data_without_files, files_list

View File

@@ -0,0 +1,202 @@
from io import BufferedReader
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
import httpx
from httpx._types import RequestFiles
import litellm
from litellm.images.utils import ImageEditRequestUtils
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.images.main import (
ImageEditOptionalRequestParams,
ImageEditRequestParams,
)
from litellm.types.llms.openai import FileTypes
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import ImageResponse
from ..common_utils import OpenAIError
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class OpenAIImageEditConfig(BaseImageEditConfig):
"""
Base configuration for OpenAI image edit API.
Used for models like gpt-image-1 that support multiple images.
"""
def get_supported_openai_params(self, model: str) -> list:
"""
All OpenAI Image Edits params are supported
"""
return [
"image",
"prompt",
"background",
"input_fidelity",
"mask",
"model",
"n",
"quality",
"response_format",
"size",
"user",
"extra_headers",
"extra_query",
"extra_body",
"timeout",
]
def map_openai_params(
self,
image_edit_optional_params: ImageEditOptionalRequestParams,
model: str,
drop_params: bool,
) -> Dict:
"""No mapping applied since inputs are in OpenAI spec already"""
return dict(image_edit_optional_params)
def _add_image_to_files(
self,
files_list: List[Tuple[str, Any]],
image: Any,
field_name: str,
) -> None:
"""Add an image to the files list with appropriate content type"""
image_content_type = ImageEditRequestUtils.get_image_content_type(image)
if isinstance(image, BufferedReader):
files_list.append((field_name, (image.name, image, image_content_type)))
else:
files_list.append((field_name, ("image.png", image, image_content_type)))
def transform_image_edit_request(
self,
model: str,
prompt: Optional[str],
image: Optional[FileTypes],
image_edit_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[Dict, RequestFiles]:
"""
Transform image edit request to OpenAI API format.
Handles multipart/form-data for images. Uses "image[]" field name
to support multiple images (e.g., for gpt-image-1).
"""
# Build request params, only including non-None values
request_params = {
"model": model,
**image_edit_optional_request_params,
}
if image is not None:
request_params["image"] = image
if prompt is not None:
request_params["prompt"] = prompt
request = ImageEditRequestParams(**request_params)
request_dict = cast(Dict, request)
#########################################################
# Separate images and masks as `files` and send other parameters as `data`
#########################################################
_image_list = request_dict.get("image")
_mask = request_dict.get("mask")
data_without_files = {
k: v for k, v in request_dict.items() if k not in ["image", "mask"]
}
files_list: List[Tuple[str, Any]] = []
# Handle image parameter
if _image_list is not None:
image_list = (
[_image_list] if not isinstance(_image_list, list) else _image_list
)
for _image in image_list:
if _image is not None:
self._add_image_to_files(
files_list=files_list,
image=_image,
field_name="image[]",
)
# Handle mask parameter if provided
if _mask is not None:
# Handle case where mask can be a list (extract first mask)
if isinstance(_mask, list):
_mask = _mask[0] if _mask else None
if _mask is not None:
mask_content_type: str = ImageEditRequestUtils.get_image_content_type(
_mask
)
if isinstance(_mask, BufferedReader):
files_list.append(("mask", (_mask.name, _mask, mask_content_type)))
else:
files_list.append(("mask", ("mask.png", _mask, mask_content_type)))
return data_without_files, files_list
def transform_image_edit_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ImageResponse:
"""No transform applied since outputs are in OpenAI spec already"""
try:
raw_response_json = raw_response.json()
except Exception:
raise OpenAIError(
message=raw_response.text, status_code=raw_response.status_code
)
return ImageResponse(**raw_response_json)
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
api_key = (
api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
)
headers.update(
{
"Authorization": f"Bearer {api_key}",
}
)
return headers
def get_complete_url(
self,
model: str,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the endpoint for OpenAI responses API
"""
api_base = (
api_base
or litellm.api_base
or get_secret_str("OPENAI_BASE_URL")
or get_secret_str("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
# Remove trailing slashes
api_base = api_base.rstrip("/")
return f"{api_base}/images/edits"

View File

@@ -0,0 +1,28 @@
from litellm.llms.base_llm.image_generation.transformation import (
BaseImageGenerationConfig,
)
from .dall_e_2_transformation import DallE2ImageGenerationConfig
from .dall_e_3_transformation import DallE3ImageGenerationConfig
from .gpt_transformation import GPTImageGenerationConfig
from .guardrail_translation import (
OpenAIImageGenerationHandler,
guardrail_translation_mappings,
)
__all__ = [
"DallE2ImageGenerationConfig",
"DallE3ImageGenerationConfig",
"GPTImageGenerationConfig",
"OpenAIImageGenerationHandler",
"guardrail_translation_mappings",
]
def get_openai_image_generation_config(model: str) -> BaseImageGenerationConfig:
if model.startswith("dall-e-2") or model == "": # empty model is dall-e-2
return DallE2ImageGenerationConfig()
elif model.startswith("dall-e-3"):
return DallE3ImageGenerationConfig()
else:
return GPTImageGenerationConfig()

View File

@@ -0,0 +1,69 @@
"""
Cost calculator for OpenAI image generation models (gpt-image-1, gpt-image-1-mini)
These models use token-based pricing instead of pixel-based pricing like DALL-E.
"""
from typing import Optional
from litellm import verbose_logger
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
from litellm.types.utils import ImageResponse, Usage
def cost_calculator(
model: str,
image_response: ImageResponse,
custom_llm_provider: Optional[str] = None,
) -> float:
"""
Calculate cost for OpenAI gpt-image-1 and gpt-image-1-mini models.
Uses the same usage format as Responses API, so we reuse the helper
to transform to chat completion format and use generic_cost_per_token.
Args:
model: The model name (e.g., "gpt-image-1", "gpt-image-1-mini")
image_response: The ImageResponse containing usage data
custom_llm_provider: Optional provider name
Returns:
float: Total cost in USD
"""
usage = getattr(image_response, "usage", None)
if usage is None:
verbose_logger.debug(
f"No usage data available for {model}, cannot calculate token-based cost"
)
return 0.0
# If usage is already a Usage object with completion_tokens_details set,
# use it directly (it was already transformed in convert_to_image_response)
if isinstance(usage, Usage) and usage.completion_tokens_details is not None:
chat_usage = usage
else:
# Transform ImageUsage to Usage using the existing helper
# ImageUsage has the same format as ResponseAPIUsage
from litellm.responses.utils import ResponseAPILoggingUtils
chat_usage = (
ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(usage)
)
# Use generic_cost_per_token for cost calculation
prompt_cost, completion_cost = generic_cost_per_token(
model=model,
usage=chat_usage,
custom_llm_provider=custom_llm_provider or "openai",
)
total_cost = prompt_cost + completion_cost
verbose_logger.debug(
f"OpenAI gpt-image cost calculation for {model}: "
f"prompt_cost=${prompt_cost:.6f}, completion_cost=${completion_cost:.6f}, "
f"total=${total_cost:.6f}"
)
return total_cost

View File

@@ -0,0 +1,87 @@
from typing import TYPE_CHECKING, Any, List, Optional
import httpx
from litellm.llms.base_llm.image_generation.transformation import (
BaseImageGenerationConfig,
)
from litellm.types.llms.openai import OpenAIImageGenerationOptionalParams
from litellm.types.utils import ImageResponse
from litellm.utils import convert_to_model_response_object
if TYPE_CHECKING:
from litellm.litellm_core_utils.logging import Logging as LiteLLMLoggingObj
class DallE2ImageGenerationConfig(BaseImageGenerationConfig):
"""
OpenAI dall-e-2 image generation config
"""
def get_supported_openai_params(
self, model: str
) -> List[OpenAIImageGenerationOptionalParams]:
return ["n", "response_format", "quality", "size", "user"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_params = self.get_supported_openai_params(model)
for k in non_default_params.keys():
if k not in optional_params.keys():
if k in supported_params:
optional_params[k] = non_default_params[k]
elif drop_params:
pass
else:
raise ValueError(
f"Parameter {k} is not supported for model {model}. Supported parameters are {supported_params}. Set drop_params=True to drop unsupported parameters."
)
return optional_params
def transform_image_generation_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ImageResponse,
logging_obj: "LiteLLMLoggingObj",
request_data: dict,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ImageResponse:
response = raw_response.json()
stringified_response = response
## LOGGING
logging_obj.post_call(
input=request_data.get("prompt", ""),
api_key=api_key,
additional_args={"complete_input_dict": request_data},
original_response=stringified_response,
)
image_response: ImageResponse = convert_to_model_response_object( # type: ignore
response_object=stringified_response,
model_response_object=model_response,
response_type="image_generation",
)
# set optional params
image_response.size = optional_params.get(
"size", "1024x1024"
) # default is always 1024x1024
image_response.quality = optional_params.get(
"quality", "standard"
) # always standard for dall-e-2
image_response.output_format = optional_params.get(
"output_format", "png"
) # always png for dall-e-2
return image_response

View File

@@ -0,0 +1,87 @@
from typing import TYPE_CHECKING, Any, List, Optional
import httpx
from litellm.llms.base_llm.image_generation.transformation import (
BaseImageGenerationConfig,
)
from litellm.types.llms.openai import OpenAIImageGenerationOptionalParams
from litellm.types.utils import ImageResponse
from litellm.utils import convert_to_model_response_object
if TYPE_CHECKING:
from litellm.litellm_core_utils.logging import Logging as LiteLLMLoggingObj
class DallE3ImageGenerationConfig(BaseImageGenerationConfig):
"""
OpenAI dall-e-3 image generation config
"""
def get_supported_openai_params(
self, model: str
) -> List[OpenAIImageGenerationOptionalParams]:
return ["n", "response_format", "quality", "size", "user", "style"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_params = self.get_supported_openai_params(model)
for k in non_default_params.keys():
if k not in optional_params.keys():
if k in supported_params:
optional_params[k] = non_default_params[k]
elif drop_params:
pass
else:
raise ValueError(
f"Parameter {k} is not supported for model {model}. Supported parameters are {supported_params}. Set drop_params=True to drop unsupported parameters."
)
return optional_params
def transform_image_generation_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ImageResponse,
logging_obj: "LiteLLMLoggingObj",
request_data: dict,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ImageResponse:
response = raw_response.json()
stringified_response = response
## LOGGING
logging_obj.post_call(
input=request_data.get("prompt", ""),
api_key=api_key,
additional_args={"complete_input_dict": request_data},
original_response=stringified_response,
)
image_response: ImageResponse = convert_to_model_response_object( # type: ignore
response_object=stringified_response,
model_response_object=model_response,
response_type="image_generation",
)
# set optional params
image_response.size = optional_params.get(
"size", "1024x1024"
) # default is always 1024x1024
image_response.quality = optional_params.get(
"quality", "hd"
) # always hd for dall-e-3
image_response.output_format = optional_params.get(
"output_format", "png"
) # always png for dall-e-3
return image_response

View File

@@ -0,0 +1,96 @@
from typing import TYPE_CHECKING, Any, List, Optional
import httpx
from litellm.llms.base_llm.image_generation.transformation import (
BaseImageGenerationConfig,
)
from litellm.types.llms.openai import OpenAIImageGenerationOptionalParams
from litellm.types.utils import ImageResponse
from litellm.utils import convert_to_model_response_object
if TYPE_CHECKING:
from litellm.litellm_core_utils.logging import Logging as LiteLLMLoggingObj
class GPTImageGenerationConfig(BaseImageGenerationConfig):
"""
OpenAI gpt-image-1 image generation config
"""
def get_supported_openai_params(
self, model: str
) -> List[OpenAIImageGenerationOptionalParams]:
return [
"background",
"moderation",
"n",
"output_compression",
"output_format",
"quality",
"size",
"user",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_params = self.get_supported_openai_params(model)
for k in non_default_params.keys():
if k not in optional_params.keys():
if k in supported_params:
optional_params[k] = non_default_params[k]
elif drop_params:
pass
else:
raise ValueError(
f"Parameter {k} is not supported for model {model}. Supported parameters are {supported_params}. Set drop_params=True to drop unsupported parameters."
)
return optional_params
def transform_image_generation_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ImageResponse,
logging_obj: "LiteLLMLoggingObj",
request_data: dict,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ImageResponse:
response = raw_response.json()
stringified_response = response
## LOGGING
logging_obj.post_call(
input=request_data.get("prompt", ""),
api_key=api_key,
additional_args={"complete_input_dict": request_data},
original_response=stringified_response,
)
image_response: ImageResponse = convert_to_model_response_object( # type: ignore
response_object=stringified_response,
model_response_object=model_response,
response_type="image_generation",
)
# set optional params
image_response.size = optional_params.get(
"size", "1024x1024"
) # default is always 1024x1024
image_response.quality = optional_params.get(
"quality", "high"
) # always hd for dall-e-3
image_response.output_format = optional_params.get(
"response_format", "png"
) # always png for dall-e-3
return image_response

View File

@@ -0,0 +1,106 @@
# OpenAI Image Generation Guardrail Translation Handler
Handler for processing OpenAI's image generation endpoint with guardrails.
## Overview
This handler processes image generation requests by:
1. Extracting the text prompt from the request
2. Applying guardrails to the prompt text
3. Updating the request with the guardrailed prompt
## Data Format
### Input Format
```json
{
"model": "dall-e-3",
"prompt": "A cute baby sea otter",
"n": 1,
"size": "1024x1024",
"quality": "standard"
}
```
### Output Format
```json
{
"created": 1589478378,
"data": [
{
"url": "https://...",
"revised_prompt": "A cute baby sea otter..."
}
]
}
```
## Usage
The handler is automatically discovered and applied when guardrails are used with the image generation endpoint.
### Example: Using Guardrails with Image Generation
```bash
curl -X POST 'http://localhost:4000/v1/images/generations' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer your-api-key' \
-d '{
"model": "dall-e-3",
"prompt": "A cute baby sea otter wearing a hat",
"guardrails": ["content_moderation"],
"size": "1024x1024"
}'
```
The guardrail will be applied to the prompt text before the image generation request is sent to the provider.
### Example: PII Masking in Prompts
```bash
curl -X POST 'http://localhost:4000/v1/images/generations' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer your-api-key' \
-d '{
"model": "dall-e-3",
"prompt": "Generate an image of John Doe at john@example.com",
"guardrails": ["mask_pii"],
"metadata": {
"guardrails": ["mask_pii"]
}
}'
```
## Implementation Details
### Input Processing
- **Field**: `prompt` (string)
- **Processing**: Applies guardrail to prompt text
- **Result**: Updated prompt in request
### Output Processing
- **Processing**: Not applicable (images don't contain text to guardrail)
- **Result**: Response returned unchanged
## Extension
Override these methods to customize behavior:
- `process_input_messages()`: Customize how the prompt is processed
- `process_output_response()`: Add custom processing for image metadata if needed
## Supported Call Types
- `CallTypes.image_generation` - Synchronous image generation
- `CallTypes.aimage_generation` - Asynchronous image generation
## Notes
- The handler only processes the `prompt` parameter
- Output processing is a no-op since images don't contain text
- Both sync and async call types use the same handler

View File

@@ -0,0 +1,13 @@
"""OpenAI Image Generation handler for Unified Guardrails."""
from litellm.llms.openai.image_generation.guardrail_translation.handler import (
OpenAIImageGenerationHandler,
)
from litellm.types.utils import CallTypes
guardrail_translation_mappings = {
CallTypes.image_generation: OpenAIImageGenerationHandler,
CallTypes.aimage_generation: OpenAIImageGenerationHandler,
}
__all__ = ["guardrail_translation_mappings", "OpenAIImageGenerationHandler"]

View File

@@ -0,0 +1,110 @@
"""
OpenAI Image Generation Handler for Unified Guardrails
This module provides guardrail translation support for OpenAI's image generation endpoint.
The handler processes the 'prompt' parameter for guardrails.
"""
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
from litellm.types.utils import GenericGuardrailAPIInputs
if TYPE_CHECKING:
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.utils import ImageResponse
class OpenAIImageGenerationHandler(BaseTranslation):
"""
Handler for processing OpenAI image generation requests with guardrails.
This class provides methods to:
1. Process input prompt (pre-call hook)
2. Process output response (post-call hook) - typically not needed for images
The handler specifically processes the 'prompt' parameter which contains
the text description for image generation.
"""
async def process_input_messages(
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process input prompt by applying guardrails to text content.
Args:
data: Request data dictionary containing 'prompt' parameter
guardrail_to_apply: The guardrail instance to apply
Returns:
Modified data with guardrails applied to prompt
"""
prompt = data.get("prompt")
if prompt is None:
verbose_proxy_logger.debug(
"OpenAI Image Generation: No prompt found in request data"
)
return data
# Apply guardrail to the prompt
if isinstance(prompt, str):
inputs = GenericGuardrailAPIInputs(texts=[prompt])
# Include model information if available
model = data.get("model")
if model:
inputs["model"] = model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=data,
input_type="request",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
data["prompt"] = guardrailed_texts[0] if guardrailed_texts else prompt
verbose_proxy_logger.debug(
"OpenAI Image Generation: Applied guardrail to prompt. "
"Original length: %d, New length: %d",
len(prompt),
len(data["prompt"]),
)
else:
verbose_proxy_logger.debug(
"OpenAI Image Generation: Unexpected prompt type: %s. Expected string.",
type(prompt),
)
return data
async def process_output_response(
self,
response: "ImageResponse",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
) -> Any:
"""
Process output response - typically not needed for image generation.
Image responses don't contain text to apply guardrails to, so this
method returns the response unchanged. This is provided for completeness
and can be overridden if needed for custom image metadata processing.
Args:
response: Image generation response object
guardrail_to_apply: The guardrail instance to apply
litellm_logging_obj: Optional logging object (unused)
user_api_key_dict: User API key metadata (unused)
Returns:
Unmodified response (images don't need text guardrails)
"""
verbose_proxy_logger.debug(
"OpenAI Image Generation: Output processing not needed for image responses"
)
return response

View File

@@ -0,0 +1,244 @@
"""
OpenAI Image Variations Handler
"""
from typing import Callable, Optional
import httpx
from openai import AsyncOpenAI, OpenAI
import litellm
from litellm.types.utils import FileTypes, ImageResponse, LlmProviders
from litellm.utils import ProviderConfigManager
from ...base_llm.image_variations.transformation import BaseImageVariationConfig
from ...custom_httpx.llm_http_handler import LiteLLMLoggingObj
from ..common_utils import OpenAIError
class OpenAIImageVariationsHandler:
def get_sync_client(
self,
client: Optional[OpenAI],
init_client_params: dict,
):
if client is None:
openai_client = OpenAI(
**init_client_params,
)
else:
openai_client = client
return openai_client
def get_async_client(
self, client: Optional[AsyncOpenAI], init_client_params: dict
) -> AsyncOpenAI:
if client is None:
openai_client = AsyncOpenAI(
**init_client_params,
)
else:
openai_client = client
return openai_client
async def async_image_variations(
self,
api_key: str,
api_base: str,
organization: Optional[str],
client: Optional[AsyncOpenAI],
data: dict,
headers: dict,
model: Optional[str],
timeout: Optional[float],
max_retries: int,
logging_obj: LiteLLMLoggingObj,
model_response: ImageResponse,
optional_params: dict,
litellm_params: dict,
image: FileTypes,
provider_config: BaseImageVariationConfig,
) -> ImageResponse:
try:
init_client_params = {
"api_key": api_key,
"base_url": api_base,
"http_client": litellm.client_session,
"timeout": timeout,
"max_retries": max_retries, # type: ignore
"organization": organization,
}
client = self.get_async_client(
client=client, init_client_params=init_client_params
)
raw_response = await client.images.with_raw_response.create_variation(**data) # type: ignore
response = raw_response.parse()
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
api_key=api_key,
original_response=response_json,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return provider_config.transform_response_image_variation(
model=model,
model_response=ImageResponse(**response_json),
raw_response=httpx.Response(
status_code=200,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
),
logging_obj=logging_obj,
request_data=data,
image=image,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=None,
api_key=api_key,
)
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
def image_variations(
self,
model_response: ImageResponse,
api_key: str,
api_base: str,
model: Optional[str],
image: FileTypes,
timeout: Optional[float],
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
litellm_params: dict,
print_verbose: Optional[Callable] = None,
logger_fn=None,
client=None,
organization: Optional[str] = None,
headers: Optional[dict] = None,
) -> ImageResponse:
try:
provider_config = ProviderConfigManager.get_provider_image_variation_config(
model=model or "", # openai defaults to dall-e-2
provider=LlmProviders.OPENAI,
)
if provider_config is None:
raise ValueError(
f"image variation provider not found: {custom_llm_provider}."
)
max_retries = optional_params.pop("max_retries", 2)
data = provider_config.transform_request_image_variation(
model=model,
image=image,
optional_params=optional_params,
headers=headers or {},
)
json_data = data.get("data")
if not json_data:
raise ValueError(
f"data field is required, for openai image variations. Got={data}"
)
## LOGGING
logging_obj.pre_call(
input="",
api_key=api_key,
additional_args={
"headers": headers,
"api_base": api_base,
"complete_input_dict": data,
},
)
if litellm_params.get("async_call", False):
return self.async_image_variations(
api_base=api_base,
data=json_data,
headers=headers or {},
model_response=model_response,
api_key=api_key,
logging_obj=logging_obj,
model=model,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
provider_config=provider_config,
image=image,
optional_params=optional_params,
litellm_params=litellm_params,
) # type: ignore
init_client_params = {
"api_key": api_key,
"base_url": api_base,
"http_client": litellm.client_session,
"timeout": timeout,
"max_retries": max_retries, # type: ignore
"organization": organization,
}
client = self.get_sync_client(
client=client, init_client_params=init_client_params
)
raw_response = client.images.with_raw_response.create_variation(**json_data) # type: ignore
response = raw_response.parse()
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
api_key=api_key,
original_response=response_json,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return provider_config.transform_response_image_variation(
model=model,
model_response=ImageResponse(**response_json),
raw_response=httpx.Response(
status_code=200,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
),
logging_obj=logging_obj,
request_data=json_data,
image=image,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=None,
api_key=api_key,
)
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)

View File

@@ -0,0 +1,82 @@
from typing import Any, List, Optional, Union
from aiohttp import ClientResponse
from httpx import Headers, Response
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.image_variations.transformation import LiteLLMLoggingObj
from litellm.types.llms.openai import OpenAIImageVariationOptionalParams
from litellm.types.utils import FileTypes, HttpHandlerRequestFields, ImageResponse
from ...base_llm.image_variations.transformation import BaseImageVariationConfig
from ..common_utils import OpenAIError
class OpenAIImageVariationConfig(BaseImageVariationConfig):
def get_supported_openai_params(
self, model: str
) -> List[OpenAIImageVariationOptionalParams]:
return ["n", "size", "response_format", "user"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
optional_params.update(non_default_params)
return optional_params
def transform_request_image_variation(
self,
model: Optional[str],
image: FileTypes,
optional_params: dict,
headers: dict,
) -> HttpHandlerRequestFields:
return {
"data": {
"image": image,
**optional_params,
}
}
async def async_transform_response_image_variation(
self,
model: Optional[str],
raw_response: ClientResponse,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
image: FileTypes,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
) -> ImageResponse:
return model_response
def transform_response_image_variation(
self,
model: Optional[str],
raw_response: Response,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
image: FileTypes,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
) -> ImageResponse:
return model_response
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return OpenAIError(
status_code=status_code,
message=error_message,
headers=headers,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,167 @@
"""
This file contains the calling OpenAI's `/v1/realtime` endpoint.
This requires websockets, and is currently only supported on LiteLLM Proxy.
"""
from typing import Any, Optional, cast
from litellm.constants import REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES
from litellm.types.realtime import RealtimeQueryParams
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ....litellm_core_utils.realtime_streaming import RealTimeStreaming
from ....llms.custom_httpx.http_handler import get_shared_realtime_ssl_context
from ..openai import OpenAIChatCompletion
class OpenAIRealtime(OpenAIChatCompletion):
"""
Base handler for OpenAI-compatible realtime WebSocket connections.
Subclasses can override template methods to customize:
- _get_default_api_base(): Default API base URL
- _get_additional_headers(): Extra headers beyond Authorization
- _get_ssl_config(): SSL configuration for WebSocket connection
"""
def _get_default_api_base(self) -> str:
"""
Get the default API base URL for this provider.
Override this in subclasses to set provider-specific defaults.
"""
return "https://api.openai.com/"
def _get_additional_headers(self, api_key: str) -> dict:
"""
Get additional headers beyond Authorization.
Override this in subclasses to customize headers (e.g., remove OpenAI-Beta).
Args:
api_key: API key for authentication
Returns:
Dictionary of additional headers
"""
return {
"Authorization": f"Bearer {api_key}",
"OpenAI-Beta": "realtime=v1",
}
def _get_ssl_config(self, url: str) -> Any:
"""
Get SSL configuration for WebSocket connection.
Override this in subclasses to customize SSL behavior.
Args:
url: WebSocket URL (ws:// or wss://)
Returns:
SSL configuration (None, True, or SSLContext)
"""
if url.startswith("ws://"):
return None
# Use the shared SSL context which respects custom CA certs and SSL settings
ssl_config = get_shared_realtime_ssl_context()
# If ssl_config is False (ssl_verify=False), websockets library needs True instead
# to establish connection without verification (False would fail)
if ssl_config is False:
return True
return ssl_config
def _construct_url(self, api_base: str, query_params: RealtimeQueryParams) -> str:
"""
Construct the backend websocket URL with all query parameters (including 'model').
"""
from httpx import URL
api_base = api_base.replace("https://", "wss://")
api_base = api_base.replace("http://", "ws://")
url = URL(api_base)
# Set the correct path
url = url.copy_with(path="/v1/realtime")
# Include all query parameters including 'model'
if query_params:
url = url.copy_with(params=query_params)
return str(url)
async def async_realtime(
self,
model: str,
websocket: Any,
logging_obj: LiteLLMLogging,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
client: Optional[Any] = None,
timeout: Optional[float] = None,
query_params: Optional[RealtimeQueryParams] = None,
user_api_key_dict: Optional[Any] = None,
litellm_metadata: Optional[dict] = None,
**kwargs: Any,
):
import websockets
from websockets.asyncio.client import ClientConnection
if api_base is None:
api_base = self._get_default_api_base()
if api_key is None:
raise ValueError("api_key is required for OpenAI realtime calls")
# Use all query params if provided, else fallback to just model
if query_params is None:
query_params = {"model": model}
url = self._construct_url(api_base, query_params)
try:
# Get provider-specific SSL configuration
ssl_config = self._get_ssl_config(url)
# Get provider-specific headers
headers = self._get_additional_headers(api_key)
# Log a masked request preview consistent with other endpoints.
logging_obj.pre_call(
input=None,
api_key=api_key,
additional_args={
"api_base": url,
"headers": headers,
"complete_input_dict": {"query_params": query_params},
},
)
async with websockets.connect( # type: ignore
url,
additional_headers=headers, # type: ignore
max_size=REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES,
ssl=ssl_config,
) as backend_ws:
realtime_streaming = RealTimeStreaming(
websocket,
cast(ClientConnection, backend_ws),
logging_obj,
user_api_key_dict=user_api_key_dict,
request_data={"litellm_metadata": litellm_metadata or {}},
)
await realtime_streaming.bidirectional_forward()
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
await websocket.close(code=e.status_code, reason=str(e))
except Exception as e:
try:
await websocket.close(
code=1011, reason=f"Internal server error: {str(e)}"
)
except RuntimeError as close_error:
if "already completed" in str(close_error) or "websocket.close" in str(
close_error
):
# The WebSocket is already closed or the response is completed, so we can ignore this error
pass
else:
# If it's a different RuntimeError, we might want to log it or handle it differently
raise Exception(
f"Unexpected error while closing WebSocket: {close_error}"
)

View File

@@ -0,0 +1,54 @@
"""OpenAI realtime HTTP transformation config (client_secrets + realtime_calls)."""
from typing import Optional
import litellm
from litellm.llms.base_llm.realtime.http_transformation import BaseRealtimeHTTPConfig
from litellm.secret_managers.main import get_secret_str
class OpenAIRealtimeHTTPConfig(BaseRealtimeHTTPConfig):
def get_api_base(self, api_base: Optional[str], **kwargs) -> str:
return (
api_base
or litellm.api_base
or get_secret_str("OPENAI_API_BASE")
or "https://api.openai.com"
)
def get_api_key(self, api_key: Optional[str], **kwargs) -> str:
return (
api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
or ""
)
def get_complete_url(
self, api_base: Optional[str], model: str, api_version: Optional[str] = None
) -> str:
base = self.get_api_base(api_base).rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
return f"{base}/v1/realtime/client_secrets"
def get_realtime_calls_url(
self, api_base: Optional[str], model: str, api_version: Optional[str] = None
) -> str:
base = self.get_api_base(api_base).rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
return f"{base}/v1/realtime/calls"
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
return {
**headers,
"Authorization": f"Bearer {api_key or ''}",
"Content-Type": "application/json",
}

View File

@@ -0,0 +1,19 @@
"""
OpenAI Responses API token counting implementation.
"""
from litellm.llms.openai.responses.count_tokens.handler import (
OpenAICountTokensHandler,
)
from litellm.llms.openai.responses.count_tokens.token_counter import (
OpenAITokenCounter,
)
from litellm.llms.openai.responses.count_tokens.transformation import (
OpenAICountTokensConfig,
)
__all__ = [
"OpenAICountTokensHandler",
"OpenAICountTokensConfig",
"OpenAITokenCounter",
]

View File

@@ -0,0 +1,107 @@
"""
OpenAI Responses API token counting handler.
Uses httpx for HTTP requests to OpenAI's /v1/responses/input_tokens endpoint.
"""
import json
from typing import Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.llms.openai.common_utils import OpenAIError
from litellm.llms.openai.responses.count_tokens.transformation import (
OpenAICountTokensConfig,
)
class OpenAICountTokensHandler(OpenAICountTokensConfig):
"""
Handler for OpenAI Responses API token counting requests.
"""
async def handle_count_tokens_request(
self,
model: str,
input: Union[str, List[Any]],
api_key: str,
api_base: Optional[str] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
instructions: Optional[str] = None,
) -> Dict[str, Any]:
"""
Handle a token counting request to OpenAI's Responses API.
Returns:
Dictionary containing {"input_tokens": <number>}
Raises:
OpenAIError: If the API request fails
"""
try:
self.validate_request(model, input)
verbose_logger.debug(
f"Processing OpenAI CountTokens request for model: {model}"
)
request_body = self.transform_request_to_count_tokens(
model=model,
input=input,
tools=tools,
instructions=instructions,
)
endpoint_url = self.get_openai_count_tokens_endpoint(api_base)
verbose_logger.debug(f"Making request to: {endpoint_url}")
headers = self.get_required_headers(api_key)
async_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.OPENAI
)
request_timeout = (
timeout if timeout is not None else litellm.request_timeout
)
response = await async_client.post(
endpoint_url,
headers=headers,
json=request_body,
timeout=request_timeout,
)
verbose_logger.debug(f"Response status: {response.status_code}")
if response.status_code != 200:
error_text = response.text
verbose_logger.error(f"OpenAI API error: {error_text}")
raise OpenAIError(
status_code=response.status_code,
message=error_text,
)
openai_response = response.json()
verbose_logger.debug(f"OpenAI response: {openai_response}")
return openai_response
except OpenAIError:
raise
except httpx.HTTPStatusError as e:
verbose_logger.error(f"HTTP error in CountTokens handler: {str(e)}")
raise OpenAIError(
status_code=e.response.status_code,
message=e.response.text,
)
except (httpx.RequestError, json.JSONDecodeError, ValueError) as e:
verbose_logger.error(f"Error in CountTokens handler: {str(e)}")
raise OpenAIError(
status_code=500,
message=f"CountTokens processing error: {str(e)}",
)

View File

@@ -0,0 +1,118 @@
"""
OpenAI Token Counter implementation using the Responses API /input_tokens endpoint.
"""
import os
from typing import Any, Dict, List, Optional
from litellm._logging import verbose_logger
from litellm.llms.base_llm.base_utils import BaseTokenCounter
from litellm.llms.openai.common_utils import OpenAIError
from litellm.llms.openai.responses.count_tokens.handler import (
OpenAICountTokensHandler,
)
from litellm.llms.openai.responses.count_tokens.transformation import (
OpenAICountTokensConfig,
)
from litellm.types.utils import LlmProviders, TokenCountResponse
# Global handler instance - reuse across all token counting requests
openai_count_tokens_handler = OpenAICountTokensHandler()
class OpenAITokenCounter(BaseTokenCounter):
"""Token counter implementation for OpenAI provider using the Responses API."""
def should_use_token_counting_api(
self,
custom_llm_provider: Optional[str] = None,
) -> bool:
return custom_llm_provider == LlmProviders.OPENAI.value
async def count_tokens(
self,
model_to_use: str,
messages: Optional[List[Dict[str, Any]]],
contents: Optional[List[Dict[str, Any]]],
deployment: Optional[Dict[str, Any]] = None,
request_model: str = "",
tools: Optional[List[Dict[str, Any]]] = None,
system: Optional[Any] = None,
) -> Optional[TokenCountResponse]:
"""
Count tokens using OpenAI's Responses API /input_tokens endpoint.
"""
if not messages:
return None
deployment = deployment or {}
litellm_params = deployment.get("litellm_params", {})
# Get OpenAI API key from deployment config or environment
api_key = litellm_params.get("api_key")
if not api_key:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
verbose_logger.warning("No OpenAI API key found for token counting")
return None
api_base = litellm_params.get("api_base")
# Convert chat messages to Responses API input format
input_items, instructions = OpenAICountTokensConfig.messages_to_responses_input(
messages
)
# Use system param if instructions not extracted from messages
if instructions is None and system is not None:
instructions = system if isinstance(system, str) else str(system)
# If no input items were produced (e.g., system-only messages), fall back to local counting
if not input_items:
return None
try:
result = await openai_count_tokens_handler.handle_count_tokens_request(
model=model_to_use,
input=input_items if input_items is not None else [],
api_key=api_key,
api_base=api_base,
tools=tools,
instructions=instructions,
)
if result is not None:
return TokenCountResponse(
total_tokens=result.get("input_tokens", 0),
request_model=request_model,
model_used=model_to_use,
tokenizer_type="openai_api",
original_response=result,
)
except OpenAIError as e:
verbose_logger.warning(
f"OpenAI CountTokens API error: status={e.status_code}, message={e.message}"
)
return TokenCountResponse(
total_tokens=0,
request_model=request_model,
model_used=model_to_use,
tokenizer_type="openai_api",
error=True,
error_message=e.message,
status_code=e.status_code,
)
except Exception as e:
verbose_logger.warning(f"Error calling OpenAI CountTokens API: {e}")
return TokenCountResponse(
total_tokens=0,
request_model=request_model,
model_used=model_to_use,
tokenizer_type="openai_api",
error=True,
error_message=str(e),
status_code=500,
)
return None

View File

@@ -0,0 +1,160 @@
"""
OpenAI Responses API token counting transformation logic.
This module handles the transformation of requests to OpenAI's /v1/responses/input_tokens endpoint.
"""
from typing import Any, Dict, List, Optional, Union
class OpenAICountTokensConfig:
"""
Configuration and transformation logic for OpenAI Responses API token counting.
OpenAI Responses API Token Counting Specification:
- Endpoint: POST https://api.openai.com/v1/responses/input_tokens
- Response: {"input_tokens": <number>}
"""
def get_openai_count_tokens_endpoint(self, api_base: Optional[str] = None) -> str:
base = api_base or "https://api.openai.com/v1"
base = base.rstrip("/")
return f"{base}/responses/input_tokens"
def transform_request_to_count_tokens(
self,
model: str,
input: Union[str, List[Any]],
tools: Optional[List[Dict[str, Any]]] = None,
instructions: Optional[str] = None,
) -> Dict[str, Any]:
"""
Transform request to OpenAI Responses API token counting format.
The Responses API uses `input` (not `messages`) and `instructions` (not `system`).
"""
request: Dict[str, Any] = {
"model": model,
"input": input,
}
if instructions is not None:
request["instructions"] = instructions
if tools is not None:
request["tools"] = self._transform_tools_for_responses_api(tools)
return request
def get_required_headers(self, api_key: str) -> Dict[str, str]:
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
def validate_request(self, model: str, input: Union[str, List[Any]]) -> None:
if not model:
raise ValueError("model parameter is required")
if not input:
raise ValueError("input parameter is required")
@staticmethod
def _transform_tools_for_responses_api(
tools: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""
Transform OpenAI chat tools format to Responses API tools format.
Chat format: {"type": "function", "function": {"name": "...", "parameters": {...}}}
Responses format: {"type": "function", "name": "...", "parameters": {...}}
"""
transformed = []
for tool in tools:
if tool.get("type") == "function" and "function" in tool:
func = tool["function"]
item: Dict[str, Any] = {
"type": "function",
"name": func.get("name", ""),
"description": func.get("description", ""),
"parameters": func.get("parameters", {}),
}
if "strict" in func:
item["strict"] = func["strict"]
transformed.append(item)
else:
# Pass through non-function tools (e.g., web_search, file_search)
transformed.append(tool)
return transformed
@staticmethod
def messages_to_responses_input(
messages: List[Dict[str, Any]],
) -> tuple:
"""
Convert standard chat messages format to OpenAI Responses API input format.
Returns:
(input_items, instructions) tuple where instructions is extracted
from system/developer messages.
"""
input_items: List[Dict[str, Any]] = []
instructions_parts: List[str] = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content") or ""
if role in ("system", "developer"):
# Extract system/developer messages as instructions
if isinstance(content, str):
instructions_parts.append(content)
elif isinstance(content, list):
# Handle content blocks - extract text
text_parts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text_parts.append(block.get("text", ""))
elif isinstance(block, str):
text_parts.append(block)
instructions_parts.append("\n".join(text_parts))
elif role == "user":
if isinstance(content, list):
# Extract text from content blocks for Responses API
text_parts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text_parts.append(block.get("text", ""))
elif isinstance(block, str):
text_parts.append(block)
content = "\n".join(text_parts)
input_items.append({"role": "user", "content": content})
elif role == "assistant":
# Map tool_calls to Responses API function_call items
tool_calls = msg.get("tool_calls")
if content:
input_items.append({"role": "assistant", "content": content})
if tool_calls:
for tc in tool_calls:
func = tc.get("function", {})
input_items.append(
{
"type": "function_call",
"call_id": tc.get("id", ""),
"name": func.get("name", ""),
"arguments": func.get("arguments", ""),
}
)
elif not content:
input_items.append({"role": "assistant", "content": content})
elif role == "tool":
input_items.append(
{
"type": "function_call_output",
"call_id": msg.get("tool_call_id", ""),
"output": content if isinstance(content, str) else str(content),
}
)
instructions = "\n".join(instructions_parts) if instructions_parts else None
return input_items, instructions

View File

@@ -0,0 +1,119 @@
# OpenAI Responses API Guardrail Translation Handler
This module provides guardrail translation support for the OpenAI Responses API format.
## Overview
The `OpenAIResponsesHandler` class handles the translation of guardrail operations for both input and output of the Responses API. It follows the same pattern as the Chat Completions handler but is adapted for the Responses API's specific data structures.
## Responses API Format
### Input Format
The Responses API accepts input in two formats:
1. **String input**: Simple text string
```python
{"input": "Hello world", "model": "gpt-4"}
```
2. **List input**: Array of message objects (ResponseInputParam)
```python
{
"input": [
{
"role": "user",
"content": "Hello", # Can be string or list of content items
"type": "message"
}
],
"model": "gpt-4"
}
```
### Output Format
The Responses API returns a `ResponsesAPIResponse` object with:
```python
{
"id": "resp_123",
"output": [
{
"type": "message",
"id": "msg_123",
"status": "completed",
"role": "assistant",
"content": [
{
"type": "output_text",
"text": "Assistant response",
"annotations": []
}
]
}
]
}
```
## Usage
The handler is automatically discovered and registered for `CallTypes.responses` and `CallTypes.aresponses`.
### Example
```python
from litellm.llms import get_guardrail_translation_mapping
from litellm.types.utils import CallTypes
# Get the handler
handler_class = get_guardrail_translation_mapping(CallTypes.responses)
handler = handler_class()
# Process input
data = {"input": "User message", "model": "gpt-4"}
processed_data = await handler.process_input_messages(data, guardrail_instance)
# Process output
response = await litellm.aresponses(**processed_data)
processed_response = await handler.process_output_response(response, guardrail_instance)
```
## Key Methods
### `process_input_messages(data, guardrail_to_apply)`
Processes input data by:
1. Handling both string and list input formats
2. Extracting text content from messages
3. Applying guardrails to text content in parallel
4. Mapping guardrail responses back to the original structure
### `process_output_response(response, guardrail_to_apply)`
Processes output response by:
1. Extracting text from output items' content
2. Applying guardrails to all text content in parallel
3. Replacing original text with guardrailed versions
## Extending the Handler
The handler can be customized by overriding these methods:
- `_extract_input_text_and_create_tasks()`: Customize input text extraction logic
- `_apply_guardrail_responses_to_input()`: Customize how guardrail responses are applied to input
- `_extract_output_text_and_create_tasks()`: Customize output text extraction logic
- `_apply_guardrail_responses_to_output()`: Customize how guardrail responses are applied to output
- `_has_text_content()`: Customize text content detection
## Testing
Comprehensive tests are available in `tests/llm_translation/test_openai_responses_guardrail_handler.py`:
```bash
pytest tests/llm_translation/test_openai_responses_guardrail_handler.py -v
```
## Implementation Details
- **Parallel Processing**: All text content is processed in parallel using `asyncio.gather()`
- **Mapping Tracking**: Uses tuples to track the location of each text segment for accurate replacement
- **Type Safety**: Handles both Pydantic objects and dict representations
- **Multimodal Support**: Properly handles mixed content with text and other media types

View File

@@ -0,0 +1,12 @@
"""OpenAI Responses API handler for Unified Guardrails."""
from litellm.llms.openai.responses.guardrail_translation.handler import (
OpenAIResponsesHandler,
)
from litellm.types.utils import CallTypes
guardrail_translation_mappings = {
CallTypes.responses: OpenAIResponsesHandler,
CallTypes.aresponses: OpenAIResponsesHandler,
}
__all__ = ["guardrail_translation_mappings"]

View File

@@ -0,0 +1,760 @@
"""
OpenAI Responses API Handler for Unified Guardrails
This module provides a class-based handler for OpenAI Responses API format.
The class methods can be overridden for custom behavior.
Pattern Overview:
-----------------
1. Extract text content from input/output (both string and list formats)
2. Create async tasks to apply guardrails to each text segment
3. Track mappings to know where each response belongs
4. Apply guardrail responses back to the original structure
Responses API Format:
---------------------
Input: Union[str, List[Dict]] where each dict has:
- role: str
- content: Union[str, List[Dict]] (can have text items)
- type: str (e.g., "message")
Output: response.output is List[GenericResponseOutputItem] where each has:
- type: str (e.g., "message")
- id: str
- status: str
- role: str
- content: List[OutputText] where OutputText has:
- type: str (e.g., "output_text")
- text: str
"""
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from pydantic import BaseModel
from litellm._logging import verbose_proxy_logger
from litellm.completion_extras.litellm_responses_transformation.transformation import (
LiteLLMResponsesTransformationHandler,
OpenAiResponsesToChatCompletionStreamIterator,
)
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
from litellm.responses.litellm_completion_transformation.transformation import (
LiteLLMCompletionResponsesConfig,
)
from litellm.types.llms.openai import (
ChatCompletionToolCallChunk,
ChatCompletionToolParam,
)
from litellm.types.responses.main import (
GenericResponseOutputItem,
OutputFunctionToolCall,
OutputText,
)
from litellm.types.utils import GenericGuardrailAPIInputs
if TYPE_CHECKING:
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.types.llms.openai import ResponseInputParam
from litellm.types.utils import ResponsesAPIResponse
class OpenAIResponsesHandler(BaseTranslation):
"""
Handler for processing OpenAI Responses API with guardrails.
This class provides methods to:
1. Process input (pre-call hook)
2. Process output response (post-call hook)
Methods can be overridden to customize behavior for different message formats.
"""
async def process_input_messages(
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process input by applying guardrails to text content.
Handles both string input and list of message objects.
"""
input_data: Optional[Union[str, "ResponseInputParam"]] = data.get("input")
tools_to_check: List[ChatCompletionToolParam] = []
if input_data is None:
return data
structured_messages = (
LiteLLMCompletionResponsesConfig.transform_responses_api_input_to_messages(
input=input_data,
responses_api_request=data,
)
)
# Handle simple string input
if isinstance(input_data, str):
inputs = GenericGuardrailAPIInputs(texts=[input_data])
original_tools: List[Dict[str, Any]] = []
# Extract and transform tools if present
if "tools" in data and data["tools"]:
original_tools = list(data["tools"])
self._extract_and_transform_tools(data["tools"], tools_to_check)
if tools_to_check:
inputs["tools"] = tools_to_check
if structured_messages:
inputs["structured_messages"] = structured_messages # type: ignore
# Include model information if available
model = data.get("model")
if model:
inputs["model"] = model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=data,
input_type="request",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
data["input"] = guardrailed_texts[0] if guardrailed_texts else input_data
self._apply_guardrailed_tools_to_data(
data, original_tools, guardrailed_inputs.get("tools")
)
verbose_proxy_logger.debug("OpenAI Responses API: Processed string input")
return data
# Handle list input (ResponseInputParam)
if not isinstance(input_data, list):
return data
texts_to_check: List[str] = []
images_to_check: List[str] = []
task_mappings: List[Tuple[int, Optional[int]]] = []
original_tools_list: List[Dict[str, Any]] = list(data.get("tools") or [])
# Step 1: Extract all text content, images, and tools
for msg_idx, message in enumerate(input_data):
self._extract_input_text_and_images(
message=message,
msg_idx=msg_idx,
texts_to_check=texts_to_check,
images_to_check=images_to_check,
task_mappings=task_mappings,
)
# Extract and transform tools if present
if "tools" in data and data["tools"]:
self._extract_and_transform_tools(data["tools"], tools_to_check)
# Step 2: Apply guardrail to all texts in batch
if texts_to_check:
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
if images_to_check:
inputs["images"] = images_to_check
if tools_to_check:
inputs["tools"] = tools_to_check
if structured_messages:
inputs["structured_messages"] = structured_messages # type: ignore
# Include model information if available
model = data.get("model")
if model:
inputs["model"] = model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=data,
input_type="request",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
self._apply_guardrailed_tools_to_data(
data,
original_tools_list,
guardrailed_inputs.get("tools"),
)
# Step 3: Map guardrail responses back to original input structure
await self._apply_guardrail_responses_to_input(
messages=input_data,
responses=guardrailed_texts,
task_mappings=task_mappings,
)
verbose_proxy_logger.debug(
"OpenAI Responses API: Processed input messages: %s", input_data
)
return data
def extract_request_tool_names(self, data: dict) -> List[str]:
"""Extract tool names from Responses API request (tools[].name for function, tools[].server_label for mcp)."""
names: List[str] = []
for tool in data.get("tools") or []:
if not isinstance(tool, dict):
continue
if tool.get("type") == "function" and tool.get("name"):
names.append(str(tool["name"]))
elif tool.get("type") == "mcp" and tool.get("server_label"):
names.append(str(tool["server_label"]))
return names
def _extract_and_transform_tools(
self,
tools: List[Dict[str, Any]],
tools_to_check: List[ChatCompletionToolParam],
) -> None:
"""
Extract and transform tools from Responses API format to Chat Completion format.
Uses the LiteLLM transformation function to convert Responses API tools
to Chat Completion tools that can be passed to guardrails.
"""
if tools is not None and isinstance(tools, list):
# Transform Responses API tools to Chat Completion tools
(
transformed_tools,
_,
) = LiteLLMCompletionResponsesConfig.transform_responses_api_tools_to_chat_completion_tools(
tools # type: ignore
)
tools_to_check.extend(
cast(List[ChatCompletionToolParam], transformed_tools)
)
def _remap_tools_to_responses_api_format(
self, guardrailed_tools: List[Any]
) -> List[Dict[str, Any]]:
"""
Remap guardrail-returned tools (Chat Completion format) back to
Responses API request tool format.
"""
return LiteLLMCompletionResponsesConfig.transform_chat_completion_tool_params_to_responses_api_tools(
guardrailed_tools # type: ignore
)
def _merge_tools_after_guardrail(
self,
original_tools: List[Dict[str, Any]],
remapped: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""
Merge remapped guardrailed tools with original tools that were not sent
to the guardrail (e.g. web_search, web_search_preview), preserving order.
"""
if not original_tools:
return remapped
result: List[Dict[str, Any]] = []
j = 0
for tool in original_tools:
if isinstance(tool, dict) and tool.get("type") in (
"web_search",
"web_search_preview",
):
result.append(tool)
else:
if j < len(remapped):
result.append(remapped[j])
j += 1
return result
def _apply_guardrailed_tools_to_data(
self,
data: dict,
original_tools: List[Dict[str, Any]],
guardrailed_tools: Optional[List[Any]],
) -> None:
"""Remap guardrailed tools to Responses API format and merge with original, then set data['tools']."""
if guardrailed_tools is not None:
remapped = self._remap_tools_to_responses_api_format(guardrailed_tools)
data["tools"] = self._merge_tools_after_guardrail(original_tools, remapped)
def _extract_input_text_and_images(
self,
message: Any, # Can be Dict[str, Any] or ResponseInputParam
msg_idx: int,
texts_to_check: List[str],
images_to_check: List[str],
task_mappings: List[Tuple[int, Optional[int]]],
) -> None:
"""
Extract text content and images from an input message.
Override this method to customize text/image extraction logic.
"""
content = message.get("content", None)
if content is None:
return
if isinstance(content, str):
# Simple string content
texts_to_check.append(content)
task_mappings.append((msg_idx, None))
elif isinstance(content, list):
# List content (e.g., multimodal with text and images)
for content_idx, content_item in enumerate(content):
if isinstance(content_item, dict):
# Extract text
text_str = content_item.get("text", None)
if text_str is not None:
texts_to_check.append(text_str)
task_mappings.append((msg_idx, int(content_idx)))
# Extract images
if content_item.get("type") == "image_url":
image_url = content_item.get("image_url", {})
if isinstance(image_url, dict):
url = image_url.get("url")
if url:
images_to_check.append(url)
async def _apply_guardrail_responses_to_input(
self,
messages: Any, # Can be List[Dict[str, Any]] or ResponseInputParam
responses: List[str],
task_mappings: List[Tuple[int, Optional[int]]],
) -> None:
"""
Apply guardrail responses back to input messages.
Override this method to customize how responses are applied.
"""
for task_idx, guardrail_response in enumerate(responses):
mapping = task_mappings[task_idx]
msg_idx = cast(int, mapping[0])
content_idx_optional = cast(Optional[int], mapping[1])
content = messages[msg_idx].get("content", None)
if content is None:
continue
if isinstance(content, str) and content_idx_optional is None:
# Replace string content with guardrail response
messages[msg_idx]["content"] = guardrail_response
elif isinstance(content, list) and content_idx_optional is not None:
# Replace specific text item in list content
if isinstance(messages[msg_idx]["content"][content_idx_optional], dict):
messages[msg_idx]["content"][content_idx_optional][
"text"
] = guardrail_response
async def process_output_response(
self,
response: "ResponsesAPIResponse",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
) -> Any:
"""
Process output response by applying guardrails to text content and tool calls.
Args:
response: LiteLLM ResponsesAPIResponse object
guardrail_to_apply: The guardrail instance to apply
litellm_logging_obj: Optional logging object
user_api_key_dict: User API key metadata to pass to guardrails
Returns:
Modified response with guardrail applied to content
Response Format Support:
- response.output is a list of output items
- Each output item can be:
* GenericResponseOutputItem with a content list of OutputText objects
* ResponseFunctionToolCall with tool call data
- Each OutputText object has a text field
"""
texts_to_check: List[str] = []
images_to_check: List[str] = []
tool_calls_to_check: List[ChatCompletionToolCallChunk] = []
task_mappings: List[Tuple[int, int]] = []
# Track (output_item_index, content_index) for each text
# Handle both dict and Pydantic object responses
if isinstance(response, dict):
response_output = response.get("output", [])
elif hasattr(response, "output"):
response_output = response.output or []
else:
verbose_proxy_logger.debug(
"OpenAI Responses API: No output found in response"
)
return response
if not response_output:
verbose_proxy_logger.debug("OpenAI Responses API: Empty output in response")
return response
# Step 1: Extract all text content and tool calls from response output
for output_idx, output_item in enumerate(response_output):
self._extract_output_text_and_images(
output_item=output_item,
output_idx=output_idx,
texts_to_check=texts_to_check,
images_to_check=images_to_check,
task_mappings=task_mappings,
tool_calls_to_check=tool_calls_to_check,
)
# Step 2: Apply guardrail to all texts in batch
if texts_to_check or tool_calls_to_check:
# Create a request_data dict with response info and user API key metadata
request_data: dict = {"response": response}
# Add user API key metadata with prefixed keys
user_metadata = self.transform_user_api_key_dict_to_metadata(
user_api_key_dict
)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
if images_to_check:
inputs["images"] = images_to_check
if tool_calls_to_check:
inputs["tool_calls"] = tool_calls_to_check
# Include model information from the response if available
response_model = None
if isinstance(response, dict):
response_model = response.get("model")
elif hasattr(response, "model"):
response_model = getattr(response, "model", None)
if response_model:
inputs["model"] = response_model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="response",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
# Step 3: Map guardrail responses back to original response structure
await self._apply_guardrail_responses_to_output(
response=response,
responses=guardrailed_texts,
task_mappings=task_mappings,
)
verbose_proxy_logger.debug(
"OpenAI Responses API: Processed output response: %s", response
)
return response
async def process_output_streaming_response(
self,
responses_so_far: List[Any],
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
) -> List[Any]:
"""
Process output streaming response by applying guardrails to text content.
"""
final_chunk = responses_so_far[-1]
if final_chunk.get("type") == "response.output_item.done":
# convert openai response to model response
model_response_stream = OpenAiResponsesToChatCompletionStreamIterator.translate_responses_chunk_to_openai_stream(
final_chunk
)
tool_calls = model_response_stream.choices[0].delta.tool_calls
if tool_calls:
inputs = GenericGuardrailAPIInputs()
inputs["tool_calls"] = cast(
List[ChatCompletionToolCallChunk], tool_calls
)
# Include model information if available
if (
hasattr(model_response_stream, "model")
and model_response_stream.model
):
inputs["model"] = model_response_stream.model
_guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data={},
input_type="response",
logging_obj=litellm_logging_obj,
)
return responses_so_far
elif final_chunk.get("type") == "response.completed":
# convert openai response to model response
outputs = final_chunk.get("response", {}).get("output", [])
model_response_choices = LiteLLMResponsesTransformationHandler._convert_response_output_to_choices(
output_items=outputs,
handle_raw_dict_callback=None,
)
if model_response_choices:
tool_calls = model_response_choices[0].message.tool_calls
text = model_response_choices[0].message.content
guardrail_inputs = GenericGuardrailAPIInputs()
if text:
guardrail_inputs["texts"] = [text]
if tool_calls:
guardrail_inputs["tool_calls"] = cast(
List[ChatCompletionToolCallChunk], tool_calls
)
# Include model information from the response if available
response_model = final_chunk.get("response", {}).get("model")
if response_model:
guardrail_inputs["model"] = response_model
if tool_calls or text:
_guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=guardrail_inputs,
request_data={},
input_type="response",
logging_obj=litellm_logging_obj,
)
return responses_so_far
else:
verbose_proxy_logger.debug(
"Skipping output guardrail - model response has no choices"
)
# model_response_stream = OpenAiResponsesToChatCompletionStreamIterator.translate_responses_chunk_to_openai_stream(final_chunk)
# tool_calls = model_response_stream.choices[0].tool_calls
# convert openai response to model response
string_so_far = self.get_streaming_string_so_far(responses_so_far)
inputs = GenericGuardrailAPIInputs(texts=[string_so_far])
# Try to get model from the final chunk if available
if isinstance(final_chunk, dict):
response_model = (
final_chunk.get("response", {}).get("model")
if isinstance(final_chunk.get("response"), dict)
else None
)
if response_model:
inputs["model"] = response_model
_guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data={},
input_type="response",
logging_obj=litellm_logging_obj,
)
return responses_so_far
def _check_streaming_has_ended(self, responses_so_far: List[Any]) -> bool:
"""
Check if the streaming has ended.
"""
return all(
response.choices[0].finish_reason is not None
for response in responses_so_far
)
def get_streaming_string_so_far(self, responses_so_far: List[Any]) -> str:
"""
Get the string so far from the responses so far.
"""
return "".join([response.get("text", "") for response in responses_so_far])
def _has_text_content(self, response: "ResponsesAPIResponse") -> bool:
"""
Check if response has any text content to process.
Override this method to customize text content detection.
"""
if not hasattr(response, "output") or response.output is None:
return False
for output_item in response.output:
if isinstance(output_item, BaseModel):
try:
generic_response_output_item = (
GenericResponseOutputItem.model_validate(
output_item.model_dump()
)
)
if generic_response_output_item.content:
output_item = generic_response_output_item
except Exception:
continue
if isinstance(output_item, (GenericResponseOutputItem, dict)):
content = (
output_item.content
if isinstance(output_item, GenericResponseOutputItem)
else output_item.get("content", [])
)
if content:
for content_item in content:
# Check if it's an OutputText with text
if isinstance(content_item, OutputText):
if content_item.text:
return True
elif isinstance(content_item, dict):
if content_item.get("text"):
return True
return False
def _extract_output_text_and_images(
self,
output_item: Any,
output_idx: int,
texts_to_check: List[str],
images_to_check: List[str],
task_mappings: List[Tuple[int, int]],
tool_calls_to_check: Optional[List[ChatCompletionToolCallChunk]] = None,
) -> None:
"""
Extract text content, images, and tool calls from a response output item.
Override this method to customize text/image/tool extraction logic.
"""
# Check if this is a tool call (OutputFunctionToolCall)
if isinstance(output_item, OutputFunctionToolCall):
if tool_calls_to_check is not None:
tool_call_dict = LiteLLMCompletionResponsesConfig.convert_response_function_tool_call_to_chat_completion_tool_call(
tool_call_item=output_item,
index=output_idx,
)
tool_calls_to_check.append(
cast(ChatCompletionToolCallChunk, tool_call_dict)
)
return
elif (
isinstance(output_item, BaseModel)
and hasattr(output_item, "type")
and getattr(output_item, "type") == "function_call"
):
if tool_calls_to_check is not None:
tool_call_dict = LiteLLMCompletionResponsesConfig.convert_response_function_tool_call_to_chat_completion_tool_call(
tool_call_item=output_item,
index=output_idx,
)
tool_calls_to_check.append(
cast(ChatCompletionToolCallChunk, tool_call_dict)
)
return
elif (
isinstance(output_item, dict) and output_item.get("type") == "function_call"
):
# Handle dict representation of tool call
if tool_calls_to_check is not None:
# Convert dict to ResponseFunctionToolCall for processing
try:
tool_call_obj = ResponseFunctionToolCall(**output_item)
tool_call_dict = LiteLLMCompletionResponsesConfig.convert_response_function_tool_call_to_chat_completion_tool_call(
tool_call_item=tool_call_obj,
index=output_idx,
)
tool_calls_to_check.append(
cast(ChatCompletionToolCallChunk, tool_call_dict)
)
except Exception:
pass
return
# Handle both GenericResponseOutputItem and dict
content: Optional[Union[List[OutputText], List[dict]]] = None
if isinstance(output_item, BaseModel):
try:
output_item_dump = output_item.model_dump()
generic_response_output_item = GenericResponseOutputItem.model_validate(
output_item_dump
)
if generic_response_output_item.content:
content = generic_response_output_item.content
except Exception:
# Try to extract content directly from output_item if validation fails
if hasattr(output_item, "content") and output_item.content: # type: ignore
content = output_item.content # type: ignore
else:
return
elif isinstance(output_item, dict):
content = output_item.get("content", [])
else:
return
if not content:
return
verbose_proxy_logger.debug(
"OpenAI Responses API: Processing output item: %s", output_item
)
# Iterate through content items (list of OutputText objects)
for content_idx, content_item in enumerate(content):
# Handle both OutputText objects and dicts
if isinstance(content_item, OutputText):
text_content = content_item.text
elif isinstance(content_item, dict):
text_content = content_item.get("text")
else:
continue
if text_content:
texts_to_check.append(text_content)
task_mappings.append((output_idx, int(content_idx)))
async def _apply_guardrail_responses_to_output(
self,
response: "ResponsesAPIResponse",
responses: List[str],
task_mappings: List[Tuple[int, int]],
) -> None:
"""
Apply guardrail responses back to output response.
Override this method to customize how responses are applied.
"""
# Handle both dict and Pydantic object responses
if isinstance(response, dict):
response_output = response.get("output", [])
elif hasattr(response, "output"):
response_output = response.output or []
else:
return
for task_idx, guardrail_response in enumerate(responses):
mapping = task_mappings[task_idx]
output_idx = cast(int, mapping[0])
content_idx = cast(int, mapping[1])
if output_idx >= len(response_output):
continue
output_item = response_output[output_idx]
# Handle both GenericResponseOutputItem, BaseModel, and dict
if isinstance(output_item, GenericResponseOutputItem):
if output_item.content and content_idx < len(output_item.content):
content_item = output_item.content[content_idx]
if isinstance(content_item, OutputText):
content_item.text = guardrail_response
elif isinstance(content_item, dict):
content_item["text"] = guardrail_response
elif isinstance(output_item, BaseModel):
# Handle other Pydantic models by converting to GenericResponseOutputItem
try:
generic_item = GenericResponseOutputItem.model_validate(
output_item.model_dump()
)
if generic_item.content and content_idx < len(generic_item.content):
content_item = generic_item.content[content_idx]
if isinstance(content_item, OutputText):
content_item.text = guardrail_response
# Update the original response output
if hasattr(output_item, "content") and output_item.content: # type: ignore
original_content = output_item.content[content_idx] # type: ignore
if hasattr(original_content, "text"):
original_content.text = guardrail_response # type: ignore
except Exception:
pass
elif isinstance(output_item, dict):
content = output_item.get("content", [])
if content and content_idx < len(content):
if isinstance(content[content_idx], dict):
content[content_idx]["text"] = guardrail_response
elif hasattr(content[content_idx], "text"):
content[content_idx].text = guardrail_response

View File

@@ -0,0 +1,580 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast, get_type_hints
import httpx
from openai.types.responses import ResponseReasoningItem
from pydantic import BaseModel, ValidationError
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import process_response_headers
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
_safe_convert_created_field,
)
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import *
from litellm.types.responses.main import *
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LlmProviders
from ..common_utils import OpenAIError
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class OpenAIResponsesAPIConfig(BaseResponsesAPIConfig):
@property
def custom_llm_provider(self) -> LlmProviders:
return LlmProviders.OPENAI
def get_supported_openai_params(self, model: str) -> list:
"""
All OpenAI Responses API params are supported
"""
supported_params = get_type_hints(ResponsesAPIRequestParams).keys()
return list(
set(
[
"input",
"model",
"extra_headers",
"extra_query",
"extra_body",
"timeout",
]
+ list(supported_params)
)
)
def map_openai_params(
self,
response_api_optional_params: ResponsesAPIOptionalRequestParams,
model: str,
drop_params: bool,
) -> Dict:
"""No mapping applied since inputs are in OpenAI spec already"""
return dict(response_api_optional_params)
def transform_responses_api_request(
self,
model: str,
input: Union[str, ResponseInputParam],
response_api_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
"""No transform applied since inputs are in OpenAI spec already"""
input = self._validate_input_param(input)
final_request_params = dict(
ResponsesAPIRequestParams(
model=model, input=input, **response_api_optional_request_params
)
)
return final_request_params
def _validate_input_param(
self, input: Union[str, ResponseInputParam]
) -> Union[str, ResponseInputParam]:
"""
Ensure all input fields if pydantic are converted to dict
OpenAI API Fails when we try to JSON dumps specific input pydantic fields.
This function ensures all input fields are converted to dict.
"""
if isinstance(input, list):
validated_input = []
for item in input:
# if it's pydantic, convert to dict
if isinstance(item, BaseModel):
validated_input.append(item.model_dump(exclude_none=True))
elif isinstance(item, dict):
# Handle reasoning items specifically to filter out status=None
if item.get("type") == "reasoning":
verbose_logger.debug(f"Handling reasoning item: {item}")
# Type assertion since we know it's a dict at this point
dict_item = cast(Dict[str, Any], item)
filtered_item = self._handle_reasoning_item(dict_item)
else:
# For other dict items, just pass through
filtered_item = cast(Dict[str, Any], item)
validated_input.append(filtered_item)
else:
validated_input.append(item)
return validated_input # type: ignore
# Input is expected to be either str or List, no single BaseModel expected
return input
def _handle_reasoning_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
"""
Handle reasoning items specifically to filter out status=None using OpenAI's model.
Issue: https://github.com/BerriAI/litellm/issues/13484
OpenAI API does not accept ReasoningItem(status=None), so we need to:
1. Check if the item is a reasoning type
2. Create a ResponseReasoningItem object with the item data
3. Convert it back to dict with exclude_none=True to filter None values
"""
if item.get("type") == "reasoning":
try:
# Ensure required fields are present for ResponseReasoningItem
item_data = dict(item)
if "summary" not in item_data:
item_data["summary"] = (
item_data.get("reasoning_content", "")[:100] + "..."
if len(item_data.get("reasoning_content", "")) > 100
else item_data.get("reasoning_content", "")
)
# Create ResponseReasoningItem object from the item data
reasoning_item = ResponseReasoningItem(**item_data)
# Convert back to dict with exclude_none=True to exclude None fields
dict_reasoning_item = reasoning_item.model_dump(exclude_none=True)
return dict_reasoning_item
except Exception as e:
verbose_logger.debug(
f"Failed to create ResponseReasoningItem, falling back to manual filtering: {e}"
)
# Fallback: manually filter out known None fields
filtered_item = {
k: v
for k, v in item.items()
if v is not None
or k not in {"status", "content", "encrypted_content"}
}
return filtered_item
return item
def transform_response_api_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIResponse:
"""No transform applied since outputs are in OpenAI spec already"""
try:
logging_obj.post_call(
original_response=raw_response.text,
additional_args={"complete_input_dict": {}},
)
raw_response_json = raw_response.json()
raw_response_json["created_at"] = _safe_convert_created_field(
raw_response_json["created_at"]
)
except Exception:
raise OpenAIError(
message=raw_response.text, status_code=raw_response.status_code
)
raw_response_headers = dict(raw_response.headers)
processed_headers = process_response_headers(raw_response_headers)
try:
response = ResponsesAPIResponse(**raw_response_json)
except Exception:
verbose_logger.debug(
f"Error constructing ResponsesAPIResponse: {raw_response_json}, using model_construct"
)
response = ResponsesAPIResponse.model_construct(**raw_response_json)
# Store processed headers in additional_headers so they get returned to the client
response._hidden_params["additional_headers"] = processed_headers
response._hidden_params["headers"] = raw_response_headers
return response
def validate_environment(
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
litellm_params = litellm_params or GenericLiteLLMParams()
api_key = (
litellm_params.api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
)
headers.update(
{
"Authorization": f"Bearer {api_key}",
}
)
return headers
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the endpoint for OpenAI responses API
"""
api_base = (
api_base
or litellm.api_base
or get_secret_str("OPENAI_BASE_URL")
or get_secret_str("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
# Remove trailing slashes
api_base = api_base.rstrip("/")
return f"{api_base}/responses"
def transform_streaming_response(
self,
model: str,
parsed_chunk: dict,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIStreamingResponse:
"""
Transform a parsed streaming response chunk into a ResponsesAPIStreamingResponse
"""
# Convert the dictionary to a properly typed ResponsesAPIStreamingResponse
verbose_logger.debug("Raw OpenAI Chunk=%s", parsed_chunk)
event_type = str(parsed_chunk.get("type"))
event_pydantic_model = OpenAIResponsesAPIConfig.get_event_model_class(
event_type=event_type
)
# Some OpenAI-compatible providers send error.code: null; coalesce so validation succeeds.
try:
error_obj = parsed_chunk.get("error")
if isinstance(error_obj, dict) and error_obj.get("code") is None:
parsed_chunk = dict(parsed_chunk)
parsed_chunk["error"] = dict(error_obj)
parsed_chunk["error"]["code"] = "unknown_error"
except Exception:
verbose_logger.debug("Failed to coalesce error.code in parsed_chunk")
try:
return event_pydantic_model(**parsed_chunk)
except ValidationError:
verbose_logger.debug(
"Pydantic validation failed for %s with chunk %s, "
"falling back to model_construct",
event_pydantic_model.__name__,
parsed_chunk,
)
return event_pydantic_model.model_construct(**parsed_chunk)
@staticmethod
def get_event_model_class(event_type: str) -> Any:
"""
Returns the appropriate event model class based on the event type.
Args:
event_type (str): The type of event from the response chunk
Returns:
Any: The corresponding event model class
Raises:
ValueError: If the event type is unknown
"""
event_models = {
ResponsesAPIStreamEvents.RESPONSE_CREATED: ResponseCreatedEvent,
ResponsesAPIStreamEvents.RESPONSE_IN_PROGRESS: ResponseInProgressEvent,
ResponsesAPIStreamEvents.RESPONSE_COMPLETED: ResponseCompletedEvent,
ResponsesAPIStreamEvents.RESPONSE_FAILED: ResponseFailedEvent,
ResponsesAPIStreamEvents.RESPONSE_INCOMPLETE: ResponseIncompleteEvent,
ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED: OutputItemAddedEvent,
ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE: OutputItemDoneEvent,
ResponsesAPIStreamEvents.CONTENT_PART_ADDED: ContentPartAddedEvent,
ResponsesAPIStreamEvents.CONTENT_PART_DONE: ContentPartDoneEvent,
ResponsesAPIStreamEvents.OUTPUT_TEXT_DELTA: OutputTextDeltaEvent,
ResponsesAPIStreamEvents.OUTPUT_TEXT_ANNOTATION_ADDED: OutputTextAnnotationAddedEvent,
ResponsesAPIStreamEvents.OUTPUT_TEXT_DONE: OutputTextDoneEvent,
ResponsesAPIStreamEvents.REFUSAL_DELTA: RefusalDeltaEvent,
ResponsesAPIStreamEvents.REFUSAL_DONE: RefusalDoneEvent,
ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DELTA: FunctionCallArgumentsDeltaEvent,
ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DONE: FunctionCallArgumentsDoneEvent,
ResponsesAPIStreamEvents.FILE_SEARCH_CALL_IN_PROGRESS: FileSearchCallInProgressEvent,
ResponsesAPIStreamEvents.FILE_SEARCH_CALL_SEARCHING: FileSearchCallSearchingEvent,
ResponsesAPIStreamEvents.FILE_SEARCH_CALL_COMPLETED: FileSearchCallCompletedEvent,
ResponsesAPIStreamEvents.WEB_SEARCH_CALL_IN_PROGRESS: WebSearchCallInProgressEvent,
ResponsesAPIStreamEvents.WEB_SEARCH_CALL_SEARCHING: WebSearchCallSearchingEvent,
ResponsesAPIStreamEvents.WEB_SEARCH_CALL_COMPLETED: WebSearchCallCompletedEvent,
ResponsesAPIStreamEvents.MCP_LIST_TOOLS_IN_PROGRESS: MCPListToolsInProgressEvent,
ResponsesAPIStreamEvents.MCP_LIST_TOOLS_COMPLETED: MCPListToolsCompletedEvent,
ResponsesAPIStreamEvents.MCP_LIST_TOOLS_FAILED: MCPListToolsFailedEvent,
ResponsesAPIStreamEvents.MCP_CALL_IN_PROGRESS: MCPCallInProgressEvent,
ResponsesAPIStreamEvents.MCP_CALL_ARGUMENTS_DELTA: MCPCallArgumentsDeltaEvent,
ResponsesAPIStreamEvents.MCP_CALL_ARGUMENTS_DONE: MCPCallArgumentsDoneEvent,
ResponsesAPIStreamEvents.MCP_CALL_COMPLETED: MCPCallCompletedEvent,
ResponsesAPIStreamEvents.MCP_CALL_FAILED: MCPCallFailedEvent,
ResponsesAPIStreamEvents.IMAGE_GENERATION_PARTIAL_IMAGE: ImageGenerationPartialImageEvent,
ResponsesAPIStreamEvents.ERROR: ErrorEvent,
# Shell tool events: passthrough as GenericEvent so payload is preserved
ResponsesAPIStreamEvents.SHELL_CALL_IN_PROGRESS: GenericEvent,
ResponsesAPIStreamEvents.SHELL_CALL_COMPLETED: GenericEvent,
ResponsesAPIStreamEvents.SHELL_CALL_OUTPUT: GenericEvent,
}
model_class = event_models.get(cast(ResponsesAPIStreamEvents, event_type))
if not model_class:
return GenericEvent
return model_class
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
if stream is not True:
return False
if model is not None:
try:
if (
litellm.utils.supports_native_streaming(
model=model,
custom_llm_provider=custom_llm_provider,
)
is False
):
return True
except Exception as e:
verbose_logger.debug(
f"Error getting model info in OpenAIResponsesAPIConfig: {e}"
)
return False
def supports_native_websocket(self) -> bool:
"""OpenAI supports native WebSocket for Responses API"""
return True
#########################################################
########## DELETE RESPONSE API TRANSFORMATION ##############
#########################################################
def transform_delete_response_api_request(
self,
response_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform the delete response API request into a URL and data
OpenAI API expects the following request
- DELETE /v1/responses/{response_id}
"""
url = f"{api_base}/{response_id}"
data: Dict = {}
return url, data
def transform_delete_response_api_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> DeleteResponseResult:
"""
Transform the delete response API response into a DeleteResponseResult
"""
try:
raw_response_json = raw_response.json()
except Exception:
raise OpenAIError(
message=raw_response.text, status_code=raw_response.status_code
)
return DeleteResponseResult(**raw_response_json)
#########################################################
########## GET RESPONSE API TRANSFORMATION ###############
#########################################################
def transform_get_response_api_request(
self,
response_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform the get response API request into a URL and data
OpenAI API expects the following request
- GET /v1/responses/{response_id}
"""
url = f"{api_base}/{response_id}"
data: Dict = {}
return url, data
def transform_get_response_api_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIResponse:
"""
Transform the get response API response into a ResponsesAPIResponse
"""
try:
raw_response_json = raw_response.json()
except Exception:
raise OpenAIError(
message=raw_response.text, status_code=raw_response.status_code
)
raw_response_headers = dict(raw_response.headers)
processed_headers = process_response_headers(raw_response_headers)
response = ResponsesAPIResponse(**raw_response_json)
response._hidden_params["additional_headers"] = processed_headers
response._hidden_params["headers"] = raw_response_headers
return response
#########################################################
########## LIST INPUT ITEMS TRANSFORMATION #############
#########################################################
def transform_list_input_items_request(
self,
response_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
after: Optional[str] = None,
before: Optional[str] = None,
include: Optional[List[str]] = None,
limit: int = 20,
order: Literal["asc", "desc"] = "desc",
) -> Tuple[str, Dict]:
url = f"{api_base}/{response_id}/input_items"
params: Dict[str, Any] = {}
if after is not None:
params["after"] = after
if before is not None:
params["before"] = before
if include:
params["include"] = ",".join(include)
if limit is not None:
params["limit"] = limit
if order is not None:
params["order"] = order
return url, params
def transform_list_input_items_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> Dict:
try:
return raw_response.json()
except Exception:
raise OpenAIError(
message=raw_response.text, status_code=raw_response.status_code
)
#########################################################
########## CANCEL RESPONSE API TRANSFORMATION ##########
#########################################################
def transform_cancel_response_api_request(
self,
response_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform the cancel response API request into a URL and data
OpenAI API expects the following request
- POST /v1/responses/{response_id}/cancel
"""
url = f"{api_base}/{response_id}/cancel"
data: Dict = {}
return url, data
def transform_cancel_response_api_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIResponse:
"""
Transform the cancel response API response into a ResponsesAPIResponse
"""
try:
raw_response_json = raw_response.json()
except Exception:
raise OpenAIError(
message=raw_response.text, status_code=raw_response.status_code
)
raw_response_headers = dict(raw_response.headers)
processed_headers = process_response_headers(raw_response_headers)
response = ResponsesAPIResponse(**raw_response_json)
response._hidden_params["additional_headers"] = processed_headers
response._hidden_params["headers"] = raw_response_headers
return response
#########################################################
########## COMPACT RESPONSE API TRANSFORMATION ##########
#########################################################
def transform_compact_response_api_request(
self,
model: str,
input: Union[str, ResponseInputParam],
response_api_optional_request_params: Dict,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform the compact response API request into a URL and data
OpenAI API expects the following request
- POST /v1/responses/compact
"""
# Preserve query params (e.g., api-version) while appending /compact.
parsed_url = httpx.URL(api_base)
compact_path = parsed_url.path.rstrip("/") + "/compact"
url = str(parsed_url.copy_with(path=compact_path))
input = self._validate_input_param(input)
data = dict(
ResponsesAPIRequestParams(
model=model, input=input, **response_api_optional_request_params
)
)
return url, data
def transform_compact_response_api_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIResponse:
"""
Transform the compact response API response into a ResponsesAPIResponse
"""
try:
logging_obj.post_call(
original_response=raw_response.text,
additional_args={"complete_input_dict": {}},
)
raw_response_json = raw_response.json()
raw_response_json["created_at"] = _safe_convert_created_field(
raw_response_json["created_at"]
)
except Exception:
raise OpenAIError(
message=raw_response.text, status_code=raw_response.status_code
)
raw_response_headers = dict(raw_response.headers)
processed_headers = process_response_headers(raw_response_headers)
try:
response = ResponsesAPIResponse(**raw_response_json)
except Exception:
verbose_logger.debug(
f"Error constructing ResponsesAPIResponse: {raw_response_json}, using model_construct"
)
response = ResponsesAPIResponse.model_construct(**raw_response_json)
response._hidden_params["additional_headers"] = processed_headers
response._hidden_params["headers"] = raw_response_headers
return response

View File

@@ -0,0 +1,178 @@
# OpenAI Text-to-Speech Guardrail Translation Handler
Handler for processing OpenAI's text-to-speech endpoint (`/v1/audio/speech`) with guardrails.
## Overview
This handler processes text-to-speech requests by:
1. Extracting the input text from the request
2. Applying guardrails to the input text
3. Updating the request with the guardrailed text
4. Returning the output unchanged (audio is binary, not text)
## Data Format
### Input Format
```json
{
"model": "tts-1",
"input": "The quick brown fox jumped over the lazy dog.",
"voice": "alloy",
"response_format": "mp3",
"speed": 1.0
}
```
### Output Format
The output is binary audio data (MP3, WAV, etc.), not text, so it cannot be guardrailed.
## Usage
The handler is automatically discovered and applied when guardrails are used with the text-to-speech endpoint.
### Example: Using Guardrails with Text-to-Speech
```bash
curl -X POST 'http://localhost:4000/v1/audio/speech' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer your-api-key' \
-d '{
"model": "tts-1",
"input": "The quick brown fox jumped over the lazy dog.",
"voice": "alloy",
"guardrails": ["content_moderation"]
}' \
--output speech.mp3
```
The guardrail will be applied to the input text before the text-to-speech conversion.
### Example: PII Masking in TTS Input
```bash
curl -X POST 'http://localhost:4000/v1/audio/speech' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer your-api-key' \
-d '{
"model": "tts-1",
"input": "Please call John Doe at john@example.com",
"voice": "nova",
"guardrails": ["mask_pii"]
}' \
--output speech.mp3
```
The audio will say: "Please call [NAME_REDACTED] at [EMAIL_REDACTED]"
### Example: Content Filtering Before TTS
```bash
curl -X POST 'http://localhost:4000/v1/audio/speech' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer your-api-key' \
-d '{
"model": "tts-1-hd",
"input": "This is the text that will be spoken",
"voice": "shimmer",
"guardrails": ["content_filter"]
}' \
--output speech.mp3
```
## Implementation Details
### Input Processing
- **Field**: `input` (string)
- **Processing**: Applies guardrail to input text
- **Result**: Updated input text in request
### Output Processing
- **Processing**: Not applicable (audio is binary data)
- **Result**: Response returned unchanged
## Use Cases
1. **PII Protection**: Remove personally identifiable information before converting to speech
2. **Content Filtering**: Remove inappropriate content before TTS conversion
3. **Compliance**: Ensure text meets requirements before voice synthesis
4. **Text Sanitization**: Clean up text before audio generation
## Extension
Override these methods to customize behavior:
- `process_input_messages()`: Customize how input text is processed
- `process_output_response()`: Currently a no-op, but can be overridden if needed
## Supported Call Types
- `CallTypes.speech` - Synchronous text-to-speech
- `CallTypes.aspeech` - Asynchronous text-to-speech
## Notes
- Only the input text is processed by guardrails
- Output processing is a no-op since audio cannot be text-guardrailed
- Both sync and async call types use the same handler
- Works with all TTS models (tts-1, tts-1-hd, etc.)
- Works with all voice options
## Common Patterns
### Remove PII Before TTS
```python
import litellm
from pathlib import Path
speech_file_path = Path(__file__).parent / "speech.mp3"
response = litellm.speech(
model="tts-1",
voice="alloy",
input="Hi, this is John Doe calling from john@company.com",
guardrails=["mask_pii"],
)
response.stream_to_file(speech_file_path)
# Audio will have PII masked
```
### Content Moderation Before TTS
```python
import litellm
from pathlib import Path
speech_file_path = Path(__file__).parent / "speech.mp3"
response = litellm.speech(
model="tts-1-hd",
voice="nova",
input="Your text here",
guardrails=["content_moderation"],
)
response.stream_to_file(speech_file_path)
```
### Async TTS with Guardrails
```python
import litellm
import asyncio
from pathlib import Path
async def generate_speech():
speech_file_path = Path(__file__).parent / "speech.mp3"
response = await litellm.aspeech(
model="tts-1",
voice="echo",
input="Text to convert to speech",
guardrails=["pii_mask"],
)
response.stream_to_file(speech_file_path)
asyncio.run(generate_speech())
```

View File

@@ -0,0 +1,13 @@
"""OpenAI Text-to-Speech handler for Unified Guardrails."""
from litellm.llms.openai.speech.guardrail_translation.handler import (
OpenAITextToSpeechHandler,
)
from litellm.types.utils import CallTypes
guardrail_translation_mappings = {
CallTypes.speech: OpenAITextToSpeechHandler,
CallTypes.aspeech: OpenAITextToSpeechHandler,
}
__all__ = ["guardrail_translation_mappings", "OpenAITextToSpeechHandler"]

View File

@@ -0,0 +1,108 @@
"""
OpenAI Text-to-Speech Handler for Unified Guardrails
This module provides guardrail translation support for OpenAI's text-to-speech endpoint.
The handler processes the 'input' text parameter (output is audio, so no text to guardrail).
"""
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
from litellm.types.utils import GenericGuardrailAPIInputs
if TYPE_CHECKING:
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.types.llms.openai import HttpxBinaryResponseContent
class OpenAITextToSpeechHandler(BaseTranslation):
"""
Handler for processing OpenAI text-to-speech requests with guardrails.
This class provides methods to:
1. Process input text (pre-call hook)
Note: Output processing is not applicable since the output is audio (binary),
not text. Only the input text is processed.
"""
async def process_input_messages(
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process input text by applying guardrails.
Args:
data: Request data dictionary containing 'input' parameter
guardrail_to_apply: The guardrail instance to apply
Returns:
Modified data with guardrails applied to input text
"""
input_text = data.get("input")
if input_text is None:
verbose_proxy_logger.debug(
"OpenAI Text-to-Speech: No input text found in request data"
)
return data
if isinstance(input_text, str):
inputs = GenericGuardrailAPIInputs(texts=[input_text])
# Include model information if available (voice model)
model = data.get("model")
if model:
inputs["model"] = model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=data,
input_type="request",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
data["input"] = guardrailed_texts[0] if guardrailed_texts else input_text
verbose_proxy_logger.debug(
"OpenAI Text-to-Speech: Applied guardrail to input text. "
"Original length: %d, New length: %d",
len(input_text),
len(data["input"]),
)
else:
verbose_proxy_logger.debug(
"OpenAI Text-to-Speech: Unexpected input type: %s. Expected string.",
type(input_text),
)
return data
async def process_output_response(
self,
response: "HttpxBinaryResponseContent",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
) -> Any:
"""
Process output - not applicable for text-to-speech.
The output is audio (binary data), not text, so there's nothing to apply
guardrails to. This method returns the response unchanged.
Args:
response: Binary audio response
guardrail_to_apply: The guardrail instance (unused)
litellm_logging_obj: Optional logging object (unused)
user_api_key_dict: User API key metadata (unused)
Returns:
Unmodified response (audio data doesn't need text guardrails)
"""
verbose_proxy_logger.debug(
"OpenAI Text-to-Speech: Output processing not applicable "
"(output is audio data, not text)"
)
return response

View File

@@ -0,0 +1,41 @@
from typing import List
from litellm.llms.base_llm.audio_transcription.transformation import (
AudioTranscriptionRequestData,
)
from litellm.types.llms.openai import OpenAIAudioTranscriptionOptionalParams
from litellm.types.utils import FileTypes
from .whisper_transformation import OpenAIWhisperAudioTranscriptionConfig
class OpenAIGPTAudioTranscriptionConfig(OpenAIWhisperAudioTranscriptionConfig):
def get_supported_openai_params(
self, model: str
) -> List[OpenAIAudioTranscriptionOptionalParams]:
"""
Get the supported OpenAI params for the `gpt-4o-transcribe` models
"""
return [
"language",
"prompt",
"response_format",
"temperature",
"include",
]
def transform_audio_transcription_request(
self,
model: str,
audio_file: FileTypes,
optional_params: dict,
litellm_params: dict,
) -> AudioTranscriptionRequestData:
"""
Transform the audio transcription request
"""
data = {"model": model, "file": audio_file, **optional_params}
return AudioTranscriptionRequestData(
data=data,
)

View File

@@ -0,0 +1,159 @@
# OpenAI Audio Transcription Guardrail Translation Handler
Handler for processing OpenAI's audio transcription endpoint (`/v1/audio/transcriptions`) with guardrails.
## Overview
This handler processes audio transcription responses by:
1. Applying guardrails to the transcribed text output
2. Returning the input unchanged (since input is an audio file, not text)
## Data Format
### Input Format
The input is an audio file, which cannot be guardrailed (it's binary data, not text).
```json
{
"model": "whisper-1",
"file": "<audio file>",
"response_format": "json",
"language": "en"
}
```
### Output Format
```json
{
"text": "This is the transcribed text from the audio file."
}
```
Or with additional metadata:
```json
{
"text": "This is the transcribed text from the audio file.",
"duration": 3.5,
"language": "en"
}
```
## Usage
The handler is automatically discovered and applied when guardrails are used with the audio transcription endpoint.
### Example: Using Guardrails with Audio Transcription
```bash
curl -X POST 'http://localhost:4000/v1/audio/transcriptions' \
-H 'Authorization: Bearer your-api-key' \
-F 'file=@audio.mp3' \
-F 'model=whisper-1' \
-F 'guardrails=["pii_mask"]'
```
The guardrail will be applied to the **output** transcribed text only.
### Example: PII Masking in Transcribed Text
```bash
curl -X POST 'http://localhost:4000/v1/audio/transcriptions' \
-H 'Authorization: Bearer your-api-key' \
-F 'file=@meeting_recording.mp3' \
-F 'model=whisper-1' \
-F 'guardrails=["mask_pii"]' \
-F 'response_format=json'
```
If the audio contains: "My name is John Doe and my email is john@example.com"
The transcription output will be: "My name is [NAME_REDACTED] and my email is [EMAIL_REDACTED]"
### Example: Content Moderation on Transcriptions
```bash
curl -X POST 'http://localhost:4000/v1/audio/transcriptions' \
-H 'Authorization: Bearer your-api-key' \
-F 'file=@audio.wav' \
-F 'model=whisper-1' \
-F 'guardrails=["content_moderation"]'
```
## Implementation Details
### Input Processing
- **Status**: Not applicable
- **Reason**: Input is an audio file (binary data), not text
- **Result**: Request data returned unchanged
### Output Processing
- **Field**: `text` (string)
- **Processing**: Applies guardrail to the transcribed text
- **Result**: Updated text in response
## Use Cases
1. **PII Protection**: Automatically redact personally identifiable information from transcriptions
2. **Content Filtering**: Remove or flag inappropriate content in transcribed audio
3. **Compliance**: Ensure transcriptions meet regulatory requirements
4. **Data Sanitization**: Clean up transcriptions before storage or further processing
## Extension
Override these methods to customize behavior:
- `process_output_response()`: Customize how transcribed text is processed
- `process_input_messages()`: Currently a no-op, but can be overridden if needed
## Supported Call Types
- `CallTypes.transcription` - Synchronous audio transcription
- `CallTypes.atranscription` - Asynchronous audio transcription
## Notes
- Input processing is a no-op since audio files cannot be text-guardrailed
- Only the transcribed text output is processed
- Guardrails apply after transcription is complete
- Both sync and async call types use the same handler
- Works with all Whisper models and response formats
## Common Patterns
### Transcribe and Redact PII
```python
import litellm
response = litellm.transcription(
model="whisper-1",
file=open("interview.mp3", "rb"),
guardrails=["mask_pii"],
)
# response.text will have PII redacted
print(response.text)
```
### Async Transcription with Guardrails
```python
import litellm
import asyncio
async def transcribe_with_guardrails():
response = await litellm.atranscription(
model="whisper-1",
file=open("audio.mp3", "rb"),
guardrails=["content_filter"],
)
return response.text
text = asyncio.run(transcribe_with_guardrails())
```

View File

@@ -0,0 +1,13 @@
"""OpenAI Audio Transcription handler for Unified Guardrails."""
from litellm.llms.openai.transcriptions.guardrail_translation.handler import (
OpenAIAudioTranscriptionHandler,
)
from litellm.types.utils import CallTypes
guardrail_translation_mappings = {
CallTypes.transcription: OpenAIAudioTranscriptionHandler,
CallTypes.atranscription: OpenAIAudioTranscriptionHandler,
}
__all__ = ["guardrail_translation_mappings", "OpenAIAudioTranscriptionHandler"]

View File

@@ -0,0 +1,117 @@
"""
OpenAI Audio Transcription Handler for Unified Guardrails
This module provides guardrail translation support for OpenAI's audio transcription endpoint.
The handler processes the output transcribed text (input is audio, so no text to guardrail).
"""
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
from litellm.types.utils import GenericGuardrailAPIInputs
if TYPE_CHECKING:
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.utils import TranscriptionResponse
class OpenAIAudioTranscriptionHandler(BaseTranslation):
"""
Handler for processing OpenAI audio transcription responses with guardrails.
This class provides methods to:
1. Process output transcription text (post-call hook)
Note: Input processing is not applicable since the input is an audio file,
not text. Only the transcribed text output is processed.
"""
async def process_input_messages(
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process input - not applicable for audio transcription.
The input is an audio file, not text, so there's nothing to apply
guardrails to. This method returns the data unchanged.
Args:
data: Request data dictionary containing audio file
guardrail_to_apply: The guardrail instance (unused)
Returns:
Unmodified data (audio files don't need text guardrails)
"""
verbose_proxy_logger.debug(
"OpenAI Audio Transcription: Input processing not applicable "
"(input is audio file, not text)"
)
return data
async def process_output_response(
self,
response: "TranscriptionResponse",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
) -> Any:
"""
Process output transcription by applying guardrails to transcribed text.
Args:
response: Transcription response object containing transcribed text
guardrail_to_apply: The guardrail instance to apply
litellm_logging_obj: Optional logging object
user_api_key_dict: User API key metadata to pass to guardrails
Returns:
Modified response with guardrails applied to transcribed text
"""
if not hasattr(response, "text") or response.text is None:
verbose_proxy_logger.debug(
"OpenAI Audio Transcription: No text in response to process"
)
return response
if isinstance(response.text, str):
original_text = response.text
# Create a request_data dict with response info and user API key metadata
request_data: dict = {"response": response}
# Add user API key metadata with prefixed keys
user_metadata = self.transform_user_api_key_dict_to_metadata(
user_api_key_dict
)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
inputs = GenericGuardrailAPIInputs(texts=[original_text])
# Include model information from the response if available
if hasattr(response, "model") and response.model:
inputs["model"] = response.model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="response",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
response.text = guardrailed_texts[0] if guardrailed_texts else original_text
verbose_proxy_logger.debug(
"OpenAI Audio Transcription: Applied guardrail to transcribed text. "
"Original length: %d, New length: %d",
len(original_text),
len(response.text),
)
else:
verbose_proxy_logger.debug(
"OpenAI Audio Transcription: Unexpected text type: %s. Expected string.",
type(response.text),
)
return response

View File

@@ -0,0 +1,231 @@
from typing import TYPE_CHECKING, Optional, Union, cast
import httpx
from openai import AsyncOpenAI, OpenAI
from pydantic import BaseModel
import litellm
if TYPE_CHECKING:
from aiohttp import ClientSession
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_name
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.audio_transcription.transformation import (
BaseAudioTranscriptionConfig,
)
from litellm.types.utils import FileTypes
from litellm.utils import (
TranscriptionResponse,
convert_to_model_response_object,
extract_duration_from_srt_or_vtt,
)
from ..openai import OpenAIChatCompletion
class OpenAIAudioTranscription(OpenAIChatCompletion):
# Audio Transcriptions
async def make_openai_audio_transcriptions_request(
self,
openai_aclient: AsyncOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call openai_aclient.audio.transcriptions.with_raw_response when litellm.return_response_headers is True
- call openai_aclient.audio.transcriptions.create by default
"""
try:
raw_response = (
await openai_aclient.audio.transcriptions.with_raw_response.create(
**data, timeout=timeout
)
) # type: ignore
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
except Exception as e:
raise e
def make_sync_openai_audio_transcriptions_request(
self,
openai_client: OpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call openai_aclient.audio.transcriptions.with_raw_response when litellm.return_response_headers is True
- call openai_aclient.audio.transcriptions.create by default
"""
try:
if litellm.return_response_headers is True:
raw_response = (
openai_client.audio.transcriptions.with_raw_response.create(
**data, timeout=timeout
)
) # type: ignore
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
else:
response = openai_client.audio.transcriptions.create(**data, timeout=timeout) # type: ignore
return None, response
except Exception as e:
raise e
def audio_transcriptions(
self,
model: str,
audio_file: FileTypes,
optional_params: dict,
litellm_params: dict,
model_response: TranscriptionResponse,
timeout: float,
max_retries: int,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str],
api_base: Optional[str],
client=None,
atranscription: bool = False,
provider_config: Optional[BaseAudioTranscriptionConfig] = None,
shared_session: Optional["ClientSession"] = None,
) -> TranscriptionResponse:
"""
Handle audio transcription request
"""
if provider_config is not None:
transformed_data = provider_config.transform_audio_transcription_request(
model=model,
audio_file=audio_file,
optional_params=optional_params,
litellm_params=litellm_params,
)
data = cast(dict, transformed_data.data)
else:
data = {"model": model, "file": audio_file, **optional_params}
if atranscription is True:
return self.async_audio_transcriptions( # type: ignore
audio_file=audio_file,
data=data,
model_response=model_response,
timeout=timeout,
api_key=api_key,
api_base=api_base,
client=client,
max_retries=max_retries,
logging_obj=logging_obj,
shared_session=shared_session,
)
openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
)
## LOGGING
logging_obj.pre_call(
input=None,
api_key=openai_client.api_key,
additional_args={
"api_base": openai_client._base_url._uri_reference,
"atranscription": True,
"complete_input_dict": data,
},
)
_, response = self.make_sync_openai_audio_transcriptions_request(
openai_client=openai_client,
data=data,
timeout=timeout,
)
if isinstance(response, BaseModel):
stringified_response = response.model_dump()
else:
stringified_response = TranscriptionResponse(text=response).model_dump()
## LOGGING
logging_obj.post_call(
input=get_audio_file_name(audio_file),
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
hidden_params = {"model": model, "custom_llm_provider": "openai"}
final_response: TranscriptionResponse = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return final_response
async def async_audio_transcriptions(
self,
audio_file: FileTypes,
data: dict,
model_response: TranscriptionResponse,
timeout: float,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
shared_session: Optional["ClientSession"] = None,
):
try:
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
is_async=True,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
shared_session=shared_session,
)
## LOGGING
logging_obj.pre_call(
input=None,
api_key=openai_aclient.api_key,
additional_args={
"api_base": openai_aclient._base_url._uri_reference,
"atranscription": True,
"complete_input_dict": data,
},
)
headers, response = await self.make_openai_audio_transcriptions_request(
openai_aclient=openai_aclient,
data=data,
timeout=timeout,
)
logging_obj.model_call_details["response_headers"] = headers
if isinstance(response, BaseModel):
stringified_response = response.model_dump()
else:
duration = extract_duration_from_srt_or_vtt(response)
stringified_response = TranscriptionResponse(text=response).model_dump()
stringified_response["_audio_transcription_duration"] = duration
## LOGGING
logging_obj.post_call(
input=get_audio_file_name(audio_file),
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
# Extract the actual model from data instead of hardcoding "whisper-1"
actual_model = data.get("model", "whisper-1")
hidden_params = {"model": actual_model, "custom_llm_provider": "openai"}
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
)
raise e

View File

@@ -0,0 +1,150 @@
from typing import List, Optional, Union
from httpx import Headers, Response
from litellm.llms.base_llm.audio_transcription.transformation import (
AudioTranscriptionRequestData,
BaseAudioTranscriptionConfig,
)
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,
OpenAIAudioTranscriptionOptionalParams,
)
from litellm.types.utils import FileTypes, TranscriptionResponse
from ..common_utils import OpenAIError
class OpenAIWhisperAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
## get the api base, attach the endpoint - v1/audio/transcriptions
# strip trailing slash if present
api_base = api_base.rstrip("/") if api_base else ""
# if endswith "/v1"
if api_base and api_base.endswith("/v1"):
api_base = f"{api_base}/audio/transcriptions"
else:
api_base = f"{api_base}/v1/audio/transcriptions"
return api_base or ""
def get_supported_openai_params(
self, model: str
) -> List[OpenAIAudioTranscriptionOptionalParams]:
"""
Get the supported OpenAI params for the `whisper-1` models
"""
return [
"language",
"prompt",
"response_format",
"temperature",
"timestamp_granularities",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map the OpenAI params to the Whisper params
"""
supported_params = self.get_supported_openai_params(model)
for k, v in non_default_params.items():
if k in supported_params:
optional_params[k] = v
return optional_params
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
api_key = api_key or get_secret_str("OPENAI_API_KEY")
auth_header = {
"Authorization": f"Bearer {api_key}",
}
headers.update(auth_header)
return headers
def transform_audio_transcription_request(
self,
model: str,
audio_file: FileTypes,
optional_params: dict,
litellm_params: dict,
) -> AudioTranscriptionRequestData:
"""
Transform the audio transcription request
"""
data = {"model": model, "file": audio_file, **optional_params}
if "response_format" not in data or (
data["response_format"] == "text" or data["response_format"] == "json"
):
data[
"response_format"
] = "verbose_json" # ensures 'duration' is received - used for cost calculation
return AudioTranscriptionRequestData(
data=data,
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return OpenAIError(
status_code=status_code,
message=error_message,
headers=headers,
)
def transform_audio_transcription_response(
self,
raw_response: Response,
) -> TranscriptionResponse:
try:
raw_response_json = raw_response.json()
except Exception as e:
raise ValueError(
f"Error transforming response to json: {str(e)}\nResponse: {raw_response.text}"
)
if any(
key in raw_response_json
for key in TranscriptionResponse.model_fields.keys()
):
return TranscriptionResponse(**raw_response_json)
else:
raise ValueError(
"Invalid response format. Received response does not match the expected format. Got: ",
raw_response_json,
)

View File

@@ -0,0 +1,258 @@
from typing import Any, Dict, Optional, Tuple, cast
import httpx
import litellm
from litellm.llms.base_llm.vector_store_files.transformation import (
BaseVectorStoreFilesConfig,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.router import GenericLiteLLMParams
from litellm.types.vector_store_files import (
VectorStoreFileAuthCredentials,
VectorStoreFileContentResponse,
VectorStoreFileCreateRequest,
VectorStoreFileDeleteResponse,
VectorStoreFileListQueryParams,
VectorStoreFileListResponse,
VectorStoreFileObject,
VectorStoreFileUpdateRequest,
)
from litellm.utils import add_openai_metadata
def _clean_dict(source: Dict[str, Any]) -> Dict[str, Any]:
return {k: v for k, v in source.items() if v is not None}
class OpenAIVectorStoreFilesConfig(BaseVectorStoreFilesConfig):
ASSISTANTS_HEADER_KEY = "OpenAI-Beta"
ASSISTANTS_HEADER_VALUE = "assistants=v2"
def get_auth_credentials(
self, litellm_params: Dict[str, Any]
) -> VectorStoreFileAuthCredentials:
api_key = litellm_params.get("api_key")
if api_key is None:
raise ValueError("api_key is required")
return {
"headers": {
"Authorization": f"Bearer {api_key}",
}
}
def get_vector_store_file_endpoints_by_type(
self,
) -> Dict[str, Tuple[Tuple[str, str], ...]]:
return {
"read": (
("GET", "/vector_stores/{vector_store_id}/files"),
("GET", "/vector_stores/{vector_store_id}/files/{file_id}"),
(
"GET",
"/vector_stores/{vector_store_id}/files/{file_id}/content",
),
),
"write": (
("POST", "/vector_stores/{vector_store_id}/files"),
("POST", "/vector_stores/{vector_store_id}/files/{file_id}"),
("DELETE", "/vector_stores/{vector_store_id}/files/{file_id}"),
),
}
def validate_environment(
self,
*,
headers: Dict[str, str],
litellm_params: Optional[GenericLiteLLMParams],
) -> Dict[str, str]:
litellm_params = litellm_params or GenericLiteLLMParams()
api_key = (
litellm_params.api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
)
headers.update(
{
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
)
if self.ASSISTANTS_HEADER_KEY not in headers:
headers[self.ASSISTANTS_HEADER_KEY] = self.ASSISTANTS_HEADER_VALUE
return headers
def get_complete_url(
self,
*,
api_base: Optional[str],
vector_store_id: str,
litellm_params: Dict[str, Any],
) -> str:
base_url = (
api_base
or litellm.api_base
or get_secret_str("OPENAI_BASE_URL")
or get_secret_str("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
base_url = base_url.rstrip("/")
return f"{base_url}/vector_stores/{vector_store_id}/files"
def transform_create_vector_store_file_request(
self,
*,
vector_store_id: str,
create_request: VectorStoreFileCreateRequest,
api_base: str,
) -> Tuple[str, Dict[str, Any]]:
payload: Dict[str, Any] = _clean_dict(dict(create_request))
attributes = payload.get("attributes")
if isinstance(attributes, dict):
filtered_attributes = add_openai_metadata(attributes)
if filtered_attributes is not None:
payload["attributes"] = filtered_attributes
else:
payload.pop("attributes", None)
url = api_base
return url, payload
def transform_create_vector_store_file_response(
self,
*,
response: httpx.Response,
) -> VectorStoreFileObject:
try:
return cast(VectorStoreFileObject, response.json())
except Exception as exc: # noqa: BLE001
raise self.get_error_class(
error_message=str(exc),
status_code=response.status_code,
headers=response.headers,
)
def transform_list_vector_store_files_request(
self,
*,
vector_store_id: str,
query_params: VectorStoreFileListQueryParams,
api_base: str,
) -> Tuple[str, Dict[str, Any]]:
params = _clean_dict(dict(query_params))
return api_base, params
def transform_list_vector_store_files_response(
self,
*,
response: httpx.Response,
) -> VectorStoreFileListResponse:
try:
return cast(VectorStoreFileListResponse, response.json())
except Exception as exc: # noqa: BLE001
raise self.get_error_class(
error_message=str(exc),
status_code=response.status_code,
headers=response.headers,
)
def transform_retrieve_vector_store_file_request(
self,
*,
vector_store_id: str,
file_id: str,
api_base: str,
) -> Tuple[str, Dict[str, Any]]:
return f"{api_base}/{file_id}", {}
def transform_retrieve_vector_store_file_response(
self,
*,
response: httpx.Response,
) -> VectorStoreFileObject:
try:
return cast(VectorStoreFileObject, response.json())
except Exception as exc: # noqa: BLE001
raise self.get_error_class(
error_message=str(exc),
status_code=response.status_code,
headers=response.headers,
)
def transform_retrieve_vector_store_file_content_request(
self,
*,
vector_store_id: str,
file_id: str,
api_base: str,
) -> Tuple[str, Dict[str, Any]]:
return f"{api_base}/{file_id}/content", {}
def transform_retrieve_vector_store_file_content_response(
self,
*,
response: httpx.Response,
) -> VectorStoreFileContentResponse:
try:
return cast(VectorStoreFileContentResponse, response.json())
except Exception as exc: # noqa: BLE001
raise self.get_error_class(
error_message=str(exc),
status_code=response.status_code,
headers=response.headers,
)
def transform_update_vector_store_file_request(
self,
*,
vector_store_id: str,
file_id: str,
update_request: VectorStoreFileUpdateRequest,
api_base: str,
) -> Tuple[str, Dict[str, Any]]:
payload: Dict[str, Any] = dict(update_request)
attributes = payload.get("attributes")
if isinstance(attributes, dict):
filtered_attributes = add_openai_metadata(attributes)
if filtered_attributes is not None:
payload["attributes"] = filtered_attributes
else:
payload.pop("attributes", None)
return f"{api_base}/{file_id}", payload
def transform_update_vector_store_file_response(
self,
*,
response: httpx.Response,
) -> VectorStoreFileObject:
try:
return cast(VectorStoreFileObject, response.json())
except Exception as exc: # noqa: BLE001
raise self.get_error_class(
error_message=str(exc),
status_code=response.status_code,
headers=response.headers,
)
def transform_delete_vector_store_file_request(
self,
*,
vector_store_id: str,
file_id: str,
api_base: str,
) -> Tuple[str, Dict[str, Any]]:
return f"{api_base}/{file_id}", {}
def transform_delete_vector_store_file_response(
self,
*,
response: httpx.Response,
) -> VectorStoreFileDeleteResponse:
try:
return cast(VectorStoreFileDeleteResponse, response.json())
except Exception as exc: # noqa: BLE001
raise self.get_error_class(
error_message=str(exc),
status_code=response.status_code,
headers=response.headers,
)

View File

@@ -0,0 +1,176 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import httpx
import litellm
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.router import GenericLiteLLMParams
from litellm.types.vector_stores import (
BaseVectorStoreAuthCredentials,
VectorStoreCreateOptionalRequestParams,
VectorStoreCreateRequest,
VectorStoreCreateResponse,
VectorStoreIndexEndpoints,
VectorStoreSearchOptionalRequestParams,
VectorStoreSearchRequest,
VectorStoreSearchResponse,
)
from litellm.utils import add_openai_metadata
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class OpenAIVectorStoreConfig(BaseVectorStoreConfig):
ASSISTANTS_HEADER_KEY = "OpenAI-Beta"
ASSISTANTS_HEADER_VALUE = "assistants=v2"
def get_auth_credentials(
self, litellm_params: dict
) -> BaseVectorStoreAuthCredentials:
api_key = litellm_params.get("api_key")
if api_key is None:
raise ValueError("api_key is required")
return {
"headers": {
"Authorization": f"Bearer {api_key}",
},
}
def get_vector_store_endpoints_by_type(self) -> VectorStoreIndexEndpoints:
return {
"read": [("GET", "/vector_stores/{index_name}/search")],
"write": [("POST", "/vector_stores")],
}
def validate_environment(
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
litellm_params = litellm_params or GenericLiteLLMParams()
api_key = (
litellm_params.api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
)
headers.update(
{
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
)
#########################################################
# Ensure OpenAI Assistants header is includes
#########################################################
if self.ASSISTANTS_HEADER_KEY not in headers:
headers.update(
{
self.ASSISTANTS_HEADER_KEY: self.ASSISTANTS_HEADER_VALUE,
}
)
return headers
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the Base endpoint for OpenAI Vector Stores API
"""
api_base = (
api_base
or litellm.api_base
or get_secret_str("OPENAI_BASE_URL")
or get_secret_str("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
# Remove trailing slashes
api_base = api_base.rstrip("/")
return f"{api_base}/vector_stores"
def transform_search_vector_store_request(
self,
vector_store_id: str,
query: Union[str, List[str]],
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> Tuple[str, Dict]:
url = f"{api_base}/{vector_store_id}/search"
typed_request_body = VectorStoreSearchRequest(
query=query,
filters=vector_store_search_optional_params.get("filters", None),
max_num_results=vector_store_search_optional_params.get(
"max_num_results", None
),
ranking_options=vector_store_search_optional_params.get(
"ranking_options", None
),
rewrite_query=vector_store_search_optional_params.get(
"rewrite_query", None
),
)
dict_request_body = cast(dict, typed_request_body)
return url, dict_request_body
def transform_search_vector_store_response(
self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj
) -> VectorStoreSearchResponse:
try:
response_json = response.json()
return VectorStoreSearchResponse(**response_json)
except Exception as e:
raise self.get_error_class(
error_message=str(e),
status_code=response.status_code,
headers=response.headers,
)
def transform_create_vector_store_request(
self,
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
api_base: str,
) -> Tuple[str, Dict]:
url = api_base # Base URL for creating vector stores
metadata = vector_store_create_optional_params.get("metadata", None)
metadata_payload = add_openai_metadata(metadata)
typed_request_body = VectorStoreCreateRequest(
name=vector_store_create_optional_params.get("name", None),
file_ids=vector_store_create_optional_params.get("file_ids", None),
expires_after=vector_store_create_optional_params.get(
"expires_after", None
),
chunking_strategy=vector_store_create_optional_params.get(
"chunking_strategy", None
),
metadata=metadata_payload,
)
dict_request_body = cast(dict, typed_request_body)
return url, dict_request_body
def transform_create_vector_store_response(
self, response: httpx.Response
) -> VectorStoreCreateResponse:
try:
response_json = response.json()
return VectorStoreCreateResponse(**response_json)
except Exception as e:
raise self.get_error_class(
error_message=str(e),
status_code=response.status_code,
headers=response.headers,
)

View File

@@ -0,0 +1,447 @@
from io import BufferedReader
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import httpx
from httpx._types import RequestFiles
import litellm
from litellm.llms.base_llm.videos.transformation import BaseVideoConfig
from litellm.llms.openai.image_edit.transformation import ImageEditRequestUtils
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import CreateVideoRequest
from litellm.types.router import GenericLiteLLMParams
from litellm.types.videos.main import VideoCreateOptionalRequestParams, VideoObject
from litellm.types.videos.utils import (
encode_video_id_with_provider,
extract_original_video_id,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from ...base_llm.chat.transformation import BaseLLMException as _BaseLLMException
LiteLLMLoggingObj = _LiteLLMLoggingObj
BaseLLMException = _BaseLLMException
else:
LiteLLMLoggingObj = Any
BaseLLMException = Any
class OpenAIVideoConfig(BaseVideoConfig):
"""
Configuration class for OpenAI video generation.
"""
def __init__(self):
super().__init__()
def get_supported_openai_params(self, model: str) -> list:
"""
Get the list of supported OpenAI parameters for video generation.
"""
return [
"model",
"prompt",
"input_reference",
"seconds",
"size",
"user",
"extra_headers",
]
def map_openai_params(
self,
video_create_optional_params: VideoCreateOptionalRequestParams,
model: str,
drop_params: bool,
) -> Dict:
"""No mapping applied since inputs are in OpenAI spec already"""
return dict(video_create_optional_params)
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
litellm_params: Optional[GenericLiteLLMParams] = None,
) -> dict:
# Use api_key from litellm_params if available, otherwise fall back to other sources
if litellm_params and litellm_params.api_key:
api_key = api_key or litellm_params.api_key
api_key = (
api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
)
headers.update(
{
"Authorization": f"Bearer {api_key}",
}
)
return headers
def get_complete_url(
self,
model: str,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the complete URL for OpenAI video generation.
"""
if api_base is None:
api_base = "https://api.openai.com/v1"
return f"{api_base.rstrip('/')}/videos"
def transform_video_create_request(
self,
model: str,
prompt: str,
api_base: str,
video_create_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[Dict, RequestFiles, str]:
"""
Transform the video creation request for OpenAI API.
"""
# Remove model and extra_headers from optional params as they're handled separately
video_create_optional_request_params = {
k: v
for k, v in video_create_optional_request_params.items()
if k not in ["model", "extra_headers", "prompt"]
}
# Create the request data
video_create_request = CreateVideoRequest(
model=model, prompt=prompt, **video_create_optional_request_params
)
request_dict = cast(Dict, video_create_request)
# Handle input_reference parameter if provided
_input_reference = video_create_optional_request_params.get("input_reference")
data_without_files = {
k: v for k, v in request_dict.items() if k not in ["input_reference"]
}
files_list: List[Tuple[str, Any]] = []
# Handle input_reference parameter
if _input_reference is not None:
self._add_image_to_files(
files_list=files_list,
image=_input_reference,
field_name="input_reference",
)
return data_without_files, files_list, api_base
def transform_video_create_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: Optional[str] = None,
request_data: Optional[Dict] = None,
) -> VideoObject:
"""Transform the OpenAI video creation response."""
response_data = raw_response.json()
video_obj = VideoObject(**response_data) # type: ignore[arg-type]
if custom_llm_provider and video_obj.id:
video_obj.id = encode_video_id_with_provider(
video_obj.id, custom_llm_provider, model
)
usage_data = {}
if video_obj:
if hasattr(video_obj, "seconds") and video_obj.seconds:
try:
usage_data["duration_seconds"] = float(video_obj.seconds)
except (ValueError, TypeError):
pass
video_obj.usage = usage_data
return video_obj
def transform_video_content_request(
self,
video_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
variant: Optional[str] = None,
) -> Tuple[str, Dict]:
"""
Transform the video content request for OpenAI API.
OpenAI API expects the following request:
- GET /v1/videos/{video_id}/content
- GET /v1/videos/{video_id}/content?variant=thumbnail
"""
original_video_id = extract_original_video_id(video_id)
# Construct the URL for video content download
url = f"{api_base.rstrip('/')}/{original_video_id}/content"
if variant is not None:
url = f"{url}?variant={variant}"
# No additional data needed for GET content request
data: Dict[str, Any] = {}
return url, data
def transform_video_remix_request(
self,
video_id: str,
prompt: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
extra_body: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict]:
"""
Transform the video remix request for OpenAI API.
OpenAI API expects the following request:
- POST /v1/videos/{video_id}/remix
"""
original_video_id = extract_original_video_id(video_id)
# Construct the URL for video remix
url = f"{api_base.rstrip('/')}/{original_video_id}/remix"
# Prepare the request data
data = {"prompt": prompt}
# Add any extra body parameters
if extra_body:
data.update(extra_body)
return url, data
def transform_video_content_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> bytes:
"""Transform the OpenAI video content download response."""
return raw_response.content
def transform_video_remix_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: Optional[str] = None,
) -> VideoObject:
"""
Transform the OpenAI video remix response.
"""
response_data = raw_response.json()
# Transform the response data
video_obj = VideoObject(**response_data) # type: ignore[arg-type]
if custom_llm_provider and video_obj.id:
video_obj.id = encode_video_id_with_provider(
video_obj.id, custom_llm_provider, None
)
# Create usage object with duration information for cost calculation
# Video remix API doesn't provide usage, so we create one with duration
usage_data = {}
if video_obj:
if hasattr(video_obj, "seconds") and video_obj.seconds:
try:
usage_data["duration_seconds"] = float(video_obj.seconds)
except (ValueError, TypeError):
pass
# Create the response
video_obj.usage = usage_data
return video_obj
def transform_video_list_request(
self,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
after: Optional[str] = None,
limit: Optional[int] = None,
order: Optional[str] = None,
extra_query: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict]:
"""
Transform the video list request for OpenAI API.
OpenAI API expects the following request:
- GET /v1/videos
"""
# Use the api_base directly for video list
url = api_base
# Prepare query parameters
params = {}
if after is not None:
# Decode the wrapped video ID back to the original provider ID
params["after"] = extract_original_video_id(after)
if limit is not None:
params["limit"] = str(limit)
if order is not None:
params["order"] = order
# Add any extra query parameters
if extra_query:
params.update(extra_query)
return url, params
def transform_video_list_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: Optional[str] = None,
) -> Dict[str, str]:
response_data = raw_response.json()
if custom_llm_provider and "data" in response_data:
for video_obj in response_data.get("data", []):
if isinstance(video_obj, dict) and "id" in video_obj:
video_obj["id"] = encode_video_id_with_provider(
video_obj["id"],
custom_llm_provider,
video_obj.get("model"),
)
# Encode pagination cursor IDs so they remain consistent
# with the wrapped data[].id format
data_list = response_data.get("data", [])
if response_data.get("first_id"):
first_model = None
if data_list and isinstance(data_list[0], dict):
first_model = data_list[0].get("model")
response_data["first_id"] = encode_video_id_with_provider(
response_data["first_id"],
custom_llm_provider,
first_model,
)
if response_data.get("last_id"):
last_model = None
if data_list and isinstance(data_list[-1], dict):
last_model = data_list[-1].get("model")
response_data["last_id"] = encode_video_id_with_provider(
response_data["last_id"],
custom_llm_provider,
last_model,
)
return response_data
def transform_video_delete_request(
self,
video_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform the video delete request for OpenAI API.
OpenAI API expects the following request:
- DELETE /v1/videos/{video_id}
"""
original_video_id = extract_original_video_id(video_id)
# Construct the URL for video delete
url = f"{api_base.rstrip('/')}/{original_video_id}"
# No data needed for DELETE request
data: Dict[str, Any] = {}
return url, data
def transform_video_delete_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> VideoObject:
"""
Transform the OpenAI video delete response.
"""
response_data = raw_response.json()
# Transform the response data
video_obj = VideoObject(**response_data) # type: ignore[arg-type] # type: ignore[arg-type]
return video_obj
def transform_video_status_retrieve_request(
self,
video_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform the OpenAI video retrieve request.
"""
# Extract the original video_id (remove provider encoding if present)
original_video_id = extract_original_video_id(video_id)
# For video retrieve, we just need to construct the URL
url = f"{api_base.rstrip('/')}/{original_video_id}"
# No additional data needed for GET request
data: Dict[str, Any] = {}
return url, data
def transform_video_status_retrieve_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: Optional[str] = None,
) -> VideoObject:
"""
Transform the OpenAI video retrieve response.
"""
response_data = raw_response.json()
# Transform the response data
video_obj = VideoObject(**response_data) # type: ignore[arg-type]
if custom_llm_provider and video_obj.id:
video_obj.id = encode_video_id_with_provider(
video_obj.id, custom_llm_provider, None
)
return video_obj
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
from ...base_llm.chat.transformation import BaseLLMException
raise BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)
def _add_image_to_files(
self,
files_list: List[Tuple[str, Any]],
image: Any,
field_name: str,
) -> None:
"""Add an image to the files list with appropriate content type"""
image_content_type = ImageEditRequestUtils.get_image_content_type(image)
if isinstance(image, BufferedReader):
files_list.append((field_name, (image.name, image, image_content_type)))
else:
files_list.append(
(field_name, ("input_reference.png", image, image_content_type))
)