chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,722 @@
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union, cast
import aiohttp
import httpx # type: ignore
from aiohttp import ClientSession, FormData
import litellm
import litellm.litellm_core_utils
import litellm.types
import litellm.types.utils
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.base_llm.image_variations.transformation import (
BaseImageVariationConfig,
)
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
)
from litellm.llms.custom_httpx.aiohttp_transport import LiteLLMAiohttpTransport
from litellm.types.llms.openai import FileTypes
from litellm.types.utils import HttpHandlerRequestFields, ImageResponse, LlmProviders
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
DEFAULT_TIMEOUT = 600
class BaseLLMAIOHTTPHandler:
def __init__(
self,
client_session: Optional[aiohttp.ClientSession] = None,
transport: Optional[LiteLLMAiohttpTransport] = None,
connector: Optional[aiohttp.BaseConnector] = None,
):
self.client_session = client_session
self._owns_session = (
client_session is None
) # Track if we own the session for cleanup
self.transport = transport
self._owns_transport = (
transport is None
) # Track if we own the transport for cleanup
self.connector = connector
self._owns_connector = (
connector is None
) # Track if we own the connector for cleanup
def _get_or_create_transport(self) -> Optional[LiteLLMAiohttpTransport]:
"""Get existing transport or create a new one if needed."""
if self.transport:
return self.transport
# Create a transport using AsyncHTTPHandler's logic
try:
self.transport = AsyncHTTPHandler._create_aiohttp_transport()
self._owns_transport = True
return self.transport
except Exception:
# If transport creation fails, return None (will use direct session)
return None
def _get_connector(self) -> Optional[aiohttp.BaseConnector]:
"""Get or create a connector for the client session."""
if self.connector:
return self.connector
elif self.transport and hasattr(self.transport, "client"):
# Extract connector from transport if available
client = self.transport.client
if callable(client):
# If client is a factory, we can't extract connector directly
return None
elif hasattr(client, "connector"):
return client.connector
return None
def _create_client_session_with_transport(self) -> ClientSession:
"""Create a new client session using transport or connector configuration."""
connector = self._get_connector()
if self.transport and hasattr(self.transport, "_get_valid_client_session"):
# Use transport's session creation if available
session = self.transport._get_valid_client_session()
return session
elif connector:
# Use provided connector
session = aiohttp.ClientSession(connector=connector)
return session
else:
# Default session creation
session = aiohttp.ClientSession()
return session
def _get_async_client_session(
self, dynamic_client_session: Optional[ClientSession] = None
) -> ClientSession:
if dynamic_client_session:
return dynamic_client_session
elif self.client_session:
return self.client_session
else:
# Create client session using transport/connector if available
self.client_session = self._create_client_session_with_transport()
self._owns_session = True # We created this session, so we own it
return self.client_session
async def close(self):
"""Close the aiohttp client session and transport if we own them."""
# Close client session if we own it
if (
self.client_session
and not self.client_session.closed
and self._owns_session
):
await self.client_session.close()
# Close transport if we own it
if (
self.transport
and self._owns_transport
and hasattr(self.transport, "aclose")
):
try:
await self.transport.aclose()
except Exception:
# Ignore errors during transport cleanup
pass
def __del__(self):
"""
Cleanup: close aiohttp session on instance destruction.
Provides defense-in-depth for issue #12443 - ensures cleanup happens
even if atexit handler doesn't run (abnormal termination).
"""
if (
self.client_session is not None
and not self.client_session.closed
and self._owns_session
):
try:
import asyncio
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# Event loop is running - schedule cleanup task
asyncio.create_task(self.close())
else:
# Event loop exists but not running - run cleanup
loop.run_until_complete(self.close())
except RuntimeError:
# No event loop available - create one for cleanup
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self.close())
finally:
loop.close()
except Exception:
# Silently ignore errors during __del__ to avoid issues
pass
async def _make_common_async_call(
self,
async_client_session: Optional[ClientSession],
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: Optional[dict],
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
form_data: Optional[FormData] = None,
stream: bool = False,
) -> aiohttp.ClientResponse:
"""Common implementation across stream + non-stream calls. Meant to ensure consistent error-handling."""
max_retry_on_unprocessable_entity_error = (
provider_config.max_retry_on_unprocessable_entity_error
)
response: Optional[aiohttp.ClientResponse] = None
async_client_session = self._get_async_client_session(
dynamic_client_session=async_client_session
)
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
try:
response = await async_client_session.post(
url=api_base,
headers=headers,
json=data,
data=form_data,
)
if not response.ok:
response.raise_for_status()
except aiohttp.ClientResponseError as e:
setattr(e, "text", e.message)
raise self._handle_error(e=e, provider_config=provider_config)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
break
if response is None:
raise provider_config.get_error_class(
error_message="No response from the API",
status_code=422,
headers={},
)
return response
def _make_common_sync_call(
self,
sync_httpx_client: HTTPHandler,
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: dict,
timeout: Optional[Union[float, httpx.Timeout]],
litellm_params: dict,
stream: bool = False,
files: Optional[dict] = None,
content: Any = None,
params: Optional[dict] = None,
) -> httpx.Response:
max_retry_on_unprocessable_entity_error = (
provider_config.max_retry_on_unprocessable_entity_error
)
response: Optional[httpx.Response] = None
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
try:
response = sync_httpx_client.post(
url=api_base,
headers=headers,
data=data, # do not json dump the data here. let the individual endpoint handle this.
timeout=timeout,
stream=stream,
files=files,
content=content,
params=params,
)
except httpx.HTTPStatusError as e:
hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
e=e, litellm_params=litellm_params
)
if should_retry and not hit_max_retry:
data = (
provider_config.transform_request_on_unprocessable_entity_error(
e=e, request_data=data
)
)
continue
else:
raise self._handle_error(e=e, provider_config=provider_config)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
break
if response is None:
raise provider_config.get_error_class(
error_message="No response from the API",
status_code=422, # don't retry on this error
headers={},
)
return response
async def async_completion(
self,
custom_llm_provider: str,
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: dict,
timeout: Union[float, httpx.Timeout],
model: str,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
messages: list,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
client: Optional[ClientSession] = None,
):
_response = await self._make_common_async_call(
async_client_session=client,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
litellm_params=litellm_params,
stream=False,
)
_transformed_response = await provider_config.transform_response( # type: ignore
model=model,
raw_response=_response, # type: ignore
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
)
return _transformed_response
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_llm_provider: str,
model_response: ModelResponse,
encoding,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
acompletion: bool,
stream: Optional[bool] = False,
fake_stream: bool = False,
api_key: Optional[str] = None,
headers: Optional[dict] = {},
client: Optional[Union[HTTPHandler, AsyncHTTPHandler, ClientSession]] = None,
):
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)
if provider_config is None:
raise ValueError(
f"Provider config not found for model: {model} and provider: {custom_llm_provider}"
)
# get config from model, custom llm provider
headers = provider_config.validate_environment(
api_key=api_key,
headers=headers or {},
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
api_base=api_base,
)
api_base = provider_config.get_complete_url(
api_base=api_base,
api_key=api_key,
model=model,
optional_params=optional_params,
litellm_params=litellm_params,
stream=stream,
)
data = provider_config.transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
if acompletion is True:
return self.async_completion(
custom_llm_provider=custom_llm_provider,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
model=model,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
client=(
client
if client is not None and isinstance(client, ClientSession)
else None
),
)
if stream is True:
if fake_stream is not True:
data["stream"] = stream
completion_stream, headers = self.make_sync_call(
provider_config=provider_config,
api_base=api_base,
headers=headers, # type: ignore
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
fake_stream=fake_stream,
client=(
client
if client is not None and isinstance(client, HTTPHandler)
else None
),
litellm_params=litellm_params,
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
response = self._make_common_sync_call(
sync_httpx_client=sync_httpx_client,
provider_config=provider_config,
api_base=api_base,
headers=headers,
timeout=timeout,
litellm_params=litellm_params,
data=data,
)
return provider_config.transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
)
def make_sync_call(
self,
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: dict,
model: str,
messages: list,
logging_obj,
litellm_params: dict,
timeout: Union[float, httpx.Timeout],
fake_stream: bool = False,
client: Optional[HTTPHandler] = None,
) -> Tuple[Any, dict]:
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
stream = True
if fake_stream is True:
stream = False
response = self._make_common_sync_call(
sync_httpx_client=sync_httpx_client,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
litellm_params=litellm_params,
stream=stream,
)
if fake_stream is True:
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.json(), sync_stream=True
)
else:
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.iter_lines(), sync_stream=True
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return completion_stream, dict(response.headers)
async def async_image_variations(
self,
client: Optional[ClientSession],
provider_config: BaseImageVariationConfig,
api_base: str,
headers: dict,
data: HttpHandlerRequestFields,
timeout: float,
litellm_params: dict,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
api_key: str,
model: Optional[str],
image: FileTypes,
optional_params: dict,
) -> ImageResponse:
# create aiohttp form data if files in data
form_data: Optional[FormData] = None
if "files" in data and "data" in data:
form_data = FormData()
for k, v in data["files"].items():
form_data.add_field(k, v[1], filename=v[0], content_type=v[2])
for key, value in data["data"].items():
form_data.add_field(key, value)
_response = await self._make_common_async_call(
async_client_session=client,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=None if form_data is not None else cast(dict, data),
form_data=form_data,
timeout=timeout,
litellm_params=litellm_params,
stream=False,
)
## LOGGING
logging_obj.post_call(
api_key=api_key,
original_response=_response.text,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return await provider_config.async_transform_response_image_variation(
model=model,
model_response=model_response,
raw_response=_response,
logging_obj=logging_obj,
request_data=cast(dict, data),
image=image,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=None,
api_key=api_key,
)
def image_variations(
self,
model_response: ImageResponse,
api_key: str,
model: Optional[str],
image: FileTypes,
timeout: float,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
litellm_params: dict,
print_verbose: Optional[Callable] = None,
api_base: Optional[str] = None,
aimage_variation: bool = False,
logger_fn=None,
client=None,
organization: Optional[str] = None,
headers: Optional[dict] = None,
) -> ImageResponse:
if model is None:
raise ValueError("model is required for non-openai image variations")
provider_config = ProviderConfigManager.get_provider_image_variation_config(
model=model, # openai defaults to dall-e-2
provider=LlmProviders(custom_llm_provider),
)
if provider_config is None:
raise ValueError(
f"image variation provider not found: {custom_llm_provider}."
)
api_base = provider_config.get_complete_url(
api_base=api_base,
api_key=api_key,
model=model,
optional_params=optional_params,
litellm_params=litellm_params,
stream=False,
)
headers = provider_config.validate_environment(
api_key=api_key,
headers=headers or {},
model=model,
messages=[{"role": "user", "content": "test"}],
optional_params=optional_params,
litellm_params=litellm_params,
api_base=api_base,
)
data = provider_config.transform_request_image_variation(
model=model,
image=image,
optional_params=optional_params,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input="",
api_key=api_key,
additional_args={
"headers": headers,
"api_base": api_base,
"complete_input_dict": data.copy(),
},
)
if litellm_params.get("async_call", False):
return self.async_image_variations(
api_base=api_base,
data=data,
headers=headers,
model_response=model_response,
logging_obj=logging_obj,
model=model,
timeout=timeout,
client=client,
optional_params=optional_params,
litellm_params=litellm_params,
image=image,
provider_config=provider_config,
) # type: ignore
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
response = self._make_common_sync_call(
sync_httpx_client=sync_httpx_client,
provider_config=provider_config,
api_base=api_base,
headers=headers,
timeout=timeout,
litellm_params=litellm_params,
stream=False,
data=data.get("data") or {},
files=data.get("files"),
content=data.get("content"),
params=data.get("params"),
)
## LOGGING
logging_obj.post_call(
api_key=api_key,
original_response=response.text,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return provider_config.transform_response_image_variation(
model=model,
model_response=model_response,
raw_response=response,
logging_obj=logging_obj,
request_data=cast(dict, data),
image=image,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=None,
api_key=api_key,
)
def _handle_error(self, e: Exception, provider_config: BaseConfig):
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
if error_response and hasattr(error_response, "text"):
error_text = getattr(error_response, "text", error_text)
if error_headers:
error_headers = dict(error_headers)
else:
error_headers = {}
raise provider_config.get_error_class(
error_message=error_text,
status_code=status_code,
headers=error_headers,
)

View File

@@ -0,0 +1,388 @@
import asyncio
import contextlib
import os
import ssl
import typing
import urllib.request
from typing import Any, Callable, Dict, Optional, Union
import aiohttp
import aiohttp.client_exceptions
import aiohttp.http_exceptions
import httpx
from aiohttp.client import ClientResponse, ClientSession
import litellm
from litellm._logging import verbose_logger
from litellm.secret_managers.main import str_to_bool
AIOHTTP_EXC_MAP: Dict = {
# Order matters here, most specific exception first
# Timeout related exceptions
asyncio.TimeoutError: httpx.TimeoutException,
aiohttp.ServerTimeoutError: httpx.TimeoutException,
aiohttp.ConnectionTimeoutError: httpx.ConnectTimeout,
aiohttp.SocketTimeoutError: httpx.ReadTimeout,
# Proxy related exceptions
aiohttp.ClientProxyConnectionError: httpx.ProxyError,
# SSL related exceptions
aiohttp.ClientConnectorCertificateError: httpx.ProtocolError,
aiohttp.ClientSSLError: httpx.ProtocolError,
aiohttp.ServerFingerprintMismatch: httpx.ProtocolError,
# Network related exceptions
aiohttp.ClientConnectorError: httpx.ConnectError,
aiohttp.ClientOSError: httpx.ConnectError,
aiohttp.ClientPayloadError: httpx.ReadError,
# Connection disconnection exceptions
aiohttp.ServerDisconnectedError: httpx.ReadError,
# Response related exceptions
aiohttp.ClientConnectionError: httpx.NetworkError,
aiohttp.ClientPayloadError: httpx.ReadError,
aiohttp.ContentTypeError: httpx.ReadError,
aiohttp.TooManyRedirects: httpx.TooManyRedirects,
# URL related exceptions
aiohttp.InvalidURL: httpx.InvalidURL,
# Base exceptions
aiohttp.ClientError: httpx.RequestError,
}
# Add client_exceptions module exceptions
try:
import aiohttp.client_exceptions
AIOHTTP_EXC_MAP[aiohttp.client_exceptions.ClientPayloadError] = httpx.ReadError
except ImportError:
pass
@contextlib.contextmanager
def map_aiohttp_exceptions() -> typing.Iterator[None]:
try:
yield
except Exception as exc:
mapped_exc = None
for from_exc, to_exc in AIOHTTP_EXC_MAP.items():
if not isinstance(exc, from_exc): # type: ignore
continue
if mapped_exc is None or issubclass(to_exc, mapped_exc):
mapped_exc = to_exc
if mapped_exc is None: # pragma: no cover
raise
message = str(exc)
raise mapped_exc(message) from exc
class AiohttpResponseStream(httpx.AsyncByteStream):
CHUNK_SIZE = 1024 * 16
def __init__(self, aiohttp_response: ClientResponse) -> None:
self._aiohttp_response = aiohttp_response
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
try:
async for chunk in self._aiohttp_response.content.iter_chunked(
self.CHUNK_SIZE
):
yield chunk
except (
aiohttp.ClientPayloadError,
aiohttp.client_exceptions.ClientPayloadError,
) as e:
# Handle incomplete transfers more gracefully
# Log the error but don't re-raise if we've already yielded some data
verbose_logger.debug(f"Transfer incomplete, but continuing: {e}")
# If the error is due to incomplete transfer encoding, we can still
# return what we've received so far, similar to how httpx handles it
return
except RuntimeError as e:
# Some providers (e.g., SSE streams) may close the connection
# causing aiohttp StreamReader to raise a generic RuntimeError
# with message "Connection closed.". Treat this as a graceful
# end-of-stream so downstream consumers don't error.
if "Connection closed" in str(e):
verbose_logger.debug(
"Upstream closed streaming connection; ending iterator gracefully"
)
return
raise
except aiohttp.http_exceptions.TransferEncodingError as e:
# Handle transfer encoding errors gracefully
verbose_logger.debug(f"Transfer encoding error, but continuing: {e}")
return
except Exception:
# For other exceptions, use the normal mapping
with map_aiohttp_exceptions():
raise
async def aclose(self) -> None:
with map_aiohttp_exceptions():
await self._aiohttp_response.__aexit__(None, None, None)
class AiohttpTransport(httpx.AsyncBaseTransport):
def __init__(
self,
client: Union[ClientSession, Callable[[], ClientSession]],
owns_session: bool = True,
) -> None:
self.client = client
self._owns_session = owns_session
#########################################################
# Class variables for proxy settings
#########################################################
self.proxy_cache: Dict[str, Optional[str]] = {}
async def aclose(self) -> None:
if self._owns_session and isinstance(self.client, ClientSession):
await self.client.close()
class LiteLLMAiohttpTransport(AiohttpTransport):
"""
LiteLLM wrapper around AiohttpTransport to handle %-encodings in URLs
and event loop lifecycle issues in CI/CD environments
Credit to: https://github.com/karpetrosyan/httpx-aiohttp for this implementation
"""
def __init__(
self,
client: Union[ClientSession, Callable[[], ClientSession]],
ssl_verify: Optional[Union[bool, ssl.SSLContext]] = None,
owns_session: bool = True,
):
self.client = client
self._ssl_verify = ssl_verify # Store for per-request SSL override
super().__init__(client=client, owns_session=owns_session)
# Store the client factory for recreating sessions when needed
if callable(client):
self._client_factory = client
def _get_valid_client_session(self) -> ClientSession:
"""
Helper to get a valid ClientSession for the current event loop.
This handles the case where the session was created in a different
event loop that may have been closed (common in CI/CD environments).
"""
from aiohttp.client import ClientSession
# If we don't have a client or it's not a ClientSession, create one
if not isinstance(self.client, ClientSession):
if hasattr(self, "_client_factory") and callable(self._client_factory):
self.client = self._client_factory()
else:
self.client = ClientSession()
# Don't return yet - check if the newly created session is valid
# Check if the session itself is closed
if self.client.closed:
verbose_logger.debug("Session is closed, creating new session")
# Create a new session
if hasattr(self, "_client_factory") and callable(self._client_factory):
self.client = self._client_factory()
else:
self.client = ClientSession()
return self.client
# Check if the existing session is still valid for the current event loop
try:
session_loop = getattr(self.client, "_loop", None)
current_loop = asyncio.get_running_loop()
# If session is from a different or closed loop, recreate it
if (
session_loop is None
or session_loop != current_loop
or session_loop.is_closed()
):
# Close old session to prevent leaks
old_session = self.client
try:
if not old_session.closed:
try:
asyncio.create_task(old_session.close())
except RuntimeError:
# Different event loop - can't schedule task, rely on GC
verbose_logger.debug(
"Old session from different loop, relying on GC"
)
except Exception as e:
verbose_logger.debug(f"Error closing old session: {e}")
# Create a new session in the current event loop
if hasattr(self, "_client_factory") and callable(self._client_factory):
self.client = self._client_factory()
else:
self.client = ClientSession()
except (RuntimeError, AttributeError):
# If we can't check the loop or session is invalid, recreate it
if hasattr(self, "_client_factory") and callable(self._client_factory):
self.client = self._client_factory()
else:
self.client = ClientSession()
return self.client
async def _make_aiohttp_request(
self,
client_session: ClientSession,
request: httpx.Request,
timeout: dict,
proxy: Optional[str],
sni_hostname: Optional[str],
ssl_verify: Optional[Union[bool, ssl.SSLContext]] = None,
) -> ClientResponse:
"""
Helper function to make an aiohttp request with the given parameters.
Args:
client_session: The aiohttp ClientSession to use
request: The httpx Request to send
timeout: Timeout settings dict with 'connect', 'read', 'pool' keys
proxy: Optional proxy URL
sni_hostname: Optional SNI hostname for SSL
ssl_verify: Optional SSL verification setting (False to disable, SSLContext for custom)
Returns:
ClientResponse from aiohttp
"""
from aiohttp import ClientTimeout
from yarl import URL as YarlURL
try:
data = request.content
except httpx.RequestNotRead:
data = request.stream # type: ignore
request.headers.pop("transfer-encoding", None) # handled by aiohttp
# Only pass ssl kwarg when explicitly configured, to avoid
# overriding the session/connector defaults with None (which is
# not a valid value for aiohttp's ssl parameter).
request_kwargs: Dict[str, Any] = {
"method": request.method,
"url": YarlURL(str(request.url), encoded=True),
"headers": request.headers,
"data": data,
"allow_redirects": False,
"auto_decompress": False,
"timeout": ClientTimeout(
sock_connect=timeout.get("connect"),
sock_read=timeout.get("read"),
connect=timeout.get("pool"),
),
"proxy": proxy,
"server_hostname": sni_hostname,
}
if ssl_verify is not None:
request_kwargs["ssl"] = ssl_verify
response = await client_session.request(**request_kwargs).__aenter__()
return response
async def handle_async_request(
self,
request: httpx.Request,
) -> httpx.Response:
timeout = request.extensions.get("timeout", {})
sni_hostname = request.extensions.get("sni_hostname")
# Use helper to ensure we have a valid session for the current event loop
client_session = self._get_valid_client_session()
# Resolve proxy settings from environment variables
proxy = await self._get_proxy_settings(request)
# Use stored SSL configuration for per-request override
ssl_config = self._ssl_verify
try:
with map_aiohttp_exceptions():
response = await self._make_aiohttp_request(
client_session=client_session,
request=request,
timeout=timeout,
proxy=proxy,
sni_hostname=sni_hostname,
ssl_verify=ssl_config,
)
except RuntimeError as e:
# Handle the case where session was closed between our check and actual use
if "Session is closed" in str(e):
verbose_logger.debug(
f"Session closed during request, retrying with new session: {e}"
)
# Force creation of a new session
if hasattr(self, "_client_factory") and callable(self._client_factory):
self.client = self._client_factory()
else:
self.client = ClientSession()
client_session = self.client
# Retry the request with the new session
with map_aiohttp_exceptions():
response = await self._make_aiohttp_request(
client_session=client_session,
request=request,
timeout=timeout,
proxy=proxy,
sni_hostname=sni_hostname,
ssl_verify=ssl_config,
)
else:
# Re-raise if it's a different RuntimeError
raise
return httpx.Response(
status_code=response.status,
headers=response.headers,
stream=AiohttpResponseStream(response),
request=request,
)
async def _get_proxy_settings(self, request: httpx.Request):
proxy = None
if not (
litellm.disable_aiohttp_trust_env
or str_to_bool(os.getenv("DISABLE_AIOHTTP_TRUST_ENV", "False"))
):
try:
proxy = self._proxy_from_env(request.url)
except Exception as e: # pragma: no cover - best effort
verbose_logger.debug(f"Error reading proxy env: {e}")
return proxy
def _proxy_from_env(self, url: httpx.URL) -> typing.Optional[str]:
"""
Return proxy URL from env for the given request URL
Only check the proxy env settings once, this is a costly operation for CPU % usage
."""
#########################################################
# Check if we've already checked the proxy env settings
#########################################################
proxy_cache_key = url.host
if proxy_cache_key in self.proxy_cache:
return self.proxy_cache[proxy_cache_key]
proxies = urllib.request.getproxies()
if urllib.request.proxy_bypass(url.host):
proxy_url = None
else:
proxy = proxies.get(url.scheme) or proxies.get("all")
if proxy and "://" not in proxy:
proxy = f"http://{proxy}"
proxy_url = proxy
self.proxy_cache[proxy_cache_key] = proxy_url
return proxy_url

View File

@@ -0,0 +1,99 @@
"""
Utility functions for cleaning up async HTTP clients to prevent resource leaks.
"""
import asyncio
async def close_litellm_async_clients():
"""
Close all cached async HTTP clients to prevent resource leaks.
This function iterates through all cached clients in litellm's in-memory cache
and closes any aiohttp client sessions that are still open. Also closes the
global base_llm_aiohttp_handler instance (issue #12443).
"""
# Import here to avoid circular import
import litellm
from litellm.llms.custom_httpx.aiohttp_handler import BaseLLMAIOHTTPHandler
cache_dict = getattr(litellm.in_memory_llm_clients_cache, "cache_dict", {})
for key, handler in cache_dict.items():
# Handle BaseLLMAIOHTTPHandler instances (aiohttp_openai provider)
if isinstance(handler, BaseLLMAIOHTTPHandler) and hasattr(handler, "close"):
try:
await handler.close()
except Exception:
# Silently ignore errors during cleanup
pass
# Handle AsyncHTTPHandler instances (used by Gemini and other providers)
elif hasattr(handler, "client"):
client = handler.client
# Check if the httpx client has an aiohttp transport
if hasattr(client, "_transport") and hasattr(client._transport, "aclose"):
try:
await client._transport.aclose()
except Exception:
# Silently ignore errors during cleanup
pass
# Also close the httpx client itself
if hasattr(client, "aclose") and not client.is_closed:
try:
await client.aclose()
except Exception:
# Silently ignore errors during cleanup
pass
# Handle any other objects with aclose method
elif hasattr(handler, "aclose"):
try:
await handler.aclose()
except Exception:
# Silently ignore errors during cleanup
pass
# Close the global base_llm_aiohttp_handler instance (issue #12443)
# This is used by Gemini and other providers that use aiohttp
if hasattr(litellm, "base_llm_aiohttp_handler"):
base_handler = getattr(litellm, "base_llm_aiohttp_handler", None)
if isinstance(base_handler, BaseLLMAIOHTTPHandler) and hasattr(
base_handler, "close"
):
try:
await base_handler.close()
except Exception:
# Silently ignore errors during cleanup
pass
def register_async_client_cleanup():
"""
Register the async client cleanup function to run at exit.
This ensures that all async HTTP clients are properly closed when the program exits.
"""
import atexit
def cleanup_wrapper():
"""
Cleanup wrapper that creates a fresh event loop for atexit cleanup.
At exit time, the main event loop is often already closed. Creating a new
event loop ensures cleanup runs successfully (fixes issue #12443).
"""
try:
# Always create a fresh event loop at exit time
# Don't use get_event_loop() - it may be closed or unavailable
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(close_litellm_async_clients())
finally:
# Clean up the loop we created
loop.close()
except Exception:
# Silently ignore errors during cleanup to avoid exit handler failures
pass
atexit.register(cleanup_wrapper)

View File

@@ -0,0 +1,419 @@
"""
Generic container file handler for LiteLLM.
This module provides a single generic handler that can process any container file
endpoint defined in endpoints.json, eliminating the need for individual handler methods.
"""
import json
from pathlib import Path
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Type, Union
import httpx
import litellm
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.containers.main import (
ContainerFileListResponse,
ContainerFileObject,
DeleteContainerFileResponse,
)
from litellm.types.router import GenericLiteLLMParams
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.containers.transformation import BaseContainerConfig
# Response type mapping
RESPONSE_TYPES: Dict[str, Type] = {
"ContainerFileListResponse": ContainerFileListResponse,
"ContainerFileObject": ContainerFileObject,
"DeleteContainerFileResponse": DeleteContainerFileResponse,
}
def _load_endpoints_config() -> Dict:
"""Load the endpoints configuration from JSON file."""
config_path = Path(__file__).parent.parent.parent / "containers" / "endpoints.json"
with open(config_path) as f:
return json.load(f)
def _get_endpoint_config(endpoint_name: str) -> Optional[Dict]:
"""Get config for a specific endpoint by name."""
config = _load_endpoints_config()
for endpoint in config["endpoints"]:
if endpoint["name"] == endpoint_name or endpoint["async_name"] == endpoint_name:
return endpoint
return None
def _build_url(
api_base: str,
path_template: str,
path_params: Dict[str, str],
) -> str:
"""Build the full URL by substituting path parameters.
The api_base from get_complete_url already includes /containers,
so we need to strip that prefix from the path_template.
"""
# api_base ends with /containers, path_template starts with /containers
# So we need to strip /containers from the path
if path_template.startswith("/containers"):
path_template = path_template[len("/containers") :]
url = f"{api_base.rstrip('/')}{path_template}"
for param, value in path_params.items():
url = url.replace(f"{{{param}}}", value)
return url
def _build_query_params(
query_param_names: list,
kwargs: Dict[str, Any],
) -> Dict[str, str]:
"""Build query parameters from kwargs."""
params = {}
for param_name in query_param_names:
value = kwargs.get(param_name)
if value is not None:
params[param_name] = str(value) if not isinstance(value, str) else value
return params
def _prepare_multipart_file_upload(
file: Any,
headers: Dict[str, Any],
) -> tuple:
"""
Prepare file and headers for multipart upload.
Returns:
Tuple of (files_dict, headers_without_content_type)
"""
from litellm.litellm_core_utils.prompt_templates.common_utils import (
extract_file_data,
)
extracted = extract_file_data(file)
filename = extracted.get("filename") or "file"
content = extracted.get("content") or b""
content_type = extracted.get("content_type") or "application/octet-stream"
files = {"file": (filename, content, content_type)}
# Remove content-type header - httpx will set it automatically for multipart
headers_copy = headers.copy()
headers_copy.pop("content-type", None)
headers_copy.pop("Content-Type", None)
return files, headers_copy
class GenericContainerHandler:
"""
Generic handler for container file API endpoints.
This single handler can process any endpoint defined in endpoints.json,
eliminating the need for individual handler methods per endpoint.
"""
def handle(
self,
endpoint_name: str,
container_provider_config: "BaseContainerConfig",
litellm_params: GenericLiteLLMParams,
logging_obj: "LiteLLMLoggingObj",
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
timeout: Union[float, httpx.Timeout] = 600,
_is_async: bool = False,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
**kwargs,
) -> Union[Any, Coroutine[Any, Any, Any]]:
"""
Generic handler for any container file endpoint.
Args:
endpoint_name: Name of the endpoint (e.g., "list_container_files")
container_provider_config: Provider-specific configuration
litellm_params: LiteLLM parameters including api_key, api_base
logging_obj: Logging object for request logging
extra_headers: Additional HTTP headers
extra_query: Additional query parameters
timeout: Request timeout
_is_async: Whether to make async request
client: Optional HTTP client
**kwargs: Path params and query params (e.g., container_id, file_id, after, limit)
"""
if _is_async:
return self._async_handle(
endpoint_name=endpoint_name,
container_provider_config=container_provider_config,
litellm_params=litellm_params,
logging_obj=logging_obj,
extra_headers=extra_headers,
extra_query=extra_query,
timeout=timeout,
client=client,
**kwargs,
)
return self._sync_handle(
endpoint_name=endpoint_name,
container_provider_config=container_provider_config,
litellm_params=litellm_params,
logging_obj=logging_obj,
extra_headers=extra_headers,
extra_query=extra_query,
timeout=timeout,
client=client,
**kwargs,
)
def _sync_handle(
self,
endpoint_name: str,
container_provider_config: "BaseContainerConfig",
litellm_params: GenericLiteLLMParams,
logging_obj: "LiteLLMLoggingObj",
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
timeout: Union[float, httpx.Timeout] = 600,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
**kwargs,
) -> Any:
"""Synchronous request handler."""
endpoint_config = _get_endpoint_config(endpoint_name)
if not endpoint_config:
raise ValueError(f"Unknown endpoint: {endpoint_name}")
# Get HTTP client
if client is None or not isinstance(client, HTTPHandler):
http_client = _get_httpx_client(
params={"ssl_verify": litellm_params.get("ssl_verify", None)}
)
else:
http_client = client
# Build request
headers = container_provider_config.validate_environment(
headers=extra_headers or {},
api_key=litellm_params.get("api_key", None),
)
if extra_headers:
headers.update(extra_headers)
api_base = container_provider_config.get_complete_url(
api_base=litellm_params.get("api_base", None),
litellm_params=dict(litellm_params),
)
# Build URL with path params
path_params = {
p: kwargs.get(p, "") for p in endpoint_config.get("path_params", [])
}
url = _build_url(api_base, endpoint_config["path"], path_params)
# Build query params
query_params = _build_query_params(
endpoint_config.get("query_params", []), kwargs
)
if extra_query:
query_params.update(extra_query)
# Log request
logging_obj.pre_call(
input="",
api_key="",
additional_args={
"api_base": url,
"headers": headers,
"params": query_params,
},
)
# Make request
method = endpoint_config["method"].upper()
returns_binary = endpoint_config.get("returns_binary", False)
is_multipart = endpoint_config.get("is_multipart", False)
try:
if method == "GET":
response = http_client.get(
url=url, headers=headers, params=query_params
)
elif method == "DELETE":
response = http_client.delete(
url=url, headers=headers, params=query_params
)
elif method == "POST":
if is_multipart and "file" in kwargs:
files, headers = _prepare_multipart_file_upload(
kwargs["file"], headers
)
response = http_client.post(
url=url, headers=headers, params=query_params, files=files
)
else:
response = http_client.post(
url=url, headers=headers, params=query_params
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
# For binary responses, return raw content
if returns_binary:
return response.content
# Check for error response
response_json = response.json()
if "error" in response_json:
from litellm.llms.base_llm.chat.transformation import BaseLLMException
error_msg = response_json.get("error", {}).get(
"message", str(response_json)
)
raise BaseLLMException(
status_code=response.status_code,
message=error_msg,
headers=dict(response.headers),
)
# Parse response
response_type = RESPONSE_TYPES.get(endpoint_config["response_type"])
if response_type:
return response_type(**response_json)
return response_json
except Exception as e:
raise e
async def _async_handle(
self,
endpoint_name: str,
container_provider_config: "BaseContainerConfig",
litellm_params: GenericLiteLLMParams,
logging_obj: "LiteLLMLoggingObj",
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
timeout: Union[float, httpx.Timeout] = 600,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
**kwargs,
) -> Any:
"""Asynchronous request handler."""
endpoint_config = _get_endpoint_config(endpoint_name)
if not endpoint_config:
raise ValueError(f"Unknown endpoint: {endpoint_name}")
# Get HTTP client
if client is None or not isinstance(client, AsyncHTTPHandler):
http_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.OPENAI,
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
)
else:
http_client = client
# Build request
headers = container_provider_config.validate_environment(
headers=extra_headers or {},
api_key=litellm_params.get("api_key", None),
)
if extra_headers:
headers.update(extra_headers)
api_base = container_provider_config.get_complete_url(
api_base=litellm_params.get("api_base", None),
litellm_params=dict(litellm_params),
)
# Build URL with path params
path_params = {
p: kwargs.get(p, "") for p in endpoint_config.get("path_params", [])
}
url = _build_url(api_base, endpoint_config["path"], path_params)
# Build query params
query_params = _build_query_params(
endpoint_config.get("query_params", []), kwargs
)
if extra_query:
query_params.update(extra_query)
# Log request
logging_obj.pre_call(
input="",
api_key="",
additional_args={
"api_base": url,
"headers": headers,
"params": query_params,
},
)
# Make request
method = endpoint_config["method"].upper()
returns_binary = endpoint_config.get("returns_binary", False)
is_multipart = endpoint_config.get("is_multipart", False)
try:
if method == "GET":
response = await http_client.get(
url=url, headers=headers, params=query_params
)
elif method == "DELETE":
response = await http_client.delete(
url=url, headers=headers, params=query_params
)
elif method == "POST":
if is_multipart and "file" in kwargs:
files, headers = _prepare_multipart_file_upload(
kwargs["file"], headers
)
response = await http_client.post(
url=url, headers=headers, params=query_params, files=files
)
else:
response = await http_client.post(
url=url, headers=headers, params=query_params
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
# For binary responses, return raw content
if returns_binary:
return response.content
# Check for error response
response_json = response.json()
if "error" in response_json:
from litellm.llms.base_llm.chat.transformation import BaseLLMException
error_msg = response_json.get("error", {}).get(
"message", str(response_json)
)
raise BaseLLMException(
status_code=response.status_code,
message=error_msg,
headers=dict(response.headers),
)
# Parse response
response_type = RESPONSE_TYPES.get(endpoint_config["response_type"])
if response_type:
return response_type(**response_json)
return response_json
except Exception as e:
raise e
# Singleton instance
generic_container_handler = GenericContainerHandler()

View File

@@ -0,0 +1,61 @@
import os
from typing import Optional, Union
import httpx
try:
from litellm._version import version
except Exception:
version = "0.0.0"
def get_default_headers() -> dict:
"""
Get default headers for HTTP requests.
- Default: `User-Agent: litellm/{version}`
- Override: set `LITELLM_USER_AGENT` to fully override the header value.
"""
user_agent = os.environ.get("LITELLM_USER_AGENT")
if user_agent is not None:
return {"User-Agent": user_agent}
return {"User-Agent": f"litellm/{version}"}
class HTTPHandler:
def __init__(self, concurrent_limit=1000):
headers = get_default_headers()
# Create a client with a connection pool
self.client = httpx.AsyncClient(
limits=httpx.Limits(
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
),
headers=headers,
)
async def close(self):
# Close the client when you're done with it
await self.client.aclose()
async def get(
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
):
response = await self.client.get(url, params=params, headers=headers)
return response
async def post(
self,
url: str,
data: Optional[Union[dict, str]] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
):
try:
response = await self.client.post(
url, data=data, params=params, headers=headers # type: ignore
)
return response
except Exception as e:
raise e

View File

@@ -0,0 +1,93 @@
"""
Mock httpx transport that returns valid OpenAI ChatCompletion responses.
Activated via `litellm_settings: { network_mock: true }`.
Intercepts at the httpx transport layer — the lowest point before bytes hit the wire —
so the full proxy -> router -> OpenAI SDK -> httpx path is exercised.
"""
import json
import time
import uuid
from typing import Tuple
import httpx
# ---------------------------------------------------------------------------
# Pre-built response templates
# ---------------------------------------------------------------------------
def _mock_id() -> str:
return f"chatcmpl-mock-{uuid.uuid4().hex[:8]}"
def _chat_completion_json(model: str) -> dict:
"""Return a minimal valid ChatCompletion object."""
return {
"id": _mock_id(),
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Mock response",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
},
}
# ---------------------------------------------------------------------------
# Transport
# ---------------------------------------------------------------------------
_JSON_HEADERS = {
"content-type": "application/json",
}
class MockOpenAITransport(httpx.AsyncBaseTransport, httpx.BaseTransport):
"""
httpx transport that returns canned OpenAI ChatCompletion responses.
Supports both async (AsyncOpenAI) and sync (OpenAI) SDK paths.
"""
@staticmethod
def _parse_request(request: httpx.Request) -> Tuple[str, bool]:
"""Extract model from the request body."""
try:
body = json.loads(request.content)
except (json.JSONDecodeError, ValueError):
return ("mock-model", False)
model = body.get("model", "mock-model")
return (model, False)
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
model, _ = self._parse_request(request)
body = json.dumps(_chat_completion_json(model)).encode()
return httpx.Response(
status_code=200,
headers=_JSON_HEADERS,
content=body,
)
def handle_request(self, request: httpx.Request) -> httpx.Response:
model, _ = self._parse_request(request)
body = json.dumps(_chat_completion_json(model)).encode()
return httpx.Response(
status_code=200,
headers=_JSON_HEADERS,
content=body,
)