chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Volcengine LLM Provider
|
||||
Support for Volcengine (ByteDance) chat, embedding, and responses models.
|
||||
"""
|
||||
|
||||
from .chat.transformation import VolcEngineChatConfig
|
||||
from .common_utils import (
|
||||
VolcEngineError,
|
||||
get_volcengine_base_url,
|
||||
get_volcengine_headers,
|
||||
)
|
||||
from .embedding import VolcEngineEmbeddingConfig
|
||||
from .responses.transformation import VolcEngineResponsesAPIConfig
|
||||
|
||||
# For backward compatibility, keep the old class name
|
||||
VolcEngineConfig = VolcEngineChatConfig
|
||||
|
||||
__all__ = [
|
||||
"VolcEngineChatConfig",
|
||||
"VolcEngineConfig", # backward compatibility
|
||||
"VolcEngineEmbeddingConfig",
|
||||
"VolcEngineResponsesAPIConfig",
|
||||
"VolcEngineError",
|
||||
"get_volcengine_base_url",
|
||||
"get_volcengine_headers",
|
||||
]
|
||||
@@ -0,0 +1,109 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig
|
||||
|
||||
|
||||
class VolcEngineChatConfig(OpenAILikeChatConfig):
|
||||
"""
|
||||
Reference: https://www.volcengine.com/docs/82379/1494384
|
||||
"""
|
||||
|
||||
frequency_penalty: Optional[int] = None
|
||||
function_call: Optional[Union[str, dict]] = None
|
||||
functions: Optional[list] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
response_format: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
function_call: Optional[Union[str, dict]] = None,
|
||||
functions: Optional[list] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
return [
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"max_completion_tokens",
|
||||
"max_tokens",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"seed",
|
||||
"stop",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"function_call",
|
||||
"functions",
|
||||
"max_retries",
|
||||
"extra_headers",
|
||||
"thinking",
|
||||
] # works across all models
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
replace_max_completion_tokens_with_max_tokens: bool = True,
|
||||
) -> dict:
|
||||
optional_params = super().map_openai_params(
|
||||
non_default_params,
|
||||
optional_params,
|
||||
model,
|
||||
drop_params,
|
||||
replace_max_completion_tokens_with_max_tokens,
|
||||
)
|
||||
|
||||
if "thinking" in optional_params:
|
||||
"""
|
||||
The `thinking` parameters of VolcEngine model has different default values.
|
||||
See the docs for details.
|
||||
Refrence: https://www.volcengine.com/docs/82379/1449737#0002
|
||||
"""
|
||||
thinking_value = optional_params.pop("thinking")
|
||||
|
||||
# Handle using thinking params case - add to extra_body if value is legal
|
||||
if (
|
||||
thinking_value is not None
|
||||
and isinstance(thinking_value, dict)
|
||||
and thinking_value.get("type", None)
|
||||
in ["enabled", "disabled", "auto"] # legal values, see docs
|
||||
):
|
||||
# Add thinking parameter to extra_body for all legal cases
|
||||
optional_params.setdefault("extra_body", {})[
|
||||
"thinking"
|
||||
] = thinking_value
|
||||
else:
|
||||
# Skip adding thinking parameter when it's not set or has invalid value
|
||||
pass
|
||||
return optional_params
|
||||
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Common utilities for Volcengine LLM provider
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class VolcEngineError(BaseLLMException):
|
||||
"""
|
||||
Custom exception class for Volcengine provider errors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, status_code: int, message: str, headers: Optional[httpx.Headers] = None
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.headers = headers or httpx.Headers()
|
||||
super().__init__(
|
||||
status_code=status_code, message=message, headers=dict(self.headers)
|
||||
)
|
||||
|
||||
|
||||
def get_volcengine_base_url(api_base: Optional[str] = None) -> str:
|
||||
"""
|
||||
Get the base URL for Volcengine API calls.
|
||||
|
||||
Args:
|
||||
api_base: Optional custom API base URL
|
||||
|
||||
Returns:
|
||||
The base URL to use for API calls
|
||||
"""
|
||||
if api_base:
|
||||
return api_base
|
||||
return "https://ark.cn-beijing.volces.com"
|
||||
|
||||
|
||||
def get_volcengine_headers(api_key: str, extra_headers: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Get headers for Volcengine API calls.
|
||||
|
||||
Args:
|
||||
api_key: The API key for authentication
|
||||
extra_headers: Optional additional headers
|
||||
|
||||
Returns:
|
||||
Dictionary of headers
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
return headers
|
||||
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Volcengine Embedding Module
|
||||
"""
|
||||
|
||||
from .transformation import VolcEngineEmbeddingConfig
|
||||
|
||||
__all__ = ["VolcEngineEmbeddingConfig"]
|
||||
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
Volcengine Embedding Transformation
|
||||
Transforms OpenAI embedding requests to Volcengine format
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
import httpx
|
||||
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from ..common_utils import get_volcengine_base_url, get_volcengine_headers
|
||||
|
||||
|
||||
class VolcEngineEmbeddingConfig(BaseEmbeddingConfig):
|
||||
"""
|
||||
Configuration class for Volcengine embedding models.
|
||||
Reference: https://ark.cn-beijing.volces.com/api/v3/embeddings
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoding_format: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
Get the list of OpenAI parameters supported by Volcengine embedding models.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
|
||||
Returns:
|
||||
List of supported parameter names
|
||||
"""
|
||||
return [
|
||||
"encoding_format",
|
||||
"user",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for volcengine embedding API calls.
|
||||
|
||||
Args:
|
||||
api_base: Optional custom API base URL
|
||||
api_key: API key (not used for URL construction)
|
||||
model: Model name (not used for URL construction)
|
||||
optional_params: Optional parameters (not used for URL construction)
|
||||
litellm_params: LiteLLM parameters (not used for URL construction)
|
||||
stream: Stream parameter (not used for URL construction)
|
||||
|
||||
Returns:
|
||||
Complete URL for the embedding API endpoint
|
||||
"""
|
||||
base_url = get_volcengine_base_url(api_base)
|
||||
# Construct the complete URL with /embeddings endpoint
|
||||
if base_url.endswith("/api/v3"):
|
||||
return f"{base_url}/embeddings"
|
||||
else:
|
||||
return f"{base_url}/api/v3/embeddings"
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: Dict[str, Any],
|
||||
optional_params: Dict[str, Any],
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Map OpenAI embedding parameters to Volcengine format.
|
||||
|
||||
Args:
|
||||
non_default_params: Parameters that are not default values
|
||||
optional_params: Optional parameters dict to update
|
||||
model: The model name
|
||||
drop_params: Whether to drop unsupported parameters
|
||||
|
||||
Returns:
|
||||
Updated optional_params dict
|
||||
"""
|
||||
for param, value in non_default_params.items():
|
||||
if param == "encoding_format":
|
||||
# Volcengine supports: float, base64, null
|
||||
if value in ["float", "base64", None]:
|
||||
optional_params["encoding_format"] = value
|
||||
else:
|
||||
if not drop_params:
|
||||
raise ValueError(
|
||||
f"Unsupported encoding_format: {value}. Volcengine supports: float, base64, null"
|
||||
)
|
||||
elif param == "user":
|
||||
# Keep user parameter as-is
|
||||
optional_params["user"] = value
|
||||
elif param in self.get_supported_openai_params(model):
|
||||
optional_params[param] = value
|
||||
elif not drop_params:
|
||||
raise ValueError(f"Unsupported parameter for Volcengine: {param}")
|
||||
|
||||
return optional_params
|
||||
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: AllEmbeddingInputValues,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""Transform embedding request to Volcengine format"""
|
||||
# Prepare request data (only the JSON body, not the full request)
|
||||
data = {
|
||||
"model": model,
|
||||
"input": input if isinstance(input, list) else [input],
|
||||
}
|
||||
|
||||
# Add optional parameters from optional_params
|
||||
if "encoding_format" in optional_params:
|
||||
encoding_format = optional_params["encoding_format"]
|
||||
if encoding_format is not None:
|
||||
data["encoding_format"] = encoding_format
|
||||
|
||||
if "user" in optional_params:
|
||||
user = optional_params["user"]
|
||||
if user is not None:
|
||||
data["user"] = user
|
||||
|
||||
return data
|
||||
|
||||
def transform_embedding_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str],
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> EmbeddingResponse:
|
||||
"""Transform Volcengine response to EmbeddingResponse"""
|
||||
try:
|
||||
response_json = raw_response.json()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse Volcengine response as JSON: {str(e)}")
|
||||
|
||||
# Volcengine response format matches OpenAI format closely
|
||||
# Just need to ensure all required fields are present
|
||||
transformed_response = {
|
||||
"object": "list",
|
||||
"data": response_json.get("data", []),
|
||||
"model": response_json.get("model", model),
|
||||
"usage": response_json.get("usage", {}),
|
||||
}
|
||||
|
||||
# Add id if present
|
||||
if "id" in response_json:
|
||||
transformed_response["id"] = response_json["id"]
|
||||
|
||||
# Create EmbeddingResponse from transformed data
|
||||
return EmbeddingResponse(**transformed_response)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Validate environment and return headers"""
|
||||
# Get Volcengine headers
|
||||
if api_key is None:
|
||||
raise ValueError("api_key is required for Volcengine authentication")
|
||||
volcengine_headers = get_volcengine_headers(api_key)
|
||||
return {**headers, **volcengine_headers}
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
"""Get error class for Volcengine errors"""
|
||||
from ..common_utils import VolcEngineError
|
||||
|
||||
# Convert dict to httpx.Headers if needed
|
||||
if isinstance(headers, dict):
|
||||
headers = httpx.Headers(headers)
|
||||
return VolcEngineError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,569 @@
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from pydantic import fields as pyd_fields
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import process_response_headers
|
||||
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
|
||||
_safe_convert_created_field,
|
||||
)
|
||||
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
ResponseInputParam,
|
||||
ResponsesAPIOptionalRequestParams,
|
||||
ResponsesAPIResponse,
|
||||
ResponsesAPIStreamingResponse,
|
||||
)
|
||||
from litellm.types.responses.main import DeleteResponseResult
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
from ..common_utils import (
|
||||
VolcEngineError,
|
||||
get_volcengine_base_url,
|
||||
get_volcengine_headers,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class VolcEngineResponsesAPIConfig(OpenAIResponsesAPIConfig):
|
||||
_SUPPORTED_OPTIONAL_PARAMS: List[str] = [
|
||||
# Doc-listed knobs
|
||||
"instructions",
|
||||
"max_output_tokens",
|
||||
"previous_response_id",
|
||||
"store",
|
||||
"reasoning",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"text",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"max_tool_calls",
|
||||
"thinking",
|
||||
"caching",
|
||||
"expire_at",
|
||||
"context_management",
|
||||
# LiteLLM-internal metadata (not sent to provider)
|
||||
"metadata",
|
||||
# Request plumbing helpers
|
||||
"extra_headers",
|
||||
"extra_query",
|
||||
"extra_body",
|
||||
"timeout",
|
||||
]
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.VOLCENGINE
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Volcengine Responses API: only documented parameters are supported.
|
||||
"""
|
||||
supported = ["input", "model"] + list(self._SUPPORTED_OPTIONAL_PARAMS)
|
||||
# Do not advertise internal-only metadata to callers; we still accept and drop it before send.
|
||||
if "metadata" in supported:
|
||||
supported.remove("metadata")
|
||||
return supported
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> VolcEngineError:
|
||||
typed_headers: httpx.Headers = (
|
||||
headers
|
||||
if isinstance(headers, httpx.Headers)
|
||||
else httpx.Headers(headers or {})
|
||||
)
|
||||
return VolcEngineError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=typed_headers,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
"""
|
||||
Build auth headers for Volcengine Responses API.
|
||||
"""
|
||||
if litellm_params is None:
|
||||
litellm_params = GenericLiteLLMParams()
|
||||
elif isinstance(litellm_params, dict):
|
||||
litellm_params = GenericLiteLLMParams(**litellm_params)
|
||||
|
||||
api_key = (
|
||||
litellm_params.api_key
|
||||
or litellm.api_key
|
||||
or get_secret_str("ARK_API_KEY")
|
||||
or get_secret_str("VOLCENGINE_API_KEY")
|
||||
)
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Volcengine API key is required. Set ARK_API_KEY / VOLCENGINE_API_KEY or pass api_key."
|
||||
)
|
||||
|
||||
return get_volcengine_headers(api_key=api_key, extra_headers=headers)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Construct Volcengine Responses API endpoint.
|
||||
"""
|
||||
base_url = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("VOLCENGINE_API_BASE")
|
||||
or get_secret_str("ARK_API_BASE")
|
||||
or get_volcengine_base_url()
|
||||
)
|
||||
|
||||
base_url = base_url.rstrip("/")
|
||||
|
||||
if base_url.endswith("/responses"):
|
||||
return base_url
|
||||
if base_url.endswith("/api/v3"):
|
||||
return f"{base_url}/responses"
|
||||
return f"{base_url}/api/v3/responses"
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
response_api_optional_params: ResponsesAPIOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
"""
|
||||
Volcengine Responses API aligns with OpenAI parameters.
|
||||
Remove parameters not supported by the public docs.
|
||||
"""
|
||||
params = {
|
||||
key: value
|
||||
for key, value in dict(response_api_optional_params).items()
|
||||
if key in self._SUPPORTED_OPTIONAL_PARAMS
|
||||
}
|
||||
|
||||
# LiteLLM metadata is internal-only; don't send to provider
|
||||
params.pop("metadata", None)
|
||||
|
||||
# Volcengine docs do not list parallel_tool_calls; drop it to avoid backend errors.
|
||||
if "parallel_tool_calls" in params:
|
||||
verbose_logger.debug(
|
||||
"Volcengine Responses API: dropping unsupported 'parallel_tool_calls' param."
|
||||
)
|
||||
params.pop("parallel_tool_calls", None)
|
||||
|
||||
return params
|
||||
|
||||
def transform_responses_api_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, ResponseInputParam],
|
||||
response_api_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
"""
|
||||
Volcengine rejects any undocumented fields (including extra_body). Fail fast
|
||||
with clear errors and re-filter with the documented whitelist before delegating
|
||||
to the OpenAI base transformer.
|
||||
"""
|
||||
allowed = set(self._SUPPORTED_OPTIONAL_PARAMS)
|
||||
|
||||
sanitized_optional = {
|
||||
k: v
|
||||
for k, v in response_api_optional_request_params.items()
|
||||
if k in allowed
|
||||
}
|
||||
# Ensure metadata never reaches provider
|
||||
sanitized_optional.pop("metadata", None)
|
||||
sanitized_optional.pop("parallel_tool_calls", None)
|
||||
|
||||
# If extra_body is provided, filter its keys against the same allowlist to avoid
|
||||
# leaking unsupported params to the provider.
|
||||
if isinstance(sanitized_optional.get("extra_body"), dict):
|
||||
filtered_body = {
|
||||
k: v
|
||||
for k, v in sanitized_optional["extra_body"].items()
|
||||
if k in allowed
|
||||
}
|
||||
if filtered_body:
|
||||
sanitized_optional["extra_body"] = filtered_body
|
||||
else:
|
||||
sanitized_optional.pop("extra_body", None)
|
||||
|
||||
return super().transform_responses_api_request(
|
||||
model=model,
|
||||
input=input,
|
||||
response_api_optional_request_params=sanitized_optional,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def transform_streaming_response(
|
||||
self,
|
||||
model: str,
|
||||
parsed_chunk: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIStreamingResponse:
|
||||
"""
|
||||
Volcengine may omit required fields; auto-fill them using event model defaults.
|
||||
"""
|
||||
chunk = parsed_chunk
|
||||
|
||||
# Patch missing response.output on response.* events
|
||||
if isinstance(chunk, dict):
|
||||
resp = chunk.get("response")
|
||||
if isinstance(resp, dict) and "output" not in resp:
|
||||
patched_chunk = dict(chunk)
|
||||
patched_resp = dict(resp)
|
||||
patched_resp["output"] = []
|
||||
patched_chunk["response"] = patched_resp
|
||||
chunk = patched_chunk
|
||||
|
||||
event_type = str(chunk.get("type")) if isinstance(chunk, dict) else None
|
||||
event_pydantic_model = OpenAIResponsesAPIConfig.get_event_model_class(
|
||||
event_type=event_type
|
||||
)
|
||||
|
||||
patched_chunk = self._fill_missing_fields(chunk, event_pydantic_model)
|
||||
|
||||
return event_pydantic_model(**patched_chunk)
|
||||
|
||||
def transform_response_api_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
try:
|
||||
logging_obj.post_call(
|
||||
original_response=raw_response.text,
|
||||
additional_args={"complete_input_dict": {}},
|
||||
)
|
||||
raw_response_json = raw_response.json()
|
||||
if "created_at" in raw_response_json:
|
||||
raw_response_json["created_at"] = _safe_convert_created_field(
|
||||
raw_response_json["created_at"]
|
||||
)
|
||||
except Exception:
|
||||
raise VolcEngineError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
raw_response_headers = dict(raw_response.headers)
|
||||
processed_headers = process_response_headers(raw_response_headers)
|
||||
|
||||
try:
|
||||
response = ResponsesAPIResponse(**raw_response_json)
|
||||
except Exception:
|
||||
verbose_logger.debug(
|
||||
"Volcengine Responses API: falling back to model_construct for response parsing."
|
||||
)
|
||||
response = ResponsesAPIResponse.model_construct(**raw_response_json)
|
||||
|
||||
response._hidden_params["additional_headers"] = processed_headers
|
||||
response._hidden_params["headers"] = raw_response_headers
|
||||
return response
|
||||
|
||||
#########################################################
|
||||
########## DELETE RESPONSE API TRANSFORMATION ##############
|
||||
#########################################################
|
||||
def transform_delete_response_api_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
url = f"{api_base}/{response_id}"
|
||||
data: Dict = {}
|
||||
return url, data
|
||||
|
||||
def transform_delete_response_api_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> DeleteResponseResult:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise VolcEngineError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
try:
|
||||
return DeleteResponseResult(**raw_response_json)
|
||||
except Exception:
|
||||
verbose_logger.debug(
|
||||
"Volcengine Responses API: falling back to model_construct for delete response parsing."
|
||||
)
|
||||
return DeleteResponseResult.model_construct(**raw_response_json)
|
||||
|
||||
#########################################################
|
||||
########## GET RESPONSE API TRANSFORMATION ###############
|
||||
#########################################################
|
||||
def transform_get_response_api_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
url = f"{api_base}/{response_id}"
|
||||
data: Dict = {}
|
||||
return url, data
|
||||
|
||||
def transform_get_response_api_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise VolcEngineError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
raw_response_headers = dict(raw_response.headers)
|
||||
processed_headers = process_response_headers(raw_response_headers)
|
||||
|
||||
response = ResponsesAPIResponse(**raw_response_json)
|
||||
response._hidden_params["additional_headers"] = processed_headers
|
||||
response._hidden_params["headers"] = raw_response_headers
|
||||
return response
|
||||
|
||||
#########################################################
|
||||
########## LIST INPUT ITEMS TRANSFORMATION #############
|
||||
#########################################################
|
||||
def transform_list_input_items_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
include: Optional[List[str]] = None,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
) -> Tuple[str, Dict]:
|
||||
url = f"{api_base}/{response_id}/input_items"
|
||||
params: Dict[str, Any] = {}
|
||||
if after is not None:
|
||||
params["after"] = after
|
||||
if before is not None:
|
||||
params["before"] = before
|
||||
if include:
|
||||
params["include"] = ",".join(include)
|
||||
if limit is not None:
|
||||
params["limit"] = limit
|
||||
if order is not None:
|
||||
params["order"] = order
|
||||
return url, params
|
||||
|
||||
def transform_list_input_items_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Dict:
|
||||
try:
|
||||
return raw_response.json()
|
||||
except Exception:
|
||||
raise VolcEngineError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
#########################################################
|
||||
########## CANCEL RESPONSE API TRANSFORMATION ##########
|
||||
#########################################################
|
||||
def transform_cancel_response_api_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
url = f"{api_base}/{response_id}/cancel"
|
||||
data: Dict = {}
|
||||
return url, data
|
||||
|
||||
def transform_cancel_response_api_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise VolcEngineError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
raw_response_headers = dict(raw_response.headers)
|
||||
processed_headers = process_response_headers(raw_response_headers)
|
||||
|
||||
response = ResponsesAPIResponse(**raw_response_json)
|
||||
response._hidden_params["additional_headers"] = processed_headers
|
||||
response._hidden_params["headers"] = raw_response_headers
|
||||
return response
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Volcengine Responses API supports native streaming; never fall back to fake stream.
|
||||
"""
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _fill_missing_fields(chunk: Any, event_model: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Heuristically fill missing required fields with safe defaults based on the
|
||||
event model's field annotations. This keeps parsing tolerant of providers that
|
||||
omit non-essential fields.
|
||||
"""
|
||||
if not isinstance(chunk, dict) or event_model is None:
|
||||
return chunk
|
||||
|
||||
patched: Dict[str, Any] = dict(chunk)
|
||||
fields_map = getattr(event_model, "model_fields", {}) or {}
|
||||
|
||||
for name, field in fields_map.items():
|
||||
if name in patched:
|
||||
patched[name] = VolcEngineResponsesAPIConfig._maybe_fill_nested(
|
||||
patched[name], field.annotation
|
||||
)
|
||||
continue
|
||||
|
||||
# Explicit default or factory
|
||||
if (
|
||||
field.default is not pyd_fields.PydanticUndefined
|
||||
and field.default is not None
|
||||
):
|
||||
patched[name] = field.default
|
||||
continue
|
||||
if (
|
||||
field.default_factory is not None
|
||||
and field.default_factory is not pyd_fields.PydanticUndefined
|
||||
):
|
||||
patched[name] = field.default_factory()
|
||||
continue
|
||||
|
||||
# Heuristic defaults for missing required fields
|
||||
patched[name] = VolcEngineResponsesAPIConfig._default_for_annotation(
|
||||
field.annotation
|
||||
)
|
||||
|
||||
return patched
|
||||
|
||||
@staticmethod
|
||||
def _default_for_annotation(annotation: Any) -> Any:
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
if annotation is int:
|
||||
return 0
|
||||
if annotation is list or origin is list:
|
||||
return []
|
||||
if origin is Union:
|
||||
# Prefer empty list when any option is a list
|
||||
if any((arg is list or get_origin(arg) is list) for arg in args):
|
||||
return []
|
||||
if type(None) in args:
|
||||
return None
|
||||
if origin is Union and type(None) in args:
|
||||
return None
|
||||
|
||||
# Fallback to None when no safer guess exists
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _maybe_fill_nested(value: Any, annotation: Any) -> Any:
|
||||
"""
|
||||
Recursively fill nested dict/list structures based on the annotated model.
|
||||
"""
|
||||
model_cls = VolcEngineResponsesAPIConfig._pick_model_class(annotation, value)
|
||||
args = get_args(annotation)
|
||||
|
||||
if isinstance(value, dict) and model_cls is not None:
|
||||
return VolcEngineResponsesAPIConfig._fill_missing_fields(value, model_cls)
|
||||
|
||||
if isinstance(value, list):
|
||||
# Attempt to fill list elements if we know the element annotation
|
||||
elem_ann: Any = args[0] if args else None
|
||||
if elem_ann is not None:
|
||||
return [
|
||||
VolcEngineResponsesAPIConfig._maybe_fill_nested(v, elem_ann)
|
||||
for v in value
|
||||
]
|
||||
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _pick_model_class(annotation: Any, value: Any) -> Optional[Any]:
|
||||
"""
|
||||
Choose the best-matching Pydantic model class for a nested dict.
|
||||
"""
|
||||
candidates: List[Any] = []
|
||||
origin = get_origin(annotation)
|
||||
|
||||
if hasattr(annotation, "model_fields"):
|
||||
candidates.append(annotation)
|
||||
if origin is Union:
|
||||
for arg in get_args(annotation):
|
||||
if hasattr(arg, "model_fields"):
|
||||
candidates.append(arg)
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Try to match by literal "type" field when available
|
||||
if isinstance(value, dict):
|
||||
v_type = value.get("type")
|
||||
for candidate in candidates:
|
||||
try:
|
||||
type_field = candidate.model_fields.get("type")
|
||||
if type_field is None:
|
||||
continue
|
||||
literal_ann = type_field.annotation
|
||||
if get_origin(literal_ann) is Literal:
|
||||
literal_values = get_args(literal_ann)
|
||||
if v_type in literal_values:
|
||||
return candidate
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Fall back to the first candidate
|
||||
return candidates[0]
|
||||
|
||||
def supports_native_websocket(self) -> bool:
|
||||
"""VolcEngine does not support native WebSocket for Responses API"""
|
||||
return False
|
||||
Reference in New Issue
Block a user