chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Base class for in memory buffer for database transactions
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._service_logger import ServiceLogging
|
||||
|
||||
service_logger_obj = (
|
||||
ServiceLogging()
|
||||
) # used for tracking metrics for In memory buffer, redis buffer, pod lock manager
|
||||
from litellm.constants import (
|
||||
LITELLM_ASYNCIO_QUEUE_MAXSIZE,
|
||||
MAX_IN_MEMORY_QUEUE_FLUSH_COUNT,
|
||||
MAX_SIZE_IN_MEMORY_QUEUE,
|
||||
)
|
||||
|
||||
|
||||
class BaseUpdateQueue:
|
||||
"""Base class for in memory buffer for database transactions"""
|
||||
|
||||
def __init__(self):
|
||||
self.update_queue = asyncio.Queue(maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE)
|
||||
self.MAX_SIZE_IN_MEMORY_QUEUE = MAX_SIZE_IN_MEMORY_QUEUE
|
||||
if MAX_SIZE_IN_MEMORY_QUEUE >= LITELLM_ASYNCIO_QUEUE_MAXSIZE:
|
||||
verbose_proxy_logger.warning(
|
||||
"Misconfigured queue thresholds: MAX_SIZE_IN_MEMORY_QUEUE (%d) >= LITELLM_ASYNCIO_QUEUE_MAXSIZE (%d). "
|
||||
"The spend aggregation check will never trigger because the asyncio.Queue blocks at %d items. "
|
||||
"Set MAX_SIZE_IN_MEMORY_QUEUE to a value less than LITELLM_ASYNCIO_QUEUE_MAXSIZE (recommended: 80%% of it).",
|
||||
MAX_SIZE_IN_MEMORY_QUEUE,
|
||||
LITELLM_ASYNCIO_QUEUE_MAXSIZE,
|
||||
LITELLM_ASYNCIO_QUEUE_MAXSIZE,
|
||||
)
|
||||
|
||||
async def add_update(self, update):
|
||||
"""Enqueue an update."""
|
||||
verbose_proxy_logger.debug("Adding update to queue: %s", update)
|
||||
await self.update_queue.put(update)
|
||||
await self._emit_new_item_added_to_queue_event(
|
||||
queue_size=self.update_queue.qsize()
|
||||
)
|
||||
|
||||
async def flush_all_updates_from_in_memory_queue(self):
|
||||
"""Get all updates from the queue."""
|
||||
updates = []
|
||||
while not self.update_queue.empty():
|
||||
# Circuit breaker to ensure we're not stuck dequeuing updates. Protect CPU utilization
|
||||
if len(updates) >= MAX_IN_MEMORY_QUEUE_FLUSH_COUNT:
|
||||
verbose_proxy_logger.debug(
|
||||
"Max in memory queue flush count reached, stopping flush"
|
||||
)
|
||||
break
|
||||
updates.append(await self.update_queue.get())
|
||||
return updates
|
||||
|
||||
async def _emit_new_item_added_to_queue_event(
|
||||
self,
|
||||
queue_size: Optional[int] = None,
|
||||
):
|
||||
"""placeholder, emit event when a new item is added to the queue"""
|
||||
pass
|
||||
@@ -0,0 +1,155 @@
|
||||
import asyncio
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import LITELLM_ASYNCIO_QUEUE_MAXSIZE
|
||||
from litellm.proxy._types import BaseDailySpendTransaction
|
||||
from litellm.proxy.db.db_transaction_queue.base_update_queue import (
|
||||
BaseUpdateQueue,
|
||||
service_logger_obj,
|
||||
)
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
|
||||
class DailySpendUpdateQueue(BaseUpdateQueue):
|
||||
"""
|
||||
In memory buffer for daily spend updates that should be committed to the database
|
||||
|
||||
To add a new daily spend update transaction, use the following format:
|
||||
daily_spend_update_queue.add_update({
|
||||
"user1_date_api_key_model_custom_llm_provider": {
|
||||
"spend": 10,
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 100,
|
||||
}
|
||||
})
|
||||
|
||||
Queue contains a list of daily spend update transactions
|
||||
|
||||
eg
|
||||
queue = [
|
||||
{
|
||||
"user1_date_api_key_model_custom_llm_provider": {
|
||||
"spend": 10,
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 100,
|
||||
"api_requests": 100,
|
||||
"successful_requests": 100,
|
||||
"failed_requests": 100,
|
||||
}
|
||||
},
|
||||
{
|
||||
"user2_date_api_key_model_custom_llm_provider": {
|
||||
"spend": 10,
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 100,
|
||||
"api_requests": 100,
|
||||
"successful_requests": 100,
|
||||
"failed_requests": 100,
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.update_queue: asyncio.Queue[
|
||||
Dict[str, BaseDailySpendTransaction]
|
||||
] = asyncio.Queue(maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE)
|
||||
|
||||
async def add_update(self, update: Dict[str, BaseDailySpendTransaction]):
|
||||
"""Enqueue an update."""
|
||||
verbose_proxy_logger.debug("Adding update to queue: %s", update)
|
||||
await self.update_queue.put(update)
|
||||
if self.update_queue.qsize() >= self.MAX_SIZE_IN_MEMORY_QUEUE:
|
||||
verbose_proxy_logger.warning(
|
||||
"Spend update queue is full. Aggregating all entries in queue to concatenate entries."
|
||||
)
|
||||
await self.aggregate_queue_updates()
|
||||
|
||||
async def aggregate_queue_updates(self):
|
||||
"""
|
||||
Combine all updates in the queue into a single update.
|
||||
This is used to reduce the size of the in-memory queue.
|
||||
"""
|
||||
updates: List[
|
||||
Dict[str, BaseDailySpendTransaction]
|
||||
] = await self.flush_all_updates_from_in_memory_queue()
|
||||
aggregated_updates = self.get_aggregated_daily_spend_update_transactions(
|
||||
updates
|
||||
)
|
||||
await self.update_queue.put(aggregated_updates)
|
||||
|
||||
async def flush_and_get_aggregated_daily_spend_update_transactions(
|
||||
self,
|
||||
) -> Dict[str, BaseDailySpendTransaction]:
|
||||
"""Get all updates from the queue and return all updates aggregated by daily_transaction_key. Works for both user and team spend updates."""
|
||||
updates = await self.flush_all_updates_from_in_memory_queue()
|
||||
if len(updates) > 0:
|
||||
verbose_proxy_logger.info(
|
||||
"Spend tracking - flushed %d daily spend update items from in-memory queue",
|
||||
len(updates),
|
||||
)
|
||||
aggregated_daily_spend_update_transactions = (
|
||||
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
updates
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"Aggregated daily spend update transactions: %s",
|
||||
aggregated_daily_spend_update_transactions,
|
||||
)
|
||||
return aggregated_daily_spend_update_transactions
|
||||
|
||||
@staticmethod
|
||||
def get_aggregated_daily_spend_update_transactions(
|
||||
updates: List[Dict[str, BaseDailySpendTransaction]],
|
||||
) -> Dict[str, BaseDailySpendTransaction]:
|
||||
"""Aggregate updates by daily_transaction_key."""
|
||||
aggregated_daily_spend_update_transactions: Dict[
|
||||
str, BaseDailySpendTransaction
|
||||
] = {}
|
||||
for _update in updates:
|
||||
for _key, payload in _update.items():
|
||||
if _key in aggregated_daily_spend_update_transactions:
|
||||
daily_transaction = aggregated_daily_spend_update_transactions[_key]
|
||||
daily_transaction["spend"] += payload["spend"]
|
||||
daily_transaction["prompt_tokens"] += payload["prompt_tokens"]
|
||||
daily_transaction["completion_tokens"] += payload[
|
||||
"completion_tokens"
|
||||
]
|
||||
daily_transaction["api_requests"] += payload["api_requests"]
|
||||
daily_transaction["successful_requests"] += payload[
|
||||
"successful_requests"
|
||||
]
|
||||
daily_transaction["failed_requests"] += payload["failed_requests"]
|
||||
|
||||
# Add optional metrics cache_read_input_tokens and cache_creation_input_tokens
|
||||
daily_transaction["cache_read_input_tokens"] = (
|
||||
payload.get("cache_read_input_tokens", 0) or 0
|
||||
) + daily_transaction.get("cache_read_input_tokens", 0)
|
||||
|
||||
daily_transaction["cache_creation_input_tokens"] = (
|
||||
payload.get("cache_creation_input_tokens", 0) or 0
|
||||
) + daily_transaction.get("cache_creation_input_tokens", 0)
|
||||
|
||||
else:
|
||||
aggregated_daily_spend_update_transactions[_key] = deepcopy(payload)
|
||||
return aggregated_daily_spend_update_transactions
|
||||
|
||||
async def _emit_new_item_added_to_queue_event(
|
||||
self,
|
||||
queue_size: Optional[int] = None,
|
||||
):
|
||||
asyncio.create_task(
|
||||
service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.IN_MEMORY_DAILY_SPEND_UPDATE_QUEUE,
|
||||
duration=0,
|
||||
call_type="_emit_new_item_added_to_queue_event",
|
||||
event_metadata={
|
||||
"gauge_labels": ServiceTypes.IN_MEMORY_DAILY_SPEND_UPDATE_QUEUE,
|
||||
"gauge_value": queue_size,
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,186 @@
|
||||
import asyncio
|
||||
from litellm._uuid import uuid
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
|
||||
from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
ProxyLogging = Any
|
||||
else:
|
||||
ProxyLogging = Any
|
||||
|
||||
|
||||
class PodLockManager:
|
||||
"""
|
||||
Manager for acquiring and releasing locks for cron jobs using Redis.
|
||||
|
||||
Ensures that only one pod can run a cron job at a time.
|
||||
"""
|
||||
|
||||
def __init__(self, redis_cache: Optional[RedisCache] = None):
|
||||
self.pod_id = str(uuid.uuid4())
|
||||
self.redis_cache = redis_cache
|
||||
|
||||
@staticmethod
|
||||
def get_redis_lock_key(cronjob_id: str) -> str:
|
||||
return f"cronjob_lock:{cronjob_id}"
|
||||
|
||||
async def acquire_lock(
|
||||
self,
|
||||
cronjob_id: str,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
Attempt to acquire the lock for a specific cron job using Redis.
|
||||
Uses the SET command with NX and EX options to ensure atomicity.
|
||||
|
||||
Args:
|
||||
cronjob_id: The ID of the cron job to lock
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
verbose_proxy_logger.debug("redis_cache is None, skipping acquire_lock")
|
||||
return None
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
"Pod %s attempting to acquire Redis lock for cronjob_id=%s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
# Try to set the lock key with the pod_id as its value, only if it doesn't exist (NX)
|
||||
# and with an expiration (EX) to avoid deadlocks.
|
||||
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
|
||||
acquired = await self.redis_cache.async_set_cache(
|
||||
lock_key,
|
||||
self.pod_id,
|
||||
nx=True,
|
||||
ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
|
||||
)
|
||||
if acquired:
|
||||
verbose_proxy_logger.info(
|
||||
"Pod %s successfully acquired Redis lock for cronjob_id=%s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
|
||||
return True
|
||||
else:
|
||||
# Check if the current pod already holds the lock
|
||||
current_value = await self.redis_cache.async_get_cache(lock_key)
|
||||
if current_value is not None:
|
||||
if isinstance(current_value, bytes):
|
||||
current_value = current_value.decode("utf-8")
|
||||
if current_value == self.pod_id:
|
||||
verbose_proxy_logger.info(
|
||||
"Pod %s already holds the Redis lock for cronjob_id=%s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
self._emit_acquired_lock_event(cronjob_id, self.pod_id)
|
||||
return True
|
||||
else:
|
||||
verbose_proxy_logger.info(
|
||||
"Spend tracking - pod %s could not acquire lock for cronjob_id=%s, "
|
||||
"held by pod %s. Spend updates in Redis will wait for the leader pod to commit.",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
current_value,
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error acquiring Redis lock for {cronjob_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
async def release_lock(
|
||||
self,
|
||||
cronjob_id: str,
|
||||
):
|
||||
"""
|
||||
Release the lock if the current pod holds it.
|
||||
Uses get and delete commands to ensure that only the owner can release the lock.
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
verbose_proxy_logger.debug("redis_cache is None, skipping release_lock")
|
||||
return
|
||||
try:
|
||||
cronjob_id = cronjob_id
|
||||
verbose_proxy_logger.debug(
|
||||
"Pod %s attempting to release Redis lock for cronjob_id=%s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
|
||||
|
||||
current_value = await self.redis_cache.async_get_cache(lock_key)
|
||||
if current_value is not None:
|
||||
if isinstance(current_value, bytes):
|
||||
current_value = current_value.decode("utf-8")
|
||||
if current_value == self.pod_id:
|
||||
result = await self.redis_cache.async_delete_cache(lock_key)
|
||||
if result == 1:
|
||||
verbose_proxy_logger.info(
|
||||
"Pod %s successfully released Redis lock for cronjob_id=%s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
self._emit_released_lock_event(
|
||||
cronjob_id=cronjob_id,
|
||||
pod_id=self.pod_id,
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"Spend tracking - pod %s failed to release Redis lock for cronjob_id=%s. "
|
||||
"Lock will expire after TTL=%ds.",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
"Pod %s cannot release Redis lock for cronjob_id=%s because it is held by pod %s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
current_value,
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
"Pod %s attempted to release Redis lock for cronjob_id=%s, but no lock was found",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error releasing Redis lock for {cronjob_id}: {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _emit_acquired_lock_event(cronjob_id: str, pod_id: str):
|
||||
asyncio.create_task(
|
||||
service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.POD_LOCK_MANAGER,
|
||||
duration=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
|
||||
call_type="_emit_acquired_lock_event",
|
||||
event_metadata={
|
||||
"gauge_labels": f"{cronjob_id}:{pod_id}",
|
||||
"gauge_value": 1,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _emit_released_lock_event(cronjob_id: str, pod_id: str):
|
||||
asyncio.create_task(
|
||||
service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.POD_LOCK_MANAGER,
|
||||
duration=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
|
||||
call_type="_emit_released_lock_event",
|
||||
event_metadata={
|
||||
"gauge_labels": f"{cronjob_id}:{pod_id}",
|
||||
"gauge_value": 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,677 @@
|
||||
"""
|
||||
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
|
||||
|
||||
This is to prevent deadlocks and improve reliability
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import RedisCache
|
||||
from litellm.constants import (
|
||||
MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
REDIS_DAILY_AGENT_SPEND_UPDATE_BUFFER_KEY,
|
||||
REDIS_DAILY_END_USER_SPEND_UPDATE_BUFFER_KEY,
|
||||
REDIS_DAILY_ORG_SPEND_UPDATE_BUFFER_KEY,
|
||||
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
|
||||
REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
|
||||
REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
|
||||
REDIS_UPDATE_BUFFER_KEY,
|
||||
)
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy._types import (
|
||||
DailyAgentSpendTransaction,
|
||||
DailyEndUserSpendTransaction,
|
||||
DailyOrganizationSpendTransaction,
|
||||
DailyTagSpendTransaction,
|
||||
DailyTeamSpendTransaction,
|
||||
DailyUserSpendTransaction,
|
||||
DBSpendUpdateTransactions,
|
||||
)
|
||||
from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj
|
||||
from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import (
|
||||
DailySpendUpdateQueue,
|
||||
)
|
||||
from litellm.proxy.db.db_transaction_queue.spend_update_queue import SpendUpdateQueue
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
from litellm.types.caching import (
|
||||
RedisPipelineLpopOperation,
|
||||
RedisPipelineRpushOperation,
|
||||
)
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
else:
|
||||
PrismaClient = Any
|
||||
|
||||
|
||||
class RedisUpdateBuffer:
|
||||
"""
|
||||
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
|
||||
|
||||
This is to prevent deadlocks and improve reliability
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_cache: Optional[RedisCache] = None,
|
||||
):
|
||||
self.redis_cache = redis_cache
|
||||
|
||||
@staticmethod
|
||||
def _should_commit_spend_updates_to_redis() -> bool:
|
||||
"""
|
||||
Checks if the Pod should commit spend updates to Redis
|
||||
|
||||
This setting enables buffering database transactions in Redis
|
||||
to improve reliability and reduce database contention
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
_use_redis_transaction_buffer: Optional[
|
||||
Union[bool, str]
|
||||
] = general_settings.get("use_redis_transaction_buffer", False)
|
||||
if isinstance(_use_redis_transaction_buffer, str):
|
||||
_use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer)
|
||||
if _use_redis_transaction_buffer is None:
|
||||
return False
|
||||
return _use_redis_transaction_buffer
|
||||
|
||||
async def _store_transactions_in_redis(
|
||||
self,
|
||||
transactions: Any,
|
||||
redis_key: str,
|
||||
service_type: ServiceTypes,
|
||||
) -> None:
|
||||
"""
|
||||
Helper method to store transactions in Redis and emit an event
|
||||
|
||||
Args:
|
||||
transactions: The transactions to store
|
||||
redis_key: The Redis key to store under
|
||||
service_type: The service type for event emission
|
||||
"""
|
||||
if transactions is None or len(transactions) == 0:
|
||||
return
|
||||
|
||||
list_of_transactions = [safe_dumps(transactions)]
|
||||
if self.redis_cache is None:
|
||||
return
|
||||
try:
|
||||
current_redis_buffer_size = await self.redis_cache.async_rpush(
|
||||
key=redis_key,
|
||||
values=list_of_transactions,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"Spend tracking - pushed spend updates to Redis buffer. "
|
||||
"redis_key=%s, buffer_size=%s",
|
||||
redis_key,
|
||||
current_redis_buffer_size,
|
||||
)
|
||||
await self._emit_new_item_added_to_redis_buffer_event(
|
||||
queue_size=current_redis_buffer_size,
|
||||
service=service_type,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"Spend tracking - failed to push spend updates to Redis (redis_key=%s). "
|
||||
"Error: %s",
|
||||
redis_key,
|
||||
str(e),
|
||||
)
|
||||
|
||||
async def store_in_memory_spend_updates_in_redis(
|
||||
self,
|
||||
spend_update_queue: SpendUpdateQueue,
|
||||
daily_spend_update_queue: DailySpendUpdateQueue,
|
||||
daily_team_spend_update_queue: DailySpendUpdateQueue,
|
||||
daily_org_spend_update_queue: DailySpendUpdateQueue,
|
||||
daily_end_user_spend_update_queue: DailySpendUpdateQueue,
|
||||
daily_agent_spend_update_queue: DailySpendUpdateQueue,
|
||||
daily_tag_spend_update_queue: DailySpendUpdateQueue,
|
||||
):
|
||||
"""
|
||||
Stores the in-memory spend updates to Redis
|
||||
|
||||
Stores the following in memory data structures in Redis:
|
||||
- SpendUpdateQueue - Key, User, Team, TeamMember, Org, EndUser Spend updates
|
||||
- DailySpendUpdateQueue - Daily Spend updates Aggregate view
|
||||
|
||||
For SpendUpdateQueue:
|
||||
Each transaction is a dict stored as following:
|
||||
- key is the entity id
|
||||
- value is the spend amount
|
||||
|
||||
```
|
||||
Redis List:
|
||||
key_list_transactions:
|
||||
[
|
||||
"0929880201": 1.2,
|
||||
"0929880202": 0.01,
|
||||
"0929880203": 0.001,
|
||||
]
|
||||
```
|
||||
|
||||
For DailySpendUpdateQueue:
|
||||
Each transaction is a Dict[str, DailyUserSpendTransaction] stored as following:
|
||||
- key is the daily_transaction_key
|
||||
- value is the DailyUserSpendTransaction
|
||||
|
||||
```
|
||||
Redis List:
|
||||
daily_spend_update_transactions:
|
||||
[
|
||||
{
|
||||
"user_keyhash_1_model_1": {
|
||||
"spend": 1.2,
|
||||
"prompt_tokens": 1000,
|
||||
"completion_tokens": 1000,
|
||||
"api_requests": 1000,
|
||||
"successful_requests": 1000,
|
||||
},
|
||||
|
||||
}
|
||||
]
|
||||
```
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"redis_cache is None, skipping store_in_memory_spend_updates_in_redis"
|
||||
)
|
||||
return
|
||||
|
||||
# Get all transactions
|
||||
db_spend_update_transactions = (
|
||||
await spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||
)
|
||||
daily_spend_update_transactions = (
|
||||
await daily_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
|
||||
)
|
||||
daily_team_spend_update_transactions = (
|
||||
await daily_team_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
|
||||
)
|
||||
daily_org_spend_update_transactions = (
|
||||
await daily_org_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
|
||||
)
|
||||
daily_end_user_spend_update_transactions = (
|
||||
await daily_end_user_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
|
||||
)
|
||||
daily_agent_spend_update_transactions = (
|
||||
await daily_agent_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
|
||||
)
|
||||
daily_tag_spend_update_transactions = (
|
||||
await daily_tag_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"ALL DB SPEND UPDATE TRANSACTIONS: %s", db_spend_update_transactions
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"ALL DAILY SPEND UPDATE TRANSACTIONS: %s", daily_spend_update_transactions
|
||||
)
|
||||
|
||||
# Build a list of rpush operations, skipping empty/None transaction sets
|
||||
_queue_configs: List[Tuple[Any, str, ServiceTypes]] = [
|
||||
(
|
||||
db_spend_update_transactions,
|
||||
REDIS_UPDATE_BUFFER_KEY,
|
||||
ServiceTypes.REDIS_SPEND_UPDATE_QUEUE,
|
||||
),
|
||||
(
|
||||
daily_spend_update_transactions,
|
||||
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
|
||||
ServiceTypes.REDIS_DAILY_SPEND_UPDATE_QUEUE,
|
||||
),
|
||||
(
|
||||
daily_team_spend_update_transactions,
|
||||
REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
|
||||
ServiceTypes.REDIS_DAILY_TEAM_SPEND_UPDATE_QUEUE,
|
||||
),
|
||||
(
|
||||
daily_org_spend_update_transactions,
|
||||
REDIS_DAILY_ORG_SPEND_UPDATE_BUFFER_KEY,
|
||||
ServiceTypes.REDIS_DAILY_ORG_SPEND_UPDATE_QUEUE,
|
||||
),
|
||||
(
|
||||
daily_end_user_spend_update_transactions,
|
||||
REDIS_DAILY_END_USER_SPEND_UPDATE_BUFFER_KEY,
|
||||
ServiceTypes.REDIS_DAILY_END_USER_SPEND_UPDATE_QUEUE,
|
||||
),
|
||||
(
|
||||
daily_agent_spend_update_transactions,
|
||||
REDIS_DAILY_AGENT_SPEND_UPDATE_BUFFER_KEY,
|
||||
ServiceTypes.REDIS_DAILY_AGENT_SPEND_UPDATE_QUEUE,
|
||||
),
|
||||
(
|
||||
daily_tag_spend_update_transactions,
|
||||
REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
|
||||
ServiceTypes.REDIS_DAILY_TAG_SPEND_UPDATE_QUEUE,
|
||||
),
|
||||
]
|
||||
|
||||
rpush_list: List[RedisPipelineRpushOperation] = []
|
||||
service_types: List[ServiceTypes] = []
|
||||
for transactions, redis_key, service_type in _queue_configs:
|
||||
if transactions is None or len(transactions) == 0:
|
||||
continue
|
||||
rpush_list.append(
|
||||
RedisPipelineRpushOperation(
|
||||
key=redis_key,
|
||||
values=[safe_dumps(transactions)],
|
||||
)
|
||||
)
|
||||
service_types.append(service_type)
|
||||
|
||||
if len(rpush_list) == 0:
|
||||
return
|
||||
|
||||
result_lengths = await self.redis_cache.async_rpush_pipeline(
|
||||
rpush_list=rpush_list,
|
||||
)
|
||||
|
||||
# Emit gauge events for each queue
|
||||
for i, queue_size in enumerate(result_lengths):
|
||||
if i < len(service_types):
|
||||
await self._emit_new_item_added_to_redis_buffer_event(
|
||||
queue_size=queue_size,
|
||||
service=service_types[i],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _number_of_transactions_to_store_in_redis(
|
||||
db_spend_update_transactions: DBSpendUpdateTransactions,
|
||||
) -> int:
|
||||
"""
|
||||
Gets the number of transactions to store in Redis
|
||||
"""
|
||||
num_transactions = 0
|
||||
for v in db_spend_update_transactions.values():
|
||||
if isinstance(v, dict):
|
||||
num_transactions += len(v)
|
||||
return num_transactions
|
||||
|
||||
@staticmethod
|
||||
def _remove_prefix_from_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Removes the specified prefix from the keys of a dictionary.
|
||||
"""
|
||||
return {key.replace(prefix, "", 1): value for key, value in data.items()}
|
||||
|
||||
async def get_all_update_transactions_from_redis_buffer(
|
||||
self,
|
||||
) -> Optional[DBSpendUpdateTransactions]:
|
||||
"""
|
||||
Gets all the update transactions from Redis
|
||||
|
||||
On Redis we store a list of transactions as a JSON string
|
||||
|
||||
eg.
|
||||
[
|
||||
DBSpendUpdateTransactions(
|
||||
user_list_transactions={
|
||||
"user_id_1": 1.2,
|
||||
"user_id_2": 0.01,
|
||||
},
|
||||
end_user_list_transactions={},
|
||||
key_list_transactions={
|
||||
"0929880201": 1.2,
|
||||
"0929880202": 0.01,
|
||||
},
|
||||
team_list_transactions={},
|
||||
team_member_list_transactions={},
|
||||
org_list_transactions={},
|
||||
),
|
||||
DBSpendUpdateTransactions(
|
||||
user_list_transactions={
|
||||
"user_id_3": 1.2,
|
||||
"user_id_4": 0.01,
|
||||
},
|
||||
end_user_list_transactions={},
|
||||
key_list_transactions={
|
||||
"key_id_1": 1.2,
|
||||
"key_id_2": 0.01,
|
||||
},
|
||||
team_list_transactions={},
|
||||
team_member_list_transactions={},
|
||||
org_list_transactions={},
|
||||
]
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None
|
||||
list_of_transactions = await self.redis_cache.async_lpop(
|
||||
key=REDIS_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
)
|
||||
if list_of_transactions is None:
|
||||
return None
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
"Spend tracking - popped %d spend update batches from Redis buffer (key=%s). "
|
||||
"These items are now removed from Redis and must be committed to DB.",
|
||||
len(list_of_transactions) if isinstance(list_of_transactions, list) else 1,
|
||||
REDIS_UPDATE_BUFFER_KEY,
|
||||
)
|
||||
|
||||
# Parse the list of transactions from JSON strings
|
||||
parsed_transactions = self._parse_list_of_transactions(list_of_transactions)
|
||||
|
||||
# If there are no transactions, return None
|
||||
if len(parsed_transactions) == 0:
|
||||
return None
|
||||
|
||||
# Combine all transactions into a single transaction
|
||||
combined_transaction = self._combine_list_of_transactions(parsed_transactions)
|
||||
|
||||
return combined_transaction
|
||||
|
||||
async def get_all_transactions_from_redis_buffer_pipeline(
|
||||
self,
|
||||
) -> Tuple[
|
||||
Optional[DBSpendUpdateTransactions],
|
||||
Optional[Dict[str, DailyUserSpendTransaction]],
|
||||
Optional[Dict[str, DailyTeamSpendTransaction]],
|
||||
Optional[Dict[str, DailyOrganizationSpendTransaction]],
|
||||
Optional[Dict[str, DailyEndUserSpendTransaction]],
|
||||
Optional[Dict[str, DailyAgentSpendTransaction]],
|
||||
Optional[Dict[str, DailyTagSpendTransaction]],
|
||||
]:
|
||||
"""
|
||||
Drains all 7 Redis buffer queues in a single pipeline round-trip.
|
||||
|
||||
Returns a 7-tuple of parsed results in this order:
|
||||
0: DBSpendUpdateTransactions
|
||||
1: daily user spend
|
||||
2: daily team spend
|
||||
3: daily org spend
|
||||
4: daily end-user spend
|
||||
5: daily agent spend
|
||||
6: daily tag spend
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None, None, None, None, None, None, None
|
||||
|
||||
lpop_list: List[RedisPipelineLpopOperation] = [
|
||||
RedisPipelineLpopOperation(
|
||||
key=REDIS_UPDATE_BUFFER_KEY, count=MAX_REDIS_BUFFER_DEQUEUE_COUNT
|
||||
),
|
||||
RedisPipelineLpopOperation(
|
||||
key=REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
),
|
||||
RedisPipelineLpopOperation(
|
||||
key=REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
),
|
||||
RedisPipelineLpopOperation(
|
||||
key=REDIS_DAILY_ORG_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
),
|
||||
RedisPipelineLpopOperation(
|
||||
key=REDIS_DAILY_END_USER_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
),
|
||||
RedisPipelineLpopOperation(
|
||||
key=REDIS_DAILY_AGENT_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
),
|
||||
RedisPipelineLpopOperation(
|
||||
key=REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
),
|
||||
]
|
||||
|
||||
raw_results = await self.redis_cache.async_lpop_pipeline(lpop_list=lpop_list)
|
||||
|
||||
# Pad with None if pipeline returned fewer results than expected
|
||||
while len(raw_results) < 7:
|
||||
raw_results.append(None)
|
||||
|
||||
# Slot 0: DBSpendUpdateTransactions
|
||||
db_spend: Optional[DBSpendUpdateTransactions] = None
|
||||
if raw_results[0] is not None:
|
||||
parsed = self._parse_list_of_transactions(raw_results[0])
|
||||
if len(parsed) > 0:
|
||||
db_spend = self._combine_list_of_transactions(parsed)
|
||||
|
||||
# Slots 1-6: daily spend categories
|
||||
daily_results: List[Optional[Dict[str, Any]]] = []
|
||||
for slot in range(1, 7):
|
||||
if raw_results[slot] is None:
|
||||
daily_results.append(None)
|
||||
else:
|
||||
list_of_daily = [json.loads(t) for t in raw_results[slot]] # type: ignore
|
||||
aggregated = DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
list_of_daily
|
||||
)
|
||||
daily_results.append(aggregated)
|
||||
|
||||
return (
|
||||
db_spend,
|
||||
cast(Optional[Dict[str, DailyUserSpendTransaction]], daily_results[0]),
|
||||
cast(Optional[Dict[str, DailyTeamSpendTransaction]], daily_results[1]),
|
||||
cast(
|
||||
Optional[Dict[str, DailyOrganizationSpendTransaction]], daily_results[2]
|
||||
),
|
||||
cast(Optional[Dict[str, DailyEndUserSpendTransaction]], daily_results[3]),
|
||||
cast(Optional[Dict[str, DailyAgentSpendTransaction]], daily_results[4]),
|
||||
cast(Optional[Dict[str, DailyTagSpendTransaction]], daily_results[5]),
|
||||
)
|
||||
|
||||
async def get_all_daily_spend_update_transactions_from_redis_buffer(
|
||||
self,
|
||||
) -> Optional[Dict[str, DailyUserSpendTransaction]]:
|
||||
"""
|
||||
Gets all the daily spend update transactions from Redis
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None
|
||||
list_of_transactions = await self.redis_cache.async_lpop(
|
||||
key=REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
)
|
||||
if list_of_transactions is None:
|
||||
return None
|
||||
list_of_daily_spend_update_transactions = [
|
||||
json.loads(transaction) for transaction in list_of_transactions
|
||||
]
|
||||
return cast(
|
||||
Dict[str, DailyUserSpendTransaction],
|
||||
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
list_of_daily_spend_update_transactions
|
||||
),
|
||||
)
|
||||
|
||||
async def get_all_daily_team_spend_update_transactions_from_redis_buffer(
|
||||
self,
|
||||
) -> Optional[Dict[str, DailyTeamSpendTransaction]]:
|
||||
"""
|
||||
Gets all the daily team spend update transactions from Redis
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None
|
||||
list_of_transactions = await self.redis_cache.async_lpop(
|
||||
key=REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
)
|
||||
if list_of_transactions is None:
|
||||
return None
|
||||
list_of_daily_spend_update_transactions = [
|
||||
json.loads(transaction) for transaction in list_of_transactions
|
||||
]
|
||||
return cast(
|
||||
Dict[str, DailyTeamSpendTransaction],
|
||||
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
list_of_daily_spend_update_transactions
|
||||
),
|
||||
)
|
||||
|
||||
async def get_all_daily_org_spend_update_transactions_from_redis_buffer(
|
||||
self,
|
||||
) -> Optional[Dict[str, DailyOrganizationSpendTransaction]]:
|
||||
"""
|
||||
Gets all the daily organization spend update transactions from Redis
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None
|
||||
list_of_transactions = await self.redis_cache.async_lpop(
|
||||
key=REDIS_DAILY_ORG_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
)
|
||||
if list_of_transactions is None:
|
||||
return None
|
||||
list_of_daily_spend_update_transactions = [
|
||||
json.loads(transaction) for transaction in list_of_transactions
|
||||
]
|
||||
return cast(
|
||||
Dict[str, DailyOrganizationSpendTransaction],
|
||||
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
list_of_daily_spend_update_transactions
|
||||
),
|
||||
)
|
||||
|
||||
async def get_all_daily_end_user_spend_update_transactions_from_redis_buffer(
|
||||
self,
|
||||
) -> Optional[Dict[str, DailyEndUserSpendTransaction]]:
|
||||
"""
|
||||
Gets all the daily end-user spend update transactions from Redis
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None
|
||||
list_of_transactions = await self.redis_cache.async_lpop(
|
||||
key=REDIS_DAILY_END_USER_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
)
|
||||
if list_of_transactions is None:
|
||||
return None
|
||||
list_of_daily_spend_update_transactions = [
|
||||
json.loads(transaction) for transaction in list_of_transactions
|
||||
]
|
||||
return cast(
|
||||
Dict[str, DailyEndUserSpendTransaction],
|
||||
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
list_of_daily_spend_update_transactions
|
||||
),
|
||||
)
|
||||
|
||||
async def get_all_daily_agent_spend_update_transactions_from_redis_buffer(
|
||||
self,
|
||||
) -> Optional[Dict[str, DailyAgentSpendTransaction]]:
|
||||
"""
|
||||
Gets all the daily agent spend update transactions from Redis
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None
|
||||
list_of_transactions = await self.redis_cache.async_lpop(
|
||||
key=REDIS_DAILY_AGENT_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
)
|
||||
if list_of_transactions is None:
|
||||
return None
|
||||
list_of_daily_spend_update_transactions = [
|
||||
json.loads(transaction) for transaction in list_of_transactions
|
||||
]
|
||||
return cast(
|
||||
Dict[str, DailyAgentSpendTransaction],
|
||||
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
list_of_daily_spend_update_transactions
|
||||
),
|
||||
)
|
||||
|
||||
async def get_all_daily_tag_spend_update_transactions_from_redis_buffer(
|
||||
self,
|
||||
) -> Optional[Dict[str, DailyTagSpendTransaction]]:
|
||||
"""
|
||||
Gets all the daily tag spend update transactions from Redis
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None
|
||||
list_of_transactions = await self.redis_cache.async_lpop(
|
||||
key=REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
)
|
||||
if list_of_transactions is None:
|
||||
return None
|
||||
list_of_daily_spend_update_transactions = [
|
||||
json.loads(transaction) for transaction in list_of_transactions
|
||||
]
|
||||
return cast(
|
||||
Dict[str, DailyTagSpendTransaction],
|
||||
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
list_of_daily_spend_update_transactions
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_list_of_transactions(
|
||||
list_of_transactions: Union[Any, List[Any]],
|
||||
) -> List[DBSpendUpdateTransactions]:
|
||||
"""
|
||||
Parses the list of transactions from Redis
|
||||
"""
|
||||
if isinstance(list_of_transactions, list):
|
||||
return [json.loads(transaction) for transaction in list_of_transactions]
|
||||
else:
|
||||
return [json.loads(list_of_transactions)]
|
||||
|
||||
@staticmethod
|
||||
def _combine_list_of_transactions(
|
||||
list_of_transactions: List[DBSpendUpdateTransactions],
|
||||
) -> DBSpendUpdateTransactions:
|
||||
"""
|
||||
Combines the list of transactions into a single DBSpendUpdateTransactions object
|
||||
"""
|
||||
# Initialize a new combined transaction object with empty dictionaries
|
||||
combined_transaction = DBSpendUpdateTransactions(
|
||||
user_list_transactions={},
|
||||
end_user_list_transactions={},
|
||||
key_list_transactions={},
|
||||
team_list_transactions={},
|
||||
team_member_list_transactions={},
|
||||
org_list_transactions={},
|
||||
tag_list_transactions={},
|
||||
agent_list_transactions={},
|
||||
)
|
||||
|
||||
# Define the transaction fields to process
|
||||
transaction_fields = [
|
||||
"user_list_transactions",
|
||||
"end_user_list_transactions",
|
||||
"key_list_transactions",
|
||||
"team_list_transactions",
|
||||
"team_member_list_transactions",
|
||||
"org_list_transactions",
|
||||
"tag_list_transactions",
|
||||
"agent_list_transactions",
|
||||
]
|
||||
|
||||
# Loop through each transaction and combine the values
|
||||
for transaction in list_of_transactions:
|
||||
# Process each field type
|
||||
for field in transaction_fields:
|
||||
if transaction.get(field):
|
||||
for entity_id, amount in transaction[field].items(): # type: ignore
|
||||
combined_transaction[field][entity_id] = ( # type: ignore
|
||||
combined_transaction[field].get(entity_id, 0) + amount # type: ignore
|
||||
)
|
||||
|
||||
return combined_transaction
|
||||
|
||||
async def _emit_new_item_added_to_redis_buffer_event(
|
||||
self,
|
||||
service: ServiceTypes,
|
||||
queue_size: int,
|
||||
):
|
||||
asyncio.create_task(
|
||||
service_logger_obj.async_service_success_hook(
|
||||
service=service,
|
||||
duration=0,
|
||||
call_type="_emit_new_item_added_to_queue_event",
|
||||
event_metadata={
|
||||
"gauge_labels": service,
|
||||
"gauge_value": queue_size,
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,172 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import RedisCache
|
||||
from litellm.constants import (
|
||||
SPEND_LOG_CLEANUP_BATCH_SIZE,
|
||||
SPEND_LOG_CLEANUP_JOB_NAME,
|
||||
SPEND_LOG_RUN_LOOPS,
|
||||
)
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
|
||||
class SpendLogCleanup:
|
||||
"""
|
||||
Handles cleaning up old spend logs based on maximum retention period.
|
||||
Deletes logs in batches to prevent timeouts.
|
||||
Uses PodLockManager to ensure only one pod runs cleanup in multi-pod deployments.
|
||||
"""
|
||||
|
||||
def __init__(self, general_settings=None, redis_cache: Optional[RedisCache] = None):
|
||||
self.batch_size = SPEND_LOG_CLEANUP_BATCH_SIZE
|
||||
self.retention_seconds: Optional[int] = None
|
||||
from litellm.proxy.proxy_server import general_settings as default_settings
|
||||
|
||||
self.general_settings = general_settings or default_settings
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
pod_lock_manager = proxy_logging_obj.db_spend_update_writer.pod_lock_manager
|
||||
self.pod_lock_manager = pod_lock_manager
|
||||
verbose_proxy_logger.info(
|
||||
f"SpendLogCleanup initialized with batch size: {self.batch_size}"
|
||||
)
|
||||
|
||||
def _should_delete_spend_logs(self) -> bool:
|
||||
"""
|
||||
Determines if logs should be deleted based on the max retention period in settings.
|
||||
"""
|
||||
retention_setting = self.general_settings.get(
|
||||
"maximum_spend_logs_retention_period"
|
||||
)
|
||||
verbose_proxy_logger.info(f"Checking retention setting: {retention_setting}")
|
||||
|
||||
if retention_setting is None:
|
||||
verbose_proxy_logger.info("No retention setting found")
|
||||
return False
|
||||
|
||||
try:
|
||||
if isinstance(retention_setting, int):
|
||||
verbose_proxy_logger.warning(
|
||||
f"maximum_spend_logs_retention_period is an integer ({retention_setting}); treating as days. "
|
||||
"Use a string like '3d' to be explicit."
|
||||
)
|
||||
retention_setting = f"{retention_setting}d"
|
||||
self.retention_seconds = duration_in_seconds(retention_setting)
|
||||
verbose_proxy_logger.info(
|
||||
f"Retention period set to {self.retention_seconds} seconds"
|
||||
)
|
||||
return True
|
||||
except ValueError as e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Invalid maximum_spend_logs_retention_period value: {retention_setting}, error: {str(e)}"
|
||||
)
|
||||
return False
|
||||
|
||||
async def _delete_old_logs(
|
||||
self, prisma_client: PrismaClient, cutoff_date: datetime
|
||||
) -> int:
|
||||
"""
|
||||
Helper method to delete old logs in batches.
|
||||
Returns the total number of logs deleted.
|
||||
"""
|
||||
total_deleted = 0
|
||||
run_count = 0
|
||||
while True:
|
||||
if run_count > SPEND_LOG_RUN_LOOPS:
|
||||
verbose_proxy_logger.info(
|
||||
"Max logs deleted - 1,00,000, rest of the logs will be deleted in next run"
|
||||
)
|
||||
break
|
||||
# Step 1: Find logs and delete them in one go without fetching to application
|
||||
# Delete in batches, limited by self.batch_size
|
||||
deleted_count = await prisma_client.db.execute_raw(
|
||||
"""
|
||||
DELETE FROM "LiteLLM_SpendLogs"
|
||||
WHERE "request_id" IN (
|
||||
SELECT "request_id" FROM "LiteLLM_SpendLogs"
|
||||
WHERE "startTime" < $1::timestamptz
|
||||
LIMIT $2
|
||||
)
|
||||
""",
|
||||
cutoff_date,
|
||||
self.batch_size,
|
||||
)
|
||||
verbose_proxy_logger.info(f"Deleted {deleted_count} logs in this batch")
|
||||
|
||||
if deleted_count == 0:
|
||||
verbose_proxy_logger.info(
|
||||
f"No more logs to delete. Total deleted: {total_deleted}"
|
||||
)
|
||||
break
|
||||
|
||||
total_deleted += deleted_count
|
||||
run_count += 1
|
||||
|
||||
# Add a small sleep to prevent overwhelming the database
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return total_deleted
|
||||
|
||||
async def cleanup_old_spend_logs(self, prisma_client: PrismaClient) -> None:
|
||||
"""
|
||||
Main cleanup function. Deletes old spend logs in batches.
|
||||
If pod_lock_manager is available, ensures only one pod runs cleanup.
|
||||
If no pod_lock_manager, runs cleanup without distributed locking.
|
||||
"""
|
||||
lock_acquired = False
|
||||
try:
|
||||
verbose_proxy_logger.info(f"Cleanup job triggered at {datetime.now()}")
|
||||
|
||||
if not self._should_delete_spend_logs():
|
||||
return
|
||||
|
||||
if self.retention_seconds is None:
|
||||
verbose_proxy_logger.error(
|
||||
"Retention seconds is None, cannot proceed with cleanup"
|
||||
)
|
||||
return
|
||||
|
||||
# If we have a pod lock manager, try to acquire the lock
|
||||
if self.pod_lock_manager and self.pod_lock_manager.redis_cache:
|
||||
lock_acquired = (
|
||||
await self.pod_lock_manager.acquire_lock(
|
||||
cronjob_id=SPEND_LOG_CLEANUP_JOB_NAME,
|
||||
)
|
||||
or False
|
||||
)
|
||||
verbose_proxy_logger.info(
|
||||
f"Lock acquisition attempt: {'successful' if lock_acquired else 'failed'} at {datetime.now()}"
|
||||
)
|
||||
|
||||
if not lock_acquired:
|
||||
verbose_proxy_logger.info("Another pod is already running cleanup")
|
||||
return
|
||||
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(
|
||||
seconds=float(self.retention_seconds)
|
||||
)
|
||||
verbose_proxy_logger.info(
|
||||
f"Deleting logs older than {cutoff_date.isoformat()}"
|
||||
)
|
||||
|
||||
# Perform the actual deletion
|
||||
total_deleted = await self._delete_old_logs(prisma_client, cutoff_date)
|
||||
verbose_proxy_logger.info(f"Deleted {total_deleted} logs")
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error during cleanup: {str(e)}")
|
||||
return # Return after error handling
|
||||
finally:
|
||||
# Only release the lock if it was actually acquired
|
||||
if (
|
||||
lock_acquired
|
||||
and self.pod_lock_manager
|
||||
and self.pod_lock_manager.redis_cache
|
||||
):
|
||||
await self.pod_lock_manager.release_lock(
|
||||
cronjob_id=SPEND_LOG_CLEANUP_JOB_NAME
|
||||
)
|
||||
verbose_proxy_logger.info("Released cleanup lock")
|
||||
@@ -0,0 +1,246 @@
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import LITELLM_ASYNCIO_QUEUE_MAXSIZE
|
||||
from litellm.proxy._types import (
|
||||
DBSpendUpdateTransactions,
|
||||
Litellm_EntityType,
|
||||
SpendUpdateQueueItem,
|
||||
)
|
||||
from litellm.proxy.db.db_transaction_queue.base_update_queue import (
|
||||
BaseUpdateQueue,
|
||||
service_logger_obj,
|
||||
)
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
|
||||
class SpendUpdateQueue(BaseUpdateQueue):
|
||||
"""
|
||||
In memory buffer for spend updates that should be committed to the database
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.update_queue: asyncio.Queue[SpendUpdateQueueItem] = asyncio.Queue(
|
||||
maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE
|
||||
)
|
||||
|
||||
async def flush_and_get_aggregated_db_spend_update_transactions(
|
||||
self,
|
||||
) -> DBSpendUpdateTransactions:
|
||||
"""Flush all updates from the queue and return all updates aggregated by entity type."""
|
||||
updates = await self.flush_all_updates_from_in_memory_queue()
|
||||
if len(updates) > 0:
|
||||
verbose_proxy_logger.info(
|
||||
"Spend tracking - flushed %d spend update items from in-memory queue",
|
||||
len(updates),
|
||||
)
|
||||
verbose_proxy_logger.debug("Aggregating updates by entity type: %s", updates)
|
||||
return self.get_aggregated_db_spend_update_transactions(updates)
|
||||
|
||||
async def add_update(self, update: SpendUpdateQueueItem):
|
||||
"""Enqueue an update to the spend update queue"""
|
||||
verbose_proxy_logger.debug("Adding update to queue: %s", update)
|
||||
await self.update_queue.put(update)
|
||||
|
||||
# if the queue is full, aggregate the updates
|
||||
if self.update_queue.qsize() >= self.MAX_SIZE_IN_MEMORY_QUEUE:
|
||||
verbose_proxy_logger.warning(
|
||||
"Spend update queue is full. Aggregating all entries in queue to concatenate entries."
|
||||
)
|
||||
await self.aggregate_queue_updates()
|
||||
|
||||
async def aggregate_queue_updates(self):
|
||||
"""Concatenate all updates in the queue to reduce the size of in-memory queue"""
|
||||
updates: List[
|
||||
SpendUpdateQueueItem
|
||||
] = await self.flush_all_updates_from_in_memory_queue()
|
||||
aggregated_updates = self._get_aggregated_spend_update_queue_item(updates)
|
||||
for update in aggregated_updates:
|
||||
await self.update_queue.put(update)
|
||||
return
|
||||
|
||||
def _get_aggregated_spend_update_queue_item(
|
||||
self, updates: List[SpendUpdateQueueItem]
|
||||
) -> List[SpendUpdateQueueItem]:
|
||||
"""
|
||||
This is used to reduce the size of the in-memory queue by aggregating updates by entity type + id
|
||||
|
||||
|
||||
Aggregate updates by entity type + id
|
||||
|
||||
eg.
|
||||
|
||||
```
|
||||
[
|
||||
{
|
||||
"entity_type": "user",
|
||||
"entity_id": "123",
|
||||
"response_cost": 100
|
||||
},
|
||||
{
|
||||
"entity_type": "user",
|
||||
"entity_id": "123",
|
||||
"response_cost": 200
|
||||
}
|
||||
]
|
||||
|
||||
```
|
||||
|
||||
becomes
|
||||
|
||||
```
|
||||
|
||||
[
|
||||
{
|
||||
"entity_type": "user",
|
||||
"entity_id": "123",
|
||||
"response_cost": 300
|
||||
}
|
||||
]
|
||||
|
||||
```
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
"Aggregating spend updates, current queue size: %s",
|
||||
self.update_queue.qsize(),
|
||||
)
|
||||
aggregated_spend_updates: List[SpendUpdateQueueItem] = []
|
||||
|
||||
_in_memory_map: Dict[str, SpendUpdateQueueItem] = {}
|
||||
"""
|
||||
Used for combining several updates into a single update
|
||||
Key=entity_type:entity_id
|
||||
Value=SpendUpdateQueueItem
|
||||
"""
|
||||
for update in updates:
|
||||
_key = f"{update.get('entity_type')}:{update.get('entity_id')}"
|
||||
if _key not in _in_memory_map:
|
||||
# avoid mutating caller-owned dicts while aggregating queue entries
|
||||
_in_memory_map[_key] = update.copy()
|
||||
else:
|
||||
current_cost = _in_memory_map[_key].get("response_cost", 0) or 0
|
||||
update_cost = update.get("response_cost", 0) or 0
|
||||
_in_memory_map[_key]["response_cost"] = current_cost + update_cost
|
||||
|
||||
for _key, update in _in_memory_map.items():
|
||||
aggregated_spend_updates.append(update)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Aggregated spend updates: %s", aggregated_spend_updates
|
||||
)
|
||||
return aggregated_spend_updates
|
||||
|
||||
def get_aggregated_db_spend_update_transactions(
|
||||
self, updates: List[SpendUpdateQueueItem]
|
||||
) -> DBSpendUpdateTransactions:
|
||||
"""Aggregate updates by entity type."""
|
||||
# Initialize all transaction lists as empty dicts
|
||||
db_spend_update_transactions = DBSpendUpdateTransactions(
|
||||
user_list_transactions={},
|
||||
end_user_list_transactions={},
|
||||
key_list_transactions={},
|
||||
team_list_transactions={},
|
||||
team_member_list_transactions={},
|
||||
org_list_transactions={},
|
||||
tag_list_transactions={},
|
||||
agent_list_transactions={},
|
||||
)
|
||||
|
||||
# Map entity types to their corresponding transaction dictionary keys
|
||||
entity_type_to_dict_key = {
|
||||
Litellm_EntityType.USER: "user_list_transactions",
|
||||
Litellm_EntityType.END_USER: "end_user_list_transactions",
|
||||
Litellm_EntityType.KEY: "key_list_transactions",
|
||||
Litellm_EntityType.TEAM: "team_list_transactions",
|
||||
Litellm_EntityType.TEAM_MEMBER: "team_member_list_transactions",
|
||||
Litellm_EntityType.ORGANIZATION: "org_list_transactions",
|
||||
Litellm_EntityType.TAG: "tag_list_transactions",
|
||||
Litellm_EntityType.AGENT: "agent_list_transactions",
|
||||
}
|
||||
|
||||
for update in updates:
|
||||
entity_type = update.get("entity_type")
|
||||
entity_id = update.get("entity_id") or ""
|
||||
response_cost = update.get("response_cost") or 0
|
||||
|
||||
if entity_type is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"Skipping update spend for update: %s, because entity_type is None",
|
||||
update,
|
||||
)
|
||||
continue
|
||||
|
||||
dict_key = entity_type_to_dict_key.get(entity_type)
|
||||
if dict_key is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"Skipping update spend for update: %s, because entity_type is not in entity_type_to_dict_key",
|
||||
update,
|
||||
)
|
||||
continue # Skip unknown entity types
|
||||
|
||||
# Type-safe access using if/elif statements
|
||||
if dict_key == "user_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"user_list_transactions"
|
||||
]
|
||||
elif dict_key == "end_user_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"end_user_list_transactions"
|
||||
]
|
||||
elif dict_key == "key_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"key_list_transactions"
|
||||
]
|
||||
elif dict_key == "team_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"team_list_transactions"
|
||||
]
|
||||
elif dict_key == "team_member_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"team_member_list_transactions"
|
||||
]
|
||||
elif dict_key == "org_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"org_list_transactions"
|
||||
]
|
||||
elif dict_key == "tag_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"tag_list_transactions"
|
||||
]
|
||||
elif dict_key == "agent_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"agent_list_transactions"
|
||||
]
|
||||
else:
|
||||
continue
|
||||
|
||||
if transactions_dict is None:
|
||||
transactions_dict = {}
|
||||
|
||||
# type ignore: dict_key is guaranteed to be one of "one of ("user_list_transactions", "end_user_list_transactions", "key_list_transactions", "team_list_transactions", "team_member_list_transactions", "org_list_transactions")"
|
||||
db_spend_update_transactions[dict_key] = transactions_dict # type: ignore
|
||||
|
||||
if entity_id not in transactions_dict:
|
||||
transactions_dict[entity_id] = 0
|
||||
|
||||
transactions_dict[entity_id] += response_cost or 0
|
||||
|
||||
return db_spend_update_transactions
|
||||
|
||||
async def _emit_new_item_added_to_queue_event(
|
||||
self,
|
||||
queue_size: Optional[int] = None,
|
||||
):
|
||||
asyncio.create_task(
|
||||
service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE,
|
||||
duration=0,
|
||||
call_type="_emit_new_item_added_to_queue_event",
|
||||
event_metadata={
|
||||
"gauge_labels": ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE,
|
||||
"gauge_value": queue_size,
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
In-memory buffer for tool registry upserts.
|
||||
|
||||
Unlike SpendUpdateQueue (which aggregates increments), ToolDiscoveryQueue
|
||||
uses set-deduplication: each unique tool_name is only queued once per flush
|
||||
cycle (~30s). The seen-set is cleared on every flush so that call_count
|
||||
increments in subsequent cycles rather than stopping after the first flush.
|
||||
"""
|
||||
|
||||
from typing import List, Set
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import ToolDiscoveryQueueItem
|
||||
|
||||
|
||||
class ToolDiscoveryQueue:
|
||||
"""
|
||||
In-memory buffer for tool registry upserts.
|
||||
|
||||
Deduplicates by tool_name within each flush cycle: a tool is only queued
|
||||
once per ~30s batch, so call_count increments once per flush cycle the
|
||||
tool appears in (not once per invocation, but not once per pod lifetime
|
||||
either). The seen-set is cleared on flush so subsequent batches can
|
||||
re-count the same tool.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._seen_tool_names: Set[str] = set()
|
||||
self._pending: List[ToolDiscoveryQueueItem] = []
|
||||
|
||||
def add_update(self, item: ToolDiscoveryQueueItem) -> None:
|
||||
"""Enqueue a tool discovery item if tool_name has not been seen before."""
|
||||
tool_name = item.get("tool_name", "")
|
||||
if not tool_name:
|
||||
return
|
||||
if tool_name in self._seen_tool_names:
|
||||
verbose_proxy_logger.debug(
|
||||
"ToolDiscoveryQueue: skipping already-seen tool %s", tool_name
|
||||
)
|
||||
return
|
||||
self._seen_tool_names.add(tool_name)
|
||||
self._pending.append(item)
|
||||
verbose_proxy_logger.debug(
|
||||
"ToolDiscoveryQueue: queued new tool %s (origin=%s)",
|
||||
tool_name,
|
||||
item.get("origin"),
|
||||
)
|
||||
|
||||
def flush(self) -> List[ToolDiscoveryQueueItem]:
|
||||
"""Return and clear all pending items. Resets seen-set so the next
|
||||
flush cycle can re-count the same tools."""
|
||||
items, self._pending = self._pending, []
|
||||
self._seen_tool_names.clear()
|
||||
return items
|
||||
Reference in New Issue
Block a user