chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,722 @@
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union, cast
|
||||
|
||||
import aiohttp
|
||||
import httpx # type: ignore
|
||||
from aiohttp import ClientSession, FormData
|
||||
|
||||
import litellm
|
||||
import litellm.litellm_core_utils
|
||||
import litellm.types
|
||||
import litellm.types.utils
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.base_llm.image_variations.transformation import (
|
||||
BaseImageVariationConfig,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
)
|
||||
from litellm.llms.custom_httpx.aiohttp_transport import LiteLLMAiohttpTransport
|
||||
from litellm.types.llms.openai import FileTypes
|
||||
from litellm.types.utils import HttpHandlerRequestFields, ImageResponse, LlmProviders
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
DEFAULT_TIMEOUT = 600
|
||||
|
||||
|
||||
class BaseLLMAIOHTTPHandler:
|
||||
def __init__(
|
||||
self,
|
||||
client_session: Optional[aiohttp.ClientSession] = None,
|
||||
transport: Optional[LiteLLMAiohttpTransport] = None,
|
||||
connector: Optional[aiohttp.BaseConnector] = None,
|
||||
):
|
||||
self.client_session = client_session
|
||||
self._owns_session = (
|
||||
client_session is None
|
||||
) # Track if we own the session for cleanup
|
||||
|
||||
self.transport = transport
|
||||
self._owns_transport = (
|
||||
transport is None
|
||||
) # Track if we own the transport for cleanup
|
||||
|
||||
self.connector = connector
|
||||
self._owns_connector = (
|
||||
connector is None
|
||||
) # Track if we own the connector for cleanup
|
||||
|
||||
def _get_or_create_transport(self) -> Optional[LiteLLMAiohttpTransport]:
|
||||
"""Get existing transport or create a new one if needed."""
|
||||
if self.transport:
|
||||
return self.transport
|
||||
|
||||
# Create a transport using AsyncHTTPHandler's logic
|
||||
try:
|
||||
self.transport = AsyncHTTPHandler._create_aiohttp_transport()
|
||||
self._owns_transport = True
|
||||
return self.transport
|
||||
except Exception:
|
||||
# If transport creation fails, return None (will use direct session)
|
||||
return None
|
||||
|
||||
def _get_connector(self) -> Optional[aiohttp.BaseConnector]:
|
||||
"""Get or create a connector for the client session."""
|
||||
if self.connector:
|
||||
return self.connector
|
||||
elif self.transport and hasattr(self.transport, "client"):
|
||||
# Extract connector from transport if available
|
||||
client = self.transport.client
|
||||
if callable(client):
|
||||
# If client is a factory, we can't extract connector directly
|
||||
return None
|
||||
elif hasattr(client, "connector"):
|
||||
return client.connector
|
||||
return None
|
||||
|
||||
def _create_client_session_with_transport(self) -> ClientSession:
|
||||
"""Create a new client session using transport or connector configuration."""
|
||||
connector = self._get_connector()
|
||||
|
||||
if self.transport and hasattr(self.transport, "_get_valid_client_session"):
|
||||
# Use transport's session creation if available
|
||||
session = self.transport._get_valid_client_session()
|
||||
return session
|
||||
elif connector:
|
||||
# Use provided connector
|
||||
session = aiohttp.ClientSession(connector=connector)
|
||||
return session
|
||||
else:
|
||||
# Default session creation
|
||||
session = aiohttp.ClientSession()
|
||||
return session
|
||||
|
||||
def _get_async_client_session(
|
||||
self, dynamic_client_session: Optional[ClientSession] = None
|
||||
) -> ClientSession:
|
||||
if dynamic_client_session:
|
||||
return dynamic_client_session
|
||||
elif self.client_session:
|
||||
return self.client_session
|
||||
else:
|
||||
# Create client session using transport/connector if available
|
||||
self.client_session = self._create_client_session_with_transport()
|
||||
self._owns_session = True # We created this session, so we own it
|
||||
return self.client_session
|
||||
|
||||
async def close(self):
|
||||
"""Close the aiohttp client session and transport if we own them."""
|
||||
# Close client session if we own it
|
||||
if (
|
||||
self.client_session
|
||||
and not self.client_session.closed
|
||||
and self._owns_session
|
||||
):
|
||||
await self.client_session.close()
|
||||
|
||||
# Close transport if we own it
|
||||
if (
|
||||
self.transport
|
||||
and self._owns_transport
|
||||
and hasattr(self.transport, "aclose")
|
||||
):
|
||||
try:
|
||||
await self.transport.aclose()
|
||||
except Exception:
|
||||
# Ignore errors during transport cleanup
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Cleanup: close aiohttp session on instance destruction.
|
||||
|
||||
Provides defense-in-depth for issue #12443 - ensures cleanup happens
|
||||
even if atexit handler doesn't run (abnormal termination).
|
||||
"""
|
||||
if (
|
||||
self.client_session is not None
|
||||
and not self.client_session.closed
|
||||
and self._owns_session
|
||||
):
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# Event loop is running - schedule cleanup task
|
||||
asyncio.create_task(self.close())
|
||||
else:
|
||||
# Event loop exists but not running - run cleanup
|
||||
loop.run_until_complete(self.close())
|
||||
except RuntimeError:
|
||||
# No event loop available - create one for cleanup
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(self.close())
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception:
|
||||
# Silently ignore errors during __del__ to avoid issues
|
||||
pass
|
||||
|
||||
async def _make_common_async_call(
|
||||
self,
|
||||
async_client_session: Optional[ClientSession],
|
||||
provider_config: BaseConfig,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: Optional[dict],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
form_data: Optional[FormData] = None,
|
||||
stream: bool = False,
|
||||
) -> aiohttp.ClientResponse:
|
||||
"""Common implementation across stream + non-stream calls. Meant to ensure consistent error-handling."""
|
||||
max_retry_on_unprocessable_entity_error = (
|
||||
provider_config.max_retry_on_unprocessable_entity_error
|
||||
)
|
||||
|
||||
response: Optional[aiohttp.ClientResponse] = None
|
||||
async_client_session = self._get_async_client_session(
|
||||
dynamic_client_session=async_client_session
|
||||
)
|
||||
|
||||
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
|
||||
try:
|
||||
response = await async_client_session.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
json=data,
|
||||
data=form_data,
|
||||
)
|
||||
if not response.ok:
|
||||
response.raise_for_status()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
setattr(e, "text", e.message)
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
except Exception as e:
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
break
|
||||
|
||||
if response is None:
|
||||
raise provider_config.get_error_class(
|
||||
error_message="No response from the API",
|
||||
status_code=422,
|
||||
headers={},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _make_common_sync_call(
|
||||
self,
|
||||
sync_httpx_client: HTTPHandler,
|
||||
provider_config: BaseConfig,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
litellm_params: dict,
|
||||
stream: bool = False,
|
||||
files: Optional[dict] = None,
|
||||
content: Any = None,
|
||||
params: Optional[dict] = None,
|
||||
) -> httpx.Response:
|
||||
max_retry_on_unprocessable_entity_error = (
|
||||
provider_config.max_retry_on_unprocessable_entity_error
|
||||
)
|
||||
|
||||
response: Optional[httpx.Response] = None
|
||||
|
||||
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
|
||||
try:
|
||||
response = sync_httpx_client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=data, # do not json dump the data here. let the individual endpoint handle this.
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
files=files,
|
||||
content=content,
|
||||
params=params,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
|
||||
should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||
e=e, litellm_params=litellm_params
|
||||
)
|
||||
if should_retry and not hit_max_retry:
|
||||
data = (
|
||||
provider_config.transform_request_on_unprocessable_entity_error(
|
||||
e=e, request_data=data
|
||||
)
|
||||
)
|
||||
continue
|
||||
else:
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
except Exception as e:
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
break
|
||||
|
||||
if response is None:
|
||||
raise provider_config.get_error_class(
|
||||
error_message="No response from the API",
|
||||
status_code=422, # don't retry on this error
|
||||
headers={},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
provider_config: BaseConfig,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
model: str,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
messages: list,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
client: Optional[ClientSession] = None,
|
||||
):
|
||||
_response = await self._make_common_async_call(
|
||||
async_client_session=client,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
litellm_params=litellm_params,
|
||||
stream=False,
|
||||
)
|
||||
_transformed_response = await provider_config.transform_response( # type: ignore
|
||||
model=model,
|
||||
raw_response=_response, # type: ignore
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
return _transformed_response
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_llm_provider: str,
|
||||
model_response: ModelResponse,
|
||||
encoding,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
acompletion: bool,
|
||||
stream: Optional[bool] = False,
|
||||
fake_stream: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
headers: Optional[dict] = {},
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler, ClientSession]] = None,
|
||||
):
|
||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||
)
|
||||
if provider_config is None:
|
||||
raise ValueError(
|
||||
f"Provider config not found for model: {model} and provider: {custom_llm_provider}"
|
||||
)
|
||||
# get config from model, custom llm provider
|
||||
headers = provider_config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers=headers or {},
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
api_base = provider_config.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
data = provider_config.transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
if acompletion is True:
|
||||
return self.async_completion(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, ClientSession)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
if stream is True:
|
||||
if fake_stream is not True:
|
||||
data["stream"] = stream
|
||||
completion_stream, headers = self.make_sync_call(
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers, # type: ignore
|
||||
data=data,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
fake_stream=fake_stream,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, HTTPHandler)
|
||||
else None
|
||||
),
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_httpx_client = _get_httpx_client()
|
||||
else:
|
||||
sync_httpx_client = client
|
||||
|
||||
response = self._make_common_sync_call(
|
||||
sync_httpx_client=sync_httpx_client,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
litellm_params=litellm_params,
|
||||
data=data,
|
||||
)
|
||||
return provider_config.transform_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def make_sync_call(
|
||||
self,
|
||||
provider_config: BaseConfig,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
litellm_params: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
fake_stream: bool = False,
|
||||
client: Optional[HTTPHandler] = None,
|
||||
) -> Tuple[Any, dict]:
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_httpx_client = _get_httpx_client()
|
||||
else:
|
||||
sync_httpx_client = client
|
||||
stream = True
|
||||
if fake_stream is True:
|
||||
stream = False
|
||||
|
||||
response = self._make_common_sync_call(
|
||||
sync_httpx_client=sync_httpx_client,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
litellm_params=litellm_params,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if fake_stream is True:
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.json(), sync_stream=True
|
||||
)
|
||||
else:
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.iter_lines(), sync_stream=True
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream, dict(response.headers)
|
||||
|
||||
async def async_image_variations(
|
||||
self,
|
||||
client: Optional[ClientSession],
|
||||
provider_config: BaseImageVariationConfig,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: HttpHandlerRequestFields,
|
||||
timeout: float,
|
||||
litellm_params: dict,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: str,
|
||||
model: Optional[str],
|
||||
image: FileTypes,
|
||||
optional_params: dict,
|
||||
) -> ImageResponse:
|
||||
# create aiohttp form data if files in data
|
||||
form_data: Optional[FormData] = None
|
||||
if "files" in data and "data" in data:
|
||||
form_data = FormData()
|
||||
for k, v in data["files"].items():
|
||||
form_data.add_field(k, v[1], filename=v[0], content_type=v[2])
|
||||
|
||||
for key, value in data["data"].items():
|
||||
form_data.add_field(key, value)
|
||||
|
||||
_response = await self._make_common_async_call(
|
||||
async_client_session=client,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=None if form_data is not None else cast(dict, data),
|
||||
form_data=form_data,
|
||||
timeout=timeout,
|
||||
litellm_params=litellm_params,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
api_key=api_key,
|
||||
original_response=_response.text,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return await provider_config.async_transform_response_image_variation(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
raw_response=_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=cast(dict, data),
|
||||
image=image,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=None,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
def image_variations(
|
||||
self,
|
||||
model_response: ImageResponse,
|
||||
api_key: str,
|
||||
model: Optional[str],
|
||||
image: FileTypes,
|
||||
timeout: float,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
api_base: Optional[str] = None,
|
||||
aimage_variation: bool = False,
|
||||
logger_fn=None,
|
||||
client=None,
|
||||
organization: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> ImageResponse:
|
||||
if model is None:
|
||||
raise ValueError("model is required for non-openai image variations")
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_image_variation_config(
|
||||
model=model, # openai defaults to dall-e-2
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
|
||||
if provider_config is None:
|
||||
raise ValueError(
|
||||
f"image variation provider not found: {custom_llm_provider}."
|
||||
)
|
||||
|
||||
api_base = provider_config.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
headers = provider_config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers=headers or {},
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
data = provider_config.transform_request_image_variation(
|
||||
model=model,
|
||||
image=image,
|
||||
optional_params=optional_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
"complete_input_dict": data.copy(),
|
||||
},
|
||||
)
|
||||
|
||||
if litellm_params.get("async_call", False):
|
||||
return self.async_image_variations(
|
||||
api_base=api_base,
|
||||
data=data,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
image=image,
|
||||
provider_config=provider_config,
|
||||
) # type: ignore
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_httpx_client = _get_httpx_client()
|
||||
else:
|
||||
sync_httpx_client = client
|
||||
|
||||
response = self._make_common_sync_call(
|
||||
sync_httpx_client=sync_httpx_client,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
litellm_params=litellm_params,
|
||||
stream=False,
|
||||
data=data.get("data") or {},
|
||||
files=data.get("files"),
|
||||
content=data.get("content"),
|
||||
params=data.get("params"),
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return provider_config.transform_response_image_variation(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
raw_response=response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=cast(dict, data),
|
||||
image=image,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=None,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
def _handle_error(self, e: Exception, provider_config: BaseConfig):
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
if error_response and hasattr(error_response, "text"):
|
||||
error_text = getattr(error_response, "text", error_text)
|
||||
if error_headers:
|
||||
error_headers = dict(error_headers)
|
||||
else:
|
||||
error_headers = {}
|
||||
raise provider_config.get_error_class(
|
||||
error_message=error_text,
|
||||
status_code=status_code,
|
||||
headers=error_headers,
|
||||
)
|
||||
@@ -0,0 +1,388 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import ssl
|
||||
import typing
|
||||
import urllib.request
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import aiohttp.client_exceptions
|
||||
import aiohttp.http_exceptions
|
||||
import httpx
|
||||
from aiohttp.client import ClientResponse, ClientSession
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
|
||||
AIOHTTP_EXC_MAP: Dict = {
|
||||
# Order matters here, most specific exception first
|
||||
# Timeout related exceptions
|
||||
asyncio.TimeoutError: httpx.TimeoutException,
|
||||
aiohttp.ServerTimeoutError: httpx.TimeoutException,
|
||||
aiohttp.ConnectionTimeoutError: httpx.ConnectTimeout,
|
||||
aiohttp.SocketTimeoutError: httpx.ReadTimeout,
|
||||
# Proxy related exceptions
|
||||
aiohttp.ClientProxyConnectionError: httpx.ProxyError,
|
||||
# SSL related exceptions
|
||||
aiohttp.ClientConnectorCertificateError: httpx.ProtocolError,
|
||||
aiohttp.ClientSSLError: httpx.ProtocolError,
|
||||
aiohttp.ServerFingerprintMismatch: httpx.ProtocolError,
|
||||
# Network related exceptions
|
||||
aiohttp.ClientConnectorError: httpx.ConnectError,
|
||||
aiohttp.ClientOSError: httpx.ConnectError,
|
||||
aiohttp.ClientPayloadError: httpx.ReadError,
|
||||
# Connection disconnection exceptions
|
||||
aiohttp.ServerDisconnectedError: httpx.ReadError,
|
||||
# Response related exceptions
|
||||
aiohttp.ClientConnectionError: httpx.NetworkError,
|
||||
aiohttp.ClientPayloadError: httpx.ReadError,
|
||||
aiohttp.ContentTypeError: httpx.ReadError,
|
||||
aiohttp.TooManyRedirects: httpx.TooManyRedirects,
|
||||
# URL related exceptions
|
||||
aiohttp.InvalidURL: httpx.InvalidURL,
|
||||
# Base exceptions
|
||||
aiohttp.ClientError: httpx.RequestError,
|
||||
}
|
||||
|
||||
# Add client_exceptions module exceptions
|
||||
try:
|
||||
import aiohttp.client_exceptions
|
||||
|
||||
AIOHTTP_EXC_MAP[aiohttp.client_exceptions.ClientPayloadError] = httpx.ReadError
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def map_aiohttp_exceptions() -> typing.Iterator[None]:
|
||||
try:
|
||||
yield
|
||||
except Exception as exc:
|
||||
mapped_exc = None
|
||||
|
||||
for from_exc, to_exc in AIOHTTP_EXC_MAP.items():
|
||||
if not isinstance(exc, from_exc): # type: ignore
|
||||
continue
|
||||
if mapped_exc is None or issubclass(to_exc, mapped_exc):
|
||||
mapped_exc = to_exc
|
||||
|
||||
if mapped_exc is None: # pragma: no cover
|
||||
raise
|
||||
|
||||
message = str(exc)
|
||||
raise mapped_exc(message) from exc
|
||||
|
||||
|
||||
class AiohttpResponseStream(httpx.AsyncByteStream):
|
||||
CHUNK_SIZE = 1024 * 16
|
||||
|
||||
def __init__(self, aiohttp_response: ClientResponse) -> None:
|
||||
self._aiohttp_response = aiohttp_response
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
try:
|
||||
async for chunk in self._aiohttp_response.content.iter_chunked(
|
||||
self.CHUNK_SIZE
|
||||
):
|
||||
yield chunk
|
||||
except (
|
||||
aiohttp.ClientPayloadError,
|
||||
aiohttp.client_exceptions.ClientPayloadError,
|
||||
) as e:
|
||||
# Handle incomplete transfers more gracefully
|
||||
# Log the error but don't re-raise if we've already yielded some data
|
||||
verbose_logger.debug(f"Transfer incomplete, but continuing: {e}")
|
||||
# If the error is due to incomplete transfer encoding, we can still
|
||||
# return what we've received so far, similar to how httpx handles it
|
||||
return
|
||||
except RuntimeError as e:
|
||||
# Some providers (e.g., SSE streams) may close the connection
|
||||
# causing aiohttp StreamReader to raise a generic RuntimeError
|
||||
# with message "Connection closed.". Treat this as a graceful
|
||||
# end-of-stream so downstream consumers don't error.
|
||||
if "Connection closed" in str(e):
|
||||
verbose_logger.debug(
|
||||
"Upstream closed streaming connection; ending iterator gracefully"
|
||||
)
|
||||
return
|
||||
raise
|
||||
except aiohttp.http_exceptions.TransferEncodingError as e:
|
||||
# Handle transfer encoding errors gracefully
|
||||
verbose_logger.debug(f"Transfer encoding error, but continuing: {e}")
|
||||
return
|
||||
except Exception:
|
||||
# For other exceptions, use the normal mapping
|
||||
with map_aiohttp_exceptions():
|
||||
raise
|
||||
|
||||
async def aclose(self) -> None:
|
||||
with map_aiohttp_exceptions():
|
||||
await self._aiohttp_response.__aexit__(None, None, None)
|
||||
|
||||
|
||||
class AiohttpTransport(httpx.AsyncBaseTransport):
|
||||
def __init__(
|
||||
self,
|
||||
client: Union[ClientSession, Callable[[], ClientSession]],
|
||||
owns_session: bool = True,
|
||||
) -> None:
|
||||
self.client = client
|
||||
self._owns_session = owns_session
|
||||
|
||||
#########################################################
|
||||
# Class variables for proxy settings
|
||||
#########################################################
|
||||
self.proxy_cache: Dict[str, Optional[str]] = {}
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._owns_session and isinstance(self.client, ClientSession):
|
||||
await self.client.close()
|
||||
|
||||
|
||||
class LiteLLMAiohttpTransport(AiohttpTransport):
|
||||
"""
|
||||
LiteLLM wrapper around AiohttpTransport to handle %-encodings in URLs
|
||||
and event loop lifecycle issues in CI/CD environments
|
||||
|
||||
Credit to: https://github.com/karpetrosyan/httpx-aiohttp for this implementation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Union[ClientSession, Callable[[], ClientSession]],
|
||||
ssl_verify: Optional[Union[bool, ssl.SSLContext]] = None,
|
||||
owns_session: bool = True,
|
||||
):
|
||||
self.client = client
|
||||
self._ssl_verify = ssl_verify # Store for per-request SSL override
|
||||
super().__init__(client=client, owns_session=owns_session)
|
||||
# Store the client factory for recreating sessions when needed
|
||||
if callable(client):
|
||||
self._client_factory = client
|
||||
|
||||
def _get_valid_client_session(self) -> ClientSession:
|
||||
"""
|
||||
Helper to get a valid ClientSession for the current event loop.
|
||||
|
||||
This handles the case where the session was created in a different
|
||||
event loop that may have been closed (common in CI/CD environments).
|
||||
"""
|
||||
from aiohttp.client import ClientSession
|
||||
|
||||
# If we don't have a client or it's not a ClientSession, create one
|
||||
if not isinstance(self.client, ClientSession):
|
||||
if hasattr(self, "_client_factory") and callable(self._client_factory):
|
||||
self.client = self._client_factory()
|
||||
else:
|
||||
self.client = ClientSession()
|
||||
# Don't return yet - check if the newly created session is valid
|
||||
|
||||
# Check if the session itself is closed
|
||||
if self.client.closed:
|
||||
verbose_logger.debug("Session is closed, creating new session")
|
||||
# Create a new session
|
||||
if hasattr(self, "_client_factory") and callable(self._client_factory):
|
||||
self.client = self._client_factory()
|
||||
else:
|
||||
self.client = ClientSession()
|
||||
return self.client
|
||||
|
||||
# Check if the existing session is still valid for the current event loop
|
||||
try:
|
||||
session_loop = getattr(self.client, "_loop", None)
|
||||
current_loop = asyncio.get_running_loop()
|
||||
|
||||
# If session is from a different or closed loop, recreate it
|
||||
if (
|
||||
session_loop is None
|
||||
or session_loop != current_loop
|
||||
or session_loop.is_closed()
|
||||
):
|
||||
# Close old session to prevent leaks
|
||||
old_session = self.client
|
||||
try:
|
||||
if not old_session.closed:
|
||||
try:
|
||||
asyncio.create_task(old_session.close())
|
||||
except RuntimeError:
|
||||
# Different event loop - can't schedule task, rely on GC
|
||||
verbose_logger.debug(
|
||||
"Old session from different loop, relying on GC"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error closing old session: {e}")
|
||||
|
||||
# Create a new session in the current event loop
|
||||
if hasattr(self, "_client_factory") and callable(self._client_factory):
|
||||
self.client = self._client_factory()
|
||||
else:
|
||||
self.client = ClientSession()
|
||||
|
||||
except (RuntimeError, AttributeError):
|
||||
# If we can't check the loop or session is invalid, recreate it
|
||||
if hasattr(self, "_client_factory") and callable(self._client_factory):
|
||||
self.client = self._client_factory()
|
||||
else:
|
||||
self.client = ClientSession()
|
||||
|
||||
return self.client
|
||||
|
||||
async def _make_aiohttp_request(
|
||||
self,
|
||||
client_session: ClientSession,
|
||||
request: httpx.Request,
|
||||
timeout: dict,
|
||||
proxy: Optional[str],
|
||||
sni_hostname: Optional[str],
|
||||
ssl_verify: Optional[Union[bool, ssl.SSLContext]] = None,
|
||||
) -> ClientResponse:
|
||||
"""
|
||||
Helper function to make an aiohttp request with the given parameters.
|
||||
|
||||
Args:
|
||||
client_session: The aiohttp ClientSession to use
|
||||
request: The httpx Request to send
|
||||
timeout: Timeout settings dict with 'connect', 'read', 'pool' keys
|
||||
proxy: Optional proxy URL
|
||||
sni_hostname: Optional SNI hostname for SSL
|
||||
ssl_verify: Optional SSL verification setting (False to disable, SSLContext for custom)
|
||||
|
||||
Returns:
|
||||
ClientResponse from aiohttp
|
||||
"""
|
||||
from aiohttp import ClientTimeout
|
||||
from yarl import URL as YarlURL
|
||||
|
||||
try:
|
||||
data = request.content
|
||||
except httpx.RequestNotRead:
|
||||
data = request.stream # type: ignore
|
||||
request.headers.pop("transfer-encoding", None) # handled by aiohttp
|
||||
|
||||
# Only pass ssl kwarg when explicitly configured, to avoid
|
||||
# overriding the session/connector defaults with None (which is
|
||||
# not a valid value for aiohttp's ssl parameter).
|
||||
request_kwargs: Dict[str, Any] = {
|
||||
"method": request.method,
|
||||
"url": YarlURL(str(request.url), encoded=True),
|
||||
"headers": request.headers,
|
||||
"data": data,
|
||||
"allow_redirects": False,
|
||||
"auto_decompress": False,
|
||||
"timeout": ClientTimeout(
|
||||
sock_connect=timeout.get("connect"),
|
||||
sock_read=timeout.get("read"),
|
||||
connect=timeout.get("pool"),
|
||||
),
|
||||
"proxy": proxy,
|
||||
"server_hostname": sni_hostname,
|
||||
}
|
||||
if ssl_verify is not None:
|
||||
request_kwargs["ssl"] = ssl_verify
|
||||
|
||||
response = await client_session.request(**request_kwargs).__aenter__()
|
||||
|
||||
return response
|
||||
|
||||
async def handle_async_request(
|
||||
self,
|
||||
request: httpx.Request,
|
||||
) -> httpx.Response:
|
||||
timeout = request.extensions.get("timeout", {})
|
||||
sni_hostname = request.extensions.get("sni_hostname")
|
||||
|
||||
# Use helper to ensure we have a valid session for the current event loop
|
||||
client_session = self._get_valid_client_session()
|
||||
|
||||
# Resolve proxy settings from environment variables
|
||||
proxy = await self._get_proxy_settings(request)
|
||||
|
||||
# Use stored SSL configuration for per-request override
|
||||
ssl_config = self._ssl_verify
|
||||
|
||||
try:
|
||||
with map_aiohttp_exceptions():
|
||||
response = await self._make_aiohttp_request(
|
||||
client_session=client_session,
|
||||
request=request,
|
||||
timeout=timeout,
|
||||
proxy=proxy,
|
||||
sni_hostname=sni_hostname,
|
||||
ssl_verify=ssl_config,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
# Handle the case where session was closed between our check and actual use
|
||||
if "Session is closed" in str(e):
|
||||
verbose_logger.debug(
|
||||
f"Session closed during request, retrying with new session: {e}"
|
||||
)
|
||||
# Force creation of a new session
|
||||
if hasattr(self, "_client_factory") and callable(self._client_factory):
|
||||
self.client = self._client_factory()
|
||||
else:
|
||||
self.client = ClientSession()
|
||||
client_session = self.client
|
||||
|
||||
# Retry the request with the new session
|
||||
with map_aiohttp_exceptions():
|
||||
response = await self._make_aiohttp_request(
|
||||
client_session=client_session,
|
||||
request=request,
|
||||
timeout=timeout,
|
||||
proxy=proxy,
|
||||
sni_hostname=sni_hostname,
|
||||
ssl_verify=ssl_config,
|
||||
)
|
||||
else:
|
||||
# Re-raise if it's a different RuntimeError
|
||||
raise
|
||||
|
||||
return httpx.Response(
|
||||
status_code=response.status,
|
||||
headers=response.headers,
|
||||
stream=AiohttpResponseStream(response),
|
||||
request=request,
|
||||
)
|
||||
|
||||
async def _get_proxy_settings(self, request: httpx.Request):
|
||||
proxy = None
|
||||
if not (
|
||||
litellm.disable_aiohttp_trust_env
|
||||
or str_to_bool(os.getenv("DISABLE_AIOHTTP_TRUST_ENV", "False"))
|
||||
):
|
||||
try:
|
||||
proxy = self._proxy_from_env(request.url)
|
||||
except Exception as e: # pragma: no cover - best effort
|
||||
verbose_logger.debug(f"Error reading proxy env: {e}")
|
||||
|
||||
return proxy
|
||||
|
||||
def _proxy_from_env(self, url: httpx.URL) -> typing.Optional[str]:
|
||||
"""
|
||||
Return proxy URL from env for the given request URL
|
||||
|
||||
Only check the proxy env settings once, this is a costly operation for CPU % usage
|
||||
|
||||
."""
|
||||
#########################################################
|
||||
# Check if we've already checked the proxy env settings
|
||||
#########################################################
|
||||
proxy_cache_key = url.host
|
||||
|
||||
if proxy_cache_key in self.proxy_cache:
|
||||
return self.proxy_cache[proxy_cache_key]
|
||||
|
||||
proxies = urllib.request.getproxies()
|
||||
if urllib.request.proxy_bypass(url.host):
|
||||
proxy_url = None
|
||||
else:
|
||||
proxy = proxies.get(url.scheme) or proxies.get("all")
|
||||
if proxy and "://" not in proxy:
|
||||
proxy = f"http://{proxy}"
|
||||
proxy_url = proxy
|
||||
|
||||
self.proxy_cache[proxy_cache_key] = proxy_url
|
||||
|
||||
return proxy_url
|
||||
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Utility functions for cleaning up async HTTP clients to prevent resource leaks.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
|
||||
async def close_litellm_async_clients():
|
||||
"""
|
||||
Close all cached async HTTP clients to prevent resource leaks.
|
||||
|
||||
This function iterates through all cached clients in litellm's in-memory cache
|
||||
and closes any aiohttp client sessions that are still open. Also closes the
|
||||
global base_llm_aiohttp_handler instance (issue #12443).
|
||||
"""
|
||||
# Import here to avoid circular import
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.aiohttp_handler import BaseLLMAIOHTTPHandler
|
||||
|
||||
cache_dict = getattr(litellm.in_memory_llm_clients_cache, "cache_dict", {})
|
||||
|
||||
for key, handler in cache_dict.items():
|
||||
# Handle BaseLLMAIOHTTPHandler instances (aiohttp_openai provider)
|
||||
if isinstance(handler, BaseLLMAIOHTTPHandler) and hasattr(handler, "close"):
|
||||
try:
|
||||
await handler.close()
|
||||
except Exception:
|
||||
# Silently ignore errors during cleanup
|
||||
pass
|
||||
|
||||
# Handle AsyncHTTPHandler instances (used by Gemini and other providers)
|
||||
elif hasattr(handler, "client"):
|
||||
client = handler.client
|
||||
# Check if the httpx client has an aiohttp transport
|
||||
if hasattr(client, "_transport") and hasattr(client._transport, "aclose"):
|
||||
try:
|
||||
await client._transport.aclose()
|
||||
except Exception:
|
||||
# Silently ignore errors during cleanup
|
||||
pass
|
||||
# Also close the httpx client itself
|
||||
if hasattr(client, "aclose") and not client.is_closed:
|
||||
try:
|
||||
await client.aclose()
|
||||
except Exception:
|
||||
# Silently ignore errors during cleanup
|
||||
pass
|
||||
|
||||
# Handle any other objects with aclose method
|
||||
elif hasattr(handler, "aclose"):
|
||||
try:
|
||||
await handler.aclose()
|
||||
except Exception:
|
||||
# Silently ignore errors during cleanup
|
||||
pass
|
||||
|
||||
# Close the global base_llm_aiohttp_handler instance (issue #12443)
|
||||
# This is used by Gemini and other providers that use aiohttp
|
||||
if hasattr(litellm, "base_llm_aiohttp_handler"):
|
||||
base_handler = getattr(litellm, "base_llm_aiohttp_handler", None)
|
||||
if isinstance(base_handler, BaseLLMAIOHTTPHandler) and hasattr(
|
||||
base_handler, "close"
|
||||
):
|
||||
try:
|
||||
await base_handler.close()
|
||||
except Exception:
|
||||
# Silently ignore errors during cleanup
|
||||
pass
|
||||
|
||||
|
||||
def register_async_client_cleanup():
|
||||
"""
|
||||
Register the async client cleanup function to run at exit.
|
||||
|
||||
This ensures that all async HTTP clients are properly closed when the program exits.
|
||||
"""
|
||||
import atexit
|
||||
|
||||
def cleanup_wrapper():
|
||||
"""
|
||||
Cleanup wrapper that creates a fresh event loop for atexit cleanup.
|
||||
|
||||
At exit time, the main event loop is often already closed. Creating a new
|
||||
event loop ensures cleanup runs successfully (fixes issue #12443).
|
||||
"""
|
||||
try:
|
||||
# Always create a fresh event loop at exit time
|
||||
# Don't use get_event_loop() - it may be closed or unavailable
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(close_litellm_async_clients())
|
||||
finally:
|
||||
# Clean up the loop we created
|
||||
loop.close()
|
||||
except Exception:
|
||||
# Silently ignore errors during cleanup to avoid exit handler failures
|
||||
pass
|
||||
|
||||
atexit.register(cleanup_wrapper)
|
||||
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
Generic container file handler for LiteLLM.
|
||||
|
||||
This module provides a single generic handler that can process any container file
|
||||
endpoint defined in endpoints.json, eliminating the need for individual handler methods.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Type, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.containers.main import (
|
||||
ContainerFileListResponse,
|
||||
ContainerFileObject,
|
||||
DeleteContainerFileResponse,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.containers.transformation import BaseContainerConfig
|
||||
|
||||
|
||||
# Response type mapping
|
||||
RESPONSE_TYPES: Dict[str, Type] = {
|
||||
"ContainerFileListResponse": ContainerFileListResponse,
|
||||
"ContainerFileObject": ContainerFileObject,
|
||||
"DeleteContainerFileResponse": DeleteContainerFileResponse,
|
||||
}
|
||||
|
||||
|
||||
def _load_endpoints_config() -> Dict:
|
||||
"""Load the endpoints configuration from JSON file."""
|
||||
config_path = Path(__file__).parent.parent.parent / "containers" / "endpoints.json"
|
||||
with open(config_path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _get_endpoint_config(endpoint_name: str) -> Optional[Dict]:
|
||||
"""Get config for a specific endpoint by name."""
|
||||
config = _load_endpoints_config()
|
||||
for endpoint in config["endpoints"]:
|
||||
if endpoint["name"] == endpoint_name or endpoint["async_name"] == endpoint_name:
|
||||
return endpoint
|
||||
return None
|
||||
|
||||
|
||||
def _build_url(
|
||||
api_base: str,
|
||||
path_template: str,
|
||||
path_params: Dict[str, str],
|
||||
) -> str:
|
||||
"""Build the full URL by substituting path parameters.
|
||||
|
||||
The api_base from get_complete_url already includes /containers,
|
||||
so we need to strip that prefix from the path_template.
|
||||
"""
|
||||
# api_base ends with /containers, path_template starts with /containers
|
||||
# So we need to strip /containers from the path
|
||||
if path_template.startswith("/containers"):
|
||||
path_template = path_template[len("/containers") :]
|
||||
|
||||
url = f"{api_base.rstrip('/')}{path_template}"
|
||||
for param, value in path_params.items():
|
||||
url = url.replace(f"{{{param}}}", value)
|
||||
return url
|
||||
|
||||
|
||||
def _build_query_params(
|
||||
query_param_names: list,
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Dict[str, str]:
|
||||
"""Build query parameters from kwargs."""
|
||||
params = {}
|
||||
for param_name in query_param_names:
|
||||
value = kwargs.get(param_name)
|
||||
if value is not None:
|
||||
params[param_name] = str(value) if not isinstance(value, str) else value
|
||||
return params
|
||||
|
||||
|
||||
def _prepare_multipart_file_upload(
|
||||
file: Any,
|
||||
headers: Dict[str, Any],
|
||||
) -> tuple:
|
||||
"""
|
||||
Prepare file and headers for multipart upload.
|
||||
|
||||
Returns:
|
||||
Tuple of (files_dict, headers_without_content_type)
|
||||
"""
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
extract_file_data,
|
||||
)
|
||||
|
||||
extracted = extract_file_data(file)
|
||||
filename = extracted.get("filename") or "file"
|
||||
content = extracted.get("content") or b""
|
||||
content_type = extracted.get("content_type") or "application/octet-stream"
|
||||
files = {"file": (filename, content, content_type)}
|
||||
|
||||
# Remove content-type header - httpx will set it automatically for multipart
|
||||
headers_copy = headers.copy()
|
||||
headers_copy.pop("content-type", None)
|
||||
headers_copy.pop("Content-Type", None)
|
||||
|
||||
return files, headers_copy
|
||||
|
||||
|
||||
class GenericContainerHandler:
|
||||
"""
|
||||
Generic handler for container file API endpoints.
|
||||
|
||||
This single handler can process any endpoint defined in endpoints.json,
|
||||
eliminating the need for individual handler methods per endpoint.
|
||||
"""
|
||||
|
||||
def handle(
|
||||
self,
|
||||
endpoint_name: str,
|
||||
container_provider_config: "BaseContainerConfig",
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
timeout: Union[float, httpx.Timeout] = 600,
|
||||
_is_async: bool = False,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
**kwargs,
|
||||
) -> Union[Any, Coroutine[Any, Any, Any]]:
|
||||
"""
|
||||
Generic handler for any container file endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the endpoint (e.g., "list_container_files")
|
||||
container_provider_config: Provider-specific configuration
|
||||
litellm_params: LiteLLM parameters including api_key, api_base
|
||||
logging_obj: Logging object for request logging
|
||||
extra_headers: Additional HTTP headers
|
||||
extra_query: Additional query parameters
|
||||
timeout: Request timeout
|
||||
_is_async: Whether to make async request
|
||||
client: Optional HTTP client
|
||||
**kwargs: Path params and query params (e.g., container_id, file_id, after, limit)
|
||||
"""
|
||||
if _is_async:
|
||||
return self._async_handle(
|
||||
endpoint_name=endpoint_name,
|
||||
container_provider_config=container_provider_config,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=logging_obj,
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self._sync_handle(
|
||||
endpoint_name=endpoint_name,
|
||||
container_provider_config=container_provider_config,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=logging_obj,
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _sync_handle(
|
||||
self,
|
||||
endpoint_name: str,
|
||||
container_provider_config: "BaseContainerConfig",
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
timeout: Union[float, httpx.Timeout] = 600,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Synchronous request handler."""
|
||||
endpoint_config = _get_endpoint_config(endpoint_name)
|
||||
if not endpoint_config:
|
||||
raise ValueError(f"Unknown endpoint: {endpoint_name}")
|
||||
|
||||
# Get HTTP client
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
http_client = _get_httpx_client(
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)}
|
||||
)
|
||||
else:
|
||||
http_client = client
|
||||
|
||||
# Build request
|
||||
headers = container_provider_config.validate_environment(
|
||||
headers=extra_headers or {},
|
||||
api_key=litellm_params.get("api_key", None),
|
||||
)
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
api_base = container_provider_config.get_complete_url(
|
||||
api_base=litellm_params.get("api_base", None),
|
||||
litellm_params=dict(litellm_params),
|
||||
)
|
||||
|
||||
# Build URL with path params
|
||||
path_params = {
|
||||
p: kwargs.get(p, "") for p in endpoint_config.get("path_params", [])
|
||||
}
|
||||
url = _build_url(api_base, endpoint_config["path"], path_params)
|
||||
|
||||
# Build query params
|
||||
query_params = _build_query_params(
|
||||
endpoint_config.get("query_params", []), kwargs
|
||||
)
|
||||
if extra_query:
|
||||
query_params.update(extra_query)
|
||||
|
||||
# Log request
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key="",
|
||||
additional_args={
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
"params": query_params,
|
||||
},
|
||||
)
|
||||
|
||||
# Make request
|
||||
method = endpoint_config["method"].upper()
|
||||
returns_binary = endpoint_config.get("returns_binary", False)
|
||||
is_multipart = endpoint_config.get("is_multipart", False)
|
||||
|
||||
try:
|
||||
if method == "GET":
|
||||
response = http_client.get(
|
||||
url=url, headers=headers, params=query_params
|
||||
)
|
||||
elif method == "DELETE":
|
||||
response = http_client.delete(
|
||||
url=url, headers=headers, params=query_params
|
||||
)
|
||||
elif method == "POST":
|
||||
if is_multipart and "file" in kwargs:
|
||||
files, headers = _prepare_multipart_file_upload(
|
||||
kwargs["file"], headers
|
||||
)
|
||||
response = http_client.post(
|
||||
url=url, headers=headers, params=query_params, files=files
|
||||
)
|
||||
else:
|
||||
response = http_client.post(
|
||||
url=url, headers=headers, params=query_params
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
# For binary responses, return raw content
|
||||
if returns_binary:
|
||||
return response.content
|
||||
|
||||
# Check for error response
|
||||
response_json = response.json()
|
||||
if "error" in response_json:
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
error_msg = response_json.get("error", {}).get(
|
||||
"message", str(response_json)
|
||||
)
|
||||
raise BaseLLMException(
|
||||
status_code=response.status_code,
|
||||
message=error_msg,
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
|
||||
# Parse response
|
||||
response_type = RESPONSE_TYPES.get(endpoint_config["response_type"])
|
||||
if response_type:
|
||||
return response_type(**response_json)
|
||||
return response_json
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def _async_handle(
|
||||
self,
|
||||
endpoint_name: str,
|
||||
container_provider_config: "BaseContainerConfig",
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
timeout: Union[float, httpx.Timeout] = 600,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Asynchronous request handler."""
|
||||
endpoint_config = _get_endpoint_config(endpoint_name)
|
||||
if not endpoint_config:
|
||||
raise ValueError(f"Unknown endpoint: {endpoint_name}")
|
||||
|
||||
# Get HTTP client
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
http_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.OPENAI,
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
|
||||
)
|
||||
else:
|
||||
http_client = client
|
||||
|
||||
# Build request
|
||||
headers = container_provider_config.validate_environment(
|
||||
headers=extra_headers or {},
|
||||
api_key=litellm_params.get("api_key", None),
|
||||
)
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
api_base = container_provider_config.get_complete_url(
|
||||
api_base=litellm_params.get("api_base", None),
|
||||
litellm_params=dict(litellm_params),
|
||||
)
|
||||
|
||||
# Build URL with path params
|
||||
path_params = {
|
||||
p: kwargs.get(p, "") for p in endpoint_config.get("path_params", [])
|
||||
}
|
||||
url = _build_url(api_base, endpoint_config["path"], path_params)
|
||||
|
||||
# Build query params
|
||||
query_params = _build_query_params(
|
||||
endpoint_config.get("query_params", []), kwargs
|
||||
)
|
||||
if extra_query:
|
||||
query_params.update(extra_query)
|
||||
|
||||
# Log request
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key="",
|
||||
additional_args={
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
"params": query_params,
|
||||
},
|
||||
)
|
||||
|
||||
# Make request
|
||||
method = endpoint_config["method"].upper()
|
||||
returns_binary = endpoint_config.get("returns_binary", False)
|
||||
is_multipart = endpoint_config.get("is_multipart", False)
|
||||
|
||||
try:
|
||||
if method == "GET":
|
||||
response = await http_client.get(
|
||||
url=url, headers=headers, params=query_params
|
||||
)
|
||||
elif method == "DELETE":
|
||||
response = await http_client.delete(
|
||||
url=url, headers=headers, params=query_params
|
||||
)
|
||||
elif method == "POST":
|
||||
if is_multipart and "file" in kwargs:
|
||||
files, headers = _prepare_multipart_file_upload(
|
||||
kwargs["file"], headers
|
||||
)
|
||||
response = await http_client.post(
|
||||
url=url, headers=headers, params=query_params, files=files
|
||||
)
|
||||
else:
|
||||
response = await http_client.post(
|
||||
url=url, headers=headers, params=query_params
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
# For binary responses, return raw content
|
||||
if returns_binary:
|
||||
return response.content
|
||||
|
||||
# Check for error response
|
||||
response_json = response.json()
|
||||
if "error" in response_json:
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
error_msg = response_json.get("error", {}).get(
|
||||
"message", str(response_json)
|
||||
)
|
||||
raise BaseLLMException(
|
||||
status_code=response.status_code,
|
||||
message=error_msg,
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
|
||||
# Parse response
|
||||
response_type = RESPONSE_TYPES.get(endpoint_config["response_type"])
|
||||
if response_type:
|
||||
return response_type(**response_json)
|
||||
return response_json
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
# Singleton instance
|
||||
generic_container_handler = GenericContainerHandler()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
try:
|
||||
from litellm._version import version
|
||||
except Exception:
|
||||
version = "0.0.0"
|
||||
|
||||
|
||||
def get_default_headers() -> dict:
|
||||
"""
|
||||
Get default headers for HTTP requests.
|
||||
|
||||
- Default: `User-Agent: litellm/{version}`
|
||||
- Override: set `LITELLM_USER_AGENT` to fully override the header value.
|
||||
"""
|
||||
user_agent = os.environ.get("LITELLM_USER_AGENT")
|
||||
if user_agent is not None:
|
||||
return {"User-Agent": user_agent}
|
||||
|
||||
return {"User-Agent": f"litellm/{version}"}
|
||||
|
||||
|
||||
class HTTPHandler:
|
||||
def __init__(self, concurrent_limit=1000):
|
||||
headers = get_default_headers()
|
||||
# Create a client with a connection pool
|
||||
self.client = httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
max_connections=concurrent_limit,
|
||||
max_keepalive_connections=concurrent_limit,
|
||||
),
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
# Close the client when you're done with it
|
||||
await self.client.aclose()
|
||||
|
||||
async def get(
|
||||
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
|
||||
):
|
||||
response = await self.client.get(url, params=params, headers=headers)
|
||||
return response
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[Union[dict, str]] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
try:
|
||||
response = await self.client.post(
|
||||
url, data=data, params=params, headers=headers # type: ignore
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
Mock httpx transport that returns valid OpenAI ChatCompletion responses.
|
||||
|
||||
Activated via `litellm_settings: { network_mock: true }`.
|
||||
Intercepts at the httpx transport layer — the lowest point before bytes hit the wire —
|
||||
so the full proxy -> router -> OpenAI SDK -> httpx path is exercised.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pre-built response templates
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _mock_id() -> str:
|
||||
return f"chatcmpl-mock-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def _chat_completion_json(model: str) -> dict:
|
||||
"""Return a minimal valid ChatCompletion object."""
|
||||
return {
|
||||
"id": _mock_id(),
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Mock response",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transport
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_JSON_HEADERS = {
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
class MockOpenAITransport(httpx.AsyncBaseTransport, httpx.BaseTransport):
|
||||
"""
|
||||
httpx transport that returns canned OpenAI ChatCompletion responses.
|
||||
|
||||
Supports both async (AsyncOpenAI) and sync (OpenAI) SDK paths.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _parse_request(request: httpx.Request) -> Tuple[str, bool]:
|
||||
"""Extract model from the request body."""
|
||||
try:
|
||||
body = json.loads(request.content)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return ("mock-model", False)
|
||||
model = body.get("model", "mock-model")
|
||||
return (model, False)
|
||||
|
||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
||||
model, _ = self._parse_request(request)
|
||||
body = json.dumps(_chat_completion_json(model)).encode()
|
||||
return httpx.Response(
|
||||
status_code=200,
|
||||
headers=_JSON_HEADERS,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def handle_request(self, request: httpx.Request) -> httpx.Response:
|
||||
model, _ = self._parse_request(request)
|
||||
body = json.dumps(_chat_completion_json(model)).encode()
|
||||
return httpx.Response(
|
||||
status_code=200,
|
||||
headers=_JSON_HEADERS,
|
||||
content=body,
|
||||
)
|
||||
Reference in New Issue
Block a user