chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,23 @@
"""
GigaChat Provider for LiteLLM
GigaChat is Sber AI's large language model (Russia's leading LLM).
Supports:
- Chat completions (sync/async)
- Streaming (sync/async)
- Function calling / Tools
- Structured output via JSON schema (emulated through function calls)
- Image input (base64 and URL)
- Embeddings
API Documentation: https://developers.sber.ru/docs/ru/gigachat/api/overview
"""
from .chat.transformation import GigaChatConfig, GigaChatError
from .embedding.transformation import GigaChatEmbeddingConfig
__all__ = [
"GigaChatConfig",
"GigaChatEmbeddingConfig",
"GigaChatError",
]

View File

@@ -0,0 +1,245 @@
"""
GigaChat OAuth Authenticator
Handles OAuth 2.0 token management for GigaChat API.
Based on official GigaChat SDK authentication flow.
"""
import time
import uuid
from typing import Optional, Tuple
import httpx
from litellm._logging import verbose_logger
from litellm.caching.caching import InMemoryCache
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.custom_httpx.http_handler import (
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import LlmProviders
# GigaChat OAuth endpoint
GIGACHAT_AUTH_URL = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth"
# Default scope for personal API access
GIGACHAT_SCOPE = "GIGACHAT_API_PERS"
# Token expiry buffer in milliseconds (refresh token 60s before expiry)
TOKEN_EXPIRY_BUFFER_MS = 60000
# Cache for access tokens
_token_cache = InMemoryCache()
class GigaChatAuthError(BaseLLMException):
"""GigaChat authentication error."""
pass
def _get_credentials() -> Optional[str]:
"""Get GigaChat credentials from environment."""
return get_secret_str("GIGACHAT_CREDENTIALS") or get_secret_str("GIGACHAT_API_KEY")
def _get_auth_url() -> str:
"""Get GigaChat auth URL from environment or use default."""
return get_secret_str("GIGACHAT_AUTH_URL") or GIGACHAT_AUTH_URL
def _get_scope() -> str:
"""Get GigaChat scope from environment or use default."""
return get_secret_str("GIGACHAT_SCOPE") or GIGACHAT_SCOPE
def _get_http_client() -> HTTPHandler:
"""Get cached httpx client with SSL verification disabled."""
return _get_httpx_client(params={"ssl_verify": False})
def get_access_token(
credentials: Optional[str] = None,
scope: Optional[str] = None,
auth_url: Optional[str] = None,
) -> str:
"""
Get valid access token, using cache if available.
Args:
credentials: Base64-encoded credentials (client_id:client_secret)
scope: API scope (GIGACHAT_API_PERS, GIGACHAT_API_CORP, etc.)
auth_url: OAuth endpoint URL
Returns:
Access token string
Raises:
GigaChatAuthError: If authentication fails
"""
credentials = credentials or _get_credentials()
if not credentials:
raise GigaChatAuthError(
status_code=401,
message="GigaChat credentials not provided. Set GIGACHAT_CREDENTIALS or GIGACHAT_API_KEY environment variable.",
)
scope = scope or _get_scope()
auth_url = auth_url or _get_auth_url()
# Check cache
cache_key = f"gigachat_token:{credentials[:16]}"
cached = _token_cache.get_cache(cache_key)
if cached:
token, expires_at = cached
# Check if token is still valid (with buffer)
if time.time() * 1000 < expires_at - TOKEN_EXPIRY_BUFFER_MS:
verbose_logger.debug("Using cached GigaChat access token")
return token
# Request new token
token, expires_at = _request_token_sync(credentials, scope, auth_url)
# Cache token
ttl_seconds = max(
0, (expires_at - TOKEN_EXPIRY_BUFFER_MS - time.time() * 1000) / 1000
)
if ttl_seconds > 0:
_token_cache.set_cache(cache_key, (token, expires_at), ttl=ttl_seconds)
return token
async def get_access_token_async(
credentials: Optional[str] = None,
scope: Optional[str] = None,
auth_url: Optional[str] = None,
) -> str:
"""Async version of get_access_token."""
credentials = credentials or _get_credentials()
if not credentials:
raise GigaChatAuthError(
status_code=401,
message="GigaChat credentials not provided. Set GIGACHAT_CREDENTIALS or GIGACHAT_API_KEY environment variable.",
)
scope = scope or _get_scope()
auth_url = auth_url or _get_auth_url()
# Check cache
cache_key = f"gigachat_token:{credentials[:16]}"
cached = _token_cache.get_cache(cache_key)
if cached:
token, expires_at = cached
if time.time() * 1000 < expires_at - TOKEN_EXPIRY_BUFFER_MS:
verbose_logger.debug("Using cached GigaChat access token")
return token
# Request new token
token, expires_at = await _request_token_async(credentials, scope, auth_url)
# Cache token
ttl_seconds = max(
0, (expires_at - TOKEN_EXPIRY_BUFFER_MS - time.time() * 1000) / 1000
)
if ttl_seconds > 0:
_token_cache.set_cache(cache_key, (token, expires_at), ttl=ttl_seconds)
return token
def _request_token_sync(
credentials: str,
scope: str,
auth_url: str,
) -> Tuple[str, int]:
"""
Request new access token from GigaChat OAuth endpoint (sync).
Returns:
Tuple of (access_token, expires_at_ms)
"""
headers = {
"Authorization": f"Basic {credentials}",
"RqUID": str(uuid.uuid4()),
"Content-Type": "application/x-www-form-urlencoded",
}
data = {"scope": scope}
verbose_logger.debug(f"Requesting GigaChat access token from {auth_url}")
try:
client = _get_http_client()
response = client.post(auth_url, headers=headers, data=data, timeout=30)
response.raise_for_status()
return _parse_token_response(response)
except httpx.HTTPStatusError as e:
raise GigaChatAuthError(
status_code=e.response.status_code,
message=f"GigaChat authentication failed: {e.response.text}",
)
except httpx.RequestError as e:
raise GigaChatAuthError(
status_code=500,
message=f"GigaChat authentication request failed: {str(e)}",
)
async def _request_token_async(
credentials: str,
scope: str,
auth_url: str,
) -> Tuple[str, int]:
"""Async version of _request_token_sync."""
headers = {
"Authorization": f"Basic {credentials}",
"RqUID": str(uuid.uuid4()),
"Content-Type": "application/x-www-form-urlencoded",
}
data = {"scope": scope}
verbose_logger.debug(f"Requesting GigaChat access token from {auth_url}")
try:
client = get_async_httpx_client(
llm_provider=LlmProviders.GIGACHAT,
params={"ssl_verify": False},
)
response = await client.post(auth_url, headers=headers, data=data, timeout=30)
response.raise_for_status()
return _parse_token_response(response)
except httpx.HTTPStatusError as e:
raise GigaChatAuthError(
status_code=e.response.status_code,
message=f"GigaChat authentication failed: {e.response.text}",
)
except httpx.RequestError as e:
raise GigaChatAuthError(
status_code=500,
message=f"GigaChat authentication request failed: {str(e)}",
)
def _parse_token_response(response: httpx.Response) -> Tuple[str, int]:
"""Parse OAuth token response."""
data = response.json()
# GigaChat returns either 'tok'/'exp' or 'access_token'/'expires_at'
access_token = data.get("tok") or data.get("access_token")
expires_at = data.get("exp") or data.get("expires_at")
if not access_token:
raise GigaChatAuthError(
status_code=500,
message=f"Invalid token response: {data}",
)
# expires_at is in milliseconds
if isinstance(expires_at, str):
expires_at = int(expires_at)
verbose_logger.debug("GigaChat access token obtained successfully")
return access_token, expires_at

View File

@@ -0,0 +1,12 @@
"""
GigaChat Chat Module
"""
from .transformation import GigaChatConfig, GigaChatError
from .streaming import GigaChatModelResponseIterator
__all__ = [
"GigaChatConfig",
"GigaChatError",
"GigaChatModelResponseIterator",
]

View File

@@ -0,0 +1,137 @@
"""
GigaChat Streaming Response Handler
"""
import json
import uuid
from typing import Any, Optional
from litellm.types.llms.openai import (
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
)
from litellm.types.utils import GenericStreamingChunk
class GigaChatModelResponseIterator:
"""Iterator for GigaChat streaming responses."""
def __init__(
self,
streaming_response: Any,
sync_stream: bool,
json_mode: Optional[bool] = False,
):
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.json_mode = json_mode
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
"""Parse a single streaming chunk from GigaChat."""
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason: Optional[str] = None
choices = chunk.get("choices", [])
if not choices:
return GenericStreamingChunk(
text="",
tool_use=None,
is_finished=False,
finish_reason="",
usage=None,
index=0,
)
choice = choices[0]
delta = choice.get("delta", {})
finish_reason = choice.get("finish_reason")
# Extract text content
text = delta.get("content", "") or ""
# Handle function_call in stream
if finish_reason == "function_call" and delta.get("function_call"):
func_call = delta["function_call"]
args = func_call.get("arguments", {})
if isinstance(args, dict):
args = json.dumps(args, ensure_ascii=False)
tool_use = ChatCompletionToolCallChunk(
id=f"call_{uuid.uuid4().hex[:24]}",
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=func_call.get("name", ""),
arguments=args,
),
index=0,
)
finish_reason = "tool_calls"
if finish_reason is not None:
is_finished = True
return GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason or "",
usage=None,
index=choice.get("index", 0),
)
def __iter__(self):
return self
def __next__(self) -> GenericStreamingChunk:
try:
chunk = self.response_iterator.__next__()
if isinstance(chunk, str):
# Parse SSE format: data: {...}
if chunk.startswith("data: "):
chunk = chunk[6:]
if chunk.strip() == "[DONE]":
raise StopIteration
try:
chunk = json.loads(chunk)
except json.JSONDecodeError:
return GenericStreamingChunk(
text="",
tool_use=None,
is_finished=False,
finish_reason="",
usage=None,
index=0,
)
return self.chunk_parser(chunk)
except StopIteration:
raise
def __aiter__(self):
return self
async def __anext__(self) -> GenericStreamingChunk:
try:
chunk = await self.response_iterator.__anext__()
if isinstance(chunk, str):
# Parse SSE format
if chunk.startswith("data: "):
chunk = chunk[6:]
if chunk.strip() == "[DONE]":
raise StopAsyncIteration
try:
chunk = json.loads(chunk)
except json.JSONDecodeError:
return GenericStreamingChunk(
text="",
tool_use=None,
is_finished=False,
finish_reason="",
usage=None,
index=0,
)
return self.chunk_parser(chunk)
except StopAsyncIteration:
raise

