Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/integrations/gcs_bucket/gcs_bucket.py
2026-03-26 16:04:46 +08:00

420 lines
15 KiB
Python

import asyncio
import hashlib
import json
import os
import time
from litellm._uuid import uuid
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from urllib.parse import quote
from litellm._logging import verbose_logger
from litellm.constants import LITELLM_ASYNCIO_QUEUE_MAXSIZE
from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
from litellm.proxy._types import CommonProxyErrors
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
from litellm.types.integrations.gcs_bucket import *
from litellm.types.utils import StandardLoggingPayload
if TYPE_CHECKING:
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
else:
VertexBase = Any
class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
def __init__(self, bucket_name: Optional[str] = None) -> None:
from litellm.proxy.proxy_server import premium_user
super().__init__(bucket_name=bucket_name)
self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE))
self.flush_interval = int(
os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS)
)
self.use_batched_logging = (
os.getenv(
"GCS_USE_BATCHED_LOGGING", str(GCS_DEFAULT_USE_BATCHED_LOGGING).lower()
).lower()
== "true"
)
self.flush_lock = asyncio.Lock()
super().__init__(
flush_lock=self.flush_lock,
batch_size=self.batch_size,
flush_interval=self.flush_interval,
)
self.log_queue: asyncio.Queue[GCSLogQueueItem] = asyncio.Queue( # type: ignore[assignment]
maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE
)
asyncio.create_task(self.periodic_flush())
AdditionalLoggingUtils.__init__(self)
if premium_user is not True:
raise ValueError(
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
)
#### ASYNC ####
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
from litellm.proxy.proxy_server import premium_user
if premium_user is not True:
raise ValueError(
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
)
try:
verbose_logger.debug(
"GCS Logger: async_log_success_event logging kwargs: %s, response_obj: %s",
kwargs,
response_obj,
)
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if logging_payload is None:
raise ValueError("standard_logging_object not found in kwargs")
# When queue is at maxsize, flush immediately to make room (no blocking, no data dropped)
if self.log_queue.full():
await self.flush_queue()
await self.log_queue.put(
GCSLogQueueItem(
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
)
)
except Exception as e:
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
verbose_logger.debug(
"GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
kwargs,
response_obj,
)
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if logging_payload is None:
raise ValueError("standard_logging_object not found in kwargs")
# When queue is at maxsize, flush immediately to make room (no blocking, no data dropped)
if self.log_queue.full():
await self.flush_queue()
await self.log_queue.put(
GCSLogQueueItem(
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
)
)
except Exception as e:
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
def _drain_queue_batch(self) -> List[GCSLogQueueItem]:
"""
Drain items from the queue (non-blocking), respecting batch_size limit.
This prevents unbounded queue growth when processing is slower than log accumulation.
Returns:
List of items to process, up to batch_size items
"""
items_to_process: List[GCSLogQueueItem] = []
while len(items_to_process) < self.batch_size:
try:
items_to_process.append(self.log_queue.get_nowait())
except asyncio.QueueEmpty:
break
return items_to_process
def _generate_batch_object_name(self, date_str: str, batch_id: str) -> str:
"""
Generate object name for a batched log file.
Format: {date}/batch-{batch_id}.ndjson
"""
return f"{date_str}/batch-{batch_id}.ndjson"
def _get_config_key(self, kwargs: Dict[str, Any]) -> str:
"""
Extract a synchronous grouping key from kwargs to group items by GCS config.
This allows us to batch items with the same bucket/credentials together.
Returns a string key that uniquely identifies the GCS config combination.
This key may contain sensitive information (bucket names, paths) - use _sanitize_config_key()
for logging purposes.
"""
standard_callback_dynamic_params = (
kwargs.get("standard_callback_dynamic_params", None) or {}
)
bucket_name = (
standard_callback_dynamic_params.get("gcs_bucket_name", None)
or self.BUCKET_NAME
or "default"
)
path_service_account = (
standard_callback_dynamic_params.get("gcs_path_service_account", None)
or self.path_service_account_json
or "default"
)
return f"{bucket_name}|{path_service_account}"
def _sanitize_config_key(self, config_key: str) -> str:
"""
Create a sanitized version of the config key for logging.
Uses a hash to avoid exposing sensitive bucket names or service account paths.
Returns a short hash prefix for safe logging.
"""
hash_obj = hashlib.sha256(config_key.encode("utf-8"))
return f"config-{hash_obj.hexdigest()[:8]}"
def _group_items_by_config(
self, items: List[GCSLogQueueItem]
) -> Dict[str, List[GCSLogQueueItem]]:
"""
Group items by their GCS config (bucket + credentials).
This ensures items with different configs are processed separately.
Returns a dict mapping config_key -> list of items with that config.
"""
grouped: Dict[str, List[GCSLogQueueItem]] = {}
for item in items:
config_key = self._get_config_key(item["kwargs"])
if config_key not in grouped:
grouped[config_key] = []
grouped[config_key].append(item)
return grouped
def _combine_payloads_to_ndjson(self, items: List[GCSLogQueueItem]) -> str:
"""
Combine multiple log payloads into newline-delimited JSON (NDJSON) format.
Each line is a valid JSON object representing one log entry.
"""
lines = []
for item in items:
logging_payload = item["payload"]
json_line = json.dumps(logging_payload, default=str, ensure_ascii=False)
lines.append(json_line)
return "\n".join(lines)
async def _send_grouped_batch(
self, items: List[GCSLogQueueItem], config_key: str
) -> Tuple[int, int]:
"""
Send a batch of items that share the same GCS config.
Returns:
(success_count, error_count)
"""
if not items:
return (0, 0)
first_kwargs = items[0]["kwargs"]
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
first_kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
current_date = self._get_object_date_from_datetime(
datetime.now(timezone.utc)
)
batch_id = f"{int(time.time() * 1000)}-{uuid.uuid4().hex[:8]}"
object_name = self._generate_batch_object_name(current_date, batch_id)
combined_payload = self._combine_payloads_to_ndjson(items)
await self._log_json_data_on_gcs(
headers=headers,
bucket_name=bucket_name,
object_name=object_name,
logging_payload=combined_payload,
)
success_count = len(items)
error_count = 0
return (success_count, error_count)
except Exception as e:
success_count = 0
error_count = len(items)
verbose_logger.exception(
f"GCS Bucket error logging batch payload to GCS bucket: {str(e)}"
)
return (success_count, error_count)
async def _send_individual_logs(self, items: List[GCSLogQueueItem]) -> None:
"""
Send each log individually as separate GCS objects (legacy behavior).
This is used when GCS_USE_BATCHED_LOGGING is disabled.
"""
for item in items:
await self._send_single_log_item(item)
async def _send_single_log_item(self, item: GCSLogQueueItem) -> None:
"""
Send a single log item to GCS as an individual object.
"""
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
item["kwargs"]
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
object_name = self._get_object_name(
kwargs=item["kwargs"],
logging_payload=item["payload"],
response_obj=item["response_obj"],
)
await self._log_json_data_on_gcs(
headers=headers,
bucket_name=bucket_name,
object_name=object_name,
logging_payload=item["payload"],
)
except Exception as e:
verbose_logger.exception(
f"GCS Bucket error logging individual payload to GCS bucket: {str(e)}"
)
async def async_send_batch(self):
"""
Process queued logs - sends logs to GCS Bucket.
If `GCS_USE_BATCHED_LOGGING` is enabled (default), batches multiple log payloads
into single GCS object uploads (NDJSON format), dramatically reducing API calls.
If disabled, sends each log individually as separate GCS objects (legacy behavior).
"""
items_to_process = self._drain_queue_batch()
if not items_to_process:
return
if self.use_batched_logging:
grouped_items = self._group_items_by_config(items_to_process)
for config_key, group_items in grouped_items.items():
await self._send_grouped_batch(group_items, config_key)
else:
await self._send_individual_logs(items_to_process)
def _get_object_name(
self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any
) -> str:
"""
Get the object name to use for the current payload
"""
current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc))
if logging_payload.get("error_str", None) is not None:
object_name = self._generate_failure_object_name(
request_date_str=current_date,
)
else:
object_name = self._generate_success_object_name(
request_date_str=current_date,
response_id=response_obj.get("id", ""),
)
# used for testing
_litellm_params = kwargs.get("litellm_params", None) or {}
_metadata = _litellm_params.get("metadata", None) or {}
if "gcs_log_id" in _metadata:
object_name = _metadata["gcs_log_id"]
return object_name
async def get_request_response_payload(
self,
request_id: str,
start_time_utc: Optional[datetime],
end_time_utc: Optional[datetime],
) -> Optional[dict]:
"""
Get the request and response payload for a given `request_id`
Tries current day, next day, and previous day until it finds the payload
"""
if start_time_utc is None:
raise ValueError(
"start_time_utc is required for getting a payload from GCS Bucket"
)
dates_to_try = [
start_time_utc,
start_time_utc + timedelta(days=1),
start_time_utc - timedelta(days=1),
]
date_str = None
for date in dates_to_try:
try:
date_str = self._get_object_date_from_datetime(datetime_obj=date)
object_name = self._generate_success_object_name(
request_date_str=date_str,
response_id=request_id,
)
encoded_object_name = quote(object_name, safe="")
response = await self.download_gcs_object(encoded_object_name)
if response is not None:
loaded_response = json.loads(response)
return loaded_response
except Exception as e:
verbose_logger.debug(
f"Failed to fetch payload for date {date_str}: {str(e)}"
)
continue
return None
def _generate_success_object_name(
self,
request_date_str: str,
response_id: str,
) -> str:
return f"{request_date_str}/{response_id}"
def _generate_failure_object_name(
self,
request_date_str: str,
) -> str:
return f"{request_date_str}/failure-{uuid.uuid4().hex}"
def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str:
return datetime_obj.strftime("%Y-%m-%d")
async def flush_queue(self):
"""
Override flush_queue to work with asyncio.Queue.
"""
await self.async_send_batch()
self.last_flush_time = time.time()
async def periodic_flush(self):
"""
Override periodic_flush to work with asyncio.Queue.
"""
while True:
await asyncio.sleep(self.flush_interval)
verbose_logger.debug(
f"GCS Bucket periodic flush after {self.flush_interval} seconds"
)
await self.flush_queue()
async def async_health_check(self) -> IntegrationHealthCheckStatus:
raise NotImplementedError("GCS Bucket does not support health check")