chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,569 @@
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from pydantic import fields as pyd_fields
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import process_response_headers
|
||||
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
|
||||
_safe_convert_created_field,
|
||||
)
|
||||
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
ResponseInputParam,
|
||||
ResponsesAPIOptionalRequestParams,
|
||||
ResponsesAPIResponse,
|
||||
ResponsesAPIStreamingResponse,
|
||||
)
|
||||
from litellm.types.responses.main import DeleteResponseResult
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
from ..common_utils import (
|
||||
VolcEngineError,
|
||||
get_volcengine_base_url,
|
||||
get_volcengine_headers,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class VolcEngineResponsesAPIConfig(OpenAIResponsesAPIConfig):
|
||||
_SUPPORTED_OPTIONAL_PARAMS: List[str] = [
|
||||
# Doc-listed knobs
|
||||
"instructions",
|
||||
"max_output_tokens",
|
||||
"previous_response_id",
|
||||
"store",
|
||||
"reasoning",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"text",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"max_tool_calls",
|
||||
"thinking",
|
||||
"caching",
|
||||
"expire_at",
|
||||
"context_management",
|
||||
# LiteLLM-internal metadata (not sent to provider)
|
||||
"metadata",
|
||||
# Request plumbing helpers
|
||||
"extra_headers",
|
||||
"extra_query",
|
||||
"extra_body",
|
||||
"timeout",
|
||||
]
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.VOLCENGINE
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Volcengine Responses API: only documented parameters are supported.
|
||||
"""
|
||||
supported = ["input", "model"] + list(self._SUPPORTED_OPTIONAL_PARAMS)
|
||||
# Do not advertise internal-only metadata to callers; we still accept and drop it before send.
|
||||
if "metadata" in supported:
|
||||
supported.remove("metadata")
|
||||
return supported
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> VolcEngineError:
|
||||
typed_headers: httpx.Headers = (
|
||||
headers
|
||||
if isinstance(headers, httpx.Headers)
|
||||
else httpx.Headers(headers or {})
|
||||
)
|
||||
return VolcEngineError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=typed_headers,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
"""
|
||||
Build auth headers for Volcengine Responses API.
|
||||
"""
|
||||
if litellm_params is None:
|
||||
litellm_params = GenericLiteLLMParams()
|
||||
elif isinstance(litellm_params, dict):
|
||||
litellm_params = GenericLiteLLMParams(**litellm_params)
|
||||
|
||||
api_key = (
|
||||
litellm_params.api_key
|
||||
or litellm.api_key
|
||||
or get_secret_str("ARK_API_KEY")
|
||||
or get_secret_str("VOLCENGINE_API_KEY")
|
||||
)
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Volcengine API key is required. Set ARK_API_KEY / VOLCENGINE_API_KEY or pass api_key."
|
||||
)
|
||||
|
||||
return get_volcengine_headers(api_key=api_key, extra_headers=headers)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Construct Volcengine Responses API endpoint.
|
||||
"""
|
||||
base_url = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("VOLCENGINE_API_BASE")
|
||||
or get_secret_str("ARK_API_BASE")
|
||||
or get_volcengine_base_url()
|
||||
)
|
||||
|
||||
base_url = base_url.rstrip("/")
|
||||
|
||||
if base_url.endswith("/responses"):
|
||||
return base_url
|
||||
if base_url.endswith("/api/v3"):
|
||||
return f"{base_url}/responses"
|
||||
return f"{base_url}/api/v3/responses"
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
response_api_optional_params: ResponsesAPIOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
"""
|
||||
Volcengine Responses API aligns with OpenAI parameters.
|
||||
Remove parameters not supported by the public docs.
|
||||
"""
|
||||
params = {
|
||||
key: value
|
||||
for key, value in dict(response_api_optional_params).items()
|
||||
if key in self._SUPPORTED_OPTIONAL_PARAMS
|
||||
}
|
||||
|
||||
# LiteLLM metadata is internal-only; don't send to provider
|
||||
params.pop("metadata", None)
|
||||
|
||||
# Volcengine docs do not list parallel_tool_calls; drop it to avoid backend errors.
|
||||
if "parallel_tool_calls" in params:
|
||||
verbose_logger.debug(
|
||||
"Volcengine Responses API: dropping unsupported 'parallel_tool_calls' param."
|
||||
)
|
||||
params.pop("parallel_tool_calls", None)
|
||||
|
||||
return params
|
||||
|
||||
def transform_responses_api_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, ResponseInputParam],
|
||||
response_api_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
"""
|
||||
Volcengine rejects any undocumented fields (including extra_body). Fail fast
|
||||
with clear errors and re-filter with the documented whitelist before delegating
|
||||
to the OpenAI base transformer.
|
||||
"""
|
||||
allowed = set(self._SUPPORTED_OPTIONAL_PARAMS)
|
||||
|
||||
sanitized_optional = {
|
||||
k: v
|
||||
for k, v in response_api_optional_request_params.items()
|
||||
if k in allowed
|
||||
}
|
||||
# Ensure metadata never reaches provider
|
||||
sanitized_optional.pop("metadata", None)
|
||||
sanitized_optional.pop("parallel_tool_calls", None)
|
||||
|
||||
# If extra_body is provided, filter its keys against the same allowlist to avoid
|
||||
# leaking unsupported params to the provider.
|
||||
if isinstance(sanitized_optional.get("extra_body"), dict):
|
||||
filtered_body = {
|
||||
k: v
|
||||
for k, v in sanitized_optional["extra_body"].items()
|
||||
if k in allowed
|
||||
}
|
||||
if filtered_body:
|
||||
sanitized_optional["extra_body"] = filtered_body
|
||||
else:
|
||||
sanitized_optional.pop("extra_body", None)
|
||||
|
||||
return super().transform_responses_api_request(
|
||||
model=model,
|
||||
input=input,
|
||||
response_api_optional_request_params=sanitized_optional,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def transform_streaming_response(
|
||||
self,
|
||||
model: str,
|
||||
parsed_chunk: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIStreamingResponse:
|
||||
"""
|
||||
Volcengine may omit required fields; auto-fill them using event model defaults.
|
||||
"""
|
||||
chunk = parsed_chunk
|
||||
|
||||
# Patch missing response.output on response.* events
|
||||
if isinstance(chunk, dict):
|
||||
resp = chunk.get("response")
|
||||
if isinstance(resp, dict) and "output" not in resp:
|
||||
patched_chunk = dict(chunk)
|
||||
patched_resp = dict(resp)
|
||||
patched_resp["output"] = []
|
||||
patched_chunk["response"] = patched_resp
|
||||
chunk = patched_chunk
|
||||
|
||||
event_type = str(chunk.get("type")) if isinstance(chunk, dict) else None
|
||||
event_pydantic_model = OpenAIResponsesAPIConfig.get_event_model_class(
|
||||
event_type=event_type
|
||||
)
|
||||
|
||||
patched_chunk = self._fill_missing_fields(chunk, event_pydantic_model)
|
||||
|
||||
return event_pydantic_model(**patched_chunk)
|
||||
|
||||
def transform_response_api_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
try:
|
||||
logging_obj.post_call(
|
||||
original_response=raw_response.text,
|
||||
additional_args={"complete_input_dict": {}},
|
||||
)
|
||||
raw_response_json = raw_response.json()
|
||||
if "created_at" in raw_response_json:
|
||||
raw_response_json["created_at"] = _safe_convert_created_field(
|
||||
raw_response_json["created_at"]
|
||||
)
|
||||
except Exception:
|
||||
raise VolcEngineError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
raw_response_headers = dict(raw_response.headers)
|
||||
processed_headers = process_response_headers(raw_response_headers)
|
||||
|
||||
try:
|
||||
response = ResponsesAPIResponse(**raw_response_json)
|
||||
except Exception:
|
||||
verbose_logger.debug(
|
||||
"Volcengine Responses API: falling back to model_construct for response parsing."
|
||||
)
|
||||
response = ResponsesAPIResponse.model_construct(**raw_response_json)
|
||||
|
||||
response._hidden_params["additional_headers"] = processed_headers
|
||||
response._hidden_params["headers"] = raw_response_headers
|
||||
return response
|
||||
|
||||
#########################################################
|
||||
########## DELETE RESPONSE API TRANSFORMATION ##############
|
||||
#########################################################
|
||||
def transform_delete_response_api_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
url = f"{api_base}/{response_id}"
|
||||
data: Dict = {}
|
||||
return url, data
|
||||
|
||||
def transform_delete_response_api_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> DeleteResponseResult:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise VolcEngineError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
try:
|
||||
return DeleteResponseResult(**raw_response_json)
|
||||
except Exception:
|
||||
verbose_logger.debug(
|
||||
"Volcengine Responses API: falling back to model_construct for delete response parsing."
|
||||
)
|
||||
return DeleteResponseResult.model_construct(**raw_response_json)
|
||||
|
||||
#########################################################
|
||||
########## GET RESPONSE API TRANSFORMATION ###############
|
||||
#########################################################
|
||||
def transform_get_response_api_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
url = f"{api_base}/{response_id}"
|
||||
data: Dict = {}
|
||||
return url, data
|
||||
|
||||
def transform_get_response_api_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise VolcEngineError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
raw_response_headers = dict(raw_response.headers)
|
||||
processed_headers = process_response_headers(raw_response_headers)
|
||||
|
||||
response = ResponsesAPIResponse(**raw_response_json)
|
||||
response._hidden_params["additional_headers"] = processed_headers
|
||||
response._hidden_params["headers"] = raw_response_headers
|
||||
return response
|
||||
|
||||
#########################################################
|
||||
########## LIST INPUT ITEMS TRANSFORMATION #############
|
||||
#########################################################
|
||||
def transform_list_input_items_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
include: Optional[List[str]] = None,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
) -> Tuple[str, Dict]:
|
||||
url = f"{api_base}/{response_id}/input_items"
|
||||
params: Dict[str, Any] = {}
|
||||
if after is not None:
|
||||
params["after"] = after
|
||||
if before is not None:
|
||||
params["before"] = before
|
||||
if include:
|
||||
params["include"] = ",".join(include)
|
||||
if limit is not None:
|
||||
params["limit"] = limit
|
||||
if order is not None:
|
||||
params["order"] = order
|
||||
return url, params
|
||||
|
||||
def transform_list_input_items_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Dict:
|
||||
try:
|
||||
return raw_response.json()
|
||||
except Exception:
|
||||
raise VolcEngineError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
#########################################################
|
||||
########## CANCEL RESPONSE API TRANSFORMATION ##########
|
||||
#########################################################
|
||||
def transform_cancel_response_api_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
url = f"{api_base}/{response_id}/cancel"
|
||||
data: Dict = {}
|
||||
return url, data
|
||||
|
||||
def transform_cancel_response_api_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise VolcEngineError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
raw_response_headers = dict(raw_response.headers)
|
||||
processed_headers = process_response_headers(raw_response_headers)
|
||||
|
||||
response = ResponsesAPIResponse(**raw_response_json)
|
||||
response._hidden_params["additional_headers"] = processed_headers
|
||||
response._hidden_params["headers"] = raw_response_headers
|
||||
return response
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Volcengine Responses API supports native streaming; never fall back to fake stream.
|
||||
"""
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _fill_missing_fields(chunk: Any, event_model: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Heuristically fill missing required fields with safe defaults based on the
|
||||
event model's field annotations. This keeps parsing tolerant of providers that
|
||||
omit non-essential fields.
|
||||
"""
|
||||
if not isinstance(chunk, dict) or event_model is None:
|
||||
return chunk
|
||||
|
||||
patched: Dict[str, Any] = dict(chunk)
|
||||
fields_map = getattr(event_model, "model_fields", {}) or {}
|
||||
|
||||
for name, field in fields_map.items():
|
||||
if name in patched:
|
||||
patched[name] = VolcEngineResponsesAPIConfig._maybe_fill_nested(
|
||||
patched[name], field.annotation
|
||||
)
|
||||
continue
|
||||
|
||||
# Explicit default or factory
|
||||
if (
|
||||
field.default is not pyd_fields.PydanticUndefined
|
||||
and field.default is not None
|
||||
):
|
||||
patched[name] = field.default
|
||||
continue
|
||||
if (
|
||||
field.default_factory is not None
|
||||
and field.default_factory is not pyd_fields.PydanticUndefined
|
||||
):
|
||||
patched[name] = field.default_factory()
|
||||
continue
|
||||
|
||||
# Heuristic defaults for missing required fields
|
||||
patched[name] = VolcEngineResponsesAPIConfig._default_for_annotation(
|
||||
field.annotation
|
||||
)
|
||||
|
||||
return patched
|
||||
|
||||
@staticmethod
|
||||
def _default_for_annotation(annotation: Any) -> Any:
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
if annotation is int:
|
||||
return 0
|
||||
if annotation is list or origin is list:
|
||||
return []
|
||||
if origin is Union:
|
||||
# Prefer empty list when any option is a list
|
||||
if any((arg is list or get_origin(arg) is list) for arg in args):
|
||||
return []
|
||||
if type(None) in args:
|
||||
return None
|
||||
if origin is Union and type(None) in args:
|
||||
return None
|
||||
|
||||
# Fallback to None when no safer guess exists
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _maybe_fill_nested(value: Any, annotation: Any) -> Any:
|
||||
"""
|
||||
Recursively fill nested dict/list structures based on the annotated model.
|
||||
"""
|
||||
model_cls = VolcEngineResponsesAPIConfig._pick_model_class(annotation, value)
|
||||
args = get_args(annotation)
|
||||
|
||||
if isinstance(value, dict) and model_cls is not None:
|
||||
return VolcEngineResponsesAPIConfig._fill_missing_fields(value, model_cls)
|
||||
|
||||
if isinstance(value, list):
|
||||
# Attempt to fill list elements if we know the element annotation
|
||||
elem_ann: Any = args[0] if args else None
|
||||
if elem_ann is not None:
|
||||
return [
|
||||
VolcEngineResponsesAPIConfig._maybe_fill_nested(v, elem_ann)
|
||||
for v in value
|
||||
]
|
||||
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _pick_model_class(annotation: Any, value: Any) -> Optional[Any]:
|
||||
"""
|
||||
Choose the best-matching Pydantic model class for a nested dict.
|
||||
"""
|
||||
candidates: List[Any] = []
|
||||
origin = get_origin(annotation)
|
||||
|
||||
if hasattr(annotation, "model_fields"):
|
||||
candidates.append(annotation)
|
||||
if origin is Union:
|
||||
for arg in get_args(annotation):
|
||||
if hasattr(arg, "model_fields"):
|
||||
candidates.append(arg)
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Try to match by literal "type" field when available
|
||||
if isinstance(value, dict):
|
||||
v_type = value.get("type")
|
||||
for candidate in candidates:
|
||||
try:
|
||||
type_field = candidate.model_fields.get("type")
|
||||
if type_field is None:
|
||||
continue
|
||||
literal_ann = type_field.annotation
|
||||
if get_origin(literal_ann) is Literal:
|
||||
literal_values = get_args(literal_ann)
|
||||
if v_type in literal_values:
|
||||
return candidate
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Fall back to the first candidate
|
||||
return candidates[0]
|
||||
|
||||
def supports_native_websocket(self) -> bool:
|
||||
"""VolcEngine does not support native WebSocket for Responses API"""
|
||||
return False
|
||||
Reference in New Issue
Block a user