chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
from .anthropic_messages.transformation import BaseAnthropicMessagesConfig
|
||||
from .audio_transcription.transformation import BaseAudioTranscriptionConfig
|
||||
from .batches.transformation import BaseBatchesConfig
|
||||
from .chat.transformation import BaseConfig
|
||||
from .embedding.transformation import BaseEmbeddingConfig
|
||||
from .image_edit.transformation import BaseImageEditConfig
|
||||
from .image_generation.transformation import BaseImageGenerationConfig
|
||||
|
||||
__all__ = [
|
||||
"BaseImageGenerationConfig",
|
||||
"BaseConfig",
|
||||
"BaseAudioTranscriptionConfig",
|
||||
"BaseAnthropicMessagesConfig",
|
||||
"BaseEmbeddingConfig",
|
||||
"BaseImageEditConfig",
|
||||
"BaseBatchesConfig",
|
||||
]
|
||||
@@ -0,0 +1,122 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.llms.anthropic_messages.anthropic_response import (
|
||||
AnthropicMessagesResponse,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseAnthropicMessagesConfig(ABC):
|
||||
@abstractmethod
|
||||
def validate_anthropic_messages_environment( # use different name because return type is different from base config's validate_environment
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[Any],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Tuple[dict, Optional[str]]:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Validate the environment for the request
|
||||
|
||||
Returns:
|
||||
- headers: dict
|
||||
- api_base: Optional[str] - If the provider needs to update the api_base, return it here. Otherwise, return None.
|
||||
"""
|
||||
return headers, api_base
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_anthropic_messages_params(self, model: str) -> list:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_anthropic_messages_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict],
|
||||
anthropic_messages_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_anthropic_messages_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> AnthropicMessagesResponse:
|
||||
pass
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
headers: dict,
|
||||
optional_params: dict,
|
||||
request_data: dict,
|
||||
api_base: str,
|
||||
api_key: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
fake_stream: Optional[bool] = None,
|
||||
) -> Tuple[dict, Optional[bytes]]:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Sign the request, providers like Bedrock need to sign the request before sending it to the API
|
||||
|
||||
For all other providers, this is a no-op and we just return the headers
|
||||
"""
|
||||
return headers, None
|
||||
|
||||
def get_async_streaming_response_iterator(
|
||||
self,
|
||||
model: str,
|
||||
httpx_response: httpx.Response,
|
||||
request_body: dict,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
) -> AsyncIterator:
|
||||
raise NotImplementedError("Subclasses must implement this method")
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> "BaseLLMException":
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
return BaseLLMException(
|
||||
message=error_message, status_code=status_code, headers=headers
|
||||
)
|
||||
@@ -0,0 +1,172 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
OpenAIAudioTranscriptionOptionalParams,
|
||||
)
|
||||
from litellm.types.utils import FileTypes, ModelResponse, TranscriptionResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioTranscriptionRequestData:
|
||||
"""
|
||||
Structured data for audio transcription requests.
|
||||
|
||||
Attributes:
|
||||
data: The request data (form data for multipart, json data for regular requests)
|
||||
files: Optional files dict for multipart form data
|
||||
content_type: Optional content type override
|
||||
"""
|
||||
|
||||
data: Union[dict, bytes]
|
||||
files: Optional[dict] = None
|
||||
content_type: Optional[str] = None
|
||||
|
||||
|
||||
class BaseAudioTranscriptionConfig(BaseConfig, ABC):
|
||||
@abstractmethod
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIAudioTranscriptionOptionalParams]:
|
||||
pass
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
|
||||
@abstractmethod
|
||||
def transform_audio_transcription_request(
|
||||
self,
|
||||
model: str,
|
||||
audio_file: FileTypes,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> AudioTranscriptionRequestData:
|
||||
raise NotImplementedError(
|
||||
"AudioTranscriptionConfig needs a request transformation for audio transcription models"
|
||||
)
|
||||
|
||||
def transform_audio_transcription_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
) -> TranscriptionResponse:
|
||||
raise NotImplementedError(
|
||||
"AudioTranscriptionConfig does not need a response transformation for audio transcription models"
|
||||
)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"AudioTranscriptionConfig does not need a request transformation for audio transcription models"
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"AudioTranscriptionConfig does not need a response transformation for audio transcription models"
|
||||
)
|
||||
|
||||
def get_provider_specific_params(
|
||||
self,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
openai_params: List[OpenAIAudioTranscriptionOptionalParams],
|
||||
) -> dict:
|
||||
"""
|
||||
Get provider specific parameters that are not OpenAI compatible
|
||||
|
||||
eg. if user passes `diarize=True`, we need to pass `diarize` to the provider
|
||||
but `diarize` is not an OpenAI parameter, so we need to handle it here
|
||||
"""
|
||||
provider_specific_params = {}
|
||||
for key, value in optional_params.items():
|
||||
# Skip None values
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
# Skip excluded parameters
|
||||
if self._should_exclude_param(
|
||||
param_name=key,
|
||||
model=model,
|
||||
):
|
||||
continue
|
||||
|
||||
# Add the parameter to the provider specific params
|
||||
provider_specific_params[key] = value
|
||||
|
||||
return provider_specific_params
|
||||
|
||||
def _should_exclude_param(
|
||||
self,
|
||||
param_name: str,
|
||||
model: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Determines if a parameter should be excluded from the query string.
|
||||
|
||||
Args:
|
||||
param_name: Parameter name
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
True if the parameter should be excluded
|
||||
"""
|
||||
# Parameters that are handled elsewhere or not relevant to Deepgram API
|
||||
excluded_params = {
|
||||
"model", # Already in the URL path
|
||||
"OPENAI_TRANSCRIPTION_PARAMS", # Internal litellm parameter
|
||||
}
|
||||
|
||||
# Skip if it's an excluded parameter
|
||||
if param_name in excluded_params:
|
||||
return True
|
||||
|
||||
# Skip if it's an OpenAI-specific parameter that we handle separately
|
||||
if param_name in self.get_supported_openai_params(model):
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -0,0 +1,264 @@
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Union, cast
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import (
|
||||
Choices,
|
||||
Delta,
|
||||
GenericStreamingChunk,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
|
||||
def convert_model_response_to_streaming(
|
||||
model_response: ModelResponse,
|
||||
) -> ModelResponseStream:
|
||||
"""
|
||||
Convert a ModelResponse to ModelResponseStream.
|
||||
|
||||
This function transforms a standard completion response into a streaming chunk format
|
||||
by converting 'message' fields to 'delta' fields.
|
||||
|
||||
Args:
|
||||
model_response: The ModelResponse to convert
|
||||
|
||||
Returns:
|
||||
ModelResponseStream: A streaming chunk version of the response
|
||||
|
||||
Raises:
|
||||
ValueError: If the conversion fails
|
||||
"""
|
||||
try:
|
||||
streaming_choices: List[StreamingChoices] = []
|
||||
for choice in model_response.choices:
|
||||
streaming_choices.append(
|
||||
StreamingChoices(
|
||||
index=choice.index,
|
||||
delta=Delta(
|
||||
**cast(Choices, choice).message.model_dump(),
|
||||
),
|
||||
finish_reason=choice.finish_reason,
|
||||
)
|
||||
)
|
||||
processed_chunk = ModelResponseStream(
|
||||
id=model_response.id,
|
||||
object="chat.completion.chunk",
|
||||
created=model_response.created,
|
||||
model=model_response.model,
|
||||
choices=streaming_choices,
|
||||
)
|
||||
return processed_chunk
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to convert ModelResponse to ModelResponseStream: {model_response}. Error: {e}"
|
||||
)
|
||||
|
||||
|
||||
class BaseModelResponseIterator:
|
||||
def __init__(
|
||||
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
|
||||
):
|
||||
self.streaming_response = streaming_response
|
||||
self.response_iterator = self.streaming_response
|
||||
self.json_mode = json_mode
|
||||
|
||||
def chunk_parser(
|
||||
self, chunk: dict
|
||||
) -> Union[GenericStreamingChunk, ModelResponseStream]:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _string_to_dict_parser(str_line: str) -> Optional[dict]:
|
||||
stripped_json_chunk: Optional[dict] = None
|
||||
stripped_chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(
|
||||
str_line
|
||||
)
|
||||
try:
|
||||
if stripped_chunk is not None:
|
||||
stripped_json_chunk = json.loads(stripped_chunk)
|
||||
else:
|
||||
stripped_json_chunk = None
|
||||
except json.JSONDecodeError:
|
||||
stripped_json_chunk = None
|
||||
return stripped_json_chunk
|
||||
|
||||
def _handle_string_chunk(
|
||||
self, str_line: str
|
||||
) -> Union[GenericStreamingChunk, ModelResponseStream]:
|
||||
# chunk is a str at this point
|
||||
stripped_json_chunk = BaseModelResponseIterator._string_to_dict_parser(
|
||||
str_line=str_line
|
||||
)
|
||||
if "[DONE]" in str_line:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
elif stripped_json_chunk:
|
||||
return self.chunk_parser(chunk=stripped_json_chunk)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
|
||||
def __next__(self):
|
||||
while True:
|
||||
try:
|
||||
chunk = self.response_iterator.__next__()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
|
||||
# Skip empty lines (common in SSE streams between events).
|
||||
# Only apply to str chunks — non-string objects (e.g. Pydantic
|
||||
# BaseModel events from the Responses API) must pass through.
|
||||
if isinstance(str_line, str) and (
|
||||
not str_line or not str_line.strip()
|
||||
):
|
||||
continue
|
||||
|
||||
# chunk is a str at this point
|
||||
return self._handle_string_chunk(str_line=str_line)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(
|
||||
f"Error parsing chunk: {e},\nReceived chunk: {chunk}"
|
||||
)
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
while True:
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
|
||||
# Skip empty lines (common in SSE streams between events).
|
||||
# Only apply to str chunks — non-string objects (e.g. Pydantic
|
||||
# BaseModel events from the Responses API) must pass through.
|
||||
if isinstance(str_line, str) and (
|
||||
not str_line or not str_line.strip()
|
||||
):
|
||||
continue
|
||||
|
||||
# chunk is a str at this point
|
||||
chunk = self._handle_string_chunk(str_line=str_line)
|
||||
|
||||
return chunk
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(
|
||||
f"Error parsing chunk: {e},\nReceived chunk: {chunk}"
|
||||
)
|
||||
|
||||
|
||||
class MockResponseIterator: # for returning ai21 streaming responses
|
||||
def __init__(
|
||||
self, model_response: ModelResponse, json_mode: Optional[bool] = False
|
||||
):
|
||||
self.model_response = model_response
|
||||
self.json_mode = json_mode
|
||||
self.is_done = False
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def _chunk_parser(self, chunk_data: ModelResponse) -> ModelResponseStream:
|
||||
return convert_model_response_to_streaming(chunk_data)
|
||||
|
||||
def __next__(self):
|
||||
if self.is_done:
|
||||
raise StopIteration
|
||||
self.is_done = True
|
||||
return self._chunk_parser(self.model_response)
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.is_done:
|
||||
raise StopAsyncIteration
|
||||
self.is_done = True
|
||||
return self._chunk_parser(self.model_response)
|
||||
|
||||
|
||||
class FakeStreamResponseIterator:
|
||||
def __init__(self, model_response, json_mode: Optional[bool] = False):
|
||||
self.model_response = model_response
|
||||
self.json_mode = json_mode
|
||||
self.is_done = False
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
pass
|
||||
|
||||
def __next__(self):
|
||||
if self.is_done:
|
||||
raise StopIteration
|
||||
self.is_done = True
|
||||
return self.chunk_parser(self.model_response)
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.is_done:
|
||||
raise StopAsyncIteration
|
||||
self.is_done = True
|
||||
return self.chunk_parser(self.model_response)
|
||||
@@ -0,0 +1,227 @@
|
||||
"""
|
||||
Utility functions for base LLM classes.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
from openai.lib import _parsing, _pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolCallChunk
|
||||
from litellm.types.utils import Message, ProviderSpecificModelInfo, TokenCountResponse
|
||||
|
||||
|
||||
class BaseTokenCounter(ABC):
|
||||
@abstractmethod
|
||||
async def count_tokens(
|
||||
self,
|
||||
model_to_use: str,
|
||||
messages: Optional[List[Dict[str, Any]]],
|
||||
contents: Optional[List[Dict[str, Any]]],
|
||||
deployment: Optional[Dict[str, Any]] = None,
|
||||
request_model: str = "",
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
system: Optional[Any] = None,
|
||||
) -> Optional[TokenCountResponse]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def should_use_token_counting_api(
|
||||
self,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if we should the this API for token counting for the selected `custom_llm_provider`
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class BaseLLMModelInfo(ABC):
|
||||
def get_provider_info(
|
||||
self,
|
||||
model: str,
|
||||
) -> Optional[ProviderSpecificModelInfo]:
|
||||
"""
|
||||
Default values all models of this provider support.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns a list of models supported by this provider.
|
||||
"""
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_api_base(
|
||||
api_base: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_base_model(model: str) -> Optional[str]:
|
||||
"""
|
||||
Returns the base model name from the given model name.
|
||||
|
||||
Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0`
|
||||
This function will return `anthropic.claude-3-opus-20240229-v1:0`
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_token_counter(self) -> Optional[BaseTokenCounter]:
|
||||
"""
|
||||
Factory method to create a token counter for this provider.
|
||||
|
||||
Returns:
|
||||
Optional TokenCounterInterface implementation for this provider,
|
||||
or None if token counting is not supported.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
def _convert_tool_response_to_message(
|
||||
tool_calls: List[ChatCompletionToolCallChunk],
|
||||
) -> Optional[Message]:
|
||||
"""
|
||||
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
|
||||
|
||||
"""
|
||||
## HANDLE JSON MODE - anthropic returns single function call
|
||||
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get("arguments")
|
||||
try:
|
||||
if json_mode_content_str is not None:
|
||||
args = json.loads(json_mode_content_str)
|
||||
if isinstance(args, dict) and (values := args.get("values")) is not None:
|
||||
_message = Message(content=json.dumps(values))
|
||||
return _message
|
||||
else:
|
||||
# a lot of the times the `values` key is not present in the tool response
|
||||
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||
_message = Message(content=json.dumps(args))
|
||||
return _message
|
||||
except json.JSONDecodeError:
|
||||
# json decode error does occur, return the original tool response str
|
||||
return Message(content=json_mode_content_str)
|
||||
return None
|
||||
|
||||
|
||||
def _dict_to_response_format_helper(
|
||||
response_format: dict, ref_template: Optional[str] = None
|
||||
) -> dict:
|
||||
if ref_template is not None and response_format.get("type") == "json_schema":
|
||||
# Deep copy to avoid modifying original
|
||||
modified_format = copy.deepcopy(response_format)
|
||||
schema = modified_format["json_schema"]["schema"]
|
||||
|
||||
# Update all $ref values in the schema
|
||||
def update_refs(schema):
|
||||
stack = [(schema, [])]
|
||||
visited = set()
|
||||
|
||||
while stack:
|
||||
obj, path = stack.pop()
|
||||
obj_id = id(obj)
|
||||
|
||||
if obj_id in visited:
|
||||
continue
|
||||
visited.add(obj_id)
|
||||
|
||||
if isinstance(obj, dict):
|
||||
if "$ref" in obj:
|
||||
ref_path = obj["$ref"]
|
||||
model_name = ref_path.split("/")[-1]
|
||||
obj["$ref"] = ref_template.format(model=model_name)
|
||||
|
||||
for k, v in obj.items():
|
||||
if isinstance(v, (dict, list)):
|
||||
stack.append((v, path + [k]))
|
||||
|
||||
elif isinstance(obj, list):
|
||||
for i, item in enumerate(obj):
|
||||
if isinstance(item, (dict, list)):
|
||||
stack.append((item, path + [i]))
|
||||
|
||||
update_refs(schema)
|
||||
return modified_format
|
||||
return response_format
|
||||
|
||||
|
||||
def type_to_response_format_param(
|
||||
response_format: Optional[Union[Type[BaseModel], dict]],
|
||||
ref_template: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Re-implementation of openai's 'type_to_response_format_param' function
|
||||
|
||||
Used for converting pydantic object to api schema.
|
||||
"""
|
||||
if response_format is None:
|
||||
return None
|
||||
|
||||
if isinstance(response_format, dict):
|
||||
return _dict_to_response_format_helper(response_format, ref_template)
|
||||
|
||||
# type checkers don't narrow the negation of a `TypeGuard` as it isn't
|
||||
# a safe default behaviour but we know that at this point the `response_format`
|
||||
# can only be a `type`
|
||||
if not _parsing._completions.is_basemodel_type(response_format):
|
||||
raise TypeError(f"Unsupported response_format type - {response_format}")
|
||||
|
||||
if ref_template is not None:
|
||||
schema = response_format.model_json_schema(ref_template=ref_template)
|
||||
else:
|
||||
schema = _pydantic.to_strict_json_schema(response_format)
|
||||
|
||||
return {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"schema": schema,
|
||||
"name": response_format.__name__,
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def map_developer_role_to_system_role(
|
||||
messages: List[AllMessageValues],
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Translate `developer` role to `system` role for non-OpenAI providers.
|
||||
"""
|
||||
new_messages: List[AllMessageValues] = []
|
||||
for m in messages:
|
||||
if m["role"] == "developer":
|
||||
verbose_logger.debug(
|
||||
"Translating developer role to system role for non-OpenAI providers."
|
||||
) # ensure user knows what's happening with their input.
|
||||
new_messages.append({"role": "system", "content": m["content"]})
|
||||
else:
|
||||
new_messages.append(m)
|
||||
return new_messages
|
||||
@@ -0,0 +1,218 @@
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from httpx import Headers
|
||||
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
CreateBatchRequest,
|
||||
)
|
||||
from litellm.types.utils import LiteLLMBatch, LlmProviders
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
from ..chat.transformation import BaseLLMException as _BaseLLMException
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
BaseLLMException = _BaseLLMException
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
BaseLLMException = Any
|
||||
|
||||
|
||||
class BaseBatchesConfig(ABC):
|
||||
"""
|
||||
Abstract base class for batch processing configurations across different LLM providers.
|
||||
|
||||
This class defines the interface that all provider-specific batch configurations
|
||||
must implement to work with LiteLLM's unified batch processing system.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
"""Return the LLM provider type for this configuration."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
"""Get configuration dictionary for this class."""
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate and prepare environment-specific headers and parameters.
|
||||
|
||||
Args:
|
||||
headers: HTTP headers dictionary
|
||||
model: Model name
|
||||
messages: List of messages
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters
|
||||
api_key: API key
|
||||
api_base: API base URL
|
||||
|
||||
Returns:
|
||||
Updated headers dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_batch_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
data: CreateBatchRequest,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for batch creation request.
|
||||
|
||||
Args:
|
||||
api_base: Base API URL
|
||||
api_key: API key
|
||||
model: Model name
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters
|
||||
data: Batch creation request data
|
||||
|
||||
Returns:
|
||||
Complete URL for the batch request
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_batch_request(
|
||||
self,
|
||||
model: str,
|
||||
create_batch_data: CreateBatchRequest,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> Union[bytes, str, Dict[str, Any]]:
|
||||
"""
|
||||
Transform the batch creation request to provider-specific format.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
create_batch_data: Batch creation request data
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters
|
||||
|
||||
Returns:
|
||||
Transformed request data
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_batch_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> LiteLLMBatch:
|
||||
"""
|
||||
Transform provider-specific batch response to LiteLLM format.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
litellm_params: LiteLLM parameters
|
||||
|
||||
Returns:
|
||||
LiteLLM batch object
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_retrieve_batch_request(
|
||||
self,
|
||||
batch_id: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> Union[bytes, str, Dict[str, Any]]:
|
||||
"""
|
||||
Transform the batch retrieval request to provider-specific format.
|
||||
|
||||
Args:
|
||||
batch_id: Batch ID to retrieve
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters
|
||||
|
||||
Returns:
|
||||
Transformed request data
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_retrieve_batch_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> LiteLLMBatch:
|
||||
"""
|
||||
Transform provider-specific batch retrieval response to LiteLLM format.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
litellm_params: LiteLLM parameters
|
||||
|
||||
Returns:
|
||||
LiteLLM batch object
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
|
||||
) -> "BaseLLMException":
|
||||
"""
|
||||
Get the appropriate error class for this provider.
|
||||
|
||||
Args:
|
||||
error_message: Error message
|
||||
status_code: HTTP status code
|
||||
headers: Response headers
|
||||
|
||||
Returns:
|
||||
Provider-specific exception class
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Bridge for transforming API requests to another API requests
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm import LiteLLMLoggingObj, ModelResponse
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
class CompletionTransformationBridge(ABC):
|
||||
@abstractmethod
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List["AllMessageValues"],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
litellm_logging_obj: "LiteLLMLoggingObj",
|
||||
) -> dict:
|
||||
"""Transform /chat/completions api request to another request"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: "BaseModel", # the response from the other API
|
||||
model_response: "ModelResponse",
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
request_data: dict,
|
||||
messages: List["AllMessageValues"],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> "ModelResponse":
|
||||
"""Transform another response to /chat/completions api response"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], "ModelResponse"],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> "BaseModelResponseIterator":
|
||||
pass
|
||||
@@ -0,0 +1,466 @@
|
||||
"""
|
||||
Common base config for all LLM providers
|
||||
"""
|
||||
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.constants import DEFAULT_MAX_TOKENS, RESPONSE_FORMAT_TOOL_NAME
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionToolChoiceFunctionParam,
|
||||
ChatCompletionToolChoiceObjectParam,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ..base_utils import (
|
||||
map_developer_role_to_system_role,
|
||||
type_to_response_format_param,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseLLMException(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Optional[Union[dict, httpx.Headers]] = None,
|
||||
request: Optional[httpx.Request] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
body: Optional[dict] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message: str = message
|
||||
self.headers = headers
|
||||
if request:
|
||||
self.request = request
|
||||
else:
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://docs.litellm.ai/docs"
|
||||
)
|
||||
if response:
|
||||
self.response = response
|
||||
else:
|
||||
self.response = httpx.Response(
|
||||
status_code=status_code, request=self.request
|
||||
)
|
||||
self.body = body
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class BaseConfig(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not k.startswith("_is_base_class")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
property,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
and not callable(v) # Filter out any callable objects including mocks
|
||||
}
|
||||
|
||||
def get_json_schema_from_pydantic_object(
|
||||
self, response_format: Optional[Union[Type[BaseModel], dict]]
|
||||
) -> Optional[dict]:
|
||||
return type_to_response_format_param(response_format=response_format)
|
||||
|
||||
def is_thinking_enabled(self, non_default_params: dict) -> bool:
|
||||
return (
|
||||
non_default_params.get("thinking", {}).get("type") == "enabled"
|
||||
or non_default_params.get("reasoning_effort") is not None
|
||||
)
|
||||
|
||||
def is_max_tokens_in_request(self, non_default_params: dict) -> bool:
|
||||
"""
|
||||
OpenAI spec allows max_tokens or max_completion_tokens to be specified.
|
||||
"""
|
||||
return (
|
||||
"max_tokens" in non_default_params
|
||||
or "max_completion_tokens" in non_default_params
|
||||
)
|
||||
|
||||
def update_optional_params_with_thinking_tokens(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
):
|
||||
"""
|
||||
Handles scenario where max tokens is not specified. For anthropic models (anthropic api/bedrock/vertex ai), this requires having the max tokens being set and being greater than the thinking token budget.
|
||||
|
||||
Checks 'non_default_params' for 'thinking' and 'max_tokens'
|
||||
|
||||
if 'thinking' is enabled and 'max_tokens' or 'max_completion_tokens' is not specified, set 'max_tokens' to the thinking token budget + DEFAULT_MAX_TOKENS
|
||||
"""
|
||||
is_thinking_enabled = self.is_thinking_enabled(optional_params)
|
||||
if is_thinking_enabled and (
|
||||
"max_tokens" not in non_default_params
|
||||
and "max_completion_tokens" not in non_default_params
|
||||
):
|
||||
thinking_token_budget = cast(dict, optional_params["thinking"]).get(
|
||||
"budget_tokens", None
|
||||
)
|
||||
if thinking_token_budget is not None:
|
||||
optional_params["max_tokens"] = (
|
||||
thinking_token_budget + DEFAULT_MAX_TOKENS
|
||||
)
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if the model/provider should fake stream
|
||||
"""
|
||||
return False
|
||||
|
||||
def _add_tools_to_optional_params(self, optional_params: dict, tools: List) -> dict:
|
||||
"""
|
||||
Helper util to add tools to optional_params.
|
||||
"""
|
||||
if "tools" not in optional_params:
|
||||
optional_params["tools"] = tools
|
||||
else:
|
||||
optional_params["tools"] = [
|
||||
*optional_params["tools"],
|
||||
*tools,
|
||||
]
|
||||
return optional_params
|
||||
|
||||
def translate_developer_role_to_system_role(
|
||||
self,
|
||||
messages: List[AllMessageValues],
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Translate `developer` role to `system` role for non-OpenAI providers.
|
||||
|
||||
Overriden by OpenAI/Azure
|
||||
"""
|
||||
return map_developer_role_to_system_role(messages=messages)
|
||||
|
||||
def should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||
self, e: httpx.HTTPStatusError, litellm_params: dict
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if the model/provider should retry the LLM API on UnprocessableEntityError
|
||||
|
||||
Overriden by azure ai - where different models support different parameters
|
||||
"""
|
||||
return False
|
||||
|
||||
def transform_request_on_unprocessable_entity_error(
|
||||
self, e: httpx.HTTPStatusError, request_data: dict
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the request data on UnprocessableEntityError
|
||||
"""
|
||||
return request_data
|
||||
|
||||
@property
|
||||
def max_retry_on_unprocessable_entity_error(self) -> int:
|
||||
"""
|
||||
Returns the max retry count for UnprocessableEntityError
|
||||
|
||||
Used if `should_retry_llm_api_inside_llm_translation_on_http_error` is True
|
||||
"""
|
||||
return 0
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
pass
|
||||
|
||||
def _add_response_format_to_tools(
|
||||
self,
|
||||
optional_params: dict,
|
||||
value: dict,
|
||||
is_response_format_supported: bool,
|
||||
enforce_tool_choice: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Follow similar approach to anthropic - translate to a single tool call.
|
||||
|
||||
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||||
- You usually want to provide a single tool
|
||||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||
|
||||
Add response format to tools
|
||||
|
||||
This is used to translate response_format to a tool call, for models/APIs that don't support response_format directly.
|
||||
"""
|
||||
json_schema: Optional[dict] = None
|
||||
if "response_schema" in value:
|
||||
json_schema = value["response_schema"]
|
||||
elif "json_schema" in value:
|
||||
json_schema = value["json_schema"]["schema"]
|
||||
|
||||
if json_schema and not is_response_format_supported:
|
||||
_tool_choice = ChatCompletionToolChoiceObjectParam(
|
||||
type="function",
|
||||
function=ChatCompletionToolChoiceFunctionParam(
|
||||
name=RESPONSE_FORMAT_TOOL_NAME
|
||||
),
|
||||
)
|
||||
|
||||
_tool = ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name=RESPONSE_FORMAT_TOOL_NAME, parameters=json_schema
|
||||
),
|
||||
)
|
||||
|
||||
optional_params.setdefault("tools", [])
|
||||
optional_params["tools"].append(_tool)
|
||||
if enforce_tool_choice:
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
|
||||
optional_params["json_mode"] = True
|
||||
elif is_response_format_supported:
|
||||
optional_params["response_format"] = value
|
||||
return optional_params
|
||||
|
||||
@abstractmethod
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
headers: dict,
|
||||
optional_params: dict,
|
||||
request_data: dict,
|
||||
api_base: str,
|
||||
api_key: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
fake_stream: Optional[bool] = None,
|
||||
) -> Tuple[dict, Optional[bytes]]:
|
||||
"""
|
||||
Some providers like Bedrock require signing the request. The sign request funtion needs access to `request_data` and `complete_url`
|
||||
Args:
|
||||
headers: dict
|
||||
optional_params: dict
|
||||
request_data: dict - the request body being sent in http request
|
||||
api_base: str - the complete url being sent in http request
|
||||
Returns:
|
||||
dict - the signed headers
|
||||
|
||||
Update the headers with the signed headers in this function. The return values will be sent as headers in the http request.
|
||||
"""
|
||||
return headers, None
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required")
|
||||
return api_base
|
||||
|
||||
@abstractmethod
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
async def async_transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Override to allow for http requests on async calls - e.g. converting url to base64
|
||||
|
||||
Currently only used by openai.py
|
||||
"""
|
||||
return self.transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: "ModelResponse",
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> "ModelResponse":
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
pass
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], "ModelResponse"],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
async def get_async_custom_stream_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
messages: list,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
signed_json_body: Optional[bytes] = None,
|
||||
) -> "CustomStreamWrapper":
|
||||
raise NotImplementedError
|
||||
|
||||
def get_sync_custom_stream_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
messages: list,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
signed_json_body: Optional[bytes] = None,
|
||||
) -> "CustomStreamWrapper":
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def has_custom_stream_wrapper(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_stream_param_in_request_body(self) -> bool:
|
||||
"""
|
||||
Some providers like Bedrock invoke do not support the stream parameter in the request body.
|
||||
|
||||
By default, this is true for almost all providers.
|
||||
"""
|
||||
return True
|
||||
|
||||
def post_stream_processing(self, stream: Any) -> Any:
|
||||
"""Hook for providers to post-process streaming responses. Default: pass-through."""
|
||||
return stream
|
||||
|
||||
def calculate_additional_costs(
|
||||
self, model: str, prompt_tokens: int, completion_tokens: int
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Calculate any additional costs beyond standard token costs.
|
||||
|
||||
This is used for provider-specific infrastructure costs, routing fees, etc.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
prompt_tokens: Number of prompt tokens
|
||||
completion_tokens: Number of completion tokens
|
||||
|
||||
Returns:
|
||||
Optional dictionary with cost names and amounts, e.g.:
|
||||
{"Infrastructure Fee": 0.001, "Routing Cost": 0.0005}
|
||||
Returns None if no additional costs apply.
|
||||
"""
|
||||
return None
|
||||
@@ -0,0 +1,75 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseTextCompletionConfig(BaseConfig, ABC):
|
||||
@abstractmethod
|
||||
def transform_text_completion_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"AudioTranscriptionConfig does not need a request transformation for audio transcription models"
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"AudioTranscriptionConfig does not need a response transformation for audio transcription models"
|
||||
)
|
||||
@@ -0,0 +1,268 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.containers.main import ContainerCreateOptionalRequestParams
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.types.containers.main import (
|
||||
ContainerFileListResponse as _ContainerFileListResponse,
|
||||
)
|
||||
from litellm.types.containers.main import (
|
||||
ContainerListResponse as _ContainerListResponse,
|
||||
)
|
||||
from litellm.types.containers.main import ContainerObject as _ContainerObject
|
||||
from litellm.types.containers.main import (
|
||||
DeleteContainerResult as _DeleteContainerResult,
|
||||
)
|
||||
|
||||
from ..chat.transformation import BaseLLMException as _BaseLLMException
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
BaseLLMException = _BaseLLMException
|
||||
ContainerObject = _ContainerObject
|
||||
DeleteContainerResult = _DeleteContainerResult
|
||||
ContainerListResponse = _ContainerListResponse
|
||||
ContainerFileListResponse = _ContainerFileListResponse
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
BaseLLMException = Any
|
||||
ContainerObject = Any
|
||||
DeleteContainerResult = Any
|
||||
ContainerListResponse = Any
|
||||
ContainerFileListResponse = Any
|
||||
|
||||
|
||||
class BaseContainerConfig(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_openai_params(self) -> list:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def map_openai_params(
|
||||
self,
|
||||
container_create_optional_params: ContainerCreateOptionalRequestParams,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
api_key: str | None = None,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: str | None,
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""Get the complete url for the request.
|
||||
|
||||
OPTIONAL - Some providers need `model` in `api_base`.
|
||||
"""
|
||||
if api_base is None:
|
||||
msg = "api_base is required"
|
||||
raise ValueError(msg)
|
||||
return api_base
|
||||
|
||||
@abstractmethod
|
||||
def transform_container_create_request(
|
||||
self,
|
||||
name: str,
|
||||
container_create_optional_request_params: dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""Transform the container creation request.
|
||||
|
||||
Returns:
|
||||
dict: Request data for container creation.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_container_create_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ContainerObject:
|
||||
"""Transform the container creation response."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_container_list_request(
|
||||
self,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
after: str | None = None,
|
||||
limit: int | None = None,
|
||||
order: str | None = None,
|
||||
extra_query: dict[str, Any] | None = None,
|
||||
) -> tuple[str, dict]:
|
||||
"""Transform the container list request into a URL and params.
|
||||
|
||||
Returns:
|
||||
tuple[str, dict]: (url, params) for the container list request.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_container_list_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ContainerListResponse:
|
||||
"""Transform the container list response."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_container_retrieve_request(
|
||||
self,
|
||||
container_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> tuple[str, dict]:
|
||||
"""Transform the container retrieve request into a URL and data/params.
|
||||
|
||||
Returns:
|
||||
tuple[str, dict]: (url, params) for the container retrieve request.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_container_retrieve_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ContainerObject:
|
||||
"""Transform the container retrieve response."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_container_delete_request(
|
||||
self,
|
||||
container_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> tuple[str, dict]:
|
||||
"""Transform the container delete request into a URL and data.
|
||||
|
||||
Returns:
|
||||
tuple[str, dict]: (url, data) for the container delete request.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_container_delete_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> DeleteContainerResult:
|
||||
"""Transform the container delete response."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_container_file_list_request(
|
||||
self,
|
||||
container_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
after: str | None = None,
|
||||
limit: int | None = None,
|
||||
order: str | None = None,
|
||||
extra_query: dict[str, Any] | None = None,
|
||||
) -> tuple[str, dict]:
|
||||
"""Transform the container file list request into a URL and params.
|
||||
|
||||
Returns:
|
||||
tuple[str, dict]: (url, params) for the container file list request.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_container_file_list_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ContainerFileListResponse:
|
||||
"""Transform the container file list response."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_container_file_content_request(
|
||||
self,
|
||||
container_id: str,
|
||||
file_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> tuple[str, dict]:
|
||||
"""Transform the container file content request into a URL and params.
|
||||
|
||||
Returns:
|
||||
tuple[str, dict]: (url, params) for the container file content request.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_container_file_content_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> bytes:
|
||||
"""Transform the container file content response.
|
||||
|
||||
Returns:
|
||||
bytes: The raw file content.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_error_class(
|
||||
self,
|
||||
error_message: str,
|
||||
status_code: int,
|
||||
headers: dict | httpx.Headers,
|
||||
) -> BaseLLMException:
|
||||
from ..chat.transformation import BaseLLMException
|
||||
|
||||
raise BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,89 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
|
||||
from litellm.types.utils import EmbeddingResponse, ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseEmbeddingConfig(BaseConfig, ABC):
|
||||
@abstractmethod
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: AllEmbeddingInputValues,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def transform_embedding_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str],
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> EmbeddingResponse:
|
||||
return model_response
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"EmbeddingConfig does not need a request transformation for chat models"
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"EmbeddingConfig does not need a response transformation for chat models"
|
||||
)
|
||||
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Base configuration for Evals API
|
||||
"""
|
||||
|
||||
from .transformation import BaseEvalsAPIConfig
|
||||
|
||||
__all__ = ["BaseEvalsAPIConfig"]
|
||||
@@ -0,0 +1,542 @@
|
||||
"""
|
||||
Base configuration class for Evals API
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.openai_evals import (
|
||||
CancelEvalResponse,
|
||||
CancelRunResponse,
|
||||
CreateEvalRequest,
|
||||
CreateRunRequest,
|
||||
DeleteEvalResponse,
|
||||
Eval,
|
||||
ListEvalsParams,
|
||||
ListEvalsResponse,
|
||||
ListRunsParams,
|
||||
ListRunsResponse,
|
||||
Run,
|
||||
RunDeleteResponse,
|
||||
UpdateEvalRequest,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseEvalsAPIConfig(ABC):
|
||||
"""Base configuration for Evals API providers"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
"""
|
||||
Validate and update headers with provider-specific requirements
|
||||
|
||||
Args:
|
||||
headers: Base headers dictionary
|
||||
litellm_params: LiteLLM parameters
|
||||
|
||||
Returns:
|
||||
Updated headers dictionary
|
||||
"""
|
||||
return headers
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
endpoint: str,
|
||||
eval_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for the API request
|
||||
|
||||
Args:
|
||||
api_base: Base API URL
|
||||
endpoint: API endpoint (e.g., 'evals', 'evals/{id}')
|
||||
eval_id: Optional eval ID for specific eval operations
|
||||
|
||||
Returns:
|
||||
Complete URL
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required")
|
||||
return f"{api_base}/v1/{endpoint}"
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_eval_request(
|
||||
self,
|
||||
create_request: CreateEvalRequest,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
"""
|
||||
Transform create eval request to provider-specific format
|
||||
|
||||
Args:
|
||||
create_request: Eval creation parameters
|
||||
litellm_params: LiteLLM parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Provider-specific request body
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_eval_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Eval:
|
||||
"""
|
||||
Transform provider response to Eval object
|
||||
|
||||
Args:
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
Eval object
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_list_evals_request(
|
||||
self,
|
||||
list_params: ListEvalsParams,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform list evals request parameters
|
||||
|
||||
Args:
|
||||
list_params: List parameters (pagination, filters)
|
||||
litellm_params: LiteLLM parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Tuple of (url, query_params)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_list_evals_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ListEvalsResponse:
|
||||
"""
|
||||
Transform provider response to ListEvalsResponse
|
||||
|
||||
Args:
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
ListEvalsResponse object
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_get_eval_request(
|
||||
self,
|
||||
eval_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform get eval request
|
||||
|
||||
Args:
|
||||
eval_id: Eval ID
|
||||
api_base: Base API URL
|
||||
litellm_params: LiteLLM parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Tuple of (url, headers)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_get_eval_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Eval:
|
||||
"""
|
||||
Transform provider response to Eval object
|
||||
|
||||
Args:
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
Eval object
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_update_eval_request(
|
||||
self,
|
||||
eval_id: str,
|
||||
update_request: UpdateEvalRequest,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict, Dict]:
|
||||
"""
|
||||
Transform update eval request
|
||||
|
||||
Args:
|
||||
eval_id: Eval ID
|
||||
update_request: Update parameters
|
||||
api_base: Base API URL
|
||||
litellm_params: LiteLLM parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Tuple of (url, headers, body)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_update_eval_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Eval:
|
||||
"""
|
||||
Transform provider response to Eval object
|
||||
|
||||
Args:
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
Eval object
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_delete_eval_request(
|
||||
self,
|
||||
eval_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform delete eval request
|
||||
|
||||
Args:
|
||||
eval_id: Eval ID
|
||||
api_base: Base API URL
|
||||
litellm_params: LiteLLM parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Tuple of (url, headers)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_delete_eval_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> DeleteEvalResponse:
|
||||
"""
|
||||
Transform provider response to DeleteEvalResponse
|
||||
|
||||
Args:
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
DeleteEvalResponse object
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_cancel_eval_request(
|
||||
self,
|
||||
eval_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict, Dict]:
|
||||
"""
|
||||
Transform cancel eval request
|
||||
|
||||
Args:
|
||||
eval_id: Eval ID
|
||||
api_base: Base API URL
|
||||
litellm_params: LiteLLM parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Tuple of (url, headers, body)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_cancel_eval_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> CancelEvalResponse:
|
||||
"""
|
||||
Transform provider response to CancelEvalResponse
|
||||
|
||||
Args:
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
CancelEvalResponse object
|
||||
"""
|
||||
pass
|
||||
|
||||
# Run API Transformations
|
||||
@abstractmethod
|
||||
def transform_create_run_request(
|
||||
self,
|
||||
eval_id: str,
|
||||
create_request: CreateRunRequest,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform create run request to provider-specific format
|
||||
|
||||
Args:
|
||||
eval_id: Eval ID
|
||||
create_request: Run creation parameters
|
||||
litellm_params: LiteLLM parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Tuple of (url, request_body)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_run_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Run:
|
||||
"""
|
||||
Transform provider response to Run object
|
||||
|
||||
Args:
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
Run object
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_list_runs_request(
|
||||
self,
|
||||
eval_id: str,
|
||||
list_params: ListRunsParams,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform list runs request parameters
|
||||
|
||||
Args:
|
||||
eval_id: Eval ID
|
||||
list_params: List parameters (pagination, filters)
|
||||
litellm_params: LiteLLM parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Tuple of (url, query_params)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_list_runs_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ListRunsResponse:
|
||||
"""
|
||||
Transform provider response to ListRunsResponse
|
||||
|
||||
Args:
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
ListRunsResponse object
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_get_run_request(
|
||||
self,
|
||||
eval_id: str,
|
||||
run_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform get run request
|
||||
|
||||
Args:
|
||||
eval_id: Eval ID
|
||||
run_id: Run ID
|
||||
api_base: Base API URL
|
||||
litellm_params: LiteLLM parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Tuple of (url, headers)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_get_run_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Run:
|
||||
"""
|
||||
Transform provider response to Run object
|
||||
|
||||
Args:
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
Run object
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_cancel_run_request(
|
||||
self,
|
||||
eval_id: str,
|
||||
run_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict, Dict]:
|
||||
"""
|
||||
Transform cancel run request
|
||||
|
||||
Args:
|
||||
eval_id: Eval ID
|
||||
run_id: Run ID
|
||||
api_base: Base API URL
|
||||
litellm_params: LiteLLM parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Tuple of (url, headers, body)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_cancel_run_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> CancelRunResponse:
|
||||
"""
|
||||
Transform provider response to CancelRunResponse
|
||||
|
||||
Args:
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
CancelRunResponse object
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_delete_run_request(
|
||||
self,
|
||||
eval_id: str,
|
||||
run_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict, Dict]:
|
||||
"""
|
||||
Transform delete run request
|
||||
|
||||
Args:
|
||||
eval_id: Eval ID
|
||||
run_id: Run ID
|
||||
api_base: Base API URL
|
||||
litellm_params: LiteLLM parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Tuple of (url, headers, body)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_delete_run_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> "RunDeleteResponse":
|
||||
"""
|
||||
Transform provider response to RunDeleteResponse
|
||||
|
||||
Args:
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
RunDeleteResponse object
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_error_class(
|
||||
self,
|
||||
error_message: str,
|
||||
status_code: int,
|
||||
headers: dict,
|
||||
) -> Exception:
|
||||
"""Get appropriate error class for the provider."""
|
||||
return BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,325 @@
|
||||
"""
|
||||
Azure Blob Storage backend implementation for file storage.
|
||||
|
||||
This module implements the Azure Blob Storage backend for storing files
|
||||
in Azure Data Lake Storage Gen2. It inherits from AzureBlobStorageLogger
|
||||
to reuse all authentication and Azure Storage operations.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm._uuid import uuid
|
||||
|
||||
from .storage_backend import BaseFileStorageBackend
|
||||
from litellm.integrations.azure_storage.azure_storage import AzureBlobStorageLogger
|
||||
|
||||
|
||||
class AzureBlobStorageBackend(BaseFileStorageBackend, AzureBlobStorageLogger):
|
||||
"""
|
||||
Azure Blob Storage backend implementation.
|
||||
|
||||
Inherits from AzureBlobStorageLogger to reuse:
|
||||
- Authentication (account key and Azure AD)
|
||||
- Service client management
|
||||
- Token management
|
||||
- All Azure Storage helper methods
|
||||
|
||||
Reads configuration from the same environment variables as AzureBlobStorageLogger.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initialize Azure Blob Storage backend.
|
||||
|
||||
Inherits all functionality from AzureBlobStorageLogger which handles:
|
||||
- Reading environment variables
|
||||
- Authentication (account key and Azure AD)
|
||||
- Service client management
|
||||
- Token management
|
||||
|
||||
Environment variables (same as AzureBlobStorageLogger):
|
||||
- AZURE_STORAGE_ACCOUNT_NAME (required)
|
||||
- AZURE_STORAGE_FILE_SYSTEM (required)
|
||||
- AZURE_STORAGE_ACCOUNT_KEY (optional, if using account key auth)
|
||||
- AZURE_STORAGE_TENANT_ID (optional, if using Azure AD)
|
||||
- AZURE_STORAGE_CLIENT_ID (optional, if using Azure AD)
|
||||
- AZURE_STORAGE_CLIENT_SECRET (optional, if using Azure AD)
|
||||
|
||||
Note: We skip periodic_flush since we're not using this as a logger.
|
||||
"""
|
||||
# Initialize AzureBlobStorageLogger (handles all auth and config)
|
||||
AzureBlobStorageLogger.__init__(self, **kwargs)
|
||||
|
||||
# Disable logging functionality - we're only using this for file storage
|
||||
# The periodic_flush task will be created but will do nothing since we override it
|
||||
|
||||
async def periodic_flush(self):
|
||||
"""
|
||||
Override to do nothing - we're not using this as a logger.
|
||||
This prevents the periodic flush task from doing any work.
|
||||
"""
|
||||
# Do nothing - this class is used for file storage, not logging
|
||||
return
|
||||
|
||||
async def async_log_success_event(self, *args, **kwargs):
|
||||
"""
|
||||
Override to do nothing - we're not using this as a logger.
|
||||
"""
|
||||
# Do nothing - this class is used for file storage, not logging
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, *args, **kwargs):
|
||||
"""
|
||||
Override to do nothing - we're not using this as a logger.
|
||||
"""
|
||||
# Do nothing - this class is used for file storage, not logging
|
||||
pass
|
||||
|
||||
def _generate_file_name(
|
||||
self, original_filename: str, file_naming_strategy: str
|
||||
) -> str:
|
||||
"""Generate file name based on naming strategy."""
|
||||
if file_naming_strategy == "original_filename":
|
||||
# Use original filename, but sanitize it
|
||||
return quote(original_filename, safe="")
|
||||
elif file_naming_strategy == "timestamp":
|
||||
# Use timestamp
|
||||
extension = (
|
||||
original_filename.split(".")[-1] if "." in original_filename else ""
|
||||
)
|
||||
timestamp = int(time.time() * 1000) # milliseconds
|
||||
return f"{timestamp}.{extension}" if extension else str(timestamp)
|
||||
else: # default to "uuid"
|
||||
# Use UUID
|
||||
extension = (
|
||||
original_filename.split(".")[-1] if "." in original_filename else ""
|
||||
)
|
||||
file_uuid = str(uuid.uuid4())
|
||||
return f"{file_uuid}.{extension}" if extension else file_uuid
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
content_type: str,
|
||||
path_prefix: Optional[str] = None,
|
||||
file_naming_strategy: str = "uuid",
|
||||
) -> str:
|
||||
"""
|
||||
Upload a file to Azure Blob Storage.
|
||||
|
||||
Returns the blob URL in format: https://{account}.blob.core.windows.net/{container}/{path}
|
||||
"""
|
||||
try:
|
||||
# Generate file name
|
||||
file_name = self._generate_file_name(filename, file_naming_strategy)
|
||||
|
||||
# Build full path
|
||||
if path_prefix:
|
||||
# Remove leading/trailing slashes and normalize
|
||||
prefix = path_prefix.strip("/")
|
||||
full_path = f"{prefix}/{file_name}"
|
||||
else:
|
||||
full_path = file_name
|
||||
|
||||
if self.azure_storage_account_key:
|
||||
# Use Azure SDK with account key (reuse logger's method)
|
||||
storage_url = await self._upload_file_with_account_key(
|
||||
file_content=file_content,
|
||||
full_path=full_path,
|
||||
)
|
||||
else:
|
||||
# Use REST API with Azure AD token (reuse logger's methods)
|
||||
storage_url = await self._upload_file_with_azure_ad(
|
||||
file_content=file_content,
|
||||
full_path=full_path,
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Successfully uploaded file to Azure Blob Storage: {storage_url}"
|
||||
)
|
||||
return storage_url
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Error uploading file to Azure Blob Storage: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def _upload_file_with_account_key(
|
||||
self, file_content: bytes, full_path: str
|
||||
) -> str:
|
||||
"""Upload file using Azure SDK with account key authentication."""
|
||||
# Reuse the logger's service client method
|
||||
service_client = await self.get_service_client()
|
||||
file_system_client = service_client.get_file_system_client(
|
||||
file_system=self.azure_storage_file_system
|
||||
)
|
||||
|
||||
# Create filesystem (container) if it doesn't exist
|
||||
if not await file_system_client.exists():
|
||||
await file_system_client.create_file_system()
|
||||
verbose_logger.debug(
|
||||
f"Created filesystem: {self.azure_storage_file_system}"
|
||||
)
|
||||
|
||||
# Extract directory and filename (similar to logger's pattern)
|
||||
path_parts = full_path.split("/")
|
||||
if len(path_parts) > 1:
|
||||
directory_path = "/".join(path_parts[:-1])
|
||||
file_name = path_parts[-1]
|
||||
|
||||
# Create directory if needed (like logger does)
|
||||
directory_client = file_system_client.get_directory_client(directory_path)
|
||||
if not await directory_client.exists():
|
||||
await directory_client.create_directory()
|
||||
verbose_logger.debug(f"Created directory: {directory_path}")
|
||||
|
||||
# Get file client from directory (same pattern as logger)
|
||||
file_client = directory_client.get_file_client(file_name)
|
||||
else:
|
||||
# No directory, create file directly in root
|
||||
file_client = file_system_client.get_file_client(full_path)
|
||||
|
||||
# Create, append, and flush (same pattern as logger's upload_to_azure_data_lake_with_azure_account_key)
|
||||
await file_client.create_file()
|
||||
await file_client.append_data(
|
||||
data=file_content, offset=0, length=len(file_content)
|
||||
)
|
||||
await file_client.flush_data(position=len(file_content), offset=0)
|
||||
|
||||
# Return blob URL (not DFS URL)
|
||||
blob_url = f"https://{self.azure_storage_account_name}.blob.core.windows.net/{self.azure_storage_file_system}/{full_path}"
|
||||
return blob_url
|
||||
|
||||
async def _upload_file_with_azure_ad(
|
||||
self, file_content: bytes, full_path: str
|
||||
) -> str:
|
||||
"""Upload file using REST API with Azure AD authentication."""
|
||||
# Reuse the logger's token management
|
||||
await self.set_valid_azure_ad_token()
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
|
||||
# Use DFS endpoint for upload
|
||||
base_url = f"https://{self.azure_storage_account_name}.dfs.core.windows.net/{self.azure_storage_file_system}/{full_path}"
|
||||
|
||||
# Execute 3-step upload process: create, append, flush
|
||||
# Reuse the logger's helper methods
|
||||
await self._create_file(async_client, base_url)
|
||||
# Append data - logger's _append_data expects string, so we create our own for bytes
|
||||
await self._append_data_bytes(async_client, base_url, file_content)
|
||||
await self._flush_data(async_client, base_url, len(file_content))
|
||||
|
||||
# Return blob URL (not DFS URL)
|
||||
blob_url = f"https://{self.azure_storage_account_name}.blob.core.windows.net/{self.azure_storage_file_system}/{full_path}"
|
||||
return blob_url
|
||||
|
||||
async def _append_data_bytes(self, client, base_url: str, file_content: bytes):
|
||||
"""Append binary data to file using REST API."""
|
||||
from litellm.constants import AZURE_STORAGE_MSFT_VERSION
|
||||
|
||||
headers = {
|
||||
"x-ms-version": AZURE_STORAGE_MSFT_VERSION,
|
||||
"Content-Type": "application/octet-stream",
|
||||
"Authorization": f"Bearer {self.azure_auth_token}",
|
||||
}
|
||||
response = await client.patch(
|
||||
f"{base_url}?action=append&position=0",
|
||||
headers=headers,
|
||||
content=file_content,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
async def download_file(self, storage_url: str) -> bytes:
|
||||
"""
|
||||
Download a file from Azure Blob Storage.
|
||||
|
||||
Args:
|
||||
storage_url: Blob URL in format: https://{account}.blob.core.windows.net/{container}/{path}
|
||||
|
||||
Returns:
|
||||
bytes: File content
|
||||
"""
|
||||
try:
|
||||
# Parse blob URL to extract path
|
||||
# URL format: https://{account}.blob.core.windows.net/{container}/{path}
|
||||
if ".blob.core.windows.net/" not in storage_url:
|
||||
raise ValueError(f"Invalid Azure Blob Storage URL: {storage_url}")
|
||||
|
||||
# Extract path after container name
|
||||
container_and_path = storage_url.split(".blob.core.windows.net/", 1)[1]
|
||||
path_parts = container_and_path.split("/", 1)
|
||||
if len(path_parts) < 2:
|
||||
raise ValueError(
|
||||
f"Invalid Azure Blob Storage URL format: {storage_url}"
|
||||
)
|
||||
file_path = path_parts[1] # Path after container name
|
||||
|
||||
if self.azure_storage_account_key:
|
||||
# Use Azure SDK (reuse logger's service client)
|
||||
return await self._download_file_with_account_key(file_path)
|
||||
else:
|
||||
# Use REST API (reuse logger's token management)
|
||||
return await self._download_file_with_azure_ad(file_path)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Error downloading file from Azure Blob Storage: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def _download_file_with_account_key(self, file_path: str) -> bytes:
|
||||
"""Download file using Azure SDK with account key."""
|
||||
# Reuse the logger's service client method
|
||||
service_client = await self.get_service_client()
|
||||
file_system_client = service_client.get_file_system_client(
|
||||
file_system=self.azure_storage_file_system
|
||||
)
|
||||
# Ensure filesystem exists (should already exist, but check for safety)
|
||||
if not await file_system_client.exists():
|
||||
raise ValueError(
|
||||
f"Filesystem {self.azure_storage_file_system} does not exist"
|
||||
)
|
||||
file_client = file_system_client.get_file_client(file_path)
|
||||
# Download file
|
||||
download_response = await file_client.download_file()
|
||||
file_content = await download_response.readall()
|
||||
return file_content
|
||||
|
||||
async def _download_file_with_azure_ad(self, file_path: str) -> bytes:
|
||||
"""Download file using REST API with Azure AD token."""
|
||||
# Reuse the logger's token management
|
||||
await self.set_valid_azure_ad_token()
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.constants import AZURE_STORAGE_MSFT_VERSION
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
|
||||
# Use blob endpoint for download (simpler than DFS)
|
||||
blob_url = f"https://{self.azure_storage_account_name}.blob.core.windows.net/{self.azure_storage_file_system}/{file_path}"
|
||||
|
||||
headers = {
|
||||
"x-ms-version": AZURE_STORAGE_MSFT_VERSION,
|
||||
"Authorization": f"Bearer {self.azure_auth_token}",
|
||||
}
|
||||
|
||||
response = await async_client.get(blob_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Base storage backend interface for file storage backends.
|
||||
|
||||
This module defines the abstract base class that all file storage backends
|
||||
(e.g., Azure Blob Storage, S3, GCS) must implement.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class BaseFileStorageBackend(ABC):
|
||||
"""
|
||||
Abstract base class for file storage backends.
|
||||
|
||||
All storage backends (Azure Blob Storage, S3, GCS, etc.) must implement
|
||||
these methods to provide a consistent interface for file operations.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def upload_file(
|
||||
self,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
content_type: str,
|
||||
path_prefix: Optional[str] = None,
|
||||
file_naming_strategy: str = "uuid",
|
||||
) -> str:
|
||||
"""
|
||||
Upload a file to the storage backend.
|
||||
|
||||
Args:
|
||||
file_content: The file content as bytes
|
||||
filename: Original filename (may be used for naming strategy)
|
||||
content_type: MIME type of the file
|
||||
path_prefix: Optional path prefix for organizing files
|
||||
file_naming_strategy: Strategy for naming files ("uuid", "timestamp", "original_filename")
|
||||
|
||||
Returns:
|
||||
str: The storage URL where the file can be accessed/downloaded
|
||||
|
||||
Raises:
|
||||
Exception: If upload fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def download_file(self, storage_url: str) -> bytes:
|
||||
"""
|
||||
Download a file from the storage backend.
|
||||
|
||||
Args:
|
||||
storage_url: The storage URL returned from upload_file
|
||||
|
||||
Returns:
|
||||
bytes: The file content
|
||||
|
||||
Raises:
|
||||
Exception: If download fails
|
||||
"""
|
||||
pass
|
||||
|
||||
async def delete_file(self, storage_url: str) -> None:
|
||||
"""
|
||||
Delete a file from the storage backend.
|
||||
|
||||
This is optional and can be overridden by backends that support deletion.
|
||||
Default implementation does nothing.
|
||||
|
||||
Args:
|
||||
storage_url: The storage URL of the file to delete
|
||||
|
||||
Raises:
|
||||
Exception: If deletion fails
|
||||
"""
|
||||
# Default implementation: no-op
|
||||
# Backends can override if they support deletion
|
||||
pass
|
||||
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Factory for creating storage backend instances.
|
||||
|
||||
This module provides a factory function to instantiate the correct storage backend
|
||||
based on the backend type. Backends use the same configuration as their corresponding
|
||||
callbacks (e.g., azure_storage uses the same env vars as AzureBlobStorageLogger).
|
||||
"""
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
from .azure_blob_storage_backend import AzureBlobStorageBackend
|
||||
from .storage_backend import BaseFileStorageBackend
|
||||
|
||||
|
||||
def get_storage_backend(backend_type: str) -> BaseFileStorageBackend:
|
||||
"""
|
||||
Factory function to create a storage backend instance.
|
||||
|
||||
Backends are configured using the same environment variables as their
|
||||
corresponding callbacks. For example, "azure_storage" uses the same
|
||||
env vars as AzureBlobStorageLogger.
|
||||
|
||||
Args:
|
||||
backend_type: Backend type identifier (e.g., "azure_storage")
|
||||
|
||||
Returns:
|
||||
BaseFileStorageBackend: Instance of the appropriate storage backend
|
||||
|
||||
Raises:
|
||||
ValueError: If backend_type is not supported
|
||||
"""
|
||||
verbose_logger.debug(f"Creating storage backend: type={backend_type}")
|
||||
|
||||
if backend_type == "azure_storage":
|
||||
return AzureBlobStorageBackend()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported storage backend type: {backend_type}. "
|
||||
f"Supported types: azure_storage"
|
||||
)
|
||||
@@ -0,0 +1,261 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai.types.file_deleted import FileDeleted
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.files import TwoStepFileUploadConfig
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
CreateFileRequest,
|
||||
FileContentRequest,
|
||||
OpenAICreateFileRequestOptionalParams,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilesPurpose,
|
||||
)
|
||||
from litellm.types.utils import LlmProviders, ModelResponse
|
||||
|
||||
from ..chat.transformation import BaseConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.router import Router as _Router
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
Span = Any
|
||||
Router = _Router
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
Span = Any
|
||||
Router = Any
|
||||
|
||||
|
||||
class BaseFilesConfig(BaseConfig):
|
||||
@property
|
||||
@abstractmethod
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
pass
|
||||
|
||||
@property
|
||||
def file_upload_http_method(self) -> str:
|
||||
"""
|
||||
HTTP method to use for file uploads.
|
||||
Override this in provider configs if they need different methods.
|
||||
Default is POST (used by most providers like OpenAI, Anthropic).
|
||||
S3-based providers like Bedrock should return "PUT".
|
||||
"""
|
||||
return "POST"
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAICreateFileRequestOptionalParams]:
|
||||
pass
|
||||
|
||||
def get_complete_file_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
data: CreateFileRequest,
|
||||
):
|
||||
return self.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_file_request(
|
||||
self,
|
||||
model: str,
|
||||
create_file_data: CreateFileRequest,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> Union[dict, str, bytes, "TwoStepFileUploadConfig"]:
|
||||
"""
|
||||
Transform OpenAI-style file creation request into provider-specific format.
|
||||
|
||||
Returns:
|
||||
- dict: For pre-signed single-step uploads (e.g., Bedrock S3)
|
||||
- str/bytes: For traditional file uploads
|
||||
- TwoStepFileUploadConfig: For two-step upload process (e.g., Manus, GCS)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_file_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> OpenAIFileObject:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_retrieve_file_request(
|
||||
self,
|
||||
file_id: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
"""Transform file retrieve request into provider-specific format."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_retrieve_file_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> OpenAIFileObject:
|
||||
"""Transform file retrieve response into OpenAI format."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_delete_file_request(
|
||||
self,
|
||||
file_id: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
"""Transform file delete request into provider-specific format."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_delete_file_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> "FileDeleted":
|
||||
"""Transform file delete response into OpenAI format."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_list_files_request(
|
||||
self,
|
||||
purpose: Optional[str],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
"""Transform file list request into provider-specific format."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_list_files_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> List[OpenAIFileObject]:
|
||||
"""Transform file list response into OpenAI format."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_file_content_request(
|
||||
self,
|
||||
file_content_request: "FileContentRequest",
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
"""Transform file content request into provider-specific format."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_file_content_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> "HttpxBinaryResponseContent":
|
||||
"""Transform file content response into OpenAI format."""
|
||||
pass
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"AudioTranscriptionConfig does not need a request transformation for audio transcription models"
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"AudioTranscriptionConfig does not need a response transformation for audio transcription models"
|
||||
)
|
||||
|
||||
|
||||
class BaseFileEndpoints(ABC):
|
||||
@abstractmethod
|
||||
async def acreate_file(
|
||||
self,
|
||||
create_file_request: CreateFileRequest,
|
||||
llm_router: Router,
|
||||
target_model_names_list: List[str],
|
||||
litellm_parent_otel_span: Span,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> OpenAIFileObject:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def afile_retrieve(
|
||||
self,
|
||||
file_id: str,
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
llm_router: Optional[Router] = None,
|
||||
) -> OpenAIFileObject:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def afile_list(
|
||||
self,
|
||||
purpose: Optional[OpenAIFilesPurpose],
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
**data: Dict,
|
||||
) -> List[OpenAIFileObject]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def afile_delete(
|
||||
self,
|
||||
file_id: str,
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
llm_router: Router,
|
||||
**data: Dict,
|
||||
) -> OpenAIFileObject:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def afile_content(
|
||||
self,
|
||||
file_id: str,
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
llm_router: Router,
|
||||
**data: Dict,
|
||||
) -> "HttpxBinaryResponseContent":
|
||||
pass
|
||||
@@ -0,0 +1,211 @@
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.types.google_genai.main import (
|
||||
GenerateContentConfigDict,
|
||||
GenerateContentContentListUnionDict,
|
||||
GenerateContentResponse,
|
||||
ToolConfigDict,
|
||||
)
|
||||
else:
|
||||
GenerateContentConfigDict = Any
|
||||
GenerateContentContentListUnionDict = Any
|
||||
GenerateContentResponse = Any
|
||||
LiteLLMLoggingObj = Any
|
||||
ToolConfigDict = Any
|
||||
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
|
||||
class BaseGoogleGenAIGenerateContentConfig(ABC):
|
||||
"""Base configuration class for Google GenAI generate_content functionality"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_generate_content_optional_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
Get the list of supported Google GenAI parameters for the model.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
|
||||
Returns:
|
||||
List of supported parameter names
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"get_supported_generate_content_optional_params is not implemented"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def map_generate_content_optional_params(
|
||||
self,
|
||||
generate_content_config_dict: GenerateContentConfigDict,
|
||||
model: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Map Google GenAI parameters to provider-specific format.
|
||||
|
||||
Args:
|
||||
generate_content_optional_params: Optional parameters for generate content
|
||||
model: The model name
|
||||
|
||||
Returns:
|
||||
Mapped parameters for the provider
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"map_generate_content_optional_params is not implemented"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
headers: Optional[dict],
|
||||
model: str,
|
||||
litellm_params: Optional[Union[GenericLiteLLMParams, dict]],
|
||||
) -> dict:
|
||||
"""
|
||||
Validate the environment and return headers for the request.
|
||||
|
||||
Args:
|
||||
api_key: API key
|
||||
headers: Existing headers
|
||||
model: The model name
|
||||
litellm_params: LiteLLM parameters
|
||||
|
||||
Returns:
|
||||
Updated headers
|
||||
"""
|
||||
raise NotImplementedError("validate_environment is not implemented")
|
||||
|
||||
def sync_get_auth_token_and_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
litellm_params: dict,
|
||||
stream: bool,
|
||||
) -> Tuple[dict, str]:
|
||||
"""
|
||||
Sync version of get_auth_token_and_url.
|
||||
|
||||
Args:
|
||||
api_base: Base API URL
|
||||
model: The model name
|
||||
litellm_params: LiteLLM parameters
|
||||
stream: Whether this is a streaming call
|
||||
|
||||
Returns:
|
||||
Tuple of headers and API base
|
||||
"""
|
||||
raise NotImplementedError("sync_get_auth_token_and_url is not implemented")
|
||||
|
||||
async def get_auth_token_and_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
litellm_params: dict,
|
||||
stream: bool,
|
||||
) -> Tuple[dict, str]:
|
||||
"""
|
||||
Get the complete URL for the request.
|
||||
|
||||
Args:
|
||||
api_base: Base API URL
|
||||
model: The model name
|
||||
litellm_params: LiteLLM parameters
|
||||
|
||||
Returns:
|
||||
Tuple of headers and API base
|
||||
"""
|
||||
raise NotImplementedError("get_auth_token_and_url is not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def transform_generate_content_request(
|
||||
self,
|
||||
model: str,
|
||||
contents: GenerateContentContentListUnionDict,
|
||||
tools: Optional[ToolConfigDict],
|
||||
generate_content_config_dict: Dict,
|
||||
system_instruction: Optional[Any] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the request parameters for the generate content API.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
contents: Input contents
|
||||
tools: Tools
|
||||
generate_content_config_dict: Generation config parameters
|
||||
system_instruction: Optional system instruction
|
||||
|
||||
Returns:
|
||||
Transformed request data
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_generate_content_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> GenerateContentResponse:
|
||||
"""
|
||||
Transform the raw response from the generate content API.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
raw_response: Raw HTTP response
|
||||
|
||||
Returns:
|
||||
Transformed response data
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> Exception:
|
||||
"""
|
||||
Get the appropriate exception class for the error.
|
||||
|
||||
Args:
|
||||
error_message: Error message
|
||||
status_code: HTTP status code
|
||||
headers: Response headers
|
||||
|
||||
Returns:
|
||||
Exception instance
|
||||
"""
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
return BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,107 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
class BaseTranslation(ABC):
|
||||
@staticmethod
|
||||
def transform_user_api_key_dict_to_metadata(
|
||||
user_api_key_dict: Optional[Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform user_api_key_dict to a metadata dict with prefixed keys.
|
||||
|
||||
Converts keys like 'user_id' to 'user_api_key_user_id' to clearly indicate
|
||||
the source of the metadata.
|
||||
|
||||
Args:
|
||||
user_api_key_dict: UserAPIKeyAuth object or dict with user information
|
||||
|
||||
Returns:
|
||||
Dict with keys prefixed with 'user_api_key_'
|
||||
"""
|
||||
if user_api_key_dict is None:
|
||||
return {}
|
||||
|
||||
# Convert to dict if it's a Pydantic object
|
||||
user_dict = (
|
||||
user_api_key_dict.model_dump()
|
||||
if hasattr(user_api_key_dict, "model_dump")
|
||||
else user_api_key_dict
|
||||
)
|
||||
|
||||
if not isinstance(user_dict, dict):
|
||||
return {}
|
||||
|
||||
# Transform keys to be prefixed with 'user_api_key_'
|
||||
transformed = {}
|
||||
for key, value in user_dict.items():
|
||||
# Skip None values and internal fields
|
||||
if value is None or key.startswith("_"):
|
||||
continue
|
||||
|
||||
# If key already has the prefix, use as-is, otherwise add prefix
|
||||
if key.startswith("user_api_key_"):
|
||||
transformed[key] = value
|
||||
else:
|
||||
transformed[f"user_api_key_{key}"] = value
|
||||
|
||||
return transformed
|
||||
|
||||
@abstractmethod
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input messages with guardrails.
|
||||
|
||||
Note: user_api_key_dict metadata should be available in the data dict.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: Any,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
|
||||
user_api_key_dict: Optional["UserAPIKeyAuth"] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output response with guardrails.
|
||||
|
||||
Args:
|
||||
response: The response object from the LLM
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
user_api_key_dict: User API key metadata (passed separately since response doesn't contain it)
|
||||
"""
|
||||
pass
|
||||
|
||||
async def process_output_streaming_response(
|
||||
self,
|
||||
responses_so_far: List[Any],
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
|
||||
user_api_key_dict: Optional["UserAPIKeyAuth"] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output streaming response with guardrails.
|
||||
|
||||
Optional to override in subclasses.
|
||||
"""
|
||||
return responses_so_far
|
||||
|
||||
def extract_request_tool_names(self, data: dict) -> List[str]:
|
||||
"""
|
||||
Extract tool names from the request body for allowlist/policy checks.
|
||||
Override in tool-capable handlers; default returns [].
|
||||
"""
|
||||
return []
|
||||
@@ -0,0 +1,130 @@
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from httpx._types import RequestFiles
|
||||
|
||||
from litellm.types.images.main import ImageEditOptionalRequestParams
|
||||
from litellm.types.responses.main import *
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import FileTypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.utils import ImageResponse as _ImageResponse
|
||||
|
||||
from ..chat.transformation import BaseLLMException as _BaseLLMException
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
BaseLLMException = _BaseLLMException
|
||||
ImageResponse = _ImageResponse
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
BaseLLMException = Any
|
||||
ImageResponse = Any
|
||||
|
||||
|
||||
class BaseImageEditConfig(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def map_openai_params(
|
||||
self,
|
||||
image_edit_optional_params: ImageEditOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required")
|
||||
return api_base
|
||||
|
||||
@abstractmethod
|
||||
def transform_image_edit_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Optional[str],
|
||||
image: Optional[FileTypes],
|
||||
image_edit_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[Dict, RequestFiles]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_image_edit_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ImageResponse:
|
||||
pass
|
||||
|
||||
def use_multipart_form_data(self) -> bool:
|
||||
"""
|
||||
Return True if the provider uses multipart/form-data for image edit requests.
|
||||
Return False if the provider uses JSON requests.
|
||||
|
||||
Default is True for backwards compatibility with OpenAI-style providers.
|
||||
"""
|
||||
return True
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
from ..chat.transformation import BaseLLMException
|
||||
|
||||
raise BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,112 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
OpenAIImageGenerationOptionalParams,
|
||||
)
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseImageGenerationConfig(ABC):
|
||||
@abstractmethod
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageGenerationOptionalParams]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
raise BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def transform_image_generation_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"ImageVariationConfig implementa 'transform_request_image_variation' for image variation models"
|
||||
)
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
raise NotImplementedError(
|
||||
"ImageVariationConfig implements 'transform_response_image_variation' for image variation models"
|
||||
)
|
||||
|
||||
def use_multipart_form_data(self) -> bool:
|
||||
"""
|
||||
Returns True if this provider requires multipart/form-data instead of JSON.
|
||||
|
||||
Override this method in subclasses that need form-data (e.g., Stability AI).
|
||||
"""
|
||||
return False
|
||||
@@ -0,0 +1,134 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
from aiohttp import ClientResponse
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
OpenAIImageVariationOptionalParams,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
FileTypes,
|
||||
HttpHandlerRequestFields,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseImageVariationConfig(BaseConfig, ABC):
|
||||
@abstractmethod
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageVariationOptionalParams]:
|
||||
pass
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
|
||||
@abstractmethod
|
||||
def transform_request_image_variation(
|
||||
self,
|
||||
model: Optional[str],
|
||||
image: FileTypes,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> HttpHandlerRequestFields:
|
||||
pass
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
async def async_transform_response_image_variation(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: ClientResponse,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
image: FileTypes,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
) -> ImageResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_response_image_variation(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
image: FileTypes,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
) -> ImageResponse:
|
||||
pass
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"ImageVariationConfig implementa 'transform_request_image_variation' for image variation models"
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"ImageVariationConfig implements 'transform_response_image_variation' for image variation models"
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Base classes for Interactions API implementations."""
|
||||
|
||||
from litellm.llms.base_llm.interactions.transformation import BaseInteractionsAPIConfig
|
||||
|
||||
__all__ = ["BaseInteractionsAPIConfig"]
|
||||
@@ -0,0 +1,310 @@
|
||||
"""
|
||||
Base transformation class for Interactions API implementations.
|
||||
|
||||
This follows the same pattern as BaseResponsesAPIConfig for the Responses API.
|
||||
|
||||
Per OpenAPI spec (https://ai.google.dev/static/api/interactions.openapi.json):
|
||||
- Create: POST /{api_version}/interactions
|
||||
- Get: GET /{api_version}/interactions/{interaction_id}
|
||||
- Delete: DELETE /{api_version}/interactions/{interaction_id}
|
||||
"""
|
||||
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.interactions import (
|
||||
CancelInteractionResult,
|
||||
DeleteInteractionResult,
|
||||
InteractionInput,
|
||||
InteractionsAPIOptionalRequestParams,
|
||||
InteractionsAPIResponse,
|
||||
InteractionsAPIStreamingResponse,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
from ..chat.transformation import BaseLLMException as _BaseLLMException
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
BaseLLMException = _BaseLLMException
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
BaseLLMException = Any
|
||||
|
||||
|
||||
class BaseInteractionsAPIConfig(ABC):
|
||||
"""
|
||||
Base configuration class for Google Interactions API implementations.
|
||||
|
||||
Per OpenAPI spec, the Interactions API supports two types of interactions:
|
||||
- Model interactions (with model parameter)
|
||||
- Agent interactions (with agent parameter)
|
||||
|
||||
Implementations should override the abstract methods to provide
|
||||
provider-specific transformations for requests and responses.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
"""Return the LLM provider identifier."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
Return the list of supported parameters for the given model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
"""
|
||||
Validate and prepare environment settings including headers.
|
||||
"""
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: Optional[str],
|
||||
agent: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for the interaction request.
|
||||
|
||||
Per OpenAPI spec: POST /{api_version}/interactions
|
||||
|
||||
Args:
|
||||
api_base: Base URL for the API
|
||||
model: The model name (for model interactions)
|
||||
agent: The agent name (for agent interactions)
|
||||
litellm_params: LiteLLM parameters
|
||||
stream: Whether this is a streaming request
|
||||
|
||||
Returns:
|
||||
The complete URL for the request
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required")
|
||||
return api_base
|
||||
|
||||
@abstractmethod
|
||||
def transform_request(
|
||||
self,
|
||||
model: Optional[str],
|
||||
agent: Optional[str],
|
||||
input: Optional[InteractionInput],
|
||||
optional_params: InteractionsAPIOptionalRequestParams,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
"""
|
||||
Transform the input request into the provider's expected format.
|
||||
|
||||
Per OpenAPI spec, the request body should be either:
|
||||
- CreateModelInteractionParams (with model)
|
||||
- CreateAgentInteractionParams (with agent)
|
||||
|
||||
Args:
|
||||
model: The model name (for model interactions)
|
||||
agent: The agent name (for agent interactions)
|
||||
input: The input content (string, content object, or list)
|
||||
optional_params: Optional parameters for the request
|
||||
litellm_params: LiteLLM-specific parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
The transformed request body as a dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> InteractionsAPIResponse:
|
||||
"""
|
||||
Transform the raw HTTP response into an InteractionsAPIResponse.
|
||||
|
||||
Per OpenAPI spec, the response is an Interaction object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_streaming_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
parsed_chunk: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> InteractionsAPIStreamingResponse:
|
||||
"""
|
||||
Transform a parsed streaming response chunk into an InteractionsAPIStreamingResponse.
|
||||
|
||||
Per OpenAPI spec, streaming uses SSE with various event types.
|
||||
"""
|
||||
pass
|
||||
|
||||
# =========================================================
|
||||
# GET INTERACTION TRANSFORMATION
|
||||
# =========================================================
|
||||
|
||||
@abstractmethod
|
||||
def transform_get_interaction_request(
|
||||
self,
|
||||
interaction_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the get interaction request into URL and query params.
|
||||
|
||||
Per OpenAPI spec: GET /{api_version}/interactions/{interaction_id}
|
||||
|
||||
Returns:
|
||||
Tuple of (URL, query_params)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_get_interaction_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> InteractionsAPIResponse:
|
||||
"""
|
||||
Transform the get interaction response.
|
||||
"""
|
||||
pass
|
||||
|
||||
# =========================================================
|
||||
# DELETE INTERACTION TRANSFORMATION
|
||||
# =========================================================
|
||||
|
||||
@abstractmethod
|
||||
def transform_delete_interaction_request(
|
||||
self,
|
||||
interaction_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the delete interaction request into URL and body.
|
||||
|
||||
Per OpenAPI spec: DELETE /{api_version}/interactions/{interaction_id}
|
||||
|
||||
Returns:
|
||||
Tuple of (URL, request_body)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_delete_interaction_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
interaction_id: str,
|
||||
) -> DeleteInteractionResult:
|
||||
"""
|
||||
Transform the delete interaction response.
|
||||
"""
|
||||
pass
|
||||
|
||||
# =========================================================
|
||||
# CANCEL INTERACTION TRANSFORMATION
|
||||
# =========================================================
|
||||
|
||||
@abstractmethod
|
||||
def transform_cancel_interaction_request(
|
||||
self,
|
||||
interaction_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the cancel interaction request into URL and body.
|
||||
|
||||
Returns:
|
||||
Tuple of (URL, request_body)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_cancel_interaction_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> CancelInteractionResult:
|
||||
"""
|
||||
Transform the cancel interaction response.
|
||||
"""
|
||||
pass
|
||||
|
||||
# =========================================================
|
||||
# ERROR HANDLING
|
||||
# =========================================================
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
"""
|
||||
Get the appropriate exception class for an error.
|
||||
"""
|
||||
from ..chat.transformation import BaseLLMException
|
||||
|
||||
raise BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if litellm should fake a stream for the given model.
|
||||
|
||||
Override in subclasses if the provider doesn't support native streaming.
|
||||
"""
|
||||
return False
|
||||
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Managed Resources Module
|
||||
|
||||
This module provides base classes and utilities for managing resources
|
||||
(files, vector stores, etc.) with target_model_names support.
|
||||
|
||||
The BaseManagedResource class provides common functionality for:
|
||||
- Storing unified resource IDs with model mappings
|
||||
- Retrieving resources by unified ID
|
||||
- Deleting resources across multiple models
|
||||
- Creating resources for multiple models
|
||||
- Filtering deployments based on model mappings
|
||||
"""
|
||||
|
||||
from .base_managed_resource import BaseManagedResource
|
||||
from .utils import (
|
||||
decode_unified_id,
|
||||
encode_unified_id,
|
||||
extract_model_id_from_unified_id,
|
||||
extract_provider_resource_id_from_unified_id,
|
||||
extract_resource_type_from_unified_id,
|
||||
extract_target_model_names_from_unified_id,
|
||||
extract_unified_uuid_from_unified_id,
|
||||
generate_unified_id_string,
|
||||
is_base64_encoded_unified_id,
|
||||
parse_unified_id,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseManagedResource",
|
||||
"is_base64_encoded_unified_id",
|
||||
"extract_target_model_names_from_unified_id",
|
||||
"extract_resource_type_from_unified_id",
|
||||
"extract_unified_uuid_from_unified_id",
|
||||
"extract_model_id_from_unified_id",
|
||||
"extract_provider_resource_id_from_unified_id",
|
||||
"generate_unified_id_string",
|
||||
"encode_unified_id",
|
||||
"decode_unified_id",
|
||||
"parse_unified_id",
|
||||
]
|
||||
@@ -0,0 +1,607 @@
|
||||
# What is this?
|
||||
## Base class for managing resources (files, vector stores, etc.) with target_model_names support
|
||||
## This provides common functionality for creating, retrieving, and managing resources across multiple models
|
||||
|
||||
import base64
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.utils import SpecialEnums
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
||||
from litellm.proxy.utils import PrismaClient as _PrismaClient
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
InternalUsageCache = _InternalUsageCache
|
||||
PrismaClient = _PrismaClient
|
||||
Router = _Router
|
||||
else:
|
||||
Span = Any
|
||||
InternalUsageCache = Any
|
||||
PrismaClient = Any
|
||||
Router = Any
|
||||
|
||||
# Generic type for resource objects
|
||||
ResourceObjectType = TypeVar("ResourceObjectType")
|
||||
|
||||
|
||||
class BaseManagedResource(ABC, Generic[ResourceObjectType]):
|
||||
"""
|
||||
Base class for managing resources with target_model_names support.
|
||||
|
||||
This class provides common functionality for:
|
||||
- Storing unified resource IDs with model mappings
|
||||
- Retrieving resources by unified ID
|
||||
- Deleting resources across multiple models
|
||||
- Creating resources for multiple models
|
||||
- Filtering deployments based on model mappings
|
||||
|
||||
Subclasses should implement:
|
||||
- resource_type: str property
|
||||
- table_name: str property
|
||||
- create_resource_for_model: method to create resource on a specific model
|
||||
- get_unified_resource_id_format: method to generate unified ID format
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
internal_usage_cache: InternalUsageCache,
|
||||
prisma_client: PrismaClient,
|
||||
):
|
||||
self.internal_usage_cache = internal_usage_cache
|
||||
self.prisma_client = prisma_client
|
||||
|
||||
# ============================================================================
|
||||
# ABSTRACT METHODS
|
||||
# ============================================================================
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def resource_type(self) -> str:
|
||||
"""
|
||||
Return the resource type identifier (e.g., 'file', 'vector_store', 'vector_store_file').
|
||||
Used for logging and unified ID generation.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def table_name(self) -> str:
|
||||
"""
|
||||
Return the database table name for this resource type.
|
||||
Example: 'litellm_managedfiletable', 'litellm_managedvectorstoretable'
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_unified_resource_id_format(
|
||||
self,
|
||||
resource_object: ResourceObjectType,
|
||||
target_model_names_list: List[str],
|
||||
) -> str:
|
||||
"""
|
||||
Generate the format string for the unified resource ID.
|
||||
|
||||
This should return a string that will be base64 encoded.
|
||||
Example for files:
|
||||
"litellm_proxy:application/json;unified_id,{uuid};target_model_names,{models};..."
|
||||
|
||||
Args:
|
||||
resource_object: The resource object returned from the provider
|
||||
target_model_names_list: List of target model names
|
||||
|
||||
Returns:
|
||||
Format string to be base64 encoded
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_resource_for_model(
|
||||
self,
|
||||
llm_router: Router,
|
||||
model: str,
|
||||
request_data: Dict[str, Any],
|
||||
litellm_parent_otel_span: Span,
|
||||
) -> ResourceObjectType:
|
||||
"""
|
||||
Create a resource for a specific model.
|
||||
|
||||
Args:
|
||||
llm_router: LiteLLM router instance
|
||||
model: Model name to create resource for
|
||||
request_data: Request data for resource creation
|
||||
litellm_parent_otel_span: OpenTelemetry span for tracing
|
||||
|
||||
Returns:
|
||||
Resource object from the provider
|
||||
"""
|
||||
pass
|
||||
|
||||
# ============================================================================
|
||||
# COMMON STORAGE OPERATIONS
|
||||
# ============================================================================
|
||||
|
||||
async def store_unified_resource_id(
|
||||
self,
|
||||
unified_resource_id: str,
|
||||
resource_object: Optional[ResourceObjectType],
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
model_mappings: Dict[str, str],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
additional_db_fields: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Store unified resource ID with model mappings in cache and database.
|
||||
|
||||
Args:
|
||||
unified_resource_id: The unified resource ID (base64 encoded)
|
||||
resource_object: The resource object to store (can be None)
|
||||
litellm_parent_otel_span: OpenTelemetry span for tracing
|
||||
model_mappings: Dictionary mapping model_id -> provider_resource_id
|
||||
user_api_key_dict: User API key authentication details
|
||||
additional_db_fields: Additional fields to store in database
|
||||
"""
|
||||
verbose_logger.info(
|
||||
f"Storing LiteLLM Managed {self.resource_type} with id={unified_resource_id} in cache"
|
||||
)
|
||||
|
||||
# Prepare cache data
|
||||
cache_data = {
|
||||
"unified_resource_id": unified_resource_id,
|
||||
"resource_object": resource_object,
|
||||
"model_mappings": model_mappings,
|
||||
"flat_model_resource_ids": list(model_mappings.values()),
|
||||
"created_by": user_api_key_dict.user_id,
|
||||
"updated_by": user_api_key_dict.user_id,
|
||||
}
|
||||
|
||||
# Add additional fields if provided
|
||||
if additional_db_fields:
|
||||
cache_data.update(additional_db_fields)
|
||||
|
||||
# Store in cache
|
||||
if resource_object is not None:
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
key=unified_resource_id,
|
||||
value=cache_data,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
|
||||
# Prepare database data
|
||||
db_data = {
|
||||
"unified_resource_id": unified_resource_id,
|
||||
"model_mappings": json.dumps(model_mappings),
|
||||
"flat_model_resource_ids": list(model_mappings.values()),
|
||||
"created_by": user_api_key_dict.user_id,
|
||||
"updated_by": user_api_key_dict.user_id,
|
||||
}
|
||||
|
||||
# Add resource object if available
|
||||
if resource_object is not None:
|
||||
# Handle both dict and Pydantic models
|
||||
if hasattr(resource_object, "model_dump_json"):
|
||||
db_data["resource_object"] = resource_object.model_dump_json() # type: ignore
|
||||
elif isinstance(resource_object, dict):
|
||||
db_data["resource_object"] = json.dumps(resource_object)
|
||||
|
||||
# Extract storage metadata from hidden params if present
|
||||
hidden_params = getattr(resource_object, "_hidden_params", {}) or {}
|
||||
if "storage_backend" in hidden_params:
|
||||
db_data["storage_backend"] = hidden_params["storage_backend"]
|
||||
if "storage_url" in hidden_params:
|
||||
db_data["storage_url"] = hidden_params["storage_url"]
|
||||
|
||||
# Add additional fields to database
|
||||
if additional_db_fields:
|
||||
db_data.update(additional_db_fields)
|
||||
|
||||
# Store in database
|
||||
table = getattr(self.prisma_client.db, self.table_name)
|
||||
result = await table.create(data=db_data)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"LiteLLM Managed {self.resource_type} with id={unified_resource_id} stored in db: {result}"
|
||||
)
|
||||
|
||||
async def get_unified_resource_id(
|
||||
self,
|
||||
unified_resource_id: str,
|
||||
litellm_parent_otel_span: Optional[Span] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve unified resource by ID from cache or database.
|
||||
|
||||
Args:
|
||||
unified_resource_id: The unified resource ID to retrieve
|
||||
litellm_parent_otel_span: OpenTelemetry span for tracing
|
||||
|
||||
Returns:
|
||||
Dictionary containing resource data or None if not found
|
||||
"""
|
||||
# Check cache first
|
||||
result = cast(
|
||||
Optional[dict],
|
||||
await self.internal_usage_cache.async_get_cache(
|
||||
key=unified_resource_id,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
),
|
||||
)
|
||||
|
||||
if result:
|
||||
return result
|
||||
|
||||
# Check database
|
||||
table = getattr(self.prisma_client.db, self.table_name)
|
||||
db_object = await table.find_first(
|
||||
where={"unified_resource_id": unified_resource_id}
|
||||
)
|
||||
|
||||
if db_object:
|
||||
return db_object.model_dump()
|
||||
|
||||
return None
|
||||
|
||||
async def delete_unified_resource_id(
|
||||
self,
|
||||
unified_resource_id: str,
|
||||
litellm_parent_otel_span: Optional[Span] = None,
|
||||
) -> Optional[ResourceObjectType]:
|
||||
"""
|
||||
Delete unified resource from cache and database.
|
||||
|
||||
Args:
|
||||
unified_resource_id: The unified resource ID to delete
|
||||
litellm_parent_otel_span: OpenTelemetry span for tracing
|
||||
|
||||
Returns:
|
||||
The deleted resource object or None if not found
|
||||
"""
|
||||
# Get old value from database
|
||||
table = getattr(self.prisma_client.db, self.table_name)
|
||||
initial_value = await table.find_first(
|
||||
where={"unified_resource_id": unified_resource_id}
|
||||
)
|
||||
|
||||
if initial_value is None:
|
||||
raise Exception(
|
||||
f"LiteLLM Managed {self.resource_type} with id={unified_resource_id} not found"
|
||||
)
|
||||
|
||||
# Delete from cache
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
key=unified_resource_id,
|
||||
value=None,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
|
||||
# Delete from database
|
||||
await table.delete(where={"unified_resource_id": unified_resource_id})
|
||||
|
||||
return initial_value.resource_object
|
||||
|
||||
async def can_user_access_unified_resource_id(
|
||||
self,
|
||||
unified_resource_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
litellm_parent_otel_span: Optional[Span] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has access to the unified resource ID.
|
||||
|
||||
Uses get_unified_resource_id() which checks cache first before hitting the database,
|
||||
avoiding direct DB queries in the critical request path.
|
||||
|
||||
Args:
|
||||
unified_resource_id: The unified resource ID to check
|
||||
user_api_key_dict: User API key authentication details
|
||||
litellm_parent_otel_span: OpenTelemetry span for tracing
|
||||
|
||||
Returns:
|
||||
True if user has access, False otherwise
|
||||
"""
|
||||
user_id = user_api_key_dict.user_id
|
||||
|
||||
# Use cached method instead of direct DB query
|
||||
resource = await self.get_unified_resource_id(
|
||||
unified_resource_id, litellm_parent_otel_span
|
||||
)
|
||||
|
||||
if resource:
|
||||
return resource.get("created_by") == user_id
|
||||
|
||||
return False
|
||||
|
||||
# ============================================================================
|
||||
# MODEL MAPPING OPERATIONS
|
||||
# ============================================================================
|
||||
|
||||
async def get_model_resource_id_mapping(
|
||||
self,
|
||||
resource_ids: List[str],
|
||||
litellm_parent_otel_span: Span,
|
||||
) -> Dict[str, Dict[str, str]]:
|
||||
"""
|
||||
Get model-specific resource IDs for a list of unified resource IDs.
|
||||
|
||||
Args:
|
||||
resource_ids: List of unified resource IDs
|
||||
litellm_parent_otel_span: OpenTelemetry span for tracing
|
||||
|
||||
Returns:
|
||||
Dictionary mapping unified_resource_id -> model_id -> provider_resource_id
|
||||
|
||||
Example:
|
||||
{
|
||||
"unified_resource_id_1": {
|
||||
"model_id_1": "provider_resource_id_1",
|
||||
"model_id_2": "provider_resource_id_2"
|
||||
}
|
||||
}
|
||||
"""
|
||||
resource_id_mapping: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
for resource_id in resource_ids:
|
||||
# Get unified resource from cache/db
|
||||
unified_resource_object = await self.get_unified_resource_id(
|
||||
resource_id, litellm_parent_otel_span
|
||||
)
|
||||
|
||||
if unified_resource_object:
|
||||
model_mappings = unified_resource_object.get("model_mappings", {})
|
||||
|
||||
# Handle both JSON string and dict
|
||||
if isinstance(model_mappings, str):
|
||||
model_mappings = json.loads(model_mappings)
|
||||
|
||||
resource_id_mapping[resource_id] = model_mappings
|
||||
|
||||
return resource_id_mapping
|
||||
|
||||
# ============================================================================
|
||||
# RESOURCE CREATION OPERATIONS
|
||||
# ============================================================================
|
||||
|
||||
async def create_resource_for_each_model(
|
||||
self,
|
||||
llm_router: Router,
|
||||
request_data: Dict[str, Any],
|
||||
target_model_names_list: List[str],
|
||||
litellm_parent_otel_span: Span,
|
||||
) -> List[ResourceObjectType]:
|
||||
"""
|
||||
Create a resource for each model in the target list.
|
||||
|
||||
Args:
|
||||
llm_router: LiteLLM router instance
|
||||
request_data: Request data for resource creation
|
||||
target_model_names_list: List of target model names
|
||||
litellm_parent_otel_span: OpenTelemetry span for tracing
|
||||
|
||||
Returns:
|
||||
List of resource objects created for each model
|
||||
"""
|
||||
if llm_router is None:
|
||||
raise Exception("LLM Router not initialized. Ensure models added to proxy.")
|
||||
|
||||
responses = []
|
||||
for model in target_model_names_list:
|
||||
individual_response = await self.create_resource_for_model(
|
||||
llm_router=llm_router,
|
||||
model=model,
|
||||
request_data=request_data,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
responses.append(individual_response)
|
||||
return responses
|
||||
|
||||
def generate_unified_resource_id(
|
||||
self,
|
||||
resource_objects: List[ResourceObjectType],
|
||||
target_model_names_list: List[str],
|
||||
) -> str:
|
||||
"""
|
||||
Generate a unified resource ID from multiple resource objects.
|
||||
|
||||
Args:
|
||||
resource_objects: List of resource objects from different models
|
||||
target_model_names_list: List of target model names
|
||||
|
||||
Returns:
|
||||
Base64 encoded unified resource ID
|
||||
"""
|
||||
# Use the first resource object to generate the format
|
||||
unified_id_format = self.get_unified_resource_id_format(
|
||||
resource_object=resource_objects[0],
|
||||
target_model_names_list=target_model_names_list,
|
||||
)
|
||||
|
||||
# Convert to URL-safe base64 and strip padding
|
||||
base64_unified_id = (
|
||||
base64.urlsafe_b64encode(unified_id_format.encode()).decode().rstrip("=")
|
||||
)
|
||||
|
||||
return base64_unified_id
|
||||
|
||||
def extract_model_mappings_from_responses(
|
||||
self,
|
||||
resource_objects: List[ResourceObjectType],
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Extract model mappings from resource objects.
|
||||
|
||||
Args:
|
||||
resource_objects: List of resource objects from different models
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model_id -> provider_resource_id
|
||||
"""
|
||||
model_mappings: Dict[str, str] = {}
|
||||
|
||||
for resource_object in resource_objects:
|
||||
# Get hidden params if available
|
||||
hidden_params = getattr(resource_object, "_hidden_params", {}) or {}
|
||||
model_resource_id_mapping = hidden_params.get("model_resource_id_mapping")
|
||||
|
||||
if model_resource_id_mapping and isinstance(
|
||||
model_resource_id_mapping, dict
|
||||
):
|
||||
model_mappings.update(model_resource_id_mapping)
|
||||
|
||||
return model_mappings
|
||||
|
||||
# ============================================================================
|
||||
# DEPLOYMENT FILTERING
|
||||
# ============================================================================
|
||||
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
request_kwargs: Optional[Dict] = None,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
resource_id_key: str = "resource_id",
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Filter deployments based on model mappings for a resource.
|
||||
|
||||
This is used by the router to select only deployments that have
|
||||
the resource available.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
healthy_deployments: List of healthy deployments
|
||||
request_kwargs: Request kwargs containing resource_id and mappings
|
||||
parent_otel_span: OpenTelemetry span for tracing
|
||||
resource_id_key: Key to use for resource ID in request_kwargs
|
||||
|
||||
Returns:
|
||||
Filtered list of deployments
|
||||
"""
|
||||
if request_kwargs is None:
|
||||
return healthy_deployments
|
||||
|
||||
resource_id = cast(Optional[str], request_kwargs.get(resource_id_key))
|
||||
model_resource_id_mapping = cast(
|
||||
Optional[Dict[str, Dict[str, str]]],
|
||||
request_kwargs.get("model_resource_id_mapping"),
|
||||
)
|
||||
|
||||
allowed_model_ids = []
|
||||
if resource_id and model_resource_id_mapping:
|
||||
model_id_dict = model_resource_id_mapping.get(resource_id, {})
|
||||
allowed_model_ids = list(model_id_dict.keys())
|
||||
|
||||
if len(allowed_model_ids) == 0:
|
||||
return healthy_deployments
|
||||
|
||||
return [
|
||||
deployment
|
||||
for deployment in healthy_deployments
|
||||
if deployment.get("model_info", {}).get("id") in allowed_model_ids
|
||||
]
|
||||
|
||||
# ============================================================================
|
||||
# UTILITY METHODS
|
||||
# ============================================================================
|
||||
|
||||
def get_unified_id_prefix(self) -> str:
|
||||
"""
|
||||
Get the prefix for unified IDs for this resource type.
|
||||
|
||||
Returns:
|
||||
Prefix string (e.g., "litellm_proxy:")
|
||||
"""
|
||||
return SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value
|
||||
|
||||
async def list_user_resources(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
limit: Optional[int] = None,
|
||||
after: Optional[str] = None,
|
||||
additional_filters: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
List resources created by a user.
|
||||
|
||||
Args:
|
||||
user_api_key_dict: User API key authentication details
|
||||
limit: Maximum number of resources to return
|
||||
after: Cursor for pagination
|
||||
additional_filters: Additional filters to apply
|
||||
|
||||
Returns:
|
||||
Dictionary with list of resources and pagination info
|
||||
"""
|
||||
where_clause: Dict[str, Any] = {}
|
||||
|
||||
# Filter by user who created the resource
|
||||
if user_api_key_dict.user_id:
|
||||
where_clause["created_by"] = user_api_key_dict.user_id
|
||||
|
||||
if after:
|
||||
where_clause["id"] = {"gt": after}
|
||||
|
||||
# Add additional filters
|
||||
if additional_filters:
|
||||
where_clause.update(additional_filters)
|
||||
|
||||
# Fetch resources
|
||||
fetch_limit = limit or 20
|
||||
table = getattr(self.prisma_client.db, self.table_name)
|
||||
resources = await table.find_many(
|
||||
where=where_clause,
|
||||
take=fetch_limit,
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
|
||||
resource_objects: List[Any] = []
|
||||
for resource in resources:
|
||||
try:
|
||||
# Stop once we have enough
|
||||
if len(resource_objects) >= (limit or 20):
|
||||
break
|
||||
|
||||
# Parse resource object
|
||||
resource_data = resource.resource_object
|
||||
if isinstance(resource_data, str):
|
||||
resource_data = json.loads(resource_data)
|
||||
|
||||
# Set unified ID
|
||||
if hasattr(resource_data, "id"):
|
||||
resource_data.id = resource.unified_resource_id
|
||||
elif isinstance(resource_data, dict):
|
||||
resource_data["id"] = resource.unified_resource_id
|
||||
|
||||
resource_objects.append(resource_data)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
f"Failed to parse {self.resource_type} object "
|
||||
f"{resource.unified_resource_id}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": resource_objects,
|
||||
"first_id": resource_objects[0].id if resource_objects else None,
|
||||
"last_id": resource_objects[-1].id if resource_objects else None,
|
||||
"has_more": len(resource_objects) == (limit or 20),
|
||||
}
|
||||
@@ -0,0 +1,364 @@
|
||||
"""
|
||||
Utility functions for managed resources.
|
||||
|
||||
This module provides common utility functions that can be used across
|
||||
different managed resource types (files, vector stores, etc.).
|
||||
"""
|
||||
|
||||
import base64
|
||||
import re
|
||||
from typing import List, Optional, Union, Literal
|
||||
|
||||
|
||||
def is_base64_encoded_unified_id(
|
||||
resource_id: str,
|
||||
prefix: str = "litellm_proxy:",
|
||||
) -> Union[str, Literal[False]]:
|
||||
"""
|
||||
Check if a resource ID is a base64 encoded unified ID.
|
||||
|
||||
Args:
|
||||
resource_id: The resource ID to check
|
||||
prefix: The expected prefix for unified IDs
|
||||
|
||||
Returns:
|
||||
Decoded string if valid unified ID, False otherwise
|
||||
"""
|
||||
# Ensure resource_id is a string
|
||||
if not isinstance(resource_id, str):
|
||||
return False
|
||||
|
||||
# Add padding back if needed
|
||||
padded = resource_id + "=" * (-len(resource_id) % 4)
|
||||
|
||||
# Decode from base64
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(padded).decode()
|
||||
if decoded.startswith(prefix):
|
||||
return decoded
|
||||
else:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def extract_target_model_names_from_unified_id(
|
||||
unified_id: str,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Extract target model names from a unified resource ID.
|
||||
|
||||
Args:
|
||||
unified_id: The unified resource ID (decoded or encoded)
|
||||
|
||||
Returns:
|
||||
List of target model names
|
||||
|
||||
Example:
|
||||
unified_id = "litellm_proxy:vector_store;unified_id,uuid;target_model_names,gpt-4,gemini-2.0"
|
||||
returns: ["gpt-4", "gemini-2.0"]
|
||||
"""
|
||||
try:
|
||||
# Ensure unified_id is a string
|
||||
if not isinstance(unified_id, str):
|
||||
return []
|
||||
|
||||
# Decode if it's base64 encoded
|
||||
decoded_id = is_base64_encoded_unified_id(unified_id)
|
||||
if decoded_id:
|
||||
unified_id = decoded_id
|
||||
|
||||
# Extract model names using regex
|
||||
match = re.search(r"target_model_names,([^;]+)", unified_id)
|
||||
if match:
|
||||
# Split on comma and strip whitespace from each model name
|
||||
return [model.strip() for model in match.group(1).split(",")]
|
||||
|
||||
return []
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def extract_resource_type_from_unified_id(
|
||||
unified_id: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Extract resource type from a unified resource ID.
|
||||
|
||||
Args:
|
||||
unified_id: The unified resource ID (decoded or encoded)
|
||||
|
||||
Returns:
|
||||
Resource type string or None
|
||||
|
||||
Example:
|
||||
unified_id = "litellm_proxy:vector_store;unified_id,uuid;..."
|
||||
returns: "vector_store"
|
||||
"""
|
||||
try:
|
||||
# Ensure unified_id is a string
|
||||
if not isinstance(unified_id, str):
|
||||
return None
|
||||
|
||||
# Decode if it's base64 encoded
|
||||
decoded_id = is_base64_encoded_unified_id(unified_id)
|
||||
if decoded_id:
|
||||
unified_id = decoded_id
|
||||
|
||||
# Extract resource type (comes after prefix and before first semicolon)
|
||||
match = re.search(r"litellm_proxy:([^;]+)", unified_id)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def extract_unified_uuid_from_unified_id(
|
||||
unified_id: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Extract the UUID from a unified resource ID.
|
||||
|
||||
Args:
|
||||
unified_id: The unified resource ID (decoded or encoded)
|
||||
|
||||
Returns:
|
||||
UUID string or None
|
||||
|
||||
Example:
|
||||
unified_id = "litellm_proxy:vector_store;unified_id,abc-123;..."
|
||||
returns: "abc-123"
|
||||
"""
|
||||
try:
|
||||
# Ensure unified_id is a string
|
||||
if not isinstance(unified_id, str):
|
||||
return None
|
||||
|
||||
# Decode if it's base64 encoded
|
||||
decoded_id = is_base64_encoded_unified_id(unified_id)
|
||||
if decoded_id:
|
||||
unified_id = decoded_id
|
||||
|
||||
# Extract UUID
|
||||
match = re.search(r"unified_id,([^;]+)", unified_id)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def extract_model_id_from_unified_id(
|
||||
unified_id: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Extract model ID from a unified resource ID.
|
||||
|
||||
Args:
|
||||
unified_id: The unified resource ID (decoded or encoded)
|
||||
|
||||
Returns:
|
||||
Model ID string or None
|
||||
|
||||
Example:
|
||||
unified_id = "litellm_proxy:vector_store;...;model_id,gpt-4-model-id;..."
|
||||
returns: "gpt-4-model-id"
|
||||
"""
|
||||
try:
|
||||
# Ensure unified_id is a string
|
||||
if not isinstance(unified_id, str):
|
||||
return None
|
||||
|
||||
# Decode if it's base64 encoded
|
||||
decoded_id = is_base64_encoded_unified_id(unified_id)
|
||||
if decoded_id:
|
||||
unified_id = decoded_id
|
||||
|
||||
# Extract model ID
|
||||
match = re.search(r"model_id,([^;]+)", unified_id)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def extract_provider_resource_id_from_unified_id(
|
||||
unified_id: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Extract provider resource ID from a unified resource ID.
|
||||
|
||||
Args:
|
||||
unified_id: The unified resource ID (decoded or encoded)
|
||||
|
||||
Returns:
|
||||
Provider resource ID string or None
|
||||
|
||||
Example:
|
||||
unified_id = "litellm_proxy:vector_store;...;resource_id,vs_abc123;..."
|
||||
returns: "vs_abc123"
|
||||
"""
|
||||
try:
|
||||
# Ensure unified_id is a string
|
||||
if not isinstance(unified_id, str):
|
||||
return None
|
||||
|
||||
# Decode if it's base64 encoded
|
||||
decoded_id = is_base64_encoded_unified_id(unified_id)
|
||||
if decoded_id:
|
||||
unified_id = decoded_id
|
||||
|
||||
# Extract resource ID (try multiple patterns for different resource types)
|
||||
patterns = [
|
||||
r"resource_id,([^;]+)",
|
||||
r"vector_store_id,([^;]+)",
|
||||
r"file_id,([^;]+)",
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, unified_id)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def generate_unified_id_string(
|
||||
resource_type: str,
|
||||
unified_uuid: str,
|
||||
target_model_names: List[str],
|
||||
provider_resource_id: str,
|
||||
model_id: str,
|
||||
additional_fields: Optional[dict] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a unified ID string (before base64 encoding).
|
||||
|
||||
Args:
|
||||
resource_type: Type of resource (e.g., "vector_store", "file")
|
||||
unified_uuid: UUID for this unified resource
|
||||
target_model_names: List of target model names
|
||||
provider_resource_id: Resource ID from the provider
|
||||
model_id: Model ID from the router
|
||||
additional_fields: Additional fields to include in the ID
|
||||
|
||||
Returns:
|
||||
Unified ID string (not yet base64 encoded)
|
||||
|
||||
Example:
|
||||
generate_unified_id_string(
|
||||
resource_type="vector_store",
|
||||
unified_uuid="abc-123",
|
||||
target_model_names=["gpt-4", "gemini"],
|
||||
provider_resource_id="vs_xyz",
|
||||
model_id="model-id-123",
|
||||
)
|
||||
returns: "litellm_proxy:vector_store;unified_id,abc-123;target_model_names,gpt-4,gemini;resource_id,vs_xyz;model_id,model-id-123"
|
||||
"""
|
||||
# Build the unified ID string
|
||||
parts = [
|
||||
f"litellm_proxy:{resource_type}",
|
||||
f"unified_id,{unified_uuid}",
|
||||
f"target_model_names,{','.join(target_model_names)}",
|
||||
f"resource_id,{provider_resource_id}",
|
||||
f"model_id,{model_id}",
|
||||
]
|
||||
|
||||
# Add additional fields if provided
|
||||
if additional_fields:
|
||||
for key, value in additional_fields.items():
|
||||
parts.append(f"{key},{value}")
|
||||
|
||||
return ";".join(parts)
|
||||
|
||||
|
||||
def encode_unified_id(unified_id_string: str) -> str:
|
||||
"""
|
||||
Encode a unified ID string to base64.
|
||||
|
||||
Args:
|
||||
unified_id_string: The unified ID string to encode
|
||||
|
||||
Returns:
|
||||
Base64 encoded unified ID (URL-safe, padding stripped)
|
||||
"""
|
||||
return base64.urlsafe_b64encode(unified_id_string.encode()).decode().rstrip("=")
|
||||
|
||||
|
||||
def decode_unified_id(encoded_unified_id: str) -> Optional[str]:
|
||||
"""
|
||||
Decode a base64 encoded unified ID.
|
||||
|
||||
Args:
|
||||
encoded_unified_id: The base64 encoded unified ID
|
||||
|
||||
Returns:
|
||||
Decoded unified ID string or None if invalid
|
||||
"""
|
||||
try:
|
||||
# Add padding back if needed
|
||||
padded = encoded_unified_id + "=" * (-len(encoded_unified_id) % 4)
|
||||
|
||||
# Decode from base64
|
||||
decoded = base64.urlsafe_b64decode(padded).decode()
|
||||
|
||||
# Verify it starts with the expected prefix
|
||||
if decoded.startswith("litellm_proxy:"):
|
||||
return decoded
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def parse_unified_id(
|
||||
unified_id: str,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Parse a unified ID into its components.
|
||||
|
||||
Args:
|
||||
unified_id: The unified ID (encoded or decoded)
|
||||
|
||||
Returns:
|
||||
Dictionary with parsed components or None if invalid
|
||||
|
||||
Example:
|
||||
{
|
||||
"resource_type": "vector_store",
|
||||
"unified_uuid": "abc-123",
|
||||
"target_model_names": ["gpt-4", "gemini"],
|
||||
"provider_resource_id": "vs_xyz",
|
||||
"model_id": "model-id-123"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# Decode if needed
|
||||
decoded_id = decode_unified_id(unified_id)
|
||||
if not decoded_id:
|
||||
# Maybe it's already decoded
|
||||
if unified_id.startswith("litellm_proxy:"):
|
||||
decoded_id = unified_id
|
||||
else:
|
||||
return None
|
||||
|
||||
return {
|
||||
"resource_type": extract_resource_type_from_unified_id(decoded_id),
|
||||
"unified_uuid": extract_unified_uuid_from_unified_id(decoded_id),
|
||||
"target_model_names": extract_target_model_names_from_unified_id(
|
||||
decoded_id
|
||||
),
|
||||
"provider_resource_id": extract_provider_resource_id_from_unified_id(
|
||||
decoded_id
|
||||
),
|
||||
"model_id": extract_model_id_from_unified_id(decoded_id),
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
@@ -0,0 +1,22 @@
|
||||
"""Base OCR transformation module."""
|
||||
from .transformation import (
|
||||
BaseOCRConfig,
|
||||
DocumentType,
|
||||
OCRPage,
|
||||
OCRPageDimensions,
|
||||
OCRPageImage,
|
||||
OCRRequestData,
|
||||
OCRResponse,
|
||||
OCRUsageInfo,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseOCRConfig",
|
||||
"DocumentType",
|
||||
"OCRResponse",
|
||||
"OCRPage",
|
||||
"OCRPageDimensions",
|
||||
"OCRPageImage",
|
||||
"OCRUsageInfo",
|
||||
"OCRRequestData",
|
||||
]
|
||||
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Base OCR transformation configuration.
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
# DocumentType for OCR - providers always receive a dict with
|
||||
# type="document_url" or type="image_url" (str values only).
|
||||
# File-type inputs are preprocessed to this format in litellm/ocr/main.py.
|
||||
DocumentType = Dict[str, str]
|
||||
|
||||
|
||||
class OCRPageDimensions(LiteLLMPydanticObjectBase):
|
||||
"""Page dimensions from OCR response."""
|
||||
|
||||
dpi: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
width: Optional[int] = None
|
||||
|
||||
|
||||
class OCRPageImage(LiteLLMPydanticObjectBase):
|
||||
"""Image extracted from OCR page."""
|
||||
|
||||
image_base64: Optional[str] = None
|
||||
bbox: Optional[Dict[str, Any]] = None
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class OCRPage(LiteLLMPydanticObjectBase):
|
||||
"""Single page from OCR response."""
|
||||
|
||||
index: int
|
||||
markdown: str
|
||||
images: Optional[List[OCRPageImage]] = None
|
||||
dimensions: Optional[OCRPageDimensions] = None
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class OCRUsageInfo(LiteLLMPydanticObjectBase):
|
||||
"""Usage information from OCR response."""
|
||||
|
||||
pages_processed: Optional[int] = None
|
||||
doc_size_bytes: Optional[int] = None
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class OCRResponse(LiteLLMPydanticObjectBase):
|
||||
"""
|
||||
Standard OCR response format.
|
||||
Standardized to Mistral OCR format - other providers should transform to this format.
|
||||
"""
|
||||
|
||||
pages: List[OCRPage]
|
||||
model: str
|
||||
document_annotation: Optional[Any] = None
|
||||
usage_info: Optional[OCRUsageInfo] = None
|
||||
object: str = "ocr"
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
# Define private attributes using PrivateAttr
|
||||
_hidden_params: dict = PrivateAttr(default_factory=dict)
|
||||
|
||||
|
||||
class OCRRequestData(LiteLLMPydanticObjectBase):
|
||||
"""OCR request data structure."""
|
||||
|
||||
data: Optional[Union[Dict, bytes]] = None
|
||||
files: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class BaseOCRConfig:
|
||||
"""
|
||||
Base configuration for OCR transformations.
|
||||
Handles provider-agnostic OCR operations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_supported_ocr_params(self, model: str) -> list:
|
||||
"""
|
||||
Get supported OCR parameters for this provider.
|
||||
Override this method in provider-specific implementations.
|
||||
"""
|
||||
return []
|
||||
|
||||
def map_ocr_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
) -> dict:
|
||||
"""Map OCR parameters to provider-specific parameters."""
|
||||
return optional_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: Dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""
|
||||
Validate environment and return headers.
|
||||
Override in provider-specific implementations.
|
||||
"""
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Get complete URL for OCR endpoint.
|
||||
Override in provider-specific implementations.
|
||||
"""
|
||||
raise NotImplementedError("get_complete_url must be implemented by provider")
|
||||
|
||||
def transform_ocr_request(
|
||||
self,
|
||||
model: str,
|
||||
document: DocumentType,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
**kwargs,
|
||||
) -> OCRRequestData:
|
||||
"""
|
||||
Transform OCR request to provider-specific format.
|
||||
Override in provider-specific implementations.
|
||||
|
||||
Note: By the time this method is called, any file-type documents have already
|
||||
been converted to document_url/image_url format with base64 data URIs by
|
||||
the preprocessing in litellm/ocr/main.py.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
document: Document to process - always a dict with type="document_url" or type="image_url"
|
||||
optional_params: Optional parameters for the request
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
OCRRequestData with data and files fields
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"transform_ocr_request must be implemented by provider"
|
||||
)
|
||||
|
||||
async def async_transform_ocr_request(
|
||||
self,
|
||||
model: str,
|
||||
document: DocumentType,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
**kwargs,
|
||||
) -> OCRRequestData:
|
||||
"""
|
||||
Async transform OCR request to provider-specific format.
|
||||
Optional method - providers can override if they need async transformations
|
||||
(e.g., Azure AI for URL-to-base64 conversion).
|
||||
|
||||
Default implementation falls back to sync transform_ocr_request.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
document: Document to process (Mistral format dict, or file path, bytes, etc.)
|
||||
optional_params: Optional parameters for the request
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
OCRRequestData with data and files fields
|
||||
"""
|
||||
# Default implementation: call sync version
|
||||
return self.transform_ocr_request(
|
||||
model=model,
|
||||
document=document,
|
||||
optional_params=optional_params,
|
||||
headers=headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def transform_ocr_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
**kwargs,
|
||||
) -> OCRResponse:
|
||||
"""
|
||||
Transform provider-specific OCR response to standard format.
|
||||
Override in provider-specific implementations.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"transform_ocr_response must be implemented by provider"
|
||||
)
|
||||
|
||||
async def async_transform_ocr_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
**kwargs,
|
||||
) -> OCRResponse:
|
||||
"""
|
||||
Async transform provider-specific OCR response to standard format.
|
||||
Optional method - providers can override if they need async transformations
|
||||
(e.g., Azure Document Intelligence for async operation polling).
|
||||
|
||||
Default implementation falls back to sync transform_ocr_response.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
OCRResponse in standard format
|
||||
"""
|
||||
# Default implementation: call sync version
|
||||
return self.transform_ocr_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
logging_obj=logging_obj,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self,
|
||||
error_message: str,
|
||||
status_code: int,
|
||||
headers: dict,
|
||||
) -> Exception:
|
||||
"""Get appropriate error class for the provider."""
|
||||
return BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,139 @@
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
from ..base_utils import BaseLLMModelInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from httpx import URL, Headers, Response
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.types.utils import CostResponseTypes
|
||||
|
||||
from ..chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class BasePassthroughConfig(BaseLLMModelInfo):
|
||||
@abstractmethod
|
||||
def is_streaming_request(self, endpoint: str, request_data: dict) -> bool:
|
||||
"""
|
||||
Check if the request is a streaming request
|
||||
"""
|
||||
pass
|
||||
|
||||
def format_url(
|
||||
self,
|
||||
endpoint: str,
|
||||
base_target_url: str,
|
||||
request_query_params: Optional[dict],
|
||||
) -> "URL":
|
||||
"""
|
||||
Helper function to add query params to the url
|
||||
Args:
|
||||
endpoint: str - the endpoint to add to the url
|
||||
base_target_url: str - the base url to add the endpoint to
|
||||
request_query_params: Optional[dict] - the query params to add to the url
|
||||
Returns:
|
||||
httpx.URL - the formatted url
|
||||
"""
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
|
||||
base = base_target_url.rstrip("/")
|
||||
endpoint = endpoint.lstrip("/")
|
||||
full_url = f"{base}/{endpoint}"
|
||||
|
||||
url = httpx.URL(full_url)
|
||||
|
||||
if request_query_params:
|
||||
url = url.copy_with(query=urlencode(request_query_params).encode("ascii"))
|
||||
|
||||
return url
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
endpoint: str,
|
||||
request_query_params: Optional[dict],
|
||||
litellm_params: dict,
|
||||
) -> Tuple["URL", str]:
|
||||
"""
|
||||
Get the complete url for the request
|
||||
Returns:
|
||||
- complete_url: URL - the complete url for the request
|
||||
- base_target_url: str - the base url to add the endpoint to. Useful for auth headers.
|
||||
"""
|
||||
pass
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
headers: dict,
|
||||
litellm_params: dict,
|
||||
request_data: Optional[dict],
|
||||
api_base: str,
|
||||
model: Optional[str] = None,
|
||||
) -> Tuple[dict, Optional[bytes]]:
|
||||
"""
|
||||
Some providers like Bedrock require signing the request. The sign request funtion needs access to `request_data` and `complete_url`
|
||||
Args:
|
||||
headers: dict
|
||||
optional_params: dict
|
||||
request_data: dict - the request body being sent in http request
|
||||
api_base: str - the complete url being sent in http request
|
||||
Returns:
|
||||
dict - the signed headers
|
||||
|
||||
Update the headers with the signed headers in this function. The return values will be sent as headers in the http request.
|
||||
"""
|
||||
return headers, None
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, "Headers"]
|
||||
) -> "BaseLLMException":
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
return BaseLLMException(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
def logging_non_streaming_response(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
httpx_response: "Response",
|
||||
request_data: dict,
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
endpoint: str,
|
||||
) -> Optional["CostResponseTypes"]:
|
||||
pass
|
||||
|
||||
def handle_logging_collected_chunks(
|
||||
self,
|
||||
all_chunks: List[str],
|
||||
litellm_logging_obj: "LiteLLMLoggingObj",
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
endpoint: str,
|
||||
) -> Optional["CostResponseTypes"]:
|
||||
return None
|
||||
|
||||
def _convert_raw_bytes_to_str_lines(self, raw_bytes: List[bytes]) -> List[str]:
|
||||
"""
|
||||
Converts a list of raw bytes into a list of string lines, similar to aiter_lines()
|
||||
|
||||
Args:
|
||||
raw_bytes: List of bytes chunks from aiter.bytes()
|
||||
|
||||
Returns:
|
||||
List of string lines, with each line being a complete data: {} chunk
|
||||
"""
|
||||
# Combine all bytes and decode to string
|
||||
combined_str = b"".join(raw_bytes).decode("utf-8")
|
||||
|
||||
# Split by newlines and filter out empty lines
|
||||
lines = [line.strip() for line in combined_str.split("\n") if line.strip()]
|
||||
|
||||
return lines
|
||||
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Base transformation class for realtime HTTP endpoints (client_secrets, realtime_calls).
|
||||
|
||||
These are HTTP (not WebSocket) endpoints used by the WebRTC flow:
|
||||
POST /v1/realtime/client_secrets — obtains a short-lived ephemeral key
|
||||
POST /v1/realtime/calls — exchanges an SDP offer using that key
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class BaseRealtimeHTTPConfig(ABC):
|
||||
"""
|
||||
Abstract base for provider-specific realtime HTTP credential / URL logic.
|
||||
|
||||
Implement one subclass per provider (OpenAI, Azure, …).
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Credential resolution #
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
@abstractmethod
|
||||
def get_api_base(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Resolve the provider API base URL.
|
||||
|
||||
Resolution order (provider-specific):
|
||||
explicit api_base → litellm.api_base → env var → hard-coded default
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_api_key(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Resolve the provider API key.
|
||||
|
||||
Resolution order (provider-specific):
|
||||
explicit api_key → litellm.api_key → env var → ""
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# client_secrets endpoint #
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(
|
||||
self, api_base: Optional[str], model: str, api_version: Optional[str] = None
|
||||
) -> str:
|
||||
"""Return the full URL for POST /realtime/client_secrets."""
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Build and return the request headers for the client_secrets call.
|
||||
|
||||
Merge `headers` (caller-supplied extras) with auth / content-type
|
||||
headers required by this provider.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# realtime_calls endpoint #
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def get_realtime_calls_url(
|
||||
self, api_base: Optional[str], model: str, api_version: Optional[str] = None
|
||||
) -> str:
|
||||
"""Return the full URL for POST /realtime/calls (SDP exchange)."""
|
||||
base = (api_base or "").rstrip("/")
|
||||
return f"{base}/v1/realtime/calls"
|
||||
|
||||
def get_realtime_calls_headers(self, ephemeral_key: str) -> dict:
|
||||
"""
|
||||
Build headers for the realtime_calls POST.
|
||||
|
||||
The Bearer token here is the ephemeral key obtained from
|
||||
client_secrets, not the long-lived provider key.
|
||||
"""
|
||||
return {
|
||||
"Authorization": f"Bearer {ephemeral_key}",
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Error handling #
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
):
|
||||
"""
|
||||
Map HTTP errors to LiteLLM exception types.
|
||||
|
||||
Default: generic exception. Override in subclasses for provider-specific
|
||||
error mapping (e.g., Azure uses different error codes).
|
||||
"""
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
return BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,83 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.realtime import (
|
||||
RealtimeResponseTransformInput,
|
||||
RealtimeResponseTypedDict,
|
||||
)
|
||||
|
||||
from ..chat.transformation import BaseLLMException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseRealtimeConfig(ABC):
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(
|
||||
self, api_base: Optional[str], model: str, api_key: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
raise BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def transform_realtime_request(
|
||||
self,
|
||||
message: str,
|
||||
model: str,
|
||||
session_configuration_request: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
pass
|
||||
|
||||
def requires_session_configuration(
|
||||
self,
|
||||
) -> bool: # initial configuration message sent to setup the realtime session
|
||||
return False
|
||||
|
||||
def session_configuration_request(
|
||||
self, model: str
|
||||
) -> Optional[str]: # message sent to setup the realtime session
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def transform_realtime_response(
|
||||
self,
|
||||
message: Union[str, bytes],
|
||||
model: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
realtime_response_transform_input: RealtimeResponseTransformInput,
|
||||
) -> RealtimeResponseTypedDict: # message sent to setup the realtime session
|
||||
"""
|
||||
Keep this state less - leave the state management (e.g. tracking current_output_item_id, current_response_id, current_conversation_id, current_delta_chunks) to the caller.
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,134 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.rerank import RerankBilledUnits, RerankResponse
|
||||
from litellm.types.utils import ModelInfo
|
||||
|
||||
from ..chat.transformation import BaseLLMException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseRerankConfig(ABC):
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_rerank_request(
|
||||
self,
|
||||
model: str,
|
||||
optional_rerank_params: Dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def transform_rerank_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: RerankResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> RerankResponse:
|
||||
return model_response
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_cohere_rerank_params(self, model: str) -> list:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def map_cohere_rerank_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
max_tokens_per_doc: Optional[int] = None,
|
||||
) -> Dict:
|
||||
pass
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
raise BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def calculate_rerank_cost(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
billed_units: Optional[RerankBilledUnits] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per query for a given rerank model.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- custom_llm_provider: str, the provider used for the model. If provided, used to check if the litellm model info is for that provider.
|
||||
- num_queries: int, the number of queries to calculate the cost for
|
||||
- model_info: ModelInfo, the model info for the given model
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
|
||||
if (
|
||||
model_info is None
|
||||
or "input_cost_per_query" not in model_info
|
||||
or model_info["input_cost_per_query"] is None
|
||||
or billed_units is None
|
||||
):
|
||||
return 0.0, 0.0
|
||||
|
||||
search_units = billed_units.get("search_units")
|
||||
|
||||
if search_units is None:
|
||||
return 0.0, 0.0
|
||||
|
||||
prompt_cost = model_info["input_cost_per_query"] * search_units
|
||||
|
||||
return prompt_cost, 0.0
|
||||
@@ -0,0 +1,283 @@
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.llms.openai import (
|
||||
ResponseInputParam,
|
||||
ResponsesAPIOptionalRequestParams,
|
||||
ResponsesAPIResponse,
|
||||
ResponsesAPIStreamingResponse,
|
||||
)
|
||||
from litellm.types.responses.main import *
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
from ..chat.transformation import BaseLLMException as _BaseLLMException
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
BaseLLMException = _BaseLLMException
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
BaseLLMException = Any
|
||||
|
||||
|
||||
class BaseResponsesAPIConfig(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def map_openai_params(
|
||||
self,
|
||||
response_api_optional_params: ResponsesAPIOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required")
|
||||
return api_base
|
||||
|
||||
@abstractmethod
|
||||
def transform_responses_api_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, ResponseInputParam],
|
||||
response_api_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_response_api_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_streaming_response(
|
||||
self,
|
||||
model: str,
|
||||
parsed_chunk: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIStreamingResponse:
|
||||
"""
|
||||
Transform a parsed streaming response chunk into a ResponsesAPIStreamingResponse
|
||||
"""
|
||||
pass
|
||||
|
||||
#########################################################
|
||||
########## DELETE RESPONSE API TRANSFORMATION ##############
|
||||
#########################################################
|
||||
@abstractmethod
|
||||
def transform_delete_response_api_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_delete_response_api_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> DeleteResponseResult:
|
||||
pass
|
||||
|
||||
#########################################################
|
||||
########## END DELETE RESPONSE API TRANSFORMATION #######
|
||||
#########################################################
|
||||
|
||||
#########################################################
|
||||
########## GET RESPONSE API TRANSFORMATION ###############
|
||||
#########################################################
|
||||
@abstractmethod
|
||||
def transform_get_response_api_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_get_response_api_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
pass
|
||||
|
||||
#########################################################
|
||||
########## LIST INPUT ITEMS API TRANSFORMATION ##########
|
||||
#########################################################
|
||||
@abstractmethod
|
||||
def transform_list_input_items_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
include: Optional[List[str]] = None,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
) -> Tuple[str, Dict]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_list_input_items_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Dict:
|
||||
pass
|
||||
|
||||
#########################################################
|
||||
########## END GET RESPONSE API TRANSFORMATION ##########
|
||||
#########################################################
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
from ..chat.transformation import BaseLLMException
|
||||
|
||||
raise BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Returns True if litellm should fake a stream for the given model and stream value"""
|
||||
return False
|
||||
|
||||
def supports_native_websocket(self) -> bool:
|
||||
"""
|
||||
Returns True if the provider has a native WebSocket endpoint for Responses API.
|
||||
|
||||
Providers with native websocket support can connect directly to wss:// endpoints.
|
||||
Providers without native support will use the ManagedResponsesWebSocketHandler
|
||||
which makes HTTP streaming calls and forwards events over the websocket.
|
||||
|
||||
Default: False (use managed websocket handler)
|
||||
"""
|
||||
return False
|
||||
|
||||
#########################################################
|
||||
########## CANCEL RESPONSE API TRANSFORMATION ##########
|
||||
#########################################################
|
||||
@abstractmethod
|
||||
def transform_cancel_response_api_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_cancel_response_api_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
pass
|
||||
|
||||
#########################################################
|
||||
########## END CANCEL RESPONSE API TRANSFORMATION #######
|
||||
#########################################################
|
||||
|
||||
#########################################################
|
||||
########## COMPACT RESPONSE API TRANSFORMATION ##########
|
||||
#########################################################
|
||||
@abstractmethod
|
||||
def transform_compact_response_api_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, ResponseInputParam],
|
||||
response_api_optional_request_params: Dict,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_compact_response_api_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
pass
|
||||
|
||||
#########################################################
|
||||
########## END COMPACT RESPONSE API TRANSFORMATION ######
|
||||
#########################################################
|
||||
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Base Search API module.
|
||||
"""
|
||||
from litellm.llms.base_llm.search.transformation import (
|
||||
BaseSearchConfig,
|
||||
SearchResponse,
|
||||
SearchResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseSearchConfig",
|
||||
"SearchResponse",
|
||||
"SearchResult",
|
||||
]
|
||||
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
Base Search transformation configuration.
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class SearchResult(LiteLLMPydanticObjectBase):
|
||||
"""Single search result."""
|
||||
|
||||
title: str
|
||||
url: str
|
||||
snippet: str
|
||||
date: Optional[str] = None
|
||||
last_updated: Optional[str] = None
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class SearchResponse(LiteLLMPydanticObjectBase):
|
||||
"""
|
||||
Standard Search response format.
|
||||
Standardized to Perplexity Search format - other providers should transform to this format.
|
||||
"""
|
||||
|
||||
results: List[SearchResult]
|
||||
object: str = "search"
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
# Define private attributes using PrivateAttr
|
||||
_hidden_params: dict = PrivateAttr(default_factory=dict)
|
||||
|
||||
|
||||
class BaseSearchConfig:
|
||||
"""
|
||||
Base configuration for Search transformations.
|
||||
Handles provider-agnostic Search operations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def ui_friendly_name() -> str:
|
||||
"""
|
||||
UI-friendly name for the search provider.
|
||||
Override in provider-specific implementations.
|
||||
"""
|
||||
return "Unknown Search Provider"
|
||||
|
||||
def get_http_method(self) -> Literal["GET", "POST"]:
|
||||
"""
|
||||
Get HTTP method for search requests.
|
||||
Override in provider-specific implementations if needed.
|
||||
|
||||
Returns:
|
||||
HTTP method ('GET' or 'POST'). Default is 'POST'.
|
||||
"""
|
||||
return "POST"
|
||||
|
||||
@staticmethod
|
||||
def get_supported_perplexity_optional_params() -> set:
|
||||
"""
|
||||
Get the set of Perplexity unified search parameters.
|
||||
These are the standard parameters that providers should transform from.
|
||||
|
||||
Returns:
|
||||
Set of parameter names that are part of the unified spec
|
||||
"""
|
||||
return {
|
||||
"max_results",
|
||||
"search_domain_filter",
|
||||
"country",
|
||||
"max_tokens_per_page",
|
||||
}
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: Dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""
|
||||
Validate environment and return headers.
|
||||
Override in provider-specific implementations.
|
||||
"""
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
optional_params: dict,
|
||||
data: Optional[Union[Dict, List[Dict]]] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Get complete URL for Search endpoint.
|
||||
|
||||
Args:
|
||||
api_base: Base URL for the API
|
||||
optional_params: Optional parameters for the request
|
||||
data: Transformed request body from transform_search_request().
|
||||
Some providers (e.g., Google PSE) use GET requests and need
|
||||
the request body to construct query parameters in the URL.
|
||||
Can be a dict or list of dicts depending on provider.
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
Complete URL for the search endpoint
|
||||
|
||||
Note:
|
||||
Override in provider-specific implementations.
|
||||
"""
|
||||
raise NotImplementedError("get_complete_url must be implemented by provider")
|
||||
|
||||
def transform_search_request(
|
||||
self,
|
||||
query: Union[str, List[str]],
|
||||
optional_params: dict,
|
||||
**kwargs,
|
||||
) -> Union[Dict, List[Dict]]:
|
||||
"""
|
||||
Transform Search request to provider-specific format.
|
||||
Override in provider-specific implementations.
|
||||
|
||||
Args:
|
||||
query: Search query (string or list of strings)
|
||||
optional_params: Optional parameters for the request
|
||||
|
||||
Returns:
|
||||
Dict with request data
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"transform_search_request must be implemented by provider"
|
||||
)
|
||||
|
||||
def transform_search_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
**kwargs,
|
||||
) -> SearchResponse:
|
||||
"""
|
||||
Transform provider-specific Search response to standard format.
|
||||
Override in provider-specific implementations.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"transform_search_response must be implemented by provider"
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self,
|
||||
error_message: str,
|
||||
status_code: int,
|
||||
headers: dict,
|
||||
) -> Exception:
|
||||
"""Get appropriate error class for the provider."""
|
||||
return BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Base Skills API configuration"""
|
||||
|
||||
from .transformation import BaseSkillsAPIConfig
|
||||
|
||||
__all__ = ["BaseSkillsAPIConfig"]
|
||||
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
Base configuration class for Skills API
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.anthropic_skills import (
|
||||
CreateSkillRequest,
|
||||
DeleteSkillResponse,
|
||||
ListSkillsParams,
|
||||
ListSkillsResponse,
|
||||
Skill,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseSkillsAPIConfig(ABC):
|
||||
"""Base configuration for Skills API providers"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
"""
|
||||
Validate and update headers with provider-specific requirements
|
||||
|
||||
Args:
|
||||
headers: Base headers dictionary
|
||||
litellm_params: LiteLLM parameters
|
||||
|
||||
Returns:
|
||||
Updated headers dictionary
|
||||
"""
|
||||
return headers
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
endpoint: str,
|
||||
skill_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for the API request
|
||||
|
||||
Args:
|
||||
api_base: Base API URL
|
||||
endpoint: API endpoint (e.g., 'skills', 'skills/{id}')
|
||||
skill_id: Optional skill ID for specific skill operations
|
||||
|
||||
Returns:
|
||||
Complete URL
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required")
|
||||
return f"{api_base}/v1/{endpoint}"
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_skill_request(
|
||||
self,
|
||||
create_request: CreateSkillRequest,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
"""
|
||||
Transform create skill request to provider-specific format
|
||||
|
||||
Args:
|
||||
create_request: Skill creation parameters
|
||||
litellm_params: LiteLLM parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Provider-specific request body
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_skill_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Skill:
|
||||
"""
|
||||
Transform provider response to Skill object
|
||||
|
||||
Args:
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
Skill object
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_list_skills_request(
|
||||
self,
|
||||
list_params: ListSkillsParams,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform list skills request parameters
|
||||
|
||||
Args:
|
||||
list_params: List parameters (pagination, filters)
|
||||
litellm_params: LiteLLM parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Tuple of (url, query_params)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_list_skills_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ListSkillsResponse:
|
||||
"""
|
||||
Transform provider response to ListSkillsResponse
|
||||
|
||||
Args:
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
ListSkillsResponse object
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_get_skill_request(
|
||||
self,
|
||||
skill_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform get skill request
|
||||
|
||||
Args:
|
||||
skill_id: Skill ID
|
||||
api_base: Base API URL
|
||||
litellm_params: LiteLLM parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Tuple of (url, headers)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_get_skill_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Skill:
|
||||
"""
|
||||
Transform provider response to Skill object
|
||||
|
||||
Args:
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
Skill object
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_delete_skill_request(
|
||||
self,
|
||||
skill_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform delete skill request
|
||||
|
||||
Args:
|
||||
skill_id: Skill ID
|
||||
api_base: Base API URL
|
||||
litellm_params: LiteLLM parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Tuple of (url, headers)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_delete_skill_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> DeleteSkillResponse:
|
||||
"""
|
||||
Transform provider response to DeleteSkillResponse
|
||||
|
||||
Args:
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
DeleteSkillResponse object
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_error_class(
|
||||
self,
|
||||
error_message: str,
|
||||
status_code: int,
|
||||
headers: dict,
|
||||
) -> Exception:
|
||||
"""Get appropriate error class for the provider."""
|
||||
return BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,149 @@
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import httpx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.types.llms.openai import (
|
||||
HttpxBinaryResponseContent as _HttpxBinaryResponseContent,
|
||||
)
|
||||
|
||||
from ..chat.transformation import BaseLLMException as _BaseLLMException
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
BaseLLMException = _BaseLLMException
|
||||
HttpxBinaryResponseContent = _HttpxBinaryResponseContent
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
BaseLLMException = Any
|
||||
HttpxBinaryResponseContent = Any
|
||||
|
||||
|
||||
class TextToSpeechRequestData(TypedDict, total=False):
|
||||
"""
|
||||
Structured return type for text-to-speech transformations.
|
||||
|
||||
This ensures a consistent interface across all TTS providers.
|
||||
Providers should set ONE of: dict_body, ssml_body, or text_body.
|
||||
"""
|
||||
|
||||
dict_body: Dict[str, Any] # JSON request body (e.g., OpenAI TTS)
|
||||
ssml_body: str # SSML/XML string body (e.g., Azure AVA TTS)
|
||||
headers: Dict[str, str] # Provider-specific headers to merge with base headers
|
||||
|
||||
|
||||
class BaseTextToSpeechConfig(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get list of OpenAI TTS parameters supported by this provider
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def map_openai_params(
|
||||
self,
|
||||
model: str,
|
||||
optional_params: Dict,
|
||||
voice: Optional[Union[str, Dict]] = None,
|
||||
drop_params: bool = False,
|
||||
kwargs: Dict = {},
|
||||
) -> Tuple[Optional[str], Dict]:
|
||||
"""
|
||||
Map OpenAI TTS parameters to provider-specific parameters
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate environment and return headers
|
||||
"""
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete url for the request
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required")
|
||||
return api_base
|
||||
|
||||
@abstractmethod
|
||||
def transform_text_to_speech_request(
|
||||
self,
|
||||
model: str,
|
||||
input: str,
|
||||
voice: Optional[str],
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
headers: dict,
|
||||
) -> TextToSpeechRequestData:
|
||||
"""
|
||||
Transform request to provider-specific format.
|
||||
|
||||
Returns:
|
||||
TextToSpeechRequestData: A structured dict containing:
|
||||
- body: The request body (JSON dict, XML string, or binary data)
|
||||
- headers: Provider-specific headers to merge with base headers
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_text_to_speech_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> "HttpxBinaryResponseContent":
|
||||
"""
|
||||
Transform provider response to standard format
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Dict
|
||||
) -> BaseLLMException:
|
||||
from ..chat.transformation import BaseLLMException
|
||||
|
||||
raise BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,163 @@
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.vector_stores import (
|
||||
VECTOR_STORE_OPENAI_PARAMS,
|
||||
BaseVectorStoreAuthCredentials,
|
||||
VectorStoreCreateOptionalRequestParams,
|
||||
VectorStoreCreateResponse,
|
||||
VectorStoreIndexEndpoints,
|
||||
VectorStoreSearchOptionalRequestParams,
|
||||
VectorStoreSearchResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
from ..chat.transformation import BaseLLMException as _BaseLLMException
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
BaseLLMException = _BaseLLMException
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
BaseLLMException = Any
|
||||
|
||||
|
||||
class BaseVectorStoreConfig:
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[VECTOR_STORE_OPENAI_PARAMS]:
|
||||
return []
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return optional_params
|
||||
|
||||
@abstractmethod
|
||||
def get_auth_credentials(
|
||||
self, litellm_params: dict
|
||||
) -> BaseVectorStoreAuthCredentials:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_vector_store_endpoints_by_type(self) -> VectorStoreIndexEndpoints:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_search_vector_store_request(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: Union[str, List[str]],
|
||||
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
|
||||
api_base: str,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
pass
|
||||
|
||||
async def atransform_search_vector_store_request(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: Union[str, List[str]],
|
||||
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
|
||||
api_base: str,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Optional async version of transform_search_vector_store_request.
|
||||
If not implemented, the handler will fall back to the sync version.
|
||||
Providers that need to make async calls (e.g., generating embeddings) should override this.
|
||||
"""
|
||||
# Default implementation: call the sync version
|
||||
return self.transform_search_vector_store_request(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
vector_store_search_optional_params=vector_store_search_optional_params,
|
||||
api_base=api_base,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def transform_search_vector_store_response(
|
||||
self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj
|
||||
) -> VectorStoreSearchResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_vector_store_request(
|
||||
self,
|
||||
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_vector_store_response(
|
||||
self, response: httpx.Response
|
||||
) -> VectorStoreCreateResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required")
|
||||
return api_base
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
from ..chat.transformation import BaseLLMException
|
||||
|
||||
raise BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
headers: dict,
|
||||
optional_params: Dict,
|
||||
request_data: Dict,
|
||||
api_base: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> Tuple[dict, Optional[bytes]]:
|
||||
"""Optionally sign or modify the request before sending.
|
||||
|
||||
Providers like AWS Bedrock require SigV4 signing. Providers that don't
|
||||
require any signing can simply return the headers unchanged and ``None``
|
||||
for the signed body.
|
||||
"""
|
||||
return headers, None
|
||||
|
||||
def calculate_vector_store_cost(
|
||||
self,
|
||||
response: VectorStoreSearchResponse,
|
||||
) -> Tuple[float, float]:
|
||||
return 0.0, 0.0
|
||||
@@ -0,0 +1,226 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.vector_store_files import (
|
||||
VectorStoreFileAuthCredentials,
|
||||
VectorStoreFileChunkingStrategy,
|
||||
VectorStoreFileContentResponse,
|
||||
VectorStoreFileCreateRequest,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileListQueryParams,
|
||||
VectorStoreFileListResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileUpdateRequest,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
from ..chat.transformation import BaseLLMException as _BaseLLMException
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
BaseLLMException = _BaseLLMException
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
BaseLLMException = Any
|
||||
|
||||
|
||||
class BaseVectorStoreFilesConfig(ABC):
|
||||
"""Base configuration contract for provider-specific vector store file implementations."""
|
||||
|
||||
def get_supported_openai_params(
|
||||
self,
|
||||
operation: str,
|
||||
) -> Tuple[str, ...]:
|
||||
"""Return the set of OpenAI params supported for the given operation."""
|
||||
|
||||
return tuple()
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
*,
|
||||
operation: str,
|
||||
non_default_params: Dict[str, Any],
|
||||
optional_params: Dict[str, Any],
|
||||
drop_params: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""Map non-default OpenAI params to provider-specific params."""
|
||||
|
||||
return optional_params
|
||||
|
||||
@abstractmethod
|
||||
def get_auth_credentials(
|
||||
self, litellm_params: Dict[str, Any]
|
||||
) -> VectorStoreFileAuthCredentials:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_vector_store_file_endpoints_by_type(
|
||||
self,
|
||||
) -> Dict[str, Tuple[Tuple[str, str], ...]]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
*,
|
||||
headers: Dict[str, str],
|
||||
litellm_params: Optional[GenericLiteLLMParams],
|
||||
) -> Dict[str, str]:
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(
|
||||
self,
|
||||
*,
|
||||
api_base: Optional[str],
|
||||
vector_store_id: str,
|
||||
litellm_params: Dict[str, Any],
|
||||
) -> str:
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required")
|
||||
return api_base
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_vector_store_file_request(
|
||||
self,
|
||||
*,
|
||||
vector_store_id: str,
|
||||
create_request: VectorStoreFileCreateRequest,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_vector_store_file_response(
|
||||
self,
|
||||
*,
|
||||
response: httpx.Response,
|
||||
) -> VectorStoreFileObject:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_list_vector_store_files_request(
|
||||
self,
|
||||
*,
|
||||
vector_store_id: str,
|
||||
query_params: VectorStoreFileListQueryParams,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_list_vector_store_files_response(
|
||||
self,
|
||||
*,
|
||||
response: httpx.Response,
|
||||
) -> VectorStoreFileListResponse:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_retrieve_vector_store_file_request(
|
||||
self,
|
||||
*,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_retrieve_vector_store_file_response(
|
||||
self,
|
||||
*,
|
||||
response: httpx.Response,
|
||||
) -> VectorStoreFileObject:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_retrieve_vector_store_file_content_request(
|
||||
self,
|
||||
*,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_retrieve_vector_store_file_content_response(
|
||||
self,
|
||||
*,
|
||||
response: httpx.Response,
|
||||
) -> VectorStoreFileContentResponse:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_update_vector_store_file_request(
|
||||
self,
|
||||
*,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
update_request: VectorStoreFileUpdateRequest,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_update_vector_store_file_response(
|
||||
self,
|
||||
*,
|
||||
response: httpx.Response,
|
||||
) -> VectorStoreFileObject:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_delete_vector_store_file_request(
|
||||
self,
|
||||
*,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def transform_delete_vector_store_file_response(
|
||||
self,
|
||||
*,
|
||||
response: httpx.Response,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
...
|
||||
|
||||
def get_error_class(
|
||||
self,
|
||||
*,
|
||||
error_message: str,
|
||||
status_code: int,
|
||||
headers: Union[Dict[str, Any], httpx.Headers],
|
||||
) -> BaseLLMException:
|
||||
from ..chat.transformation import BaseLLMException
|
||||
|
||||
raise BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
*,
|
||||
headers: Dict[str, str],
|
||||
optional_params: Dict[str, Any],
|
||||
request_data: Dict[str, Any],
|
||||
api_base: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> Tuple[Dict[str, str], Optional[bytes]]:
|
||||
return headers, None
|
||||
|
||||
def prepare_chunking_strategy(
|
||||
self,
|
||||
chunking_strategy: Optional[VectorStoreFileChunkingStrategy],
|
||||
) -> Optional[VectorStoreFileChunkingStrategy]:
|
||||
return chunking_strategy
|
||||
@@ -0,0 +1,277 @@
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
from httpx._types import RequestFiles
|
||||
|
||||
from litellm.types.responses.main import *
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.videos.main import VideoCreateOptionalRequestParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.types.videos.main import VideoObject as _VideoObject
|
||||
|
||||
from ..chat.transformation import BaseLLMException as _BaseLLMException
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
BaseLLMException = _BaseLLMException
|
||||
VideoObject = _VideoObject
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
BaseLLMException = Any
|
||||
VideoObject = Any
|
||||
|
||||
|
||||
class BaseVideoConfig(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def map_openai_params(
|
||||
self,
|
||||
video_create_optional_params: VideoCreateOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
litellm_params: Optional[GenericLiteLLMParams] = None,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required")
|
||||
return api_base
|
||||
|
||||
@abstractmethod
|
||||
def transform_video_create_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
api_base: str,
|
||||
video_create_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[Dict, RequestFiles, str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_video_create_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
request_data: Optional[Dict] = None,
|
||||
) -> VideoObject:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_video_content_request(
|
||||
self,
|
||||
video_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
variant: Optional[str] = None,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the video content request into a URL and data/params
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (url, params) for the video content request
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_video_content_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> bytes:
|
||||
pass
|
||||
|
||||
async def async_transform_video_content_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> bytes:
|
||||
"""
|
||||
Async transform video content download response to bytes.
|
||||
Optional method - providers can override if they need async transformations
|
||||
(e.g., RunwayML for downloading video from CloudFront URL).
|
||||
|
||||
Default implementation falls back to sync transform_video_content_response.
|
||||
|
||||
Args:
|
||||
raw_response: Raw HTTP response
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
Video content as bytes
|
||||
"""
|
||||
# Default implementation: call sync version
|
||||
return self.transform_video_content_response(
|
||||
raw_response=raw_response,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def transform_video_remix_request(
|
||||
self,
|
||||
video_id: str,
|
||||
prompt: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the video remix request into a URL and data
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (url, data) for the video remix request
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_video_remix_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> VideoObject:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_video_list_request(
|
||||
self,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
order: Optional[str] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the video list request into a URL and params
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (url, params) for the video list request
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_video_list_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_video_delete_request(
|
||||
self,
|
||||
video_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the video delete request into a URL and data
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (url, data) for the video delete request
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_video_delete_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> VideoObject:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_video_status_retrieve_request(
|
||||
self,
|
||||
video_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the video retrieve request into a URL and data/params
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (url, params) for the video retrieve request
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_video_status_retrieve_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> VideoObject:
|
||||
pass
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
from ..chat.transformation import BaseLLMException
|
||||
|
||||
raise BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
Reference in New Issue
Block a user