View File

@@ -0,0 +1,510 @@
"""
GigaChat Chat Transformation
Transforms OpenAI-format requests to GigaChat format and back.
"""
import json
import time
import uuid
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
import httpx
from litellm._logging import verbose_logger
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage
from ..authenticator import get_access_token
from ..file_handler import upload_file_sync
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
# GigaChat API endpoint
GIGACHAT_BASE_URL = "https://gigachat.devices.sberbank.ru/api/v1"
def is_valid_json(value: str) -> bool:
"""Checks whether the value passed is a valid serialized JSON string"""
try:
json.loads(value)
except json.JSONDecodeError:
return False
else:
return True
class GigaChatError(BaseLLMException):
"""GigaChat API error."""
pass
class GigaChatConfig(BaseConfig):
"""
Configuration class for GigaChat API.
GigaChat is Sber's (Russia's largest bank) LLM API.
Supported parameters:
temperature: Sampling temperature (0-2, default 0.87)
top_p: Nucleus sampling parameter
max_tokens: Maximum tokens to generate
repetition_penalty: Repetition penalty factor
profanity_check: Enable content filtering
stream: Enable streaming
"""
temperature: Optional[float] = None
top_p: Optional[float] = None
max_tokens: Optional[int] = None
repetition_penalty: Optional[float] = None
profanity_check: Optional[bool] = None
def __init__(
self,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
profanity_check: Optional[bool] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
# Instance variables for current request context
self._current_credentials: Optional[str] = None
self._current_api_base: Optional[str] = None
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""Get complete API URL for chat completions."""
base = api_base or get_secret_str("GIGACHAT_API_BASE") or GIGACHAT_BASE_URL
return f"{base}/chat/completions"
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Set up headers with OAuth token.
"""
# Get access token
credentials = (
api_key
or get_secret_str("GIGACHAT_CREDENTIALS")
or get_secret_str("GIGACHAT_API_KEY")
)
access_token = get_access_token(credentials=credentials)
# Store credentials for image uploads
self._current_credentials = credentials
self._current_api_base = api_base
headers["Authorization"] = f"Bearer {access_token}"
headers["Content-Type"] = "application/json"
headers["Accept"] = "application/json"
return headers
def get_supported_openai_params(self, model: str) -> List[str]:
"""Return list of supported OpenAI parameters."""
return [
"stream",
"temperature",
"top_p",
"max_tokens",
"max_completion_tokens",
"stop",
"tools",
"tool_choice",
"functions",
"function_call",
"response_format",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""Map OpenAI parameters to GigaChat parameters."""
for param, value in non_default_params.items():
if param == "stream":
optional_params["stream"] = value
elif param == "temperature":
# GigaChat: temperature 0 means use top_p=0 instead
if value == 0:
optional_params["top_p"] = 0
else:
optional_params["temperature"] = value
elif param == "top_p":
optional_params["top_p"] = value
elif param in ("max_tokens", "max_completion_tokens"):
optional_params["max_tokens"] = value
elif param == "stop":
# GigaChat doesn't support stop sequences
pass
elif param == "tools":
# Convert tools to functions format
optional_params["functions"] = self._convert_tools_to_functions(value)
elif param == "tool_choice":
# Map OpenAI tool_choice to GigaChat function_call
mapped_choice = self._map_tool_choice(value)
if mapped_choice is not None:
optional_params["function_call"] = mapped_choice
elif param == "functions":
optional_params["functions"] = value
elif param == "function_call":
optional_params["function_call"] = value
elif param == "response_format":
# Handle structured output via function calling
if value.get("type") == "json_schema":
json_schema = value.get("json_schema", {})
schema_name = json_schema.get("name", "structured_output")
schema = json_schema.get("schema", {})
function_def = {
"name": schema_name,
"description": f"Output structured response: {schema_name}",
"parameters": schema,
}
if "functions" not in optional_params:
optional_params["functions"] = []
optional_params["functions"].append(function_def)
optional_params["function_call"] = {"name": schema_name}
optional_params["_structured_output"] = True
return optional_params
def _convert_tools_to_functions(self, tools: List[dict]) -> List[dict]:
"""Convert OpenAI tools format to GigaChat functions format."""
functions = []
for tool in tools:
if tool.get("type") == "function":
func = tool.get("function", {})
functions.append(
{
"name": func.get("name", ""),
"description": func.get("description", ""),
"parameters": func.get("parameters", {}),
}
)
return functions
def _map_tool_choice(
self, tool_choice: Union[str, dict]
) -> Optional[Union[str, dict]]:
"""
Map OpenAI tool_choice to GigaChat function_call format.
OpenAI format:
- "auto": Call zero, one, or multiple functions (default)
- "required": Call one or more functions
- "none": Don't call any functions
- {"type": "function", "function": {"name": "get_weather"}}: Force specific function
GigaChat format:
- "none": Disable function calls
- "auto": Automatic mode (default)
- {"name": "get_weather"}: Force specific function
Args:
tool_choice: OpenAI tool_choice value
Returns:
GigaChat function_call value or None
"""
if tool_choice == "none":
return "none"
elif tool_choice == "auto":
return "auto"
elif tool_choice == "required":
# GigaChat doesn't have a direct "required" equivalent
# Use "auto" as the closest behavior
return "auto"
elif isinstance(tool_choice, dict):
# OpenAI format: {"type": "function", "function": {"name": "func_name"}}
# GigaChat format: {"name": "func_name"}
if tool_choice.get("type") == "function":
func_name = tool_choice.get("function", {}).get("name")
if func_name:
return {"name": func_name}
# Default to None (don't set function_call)
return None
def _upload_image(self, image_url: str) -> Optional[str]:
"""
Upload image to GigaChat and return file_id.
Args:
image_url: URL or base64 data URL of the image
Returns:
file_id string or None if upload failed
"""
try:
return upload_file_sync(
image_url=image_url,
credentials=self._current_credentials,
api_base=self._current_api_base,
)
except Exception as e:
verbose_logger.error(f"Failed to upload image: {e}")
return None
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""Transform OpenAI request to GigaChat format."""
# Transform messages
giga_messages = self._transform_messages(messages)
# Build request
request_data = {
"model": model.replace("gigachat/", ""),
"messages": giga_messages,
}
# Add optional params
for key in [
"temperature",
"top_p",
"max_tokens",
"stream",
"repetition_penalty",
"profanity_check",
]:
if key in optional_params:
request_data[key] = optional_params[key]
# Add functions if present
if "functions" in optional_params:
request_data["functions"] = optional_params["functions"]
if "function_call" in optional_params:
request_data["function_call"] = optional_params["function_call"]
return request_data
def _transform_messages(self, messages: List[AllMessageValues]) -> List[dict]:
"""Transform OpenAI messages to GigaChat format."""
transformed = []
for i, msg in enumerate(messages):
message = dict(msg)
# Remove unsupported fields
message.pop("name", None)
# Transform roles
role = message.get("role", "user")
if role == "developer":
message["role"] = "system"
elif role == "system" and i > 0:
# GigaChat only allows system message as first message
message["role"] = "user"
elif role == "tool":
message["role"] = "function"
content = message.get("content", "")
if not isinstance(content, str) or not is_valid_json(content):
message["content"] = json.dumps(content, ensure_ascii=False)
# Handle None content
if message.get("content") is None:
message["content"] = ""
# Handle list content (multimodal) - extract text and images
content = message.get("content")
if isinstance(content, list):
texts = []
attachments = []
for part in content:
if isinstance(part, dict):
if part.get("type") == "text":
texts.append(part.get("text", ""))
elif part.get("type") == "image_url":
# Extract image URL and upload to GigaChat
image_url = part.get("image_url", {})
if isinstance(image_url, str):
url = image_url
else:
url = image_url.get("url", "")
if url:
file_id = self._upload_image(url)
if file_id:
attachments.append(file_id)
message["content"] = "\n".join(texts) if texts else ""
if attachments:
message["attachments"] = attachments
# Transform tool_calls to function_call
tool_calls = message.get("tool_calls")
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
tool_call = tool_calls[0]
func = tool_call.get("function", {})
args = func.get("arguments", "{}")
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
message["function_call"] = {
"name": func.get("name", ""),
"arguments": args,
}
message.pop("tool_calls", None)
transformed.append(message)
return transformed
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""Transform GigaChat response to OpenAI format."""
try:
response_json = raw_response.json()
except Exception:
raise GigaChatError(
status_code=raw_response.status_code,
message=f"Invalid JSON response: {raw_response.text}",
)
is_structured_output = optional_params.get("_structured_output", False)
choices = []
for choice in response_json.get("choices", []):
message_data = choice.get("message", {})
finish_reason = choice.get("finish_reason", "stop")
# Transform function_call to tool_calls or content
if finish_reason == "function_call" and message_data.get("function_call"):
func_call = message_data["function_call"]
args = func_call.get("arguments", {})
if is_structured_output:
# Convert to content for structured output
if isinstance(args, dict):
content = json.dumps(args, ensure_ascii=False)
else:
content = str(args)
message_data["content"] = content
message_data.pop("function_call", None)
message_data.pop("functions_state_id", None)
finish_reason = "stop"
else:
# Convert to tool_calls format
if isinstance(args, dict):
args = json.dumps(args, ensure_ascii=False)
message_data["tool_calls"] = [
{
"id": f"call_{uuid.uuid4().hex[:24]}",
"type": "function",
"function": {
"name": func_call.get("name", ""),
"arguments": args,
},
}
]
message_data.pop("function_call", None)
finish_reason = "tool_calls"
# Clean up GigaChat-specific fields
message_data.pop("functions_state_id", None)
choices.append(
Choices(
index=choice.get("index", 0),
message=Message(
role=message_data.get("role", "assistant"),
content=message_data.get("content"),
tool_calls=message_data.get("tool_calls"),
),
finish_reason=finish_reason,
)
)
# Build usage
usage_data = response_json.get("usage", {})
usage = Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
model_response.id = response_json.get("id", f"chatcmpl-{uuid.uuid4().hex[:12]}")
model_response.created = response_json.get("created", int(time.time()))
model_response.model = model
model_response.choices = choices # type: ignore
setattr(model_response, "usage", usage)
return model_response
def get_error_class(
self,
error_message: str,
status_code: int,
headers: Union[dict, httpx.Headers],
) -> BaseLLMException:
"""Return GigaChat error class."""
return GigaChatError(
status_code=status_code,
message=error_message,
headers=headers,
)
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
"""Return streaming response iterator."""
from .streaming import GigaChatModelResponseIterator
return GigaChatModelResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)

View File

@@ -0,0 +1,7 @@
"""
GigaChat Embedding Module
"""
from .transformation import GigaChatEmbeddingConfig
__all__ = ["GigaChatEmbeddingConfig"]

View File

@@ -0,0 +1,212 @@
"""
GigaChat Embedding Transformation
Transforms OpenAI /v1/embeddings format to GigaChat format.
API Documentation: https://developers.sber.ru/docs/ru/gigachat/api/reference/rest/post-embeddings
"""
import types
from typing import List, Optional, Tuple, Union
import httpx
from litellm import LlmProviders
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
from litellm.types.utils import EmbeddingResponse
from ..authenticator import get_access_token
# GigaChat API endpoint
GIGACHAT_BASE_URL = "https://gigachat.devices.sberbank.ru/api/v1"
class GigaChatEmbeddingError(BaseLLMException):
"""GigaChat Embedding API error."""
pass
class GigaChatEmbeddingConfig(BaseEmbeddingConfig):
"""
Configuration class for GigaChat Embeddings API.
GigaChat embeddings endpoint: POST /api/v1/embeddings
"""
def __init__(self) -> None:
pass
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> List[str]:
"""GigaChat embeddings don't support additional parameters."""
return []
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""Map OpenAI params to GigaChat format (no special mapping needed)."""
return optional_params
def _get_openai_compatible_provider_info(
self,
api_base: Optional[str],
api_key: Optional[str],
) -> Tuple[str, Optional[str], Optional[str]]:
"""
Returns provider info for GigaChat.
Returns:
Tuple of (custom_llm_provider, api_base, dynamic_api_key)
"""
api_base = api_base or GIGACHAT_BASE_URL
return LlmProviders.GIGACHAT.value, api_base, api_key
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""Get the complete URL for embeddings endpoint."""
base = api_base or GIGACHAT_BASE_URL
return f"{base}/embeddings"
def transform_embedding_request(
self,
model: str,
input: AllEmbeddingInputValues,
optional_params: dict,
headers: dict,
) -> dict:
"""
Transform OpenAI embedding request to GigaChat format.
GigaChat format:
{
"model": "Embeddings",
"input": ["text1", "text2", ...]
}
"""
# Normalize input to list
if isinstance(input, str):
input_list: list = [input]
elif isinstance(input, list):
input_list = input
else:
input_list = [input]
# Remove gigachat/ prefix from model if present
if model.startswith("gigachat/"):
model = model[9:]
return {
"model": model,
"input": input_list,
}
def transform_embedding_response(
self,
model: str,
raw_response: httpx.Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str],
request_data: dict,
optional_params: dict,
litellm_params: dict,
) -> EmbeddingResponse:
"""
Transform GigaChat embedding response to OpenAI format.
GigaChat returns:
{
"object": "list",
"data": [{"object": "embedding", "embedding": [...], "index": 0, "usage": {...}}],
"model": "Embeddings"
}
"""
response_json = raw_response.json()
# Log response
logging_obj.post_call(
input=request_data.get("input"),
api_key=api_key,
additional_args={"complete_input_dict": request_data},
original_response=response_json,
)
# Calculate total tokens from individual embeddings
total_tokens = 0
if "data" in response_json:
for emb in response_json["data"]:
if "usage" in emb and "prompt_tokens" in emb["usage"]:
total_tokens += emb["usage"]["prompt_tokens"]
# Remove usage from individual embeddings (not part of OpenAI format)
if "usage" in emb:
del emb["usage"]
# Set overall usage
response_json["usage"] = {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
}
return EmbeddingResponse(**response_json)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Set up headers with OAuth token for GigaChat.
"""
# Get access token via OAuth
access_token = get_access_token(api_key)
default_headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
}
return {**default_headers, **headers}
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
"""Return GigaChat-specific error class."""
return GigaChatEmbeddingError(
status_code=status_code,
message=error_message,
)

View File

@@ -0,0 +1,211 @@
"""
GigaChat File Handler
Handles file uploads to GigaChat API for image processing.
GigaChat requires files to be uploaded first, then referenced by file_id.
"""
import base64
import hashlib
import re
import uuid
from typing import Dict, Optional, Tuple
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.utils import LlmProviders
from .authenticator import get_access_token, get_access_token_async
# GigaChat API endpoint
GIGACHAT_BASE_URL = "https://gigachat.devices.sberbank.ru/api/v1"
# Simple in-memory cache for file IDs
_file_cache: Dict[str, str] = {}
def _get_url_hash(url: str) -> str:
"""Generate hash for URL to use as cache key."""
return hashlib.sha256(url.encode()).hexdigest()
def _parse_data_url(data_url: str) -> Optional[Tuple[bytes, str, str]]:
"""
Parse data URL (base64 image).
Returns:
Tuple of (content_bytes, content_type, extension) or None
"""
match = re.match(r"data:([^;]+);base64,(.+)", data_url)
if not match:
return None
content_type = match.group(1)
base64_data = match.group(2)
content_bytes = base64.b64decode(base64_data)
ext = content_type.split("/")[-1].split(";")[0] or "jpg"
return content_bytes, content_type, ext
def _download_image_sync(url: str) -> Tuple[bytes, str, str]:
"""Download image from URL synchronously."""
client = _get_httpx_client(params={"ssl_verify": False})
response = client.get(url)
response.raise_for_status()
content_type = response.headers.get("content-type", "image/jpeg")
ext = content_type.split("/")[-1].split(";")[0] or "jpg"
return response.content, content_type, ext
async def _download_image_async(url: str) -> Tuple[bytes, str, str]:
"""Download image from URL asynchronously."""
client = get_async_httpx_client(
llm_provider=LlmProviders.GIGACHAT,
params={"ssl_verify": False},
)
response = await client.get(url)
response.raise_for_status()
content_type = response.headers.get("content-type", "image/jpeg")
ext = content_type.split("/")[-1].split(";")[0] or "jpg"
return response.content, content_type, ext
def upload_file_sync(
image_url: str,
credentials: Optional[str] = None,
api_base: Optional[str] = None,
) -> Optional[str]:
"""
Upload file to GigaChat and return file_id (sync).
Args:
image_url: URL or base64 data URL of the image
credentials: GigaChat credentials for auth
api_base: Optional custom API base URL
Returns:
file_id string or None if upload failed
"""
url_hash = _get_url_hash(image_url)
# Check cache
if url_hash in _file_cache:
verbose_logger.debug(f"Image found in cache: {url_hash[:16]}...")
return _file_cache[url_hash]
try:
# Get image data
parsed = _parse_data_url(image_url)
if parsed:
content_bytes, content_type, ext = parsed
verbose_logger.debug("Decoded base64 image")
else:
verbose_logger.debug(f"Downloading image from URL: {image_url[:80]}...")
content_bytes, content_type, ext = _download_image_sync(image_url)
filename = f"{uuid.uuid4()}.{ext}"
# Get access token
access_token = get_access_token(credentials)
# Upload to GigaChat
base_url = api_base or GIGACHAT_BASE_URL
upload_url = f"{base_url}/files"
client = _get_httpx_client(params={"ssl_verify": False})
response = client.post(
upload_url,
headers={"Authorization": f"Bearer {access_token}"},
files={"file": (filename, content_bytes, content_type)},
data={"purpose": "general"},
timeout=60,
)
response.raise_for_status()
result = response.json()
file_id = result.get("id")
if file_id:
_file_cache[url_hash] = file_id
verbose_logger.debug(f"File uploaded successfully, file_id: {file_id}")
return file_id
except Exception as e:
verbose_logger.error(f"Error uploading file to GigaChat: {e}")
return None
async def upload_file_async(
image_url: str,
credentials: Optional[str] = None,
api_base: Optional[str] = None,
) -> Optional[str]:
"""
Upload file to GigaChat and return file_id (async).
Args:
image_url: URL or base64 data URL of the image
credentials: GigaChat credentials for auth
api_base: Optional custom API base URL
Returns:
file_id string or None if upload failed
"""
url_hash = _get_url_hash(image_url)
# Check cache
if url_hash in _file_cache:
verbose_logger.debug(f"Image found in cache: {url_hash[:16]}...")
return _file_cache[url_hash]
try:
# Get image data
parsed = _parse_data_url(image_url)
if parsed:
content_bytes, content_type, ext = parsed
verbose_logger.debug("Decoded base64 image")
else:
verbose_logger.debug(f"Downloading image from URL: {image_url[:80]}...")
content_bytes, content_type, ext = await _download_image_async(image_url)
filename = f"{uuid.uuid4()}.{ext}"
# Get access token
access_token = await get_access_token_async(credentials)
# Upload to GigaChat
base_url = api_base or GIGACHAT_BASE_URL
upload_url = f"{base_url}/files"
client = get_async_httpx_client(
llm_provider=LlmProviders.GIGACHAT,
params={"ssl_verify": False},
)
response = await client.post(
upload_url,
headers={"Authorization": f"Bearer {access_token}"},
files={"file": (filename, content_bytes, content_type)},
data={"purpose": "general"},
timeout=60,
)
response.raise_for_status()
result = response.json()
file_id = result.get("id")
if file_id:
_file_cache[url_hash] = file_id
verbose_logger.debug(f"File uploaded successfully, file_id: {file_id}")
return file_id
except Exception as e:
verbose_logger.error(f"Error uploading file to GigaChat: {e}")
return None