chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
Dynamic configuration class generator for JSON-based providers.
|
||||
"""
|
||||
|
||||
from typing import Any, Coroutine, List, Literal, Optional, Tuple, Union, overload
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
handle_messages_with_content_list_to_str_conversion,
|
||||
)
|
||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
from .json_loader import SimpleProviderConfig
|
||||
|
||||
|
||||
def create_config_class(provider: SimpleProviderConfig):
|
||||
"""Generate config class dynamically from JSON configuration"""
|
||||
|
||||
# Choose base class
|
||||
base_class: type = (
|
||||
OpenAIGPTConfig if provider.base_class == "openai_gpt" else OpenAILikeChatConfig
|
||||
)
|
||||
|
||||
class JSONProviderConfig(base_class): # type: ignore[valid-type,misc]
|
||||
@overload
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues], model: str, is_async: Literal[True]
|
||||
) -> Coroutine[Any, Any, List[AllMessageValues]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def _transform_messages(
|
||||
self,
|
||||
messages: List[AllMessageValues],
|
||||
model: str,
|
||||
is_async: Literal[False] = False,
|
||||
) -> List[AllMessageValues]:
|
||||
...
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues], model: str, is_async: bool = False
|
||||
) -> Union[List[AllMessageValues], Coroutine[Any, Any, List[AllMessageValues]]]:
|
||||
"""Transform messages based on special_handling config"""
|
||||
|
||||
# Handle content list to string conversion if configured
|
||||
if provider.special_handling.get("convert_content_list_to_string"):
|
||||
messages = handle_messages_with_content_list_to_str_conversion(messages)
|
||||
|
||||
if is_async:
|
||||
return super()._transform_messages(
|
||||
messages=messages, model=model, is_async=True
|
||||
)
|
||||
else:
|
||||
return super()._transform_messages(
|
||||
messages=messages, model=model, is_async=False
|
||||
)
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Get API base and key from JSON config"""
|
||||
|
||||
# Resolve base URL
|
||||
resolved_base = api_base
|
||||
if not resolved_base and provider.api_base_env:
|
||||
resolved_base = get_secret_str(provider.api_base_env)
|
||||
if not resolved_base:
|
||||
resolved_base = provider.base_url
|
||||
|
||||
# Resolve API key
|
||||
resolved_key = api_key or get_secret_str(provider.api_key_env)
|
||||
|
||||
return resolved_base, resolved_key
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""Build complete URL for the API endpoint"""
|
||||
if not api_base:
|
||||
api_base = provider.base_url
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError(f"api_base is required for provider {provider.slug}")
|
||||
|
||||
if not api_base.endswith("/chat/completions"):
|
||||
api_base = f"{api_base}/chat/completions"
|
||||
|
||||
return api_base
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""Get supported OpenAI params, excluding tool-related params for models
|
||||
that don't support function calling."""
|
||||
from litellm.utils import supports_function_calling
|
||||
|
||||
supported_params = super().get_supported_openai_params(model=model)
|
||||
|
||||
_supports_fc = supports_function_calling(
|
||||
model=model, custom_llm_provider=provider.slug
|
||||
)
|
||||
|
||||
if not _supports_fc:
|
||||
tool_params = [
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"function_call",
|
||||
"functions",
|
||||
"parallel_tool_calls",
|
||||
]
|
||||
for param in tool_params:
|
||||
if param in supported_params:
|
||||
supported_params.remove(param)
|
||||
verbose_logger.debug(
|
||||
f"Model {model} on provider {provider.slug} does not support "
|
||||
f"function calling — removed tool-related params from supported params."
|
||||
)
|
||||
|
||||
return supported_params
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""Apply parameter mappings and constraints"""
|
||||
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
|
||||
# Apply supported params
|
||||
for param, value in non_default_params.items():
|
||||
# Check parameter mappings first
|
||||
if param in provider.param_mappings:
|
||||
optional_params[provider.param_mappings[param]] = value
|
||||
elif param in supported_params:
|
||||
optional_params[param] = value
|
||||
|
||||
# Apply temperature constraints if present
|
||||
if "temperature" in optional_params:
|
||||
temp = optional_params["temperature"]
|
||||
constraints = provider.constraints
|
||||
|
||||
# Clamp to max
|
||||
if "temperature_max" in constraints:
|
||||
temp = min(temp, constraints["temperature_max"])
|
||||
|
||||
# Clamp to min
|
||||
if "temperature_min" in constraints:
|
||||
temp = max(temp, constraints["temperature_min"])
|
||||
|
||||
# Special case: temperature_min_with_n_gt_1
|
||||
if "temperature_min_with_n_gt_1" in constraints:
|
||||
n = optional_params.get("n", 1)
|
||||
if n > 1 and temp < constraints["temperature_min_with_n_gt_1"]:
|
||||
temp = constraints["temperature_min_with_n_gt_1"]
|
||||
|
||||
optional_params["temperature"] = temp
|
||||
|
||||
return optional_params
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return provider.slug
|
||||
|
||||
return JSONProviderConfig
|
||||
|
||||
|
||||
_responses_config_cache: dict = {}
|
||||
|
||||
|
||||
def create_responses_config_class(provider: SimpleProviderConfig):
|
||||
"""Generate a Responses API config class dynamically from JSON configuration.
|
||||
|
||||
Parallel to create_config_class() but for /v1/responses endpoints.
|
||||
Classes are cached per provider slug to avoid regeneration on every request.
|
||||
"""
|
||||
if provider.slug in _responses_config_cache:
|
||||
return _responses_config_cache[provider.slug]
|
||||
|
||||
from litellm.llms.openai_like.responses.transformation import (
|
||||
OpenAILikeResponsesConfig,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
class JSONProviderResponsesConfig(OpenAILikeResponsesConfig):
|
||||
@property
|
||||
def custom_llm_provider(self): # type: ignore[override]
|
||||
return provider.slug
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
litellm_params: Optional[GenericLiteLLMParams],
|
||||
) -> dict:
|
||||
litellm_params = litellm_params or GenericLiteLLMParams()
|
||||
api_key = litellm_params.api_key or get_secret_str(provider.api_key_env)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
if not api_base:
|
||||
if provider.api_base_env:
|
||||
api_base = get_secret_str(provider.api_base_env)
|
||||
if not api_base:
|
||||
api_base = provider.base_url
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError(f"api_base is required for provider {provider.slug}")
|
||||
|
||||
api_base = api_base.rstrip("/")
|
||||
return f"{api_base}/responses"
|
||||
|
||||
_responses_config_cache[provider.slug] = JSONProviderResponsesConfig
|
||||
return JSONProviderResponsesConfig
|
||||
Reference in New Issue
Block a user