chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,169 @@
|
||||
def show_missing_vars_in_env():
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from litellm.proxy.proxy_server import master_key, prisma_client
|
||||
|
||||
if prisma_client is None and master_key is None:
|
||||
return HTMLResponse(
|
||||
content=missing_keys_form(
|
||||
missing_key_names="DATABASE_URL, LITELLM_MASTER_KEY"
|
||||
),
|
||||
status_code=200,
|
||||
)
|
||||
if prisma_client is None:
|
||||
return HTMLResponse(
|
||||
content=missing_keys_form(missing_key_names="DATABASE_URL"), status_code=200
|
||||
)
|
||||
|
||||
if master_key is None:
|
||||
return HTMLResponse(
|
||||
content=missing_keys_form(missing_key_names="LITELLM_MASTER_KEY"),
|
||||
status_code=200,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def missing_keys_form(missing_key_names: str):
|
||||
missing_keys_html_form = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<style>
|
||||
body {{
|
||||
font-family: Arial, sans-serif;
|
||||
background-color: #f4f4f9;
|
||||
color: #333;
|
||||
margin: 20px;
|
||||
line-height: 1.6;
|
||||
}}
|
||||
.container {{
|
||||
max-width: 800px;
|
||||
margin: auto;
|
||||
padding: 20px;
|
||||
background: #fff;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 5px;
|
||||
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
||||
}}
|
||||
h1 {{
|
||||
font-size: 24px;
|
||||
margin-bottom: 20px;
|
||||
}}
|
||||
pre {{
|
||||
background: #f8f8f8;
|
||||
padding: 1px;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 4px;
|
||||
overflow-x: auto;
|
||||
font-size: 14px;
|
||||
}}
|
||||
.env-var {{
|
||||
font-weight: normal;
|
||||
}}
|
||||
.comment {{
|
||||
font-weight: normal;
|
||||
color: #777;
|
||||
}}
|
||||
</style>
|
||||
<title>Environment Setup Instructions</title>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Environment Setup Instructions</h1>
|
||||
<p>Please add the following variables to your environment variables:</p>
|
||||
<pre>
|
||||
<span class="env-var">LITELLM_MASTER_KEY="sk-1234"</span> <span class="comment"># Your master key for the proxy server. Can use this to send /chat/completion requests etc</span>
|
||||
<span class="env-var">LITELLM_SALT_KEY="sk-XXXXXXXX"</span> <span class="comment"># Can NOT CHANGE THIS ONCE SET - It is used to encrypt/decrypt credentials stored in DB. If value of 'LITELLM_SALT_KEY' changes your models cannot be retrieved from DB</span>
|
||||
<span class="env-var">DATABASE_URL="postgres://..."</span> <span class="comment"># Need a postgres database? (Check out Supabase, Neon, etc)</span>
|
||||
<span class="comment">## OPTIONAL ##</span>
|
||||
<span class="env-var">PORT=4000</span> <span class="comment"># DO THIS FOR RENDER/RAILWAY</span>
|
||||
<span class="env-var">STORE_MODEL_IN_DB="True"</span> <span class="comment"># Allow storing models in db</span>
|
||||
</pre>
|
||||
<h1>Missing Environment Variables</h1>
|
||||
<p>{missing_keys}</p>
|
||||
</div>
|
||||
|
||||
<div class="container">
|
||||
<h1>Need Help? Support</h1>
|
||||
<p>Discord: <a href="https://discord.com/invite/wuPM9dRgDw" target="_blank">https://discord.com/invite/wuPM9dRgDw</a></p>
|
||||
<p>Docs: <a href="https://docs.litellm.ai/docs/" target="_blank">https://docs.litellm.ai/docs/</a></p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return missing_keys_html_form.format(missing_keys=missing_key_names)
|
||||
|
||||
|
||||
def admin_ui_disabled():
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
ui_disabled_html = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<style>
|
||||
body {{
|
||||
font-family: Arial, sans-serif;
|
||||
background-color: #f4f4f9;
|
||||
color: #333;
|
||||
margin: 20px;
|
||||
line-height: 1.6;
|
||||
}}
|
||||
.container {{
|
||||
max-width: 800px;
|
||||
margin: auto;
|
||||
padding: 20px;
|
||||
background: #fff;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 5px;
|
||||
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
||||
}}
|
||||
h1 {{
|
||||
font-size: 24px;
|
||||
margin-bottom: 20px;
|
||||
}}
|
||||
pre {{
|
||||
background: #f8f8f8;
|
||||
padding: 1px;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 4px;
|
||||
overflow-x: auto;
|
||||
font-size: 14px;
|
||||
}}
|
||||
.env-var {{
|
||||
font-weight: normal;
|
||||
}}
|
||||
.comment {{
|
||||
font-weight: normal;
|
||||
color: #777;
|
||||
}}
|
||||
</style>
|
||||
<title>Admin UI Disabled</title>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Admin UI is Disabled</h1>
|
||||
<p>The Admin UI has been disabled by the administrator. To re-enable it, please update the following environment variable:</p>
|
||||
<pre>
|
||||
<span class="env-var">DISABLE_ADMIN_UI="False"</span> <span class="comment"># Set this to "False" to enable the Admin UI.</span>
|
||||
</pre>
|
||||
<p>After making this change, restart the application for it to take effect.</p>
|
||||
</div>
|
||||
|
||||
<div class="container">
|
||||
<h1>Need Help? Support</h1>
|
||||
<p>Discord: <a href="https://discord.com/invite/wuPM9dRgDw" target="_blank">https://discord.com/invite/wuPM9dRgDw</a></p>
|
||||
<p>Docs: <a href="https://docs.litellm.ai/docs/" target="_blank">https://docs.litellm.ai/docs/</a></p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return HTMLResponse(
|
||||
content=ui_disabled_html,
|
||||
status_code=200,
|
||||
)
|
||||
@@ -0,0 +1,17 @@
|
||||
# LiteLLM ASCII banner
|
||||
LITELLM_BANNER = """ ██╗ ██╗████████╗███████╗██╗ ██╗ ███╗ ███╗
|
||||
██║ ██║╚══██╔══╝██╔════╝██║ ██║ ████╗ ████║
|
||||
██║ ██║ ██║ █████╗ ██║ ██║ ██╔████╔██║
|
||||
██║ ██║ ██║ ██╔══╝ ██║ ██║ ██║╚██╔╝██║
|
||||
███████╗██║ ██║ ███████╗███████╗███████╗██║ ╚═╝ ██║
|
||||
╚══════╝╚═╝ ╚═╝ ╚══════╝╚══════╝╚══════╝╚═╝ ╚═╝"""
|
||||
|
||||
|
||||
def show_banner():
|
||||
"""Display the LiteLLM CLI banner."""
|
||||
try:
|
||||
import click
|
||||
|
||||
click.echo(f"\n{LITELLM_BANNER}\n")
|
||||
except ImportError:
|
||||
print("\n") # noqa: T201
|
||||
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
Event-driven cache coordinator to prevent cache stampede.
|
||||
|
||||
Use this when many requests can miss the same cache key at once (e.g. after
|
||||
expiry or restart). Without coordination, they would all run the expensive
|
||||
load (DB query, API call) in parallel and overload the backend.
|
||||
|
||||
This module ensures only one request performs the load; the rest wait for a
|
||||
signal and then read the freshly cached value. Reuse it for any cache-aside
|
||||
pattern: global spend, feature flags, config, or other shared read-through data.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Awaitable, Callable, Optional, Protocol, TypeVar
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class AsyncCacheProtocol(Protocol):
|
||||
"""Protocol for cache backends used by EventDrivenCacheCoordinator."""
|
||||
|
||||
async def async_get_cache(self, key: str, **kwargs: Any) -> Any:
|
||||
...
|
||||
|
||||
async def async_set_cache(self, key: str, value: Any, **kwargs: Any) -> Any:
|
||||
...
|
||||
|
||||
|
||||
class EventDrivenCacheCoordinator:
|
||||
"""
|
||||
Coordinates a single in-flight load per logical resource to prevent cache stampede.
|
||||
|
||||
Pattern:
|
||||
- First request: loads data (e.g. DB query), caches it, then signals waiters.
|
||||
- Other requests: wait for the signal, then read from cache.
|
||||
|
||||
Create one instance per resource (e.g. one for global spend, one for feature flags).
|
||||
"""
|
||||
|
||||
def __init__(self, log_prefix: str = "[CACHE]"):
|
||||
self._lock = asyncio.Lock()
|
||||
self._event: Optional[asyncio.Event] = None
|
||||
self._query_in_progress = False
|
||||
self._log_prefix = log_prefix
|
||||
|
||||
async def _get_cached(
|
||||
self, cache_key: str, cache: AsyncCacheProtocol
|
||||
) -> Optional[Any]:
|
||||
"""Return value from cache if present, else None."""
|
||||
return await cache.async_get_cache(key=cache_key)
|
||||
|
||||
def _log_cache_hit(self, value: T) -> None:
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug(
|
||||
"%s Cache hit, value: %s", self._log_prefix, value
|
||||
)
|
||||
|
||||
def _log_cache_miss(self) -> None:
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug("%s Cache miss", self._log_prefix)
|
||||
|
||||
async def _claim_role(self) -> Optional[asyncio.Event]:
|
||||
"""
|
||||
Under lock: return event to wait on if load is in progress, else set us as loader and return None.
|
||||
"""
|
||||
async with self._lock:
|
||||
if self._query_in_progress and self._event is not None:
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug(
|
||||
"%s Load in flight, waiting for signal", self._log_prefix
|
||||
)
|
||||
return self._event
|
||||
self._query_in_progress = True
|
||||
self._event = asyncio.Event()
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug(
|
||||
"%s Starting load (will signal others when done)",
|
||||
self._log_prefix,
|
||||
)
|
||||
return None
|
||||
|
||||
async def _wait_for_signal_and_get(
|
||||
self,
|
||||
event: asyncio.Event,
|
||||
cache_key: str,
|
||||
cache: AsyncCacheProtocol,
|
||||
) -> Optional[T]:
|
||||
"""Wait for loader to finish, then read from cache."""
|
||||
await event.wait()
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug(
|
||||
"%s Signal received, reading from cache", self._log_prefix
|
||||
)
|
||||
value: Optional[T] = await cache.async_get_cache(key=cache_key)
|
||||
if value is not None and self._log_prefix:
|
||||
verbose_proxy_logger.debug(
|
||||
"%s Cache filled by other request, value: %s",
|
||||
self._log_prefix,
|
||||
value,
|
||||
)
|
||||
elif value is None and self._log_prefix:
|
||||
verbose_proxy_logger.debug(
|
||||
"%s Signal received but cache still empty", self._log_prefix
|
||||
)
|
||||
return value
|
||||
|
||||
async def _load_and_cache(
|
||||
self,
|
||||
cache_key: str,
|
||||
cache: AsyncCacheProtocol,
|
||||
load_fn: Callable[[], Awaitable[T]],
|
||||
) -> Optional[T]:
|
||||
"""Double-check cache, run load_fn, set cache, return value. Caller must call _signal_done in finally."""
|
||||
value = await cache.async_get_cache(key=cache_key)
|
||||
if value is not None:
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug(
|
||||
"%s Cache filled while acquiring lock, value: %s",
|
||||
self._log_prefix,
|
||||
value,
|
||||
)
|
||||
return value
|
||||
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug("%s Running load", self._log_prefix)
|
||||
start = time.perf_counter()
|
||||
value = await load_fn()
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug(
|
||||
"%s Load completed in %.2fms, result: %s",
|
||||
self._log_prefix,
|
||||
elapsed_ms,
|
||||
value,
|
||||
)
|
||||
|
||||
await cache.async_set_cache(key=cache_key, value=value)
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug("%s Result cached", self._log_prefix)
|
||||
return value
|
||||
|
||||
async def _signal_done(self) -> None:
|
||||
"""Reset loader state and signal all waiters."""
|
||||
async with self._lock:
|
||||
self._query_in_progress = False
|
||||
if self._event is not None:
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug(
|
||||
"%s Signaling all waiting requests", self._log_prefix
|
||||
)
|
||||
self._event.set()
|
||||
self._event = None
|
||||
|
||||
async def get_or_load(
|
||||
self,
|
||||
cache_key: str,
|
||||
cache: AsyncCacheProtocol,
|
||||
load_fn: Callable[[], Awaitable[T]],
|
||||
) -> Optional[T]:
|
||||
"""
|
||||
Return cached value or load it once and signal waiters.
|
||||
|
||||
- cache_key: Key to read/write in the cache.
|
||||
- cache: Object with async_get_cache(key) and async_set_cache(key, value).
|
||||
- load_fn: Async callable that performs the load (e.g. DB query). No args.
|
||||
Return value is cached and returned. If it raises, waiters are
|
||||
still signaled so they can retry or handle empty cache.
|
||||
|
||||
Returns the value from cache or from load_fn, or None if load failed or
|
||||
cache was still empty after waiting.
|
||||
"""
|
||||
value = await self._get_cached(cache_key, cache)
|
||||
if value is not None:
|
||||
self._log_cache_hit(value)
|
||||
return value
|
||||
|
||||
self._log_cache_miss()
|
||||
event_to_wait = await self._claim_role()
|
||||
|
||||
if event_to_wait is not None:
|
||||
return await self._wait_for_signal_and_get(event_to_wait, cache_key, cache)
|
||||
|
||||
try:
|
||||
result = await self._load_and_cache(cache_key, cache, load_fn)
|
||||
return result
|
||||
finally:
|
||||
await self._signal_done()
|
||||
@@ -0,0 +1,526 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams
|
||||
from litellm.proxy.types_utils.utils import get_instance_fn
|
||||
from litellm.types.utils import (
|
||||
StandardLoggingGuardrailInformation,
|
||||
StandardLoggingPayload,
|
||||
)
|
||||
|
||||
blue_color_code = "\033[94m"
|
||||
reset_color_code = "\033[0m"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
|
||||
|
||||
def initialize_callbacks_on_proxy( # noqa: PLR0915
|
||||
value: Any,
|
||||
premium_user: bool,
|
||||
config_file_path: str,
|
||||
litellm_settings: dict,
|
||||
callback_specific_params: dict = {},
|
||||
):
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.logging_callback_manager import (
|
||||
LoggingCallbackManager,
|
||||
)
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"{blue_color_code}initializing callbacks={value} on proxy{reset_color_code}"
|
||||
)
|
||||
if isinstance(value, list):
|
||||
imported_list: List[Any] = []
|
||||
for callback in value: # ["presidio", <my-custom-callback>]
|
||||
# check if callback is a custom logger compatible callback
|
||||
if isinstance(callback, str):
|
||||
callback = LoggingCallbackManager._add_custom_callback_generic_api_str(
|
||||
callback
|
||||
)
|
||||
if (
|
||||
isinstance(callback, str)
|
||||
and callback in litellm._known_custom_logger_compatible_callbacks
|
||||
):
|
||||
imported_list.append(callback)
|
||||
elif isinstance(callback, str) and callback == "presidio":
|
||||
from litellm.proxy.guardrails.guardrail_hooks.presidio import (
|
||||
_OPTIONAL_PresidioPIIMasking,
|
||||
)
|
||||
|
||||
presidio_logging_only: Optional[bool] = litellm_settings.get(
|
||||
"presidio_logging_only", None
|
||||
)
|
||||
if presidio_logging_only is not None:
|
||||
presidio_logging_only = bool(
|
||||
presidio_logging_only
|
||||
) # validate boolean given
|
||||
|
||||
_presidio_params = {}
|
||||
if "presidio" in callback_specific_params and isinstance(
|
||||
callback_specific_params["presidio"], dict
|
||||
):
|
||||
_presidio_params = callback_specific_params["presidio"]
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"logging_only": presidio_logging_only,
|
||||
**_presidio_params,
|
||||
}
|
||||
pii_masking_object = _OPTIONAL_PresidioPIIMasking(**params)
|
||||
imported_list.append(pii_masking_object)
|
||||
elif isinstance(callback, str) and callback == "llamaguard_moderations":
|
||||
try:
|
||||
from litellm_enterprise.enterprise_callbacks.llama_guard import (
|
||||
_ENTERPRISE_LlamaGuard,
|
||||
)
|
||||
except ImportError:
|
||||
raise Exception(
|
||||
"MissingTrying to use Llama Guard"
|
||||
+ CommonProxyErrors.missing_enterprise_package.value
|
||||
)
|
||||
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use Llama Guard"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
llama_guard_object = _ENTERPRISE_LlamaGuard()
|
||||
imported_list.append(llama_guard_object)
|
||||
elif isinstance(callback, str) and callback == "hide_secrets":
|
||||
try:
|
||||
from litellm_enterprise.enterprise_callbacks.secret_detection import (
|
||||
_ENTERPRISE_SecretDetection,
|
||||
)
|
||||
except ImportError:
|
||||
raise Exception(
|
||||
"Trying to use Secret Detection"
|
||||
+ CommonProxyErrors.missing_enterprise_package.value
|
||||
)
|
||||
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use secret hiding"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
_secret_detection_object = _ENTERPRISE_SecretDetection()
|
||||
imported_list.append(_secret_detection_object)
|
||||
elif isinstance(callback, str) and callback == "openai_moderations":
|
||||
try:
|
||||
from enterprise.enterprise_hooks.openai_moderation import (
|
||||
_ENTERPRISE_OpenAI_Moderation,
|
||||
)
|
||||
except ImportError:
|
||||
raise Exception(
|
||||
"Trying to use OpenAI Moderations Check,"
|
||||
+ CommonProxyErrors.missing_enterprise_package_docker.value
|
||||
)
|
||||
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use OpenAI Moderations Check"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
openai_moderations_object = _ENTERPRISE_OpenAI_Moderation()
|
||||
imported_list.append(openai_moderations_object)
|
||||
elif isinstance(callback, str) and callback == "lakera_prompt_injection":
|
||||
from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import (
|
||||
lakeraAI_Moderation,
|
||||
)
|
||||
|
||||
init_params = {}
|
||||
if "lakera_prompt_injection" in callback_specific_params:
|
||||
init_params = callback_specific_params["lakera_prompt_injection"]
|
||||
lakera_moderations_object = lakeraAI_Moderation(**init_params)
|
||||
imported_list.append(lakera_moderations_object)
|
||||
elif isinstance(callback, str) and callback == "aporia_prompt_injection":
|
||||
from litellm.proxy.guardrails.guardrail_hooks.aporia_ai.aporia_ai import (
|
||||
AporiaGuardrail,
|
||||
)
|
||||
|
||||
aporia_guardrail_object = AporiaGuardrail()
|
||||
imported_list.append(aporia_guardrail_object)
|
||||
elif isinstance(callback, str) and callback == "google_text_moderation":
|
||||
try:
|
||||
from enterprise.enterprise_hooks.google_text_moderation import (
|
||||
_ENTERPRISE_GoogleTextModeration,
|
||||
)
|
||||
except ImportError:
|
||||
raise Exception(
|
||||
"Trying to use Google Text Moderation,"
|
||||
+ CommonProxyErrors.missing_enterprise_package_docker.value
|
||||
)
|
||||
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use Google Text Moderation"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
google_text_moderation_obj = _ENTERPRISE_GoogleTextModeration()
|
||||
imported_list.append(google_text_moderation_obj)
|
||||
elif isinstance(callback, str) and callback == "llmguard_moderations":
|
||||
try:
|
||||
from litellm_enterprise.enterprise_callbacks.llm_guard import (
|
||||
_ENTERPRISE_LLMGuard,
|
||||
)
|
||||
except ImportError:
|
||||
raise Exception(
|
||||
"Trying to use Llm Guard"
|
||||
+ CommonProxyErrors.missing_enterprise_package.value
|
||||
)
|
||||
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use Llm Guard"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
llm_guard_moderation_obj = _ENTERPRISE_LLMGuard()
|
||||
imported_list.append(llm_guard_moderation_obj)
|
||||
elif isinstance(callback, str) and callback == "blocked_user_check":
|
||||
try:
|
||||
from enterprise.enterprise_hooks.blocked_user_list import (
|
||||
_ENTERPRISE_BlockedUserList,
|
||||
)
|
||||
except ImportError:
|
||||
raise Exception(
|
||||
"Trying to use Blocked User List"
|
||||
+ CommonProxyErrors.missing_enterprise_package_docker.value
|
||||
)
|
||||
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use ENTERPRISE BlockedUser"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
blocked_user_list = _ENTERPRISE_BlockedUserList(
|
||||
prisma_client=prisma_client
|
||||
)
|
||||
imported_list.append(blocked_user_list)
|
||||
elif isinstance(callback, str) and callback == "banned_keywords":
|
||||
try:
|
||||
from enterprise.enterprise_hooks.banned_keywords import (
|
||||
_ENTERPRISE_BannedKeywords,
|
||||
)
|
||||
except ImportError:
|
||||
raise Exception(
|
||||
"Trying to use Banned Keywords"
|
||||
+ CommonProxyErrors.missing_enterprise_package_docker.value
|
||||
)
|
||||
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use ENTERPRISE BannedKeyword"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
|
||||
imported_list.append(banned_keywords_obj)
|
||||
elif isinstance(callback, str) and callback == "detect_prompt_injection":
|
||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||
_OPTIONAL_PromptInjectionDetection,
|
||||
)
|
||||
|
||||
prompt_injection_params = None
|
||||
if "prompt_injection_params" in litellm_settings:
|
||||
prompt_injection_params_in_config = litellm_settings[
|
||||
"prompt_injection_params"
|
||||
]
|
||||
prompt_injection_params = LiteLLMPromptInjectionParams(
|
||||
**prompt_injection_params_in_config
|
||||
)
|
||||
|
||||
prompt_injection_detection_obj = _OPTIONAL_PromptInjectionDetection(
|
||||
prompt_injection_params=prompt_injection_params,
|
||||
)
|
||||
imported_list.append(prompt_injection_detection_obj)
|
||||
elif isinstance(callback, str) and callback == "batch_redis_requests":
|
||||
from litellm.proxy.hooks.batch_redis_get import (
|
||||
_PROXY_BatchRedisRequests,
|
||||
)
|
||||
|
||||
batch_redis_obj = _PROXY_BatchRedisRequests()
|
||||
imported_list.append(batch_redis_obj)
|
||||
elif isinstance(callback, str) and callback == "azure_content_safety":
|
||||
from litellm.proxy.hooks.azure_content_safety import (
|
||||
_PROXY_AzureContentSafety,
|
||||
)
|
||||
|
||||
azure_content_safety_params = litellm_settings[
|
||||
"azure_content_safety_params"
|
||||
]
|
||||
for k, v in azure_content_safety_params.items():
|
||||
if (
|
||||
v is not None
|
||||
and isinstance(v, str)
|
||||
and v.startswith("os.environ/")
|
||||
):
|
||||
azure_content_safety_params[k] = get_secret(v)
|
||||
|
||||
azure_content_safety_obj = _PROXY_AzureContentSafety(
|
||||
**azure_content_safety_params,
|
||||
)
|
||||
imported_list.append(azure_content_safety_obj)
|
||||
elif isinstance(callback, str) and callback == "websearch_interception":
|
||||
from litellm.integrations.websearch_interception.handler import (
|
||||
WebSearchInterceptionLogger,
|
||||
)
|
||||
|
||||
websearch_interception_obj = (
|
||||
WebSearchInterceptionLogger.initialize_from_proxy_config(
|
||||
litellm_settings=litellm_settings,
|
||||
callback_specific_params=callback_specific_params,
|
||||
)
|
||||
)
|
||||
imported_list.append(websearch_interception_obj)
|
||||
elif isinstance(callback, str) and callback == "datadog_cost_management":
|
||||
from litellm.integrations.datadog.datadog_cost_management import (
|
||||
DatadogCostManagementLogger,
|
||||
)
|
||||
|
||||
datadog_cost_management_obj = DatadogCostManagementLogger()
|
||||
imported_list.append(datadog_cost_management_obj)
|
||||
elif isinstance(callback, CustomLogger):
|
||||
imported_list.append(callback)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
f"{blue_color_code} attempting to import custom calback={callback} {reset_color_code}"
|
||||
)
|
||||
imported_list.append(
|
||||
get_instance_fn(
|
||||
value=callback,
|
||||
config_file_path=config_file_path,
|
||||
)
|
||||
)
|
||||
if isinstance(litellm.callbacks, list):
|
||||
litellm.callbacks.extend(imported_list)
|
||||
else:
|
||||
litellm.callbacks = imported_list # type: ignore
|
||||
|
||||
if "prometheus" in value:
|
||||
from litellm.integrations.prometheus import PrometheusLogger
|
||||
|
||||
PrometheusLogger._mount_metrics_endpoint()
|
||||
else:
|
||||
litellm.callbacks = [
|
||||
get_instance_fn(
|
||||
value=value,
|
||||
config_file_path=config_file_path,
|
||||
)
|
||||
]
|
||||
verbose_proxy_logger.debug(
|
||||
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
|
||||
)
|
||||
|
||||
|
||||
def get_model_group_from_litellm_kwargs(kwargs: dict) -> Optional[str]:
|
||||
_litellm_params = kwargs.get("litellm_params", None) or {}
|
||||
_metadata = (
|
||||
_litellm_params.get(get_metadata_variable_name_from_kwargs(kwargs)) or {}
|
||||
)
|
||||
_model_group = _metadata.get("model_group", None)
|
||||
if _model_group is not None:
|
||||
return _model_group
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_model_group_from_request_data(data: dict) -> Optional[str]:
|
||||
_metadata = data.get("metadata", None) or {}
|
||||
_model_group = _metadata.get("model_group", None)
|
||||
if _model_group is not None:
|
||||
return _model_group
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str, str]:
|
||||
"""
|
||||
Helper function to return x-litellm-key-remaining-tokens-{model_group} and x-litellm-key-remaining-requests-{model_group}
|
||||
|
||||
Returns {} when api_key + model rpm/tpm limit is not set
|
||||
|
||||
"""
|
||||
headers = {}
|
||||
_metadata = data.get("metadata", None) or {}
|
||||
model_group = get_model_group_from_request_data(data)
|
||||
|
||||
# The h11 package considers "/" or ":" invalid and raise a LocalProtocolError
|
||||
h11_model_group_name = (
|
||||
model_group.replace("/", "-").replace(":", "-") if model_group else None
|
||||
)
|
||||
|
||||
# Remaining Requests
|
||||
remaining_requests_variable_name = f"litellm-key-remaining-requests-{model_group}"
|
||||
remaining_requests = _metadata.get(remaining_requests_variable_name, None)
|
||||
if remaining_requests:
|
||||
headers[
|
||||
f"x-litellm-key-remaining-requests-{h11_model_group_name}"
|
||||
] = remaining_requests
|
||||
|
||||
# Remaining Tokens
|
||||
remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}"
|
||||
remaining_tokens = _metadata.get(remaining_tokens_variable_name, None)
|
||||
if remaining_tokens:
|
||||
headers[
|
||||
f"x-litellm-key-remaining-tokens-{h11_model_group_name}"
|
||||
] = remaining_tokens
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]:
|
||||
_metadata = request_data.get("metadata", None)
|
||||
if not _metadata:
|
||||
_metadata = request_data.get("litellm_metadata", None)
|
||||
if not isinstance(_metadata, dict):
|
||||
_metadata = {}
|
||||
headers = {}
|
||||
if "applied_guardrails" in _metadata:
|
||||
headers["x-litellm-applied-guardrails"] = ",".join(
|
||||
_metadata["applied_guardrails"]
|
||||
)
|
||||
|
||||
if "applied_policies" in _metadata:
|
||||
headers["x-litellm-applied-policies"] = ",".join(_metadata["applied_policies"])
|
||||
|
||||
if "policy_sources" in _metadata:
|
||||
sources = _metadata["policy_sources"]
|
||||
if isinstance(sources, dict) and sources:
|
||||
# Use ';' as delimiter — matched_via reasons may contain commas
|
||||
headers["x-litellm-policy-sources"] = "; ".join(
|
||||
f"{name}={reason}" for name, reason in sources.items()
|
||||
)
|
||||
|
||||
if "semantic-similarity" in _metadata:
|
||||
headers["x-litellm-semantic-similarity"] = str(_metadata["semantic-similarity"])
|
||||
|
||||
pillar_headers = _metadata.get("pillar_response_headers")
|
||||
if isinstance(pillar_headers, dict):
|
||||
headers.update(pillar_headers)
|
||||
elif "pillar_flagged" in _metadata:
|
||||
headers["x-pillar-flagged"] = str(_metadata["pillar_flagged"]).lower()
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def add_guardrail_to_applied_guardrails_header(
|
||||
request_data: Dict, guardrail_name: Optional[str]
|
||||
):
|
||||
if guardrail_name is None:
|
||||
return
|
||||
_metadata = request_data.get("metadata", None) or {}
|
||||
if "applied_guardrails" in _metadata:
|
||||
_metadata["applied_guardrails"].append(guardrail_name)
|
||||
else:
|
||||
_metadata["applied_guardrails"] = [guardrail_name]
|
||||
# Ensure metadata is set back to request_data (important when metadata didn't exist)
|
||||
request_data["metadata"] = _metadata
|
||||
|
||||
|
||||
def add_policy_to_applied_policies_header(
|
||||
request_data: Dict, policy_name: Optional[str]
|
||||
):
|
||||
"""
|
||||
Add a policy name to the applied_policies list in request metadata.
|
||||
|
||||
This is used to track which policies were applied to a request,
|
||||
similar to how applied_guardrails tracks guardrails.
|
||||
"""
|
||||
if policy_name is None:
|
||||
return
|
||||
_metadata = request_data.get("metadata", None) or {}
|
||||
if "applied_policies" in _metadata:
|
||||
if policy_name not in _metadata["applied_policies"]:
|
||||
_metadata["applied_policies"].append(policy_name)
|
||||
else:
|
||||
_metadata["applied_policies"] = [policy_name]
|
||||
# Ensure metadata is set back to request_data (important when metadata didn't exist)
|
||||
request_data["metadata"] = _metadata
|
||||
|
||||
|
||||
def add_policy_sources_to_metadata(request_data: Dict, policy_sources: Dict[str, str]):
|
||||
"""
|
||||
Store policy match reasons in metadata for x-litellm-policy-sources header.
|
||||
|
||||
Args:
|
||||
request_data: The request data dict
|
||||
policy_sources: Map of policy_name -> matched_via reason
|
||||
"""
|
||||
if not policy_sources:
|
||||
return
|
||||
_metadata = request_data.get("metadata", None) or {}
|
||||
existing = _metadata.get("policy_sources", {})
|
||||
if not isinstance(existing, dict):
|
||||
existing = {}
|
||||
existing.update(policy_sources)
|
||||
_metadata["policy_sources"] = existing
|
||||
request_data["metadata"] = _metadata
|
||||
|
||||
|
||||
def add_guardrail_response_to_standard_logging_object(
|
||||
litellm_logging_obj: Optional["LiteLLMLogging"],
|
||||
guardrail_response: StandardLoggingGuardrailInformation,
|
||||
):
|
||||
if litellm_logging_obj is None:
|
||||
return
|
||||
standard_logging_object: Optional[
|
||||
StandardLoggingPayload
|
||||
] = litellm_logging_obj.model_call_details.get("standard_logging_object")
|
||||
if standard_logging_object is None:
|
||||
return
|
||||
guardrail_information = standard_logging_object.get("guardrail_information", [])
|
||||
if guardrail_information is None:
|
||||
guardrail_information = []
|
||||
guardrail_information.append(guardrail_response)
|
||||
standard_logging_object["guardrail_information"] = guardrail_information
|
||||
|
||||
return standard_logging_object
|
||||
|
||||
|
||||
def get_metadata_variable_name_from_kwargs(
|
||||
kwargs: dict,
|
||||
) -> Literal["metadata", "litellm_metadata"]:
|
||||
"""
|
||||
Helper to return what the "metadata" field should be called in the request data
|
||||
|
||||
- New endpoints return `litellm_metadata`
|
||||
- Old endpoints return `metadata`
|
||||
|
||||
Context:
|
||||
- LiteLLM used `metadata` as an internal field for storing metadata
|
||||
- OpenAI then started using this field for their metadata
|
||||
- LiteLLM is now moving to using `litellm_metadata` for our metadata
|
||||
"""
|
||||
return "litellm_metadata" if "litellm_metadata" in kwargs else "metadata"
|
||||
|
||||
|
||||
def process_callback(
|
||||
_callback: str, callback_type: str, environment_variables: dict
|
||||
) -> dict:
|
||||
"""Process a single callback and return its data with environment variables"""
|
||||
env_vars = CustomLogger.get_callback_env_vars(_callback)
|
||||
|
||||
env_vars_dict: dict[str, str | None] = {}
|
||||
for _var in env_vars:
|
||||
env_variable = environment_variables.get(_var, None)
|
||||
if env_variable is None:
|
||||
env_vars_dict[_var] = None
|
||||
else:
|
||||
env_vars_dict[_var] = env_variable
|
||||
|
||||
return {"name": _callback, "variables": env_vars_dict, "type": callback_type}
|
||||
|
||||
|
||||
def normalize_callback_names(callbacks: Iterable[Any]) -> List[Any]:
|
||||
if callbacks is None:
|
||||
return []
|
||||
return [c.lower() if isinstance(c, str) else c for c in callbacks]
|
||||
@@ -0,0 +1,437 @@
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
|
||||
class CustomOpenAPISpec:
|
||||
"""
|
||||
Handler for customizing OpenAPI specifications with Pydantic models
|
||||
for documentation purposes without runtime validation.
|
||||
"""
|
||||
|
||||
CHAT_COMPLETION_PATHS = [
|
||||
"/v1/chat/completions",
|
||||
"/chat/completions",
|
||||
"/engines/{model}/chat/completions",
|
||||
"/openai/deployments/{model}/chat/completions",
|
||||
]
|
||||
|
||||
EMBEDDING_PATHS = [
|
||||
"/v1/embeddings",
|
||||
"/embeddings",
|
||||
"/engines/{model}/embeddings",
|
||||
"/openai/deployments/{model}/embeddings",
|
||||
]
|
||||
|
||||
RESPONSES_API_PATHS = ["/v1/responses", "/responses"]
|
||||
|
||||
@staticmethod
|
||||
def get_pydantic_schema(model_class) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get JSON schema from a Pydantic model, handling both v1 and v2 APIs.
|
||||
|
||||
Args:
|
||||
model_class: Pydantic model class
|
||||
|
||||
Returns:
|
||||
JSON schema dict or None if failed
|
||||
"""
|
||||
try:
|
||||
# Try Pydantic v2 method first
|
||||
return model_class.model_json_schema() # type: ignore
|
||||
except AttributeError:
|
||||
try:
|
||||
# Fallback to Pydantic v1 method
|
||||
return model_class.schema() # type: ignore
|
||||
except AttributeError:
|
||||
# If both methods fail, return None
|
||||
return None
|
||||
except Exception as e:
|
||||
# FastAPI 0.120+ may fail schema generation for certain types (e.g., openai.Timeout)
|
||||
# Log the error and return None to skip schema generation for this model
|
||||
verbose_proxy_logger.debug(
|
||||
f"Failed to generate schema for {model_class}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def add_schema_to_components(
|
||||
openapi_schema: Dict[str, Any], schema_name: str, schema_def: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Add a schema definition to the OpenAPI components/schemas section.
|
||||
|
||||
Args:
|
||||
openapi_schema: The OpenAPI schema dict to modify
|
||||
schema_name: Name for the schema component
|
||||
schema_def: The schema definition
|
||||
"""
|
||||
# Ensure components/schemas structure exists
|
||||
if "components" not in openapi_schema:
|
||||
openapi_schema["components"] = {}
|
||||
if "schemas" not in openapi_schema["components"]:
|
||||
openapi_schema["components"]["schemas"] = {}
|
||||
|
||||
# Add the schema
|
||||
CustomOpenAPISpec._move_defs_to_components(
|
||||
openapi_schema, {schema_name: schema_def}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_request_body_to_paths(
|
||||
openapi_schema: Dict[str, Any], paths: List[str], schema_ref: str
|
||||
) -> None:
|
||||
"""
|
||||
Add request body with expanded form fields for better Swagger UI display.
|
||||
This keeps the request body but expands it to show individual fields in the UI.
|
||||
|
||||
Args:
|
||||
openapi_schema: The OpenAPI schema dict to modify
|
||||
paths: List of paths to update
|
||||
schema_ref: Reference to the schema component (e.g., "#/components/schemas/ModelName")
|
||||
"""
|
||||
for path in paths:
|
||||
if (
|
||||
path in openapi_schema.get("paths", {})
|
||||
and "post" in openapi_schema["paths"][path]
|
||||
):
|
||||
# Get the actual schema to extract ALL field definitions
|
||||
schema_name = schema_ref.split("/")[
|
||||
-1
|
||||
] # Extract "ProxyChatCompletionRequest" from the ref
|
||||
actual_schema = (
|
||||
openapi_schema.get("components", {})
|
||||
.get("schemas", {})
|
||||
.get(schema_name, {})
|
||||
)
|
||||
schema_properties = actual_schema.get("properties", {})
|
||||
required_fields = actual_schema.get("required", [])
|
||||
|
||||
# Extract $defs and add them to components/schemas
|
||||
# This fixes Pydantic v2 $defs not being resolvable in Swagger/OpenAPI
|
||||
if "$defs" in actual_schema:
|
||||
CustomOpenAPISpec._move_defs_to_components(
|
||||
openapi_schema, actual_schema["$defs"]
|
||||
)
|
||||
|
||||
# Create an expanded inline schema instead of just a $ref
|
||||
# This makes Swagger UI show all individual fields in the request body editor
|
||||
expanded_schema = {
|
||||
"type": "object",
|
||||
"required": required_fields,
|
||||
"properties": {},
|
||||
}
|
||||
|
||||
# Add all properties with their full definitions
|
||||
for field_name, field_def in schema_properties.items():
|
||||
expanded_field = CustomOpenAPISpec._expand_field_definition(
|
||||
field_def
|
||||
)
|
||||
|
||||
# Rewrite $defs references to use components/schemas instead
|
||||
expanded_field = CustomOpenAPISpec._rewrite_defs_refs(
|
||||
expanded_field
|
||||
)
|
||||
|
||||
# Add a simple example for the messages field
|
||||
if field_name == "messages":
|
||||
expanded_field["example"] = [
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
|
||||
expanded_schema["properties"][field_name] = expanded_field
|
||||
|
||||
# Set the request body with the expanded schema
|
||||
openapi_schema["paths"][path]["post"]["requestBody"] = {
|
||||
"required": True,
|
||||
"content": {"application/json": {"schema": expanded_schema}},
|
||||
}
|
||||
|
||||
# Keep any existing parameters (like path parameters) but remove conflicting query params
|
||||
if "parameters" in openapi_schema["paths"][path]["post"]:
|
||||
existing_params = openapi_schema["paths"][path]["post"][
|
||||
"parameters"
|
||||
]
|
||||
# Only keep path parameters, remove query params that conflict with request body
|
||||
filtered_params = [
|
||||
param for param in existing_params if param.get("in") == "path"
|
||||
]
|
||||
openapi_schema["paths"][path]["post"][
|
||||
"parameters"
|
||||
] = filtered_params
|
||||
|
||||
@staticmethod
|
||||
def _move_defs_to_components(
|
||||
openapi_schema: Dict[str, Any], defs: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Move $defs from Pydantic v2 schema to OpenAPI components/schemas.
|
||||
This makes the definitions resolvable in Swagger/OpenAPI viewers.
|
||||
|
||||
Args:
|
||||
openapi_schema: The OpenAPI schema dict to modify
|
||||
defs: The $defs dictionary from Pydantic schema
|
||||
"""
|
||||
if not defs:
|
||||
return
|
||||
|
||||
# Ensure components/schemas exists
|
||||
if "components" not in openapi_schema:
|
||||
openapi_schema["components"] = {}
|
||||
if "schemas" not in openapi_schema["components"]:
|
||||
openapi_schema["components"]["schemas"] = {}
|
||||
|
||||
# Add each definition to components/schemas
|
||||
for def_name, def_schema in defs.items():
|
||||
# Recursively rewrite any nested $defs references within this definition
|
||||
rewritten_def = CustomOpenAPISpec._rewrite_defs_refs(def_schema)
|
||||
openapi_schema["components"]["schemas"][def_name] = rewritten_def
|
||||
|
||||
# If this definition also has $defs, process them recursively
|
||||
if "$defs" in def_schema:
|
||||
CustomOpenAPISpec._move_defs_to_components(
|
||||
openapi_schema, def_schema["$defs"]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _rewrite_defs_refs(schema: Any) -> Any:
|
||||
"""
|
||||
Recursively rewrite $ref values from #/$defs/... to #/components/schemas/...
|
||||
This converts Pydantic v2 references to OpenAPI-compatible references.
|
||||
|
||||
Args:
|
||||
schema: Schema object to process (can be dict, list, or primitive)
|
||||
|
||||
Returns:
|
||||
Schema with rewritten references
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
result = {}
|
||||
for key, value in schema.items():
|
||||
if (
|
||||
key == "$ref"
|
||||
and isinstance(value, str)
|
||||
and value.startswith("#/$defs/")
|
||||
):
|
||||
# Rewrite the reference to use components/schemas
|
||||
def_name = value.replace("#/$defs/", "")
|
||||
result[key] = f"#/components/schemas/{def_name}"
|
||||
elif key == "$defs":
|
||||
# Remove $defs from the schema since they're moved to components
|
||||
continue
|
||||
else:
|
||||
# Recursively process nested structures
|
||||
result[key] = CustomOpenAPISpec._rewrite_defs_refs(value)
|
||||
return result
|
||||
elif isinstance(schema, list):
|
||||
return [CustomOpenAPISpec._rewrite_defs_refs(item) for item in schema]
|
||||
else:
|
||||
return schema
|
||||
|
||||
@staticmethod
|
||||
def _extract_field_schema(field_def: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract a simple schema from a Pydantic field definition for parameter display.
|
||||
|
||||
Args:
|
||||
field_def: Pydantic field definition
|
||||
|
||||
Returns:
|
||||
Simplified schema for OpenAPI parameter
|
||||
"""
|
||||
# Handle simple types
|
||||
if "type" in field_def:
|
||||
return {"type": field_def["type"]}
|
||||
|
||||
# Handle anyOf (Optional fields in Pydantic v2)
|
||||
if "anyOf" in field_def:
|
||||
any_of = field_def["anyOf"]
|
||||
# Find the non-null type
|
||||
for option in any_of:
|
||||
if option.get("type") != "null":
|
||||
return option
|
||||
# Fallback to string if all else fails
|
||||
return {"type": "string"}
|
||||
|
||||
# Default fallback
|
||||
return {"type": "string"}
|
||||
|
||||
@staticmethod
|
||||
def _expand_field_definition(field_def: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Expand a Pydantic field definition for inline use in OpenAPI schema.
|
||||
This creates a full field definition that Swagger UI can render as individual form fields.
|
||||
|
||||
Args:
|
||||
field_def: Pydantic field definition
|
||||
|
||||
Returns:
|
||||
Expanded field definition for OpenAPI schema
|
||||
"""
|
||||
# Return the field definition as-is since Pydantic already provides proper schemas
|
||||
return field_def.copy()
|
||||
|
||||
@staticmethod
|
||||
def add_request_schema(
|
||||
openapi_schema: Dict[str, Any],
|
||||
model_class: Type,
|
||||
schema_name: str,
|
||||
paths: List[str],
|
||||
operation_name: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generic method to add a request schema to OpenAPI specification.
|
||||
|
||||
Args:
|
||||
openapi_schema: The OpenAPI schema dict to modify
|
||||
model_class: The Pydantic model class to get schema from
|
||||
schema_name: Name for the schema component
|
||||
paths: List of paths to add the request body to
|
||||
operation_name: Name of the operation for logging (e.g., "chat completion", "embedding")
|
||||
|
||||
Returns:
|
||||
Modified OpenAPI schema
|
||||
"""
|
||||
try:
|
||||
# Get the schema for the model class
|
||||
request_schema = CustomOpenAPISpec.get_pydantic_schema(model_class)
|
||||
|
||||
# Only proceed if we successfully got the schema
|
||||
if request_schema is not None:
|
||||
# Add schema to components
|
||||
CustomOpenAPISpec.add_schema_to_components(
|
||||
openapi_schema, schema_name, request_schema
|
||||
)
|
||||
|
||||
# Add request body to specified endpoints
|
||||
CustomOpenAPISpec.add_request_body_to_paths(
|
||||
openapi_schema, paths, f"#/components/schemas/{schema_name}"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Successfully added {schema_name} schema to OpenAPI spec"
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.debug(f"Could not get schema for {schema_name}")
|
||||
|
||||
except Exception as e:
|
||||
# If schema addition fails, continue without it
|
||||
verbose_proxy_logger.debug(
|
||||
f"Failed to add {operation_name} request schema: {str(e)}"
|
||||
)
|
||||
|
||||
return openapi_schema
|
||||
|
||||
@staticmethod
|
||||
def add_chat_completion_request_schema(
|
||||
openapi_schema: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Add ProxyChatCompletionRequest schema to chat completion endpoints for documentation.
|
||||
This shows the request body in Swagger without runtime validation.
|
||||
|
||||
Args:
|
||||
openapi_schema: The OpenAPI schema dict to modify
|
||||
|
||||
Returns:
|
||||
Modified OpenAPI schema
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy._types import ProxyChatCompletionRequest
|
||||
|
||||
return CustomOpenAPISpec.add_request_schema(
|
||||
openapi_schema=openapi_schema,
|
||||
model_class=ProxyChatCompletionRequest,
|
||||
schema_name="ProxyChatCompletionRequest",
|
||||
paths=CustomOpenAPISpec.CHAT_COMPLETION_PATHS,
|
||||
operation_name="chat completion",
|
||||
)
|
||||
except ImportError as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Failed to import ProxyChatCompletionRequest: {str(e)}"
|
||||
)
|
||||
return openapi_schema
|
||||
|
||||
@staticmethod
|
||||
def add_embedding_request_schema(openapi_schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Add EmbeddingRequest schema to embedding endpoints for documentation.
|
||||
This shows the request body in Swagger without runtime validation.
|
||||
|
||||
Args:
|
||||
openapi_schema: The OpenAPI schema dict to modify
|
||||
|
||||
Returns:
|
||||
Modified OpenAPI schema
|
||||
"""
|
||||
try:
|
||||
from litellm.types.embedding import EmbeddingRequest
|
||||
|
||||
return CustomOpenAPISpec.add_request_schema(
|
||||
openapi_schema=openapi_schema,
|
||||
model_class=EmbeddingRequest,
|
||||
schema_name="EmbeddingRequest",
|
||||
paths=CustomOpenAPISpec.EMBEDDING_PATHS,
|
||||
operation_name="embedding",
|
||||
)
|
||||
except ImportError as e:
|
||||
verbose_proxy_logger.debug(f"Failed to import EmbeddingRequest: {str(e)}")
|
||||
return openapi_schema
|
||||
|
||||
@staticmethod
|
||||
def add_responses_api_request_schema(
|
||||
openapi_schema: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Add ResponsesAPIRequestParams schema to responses API endpoints for documentation.
|
||||
This shows the request body in Swagger without runtime validation.
|
||||
|
||||
Args:
|
||||
openapi_schema: The OpenAPI schema dict to modify
|
||||
|
||||
Returns:
|
||||
Modified OpenAPI schema
|
||||
"""
|
||||
try:
|
||||
from litellm.types.llms.openai import ResponsesAPIRequestParams
|
||||
|
||||
return CustomOpenAPISpec.add_request_schema(
|
||||
openapi_schema=openapi_schema,
|
||||
model_class=ResponsesAPIRequestParams,
|
||||
schema_name="ResponsesAPIRequestParams",
|
||||
paths=CustomOpenAPISpec.RESPONSES_API_PATHS,
|
||||
operation_name="responses API",
|
||||
)
|
||||
except ImportError as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Failed to import ResponsesAPIRequestParams: {str(e)}"
|
||||
)
|
||||
return openapi_schema
|
||||
|
||||
@staticmethod
|
||||
def add_llm_api_request_schema_body(
|
||||
openapi_schema: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Add LLM API request schema bodies to OpenAPI specification for documentation.
|
||||
|
||||
Args:
|
||||
openapi_schema: The base OpenAPI schema
|
||||
|
||||
Returns:
|
||||
OpenAPI schema with added request body schemas
|
||||
"""
|
||||
# Add chat completion request schema
|
||||
openapi_schema = CustomOpenAPISpec.add_chat_completion_request_schema(
|
||||
openapi_schema
|
||||
)
|
||||
|
||||
# Add embedding request schema
|
||||
openapi_schema = CustomOpenAPISpec.add_embedding_request_schema(openapi_schema)
|
||||
|
||||
# Add responses API request schema
|
||||
openapi_schema = CustomOpenAPISpec.add_responses_api_request_schema(
|
||||
openapi_schema
|
||||
)
|
||||
|
||||
return openapi_schema
|
||||
@@ -0,0 +1,832 @@
|
||||
# Start tracing memory allocations
|
||||
import asyncio
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tracemalloc
|
||||
from collections import Counter
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from litellm import get_secret_str
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import PYTHON_GC_THRESHOLD
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Configure garbage collection thresholds from environment variables
|
||||
def configure_gc_thresholds():
|
||||
"""Configure Python garbage collection thresholds from environment variables."""
|
||||
gc_threshold_env = PYTHON_GC_THRESHOLD
|
||||
if gc_threshold_env:
|
||||
try:
|
||||
# Parse threshold string like "1000,50,50"
|
||||
thresholds = [int(x.strip()) for x in gc_threshold_env.split(",")]
|
||||
if len(thresholds) == 3:
|
||||
gc.set_threshold(*thresholds)
|
||||
verbose_proxy_logger.info(f"GC thresholds set to: {thresholds}")
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
f"GC threshold not set: {gc_threshold_env}. Expected format: 'gen0,gen1,gen2'"
|
||||
)
|
||||
except ValueError as e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Failed to parse GC threshold: {gc_threshold_env}. Error: {e}"
|
||||
)
|
||||
|
||||
# Log current thresholds
|
||||
current_thresholds = gc.get_threshold()
|
||||
verbose_proxy_logger.info(
|
||||
f"Current GC thresholds: gen0={current_thresholds[0]}, gen1={current_thresholds[1]}, gen2={current_thresholds[2]}"
|
||||
)
|
||||
|
||||
|
||||
# Initialize GC configuration
|
||||
configure_gc_thresholds()
|
||||
|
||||
|
||||
@router.get("/debug/asyncio-tasks")
|
||||
async def get_active_tasks_stats():
|
||||
"""
|
||||
Returns:
|
||||
total_active_tasks: int
|
||||
by_name: { coroutine_name: count }
|
||||
"""
|
||||
MAX_TASKS_TO_CHECK = 5000
|
||||
# Gather all tasks in this event loop (including this endpoint’s own task).
|
||||
all_tasks = asyncio.all_tasks()
|
||||
|
||||
# Filter out tasks that are already done.
|
||||
active_tasks = [t for t in all_tasks if not t.done()]
|
||||
|
||||
# Count how many active tasks exist, grouped by coroutine function name.
|
||||
counter = Counter()
|
||||
for idx, task in enumerate(active_tasks):
|
||||
# reasonable max circuit breaker
|
||||
if idx >= MAX_TASKS_TO_CHECK:
|
||||
break
|
||||
coro = task.get_coro()
|
||||
# Derive a human‐readable name from the coroutine:
|
||||
name = (
|
||||
getattr(coro, "__qualname__", None)
|
||||
or getattr(coro, "__name__", None)
|
||||
or repr(coro)
|
||||
)
|
||||
counter[name] += 1
|
||||
|
||||
return {
|
||||
"total_active_tasks": len(active_tasks),
|
||||
"by_name": dict(counter),
|
||||
}
|
||||
|
||||
|
||||
if os.environ.get("LITELLM_PROFILE", "false").lower() == "true":
|
||||
try:
|
||||
import objgraph # type: ignore
|
||||
|
||||
print("growth of objects") # noqa
|
||||
objgraph.show_growth()
|
||||
print("\n\nMost common types") # noqa
|
||||
objgraph.show_most_common_types()
|
||||
roots = objgraph.get_leaking_objects()
|
||||
print("\n\nLeaking objects") # noqa
|
||||
objgraph.show_most_common_types(objects=roots)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"objgraph not found. Please install objgraph to use this feature."
|
||||
)
|
||||
|
||||
tracemalloc.start(10)
|
||||
|
||||
@router.get("/memory-usage", include_in_schema=False)
|
||||
async def memory_usage():
|
||||
# Take a snapshot of the current memory usage
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
top_stats = snapshot.statistics("lineno")
|
||||
verbose_proxy_logger.debug("TOP STATS: %s", top_stats)
|
||||
|
||||
# Get the top 50 memory usage lines
|
||||
top_50 = top_stats[:50]
|
||||
result = []
|
||||
for stat in top_50:
|
||||
result.append(f"{stat.traceback.format(limit=10)}: {stat.size / 1024} KiB")
|
||||
|
||||
return {"top_50_memory_usage": result}
|
||||
|
||||
|
||||
@router.get("/memory-usage-in-mem-cache", include_in_schema=False)
|
||||
async def memory_usage_in_mem_cache(
|
||||
_: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
# returns the size of all in-memory caches on the proxy server
|
||||
"""
|
||||
1. user_api_key_cache
|
||||
2. router_cache
|
||||
3. proxy_logging_cache
|
||||
4. internal_usage_cache
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
llm_router,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if llm_router is None:
|
||||
num_items_in_llm_router_cache = 0
|
||||
else:
|
||||
num_items_in_llm_router_cache = len(
|
||||
llm_router.cache.in_memory_cache.cache_dict
|
||||
) + len(llm_router.cache.in_memory_cache.ttl_dict)
|
||||
|
||||
num_items_in_user_api_key_cache = len(
|
||||
user_api_key_cache.in_memory_cache.cache_dict
|
||||
) + len(user_api_key_cache.in_memory_cache.ttl_dict)
|
||||
|
||||
num_items_in_proxy_logging_obj_cache = len(
|
||||
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
|
||||
) + len(proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict)
|
||||
|
||||
return {
|
||||
"num_items_in_user_api_key_cache": num_items_in_user_api_key_cache,
|
||||
"num_items_in_llm_router_cache": num_items_in_llm_router_cache,
|
||||
"num_items_in_proxy_logging_obj_cache": num_items_in_proxy_logging_obj_cache,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/memory-usage-in-mem-cache-items", include_in_schema=False)
|
||||
async def memory_usage_in_mem_cache_items(
|
||||
_: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
# returns the size of all in-memory caches on the proxy server
|
||||
"""
|
||||
1. user_api_key_cache
|
||||
2. router_cache
|
||||
3. proxy_logging_cache
|
||||
4. internal_usage_cache
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
llm_router,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if llm_router is None:
|
||||
llm_router_in_memory_cache_dict = {}
|
||||
llm_router_in_memory_ttl_dict = {}
|
||||
else:
|
||||
llm_router_in_memory_cache_dict = llm_router.cache.in_memory_cache.cache_dict
|
||||
llm_router_in_memory_ttl_dict = llm_router.cache.in_memory_cache.ttl_dict
|
||||
|
||||
return {
|
||||
"user_api_key_cache": user_api_key_cache.in_memory_cache.cache_dict,
|
||||
"user_api_key_ttl": user_api_key_cache.in_memory_cache.ttl_dict,
|
||||
"llm_router_cache": llm_router_in_memory_cache_dict,
|
||||
"llm_router_ttl": llm_router_in_memory_ttl_dict,
|
||||
"proxy_logging_obj_cache": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict,
|
||||
"proxy_logging_obj_ttl": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/debug/memory/summary", include_in_schema=False)
|
||||
async def get_memory_summary(
|
||||
_: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get simplified memory usage summary for the proxy.
|
||||
|
||||
Returns:
|
||||
- worker_pid: Process ID
|
||||
- status: Overall health based on memory usage
|
||||
- memory: Process memory usage and RAM info
|
||||
- caches: Cache item counts and descriptions
|
||||
- garbage_collector: GC status and pending object counts
|
||||
|
||||
Example usage:
|
||||
curl http://localhost:4000/debug/memory/summary -H "Authorization: Bearer sk-1234"
|
||||
|
||||
For detailed analysis, call GET /debug/memory/details
|
||||
For cache management, use the cache management endpoints
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
llm_router,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
# Get process memory info
|
||||
process_memory = {}
|
||||
health_status = "healthy"
|
||||
|
||||
try:
|
||||
import psutil
|
||||
|
||||
process = psutil.Process()
|
||||
memory_info = process.memory_info()
|
||||
memory_mb = memory_info.rss / (1024 * 1024)
|
||||
memory_percent = process.memory_percent()
|
||||
|
||||
process_memory = {
|
||||
"summary": f"{memory_mb:.1f} MB ({memory_percent:.1f}% of system memory)",
|
||||
"ram_usage_mb": round(memory_mb, 2),
|
||||
"system_memory_percent": round(memory_percent, 2),
|
||||
}
|
||||
|
||||
# Check memory health status
|
||||
if memory_percent > 80:
|
||||
health_status = "critical"
|
||||
elif memory_percent > 60:
|
||||
health_status = "warning"
|
||||
else:
|
||||
health_status = "healthy"
|
||||
|
||||
except ImportError:
|
||||
process_memory[
|
||||
"error"
|
||||
] = "Install psutil for memory monitoring: pip install psutil"
|
||||
except Exception as e:
|
||||
process_memory["error"] = str(e)
|
||||
|
||||
# Get cache information
|
||||
caches: Dict[str, Any] = {}
|
||||
total_cache_items = 0
|
||||
|
||||
try:
|
||||
# User API key cache
|
||||
user_cache_items = len(user_api_key_cache.in_memory_cache.cache_dict)
|
||||
total_cache_items += user_cache_items
|
||||
caches["user_api_keys"] = {
|
||||
"count": user_cache_items,
|
||||
"count_readable": f"{user_cache_items:,}",
|
||||
"what_it_stores": "Validated API keys for faster authentication",
|
||||
}
|
||||
|
||||
# Router cache
|
||||
if llm_router is not None:
|
||||
router_cache_items = len(llm_router.cache.in_memory_cache.cache_dict)
|
||||
total_cache_items += router_cache_items
|
||||
caches["llm_responses"] = {
|
||||
"count": router_cache_items,
|
||||
"count_readable": f"{router_cache_items:,}",
|
||||
"what_it_stores": "LLM responses for identical requests",
|
||||
}
|
||||
|
||||
# Proxy logging cache
|
||||
logging_cache_items = len(
|
||||
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
|
||||
)
|
||||
total_cache_items += logging_cache_items
|
||||
caches["usage_tracking"] = {
|
||||
"count": logging_cache_items,
|
||||
"count_readable": f"{logging_cache_items:,}",
|
||||
"what_it_stores": "Usage metrics before database write",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
caches["error"] = str(e)
|
||||
|
||||
# Get garbage collector stats
|
||||
gc_enabled = gc.isenabled()
|
||||
objects_pending = gc.get_count()[0]
|
||||
uncollectable = len(gc.garbage)
|
||||
|
||||
gc_info = {
|
||||
"status": "enabled" if gc_enabled else "disabled",
|
||||
"objects_awaiting_collection": objects_pending,
|
||||
}
|
||||
|
||||
# Add warning if garbage collection issues detected
|
||||
if uncollectable > 0:
|
||||
gc_info[
|
||||
"warning"
|
||||
] = f"{uncollectable} uncollectable objects (possible memory leak)"
|
||||
|
||||
return {
|
||||
"worker_pid": os.getpid(),
|
||||
"status": health_status,
|
||||
"memory": process_memory,
|
||||
"caches": {
|
||||
"total_items": total_cache_items,
|
||||
"breakdown": caches,
|
||||
},
|
||||
"garbage_collector": gc_info,
|
||||
}
|
||||
|
||||
|
||||
def _get_gc_statistics() -> Dict[str, Any]:
|
||||
"""Get garbage collector statistics."""
|
||||
return {
|
||||
"enabled": gc.isenabled(),
|
||||
"thresholds": {
|
||||
"generation_0": gc.get_threshold()[0],
|
||||
"generation_1": gc.get_threshold()[1],
|
||||
"generation_2": gc.get_threshold()[2],
|
||||
"explanation": "Number of allocations before automatic collection for each generation",
|
||||
},
|
||||
"current_counts": {
|
||||
"generation_0": gc.get_count()[0],
|
||||
"generation_1": gc.get_count()[1],
|
||||
"generation_2": gc.get_count()[2],
|
||||
"explanation": "Current number of allocated objects in each generation",
|
||||
},
|
||||
"collection_history": [
|
||||
{
|
||||
"generation": i,
|
||||
"total_collections": stat["collections"],
|
||||
"total_collected": stat["collected"],
|
||||
"uncollectable": stat["uncollectable"],
|
||||
}
|
||||
for i, stat in enumerate(gc.get_stats())
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _get_object_type_counts(top_n: int) -> Tuple[int, List[Dict[str, Any]]]:
|
||||
"""Count objects by type and return total count and top N types."""
|
||||
type_counts: Counter = Counter()
|
||||
total_objects = 0
|
||||
|
||||
for obj in gc.get_objects():
|
||||
total_objects += 1
|
||||
obj_type = type(obj).__name__
|
||||
type_counts[obj_type] += 1
|
||||
|
||||
top_object_types = [
|
||||
{"type": obj_type, "count": count, "count_readable": f"{count:,}"}
|
||||
for obj_type, count in type_counts.most_common(top_n)
|
||||
]
|
||||
|
||||
return total_objects, top_object_types
|
||||
|
||||
|
||||
def _get_uncollectable_objects_info() -> Dict[str, Any]:
|
||||
"""Get information about uncollectable objects (potential memory leaks)."""
|
||||
uncollectable = gc.garbage
|
||||
return {
|
||||
"count": len(uncollectable),
|
||||
"sample_types": [type(obj).__name__ for obj in uncollectable[:10]],
|
||||
"warning": "If count > 0, you may have reference cycles preventing garbage collection"
|
||||
if len(uncollectable) > 0
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
def _get_cache_memory_stats(
|
||||
user_api_key_cache, llm_router, proxy_logging_obj, redis_usage_cache
|
||||
) -> Dict[str, Any]:
|
||||
"""Calculate memory usage for all caches."""
|
||||
cache_stats: Dict[str, Any] = {}
|
||||
try:
|
||||
# User API key cache
|
||||
user_cache_size = sys.getsizeof(user_api_key_cache.in_memory_cache.cache_dict)
|
||||
user_ttl_size = sys.getsizeof(user_api_key_cache.in_memory_cache.ttl_dict)
|
||||
cache_stats["user_api_key_cache"] = {
|
||||
"num_items": len(user_api_key_cache.in_memory_cache.cache_dict),
|
||||
"cache_dict_size_bytes": user_cache_size,
|
||||
"ttl_dict_size_bytes": user_ttl_size,
|
||||
"total_size_mb": round(
|
||||
(user_cache_size + user_ttl_size) / (1024 * 1024), 2
|
||||
),
|
||||
}
|
||||
|
||||
# Router cache
|
||||
if llm_router is not None:
|
||||
router_cache_size = sys.getsizeof(
|
||||
llm_router.cache.in_memory_cache.cache_dict
|
||||
)
|
||||
router_ttl_size = sys.getsizeof(llm_router.cache.in_memory_cache.ttl_dict)
|
||||
cache_stats["llm_router_cache"] = {
|
||||
"num_items": len(llm_router.cache.in_memory_cache.cache_dict),
|
||||
"cache_dict_size_bytes": router_cache_size,
|
||||
"ttl_dict_size_bytes": router_ttl_size,
|
||||
"total_size_mb": round(
|
||||
(router_cache_size + router_ttl_size) / (1024 * 1024), 2
|
||||
),
|
||||
}
|
||||
|
||||
# Proxy logging cache
|
||||
logging_cache_size = sys.getsizeof(
|
||||
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
|
||||
)
|
||||
logging_ttl_size = sys.getsizeof(
|
||||
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict
|
||||
)
|
||||
cache_stats["proxy_logging_cache"] = {
|
||||
"num_items": len(
|
||||
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
|
||||
),
|
||||
"cache_dict_size_bytes": logging_cache_size,
|
||||
"ttl_dict_size_bytes": logging_ttl_size,
|
||||
"total_size_mb": round(
|
||||
(logging_cache_size + logging_ttl_size) / (1024 * 1024), 2
|
||||
),
|
||||
}
|
||||
|
||||
# Redis cache info
|
||||
if redis_usage_cache is not None:
|
||||
cache_stats["redis_usage_cache"] = {
|
||||
"enabled": True,
|
||||
"cache_type": type(redis_usage_cache).__name__,
|
||||
}
|
||||
# Try to get Redis connection pool info if available
|
||||
try:
|
||||
if (
|
||||
hasattr(redis_usage_cache, "redis_client")
|
||||
and redis_usage_cache.redis_client
|
||||
):
|
||||
if hasattr(redis_usage_cache.redis_client, "connection_pool"):
|
||||
pool_info = redis_usage_cache.redis_client.connection_pool # type: ignore
|
||||
cache_stats["redis_usage_cache"]["connection_pool"] = {
|
||||
"max_connections": pool_info.max_connections
|
||||
if hasattr(pool_info, "max_connections")
|
||||
else None,
|
||||
"connection_class": pool_info.connection_class.__name__
|
||||
if hasattr(pool_info, "connection_class")
|
||||
else None,
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error getting Redis pool info: {e}")
|
||||
else:
|
||||
cache_stats["redis_usage_cache"] = {"enabled": False}
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error calculating cache stats: {e}")
|
||||
cache_stats["error"] = str(e)
|
||||
|
||||
return cache_stats
|
||||
|
||||
|
||||
def _get_router_memory_stats(llm_router) -> Dict[str, Any]:
|
||||
"""Get memory usage statistics for LiteLLM router."""
|
||||
litellm_router_memory: Dict[str, Any] = {}
|
||||
try:
|
||||
if llm_router is not None:
|
||||
# Model list memory size
|
||||
if hasattr(llm_router, "model_list") and llm_router.model_list:
|
||||
model_list_size = sys.getsizeof(llm_router.model_list)
|
||||
litellm_router_memory["model_list"] = {
|
||||
"num_models": len(llm_router.model_list),
|
||||
"size_bytes": model_list_size,
|
||||
"size_mb": round(model_list_size / (1024 * 1024), 4),
|
||||
}
|
||||
|
||||
# Model names set
|
||||
if hasattr(llm_router, "model_names") and llm_router.model_names:
|
||||
model_names_size = sys.getsizeof(llm_router.model_names)
|
||||
litellm_router_memory["model_names_set"] = {
|
||||
"num_model_groups": len(llm_router.model_names),
|
||||
"size_bytes": model_names_size,
|
||||
"size_mb": round(model_names_size / (1024 * 1024), 4),
|
||||
}
|
||||
|
||||
# Deployment names list
|
||||
if hasattr(llm_router, "deployment_names") and llm_router.deployment_names:
|
||||
deployment_names_size = sys.getsizeof(llm_router.deployment_names)
|
||||
litellm_router_memory["deployment_names"] = {
|
||||
"num_deployments": len(llm_router.deployment_names),
|
||||
"size_bytes": deployment_names_size,
|
||||
"size_mb": round(deployment_names_size / (1024 * 1024), 4),
|
||||
}
|
||||
|
||||
# Deployment latency map
|
||||
if (
|
||||
hasattr(llm_router, "deployment_latency_map")
|
||||
and llm_router.deployment_latency_map
|
||||
):
|
||||
latency_map_size = sys.getsizeof(llm_router.deployment_latency_map)
|
||||
litellm_router_memory["deployment_latency_map"] = {
|
||||
"num_tracked_deployments": len(llm_router.deployment_latency_map),
|
||||
"size_bytes": latency_map_size,
|
||||
"size_mb": round(latency_map_size / (1024 * 1024), 4),
|
||||
}
|
||||
|
||||
# Fallback configuration
|
||||
if hasattr(llm_router, "fallbacks") and llm_router.fallbacks:
|
||||
fallbacks_size = sys.getsizeof(llm_router.fallbacks)
|
||||
litellm_router_memory["fallbacks"] = {
|
||||
"num_fallback_configs": len(llm_router.fallbacks),
|
||||
"size_bytes": fallbacks_size,
|
||||
"size_mb": round(fallbacks_size / (1024 * 1024), 4),
|
||||
}
|
||||
|
||||
# Total router object size
|
||||
router_obj_size = sys.getsizeof(llm_router)
|
||||
litellm_router_memory["router_object"] = {
|
||||
"size_bytes": router_obj_size,
|
||||
"size_mb": round(router_obj_size / (1024 * 1024), 4),
|
||||
}
|
||||
|
||||
else:
|
||||
litellm_router_memory = {"note": "Router not initialized"}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error getting router memory info: {e}")
|
||||
litellm_router_memory = {"error": str(e)}
|
||||
|
||||
return litellm_router_memory
|
||||
|
||||
|
||||
def _get_process_memory_info(
|
||||
worker_pid: int, include_process_info: bool
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get process-level memory information using psutil."""
|
||||
if not include_process_info:
|
||||
return None
|
||||
|
||||
try:
|
||||
import psutil
|
||||
|
||||
process = psutil.Process()
|
||||
memory_info = process.memory_info()
|
||||
ram_usage_mb = round(memory_info.rss / (1024 * 1024), 2)
|
||||
virtual_memory_mb = round(memory_info.vms / (1024 * 1024), 2)
|
||||
memory_percent = round(process.memory_percent(), 2)
|
||||
|
||||
return {
|
||||
"pid": worker_pid,
|
||||
"summary": f"Worker PID {worker_pid} using {ram_usage_mb:.1f} MB of RAM ({memory_percent:.1f}% of system memory)",
|
||||
"ram_usage": {
|
||||
"megabytes": ram_usage_mb,
|
||||
"description": "Actual physical RAM used by this process",
|
||||
},
|
||||
"virtual_memory": {
|
||||
"megabytes": virtual_memory_mb,
|
||||
"description": "Total virtual memory allocated (includes swapped memory)",
|
||||
},
|
||||
"system_memory_percent": {
|
||||
"percent": memory_percent,
|
||||
"description": "Percentage of total system RAM being used",
|
||||
},
|
||||
"open_file_handles": {
|
||||
"count": process.num_fds()
|
||||
if hasattr(process, "num_fds")
|
||||
else "N/A (Windows)",
|
||||
"description": "Number of open file descriptors/handles",
|
||||
},
|
||||
"threads": {
|
||||
"count": process.num_threads(),
|
||||
"description": "Number of active threads in this process",
|
||||
},
|
||||
}
|
||||
except ImportError:
|
||||
return {
|
||||
"pid": worker_pid,
|
||||
"error": "psutil not installed. Install with: pip install psutil",
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error getting process info: {e}")
|
||||
return {"pid": worker_pid, "error": str(e)}
|
||||
|
||||
|
||||
@router.get("/debug/memory/details", include_in_schema=False)
|
||||
async def get_memory_details(
|
||||
_: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
top_n: int = Query(20, description="Number of top object types to return"),
|
||||
include_process_info: bool = Query(True, description="Include process memory info"),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get detailed memory diagnostics for deep debugging.
|
||||
|
||||
Returns:
|
||||
- worker_pid: Process ID
|
||||
- process_memory: RAM usage, virtual memory, file handles, threads
|
||||
- garbage_collector: GC thresholds, counts, collection history
|
||||
- objects: Total tracked objects and top object types
|
||||
- uncollectable: Objects that can't be garbage collected (potential leaks)
|
||||
- cache_memory: Memory usage of user_api_key, router, and logging caches
|
||||
- router_memory: Memory usage of router components (model_list, deployment_names, etc.)
|
||||
|
||||
Query Parameters:
|
||||
- top_n: Number of top object types to return (default: 20)
|
||||
- include_process_info: Include process-level memory info using psutil (default: true)
|
||||
|
||||
Example usage:
|
||||
curl "http://localhost:4000/debug/memory/details?top_n=30" -H "Authorization: Bearer sk-1234"
|
||||
|
||||
All memory sizes are reported in both bytes and MB.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
llm_router,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
redis_usage_cache,
|
||||
)
|
||||
|
||||
worker_pid = os.getpid()
|
||||
|
||||
# Collect all diagnostics using helper functions
|
||||
gc_stats = _get_gc_statistics()
|
||||
total_objects, top_object_types = _get_object_type_counts(top_n)
|
||||
uncollectable_info = _get_uncollectable_objects_info()
|
||||
cache_stats = _get_cache_memory_stats(
|
||||
user_api_key_cache, llm_router, proxy_logging_obj, redis_usage_cache
|
||||
)
|
||||
litellm_router_memory = _get_router_memory_stats(llm_router)
|
||||
process_info = _get_process_memory_info(worker_pid, include_process_info)
|
||||
|
||||
return {
|
||||
"worker_pid": worker_pid,
|
||||
"process_memory": process_info,
|
||||
"garbage_collector": gc_stats,
|
||||
"objects": {
|
||||
"total_tracked": total_objects,
|
||||
"total_tracked_readable": f"{total_objects:,}",
|
||||
"top_types": top_object_types,
|
||||
},
|
||||
"uncollectable": uncollectable_info,
|
||||
"cache_memory": cache_stats,
|
||||
"router_memory": litellm_router_memory,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/debug/memory/gc/configure", include_in_schema=False)
|
||||
async def configure_gc_thresholds_endpoint(
|
||||
_: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
generation_0: int = Query(700, description="Generation 0 threshold (default: 700)"),
|
||||
generation_1: int = Query(10, description="Generation 1 threshold (default: 10)"),
|
||||
generation_2: int = Query(10, description="Generation 2 threshold (default: 10)"),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Configure Python garbage collection thresholds.
|
||||
|
||||
Lower thresholds mean more frequent GC cycles (less memory, more CPU overhead).
|
||||
Higher thresholds mean less frequent GC cycles (more memory, less CPU overhead).
|
||||
|
||||
Returns:
|
||||
- message: Confirmation message
|
||||
- previous_thresholds: Old threshold values
|
||||
- new_thresholds: New threshold values
|
||||
- objects_awaiting_collection: Current object count in gen-0
|
||||
- tip: Hint about when next collection will occur
|
||||
|
||||
Query Parameters:
|
||||
- generation_0: Number of allocations before gen-0 collection (default: 700)
|
||||
- generation_1: Number of gen-0 collections before gen-1 collection (default: 10)
|
||||
- generation_2: Number of gen-1 collections before gen-2 collection (default: 10)
|
||||
|
||||
Example for more aggressive collection:
|
||||
curl -X POST "http://localhost:4000/debug/memory/gc/configure?generation_0=500" -H "Authorization: Bearer sk-1234"
|
||||
|
||||
Example for less aggressive collection:
|
||||
curl -X POST "http://localhost:4000/debug/memory/gc/configure?generation_0=1000" -H "Authorization: Bearer sk-1234"
|
||||
|
||||
Monitor memory usage with GET /debug/memory/summary after changes.
|
||||
"""
|
||||
# Get current thresholds for logging
|
||||
old_thresholds = gc.get_threshold()
|
||||
|
||||
# Set new thresholds with error handling
|
||||
try:
|
||||
gc.set_threshold(generation_0, generation_1, generation_2)
|
||||
verbose_proxy_logger.info(
|
||||
f"GC thresholds updated from {old_thresholds} to "
|
||||
f"({generation_0}, {generation_1}, {generation_2})"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Failed to set GC thresholds: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to set GC thresholds: {str(e)}"
|
||||
)
|
||||
|
||||
# Get current object count to show immediate impact
|
||||
current_count = gc.get_count()[0]
|
||||
|
||||
return {
|
||||
"message": "GC thresholds updated",
|
||||
"previous_thresholds": f"{old_thresholds[0]}, {old_thresholds[1]}, {old_thresholds[2]}",
|
||||
"new_thresholds": f"{generation_0}, {generation_1}, {generation_2}",
|
||||
"objects_awaiting_collection": current_count,
|
||||
"tip": f"Next collection will run after {generation_0 - current_count} more allocations",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/otel-spans", include_in_schema=False)
|
||||
async def get_otel_spans():
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
if open_telemetry_logger is None:
|
||||
return {
|
||||
"otel_spans": [],
|
||||
"spans_grouped_by_parent": {},
|
||||
"most_recent_parent": None,
|
||||
}
|
||||
|
||||
otel_exporter = open_telemetry_logger.OTEL_EXPORTER
|
||||
if hasattr(otel_exporter, "get_finished_spans"):
|
||||
recorded_spans = otel_exporter.get_finished_spans() # type: ignore
|
||||
else:
|
||||
recorded_spans = []
|
||||
|
||||
print("Spans: ", recorded_spans) # noqa
|
||||
|
||||
most_recent_parent = None
|
||||
most_recent_start_time = 1000000
|
||||
spans_grouped_by_parent = {}
|
||||
for span in recorded_spans:
|
||||
if span.parent is not None:
|
||||
parent_trace_id = span.parent.trace_id
|
||||
if parent_trace_id not in spans_grouped_by_parent:
|
||||
spans_grouped_by_parent[parent_trace_id] = []
|
||||
spans_grouped_by_parent[parent_trace_id].append(span.name)
|
||||
|
||||
# check time of span
|
||||
if span.start_time > most_recent_start_time:
|
||||
most_recent_parent = parent_trace_id
|
||||
most_recent_start_time = span.start_time
|
||||
|
||||
# these are otel spans - get the span name
|
||||
span_names = [span.name for span in recorded_spans]
|
||||
return {
|
||||
"otel_spans": span_names,
|
||||
"spans_grouped_by_parent": spans_grouped_by_parent,
|
||||
"most_recent_parent": most_recent_parent,
|
||||
}
|
||||
|
||||
|
||||
# Helper functions for debugging
|
||||
def init_verbose_loggers():
|
||||
try:
|
||||
worker_config = get_secret_str("WORKER_CONFIG")
|
||||
# if not, assume it's a json string
|
||||
if worker_config is None:
|
||||
return
|
||||
if os.path.isfile(worker_config):
|
||||
return
|
||||
_settings = json.loads(worker_config)
|
||||
if not isinstance(_settings, dict):
|
||||
return
|
||||
|
||||
debug = _settings.get("debug", None)
|
||||
detailed_debug = _settings.get("detailed_debug", None)
|
||||
if debug is True: # this needs to be first, so users can see Router init debugg
|
||||
import logging
|
||||
|
||||
from litellm._logging import (
|
||||
verbose_logger,
|
||||
verbose_proxy_logger,
|
||||
verbose_router_logger,
|
||||
)
|
||||
|
||||
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
|
||||
verbose_logger.setLevel(level=logging.INFO) # sets package logs to info
|
||||
verbose_router_logger.setLevel(
|
||||
level=logging.INFO
|
||||
) # set router logs to info
|
||||
verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info
|
||||
if detailed_debug is True:
|
||||
import logging
|
||||
|
||||
from litellm._logging import (
|
||||
verbose_logger,
|
||||
verbose_proxy_logger,
|
||||
verbose_router_logger,
|
||||
)
|
||||
|
||||
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
|
||||
verbose_router_logger.setLevel(
|
||||
level=logging.DEBUG
|
||||
) # set router logs to debug
|
||||
verbose_proxy_logger.setLevel(
|
||||
level=logging.DEBUG
|
||||
) # set proxy logs to debug
|
||||
elif debug is False and detailed_debug is False:
|
||||
# users can control proxy debugging using env variable = 'LITELLM_LOG'
|
||||
litellm_log_setting = os.environ.get("LITELLM_LOG", "")
|
||||
if litellm_log_setting is not None:
|
||||
if litellm_log_setting.upper() == "INFO":
|
||||
import logging
|
||||
|
||||
from litellm._logging import (
|
||||
verbose_proxy_logger,
|
||||
verbose_router_logger,
|
||||
)
|
||||
|
||||
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
|
||||
|
||||
verbose_router_logger.setLevel(
|
||||
level=logging.INFO
|
||||
) # set router logs to info
|
||||
verbose_proxy_logger.setLevel(
|
||||
level=logging.INFO
|
||||
) # set proxy logs to info
|
||||
elif litellm_log_setting.upper() == "DEBUG":
|
||||
import logging
|
||||
|
||||
from litellm._logging import (
|
||||
verbose_proxy_logger,
|
||||
verbose_router_logger,
|
||||
)
|
||||
|
||||
verbose_router_logger.setLevel(
|
||||
level=logging.DEBUG
|
||||
) # set router logs to info
|
||||
verbose_proxy_logger.setLevel(
|
||||
level=logging.DEBUG
|
||||
) # set proxy logs to debug
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.warning(f"Failed to init verbose loggers: {str(e)}")
|
||||
@@ -0,0 +1,122 @@
|
||||
import base64
|
||||
import os
|
||||
from typing import Literal, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
|
||||
def _get_salt_key():
|
||||
from litellm.proxy.proxy_server import master_key
|
||||
|
||||
salt_key = os.getenv("LITELLM_SALT_KEY", None)
|
||||
|
||||
if salt_key is None:
|
||||
salt_key = master_key
|
||||
|
||||
return salt_key
|
||||
|
||||
|
||||
def encrypt_value_helper(value: str, new_encryption_key: Optional[str] = None):
|
||||
signing_key = new_encryption_key or _get_salt_key()
|
||||
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
encrypted_value = encrypt_value(value=value, signing_key=signing_key) # type: ignore
|
||||
# Use urlsafe_b64encode for URL-safe base64 encoding (replaces + with - and / with _)
|
||||
encrypted_value = base64.urlsafe_b64encode(encrypted_value).decode("utf-8")
|
||||
|
||||
return encrypted_value
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Invalid value type passed to encrypt_value: {type(value)} for Value: {value}\n Value must be a string"
|
||||
)
|
||||
# if it's not a string - do not encrypt it and return the value
|
||||
return value
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def decrypt_value_helper(
|
||||
value: str,
|
||||
key: str, # this is just for debug purposes, showing the k,v pair that's invalid. not a signing key.
|
||||
exception_type: Literal["debug", "error"] = "error",
|
||||
return_original_value: bool = False,
|
||||
):
|
||||
signing_key = _get_salt_key()
|
||||
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
# Try URL-safe base64 decoding first (new format)
|
||||
# Fall back to standard base64 decoding for backwards compatibility (old format)
|
||||
try:
|
||||
decoded_b64 = base64.urlsafe_b64decode(value)
|
||||
except Exception:
|
||||
# If URL-safe decoding fails, try standard base64 decoding for backwards compatibility
|
||||
decoded_b64 = base64.b64decode(value)
|
||||
|
||||
value = decrypt_value(value=decoded_b64, signing_key=signing_key) # type: ignore
|
||||
return value
|
||||
|
||||
# if it's not str - do not decrypt it, return the value
|
||||
return value
|
||||
except Exception as e:
|
||||
error_message = f"Error decrypting value for key: {key}, Did your master_key/salt key change recently? \nError: {str(e)}\nSet permanent salt key - https://docs.litellm.ai/docs/proxy/prod#5-set-litellm-salt-key"
|
||||
if exception_type == "debug":
|
||||
verbose_proxy_logger.debug(error_message)
|
||||
return value if return_original_value else None
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Unable to decrypt value={value} for key: {key}, returning None"
|
||||
)
|
||||
if return_original_value:
|
||||
return value
|
||||
else:
|
||||
verbose_proxy_logger.exception(error_message)
|
||||
# [Non-Blocking Exception. - this should not block decrypting other values]
|
||||
return None
|
||||
|
||||
|
||||
def encrypt_value(value: str, signing_key: str):
|
||||
import hashlib
|
||||
|
||||
import nacl.secret
|
||||
import nacl.utils
|
||||
|
||||
# get 32 byte master key #
|
||||
hash_object = hashlib.sha256(signing_key.encode())
|
||||
hash_bytes = hash_object.digest()
|
||||
|
||||
# initialize secret box #
|
||||
box = nacl.secret.SecretBox(hash_bytes)
|
||||
|
||||
# encode message #
|
||||
value_bytes = value.encode("utf-8")
|
||||
|
||||
encrypted = box.encrypt(value_bytes)
|
||||
|
||||
return encrypted
|
||||
|
||||
|
||||
def decrypt_value(value: bytes, signing_key: str) -> str:
|
||||
import hashlib
|
||||
|
||||
import nacl.secret
|
||||
import nacl.utils
|
||||
|
||||
# get 32 byte master key #
|
||||
hash_object = hashlib.sha256(signing_key.encode())
|
||||
hash_bytes = hash_object.digest()
|
||||
|
||||
# initialize secret box #
|
||||
box = nacl.secret.SecretBox(hash_bytes)
|
||||
|
||||
# Convert the bytes object to a string
|
||||
try:
|
||||
if len(value) == 0:
|
||||
return ""
|
||||
|
||||
plaintext = box.decrypt(value)
|
||||
plaintext = plaintext.decode("utf-8") # type: ignore
|
||||
return plaintext # type: ignore
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
Utility class for getting routes from a FastAPI app.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from starlette.routing import BaseRoute
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
class GetRoutes:
|
||||
@staticmethod
|
||||
def get_app_routes(
|
||||
route: BaseRoute,
|
||||
endpoint_route: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get routes for a regular route.
|
||||
"""
|
||||
routes: List[Dict[str, Any]] = []
|
||||
route_info = {
|
||||
"path": getattr(route, "path", None),
|
||||
"methods": getattr(route, "methods", None),
|
||||
"name": getattr(route, "name", None),
|
||||
"endpoint": (
|
||||
endpoint_route.__name__ if getattr(route, "endpoint", None) else None
|
||||
),
|
||||
}
|
||||
routes.append(route_info)
|
||||
return routes
|
||||
|
||||
@staticmethod
|
||||
def get_routes_for_mounted_app(
|
||||
route: BaseRoute,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get routes for a mounted sub-application.
|
||||
"""
|
||||
routes: List[Dict[str, Any]] = []
|
||||
mount_path = getattr(route, "path", "")
|
||||
sub_app = getattr(route, "app", None)
|
||||
if sub_app and hasattr(sub_app, "routes"):
|
||||
for sub_route in sub_app.routes:
|
||||
# Get endpoint - either from endpoint attribute or app attribute
|
||||
endpoint_func = getattr(sub_route, "endpoint", None) or getattr(
|
||||
sub_route, "app", None
|
||||
)
|
||||
|
||||
if endpoint_func is not None:
|
||||
sub_route_path = getattr(sub_route, "path", "")
|
||||
full_path = mount_path.rstrip("/") + sub_route_path
|
||||
|
||||
route_info = {
|
||||
"path": full_path,
|
||||
"methods": getattr(sub_route, "methods", ["GET", "POST"]),
|
||||
"name": getattr(sub_route, "name", None),
|
||||
"endpoint": GetRoutes._safe_get_endpoint_name(endpoint_func),
|
||||
"mounted_app": True,
|
||||
}
|
||||
routes.append(route_info)
|
||||
return routes
|
||||
|
||||
@staticmethod
|
||||
def _safe_get_endpoint_name(endpoint_function: Any) -> Optional[str]:
|
||||
"""
|
||||
Safely get the name of the endpoint function.
|
||||
"""
|
||||
try:
|
||||
if hasattr(endpoint_function, "__name__"):
|
||||
return getattr(endpoint_function, "__name__")
|
||||
elif hasattr(endpoint_function, "__class__") and hasattr(
|
||||
endpoint_function.__class__, "__name__"
|
||||
):
|
||||
return getattr(endpoint_function.__class__, "__name__")
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
f"Error getting endpoint name for route: {endpoint_function}"
|
||||
)
|
||||
return None
|
||||
@@ -0,0 +1,207 @@
|
||||
from litellm.proxy.common_utils.banner import LITELLM_BANNER
|
||||
|
||||
|
||||
def render_cli_sso_success_page() -> str:
|
||||
"""
|
||||
Renders the CLI SSO authentication success page with minimal styling
|
||||
|
||||
Returns:
|
||||
str: HTML content for the success page
|
||||
"""
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>CLI Authentication Successful - LiteLLM</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<style>
|
||||
body {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
||||
background-color: #f8fafc;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
min-height: 100vh;
|
||||
color: #1e293b;
|
||||
}}
|
||||
|
||||
.container {{
|
||||
background-color: #fff;
|
||||
padding: 40px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
|
||||
width: 450px;
|
||||
max-width: 100%;
|
||||
text-align: center;
|
||||
}}
|
||||
|
||||
.logo-container {{
|
||||
margin-bottom: 20px;
|
||||
}}
|
||||
|
||||
.logo {{
|
||||
font-size: 24px;
|
||||
font-weight: 600;
|
||||
color: #1e293b;
|
||||
}}
|
||||
|
||||
h1 {{
|
||||
margin: 0 0 10px;
|
||||
color: #1e293b;
|
||||
font-size: 28px;
|
||||
font-weight: 600;
|
||||
}}
|
||||
|
||||
.subtitle {{
|
||||
color: #64748b;
|
||||
margin: 0 0 30px;
|
||||
font-size: 16px;
|
||||
}}
|
||||
|
||||
.banner {{
|
||||
background-color: #f8fafc;
|
||||
color: #334155;
|
||||
font-family: 'Courier New', Consolas, monospace;
|
||||
font-size: 10px;
|
||||
line-height: 1.1;
|
||||
white-space: pre;
|
||||
padding: 20px;
|
||||
border-radius: 6px;
|
||||
margin: 20px 0;
|
||||
text-align: center;
|
||||
border: 1px solid #e2e8f0;
|
||||
overflow-x: auto;
|
||||
}}
|
||||
|
||||
.success-box {{
|
||||
background-color: #f8fafc;
|
||||
border-radius: 6px;
|
||||
padding: 20px;
|
||||
margin-bottom: 30px;
|
||||
border: 1px solid #e2e8f0;
|
||||
}}
|
||||
|
||||
.success-header {{
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
margin-bottom: 12px;
|
||||
color: #1e293b;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
}}
|
||||
|
||||
.success-header svg {{
|
||||
margin-right: 8px;
|
||||
}}
|
||||
|
||||
.success-box p {{
|
||||
color: #64748b;
|
||||
margin: 8px 0;
|
||||
line-height: 1.5;
|
||||
font-size: 14px;
|
||||
}}
|
||||
|
||||
.instructions {{
|
||||
background-color: #f8fafc;
|
||||
border-radius: 6px;
|
||||
padding: 20px;
|
||||
margin-bottom: 20px;
|
||||
border: 1px solid #e2e8f0;
|
||||
}}
|
||||
|
||||
.instructions-header {{
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
margin-bottom: 12px;
|
||||
color: #1e293b;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
}}
|
||||
|
||||
.instructions-header svg {{
|
||||
margin-right: 8px;
|
||||
}}
|
||||
|
||||
.instructions p {{
|
||||
color: #64748b;
|
||||
margin: 8px 0;
|
||||
line-height: 1.5;
|
||||
font-size: 14px;
|
||||
}}
|
||||
|
||||
.countdown {{
|
||||
color: #64748b;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
padding: 12px;
|
||||
background-color: #f8fafc;
|
||||
border-radius: 6px;
|
||||
border: 1px solid #e2e8f0;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="logo-container">
|
||||
<div class="logo">
|
||||
🚅 LiteLLM
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="banner">{LITELLM_BANNER}</div>
|
||||
|
||||
<h1>Authentication Successful!</h1>
|
||||
<p class="subtitle">Your CLI authentication is complete.</p>
|
||||
|
||||
<div class="success-box">
|
||||
<div class="success-header">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M9 12l2 2 4-4"></path>
|
||||
<circle cx="12" cy="12" r="10"></circle>
|
||||
</svg>
|
||||
CLI Authentication Complete
|
||||
</div>
|
||||
<p>Your LiteLLM CLI has been successfully authenticated and is ready to use.</p>
|
||||
</div>
|
||||
|
||||
<div class="instructions">
|
||||
<div class="instructions-header">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<circle cx="12" cy="12" r="10"></circle>
|
||||
<line x1="12" y1="16" x2="12" y2="12"></line>
|
||||
<line x1="12" y1="8" x2="12.01" y2="8"></line>
|
||||
</svg>
|
||||
Next Steps
|
||||
</div>
|
||||
<p>Return to your terminal - the CLI will automatically detect the successful authentication.</p>
|
||||
<p>You can now use LiteLLM CLI commands with your authenticated session.</p>
|
||||
</div>
|
||||
|
||||
<div class="countdown" id="countdown">This window will close in 3 seconds...</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let seconds = 3;
|
||||
const countdownElement = document.getElementById('countdown');
|
||||
|
||||
const countdown = setInterval(function() {{
|
||||
seconds--;
|
||||
if (seconds > 0) {{
|
||||
countdownElement.textContent = `This window will close in ${{seconds}} second${{seconds === 1 ? '' : 's'}}...`;
|
||||
}} else {{
|
||||
countdownElement.textContent = 'Closing...';
|
||||
clearInterval(countdown);
|
||||
window.close();
|
||||
}}
|
||||
}}, 1000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return html_content
|
||||
@@ -0,0 +1,284 @@
|
||||
# JWT display template for SSO debug callback
|
||||
jwt_display_template = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>LiteLLM SSO Debug - JWT Information</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
||||
background-color: #f8fafc;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
min-height: 100vh;
|
||||
color: #333;
|
||||
}
|
||||
|
||||
.container {
|
||||
background-color: #fff;
|
||||
padding: 40px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
|
||||
width: 800px;
|
||||
max-width: 100%;
|
||||
}
|
||||
|
||||
.logo-container {
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
.logo {
|
||||
font-size: 24px;
|
||||
font-weight: 600;
|
||||
color: #1e293b;
|
||||
}
|
||||
|
||||
h2 {
|
||||
margin: 0 0 10px;
|
||||
color: #1e293b;
|
||||
font-size: 28px;
|
||||
font-weight: 600;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.subtitle {
|
||||
color: #64748b;
|
||||
margin: 0 0 20px;
|
||||
font-size: 16px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.info-box {
|
||||
background-color: #f1f5f9;
|
||||
border-radius: 6px;
|
||||
padding: 20px;
|
||||
margin-bottom: 30px;
|
||||
border-left: 4px solid #2563eb;
|
||||
}
|
||||
|
||||
.success-box {
|
||||
background-color: #f0fdf4;
|
||||
border-radius: 6px;
|
||||
padding: 20px;
|
||||
margin-bottom: 30px;
|
||||
border-left: 4px solid #16a34a;
|
||||
}
|
||||
|
||||
.info-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-bottom: 12px;
|
||||
color: #1e40af;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.success-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-bottom: 12px;
|
||||
color: #166534;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.info-header svg, .success-header svg {
|
||||
margin-right: 8px;
|
||||
}
|
||||
|
||||
.data-container {
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.data-row {
|
||||
display: flex;
|
||||
border-bottom: 1px solid #e2e8f0;
|
||||
padding: 12px 0;
|
||||
}
|
||||
|
||||
.data-row:last-child {
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
.data-label {
|
||||
font-weight: 500;
|
||||
color: #334155;
|
||||
width: 180px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.data-value {
|
||||
color: #475569;
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
.jwt-container {
|
||||
background-color: #f8fafc;
|
||||
border-radius: 6px;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
overflow-x: auto;
|
||||
border: 1px solid #e2e8f0;
|
||||
}
|
||||
|
||||
.jwt-text {
|
||||
font-family: monospace;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-all;
|
||||
margin: 0;
|
||||
color: #334155;
|
||||
}
|
||||
|
||||
.back-button {
|
||||
display: inline-block;
|
||||
background-color: #6466E9;
|
||||
color: #fff;
|
||||
text-decoration: none;
|
||||
padding: 10px 16px;
|
||||
border-radius: 6px;
|
||||
font-weight: 500;
|
||||
margin-top: 20px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.back-button:hover {
|
||||
background-color: #4138C2;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
.buttons {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.copy-button {
|
||||
background-color: #e2e8f0;
|
||||
color: #334155;
|
||||
border: none;
|
||||
padding: 8px 12px;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.copy-button:hover {
|
||||
background-color: #cbd5e1;
|
||||
}
|
||||
|
||||
.copy-button svg {
|
||||
margin-right: 6px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="logo-container">
|
||||
<div class="logo">
|
||||
🚅 LiteLLM
|
||||
</div>
|
||||
</div>
|
||||
<h2>SSO Debug Information</h2>
|
||||
<p class="subtitle">Results from the SSO authentication process.</p>
|
||||
|
||||
<div class="success-box">
|
||||
<div class="success-header">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M22 11.08V12a10 10 0 1 1-5.93-9.14"></path>
|
||||
<polyline points="22 4 12 14.01 9 11.01"></polyline>
|
||||
</svg>
|
||||
Authentication Successful
|
||||
</div>
|
||||
<p>The SSO authentication completed successfully. Below is the information returned by the provider.</p>
|
||||
</div>
|
||||
|
||||
<div class="data-container" id="userData">
|
||||
<!-- Data will be inserted here by JavaScript -->
|
||||
</div>
|
||||
|
||||
<div class="info-box">
|
||||
<div class="info-header">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<circle cx="12" cy="12" r="10"></circle>
|
||||
<line x1="12" y1="16" x2="12" y2="12"></line>
|
||||
<line x1="12" y1="8" x2="12.01" y2="8"></line>
|
||||
</svg>
|
||||
JSON Representation
|
||||
</div>
|
||||
<div class="jwt-container">
|
||||
<pre class="jwt-text" id="jsonData">Loading...</pre>
|
||||
</div>
|
||||
<div class="buttons">
|
||||
<button class="copy-button" onclick="copyToClipboard('jsonData')">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect>
|
||||
<path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path>
|
||||
</svg>
|
||||
Copy to Clipboard
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<a href="/sso/debug/login" class="back-button">
|
||||
Try Another SSO Login
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// This will be populated with the actual data from the server
|
||||
const userData = SSO_DATA;
|
||||
|
||||
function renderUserData() {
|
||||
const container = document.getElementById('userData');
|
||||
const jsonDisplay = document.getElementById('jsonData');
|
||||
|
||||
// Format JSON with indentation for display
|
||||
jsonDisplay.textContent = JSON.stringify(userData, null, 2);
|
||||
|
||||
// Clear container
|
||||
container.innerHTML = '';
|
||||
|
||||
// Add each key-value pair to the UI
|
||||
for (const [key, value] of Object.entries(userData)) {
|
||||
if (typeof value !== 'object' || value === null) {
|
||||
const row = document.createElement('div');
|
||||
row.className = 'data-row';
|
||||
|
||||
const label = document.createElement('div');
|
||||
label.className = 'data-label';
|
||||
label.textContent = key;
|
||||
|
||||
const dataValue = document.createElement('div');
|
||||
dataValue.className = 'data-value';
|
||||
dataValue.textContent = value !== null ? value : 'null';
|
||||
|
||||
row.appendChild(label);
|
||||
row.appendChild(dataValue);
|
||||
container.appendChild(row);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function copyToClipboard(elementId) {
|
||||
const text = document.getElementById(elementId).textContent;
|
||||
navigator.clipboard.writeText(text).then(() => {
|
||||
alert('Copied to clipboard!');
|
||||
}).catch(err => {
|
||||
console.error('Could not copy text: ', err);
|
||||
});
|
||||
}
|
||||
|
||||
// Render the data when the page loads
|
||||
document.addEventListener('DOMContentLoaded', renderUserData);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
@@ -0,0 +1,269 @@
|
||||
import os
|
||||
|
||||
from litellm.proxy.utils import get_custom_url
|
||||
|
||||
url_to_redirect_to = os.getenv("PROXY_BASE_URL", "")
|
||||
server_root_path = os.getenv("SERVER_ROOT_PATH", "")
|
||||
if server_root_path != "":
|
||||
url_to_redirect_to += server_root_path
|
||||
url_to_redirect_to += "/login"
|
||||
new_ui_login_url = get_custom_url("", "ui/login")
|
||||
|
||||
|
||||
def build_ui_login_form(show_deprecation_banner: bool = False) -> str:
|
||||
banner_html = (
|
||||
f"""
|
||||
<div class="deprecation-banner">
|
||||
<strong>Deprecated:</strong> Logging in with username and password on this page is deprecated.
|
||||
Please use the <a href="{new_ui_login_url}">new login page</a> instead.
|
||||
This page will be dedicated to signing in via SSO in the future.
|
||||
</div>
|
||||
"""
|
||||
if show_deprecation_banner
|
||||
else ""
|
||||
)
|
||||
|
||||
return f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>LiteLLM Login</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<style>
|
||||
body {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
||||
background-color: #f8fafc;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
min-height: 100vh;
|
||||
color: #333;
|
||||
}}
|
||||
|
||||
form {{
|
||||
background-color: #fff;
|
||||
padding: 40px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
|
||||
width: 450px;
|
||||
max-width: 100%;
|
||||
}}
|
||||
|
||||
.logo-container {{
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
}}
|
||||
|
||||
.logo {{
|
||||
font-size: 24px;
|
||||
font-weight: 600;
|
||||
color: #1e293b;
|
||||
}}
|
||||
|
||||
h2 {{
|
||||
margin: 0 0 10px;
|
||||
color: #1e293b;
|
||||
font-size: 28px;
|
||||
font-weight: 600;
|
||||
text-align: center;
|
||||
}}
|
||||
|
||||
.subtitle {{
|
||||
color: #64748b;
|
||||
margin: 0 0 20px;
|
||||
font-size: 16px;
|
||||
text-align: center;
|
||||
}}
|
||||
|
||||
.info-box {{
|
||||
background-color: #f1f5f9;
|
||||
border-radius: 6px;
|
||||
padding: 20px;
|
||||
margin-bottom: 30px;
|
||||
border-left: 4px solid #2563eb;
|
||||
}}
|
||||
|
||||
.info-header {{
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-bottom: 12px;
|
||||
color: #1e40af;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
}}
|
||||
|
||||
.info-header svg {{
|
||||
margin-right: 8px;
|
||||
}}
|
||||
|
||||
.info-box p {{
|
||||
color: #475569;
|
||||
margin: 8px 0;
|
||||
line-height: 1.5;
|
||||
font-size: 14px;
|
||||
}}
|
||||
|
||||
label {{
|
||||
display: block;
|
||||
margin-bottom: 8px;
|
||||
font-weight: 500;
|
||||
color: #334155;
|
||||
font-size: 14px;
|
||||
}}
|
||||
|
||||
.required {{
|
||||
color: #dc2626;
|
||||
margin-left: 2px;
|
||||
}}
|
||||
|
||||
input[type="text"],
|
||||
input[type="password"] {{
|
||||
width: 100%;
|
||||
padding: 10px 14px;
|
||||
margin-bottom: 20px;
|
||||
box-sizing: border-box;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 6px;
|
||||
font-size: 15px;
|
||||
color: #1e293b;
|
||||
background-color: #fff;
|
||||
transition: border-color 0.2s, box-shadow 0.2s;
|
||||
}}
|
||||
|
||||
input[type="text"]:focus,
|
||||
input[type="password"]:focus {{
|
||||
outline: none;
|
||||
border-color: #3b82f6;
|
||||
box-shadow: 0 0 0 2px rgba(59, 130, 246, 0.2);
|
||||
}}
|
||||
|
||||
.toggle-password {{
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-top: -15px;
|
||||
margin-bottom: 20px;
|
||||
}}
|
||||
|
||||
.toggle-password input[type="checkbox"] {{
|
||||
margin-right: 8px;
|
||||
vertical-align: middle;
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
}}
|
||||
|
||||
.toggle-password label {{
|
||||
margin-bottom: 0;
|
||||
font-size: 14px;
|
||||
cursor: pointer;
|
||||
line-height: 1;
|
||||
}}
|
||||
|
||||
input[type="submit"] {{
|
||||
background-color: #6466E9;
|
||||
color: #fff;
|
||||
cursor: pointer;
|
||||
font-weight: 500;
|
||||
border: none;
|
||||
padding: 10px 16px;
|
||||
transition: background-color 0.2s;
|
||||
border-radius: 6px;
|
||||
margin-top: 10px;
|
||||
font-size: 14px;
|
||||
width: 100%;
|
||||
}}
|
||||
|
||||
input[type="submit"]:hover {{
|
||||
background-color: #4138C2;
|
||||
}}
|
||||
|
||||
a {{
|
||||
color: #3b82f6;
|
||||
text-decoration: none;
|
||||
}}
|
||||
|
||||
a:hover {{
|
||||
text-decoration: underline;
|
||||
}}
|
||||
|
||||
code {{
|
||||
background-color: #f1f5f9;
|
||||
padding: 2px 4px;
|
||||
border-radius: 4px;
|
||||
font-family: monospace;
|
||||
font-size: 13px;
|
||||
color: #334155;
|
||||
}}
|
||||
|
||||
.help-text {{
|
||||
color: #64748b;
|
||||
font-size: 14px;
|
||||
margin-top: -12px;
|
||||
margin-bottom: 20px;
|
||||
}}
|
||||
|
||||
.deprecation-banner {{
|
||||
background-color: #fee2e2;
|
||||
border: 1px solid #ef4444;
|
||||
color: #991b1b;
|
||||
padding: 14px 16px;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 20px;
|
||||
font-size: 14px;
|
||||
line-height: 1.5;
|
||||
}}
|
||||
|
||||
.deprecation-banner a {{
|
||||
color: #991b1b;
|
||||
font-weight: 600;
|
||||
text-decoration: underline;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<form action="{url_to_redirect_to}" method="post">
|
||||
{banner_html}
|
||||
<div class="logo-container">
|
||||
<div class="logo">
|
||||
🚅 LiteLLM
|
||||
</div>
|
||||
</div>
|
||||
<h2>Login</h2>
|
||||
<p class="subtitle">Access your LiteLLM Admin UI.</p>
|
||||
<div class="info-box">
|
||||
<div class="info-header">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<circle cx="12" cy="12" r="10"></circle>
|
||||
<line x1="12" y1="16" x2="12" y2="12"></line>
|
||||
<line x1="12" y1="8" x2="12.01" y2="8"></line>
|
||||
</svg>
|
||||
Default Credentials
|
||||
</div>
|
||||
<p>By default, Username is <code>admin</code> and Password is your set LiteLLM Proxy <code>MASTER_KEY</code>.</p>
|
||||
<p>Need to set UI credentials or SSO? <a href="https://docs.litellm.ai/docs/proxy/ui" target="_blank">Check the documentation</a>.</p>
|
||||
</div>
|
||||
<label for="username">Username<span class="required">*</span></label>
|
||||
<input type="text" id="username" name="username" required placeholder="Enter your username" autocomplete="username">
|
||||
|
||||
<label for="password">Password<span class="required">*</span></label>
|
||||
<input type="password" id="password" name="password" required placeholder="Enter your password" autocomplete="current-password">
|
||||
<div class="toggle-password">
|
||||
<input type="checkbox" id="show-password" onclick="togglePasswordVisibility()">
|
||||
<label for="show-password">Show password</label>
|
||||
</div>
|
||||
<input type="submit" value="Login">
|
||||
</form>
|
||||
<script>
|
||||
function togglePasswordVisibility() {{
|
||||
var passwordField = document.getElementById("password");
|
||||
passwordField.type = passwordField.type === "password" ? "text" : "password";
|
||||
}}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
html_form = build_ui_login_form(show_deprecation_banner=True)
|
||||
@@ -0,0 +1,522 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Collection, Dict, List, Optional
|
||||
|
||||
import orjson
|
||||
from fastapi import Request, UploadFile, status
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import ProxyException
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
get_metadata_variable_name_from_kwargs,
|
||||
)
|
||||
from litellm.types.router import Deployment
|
||||
|
||||
|
||||
async def _read_request_body(request: Optional[Request]) -> Dict:
|
||||
"""
|
||||
Safely read the request body and parse it as JSON.
|
||||
|
||||
Parameters:
|
||||
- request: The request object to read the body from
|
||||
|
||||
Returns:
|
||||
- dict: Parsed request data as a dictionary or an empty dictionary if parsing fails
|
||||
"""
|
||||
try:
|
||||
if request is None:
|
||||
return {}
|
||||
|
||||
# Check if we already read and parsed the body
|
||||
_cached_request_body: Optional[dict] = _safe_get_request_parsed_body(
|
||||
request=request
|
||||
)
|
||||
if _cached_request_body is not None:
|
||||
return _cached_request_body
|
||||
|
||||
_request_headers: dict = _safe_get_request_headers(request=request)
|
||||
content_type = _request_headers.get("content-type", "")
|
||||
|
||||
if "form" in content_type:
|
||||
parsed_body = dict(await request.form())
|
||||
if "metadata" in parsed_body and isinstance(parsed_body["metadata"], str):
|
||||
parsed_body["metadata"] = json.loads(parsed_body["metadata"])
|
||||
else:
|
||||
# Read the request body
|
||||
body = await request.body()
|
||||
|
||||
# Return empty dict if body is empty or None
|
||||
if not body:
|
||||
parsed_body = {}
|
||||
else:
|
||||
try:
|
||||
parsed_body = orjson.loads(body)
|
||||
except orjson.JSONDecodeError as e:
|
||||
# First try the standard json module which is more forgiving
|
||||
# First decode bytes to string if needed
|
||||
body_str = body.decode("utf-8") if isinstance(body, bytes) else body
|
||||
|
||||
# Replace invalid surrogate pairs
|
||||
# This regex finds incomplete surrogate pairs
|
||||
body_str = re.sub(
|
||||
r"[\uD800-\uDBFF](?![\uDC00-\uDFFF])", "", body_str
|
||||
)
|
||||
# This regex finds low surrogates without high surrogates
|
||||
body_str = re.sub(
|
||||
r"(?<![\uD800-\uDBFF])[\uDC00-\uDFFF]", "", body_str
|
||||
)
|
||||
|
||||
try:
|
||||
parsed_body = json.loads(body_str)
|
||||
except json.JSONDecodeError:
|
||||
# If both orjson and json.loads fail, throw a proper error
|
||||
verbose_proxy_logger.error(
|
||||
f"Invalid JSON payload received: {str(e)}"
|
||||
)
|
||||
raise ProxyException(
|
||||
message=f"Invalid JSON payload: {str(e)}",
|
||||
type="invalid_request_error",
|
||||
param="request_body",
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Cache the parsed result
|
||||
_safe_set_request_parsed_body(request=request, parsed_body=parsed_body)
|
||||
return parsed_body
|
||||
|
||||
except (json.JSONDecodeError, orjson.JSONDecodeError, ProxyException) as e:
|
||||
# Re-raise ProxyException as-is
|
||||
verbose_proxy_logger.error(f"Invalid JSON payload received: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
# Catch unexpected errors to avoid crashes
|
||||
verbose_proxy_logger.exception(
|
||||
"Unexpected error reading request body - {}".format(e)
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
def _safe_get_request_parsed_body(request: Optional[Request]) -> Optional[dict]:
|
||||
if request is None:
|
||||
return None
|
||||
if (
|
||||
hasattr(request, "scope")
|
||||
and "parsed_body" in request.scope
|
||||
and isinstance(request.scope["parsed_body"], tuple)
|
||||
):
|
||||
accepted_keys, parsed_body = request.scope["parsed_body"]
|
||||
return {key: parsed_body[key] for key in accepted_keys}
|
||||
return None
|
||||
|
||||
|
||||
def _safe_get_request_query_params(request: Optional[Request]) -> Dict:
|
||||
if request is None:
|
||||
return {}
|
||||
try:
|
||||
if hasattr(request, "query_params"):
|
||||
return dict(request.query_params)
|
||||
return {}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"Unexpected error reading request query params - {}".format(e)
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
def _safe_set_request_parsed_body(
|
||||
request: Optional[Request],
|
||||
parsed_body: dict,
|
||||
) -> None:
|
||||
try:
|
||||
if request is None:
|
||||
return
|
||||
request.scope["parsed_body"] = (tuple(parsed_body.keys()), parsed_body)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"Unexpected error setting request parsed body - {}".format(e)
|
||||
)
|
||||
|
||||
|
||||
def _safe_get_request_headers(request: Optional[Request]) -> dict:
|
||||
"""
|
||||
[Non-Blocking] Safely get the request headers.
|
||||
Caches the result on request.state to avoid re-creating dict(request.headers) per call.
|
||||
|
||||
Warning: Callers must NOT mutate the returned dict — it is shared across
|
||||
all callers within the same request via the cache.
|
||||
"""
|
||||
if request is None:
|
||||
return {}
|
||||
state = getattr(request, "state", None)
|
||||
cached = getattr(state, "_cached_headers", None)
|
||||
if isinstance(cached, dict):
|
||||
return cached
|
||||
if cached is not None:
|
||||
verbose_proxy_logger.debug(
|
||||
"Unexpected cached request headers type - {}".format(type(cached))
|
||||
)
|
||||
try:
|
||||
headers = dict(request.headers)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"Unexpected error reading request headers - {}".format(e)
|
||||
)
|
||||
headers = {}
|
||||
try:
|
||||
if state is not None:
|
||||
state._cached_headers = headers
|
||||
except Exception:
|
||||
pass # request.state may not be available in all contexts
|
||||
return headers
|
||||
|
||||
|
||||
def check_file_size_under_limit(
|
||||
request_data: dict,
|
||||
file: UploadFile,
|
||||
router_model_names: Collection[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if any files passed in request are under max_file_size_mb
|
||||
|
||||
Returns True -> when file size is under max_file_size_mb limit
|
||||
Raises ProxyException -> when file size is over max_file_size_mb limit or not a premium_user
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
CommonProxyErrors,
|
||||
ProxyException,
|
||||
llm_router,
|
||||
premium_user,
|
||||
)
|
||||
|
||||
file_contents_size = file.size or 0
|
||||
file_content_size_in_mb = file_contents_size / (1024 * 1024)
|
||||
if "metadata" not in request_data:
|
||||
request_data["metadata"] = {}
|
||||
request_data["metadata"]["file_size_in_mb"] = file_content_size_in_mb
|
||||
max_file_size_mb = None
|
||||
|
||||
if llm_router is not None and request_data["model"] in router_model_names:
|
||||
try:
|
||||
deployment: Optional[
|
||||
Deployment
|
||||
] = llm_router.get_deployment_by_model_group_name(
|
||||
model_group_name=request_data["model"]
|
||||
)
|
||||
if (
|
||||
deployment
|
||||
and deployment.litellm_params is not None
|
||||
and deployment.litellm_params.max_file_size_mb is not None
|
||||
):
|
||||
max_file_size_mb = deployment.litellm_params.max_file_size_mb
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"Got error when checking file size: %s", (str(e))
|
||||
)
|
||||
|
||||
if max_file_size_mb is not None:
|
||||
verbose_proxy_logger.debug(
|
||||
"Checking file size, file content size=%s, max_file_size_mb=%s",
|
||||
file_content_size_in_mb,
|
||||
max_file_size_mb,
|
||||
)
|
||||
if not premium_user:
|
||||
raise ProxyException(
|
||||
message=f"Tried setting max_file_size_mb for /audio/transcriptions. {CommonProxyErrors.not_premium_user.value}",
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
type="bad_request",
|
||||
param="file",
|
||||
)
|
||||
if file_content_size_in_mb > max_file_size_mb:
|
||||
raise ProxyException(
|
||||
message=f"File size is too large. Please check your file size. Passed file size: {file_content_size_in_mb} MB. Max file size: {max_file_size_mb} MB",
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
type="bad_request",
|
||||
param="file",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def get_form_data(request: Request) -> Dict[str, Any]:
|
||||
"""
|
||||
Read form data from request
|
||||
|
||||
Handles when OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
|
||||
"""
|
||||
form = await request.form()
|
||||
form_data = dict(form)
|
||||
parsed_form_data: dict[str, Any] = {}
|
||||
for key, value in form_data.items():
|
||||
# OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
|
||||
if key.endswith("[]"):
|
||||
clean_key = key[:-2]
|
||||
parsed_form_data.setdefault(clean_key, []).append(value)
|
||||
else:
|
||||
parsed_form_data[key] = value
|
||||
return parsed_form_data
|
||||
|
||||
|
||||
async def convert_upload_files_to_file_data(
|
||||
form_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert FastAPI UploadFile objects to file data tuples for litellm.
|
||||
|
||||
Converts UploadFile objects to tuples of (filename, content, content_type)
|
||||
which is the format expected by httpx and litellm's HTTP handlers.
|
||||
|
||||
Args:
|
||||
form_data: Dictionary containing form data with potential UploadFile objects
|
||||
|
||||
Returns:
|
||||
Dictionary with UploadFile objects converted to file data tuples
|
||||
|
||||
Example:
|
||||
```python
|
||||
form_data = await get_form_data(request)
|
||||
data = await convert_upload_files_to_file_data(form_data)
|
||||
# data["files"] is now [(filename, content, content_type), ...]
|
||||
```
|
||||
"""
|
||||
data = {}
|
||||
for key, value in form_data.items():
|
||||
if isinstance(value, list):
|
||||
# Check if it's a list of UploadFile objects
|
||||
if value and hasattr(value[0], "read"):
|
||||
files = []
|
||||
for f in value:
|
||||
file_content = await f.read()
|
||||
# Create tuple: (filename, content, content_type)
|
||||
files.append((f.filename, file_content, f.content_type))
|
||||
data[key] = files
|
||||
else:
|
||||
data[key] = value
|
||||
elif hasattr(value, "read"):
|
||||
# Single UploadFile object - read and convert to list for consistency
|
||||
file_content = await value.read()
|
||||
data[key] = [(value.filename, file_content, value.content_type)]
|
||||
else:
|
||||
# Regular form field
|
||||
data[key] = value
|
||||
return data
|
||||
|
||||
|
||||
async def get_request_body(request: Request) -> Dict[str, Any]:
|
||||
"""
|
||||
Read the request body and parse it as JSON.
|
||||
"""
|
||||
if request.method == "POST":
|
||||
if request.headers.get("content-type", "") == "application/json":
|
||||
return await _read_request_body(request)
|
||||
elif "multipart/form-data" in request.headers.get(
|
||||
"content-type", ""
|
||||
) or "application/x-www-form-urlencoded" in request.headers.get(
|
||||
"content-type", ""
|
||||
):
|
||||
return await get_form_data(request)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported content type: {request.headers.get('content-type')}"
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
def extract_nested_form_metadata(
|
||||
form_data: Dict[str, Any], prefix: str = "litellm_metadata["
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract nested metadata from form data with bracket notation.
|
||||
|
||||
Handles form data that uses bracket notation to represent nested dictionaries,
|
||||
such as litellm_metadata[spend_logs_metadata][owner] = "value".
|
||||
|
||||
This is commonly encountered when SDKs or clients send form data with nested
|
||||
structures using bracket notation instead of JSON.
|
||||
|
||||
Args:
|
||||
form_data: Dictionary containing form data (from request.form())
|
||||
prefix: The prefix to look for in form keys (default: "litellm_metadata[")
|
||||
|
||||
Returns:
|
||||
Dictionary with nested structure reconstructed from bracket notation
|
||||
|
||||
Example:
|
||||
Input form_data:
|
||||
{
|
||||
"litellm_metadata[spend_logs_metadata][owner]": "john",
|
||||
"litellm_metadata[spend_logs_metadata][team]": "engineering",
|
||||
"litellm_metadata[tags]": "production",
|
||||
"other_field": "value"
|
||||
}
|
||||
|
||||
Output:
|
||||
{
|
||||
"spend_logs_metadata": {
|
||||
"owner": "john",
|
||||
"team": "engineering"
|
||||
},
|
||||
"tags": "production"
|
||||
}
|
||||
"""
|
||||
if not form_data:
|
||||
return {}
|
||||
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
for key, value in form_data.items():
|
||||
# Skip keys that don't start with the prefix
|
||||
if not isinstance(key, str) or not key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# Skip UploadFile objects - they should not be in metadata
|
||||
if isinstance(value, UploadFile):
|
||||
verbose_proxy_logger.warning(
|
||||
f"Skipping UploadFile in metadata extraction for key: {key}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Extract the nested path from bracket notation
|
||||
# Example: "litellm_metadata[spend_logs_metadata][owner]" -> ["spend_logs_metadata", "owner"]
|
||||
try:
|
||||
# Remove the prefix and strip trailing ']'
|
||||
path_string = key.replace(prefix, "").rstrip("]")
|
||||
|
||||
# Split by "][" to get individual path parts
|
||||
parts = path_string.split("][")
|
||||
|
||||
if not parts or not parts[0]:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Invalid metadata key format (empty path): {key}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Navigate/create nested dictionary structure
|
||||
current = metadata
|
||||
for part in parts[:-1]:
|
||||
if not isinstance(current, dict):
|
||||
verbose_proxy_logger.warning(
|
||||
f"Cannot create nested path - intermediate value is not a dict at: {part}"
|
||||
)
|
||||
break
|
||||
current = current.setdefault(part, {})
|
||||
else:
|
||||
# Set the final value (only if we didn't break out of the loop)
|
||||
if isinstance(current, dict):
|
||||
current[parts[-1]] = value
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Cannot set value - parent is not a dict for key: {key}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error parsing metadata key '{key}': {str(e)}")
|
||||
continue
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def get_tags_from_request_body(request_body: dict) -> List[str]:
|
||||
"""
|
||||
Extract tags from request body metadata.
|
||||
|
||||
Args:
|
||||
request_body: The request body dictionary
|
||||
|
||||
Returns:
|
||||
List of tag names (strings), empty list if no valid tags found
|
||||
"""
|
||||
metadata_variable_name = get_metadata_variable_name_from_kwargs(request_body)
|
||||
metadata = request_body.get(metadata_variable_name) or {}
|
||||
tags_in_metadata: Any = metadata.get("tags", [])
|
||||
tags_in_request_body: Any = request_body.get("tags", [])
|
||||
combined_tags: List[str] = []
|
||||
|
||||
######################################
|
||||
# Only combine tags if they are lists
|
||||
######################################
|
||||
if isinstance(tags_in_metadata, list):
|
||||
combined_tags.extend(tags_in_metadata)
|
||||
if isinstance(tags_in_request_body, list):
|
||||
combined_tags.extend(tags_in_request_body)
|
||||
######################################
|
||||
return [tag for tag in combined_tags if isinstance(tag, str)]
|
||||
|
||||
|
||||
def populate_request_with_path_params(request_data: dict, request: Request) -> dict:
|
||||
"""
|
||||
Copy FastAPI path params and query params into the request payload so downstream checks
|
||||
(e.g. vector store RBAC, organization RBAC) see them the same way as body params.
|
||||
|
||||
Since path_params may not be available during dependency injection,
|
||||
we parse the URL path directly for known patterns.
|
||||
|
||||
Args:
|
||||
request_data: The request data dictionary to populate
|
||||
request: The FastAPI Request object
|
||||
|
||||
Returns:
|
||||
dict: Updated request_data with path parameters and query parameters added
|
||||
"""
|
||||
# Add query parameters to request_data (for GET requests, etc.)
|
||||
query_params = _safe_get_request_query_params(request)
|
||||
if query_params:
|
||||
for key, value in query_params.items():
|
||||
# Don't overwrite existing values from request body
|
||||
request_data.setdefault(key, value)
|
||||
|
||||
# Try to get path_params if available (sometimes populated by FastAPI)
|
||||
path_params = getattr(request, "path_params", None)
|
||||
if isinstance(path_params, dict) and path_params:
|
||||
for key, value in path_params.items():
|
||||
if key == "vector_store_id":
|
||||
request_data.setdefault("vector_store_id", value)
|
||||
existing_ids = request_data.get("vector_store_ids")
|
||||
if isinstance(existing_ids, list):
|
||||
if value not in existing_ids:
|
||||
existing_ids.append(value)
|
||||
else:
|
||||
request_data["vector_store_ids"] = [value]
|
||||
continue
|
||||
request_data.setdefault(key, value)
|
||||
verbose_proxy_logger.debug(
|
||||
f"populate_request_with_path_params: Found path_params, vector_store_ids={request_data.get('vector_store_ids')}"
|
||||
)
|
||||
return request_data
|
||||
|
||||
# Fallback: parse the URL path directly to extract vector_store_id
|
||||
_add_vector_store_id_from_path(request_data=request_data, request=request)
|
||||
|
||||
return request_data
|
||||
|
||||
|
||||
def _add_vector_store_id_from_path(request_data: dict, request: Request) -> None:
|
||||
"""
|
||||
Parse the request path to find /vector_stores/{vector_store_id}/... segments.
|
||||
|
||||
When found, ensure both vector_store_id and vector_store_ids are populated.
|
||||
|
||||
Args:
|
||||
request_data: The request data dictionary to populate
|
||||
request: The FastAPI Request object
|
||||
"""
|
||||
path = request.url.path
|
||||
vector_store_match = re.search(r"/vector_stores/([^/]+)/", path)
|
||||
if vector_store_match:
|
||||
vector_store_id = vector_store_match.group(1)
|
||||
verbose_proxy_logger.debug(
|
||||
f"populate_request_with_path_params: Extracted vector_store_id={vector_store_id} from path={path}"
|
||||
)
|
||||
request_data.setdefault("vector_store_id", vector_store_id)
|
||||
existing_ids = request_data.get("vector_store_ids")
|
||||
if isinstance(existing_ids, list):
|
||||
if vector_store_id not in existing_ids:
|
||||
existing_ids.append(vector_store_id)
|
||||
else:
|
||||
request_data["vector_store_ids"] = [vector_store_id]
|
||||
verbose_proxy_logger.debug(
|
||||
f"populate_request_with_path_params: Updated request_data with vector_store_ids={request_data.get('vector_store_ids')}"
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
f"populate_request_with_path_params: No vector_store_id present in path={path}"
|
||||
)
|
||||
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
Key Rotation Manager - Automated key rotation based on rotation schedules
|
||||
|
||||
Handles finding keys that need rotation based on their individual schedules.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import List
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import (
|
||||
LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME,
|
||||
LITELLM_KEY_ROTATION_GRACE_PERIOD,
|
||||
)
|
||||
from litellm.proxy._types import (
|
||||
GenerateKeyResponse,
|
||||
LiteLLM_VerificationToken,
|
||||
RegenerateKeyRequest,
|
||||
)
|
||||
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
_calculate_key_rotation_time,
|
||||
regenerate_key_fn,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
|
||||
class KeyRotationManager:
|
||||
"""
|
||||
Manages automated key rotation based on individual key rotation schedules.
|
||||
"""
|
||||
|
||||
def __init__(self, prisma_client: PrismaClient):
|
||||
self.prisma_client = prisma_client
|
||||
|
||||
async def process_rotations(self):
|
||||
"""
|
||||
Main entry point - find and rotate keys that are due for rotation
|
||||
"""
|
||||
try:
|
||||
verbose_proxy_logger.info("Starting scheduled key rotation check...")
|
||||
|
||||
# Clean up expired deprecated keys first
|
||||
await self._cleanup_expired_deprecated_keys()
|
||||
|
||||
# Find keys that are due for rotation
|
||||
keys_to_rotate = await self._find_keys_needing_rotation()
|
||||
|
||||
if not keys_to_rotate:
|
||||
verbose_proxy_logger.debug("No keys are due for rotation at this time")
|
||||
return
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Found {len(keys_to_rotate)} keys due for rotation"
|
||||
)
|
||||
|
||||
# Rotate each key
|
||||
for key in keys_to_rotate:
|
||||
try:
|
||||
await self._rotate_key(key)
|
||||
key_identifier = key.key_name or (
|
||||
key.token[:8] + "..." if key.token else "unknown"
|
||||
)
|
||||
verbose_proxy_logger.info(
|
||||
f"Successfully rotated key: {key_identifier}"
|
||||
)
|
||||
except Exception as e:
|
||||
key_identifier = key.key_name or (
|
||||
key.token[:8] + "..." if key.token else "unknown"
|
||||
)
|
||||
verbose_proxy_logger.error(
|
||||
f"Failed to rotate key {key_identifier}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Key rotation process failed: {e}")
|
||||
|
||||
async def _find_keys_needing_rotation(self) -> List[LiteLLM_VerificationToken]:
|
||||
"""
|
||||
Find keys that are due for rotation based on their key_rotation_at timestamp.
|
||||
|
||||
Logic:
|
||||
- Key has auto_rotate = true
|
||||
- key_rotation_at is null (needs initial setup) OR key_rotation_at <= now
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
keys_with_rotation = (
|
||||
await self.prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={
|
||||
"auto_rotate": True, # Only keys marked for auto rotation
|
||||
"OR": [
|
||||
{
|
||||
"key_rotation_at": None
|
||||
}, # Keys that need initial rotation time setup
|
||||
{
|
||||
"key_rotation_at": {"lte": now}
|
||||
}, # Keys where rotation time has passed
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return keys_with_rotation
|
||||
|
||||
async def _cleanup_expired_deprecated_keys(self) -> None:
|
||||
"""
|
||||
Remove deprecated key entries whose revoke_at has passed.
|
||||
"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
result = await self.prisma_client.db.litellm_deprecatedverificationtoken.delete_many(
|
||||
where={"revoke_at": {"lt": now}}
|
||||
)
|
||||
if result > 0:
|
||||
verbose_proxy_logger.debug(
|
||||
"Cleaned up %s expired deprecated key(s)", result
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"Deprecated key cleanup skipped (table may not exist): %s", e
|
||||
)
|
||||
|
||||
def _should_rotate_key(self, key: LiteLLM_VerificationToken, now: datetime) -> bool:
|
||||
"""
|
||||
Determine if a key should be rotated based on key_rotation_at timestamp.
|
||||
"""
|
||||
if not key.rotation_interval:
|
||||
return False
|
||||
|
||||
# If key_rotation_at is not set, rotate immediately (and set it)
|
||||
if key.key_rotation_at is None:
|
||||
return True
|
||||
|
||||
# Check if the rotation time has passed
|
||||
return now >= key.key_rotation_at
|
||||
|
||||
async def _rotate_key(self, key: LiteLLM_VerificationToken):
|
||||
"""
|
||||
Rotate a single key using existing regenerate_key_fn and call the rotation hook
|
||||
"""
|
||||
# Create regenerate request with grace period for seamless cutover
|
||||
regenerate_request = RegenerateKeyRequest(
|
||||
key=key.token or "",
|
||||
key_alias=key.key_alias, # Pass key alias to ensure correct secret is updated in AWS Secrets Manager
|
||||
grace_period=LITELLM_KEY_ROTATION_GRACE_PERIOD or None,
|
||||
)
|
||||
|
||||
# Create a system user for key rotation
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
system_user = UserAPIKeyAuth.get_litellm_internal_jobs_user_api_key_auth()
|
||||
|
||||
# Use existing regenerate key function
|
||||
response = await regenerate_key_fn(
|
||||
data=regenerate_request,
|
||||
user_api_key_dict=system_user,
|
||||
litellm_changed_by=LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME,
|
||||
)
|
||||
|
||||
# Update the NEW key with rotation info (regenerate_key_fn creates a new token)
|
||||
if (
|
||||
isinstance(response, GenerateKeyResponse)
|
||||
and response.token_id
|
||||
and key.rotation_interval
|
||||
):
|
||||
# Calculate next rotation time using helper function
|
||||
now = datetime.now(timezone.utc)
|
||||
next_rotation_time = _calculate_key_rotation_time(key.rotation_interval)
|
||||
await self.prisma_client.db.litellm_verificationtoken.update(
|
||||
where={"token": response.token_id},
|
||||
data={
|
||||
"rotation_count": (key.rotation_count or 0) + 1,
|
||||
"last_rotation_at": now,
|
||||
"key_rotation_at": next_rotation_time,
|
||||
},
|
||||
)
|
||||
|
||||
# Call the existing rotation hook for notifications, audit logs, etc.
|
||||
if isinstance(response, GenerateKeyResponse):
|
||||
await KeyManagementEventHooks.async_key_rotated_hook(
|
||||
data=regenerate_request,
|
||||
existing_key_row=key,
|
||||
response=response,
|
||||
user_api_key_dict=system_user,
|
||||
litellm_changed_by=LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME,
|
||||
)
|
||||
@@ -0,0 +1,178 @@
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
|
||||
def get_file_contents_from_s3(bucket_name, object_key):
|
||||
try:
|
||||
# v0 rely on boto3 for authentication - allowing boto3 to handle IAM credentials etc
|
||||
import boto3
|
||||
from botocore.credentials import Credentials
|
||||
|
||||
from litellm.main import bedrock_converse_chat_completion
|
||||
|
||||
credentials: Credentials = bedrock_converse_chat_completion.get_credentials()
|
||||
s3_client = boto3.client(
|
||||
"s3",
|
||||
aws_access_key_id=credentials.access_key,
|
||||
aws_secret_access_key=credentials.secret_key,
|
||||
aws_session_token=credentials.token, # Optional, if using temporary credentials
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Retrieving {object_key} from S3 bucket: {bucket_name}"
|
||||
)
|
||||
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
|
||||
verbose_proxy_logger.debug(f"Response: {response}")
|
||||
|
||||
# Read the file contents and directly parse YAML
|
||||
file_contents = response["Body"].read().decode("utf-8")
|
||||
verbose_proxy_logger.debug("File contents retrieved from S3")
|
||||
|
||||
# Parse YAML directly from string
|
||||
config = yaml.safe_load(file_contents)
|
||||
return config
|
||||
|
||||
except ImportError as e:
|
||||
# this is most likely if a user is not using the litellm docker container
|
||||
verbose_proxy_logger.error(f"ImportError: {str(e)}")
|
||||
pass
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error retrieving file contents: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_config_file_contents_from_gcs(bucket_name, object_key):
|
||||
try:
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket import GCSBucketLogger
|
||||
|
||||
gcs_bucket = GCSBucketLogger(
|
||||
bucket_name=bucket_name,
|
||||
)
|
||||
file_contents = await gcs_bucket.download_gcs_object(object_key)
|
||||
if file_contents is None:
|
||||
raise Exception(f"File contents are None for {object_key}")
|
||||
# file_contentis is a bytes object, so we need to convert it to yaml
|
||||
file_contents = file_contents.decode("utf-8")
|
||||
# convert to yaml
|
||||
config = yaml.safe_load(file_contents)
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error retrieving file contents: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def download_python_file_from_s3(
|
||||
bucket_name: str,
|
||||
object_key: str,
|
||||
local_file_path: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Download a Python file from S3 and save it to local filesystem.
|
||||
|
||||
Args:
|
||||
bucket_name (str): S3 bucket name
|
||||
object_key (str): S3 object key (file path in bucket)
|
||||
local_file_path (str): Local path where file should be saved
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
import boto3
|
||||
from botocore.credentials import Credentials
|
||||
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
|
||||
base_aws_llm = BaseAWSLLM()
|
||||
|
||||
credentials: Credentials = base_aws_llm.get_credentials()
|
||||
s3_client = boto3.client(
|
||||
"s3",
|
||||
aws_access_key_id=credentials.access_key,
|
||||
aws_secret_access_key=credentials.secret_key,
|
||||
aws_session_token=credentials.token,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Downloading Python file {object_key} from S3 bucket: {bucket_name}"
|
||||
)
|
||||
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
|
||||
|
||||
# Read the file contents
|
||||
file_contents = response["Body"].read().decode("utf-8")
|
||||
verbose_proxy_logger.debug(f"File contents: {file_contents}")
|
||||
|
||||
# Ensure directory exists
|
||||
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
||||
|
||||
# Write to local file
|
||||
with open(local_file_path, "w") as f:
|
||||
f.write(file_contents)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Python file downloaded successfully to {local_file_path}"
|
||||
)
|
||||
return True
|
||||
|
||||
except ImportError as e:
|
||||
verbose_proxy_logger.error(f"ImportError: {str(e)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error downloading Python file: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def download_python_file_from_gcs(
|
||||
bucket_name: str,
|
||||
object_key: str,
|
||||
local_file_path: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Download a Python file from GCS and save it to local filesystem.
|
||||
|
||||
Args:
|
||||
bucket_name (str): GCS bucket name
|
||||
object_key (str): GCS object key (file path in bucket)
|
||||
local_file_path (str): Local path where file should be saved
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket import GCSBucketLogger
|
||||
|
||||
gcs_bucket = GCSBucketLogger(
|
||||
bucket_name=bucket_name,
|
||||
)
|
||||
file_contents = await gcs_bucket.download_gcs_object(object_key)
|
||||
if file_contents is None:
|
||||
raise Exception(f"File contents are None for {object_key}")
|
||||
|
||||
# file_contents is a bytes object, decode it
|
||||
file_contents = file_contents.decode("utf-8")
|
||||
|
||||
# Ensure directory exists
|
||||
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
||||
|
||||
# Write to local file
|
||||
with open(local_file_path, "w") as f:
|
||||
f.write(file_contents)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Python file downloaded successfully to {local_file_path}"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error downloading Python file from GCS: {str(e)}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
# # Example usage
|
||||
# bucket_name = 'litellm-proxy'
|
||||
# object_key = 'litellm_proxy_config.yaml'
|
||||
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Contains utils used by OpenAI compatible endpoints
|
||||
"""
|
||||
|
||||
from typing import Optional, Set
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
|
||||
SENSITIVE_DATA_MASKER = SensitiveDataMasker()
|
||||
|
||||
|
||||
def remove_sensitive_info_from_deployment(
|
||||
deployment_dict: dict,
|
||||
excluded_keys: Optional[Set[str]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Removes sensitive information from a deployment dictionary.
|
||||
|
||||
Args:
|
||||
deployment_dict (dict): The deployment dictionary to remove sensitive information from.
|
||||
excluded_keys (Optional[Set[str]]): Set of keys that should not be masked (exact match).
|
||||
|
||||
Returns:
|
||||
dict: The modified deployment dictionary with sensitive information removed.
|
||||
"""
|
||||
deployment_dict["litellm_params"].pop("api_key", None)
|
||||
deployment_dict["litellm_params"].pop("client_secret", None)
|
||||
deployment_dict["litellm_params"].pop("vertex_credentials", None)
|
||||
deployment_dict["litellm_params"].pop("aws_access_key_id", None)
|
||||
deployment_dict["litellm_params"].pop("aws_secret_access_key", None)
|
||||
|
||||
deployment_dict["litellm_params"] = SENSITIVE_DATA_MASKER.mask_dict(
|
||||
deployment_dict["litellm_params"], excluded_keys=excluded_keys
|
||||
)
|
||||
|
||||
return deployment_dict
|
||||
|
||||
|
||||
async def get_custom_llm_provider_from_request_body(request: Request) -> Optional[str]:
|
||||
"""
|
||||
Get the `custom_llm_provider` from the request body
|
||||
|
||||
Safely reads the request body
|
||||
"""
|
||||
request_body: dict = await _read_request_body(request=request) or {}
|
||||
if "custom_llm_provider" in request_body:
|
||||
return request_body["custom_llm_provider"]
|
||||
return None
|
||||
|
||||
|
||||
def get_custom_llm_provider_from_request_query(request: Request) -> Optional[str]:
|
||||
"""
|
||||
Get the `custom_llm_provider` from the request query parameters
|
||||
|
||||
Safely reads the request query parameters
|
||||
"""
|
||||
if "custom_llm_provider" in request.query_params:
|
||||
return request.query_params["custom_llm_provider"]
|
||||
return None
|
||||
|
||||
|
||||
def get_custom_llm_provider_from_request_headers(request: Request) -> Optional[str]:
|
||||
"""
|
||||
Get the `custom_llm_provider` from the request header `custom-llm-provider`
|
||||
"""
|
||||
if "custom-llm-provider" in request.headers:
|
||||
return request.headers["custom-llm-provider"]
|
||||
return None
|
||||
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Utility module for handling OpenAPI schema generation compatibility with FastAPI 0.120+.
|
||||
|
||||
FastAPI 0.120+ has stricter schema generation that fails on certain types like openai.Timeout.
|
||||
This module provides a compatibility layer to handle these cases gracefully.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
|
||||
def get_openapi_schema_with_compat(
|
||||
get_openapi_func,
|
||||
title: str,
|
||||
version: str,
|
||||
description: str,
|
||||
routes: list,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate OpenAPI schema with compatibility handling for FastAPI 0.120+.
|
||||
|
||||
This function patches Pydantic's schema generation to handle non-serializable types
|
||||
like openai.Timeout that cause PydanticSchemaGenerationError in FastAPI 0.120+.
|
||||
|
||||
Args:
|
||||
get_openapi_func: The FastAPI get_openapi function
|
||||
title: API title
|
||||
version: API version
|
||||
description: API description
|
||||
routes: List of routes
|
||||
|
||||
Returns:
|
||||
OpenAPI schema dictionary
|
||||
"""
|
||||
# FastAPI 0.120+ may fail schema generation for certain types (e.g., openai.Timeout)
|
||||
# Patch Pydantic's schema generation to handle unknown types gracefully
|
||||
try:
|
||||
from pydantic._internal._generate_schema import GenerateSchema
|
||||
from pydantic_core import core_schema
|
||||
|
||||
# Store original method
|
||||
original_unknown_type_schema = GenerateSchema._unknown_type_schema
|
||||
|
||||
def patched_unknown_type_schema(self, obj):
|
||||
"""Patch to handle openai.Timeout and other non-serializable types"""
|
||||
# Check if it's openai.Timeout or similar types
|
||||
obj_str = str(obj)
|
||||
obj_module = getattr(obj, "__module__", "")
|
||||
|
||||
if (obj_module == "openai" and "Timeout" in obj_str) or (
|
||||
hasattr(obj, "__name__")
|
||||
and obj.__name__ == "Timeout"
|
||||
and obj_module == "openai"
|
||||
):
|
||||
# Return a simple string schema for Timeout types
|
||||
return core_schema.str_schema()
|
||||
|
||||
# For other unknown types, try to return a default schema
|
||||
# This prevents the error from propagating
|
||||
try:
|
||||
return core_schema.any_schema()
|
||||
except Exception:
|
||||
# Last resort: return string schema
|
||||
return core_schema.str_schema()
|
||||
|
||||
# Apply patch
|
||||
setattr(GenerateSchema, "_unknown_type_schema", patched_unknown_type_schema)
|
||||
|
||||
try:
|
||||
openapi_schema = get_openapi_func(
|
||||
title=title,
|
||||
version=version,
|
||||
description=description,
|
||||
routes=routes,
|
||||
)
|
||||
finally:
|
||||
# Restore original method
|
||||
setattr(
|
||||
GenerateSchema, "_unknown_type_schema", original_unknown_type_schema
|
||||
)
|
||||
|
||||
return openapi_schema
|
||||
|
||||
except (ImportError, AttributeError) as e:
|
||||
# If patching fails, try normal generation with error handling
|
||||
verbose_proxy_logger.debug(
|
||||
f"Could not patch Pydantic schema generation: {e}. Trying normal generation."
|
||||
)
|
||||
try:
|
||||
return get_openapi_func(
|
||||
title=title,
|
||||
version=version,
|
||||
description=description,
|
||||
routes=routes,
|
||||
)
|
||||
except Exception as pydantic_error:
|
||||
# Check if it's a PydanticSchemaGenerationError by checking the error type name
|
||||
# This avoids import issues if PydanticSchemaGenerationError is not available
|
||||
error_type_name = type(pydantic_error).__name__
|
||||
if (
|
||||
error_type_name == "PydanticSchemaGenerationError"
|
||||
or "PydanticSchemaGenerationError" in str(type(pydantic_error))
|
||||
):
|
||||
# If we still get the error, log it and return minimal schema
|
||||
verbose_proxy_logger.warning(
|
||||
f"PydanticSchemaGenerationError during schema generation: {pydantic_error}"
|
||||
)
|
||||
return {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": title,
|
||||
"version": version,
|
||||
"description": description or "",
|
||||
},
|
||||
"paths": {},
|
||||
"components": {"schemas": {}},
|
||||
}
|
||||
else:
|
||||
# Re-raise if it's a different error
|
||||
raise
|
||||
@@ -0,0 +1,214 @@
|
||||
# Performance Utilities Documentation
|
||||
|
||||
This module provides performance monitoring and profiling functionality for LiteLLM proxy server using `cProfile` and `line_profiler`.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Line Profiler Usage](#line-profiler-usage)
|
||||
- [Example 1: Wrapping a function directly](#example-1-wrapping-a-function-directly)
|
||||
- [Example 2: Wrapping a module function dynamically](#example-2-wrapping-a-module-function-dynamically)
|
||||
- [Example 3: Manual stats collection](#example-3-manual-stats-collection)
|
||||
- [Example 4: Analyzing the profile output](#example-4-analyzing-the-profile-output)
|
||||
- [Example 5: Using in a decorator pattern](#example-5-using-in-a-decorator-pattern)
|
||||
- [cProfile Usage](#cprofile-usage)
|
||||
- [Installation](#installation)
|
||||
- [Notes](#notes)
|
||||
|
||||
## Line Profiler Usage
|
||||
|
||||
### Example 1: Wrapping a function directly
|
||||
|
||||
This is how it's used in `litellm/utils.py` to profile `wrapper_async`:
|
||||
|
||||
```python
|
||||
from litellm.proxy.common_utils.performance_utils import (
|
||||
register_shutdown_handler,
|
||||
wrap_function_directly,
|
||||
)
|
||||
|
||||
def client(original_function):
|
||||
@wraps(original_function)
|
||||
async def wrapper_async(*args, **kwargs):
|
||||
# ... function implementation ...
|
||||
pass
|
||||
|
||||
# Wrap the function with line_profiler
|
||||
wrapper_async = wrap_function_directly(wrapper_async)
|
||||
|
||||
# Register shutdown handler to collect stats on server shutdown
|
||||
register_shutdown_handler(output_file="wrapper_async_line_profile.lprof")
|
||||
|
||||
return wrapper_async
|
||||
```
|
||||
|
||||
### Example 2: Wrapping a module function dynamically
|
||||
|
||||
```python
|
||||
import my_module
|
||||
from litellm.proxy.common_utils.performance_utils import (
|
||||
wrap_function_with_line_profiler,
|
||||
register_shutdown_handler,
|
||||
)
|
||||
|
||||
# Wrap a function in a module
|
||||
wrap_function_with_line_profiler(my_module, "expensive_function")
|
||||
|
||||
# Register shutdown handler
|
||||
register_shutdown_handler(output_file="my_profile.lprof")
|
||||
|
||||
# Now all calls to my_module.expensive_function will be profiled
|
||||
my_module.expensive_function()
|
||||
```
|
||||
|
||||
### Example 3: Manual stats collection
|
||||
|
||||
```python
|
||||
from litellm.proxy.common_utils.performance_utils import (
|
||||
wrap_function_directly,
|
||||
collect_line_profiler_stats,
|
||||
)
|
||||
|
||||
def my_function():
|
||||
# ... implementation ...
|
||||
pass
|
||||
|
||||
# Wrap the function
|
||||
my_function = wrap_function_directly(my_function)
|
||||
|
||||
# Run your code
|
||||
my_function()
|
||||
|
||||
# Collect stats manually (instead of waiting for shutdown)
|
||||
collect_line_profiler_stats(output_file="manual_profile.lprof")
|
||||
```
|
||||
|
||||
### Example 4: Analyzing the profile output
|
||||
|
||||
After running your code, analyze the `.lprof` file:
|
||||
|
||||
```bash
|
||||
# View the profile
|
||||
python -m line_profiler wrapper_async_line_profile.lprof
|
||||
|
||||
# Save to text file
|
||||
python -m line_profiler wrapper_async_line_profile.lprof > profile_report.txt
|
||||
```
|
||||
|
||||
The output shows:
|
||||
- **Line #**: Line number in the source file
|
||||
- **Hits**: Number of times the line was executed
|
||||
- **Time**: Total time spent on that line (in microseconds)
|
||||
- **Per Hit**: Average time per execution
|
||||
- **% Time**: Percentage of total function time
|
||||
- **Line Contents**: The actual source code
|
||||
|
||||
Example output:
|
||||
```
|
||||
Timer unit: 1e-06 s
|
||||
|
||||
Total time: 3.73697 s
|
||||
File: litellm/utils.py
|
||||
Function: client.<locals>.wrapper_async at line 1657
|
||||
|
||||
Line # Hits Time Per Hit % Time Line Contents
|
||||
==============================================================
|
||||
1657 @wraps(original_function)
|
||||
1658 async def wrapper_async(*args, **kwargs):
|
||||
1659 2005 7577.1 3.8 0.2 print_args_passed_to_litellm(...)
|
||||
1763 2005 1351909.0 674.3 36.2 result = await original_function(*args, **kwargs)
|
||||
1846 4010 1543688.1 385.0 41.3 update_response_metadata(...)
|
||||
```
|
||||
|
||||
### Example 5: Using in a decorator pattern
|
||||
|
||||
```python
|
||||
from litellm.proxy.common_utils.performance_utils import (
|
||||
wrap_function_directly,
|
||||
register_shutdown_handler,
|
||||
)
|
||||
|
||||
def profile_decorator(func):
|
||||
# Wrap the function
|
||||
profiled_func = wrap_function_directly(func)
|
||||
|
||||
# Register shutdown handler (only once)
|
||||
if not hasattr(profile_decorator, '_registered'):
|
||||
register_shutdown_handler(output_file="decorated_functions.lprof")
|
||||
profile_decorator._registered = True
|
||||
|
||||
return profiled_func
|
||||
|
||||
@profile_decorator
|
||||
async def my_async_function():
|
||||
# This function will be profiled
|
||||
pass
|
||||
```
|
||||
|
||||
## cProfile Usage
|
||||
|
||||
### Example: Using the profile_endpoint decorator
|
||||
|
||||
```python
|
||||
from litellm.proxy.common_utils.performance_utils import profile_endpoint
|
||||
|
||||
@profile_endpoint(sampling_rate=0.1) # Profile 10% of requests
|
||||
async def my_endpoint():
|
||||
# ... implementation ...
|
||||
pass
|
||||
```
|
||||
|
||||
The `sampling_rate` parameter controls what percentage of requests are profiled:
|
||||
- `1.0`: Profile all requests (100%)
|
||||
- `0.1`: Profile 1 in 10 requests (10%)
|
||||
- `0.0`: Profile no requests (0%)
|
||||
|
||||
## Installation
|
||||
|
||||
`line_profiler` must be installed to use the line profiling functionality:
|
||||
|
||||
```bash
|
||||
pip install line_profiler
|
||||
```
|
||||
|
||||
On Windows with Python 3.14+, you may need to install Microsoft Visual C++ Build Tools to compile `line_profiler` from source.
|
||||
|
||||
## Notes
|
||||
|
||||
- The profiler aggregates stats by source code location, so multiple instances of the same function (e.g., closures) will be profiled together
|
||||
- Stats are automatically collected on server shutdown via `atexit` handler when using `register_shutdown_handler()`
|
||||
- You can also manually collect stats using `collect_line_profiler_stats()`
|
||||
- The line profiler will fail with an `ImportError` if `line_profiler` is not installed (as configured in `litellm/utils.py`)
|
||||
|
||||
## API Reference
|
||||
|
||||
### `wrap_function_directly(func: Callable) -> Callable`
|
||||
|
||||
Wrap a function directly with line_profiler. This is the recommended way to profile functions, especially closures or functions created dynamically.
|
||||
|
||||
**Raises:**
|
||||
- `ImportError`: If line_profiler is not available
|
||||
- `RuntimeError`: If line_profiler cannot be enabled or function cannot be wrapped
|
||||
|
||||
### `wrap_function_with_line_profiler(module: Any, function_name: str) -> bool`
|
||||
|
||||
Dynamically wrap a function in a module with line_profiler.
|
||||
|
||||
**Returns:** `True` if wrapping was successful, `False` otherwise
|
||||
|
||||
### `collect_line_profiler_stats(output_file: Optional[str] = None) -> None`
|
||||
|
||||
Collect and save line_profiler statistics. If `output_file` is provided, saves to file. Otherwise, prints to stdout.
|
||||
|
||||
### `register_shutdown_handler(output_file: Optional[str] = None) -> None`
|
||||
|
||||
Register an `atexit` handler that will automatically save profiling statistics when the Python process exits. Safe to call multiple times (only registers once).
|
||||
|
||||
**Default output file:** `line_profile_stats.lprof` if not specified
|
||||
|
||||
### `profile_endpoint(sampling_rate: float = 1.0)`
|
||||
|
||||
Decorator to sample endpoint hits and save to a profile file using cProfile.
|
||||
|
||||
**Args:**
|
||||
- `sampling_rate`: Rate of requests to profile (0.0 to 1.0)
|
||||
|
||||
@@ -0,0 +1,296 @@
|
||||
"""
|
||||
Performance utilities for LiteLLM proxy server.
|
||||
|
||||
This module provides performance monitoring and profiling functionality for endpoint
|
||||
performance analysis using cProfile with configurable sampling rates, and line_profiler
|
||||
for line-by-line profiling.
|
||||
|
||||
See performance_utils.md for detailed usage examples and documentation.
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import cProfile
|
||||
import functools
|
||||
import inspect
|
||||
import threading
|
||||
from pathlib import Path as PathLib
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
# Global profiling state
|
||||
_profile_lock = threading.Lock()
|
||||
_profiler = None
|
||||
_last_profile_file_path = None
|
||||
_sample_counter = 0
|
||||
_sample_counter_lock = threading.Lock()
|
||||
|
||||
# Global line_profiler state
|
||||
_line_profiler: Optional[Any] = None
|
||||
_line_profiler_lock = threading.Lock()
|
||||
_wrapped_functions: dict[str, Callable] = {} # Store original functions
|
||||
|
||||
|
||||
def _should_sample(profile_sampling_rate: float) -> bool:
|
||||
"""Determine if current request should be sampled based on sampling rate."""
|
||||
if profile_sampling_rate >= 1.0:
|
||||
return True # Always sample
|
||||
elif profile_sampling_rate <= 0.0:
|
||||
return False # Never sample
|
||||
|
||||
# Use deterministic sampling based on counter for consistent rate
|
||||
global _sample_counter
|
||||
with _sample_counter_lock:
|
||||
_sample_counter += 1
|
||||
# Sample based on rate (e.g., 0.1 means sample every 10th request)
|
||||
should_sample = (_sample_counter % int(1.0 / profile_sampling_rate)) == 0
|
||||
return should_sample
|
||||
|
||||
|
||||
def _start_profiling(profile_sampling_rate: float) -> None:
|
||||
"""Start cProfile profiling once globally."""
|
||||
global _profiler
|
||||
with _profile_lock:
|
||||
if _profiler is None:
|
||||
_profiler = cProfile.Profile()
|
||||
_profiler.enable()
|
||||
verbose_proxy_logger.info(
|
||||
f"Profiling started with sampling rate: {profile_sampling_rate}"
|
||||
)
|
||||
|
||||
|
||||
def _start_profiling_for_request(profile_sampling_rate: float) -> bool:
|
||||
"""Start profiling for a specific request (if sampling allows)."""
|
||||
if _should_sample(profile_sampling_rate):
|
||||
_start_profiling(profile_sampling_rate)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _save_stats(profile_file: PathLib) -> None:
|
||||
"""Save current stats directly to file."""
|
||||
with _profile_lock:
|
||||
if _profiler is None:
|
||||
return
|
||||
try:
|
||||
# Disable profiler temporarily to dump stats
|
||||
_profiler.disable()
|
||||
_profiler.dump_stats(str(profile_file))
|
||||
# Re-enable profiler to continue profiling
|
||||
_profiler.enable()
|
||||
verbose_proxy_logger.debug(f"Profiling stats saved to {profile_file}")
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error saving profiling stats: {e}")
|
||||
# Make sure profiler is re-enabled even if there's an error
|
||||
try:
|
||||
_profiler.enable()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def profile_endpoint(sampling_rate: float = 1.0):
|
||||
"""Decorator to sample endpoint hits and save to a profile file.
|
||||
|
||||
Args:
|
||||
sampling_rate: Rate of requests to profile (0.0 to 1.0)
|
||||
- 1.0: Profile all requests (100%)
|
||||
- 0.1: Profile 1 in 10 requests (10%)
|
||||
- 0.0: Profile no requests (0%)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
def set_last_profile_path(path: PathLib) -> None:
|
||||
global _last_profile_file_path
|
||||
_last_profile_file_path = path
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
is_sampling = _start_profiling_for_request(sampling_rate)
|
||||
file_path_obj = PathLib("endpoint_profile.pstat")
|
||||
set_last_profile_path(file_path_obj)
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
if is_sampling:
|
||||
_save_stats(file_path_obj)
|
||||
return result
|
||||
except Exception:
|
||||
if is_sampling:
|
||||
_save_stats(file_path_obj)
|
||||
raise
|
||||
|
||||
return async_wrapper
|
||||
else:
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
is_sampling = _start_profiling_for_request(sampling_rate)
|
||||
file_path_obj = PathLib("endpoint_profile.pstat")
|
||||
set_last_profile_path(file_path_obj)
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
if is_sampling:
|
||||
_save_stats(file_path_obj)
|
||||
return result
|
||||
except Exception:
|
||||
if is_sampling:
|
||||
_save_stats(file_path_obj)
|
||||
raise
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def enable_line_profiler() -> None:
|
||||
"""Enable line_profiler for dynamic function wrapping.
|
||||
|
||||
Raises:
|
||||
ImportError: If line_profiler is not available
|
||||
"""
|
||||
global _line_profiler
|
||||
from line_profiler import LineProfiler # Will raise ImportError if not available
|
||||
|
||||
with _line_profiler_lock:
|
||||
if _line_profiler is None:
|
||||
_line_profiler = LineProfiler()
|
||||
verbose_proxy_logger.info("Line profiler enabled")
|
||||
|
||||
|
||||
def wrap_function_with_line_profiler(module: Any, function_name: str) -> bool:
|
||||
"""Dynamically wrap a function with line_profiler.
|
||||
|
||||
Args:
|
||||
module: The module containing the function
|
||||
function_name: Name of the function to wrap
|
||||
|
||||
Returns:
|
||||
True if wrapping was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
enable_line_profiler() # May raise ImportError if not available
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
if _line_profiler is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
original_function = getattr(module, function_name, None)
|
||||
if original_function is None:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Function {function_name} not found in module {module.__name__}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Store original function if not already wrapped
|
||||
if function_name not in _wrapped_functions:
|
||||
_wrapped_functions[function_name] = original_function
|
||||
|
||||
# Wrap with line_profiler
|
||||
profiled_function = _line_profiler(original_function)
|
||||
setattr(module, function_name, profiled_function)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Wrapped {module.__name__}.{function_name} with line_profiler"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error wrapping {function_name} with line_profiler: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def wrap_function_directly(func: Callable) -> Callable:
|
||||
"""Wrap a function directly with line_profiler.
|
||||
|
||||
This is the recommended way to profile functions, especially closures or
|
||||
functions created dynamically (like wrapper_async in litellm/utils.py).
|
||||
|
||||
Args:
|
||||
func: The function to wrap
|
||||
|
||||
Returns:
|
||||
The wrapped function that will be profiled when called
|
||||
|
||||
Raises:
|
||||
ImportError: If line_profiler is not available
|
||||
RuntimeError: If line_profiler cannot be enabled or function cannot be wrapped
|
||||
"""
|
||||
import warnings
|
||||
|
||||
enable_line_profiler() # Will raise ImportError if not available
|
||||
|
||||
if _line_profiler is None:
|
||||
raise RuntimeError("Line profiler was not initialized")
|
||||
|
||||
# Suppress warnings about __wrapped__ - we intentionally want to profile the wrapper
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", message=".*__wrapped__.*", category=UserWarning
|
||||
)
|
||||
# Add function to line_profiler and wrap it
|
||||
_line_profiler.add_function(func)
|
||||
profiled_function = _line_profiler(func)
|
||||
|
||||
verbose_proxy_logger.info(f"Wrapped function {func.__name__} with line_profiler")
|
||||
return profiled_function
|
||||
|
||||
|
||||
def collect_line_profiler_stats(output_file: Optional[str] = None) -> None:
|
||||
"""Collect and save line_profiler statistics.
|
||||
|
||||
This can be called manually to collect stats at any time, or it's
|
||||
automatically called on shutdown if register_shutdown_handler() was used.
|
||||
|
||||
Args:
|
||||
output_file: Optional path to save stats. If None, prints to stdout.
|
||||
"""
|
||||
global _line_profiler
|
||||
|
||||
with _line_profiler_lock:
|
||||
if _line_profiler is None:
|
||||
verbose_proxy_logger.debug("Line profiler not enabled, nothing to collect")
|
||||
return
|
||||
|
||||
try:
|
||||
if output_file:
|
||||
# Save to file
|
||||
output_path = PathLib(output_file)
|
||||
_line_profiler.dump_stats(str(output_path))
|
||||
verbose_proxy_logger.info(f"Line profiler stats saved to {output_path}")
|
||||
else:
|
||||
# Print to stdout
|
||||
from io import StringIO
|
||||
|
||||
stream = StringIO()
|
||||
_line_profiler.print_stats(stream=stream)
|
||||
stats_output = stream.getvalue()
|
||||
verbose_proxy_logger.info("Line profiler stats:\n" + stats_output)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error collecting line profiler stats: {e}")
|
||||
|
||||
|
||||
def register_shutdown_handler(output_file: Optional[str] = None) -> None:
|
||||
"""Register a shutdown handler to collect line_profiler stats.
|
||||
|
||||
This registers an atexit handler that will automatically save profiling
|
||||
statistics when the Python process exits. Safe to call multiple times
|
||||
(only registers once).
|
||||
|
||||
Args:
|
||||
output_file: Optional path to save stats on shutdown.
|
||||
Defaults to 'line_profile_stats.lprof'
|
||||
"""
|
||||
if output_file is None:
|
||||
output_file = "line_profile_stats.lprof"
|
||||
|
||||
def shutdown_handler():
|
||||
collect_line_profiler_stats(output_file=output_file)
|
||||
|
||||
atexit.register(shutdown_handler)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Registered line_profiler shutdown handler for {output_file}"
|
||||
)
|
||||
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
This file is used to store the state variables of the proxy server.
|
||||
|
||||
Example: `spend_logs_row_count` is used to store the number of rows in the `LiteLLM_SpendLogs` table.
|
||||
"""
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from litellm.proxy._types import ProxyStateVariables
|
||||
|
||||
|
||||
class ProxyState:
|
||||
"""
|
||||
Proxy state class has get/set methods for Proxy state variables.
|
||||
"""
|
||||
|
||||
# Note: mypy does not recognize when we fetch ProxyStateVariables.annotations.keys(), so we also need to add the valid keys here
|
||||
valid_keys_literal = Literal["spend_logs_row_count"]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.proxy_state_variables: ProxyStateVariables = ProxyStateVariables(
|
||||
spend_logs_row_count=0,
|
||||
)
|
||||
|
||||
def get_proxy_state_variable(
|
||||
self,
|
||||
variable_name: valid_keys_literal,
|
||||
) -> Any:
|
||||
return self.proxy_state_variables.get(variable_name, None)
|
||||
|
||||
def set_proxy_state_variable(
|
||||
self,
|
||||
variable_name: valid_keys_literal,
|
||||
value: Any,
|
||||
) -> None:
|
||||
self.proxy_state_variables[variable_name] = value
|
||||
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
RBAC utility helpers for feature-level access control.
|
||||
|
||||
These helpers are used by agent and vector store endpoints to enforce
|
||||
proxy-admin-configurable toggles that restrict access for internal users.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||||
|
||||
FeatureName = Literal["agents", "vector_stores"]
|
||||
|
||||
|
||||
async def check_feature_access_for_user(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
feature_name: FeatureName,
|
||||
) -> None:
|
||||
"""
|
||||
Raise HTTP 403 if the user's role is blocked from accessing the given feature
|
||||
by the UI settings stored in general_settings.
|
||||
|
||||
Args:
|
||||
user_api_key_dict: The authenticated user.
|
||||
feature_name: Either "agents" or "vector_stores".
|
||||
"""
|
||||
# Proxy admins (and view-only admins) are never blocked.
|
||||
if user_api_key_dict.user_role in (
|
||||
LitellmUserRoles.PROXY_ADMIN,
|
||||
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
|
||||
LitellmUserRoles.PROXY_ADMIN.value,
|
||||
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value,
|
||||
):
|
||||
return
|
||||
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
prisma_client,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
disable_flag = f"disable_{feature_name}_for_internal_users"
|
||||
allow_team_admins_flag = f"allow_{feature_name}_for_team_admins"
|
||||
|
||||
if not general_settings.get(disable_flag, False):
|
||||
# Feature is not disabled — allow all authenticated users.
|
||||
return
|
||||
|
||||
# Feature is disabled. Check if team/org admins are exempted.
|
||||
if general_settings.get(allow_team_admins_flag, False):
|
||||
from litellm.proxy.management_endpoints.common_utils import (
|
||||
_user_has_admin_privileges,
|
||||
)
|
||||
|
||||
is_admin = await _user_has_admin_privileges(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
if is_admin:
|
||||
return
|
||||
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": f"Access to {feature_name} is disabled for your role. Contact your proxy admin."
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,13 @@
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
from litellm.constants import _REALTIME_BODY_CACHE_SIZE
|
||||
|
||||
|
||||
@lru_cache(maxsize=_REALTIME_BODY_CACHE_SIZE)
|
||||
def _realtime_request_body(model: Optional[str]) -> bytes:
|
||||
"""
|
||||
Generate the realtime websocket request body. Cached with LRU semantics to avoid repeated
|
||||
string formatting work while keeping memory usage bounded.
|
||||
"""
|
||||
return f'{{"model": "{model or ""}"}}'.encode()
|
||||
@@ -0,0 +1,619 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_BudgetTableFull,
|
||||
LiteLLM_EndUserTable,
|
||||
LiteLLM_TeamTable,
|
||||
LiteLLM_UserTable,
|
||||
LiteLLM_VerificationToken,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
|
||||
class ResetBudgetJob:
|
||||
"""
|
||||
Resets the budget for all the keys, users, and teams that need it
|
||||
"""
|
||||
|
||||
def __init__(self, proxy_logging_obj: ProxyLogging, prisma_client: PrismaClient):
|
||||
self.proxy_logging_obj: ProxyLogging = proxy_logging_obj
|
||||
self.prisma_client: PrismaClient = prisma_client
|
||||
|
||||
async def reset_budget(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
Gets all the non-expired keys for a db, which need spend to be reset
|
||||
|
||||
Resets their spend
|
||||
|
||||
Updates db
|
||||
"""
|
||||
if self.prisma_client is not None:
|
||||
### RESET KEY BUDGET ###
|
||||
await self.reset_budget_for_litellm_keys()
|
||||
|
||||
### RESET USER BUDGET ###
|
||||
await self.reset_budget_for_litellm_users()
|
||||
|
||||
## Reset Team Budget
|
||||
await self.reset_budget_for_litellm_teams()
|
||||
|
||||
### RESET ENDUSER (Customer) BUDGET and corresponding Budget duration ###
|
||||
await self.reset_budget_for_litellm_budget_table()
|
||||
|
||||
async def reset_budget_for_litellm_team_members(
|
||||
self, budgets_to_reset: List[LiteLLM_BudgetTableFull]
|
||||
):
|
||||
"""
|
||||
Resets the budget for all LiteLLM Team Members if their budget has expired
|
||||
"""
|
||||
return await self.prisma_client.db.litellm_teammembership.update_many(
|
||||
where={
|
||||
"budget_id": {
|
||||
"in": [
|
||||
budget.budget_id
|
||||
for budget in budgets_to_reset
|
||||
if budget.budget_id is not None
|
||||
]
|
||||
}
|
||||
},
|
||||
data={
|
||||
"spend": 0,
|
||||
},
|
||||
)
|
||||
|
||||
async def reset_budget_for_keys_linked_to_budgets(
|
||||
self, budgets_to_reset: List[LiteLLM_BudgetTableFull]
|
||||
):
|
||||
"""
|
||||
Resets the spend for keys linked to budget tiers that are being reset.
|
||||
|
||||
This handles keys that have budget_id but no budget_duration set on the key
|
||||
itself. Keys with budget_id rely on their linked budget tier's reset schedule
|
||||
rather than having their own budget_duration.
|
||||
|
||||
Keys that have their own budget_duration are already handled by
|
||||
reset_budget_for_litellm_keys() and are excluded here to avoid
|
||||
double-resetting.
|
||||
"""
|
||||
budget_ids = [
|
||||
budget.budget_id
|
||||
for budget in budgets_to_reset
|
||||
if budget.budget_id is not None
|
||||
]
|
||||
if not budget_ids:
|
||||
return
|
||||
|
||||
return await self.prisma_client.db.litellm_verificationtoken.update_many(
|
||||
where={
|
||||
"budget_id": {"in": budget_ids},
|
||||
"budget_duration": None, # only keys without their own reset schedule
|
||||
"spend": {"gt": 0}, # only reset keys that have accumulated spend
|
||||
},
|
||||
data={
|
||||
"spend": 0,
|
||||
},
|
||||
)
|
||||
|
||||
async def reset_budget_for_litellm_budget_table(self):
|
||||
"""
|
||||
Resets the budget for all LiteLLM End-Users (Customers), and Team Members if their budget has expired
|
||||
The corresponding Budget duration is also updated.
|
||||
"""
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
start_time = time.time()
|
||||
endusers_to_reset: Optional[List[LiteLLM_EndUserTable]] = None
|
||||
budgets_to_reset: Optional[List[LiteLLM_BudgetTableFull]] = None
|
||||
updated_endusers: List[LiteLLM_EndUserTable] = []
|
||||
failed_endusers = []
|
||||
try:
|
||||
budgets_to_reset = await self.prisma_client.get_data(
|
||||
table_name="budget", query_type="find_all", reset_at=now
|
||||
)
|
||||
|
||||
if budgets_to_reset is not None and len(budgets_to_reset) > 0:
|
||||
for budget in budgets_to_reset:
|
||||
budget = await ResetBudgetJob._reset_budget_reset_at_date(
|
||||
budget, now
|
||||
)
|
||||
|
||||
await self.prisma_client.update_data(
|
||||
query_type="update_many",
|
||||
data_list=budgets_to_reset,
|
||||
table_name="budget",
|
||||
)
|
||||
|
||||
endusers_to_reset = await self.prisma_client.get_data(
|
||||
table_name="enduser",
|
||||
query_type="find_all",
|
||||
budget_id_list=[
|
||||
budget.budget_id
|
||||
for budget in budgets_to_reset
|
||||
if budget.budget_id is not None
|
||||
],
|
||||
)
|
||||
|
||||
await self.reset_budget_for_litellm_team_members(
|
||||
budgets_to_reset=budgets_to_reset
|
||||
)
|
||||
|
||||
await self.reset_budget_for_keys_linked_to_budgets(
|
||||
budgets_to_reset=budgets_to_reset
|
||||
)
|
||||
|
||||
if endusers_to_reset is not None and len(endusers_to_reset) > 0:
|
||||
for enduser in endusers_to_reset:
|
||||
try:
|
||||
updated_enduser = (
|
||||
await ResetBudgetJob._reset_budget_for_enduser(
|
||||
enduser=enduser
|
||||
)
|
||||
)
|
||||
if updated_enduser is not None:
|
||||
updated_endusers.append(updated_enduser)
|
||||
else:
|
||||
failed_endusers.append(
|
||||
{
|
||||
"enduser": enduser,
|
||||
"error": "Returned None without exception",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
failed_endusers.append({"enduser": enduser, "error": str(e)})
|
||||
verbose_proxy_logger.exception(
|
||||
"Failed to reset budget for enduser: %s", enduser
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Updated users %s",
|
||||
json.dumps(updated_endusers, indent=4, default=str),
|
||||
)
|
||||
|
||||
await self.prisma_client.update_data(
|
||||
query_type="update_many",
|
||||
data_list=updated_endusers,
|
||||
table_name="enduser",
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
if len(failed_endusers) > 0: # If any endusers failed to reset
|
||||
raise Exception(
|
||||
f"Failed to reset {len(failed_endusers)} endusers: {json.dumps(failed_endusers, default=str)}"
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||
service=ServiceTypes.RESET_BUDGET_JOB,
|
||||
duration=end_time - start_time,
|
||||
call_type="reset_budget_budget_table",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"num_budgets_found": (
|
||||
len(budgets_to_reset) if budgets_to_reset else 0
|
||||
),
|
||||
"budgets_found": json.dumps(
|
||||
budgets_to_reset, indent=4, default=str
|
||||
),
|
||||
"num_endusers_found": (
|
||||
len(endusers_to_reset) if endusers_to_reset else 0
|
||||
),
|
||||
"endusers_found": json.dumps(
|
||||
endusers_to_reset, indent=4, default=str
|
||||
),
|
||||
"num_endusers_updated": len(updated_endusers),
|
||||
"endusers_updated": json.dumps(
|
||||
updated_endusers, indent=4, default=str
|
||||
),
|
||||
"num_endusers_failed": len(failed_endusers),
|
||||
"endusers_failed": json.dumps(
|
||||
failed_endusers, indent=4, default=str
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.RESET_BUDGET_JOB,
|
||||
duration=end_time - start_time,
|
||||
error=e,
|
||||
call_type="reset_budget_endusers",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"num_budgets_found": (
|
||||
len(budgets_to_reset) if budgets_to_reset else 0
|
||||
),
|
||||
"budgets_found": json.dumps(
|
||||
budgets_to_reset, indent=4, default=str
|
||||
),
|
||||
"num_endusers_found": (
|
||||
len(endusers_to_reset) if endusers_to_reset else 0
|
||||
),
|
||||
"endusers_found": json.dumps(
|
||||
endusers_to_reset, indent=4, default=str
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.exception("Failed to reset budget for endusers: %s", e)
|
||||
|
||||
async def reset_budget_for_litellm_keys(self):
|
||||
"""
|
||||
Resets the budget for all the litellm keys
|
||||
|
||||
Catches Exceptions and logs them
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
start_time = time.time()
|
||||
keys_to_reset: Optional[List[LiteLLM_VerificationToken]] = None
|
||||
try:
|
||||
keys_to_reset = await self.prisma_client.get_data(
|
||||
table_name="key", query_type="find_all", expires=now, reset_at=now
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"Keys to reset %s", json.dumps(keys_to_reset, indent=4, default=str)
|
||||
)
|
||||
updated_keys: List[LiteLLM_VerificationToken] = []
|
||||
failed_keys = []
|
||||
if keys_to_reset is not None and len(keys_to_reset) > 0:
|
||||
for key in keys_to_reset:
|
||||
try:
|
||||
updated_key = await ResetBudgetJob._reset_budget_for_key(
|
||||
key=key, current_time=now
|
||||
)
|
||||
if updated_key is not None:
|
||||
updated_keys.append(updated_key)
|
||||
else:
|
||||
failed_keys.append(
|
||||
{"key": key, "error": "Returned None without exception"}
|
||||
)
|
||||
except Exception as e:
|
||||
failed_keys.append({"key": key, "error": str(e)})
|
||||
verbose_proxy_logger.exception(
|
||||
"Failed to reset budget for key: %s", key
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Updated keys %s", json.dumps(updated_keys, indent=4, default=str)
|
||||
)
|
||||
|
||||
if updated_keys:
|
||||
await self.prisma_client.update_data(
|
||||
query_type="update_many",
|
||||
data_list=updated_keys,
|
||||
table_name="key",
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
if len(failed_keys) > 0: # If any keys failed to reset
|
||||
raise Exception(
|
||||
f"Failed to reset {len(failed_keys)} keys: {json.dumps(failed_keys, default=str)}"
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||
service=ServiceTypes.RESET_BUDGET_JOB,
|
||||
duration=end_time - start_time,
|
||||
call_type="reset_budget_keys",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"num_keys_found": len(keys_to_reset) if keys_to_reset else 0,
|
||||
"keys_found": json.dumps(keys_to_reset, indent=4, default=str),
|
||||
"num_keys_updated": len(updated_keys),
|
||||
"keys_updated": json.dumps(updated_keys, indent=4, default=str),
|
||||
"num_keys_failed": len(failed_keys),
|
||||
"keys_failed": json.dumps(failed_keys, indent=4, default=str),
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.RESET_BUDGET_JOB,
|
||||
duration=end_time - start_time,
|
||||
error=e,
|
||||
call_type="reset_budget_keys",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"num_keys_found": len(keys_to_reset) if keys_to_reset else 0,
|
||||
"keys_found": json.dumps(keys_to_reset, indent=4, default=str),
|
||||
},
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.exception("Failed to reset budget for keys: %s", e)
|
||||
|
||||
async def reset_budget_for_litellm_users(self):
|
||||
"""
|
||||
Resets the budget for all LiteLLM Internal Users if their budget has expired
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
start_time = time.time()
|
||||
users_to_reset: Optional[List[LiteLLM_UserTable]] = None
|
||||
try:
|
||||
users_to_reset = await self.prisma_client.get_data(
|
||||
table_name="user", query_type="find_all", reset_at=now
|
||||
)
|
||||
updated_users: List[LiteLLM_UserTable] = []
|
||||
failed_users = []
|
||||
if users_to_reset is not None and len(users_to_reset) > 0:
|
||||
for user in users_to_reset:
|
||||
try:
|
||||
updated_user = await ResetBudgetJob._reset_budget_for_user(
|
||||
user=user, current_time=now
|
||||
)
|
||||
if updated_user is not None:
|
||||
updated_users.append(updated_user)
|
||||
else:
|
||||
failed_users.append(
|
||||
{
|
||||
"user": user,
|
||||
"error": "Returned None without exception",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
failed_users.append({"user": user, "error": str(e)})
|
||||
verbose_proxy_logger.exception(
|
||||
"Failed to reset budget for user: %s", user
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Updated users %s", json.dumps(updated_users, indent=4, default=str)
|
||||
)
|
||||
if updated_users:
|
||||
await self.prisma_client.update_data(
|
||||
query_type="update_many",
|
||||
data_list=updated_users,
|
||||
table_name="user",
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
if len(failed_users) > 0: # If any users failed to reset
|
||||
raise Exception(
|
||||
f"Failed to reset {len(failed_users)} users: {json.dumps(failed_users, default=str)}"
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||
service=ServiceTypes.RESET_BUDGET_JOB,
|
||||
duration=end_time - start_time,
|
||||
call_type="reset_budget_users",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"num_users_found": len(users_to_reset) if users_to_reset else 0,
|
||||
"users_found": json.dumps(
|
||||
users_to_reset, indent=4, default=str
|
||||
),
|
||||
"num_users_updated": len(updated_users),
|
||||
"users_updated": json.dumps(
|
||||
updated_users, indent=4, default=str
|
||||
),
|
||||
"num_users_failed": len(failed_users),
|
||||
"users_failed": json.dumps(failed_users, indent=4, default=str),
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.RESET_BUDGET_JOB,
|
||||
duration=end_time - start_time,
|
||||
error=e,
|
||||
call_type="reset_budget_users",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"num_users_found": len(users_to_reset) if users_to_reset else 0,
|
||||
"users_found": json.dumps(
|
||||
users_to_reset, indent=4, default=str
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.exception("Failed to reset budget for users: %s", e)
|
||||
|
||||
async def reset_budget_for_litellm_teams(self):
|
||||
"""
|
||||
Resets the budget for all LiteLLM Internal Teams if their budget has expired
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
start_time = time.time()
|
||||
teams_to_reset: Optional[List[LiteLLM_TeamTable]] = None
|
||||
try:
|
||||
teams_to_reset = await self.prisma_client.get_data(
|
||||
table_name="team", query_type="find_all", reset_at=now
|
||||
)
|
||||
updated_teams: List[LiteLLM_TeamTable] = []
|
||||
failed_teams = []
|
||||
if teams_to_reset is not None and len(teams_to_reset) > 0:
|
||||
for team in teams_to_reset:
|
||||
try:
|
||||
updated_team = await ResetBudgetJob._reset_budget_for_team(
|
||||
team=team, current_time=now
|
||||
)
|
||||
if updated_team is not None:
|
||||
updated_teams.append(updated_team)
|
||||
else:
|
||||
failed_teams.append(
|
||||
{
|
||||
"team": team,
|
||||
"error": "Returned None without exception",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
failed_teams.append({"team": team, "error": str(e)})
|
||||
verbose_proxy_logger.exception(
|
||||
"Failed to reset budget for team: %s", team
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Updated teams %s", json.dumps(updated_teams, indent=4, default=str)
|
||||
)
|
||||
if updated_teams:
|
||||
await self.prisma_client.update_data(
|
||||
query_type="update_many",
|
||||
data_list=updated_teams,
|
||||
table_name="team",
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
if len(failed_teams) > 0: # If any teams failed to reset
|
||||
raise Exception(
|
||||
f"Failed to reset {len(failed_teams)} teams: {json.dumps(failed_teams, default=str)}"
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||
service=ServiceTypes.RESET_BUDGET_JOB,
|
||||
duration=end_time - start_time,
|
||||
call_type="reset_budget_teams",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"num_teams_found": len(teams_to_reset) if teams_to_reset else 0,
|
||||
"teams_found": json.dumps(
|
||||
teams_to_reset, indent=4, default=str
|
||||
),
|
||||
"num_teams_updated": len(updated_teams),
|
||||
"teams_updated": json.dumps(
|
||||
updated_teams, indent=4, default=str
|
||||
),
|
||||
"num_teams_failed": len(failed_teams),
|
||||
"teams_failed": json.dumps(failed_teams, indent=4, default=str),
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.RESET_BUDGET_JOB,
|
||||
duration=end_time - start_time,
|
||||
error=e,
|
||||
call_type="reset_budget_teams",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"num_teams_found": len(teams_to_reset) if teams_to_reset else 0,
|
||||
"teams_found": json.dumps(
|
||||
teams_to_reset, indent=4, default=str
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.exception("Failed to reset budget for teams: %s", e)
|
||||
|
||||
@staticmethod
|
||||
async def _reset_budget_common(
|
||||
item: Union[LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken],
|
||||
current_time: datetime,
|
||||
item_type: Literal["key", "team", "user"],
|
||||
):
|
||||
"""
|
||||
In-place, updates spend=0, and sets budget_reset_at to current_time + budget_duration
|
||||
|
||||
Common logic for resetting budget for a team, user, or key
|
||||
"""
|
||||
try:
|
||||
item.spend = 0.0
|
||||
if hasattr(item, "budget_duration") and item.budget_duration is not None:
|
||||
# Get standardized reset time based on budget duration
|
||||
from litellm.proxy.common_utils.timezone_utils import (
|
||||
get_budget_reset_time,
|
||||
)
|
||||
|
||||
item.budget_reset_at = get_budget_reset_time(
|
||||
budget_duration=item.budget_duration
|
||||
)
|
||||
return item
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"Error resetting budget for %s: %s. Item: %s", item_type, e, item
|
||||
)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
async def _reset_budget_for_team(
|
||||
team: LiteLLM_TeamTable, current_time: datetime
|
||||
) -> Optional[LiteLLM_TeamTable]:
|
||||
await ResetBudgetJob._reset_budget_common(
|
||||
item=team, current_time=current_time, item_type="team"
|
||||
)
|
||||
return team
|
||||
|
||||
@staticmethod
|
||||
async def _reset_budget_for_user(
|
||||
user: LiteLLM_UserTable, current_time: datetime
|
||||
) -> Optional[LiteLLM_UserTable]:
|
||||
await ResetBudgetJob._reset_budget_common(
|
||||
item=user, current_time=current_time, item_type="user"
|
||||
)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def _reset_budget_for_enduser(
|
||||
enduser: LiteLLM_EndUserTable,
|
||||
) -> Optional[LiteLLM_EndUserTable]:
|
||||
try:
|
||||
enduser.spend = 0.0
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"Error resetting budget for enduser: %s. Item: %s", e, enduser
|
||||
)
|
||||
raise e
|
||||
return enduser
|
||||
|
||||
@staticmethod
|
||||
async def _reset_budget_reset_at_date(
|
||||
budget: LiteLLM_BudgetTableFull, current_time: datetime
|
||||
) -> LiteLLM_BudgetTableFull:
|
||||
try:
|
||||
if budget.budget_duration is not None:
|
||||
from litellm.litellm_core_utils.duration_parser import (
|
||||
duration_in_seconds,
|
||||
)
|
||||
|
||||
duration_s = duration_in_seconds(duration=budget.budget_duration)
|
||||
|
||||
# Fallback for existing budgets that do not have a budget_reset_at date set, ensuring the duration is taken into account
|
||||
if (
|
||||
budget.budget_reset_at is None
|
||||
and budget.created_at + timedelta(seconds=duration_s) > current_time
|
||||
):
|
||||
budget.budget_reset_at = budget.created_at + timedelta(
|
||||
seconds=duration_s
|
||||
)
|
||||
else:
|
||||
budget.budget_reset_at = current_time + timedelta(
|
||||
seconds=duration_s
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"Error resetting budget_reset_at for budget: %s. Item: %s", e, budget
|
||||
)
|
||||
raise e
|
||||
return budget
|
||||
|
||||
@staticmethod
|
||||
async def _reset_budget_for_key(
|
||||
key: LiteLLM_VerificationToken, current_time: datetime
|
||||
) -> Optional[LiteLLM_VerificationToken]:
|
||||
await ResetBudgetJob._reset_budget_common(
|
||||
item=key, current_time=current_time, item_type="key"
|
||||
)
|
||||
return key
|
||||
@@ -0,0 +1,48 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from litellm.exceptions import LITELLM_EXCEPTION_TYPES
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
detail: Dict[str, Any] = Field(
|
||||
...,
|
||||
example={ # type: ignore
|
||||
"error": {
|
||||
"message": "Error message",
|
||||
"type": "error_type",
|
||||
"param": "error_param",
|
||||
"code": "error_code",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Define a function to get the status code
|
||||
def get_status_code(exception):
|
||||
if hasattr(exception, "status_code"):
|
||||
return exception.status_code
|
||||
# Default status codes for exceptions without a status_code attribute
|
||||
if exception.__name__ == "Timeout":
|
||||
return 408 # Request Timeout
|
||||
if exception.__name__ == "APIConnectionError":
|
||||
return 503 # Service Unavailable
|
||||
return 500 # Internal Server Error as default
|
||||
|
||||
|
||||
# Create error responses
|
||||
ERROR_RESPONSES = {
|
||||
get_status_code(exception): {
|
||||
"model": ErrorResponse,
|
||||
"description": exception.__doc__ or exception.__name__,
|
||||
}
|
||||
for exception in LITELLM_EXCEPTION_TYPES
|
||||
}
|
||||
|
||||
# Ensure we have a 500 error response
|
||||
if 500 not in ERROR_RESPONSES:
|
||||
ERROR_RESPONSES[500] = {
|
||||
"model": ErrorResponse,
|
||||
"description": "Internal Server Error",
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.duration_parser import get_next_standardized_reset_time
|
||||
|
||||
|
||||
def get_budget_reset_timezone():
|
||||
"""
|
||||
Get the budget reset timezone from litellm_settings.
|
||||
Falls back to UTC if not specified.
|
||||
|
||||
litellm_settings values are set as attributes on the litellm module
|
||||
by proxy_server.py at startup (via setattr(litellm, key, value)).
|
||||
"""
|
||||
return getattr(litellm, "timezone", None) or "UTC"
|
||||
|
||||
|
||||
def get_budget_reset_time(budget_duration: str):
|
||||
"""
|
||||
Get the budget reset time based on the configured timezone.
|
||||
Falls back to UTC if not specified.
|
||||
"""
|
||||
|
||||
reset_at = get_next_standardized_reset_time(
|
||||
duration=budget_duration,
|
||||
current_time=datetime.now(timezone.utc),
|
||||
timezone_str=get_budget_reset_timezone(),
|
||||
)
|
||||
return reset_at
|
||||
Reference in New Issue
Block a user