from typing import ( TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union, get_args, get_origin, ) import httpx from pydantic import fields as pyd_fields import litellm from litellm._logging import verbose_logger from litellm.litellm_core_utils.core_helpers import process_response_headers from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import ( _safe_convert_created_field, ) from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig from litellm.secret_managers.main import get_secret_str from litellm.types.llms.openai import ( ResponseInputParam, ResponsesAPIOptionalRequestParams, ResponsesAPIResponse, ResponsesAPIStreamingResponse, ) from litellm.types.responses.main import DeleteResponseResult from litellm.types.router import GenericLiteLLMParams from litellm.types.utils import LlmProviders from ..common_utils import ( VolcEngineError, get_volcengine_base_url, get_volcengine_headers, ) if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj LiteLLMLoggingObj = _LiteLLMLoggingObj else: LiteLLMLoggingObj = Any class VolcEngineResponsesAPIConfig(OpenAIResponsesAPIConfig): _SUPPORTED_OPTIONAL_PARAMS: List[str] = [ # Doc-listed knobs "instructions", "max_output_tokens", "previous_response_id", "store", "reasoning", "stream", "temperature", "top_p", "text", "tools", "tool_choice", "max_tool_calls", "thinking", "caching", "expire_at", "context_management", # LiteLLM-internal metadata (not sent to provider) "metadata", # Request plumbing helpers "extra_headers", "extra_query", "extra_body", "timeout", ] @property def custom_llm_provider(self) -> LlmProviders: return LlmProviders.VOLCENGINE def get_supported_openai_params(self, model: str) -> list: """ Volcengine Responses API: only documented parameters are supported. """ supported = ["input", "model"] + list(self._SUPPORTED_OPTIONAL_PARAMS) # Do not advertise internal-only metadata to callers; we still accept and drop it before send. if "metadata" in supported: supported.remove("metadata") return supported def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] ) -> VolcEngineError: typed_headers: httpx.Headers = ( headers if isinstance(headers, httpx.Headers) else httpx.Headers(headers or {}) ) return VolcEngineError( status_code=status_code, message=error_message, headers=typed_headers, ) def validate_environment( self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams] ) -> dict: """ Build auth headers for Volcengine Responses API. """ if litellm_params is None: litellm_params = GenericLiteLLMParams() elif isinstance(litellm_params, dict): litellm_params = GenericLiteLLMParams(**litellm_params) api_key = ( litellm_params.api_key or litellm.api_key or get_secret_str("ARK_API_KEY") or get_secret_str("VOLCENGINE_API_KEY") ) if api_key is None: raise ValueError( "Volcengine API key is required. Set ARK_API_KEY / VOLCENGINE_API_KEY or pass api_key." ) return get_volcengine_headers(api_key=api_key, extra_headers=headers) def get_complete_url( self, api_base: Optional[str], litellm_params: dict, ) -> str: """ Construct Volcengine Responses API endpoint. """ base_url = ( api_base or litellm.api_base or get_secret_str("VOLCENGINE_API_BASE") or get_secret_str("ARK_API_BASE") or get_volcengine_base_url() ) base_url = base_url.rstrip("/") if base_url.endswith("/responses"): return base_url if base_url.endswith("/api/v3"): return f"{base_url}/responses" return f"{base_url}/api/v3/responses" def map_openai_params( self, response_api_optional_params: ResponsesAPIOptionalRequestParams, model: str, drop_params: bool, ) -> Dict: """ Volcengine Responses API aligns with OpenAI parameters. Remove parameters not supported by the public docs. """ params = { key: value for key, value in dict(response_api_optional_params).items() if key in self._SUPPORTED_OPTIONAL_PARAMS } # LiteLLM metadata is internal-only; don't send to provider params.pop("metadata", None) # Volcengine docs do not list parallel_tool_calls; drop it to avoid backend errors. if "parallel_tool_calls" in params: verbose_logger.debug( "Volcengine Responses API: dropping unsupported 'parallel_tool_calls' param." ) params.pop("parallel_tool_calls", None) return params def transform_responses_api_request( self, model: str, input: Union[str, ResponseInputParam], response_api_optional_request_params: Dict, litellm_params: GenericLiteLLMParams, headers: dict, ) -> Dict: """ Volcengine rejects any undocumented fields (including extra_body). Fail fast with clear errors and re-filter with the documented whitelist before delegating to the OpenAI base transformer. """ allowed = set(self._SUPPORTED_OPTIONAL_PARAMS) sanitized_optional = { k: v for k, v in response_api_optional_request_params.items() if k in allowed } # Ensure metadata never reaches provider sanitized_optional.pop("metadata", None) sanitized_optional.pop("parallel_tool_calls", None) # If extra_body is provided, filter its keys against the same allowlist to avoid # leaking unsupported params to the provider. if isinstance(sanitized_optional.get("extra_body"), dict): filtered_body = { k: v for k, v in sanitized_optional["extra_body"].items() if k in allowed } if filtered_body: sanitized_optional["extra_body"] = filtered_body else: sanitized_optional.pop("extra_body", None) return super().transform_responses_api_request( model=model, input=input, response_api_optional_request_params=sanitized_optional, litellm_params=litellm_params, headers=headers, ) def transform_streaming_response( self, model: str, parsed_chunk: dict, logging_obj: LiteLLMLoggingObj, ) -> ResponsesAPIStreamingResponse: """ Volcengine may omit required fields; auto-fill them using event model defaults. """ chunk = parsed_chunk # Patch missing response.output on response.* events if isinstance(chunk, dict): resp = chunk.get("response") if isinstance(resp, dict) and "output" not in resp: patched_chunk = dict(chunk) patched_resp = dict(resp) patched_resp["output"] = [] patched_chunk["response"] = patched_resp chunk = patched_chunk event_type = str(chunk.get("type")) if isinstance(chunk, dict) else None event_pydantic_model = OpenAIResponsesAPIConfig.get_event_model_class( event_type=event_type ) patched_chunk = self._fill_missing_fields(chunk, event_pydantic_model) return event_pydantic_model(**patched_chunk) def transform_response_api_response( self, model: str, raw_response: httpx.Response, logging_obj: LiteLLMLoggingObj, ) -> ResponsesAPIResponse: try: logging_obj.post_call( original_response=raw_response.text, additional_args={"complete_input_dict": {}}, ) raw_response_json = raw_response.json() if "created_at" in raw_response_json: raw_response_json["created_at"] = _safe_convert_created_field( raw_response_json["created_at"] ) except Exception: raise VolcEngineError( message=raw_response.text, status_code=raw_response.status_code ) raw_response_headers = dict(raw_response.headers) processed_headers = process_response_headers(raw_response_headers) try: response = ResponsesAPIResponse(**raw_response_json) except Exception: verbose_logger.debug( "Volcengine Responses API: falling back to model_construct for response parsing." ) response = ResponsesAPIResponse.model_construct(**raw_response_json) response._hidden_params["additional_headers"] = processed_headers response._hidden_params["headers"] = raw_response_headers return response ######################################################### ########## DELETE RESPONSE API TRANSFORMATION ############## ######################################################### def transform_delete_response_api_request( self, response_id: str, api_base: str, litellm_params: GenericLiteLLMParams, headers: dict, ) -> Tuple[str, Dict]: url = f"{api_base}/{response_id}" data: Dict = {} return url, data def transform_delete_response_api_response( self, raw_response: httpx.Response, logging_obj: LiteLLMLoggingObj, ) -> DeleteResponseResult: try: raw_response_json = raw_response.json() except Exception: raise VolcEngineError( message=raw_response.text, status_code=raw_response.status_code ) try: return DeleteResponseResult(**raw_response_json) except Exception: verbose_logger.debug( "Volcengine Responses API: falling back to model_construct for delete response parsing." ) return DeleteResponseResult.model_construct(**raw_response_json) ######################################################### ########## GET RESPONSE API TRANSFORMATION ############### ######################################################### def transform_get_response_api_request( self, response_id: str, api_base: str, litellm_params: GenericLiteLLMParams, headers: dict, ) -> Tuple[str, Dict]: url = f"{api_base}/{response_id}" data: Dict = {} return url, data def transform_get_response_api_response( self, raw_response: httpx.Response, logging_obj: LiteLLMLoggingObj, ) -> ResponsesAPIResponse: try: raw_response_json = raw_response.json() except Exception: raise VolcEngineError( message=raw_response.text, status_code=raw_response.status_code ) raw_response_headers = dict(raw_response.headers) processed_headers = process_response_headers(raw_response_headers) response = ResponsesAPIResponse(**raw_response_json) response._hidden_params["additional_headers"] = processed_headers response._hidden_params["headers"] = raw_response_headers return response ######################################################### ########## LIST INPUT ITEMS TRANSFORMATION ############# ######################################################### def transform_list_input_items_request( self, response_id: str, api_base: str, litellm_params: GenericLiteLLMParams, headers: dict, after: Optional[str] = None, before: Optional[str] = None, include: Optional[List[str]] = None, limit: int = 20, order: Literal["asc", "desc"] = "desc", ) -> Tuple[str, Dict]: url = f"{api_base}/{response_id}/input_items" params: Dict[str, Any] = {} if after is not None: params["after"] = after if before is not None: params["before"] = before if include: params["include"] = ",".join(include) if limit is not None: params["limit"] = limit if order is not None: params["order"] = order return url, params def transform_list_input_items_response( self, raw_response: httpx.Response, logging_obj: LiteLLMLoggingObj, ) -> Dict: try: return raw_response.json() except Exception: raise VolcEngineError( message=raw_response.text, status_code=raw_response.status_code ) ######################################################### ########## CANCEL RESPONSE API TRANSFORMATION ########## ######################################################### def transform_cancel_response_api_request( self, response_id: str, api_base: str, litellm_params: GenericLiteLLMParams, headers: dict, ) -> Tuple[str, Dict]: url = f"{api_base}/{response_id}/cancel" data: Dict = {} return url, data def transform_cancel_response_api_response( self, raw_response: httpx.Response, logging_obj: LiteLLMLoggingObj, ) -> ResponsesAPIResponse: try: raw_response_json = raw_response.json() except Exception: raise VolcEngineError( message=raw_response.text, status_code=raw_response.status_code ) raw_response_headers = dict(raw_response.headers) processed_headers = process_response_headers(raw_response_headers) response = ResponsesAPIResponse(**raw_response_json) response._hidden_params["additional_headers"] = processed_headers response._hidden_params["headers"] = raw_response_headers return response def should_fake_stream( self, model: Optional[str], stream: Optional[bool], custom_llm_provider: Optional[str] = None, ) -> bool: """ Volcengine Responses API supports native streaming; never fall back to fake stream. """ return False @staticmethod def _fill_missing_fields(chunk: Any, event_model: Any) -> Dict[str, Any]: """ Heuristically fill missing required fields with safe defaults based on the event model's field annotations. This keeps parsing tolerant of providers that omit non-essential fields. """ if not isinstance(chunk, dict) or event_model is None: return chunk patched: Dict[str, Any] = dict(chunk) fields_map = getattr(event_model, "model_fields", {}) or {} for name, field in fields_map.items(): if name in patched: patched[name] = VolcEngineResponsesAPIConfig._maybe_fill_nested( patched[name], field.annotation ) continue # Explicit default or factory if ( field.default is not pyd_fields.PydanticUndefined and field.default is not None ): patched[name] = field.default continue if ( field.default_factory is not None and field.default_factory is not pyd_fields.PydanticUndefined ): patched[name] = field.default_factory() continue # Heuristic defaults for missing required fields patched[name] = VolcEngineResponsesAPIConfig._default_for_annotation( field.annotation ) return patched @staticmethod def _default_for_annotation(annotation: Any) -> Any: origin = get_origin(annotation) args = get_args(annotation) if annotation is int: return 0 if annotation is list or origin is list: return [] if origin is Union: # Prefer empty list when any option is a list if any((arg is list or get_origin(arg) is list) for arg in args): return [] if type(None) in args: return None if origin is Union and type(None) in args: return None # Fallback to None when no safer guess exists return None @staticmethod def _maybe_fill_nested(value: Any, annotation: Any) -> Any: """ Recursively fill nested dict/list structures based on the annotated model. """ model_cls = VolcEngineResponsesAPIConfig._pick_model_class(annotation, value) args = get_args(annotation) if isinstance(value, dict) and model_cls is not None: return VolcEngineResponsesAPIConfig._fill_missing_fields(value, model_cls) if isinstance(value, list): # Attempt to fill list elements if we know the element annotation elem_ann: Any = args[0] if args else None if elem_ann is not None: return [ VolcEngineResponsesAPIConfig._maybe_fill_nested(v, elem_ann) for v in value ] return value @staticmethod def _pick_model_class(annotation: Any, value: Any) -> Optional[Any]: """ Choose the best-matching Pydantic model class for a nested dict. """ candidates: List[Any] = [] origin = get_origin(annotation) if hasattr(annotation, "model_fields"): candidates.append(annotation) if origin is Union: for arg in get_args(annotation): if hasattr(arg, "model_fields"): candidates.append(arg) if not candidates: return None # Try to match by literal "type" field when available if isinstance(value, dict): v_type = value.get("type") for candidate in candidates: try: type_field = candidate.model_fields.get("type") if type_field is None: continue literal_ann = type_field.annotation if get_origin(literal_ann) is Literal: literal_values = get_args(literal_ann) if v_type in literal_values: return candidate except Exception: continue # Fall back to the first candidate return candidates[0] def supports_native_websocket(self) -> bool: """VolcEngine does not support native WebSocket for Responses API""" return False