chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,12 @@
|
||||
# GCS (Google Cloud Storage) Bucket Logging on LiteLLM Gateway
|
||||
|
||||
This folder contains the GCS Bucket Logging integration for LiteLLM Gateway.
|
||||
|
||||
## Folder Structure
|
||||
|
||||
- `gcs_bucket.py`: This is the main file that handles failure/success logging to GCS Bucket
|
||||
- `gcs_bucket_base.py`: This file contains the GCSBucketBase class which handles Authentication for GCS Buckets
|
||||
|
||||
## Further Reading
|
||||
- [Doc setting up GCS Bucket Logging on LiteLLM Proxy (Gateway)](https://docs.litellm.ai/docs/observability/gcs_bucket_integration)
|
||||
- [Doc on Key / Team Based logging with GCS](https://docs.litellm.ai/docs/proxy/team_logging)
|
||||
@@ -0,0 +1,419 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from litellm._uuid import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import LITELLM_ASYNCIO_QUEUE_MAXSIZE
|
||||
from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
|
||||
from litellm.types.integrations.gcs_bucket import *
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
else:
|
||||
VertexBase = Any
|
||||
|
||||
|
||||
class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
|
||||
def __init__(self, bucket_name: Optional[str] = None) -> None:
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
super().__init__(bucket_name=bucket_name)
|
||||
|
||||
self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE))
|
||||
self.flush_interval = int(
|
||||
os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS)
|
||||
)
|
||||
self.use_batched_logging = (
|
||||
os.getenv(
|
||||
"GCS_USE_BATCHED_LOGGING", str(GCS_DEFAULT_USE_BATCHED_LOGGING).lower()
|
||||
).lower()
|
||||
== "true"
|
||||
)
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(
|
||||
flush_lock=self.flush_lock,
|
||||
batch_size=self.batch_size,
|
||||
flush_interval=self.flush_interval,
|
||||
)
|
||||
self.log_queue: asyncio.Queue[GCSLogQueueItem] = asyncio.Queue( # type: ignore[assignment]
|
||||
maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE
|
||||
)
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
AdditionalLoggingUtils.__init__(self)
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
|
||||
#### ASYNC ####
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"GCS Logger: async_log_success_event logging kwargs: %s, response_obj: %s",
|
||||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
if logging_payload is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
# When queue is at maxsize, flush immediately to make room (no blocking, no data dropped)
|
||||
if self.log_queue.full():
|
||||
await self.flush_queue()
|
||||
await self.log_queue.put(
|
||||
GCSLogQueueItem(
|
||||
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
|
||||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
|
||||
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
if logging_payload is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
# When queue is at maxsize, flush immediately to make room (no blocking, no data dropped)
|
||||
if self.log_queue.full():
|
||||
await self.flush_queue()
|
||||
await self.log_queue.put(
|
||||
GCSLogQueueItem(
|
||||
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
|
||||
|
||||
def _drain_queue_batch(self) -> List[GCSLogQueueItem]:
|
||||
"""
|
||||
Drain items from the queue (non-blocking), respecting batch_size limit.
|
||||
|
||||
This prevents unbounded queue growth when processing is slower than log accumulation.
|
||||
|
||||
Returns:
|
||||
List of items to process, up to batch_size items
|
||||
"""
|
||||
items_to_process: List[GCSLogQueueItem] = []
|
||||
while len(items_to_process) < self.batch_size:
|
||||
try:
|
||||
items_to_process.append(self.log_queue.get_nowait())
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
return items_to_process
|
||||
|
||||
def _generate_batch_object_name(self, date_str: str, batch_id: str) -> str:
|
||||
"""
|
||||
Generate object name for a batched log file.
|
||||
Format: {date}/batch-{batch_id}.ndjson
|
||||
"""
|
||||
return f"{date_str}/batch-{batch_id}.ndjson"
|
||||
|
||||
def _get_config_key(self, kwargs: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract a synchronous grouping key from kwargs to group items by GCS config.
|
||||
This allows us to batch items with the same bucket/credentials together.
|
||||
|
||||
Returns a string key that uniquely identifies the GCS config combination.
|
||||
This key may contain sensitive information (bucket names, paths) - use _sanitize_config_key()
|
||||
for logging purposes.
|
||||
"""
|
||||
standard_callback_dynamic_params = (
|
||||
kwargs.get("standard_callback_dynamic_params", None) or {}
|
||||
)
|
||||
|
||||
bucket_name = (
|
||||
standard_callback_dynamic_params.get("gcs_bucket_name", None)
|
||||
or self.BUCKET_NAME
|
||||
or "default"
|
||||
)
|
||||
path_service_account = (
|
||||
standard_callback_dynamic_params.get("gcs_path_service_account", None)
|
||||
or self.path_service_account_json
|
||||
or "default"
|
||||
)
|
||||
|
||||
return f"{bucket_name}|{path_service_account}"
|
||||
|
||||
def _sanitize_config_key(self, config_key: str) -> str:
|
||||
"""
|
||||
Create a sanitized version of the config key for logging.
|
||||
Uses a hash to avoid exposing sensitive bucket names or service account paths.
|
||||
|
||||
Returns a short hash prefix for safe logging.
|
||||
"""
|
||||
hash_obj = hashlib.sha256(config_key.encode("utf-8"))
|
||||
return f"config-{hash_obj.hexdigest()[:8]}"
|
||||
|
||||
def _group_items_by_config(
|
||||
self, items: List[GCSLogQueueItem]
|
||||
) -> Dict[str, List[GCSLogQueueItem]]:
|
||||
"""
|
||||
Group items by their GCS config (bucket + credentials).
|
||||
This ensures items with different configs are processed separately.
|
||||
|
||||
Returns a dict mapping config_key -> list of items with that config.
|
||||
"""
|
||||
grouped: Dict[str, List[GCSLogQueueItem]] = {}
|
||||
for item in items:
|
||||
config_key = self._get_config_key(item["kwargs"])
|
||||
if config_key not in grouped:
|
||||
grouped[config_key] = []
|
||||
grouped[config_key].append(item)
|
||||
return grouped
|
||||
|
||||
def _combine_payloads_to_ndjson(self, items: List[GCSLogQueueItem]) -> str:
|
||||
"""
|
||||
Combine multiple log payloads into newline-delimited JSON (NDJSON) format.
|
||||
Each line is a valid JSON object representing one log entry.
|
||||
"""
|
||||
lines = []
|
||||
for item in items:
|
||||
logging_payload = item["payload"]
|
||||
json_line = json.dumps(logging_payload, default=str, ensure_ascii=False)
|
||||
lines.append(json_line)
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _send_grouped_batch(
|
||||
self, items: List[GCSLogQueueItem], config_key: str
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Send a batch of items that share the same GCS config.
|
||||
|
||||
Returns:
|
||||
(success_count, error_count)
|
||||
"""
|
||||
if not items:
|
||||
return (0, 0)
|
||||
|
||||
first_kwargs = items[0]["kwargs"]
|
||||
|
||||
try:
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
first_kwargs
|
||||
)
|
||||
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
|
||||
current_date = self._get_object_date_from_datetime(
|
||||
datetime.now(timezone.utc)
|
||||
)
|
||||
batch_id = f"{int(time.time() * 1000)}-{uuid.uuid4().hex[:8]}"
|
||||
object_name = self._generate_batch_object_name(current_date, batch_id)
|
||||
combined_payload = self._combine_payloads_to_ndjson(items)
|
||||
|
||||
await self._log_json_data_on_gcs(
|
||||
headers=headers,
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
logging_payload=combined_payload,
|
||||
)
|
||||
|
||||
success_count = len(items)
|
||||
error_count = 0
|
||||
return (success_count, error_count)
|
||||
|
||||
except Exception as e:
|
||||
success_count = 0
|
||||
error_count = len(items)
|
||||
verbose_logger.exception(
|
||||
f"GCS Bucket error logging batch payload to GCS bucket: {str(e)}"
|
||||
)
|
||||
return (success_count, error_count)
|
||||
|
||||
async def _send_individual_logs(self, items: List[GCSLogQueueItem]) -> None:
|
||||
"""
|
||||
Send each log individually as separate GCS objects (legacy behavior).
|
||||
This is used when GCS_USE_BATCHED_LOGGING is disabled.
|
||||
"""
|
||||
for item in items:
|
||||
await self._send_single_log_item(item)
|
||||
|
||||
async def _send_single_log_item(self, item: GCSLogQueueItem) -> None:
|
||||
"""
|
||||
Send a single log item to GCS as an individual object.
|
||||
"""
|
||||
try:
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
item["kwargs"]
|
||||
)
|
||||
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
|
||||
object_name = self._get_object_name(
|
||||
kwargs=item["kwargs"],
|
||||
logging_payload=item["payload"],
|
||||
response_obj=item["response_obj"],
|
||||
)
|
||||
|
||||
await self._log_json_data_on_gcs(
|
||||
headers=headers,
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
logging_payload=item["payload"],
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"GCS Bucket error logging individual payload to GCS bucket: {str(e)}"
|
||||
)
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
Process queued logs - sends logs to GCS Bucket.
|
||||
|
||||
If `GCS_USE_BATCHED_LOGGING` is enabled (default), batches multiple log payloads
|
||||
into single GCS object uploads (NDJSON format), dramatically reducing API calls.
|
||||
|
||||
If disabled, sends each log individually as separate GCS objects (legacy behavior).
|
||||
"""
|
||||
items_to_process = self._drain_queue_batch()
|
||||
|
||||
if not items_to_process:
|
||||
return
|
||||
|
||||
if self.use_batched_logging:
|
||||
grouped_items = self._group_items_by_config(items_to_process)
|
||||
|
||||
for config_key, group_items in grouped_items.items():
|
||||
await self._send_grouped_batch(group_items, config_key)
|
||||
else:
|
||||
await self._send_individual_logs(items_to_process)
|
||||
|
||||
def _get_object_name(
|
||||
self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any
|
||||
) -> str:
|
||||
"""
|
||||
Get the object name to use for the current payload
|
||||
"""
|
||||
current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc))
|
||||
if logging_payload.get("error_str", None) is not None:
|
||||
object_name = self._generate_failure_object_name(
|
||||
request_date_str=current_date,
|
||||
)
|
||||
else:
|
||||
object_name = self._generate_success_object_name(
|
||||
request_date_str=current_date,
|
||||
response_id=response_obj.get("id", ""),
|
||||
)
|
||||
|
||||
# used for testing
|
||||
_litellm_params = kwargs.get("litellm_params", None) or {}
|
||||
_metadata = _litellm_params.get("metadata", None) or {}
|
||||
if "gcs_log_id" in _metadata:
|
||||
object_name = _metadata["gcs_log_id"]
|
||||
|
||||
return object_name
|
||||
|
||||
async def get_request_response_payload(
|
||||
self,
|
||||
request_id: str,
|
||||
start_time_utc: Optional[datetime],
|
||||
end_time_utc: Optional[datetime],
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the request and response payload for a given `request_id`
|
||||
Tries current day, next day, and previous day until it finds the payload
|
||||
"""
|
||||
if start_time_utc is None:
|
||||
raise ValueError(
|
||||
"start_time_utc is required for getting a payload from GCS Bucket"
|
||||
)
|
||||
|
||||
dates_to_try = [
|
||||
start_time_utc,
|
||||
start_time_utc + timedelta(days=1),
|
||||
start_time_utc - timedelta(days=1),
|
||||
]
|
||||
date_str = None
|
||||
for date in dates_to_try:
|
||||
try:
|
||||
date_str = self._get_object_date_from_datetime(datetime_obj=date)
|
||||
object_name = self._generate_success_object_name(
|
||||
request_date_str=date_str,
|
||||
response_id=request_id,
|
||||
)
|
||||
encoded_object_name = quote(object_name, safe="")
|
||||
response = await self.download_gcs_object(encoded_object_name)
|
||||
|
||||
if response is not None:
|
||||
loaded_response = json.loads(response)
|
||||
return loaded_response
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Failed to fetch payload for date {date_str}: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def _generate_success_object_name(
|
||||
self,
|
||||
request_date_str: str,
|
||||
response_id: str,
|
||||
) -> str:
|
||||
return f"{request_date_str}/{response_id}"
|
||||
|
||||
def _generate_failure_object_name(
|
||||
self,
|
||||
request_date_str: str,
|
||||
) -> str:
|
||||
return f"{request_date_str}/failure-{uuid.uuid4().hex}"
|
||||
|
||||
def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str:
|
||||
return datetime_obj.strftime("%Y-%m-%d")
|
||||
|
||||
async def flush_queue(self):
|
||||
"""
|
||||
Override flush_queue to work with asyncio.Queue.
|
||||
"""
|
||||
await self.async_send_batch()
|
||||
self.last_flush_time = time.time()
|
||||
|
||||
async def periodic_flush(self):
|
||||
"""
|
||||
Override periodic_flush to work with asyncio.Queue.
|
||||
"""
|
||||
while True:
|
||||
await asyncio.sleep(self.flush_interval)
|
||||
verbose_logger.debug(
|
||||
f"GCS Bucket periodic flush after {self.flush_interval} seconds"
|
||||
)
|
||||
await self.flush_queue()
|
||||
|
||||
async def async_health_check(self) -> IntegrationHealthCheckStatus:
|
||||
raise NotImplementedError("GCS Bucket does not support health check")
|
||||
@@ -0,0 +1,347 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket_mock_client import (
|
||||
should_use_gcs_mock,
|
||||
create_mock_gcs_client,
|
||||
mock_vertex_auth_methods,
|
||||
)
|
||||
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.integrations.gcs_bucket import *
|
||||
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
else:
|
||||
VertexBase = Any
|
||||
IAM_AUTH_KEY = "IAM_AUTH"
|
||||
|
||||
|
||||
class GCSBucketBase(CustomBatchLogger):
|
||||
def __init__(self, bucket_name: Optional[str] = None, **kwargs) -> None:
|
||||
self.is_mock_mode = should_use_gcs_mock()
|
||||
|
||||
if self.is_mock_mode:
|
||||
mock_vertex_auth_methods()
|
||||
create_mock_gcs_client()
|
||||
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
_path_service_account = os.getenv("GCS_PATH_SERVICE_ACCOUNT")
|
||||
_bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
|
||||
self.path_service_account_json: Optional[str] = _path_service_account
|
||||
self.BUCKET_NAME: Optional[str] = _bucket_name
|
||||
self.vertex_instances: Dict[str, VertexBase] = {}
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def construct_request_headers(
|
||||
self,
|
||||
service_account_json: Optional[str],
|
||||
vertex_instance: Optional[VertexBase] = None,
|
||||
) -> Dict[str, str]:
|
||||
from litellm import vertex_chat_completion
|
||||
|
||||
if vertex_instance is None:
|
||||
vertex_instance = vertex_chat_completion
|
||||
|
||||
_auth_header, vertex_project = await vertex_instance._ensure_access_token_async(
|
||||
credentials=service_account_json,
|
||||
project_id=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_instance._get_token_and_url(
|
||||
model="gcs-bucket",
|
||||
auth_header=_auth_header,
|
||||
vertex_credentials=service_account_json,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=None,
|
||||
gemini_api_key=None,
|
||||
stream=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
api_base=None,
|
||||
)
|
||||
verbose_logger.debug("constructed auth_header %s", auth_header)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}", # auth_header
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
return headers
|
||||
|
||||
def sync_construct_request_headers(self) -> Dict[str, str]:
|
||||
"""
|
||||
Construct request headers for GCS API calls
|
||||
"""
|
||||
from litellm import vertex_chat_completion
|
||||
|
||||
# Get project_id from environment if available, otherwise None
|
||||
# This helps support use of this library to auth to pull secrets
|
||||
# from Secret Manager.
|
||||
project_id = os.getenv("GOOGLE_SECRET_MANAGER_PROJECT_ID")
|
||||
|
||||
_auth_header, vertex_project = vertex_chat_completion._ensure_access_token(
|
||||
credentials=self.path_service_account_json,
|
||||
project_id=project_id,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
||||
model="gcs-bucket",
|
||||
auth_header=_auth_header,
|
||||
vertex_credentials=self.path_service_account_json,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=None,
|
||||
gemini_api_key=None,
|
||||
stream=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
api_base=None,
|
||||
)
|
||||
verbose_logger.debug("constructed auth_header %s", auth_header)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}", # auth_header
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
return headers
|
||||
|
||||
def _handle_folders_in_bucket_name(
|
||||
self,
|
||||
bucket_name: str,
|
||||
object_name: str,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Handles when the user passes a bucket name with a folder postfix
|
||||
|
||||
|
||||
Example:
|
||||
- Bucket name: "my-bucket/my-folder/dev"
|
||||
- Object name: "my-object"
|
||||
- Returns: bucket_name="my-bucket", object_name="my-folder/dev/my-object"
|
||||
|
||||
"""
|
||||
if "/" in bucket_name:
|
||||
bucket_name, prefix = bucket_name.split("/", 1)
|
||||
object_name = f"{prefix}/{object_name}"
|
||||
return bucket_name, object_name
|
||||
return bucket_name, object_name
|
||||
|
||||
async def get_gcs_logging_config(
|
||||
self, kwargs: Optional[Dict[str, Any]] = {}
|
||||
) -> GCSLoggingConfig:
|
||||
"""
|
||||
This function is used to get the GCS logging config for the GCS Bucket Logger.
|
||||
It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config.
|
||||
If no dynamic parameters are provided, it uses the default values.
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
standard_callback_dynamic_params: Optional[
|
||||
StandardCallbackDynamicParams
|
||||
] = kwargs.get("standard_callback_dynamic_params", None)
|
||||
|
||||
bucket_name: str
|
||||
path_service_account: Optional[str]
|
||||
if standard_callback_dynamic_params is not None:
|
||||
verbose_logger.debug("Using dynamic GCS logging")
|
||||
verbose_logger.debug(
|
||||
"standard_callback_dynamic_params: %s", standard_callback_dynamic_params
|
||||
)
|
||||
|
||||
_bucket_name: Optional[str] = (
|
||||
standard_callback_dynamic_params.get("gcs_bucket_name", None)
|
||||
or self.BUCKET_NAME
|
||||
)
|
||||
_path_service_account: Optional[str] = (
|
||||
standard_callback_dynamic_params.get("gcs_path_service_account", None)
|
||||
or self.path_service_account_json
|
||||
)
|
||||
|
||||
if _bucket_name is None:
|
||||
raise ValueError(
|
||||
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
||||
)
|
||||
bucket_name = _bucket_name
|
||||
path_service_account = _path_service_account
|
||||
vertex_instance = await self.get_or_create_vertex_instance(
|
||||
credentials=path_service_account
|
||||
)
|
||||
else:
|
||||
# If no dynamic parameters, use the default instance
|
||||
if self.BUCKET_NAME is None:
|
||||
raise ValueError(
|
||||
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
||||
)
|
||||
bucket_name = self.BUCKET_NAME
|
||||
path_service_account = self.path_service_account_json
|
||||
vertex_instance = await self.get_or_create_vertex_instance(
|
||||
credentials=path_service_account
|
||||
)
|
||||
|
||||
return GCSLoggingConfig(
|
||||
bucket_name=bucket_name,
|
||||
vertex_instance=vertex_instance,
|
||||
path_service_account=path_service_account,
|
||||
)
|
||||
|
||||
async def get_or_create_vertex_instance(
|
||||
self, credentials: Optional[str]
|
||||
) -> VertexBase:
|
||||
"""
|
||||
This function is used to get the Vertex instance for the GCS Bucket Logger.
|
||||
It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
|
||||
"""
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
|
||||
_in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
|
||||
if _in_memory_key not in self.vertex_instances:
|
||||
vertex_instance = VertexBase()
|
||||
await vertex_instance._ensure_access_token_async(
|
||||
credentials=credentials,
|
||||
project_id=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
self.vertex_instances[_in_memory_key] = vertex_instance
|
||||
return self.vertex_instances[_in_memory_key]
|
||||
|
||||
def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
|
||||
"""
|
||||
Returns key to use for caching the Vertex instance in-memory.
|
||||
|
||||
When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
|
||||
|
||||
- If a credentials string is provided, it is used as the key.
|
||||
- If no credentials string is provided, "IAM_AUTH" is used as the key.
|
||||
"""
|
||||
return credentials or IAM_AUTH_KEY
|
||||
|
||||
async def download_gcs_object(self, object_name: str, **kwargs):
|
||||
"""
|
||||
Download an object from GCS.
|
||||
|
||||
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
|
||||
"""
|
||||
try:
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs=kwargs
|
||||
)
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
)
|
||||
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
|
||||
|
||||
# Send the GET request to download the object
|
||||
response = await self.async_httpx_client.get(url=url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
verbose_logger.error(
|
||||
"GCS object download error: %s", str(response.text)
|
||||
)
|
||||
return None
|
||||
|
||||
verbose_logger.debug(
|
||||
"GCS object download response status code: %s", response.status_code
|
||||
)
|
||||
|
||||
# Return the content of the downloaded object
|
||||
return response.content
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error("GCS object download error: %s", str(e))
|
||||
return None
|
||||
|
||||
async def delete_gcs_object(self, object_name: str, **kwargs):
|
||||
"""
|
||||
Delete an object from GCS.
|
||||
"""
|
||||
try:
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs=kwargs
|
||||
)
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
)
|
||||
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
|
||||
|
||||
# Send the DELETE request to delete the object
|
||||
response = await self.async_httpx_client.delete(url=url, headers=headers)
|
||||
|
||||
if (response.status_code != 200) or (response.status_code != 204):
|
||||
verbose_logger.error(
|
||||
"GCS object delete error: %s, status code: %s",
|
||||
str(response.text),
|
||||
response.status_code,
|
||||
)
|
||||
return None
|
||||
|
||||
verbose_logger.debug(
|
||||
"GCS object delete response status code: %s, response: %s",
|
||||
response.status_code,
|
||||
response.text,
|
||||
)
|
||||
|
||||
# Return the content of the downloaded object
|
||||
return response.text
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error("GCS object download error: %s", str(e))
|
||||
return None
|
||||
|
||||
async def _log_json_data_on_gcs(
|
||||
self,
|
||||
headers: Dict[str, str],
|
||||
bucket_name: str,
|
||||
object_name: str,
|
||||
logging_payload: Union[StandardLoggingPayload, str],
|
||||
):
|
||||
"""
|
||||
Helper function to make POST request to GCS Bucket in the specified bucket.
|
||||
"""
|
||||
if isinstance(logging_payload, str):
|
||||
json_logged_payload = logging_payload
|
||||
else:
|
||||
json_logged_payload = json.dumps(logging_payload, default=str)
|
||||
|
||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
)
|
||||
|
||||
response = await self.async_httpx_client.post(
|
||||
headers=headers,
|
||||
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
|
||||
data=json_logged_payload,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
|
||||
|
||||
verbose_logger.debug("GCS Bucket response %s", response)
|
||||
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
|
||||
verbose_logger.debug("GCS Bucket response.text %s", response.text)
|
||||
|
||||
return response.json()
|
||||
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
Mock client for GCS Bucket integration testing.
|
||||
|
||||
This module intercepts GCS API calls and Vertex AI auth calls, returning successful
|
||||
mock responses, allowing full code execution without making actual network calls.
|
||||
|
||||
Usage:
|
||||
Set GCS_MOCK=true in environment variables or config to enable mock mode.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.mock_client_factory import (
|
||||
MockClientConfig,
|
||||
create_mock_client_factory,
|
||||
MockResponse,
|
||||
)
|
||||
|
||||
# Use factory for POST handler
|
||||
_config = MockClientConfig(
|
||||
name="GCS",
|
||||
env_var="GCS_MOCK",
|
||||
default_latency_ms=150,
|
||||
default_status_code=200,
|
||||
default_json_data={"kind": "storage#object", "name": "mock-object"},
|
||||
url_matchers=["storage.googleapis.com"],
|
||||
patch_async_handler=True,
|
||||
patch_sync_client=False,
|
||||
)
|
||||
|
||||
_create_mock_gcs_post, should_use_gcs_mock = create_mock_client_factory(_config)
|
||||
|
||||
# Store original methods for GET/DELETE (GCS-specific)
|
||||
_original_async_handler_get = None
|
||||
_original_async_handler_delete = None
|
||||
_mocks_initialized = False
|
||||
|
||||
# Default mock latency in seconds (simulates network round-trip)
|
||||
# Typical GCS API calls take 100-300ms for uploads, 50-150ms for GET/DELETE
|
||||
_MOCK_LATENCY_SECONDS = (
|
||||
float(__import__("os").getenv("GCS_MOCK_LATENCY_MS", "150")) / 1000.0
|
||||
)
|
||||
|
||||
|
||||
async def _mock_async_handler_get(
|
||||
self, url, params=None, headers=None, follow_redirects=None
|
||||
):
|
||||
"""Monkey-patched AsyncHTTPHandler.get that intercepts GCS calls."""
|
||||
# Only mock GCS API calls
|
||||
if isinstance(url, str) and "storage.googleapis.com" in url:
|
||||
verbose_logger.info(f"[GCS MOCK] GET to {url}")
|
||||
await asyncio.sleep(_MOCK_LATENCY_SECONDS)
|
||||
# Return a minimal but valid StandardLoggingPayload JSON string as bytes
|
||||
# This matches what GCS returns when downloading with ?alt=media
|
||||
mock_payload = {
|
||||
"id": "mock-request-id",
|
||||
"trace_id": "mock-trace-id",
|
||||
"call_type": "completion",
|
||||
"stream": False,
|
||||
"response_cost": 0.0,
|
||||
"status": "success",
|
||||
"status_fields": {"llm_api_status": "success"},
|
||||
"custom_llm_provider": "mock",
|
||||
"total_tokens": 0,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"startTime": 0.0,
|
||||
"endTime": 0.0,
|
||||
"completionStartTime": 0.0,
|
||||
"response_time": 0.0,
|
||||
"model_map_information": {"model": "mock-model"},
|
||||
"model": "mock-model",
|
||||
"model_id": None,
|
||||
"model_group": None,
|
||||
"api_base": "https://api.mock.com",
|
||||
"metadata": {},
|
||||
"cache_hit": None,
|
||||
"cache_key": None,
|
||||
"saved_cache_cost": 0.0,
|
||||
"request_tags": [],
|
||||
"end_user": None,
|
||||
"requester_ip_address": None,
|
||||
"messages": None,
|
||||
"response": None,
|
||||
"error_str": None,
|
||||
"error_information": None,
|
||||
"model_parameters": {},
|
||||
"hidden_params": {},
|
||||
"guardrail_information": None,
|
||||
"standard_built_in_tools_params": None,
|
||||
}
|
||||
return MockResponse(
|
||||
status_code=200,
|
||||
json_data=mock_payload,
|
||||
url=url,
|
||||
elapsed_seconds=_MOCK_LATENCY_SECONDS,
|
||||
)
|
||||
if _original_async_handler_get is not None:
|
||||
return await _original_async_handler_get(
|
||||
self,
|
||||
url=url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
follow_redirects=follow_redirects,
|
||||
)
|
||||
raise RuntimeError("Original AsyncHTTPHandler.get not available")
|
||||
|
||||
|
||||
async def _mock_async_handler_delete(
|
||||
self,
|
||||
url,
|
||||
data=None,
|
||||
json=None,
|
||||
params=None,
|
||||
headers=None,
|
||||
timeout=None,
|
||||
stream=False,
|
||||
content=None,
|
||||
):
|
||||
"""Monkey-patched AsyncHTTPHandler.delete that intercepts GCS calls."""
|
||||
# Only mock GCS API calls
|
||||
if isinstance(url, str) and "storage.googleapis.com" in url:
|
||||
verbose_logger.info(f"[GCS MOCK] DELETE to {url}")
|
||||
await asyncio.sleep(_MOCK_LATENCY_SECONDS)
|
||||
# DELETE returns 204 No Content with empty body (not JSON)
|
||||
return MockResponse(
|
||||
status_code=204,
|
||||
json_data=None, # Empty body for DELETE
|
||||
url=url,
|
||||
elapsed_seconds=_MOCK_LATENCY_SECONDS,
|
||||
)
|
||||
if _original_async_handler_delete is not None:
|
||||
return await _original_async_handler_delete(
|
||||
self,
|
||||
url=url,
|
||||
data=data,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
content=content,
|
||||
)
|
||||
raise RuntimeError("Original AsyncHTTPHandler.delete not available")
|
||||
|
||||
|
||||
def create_mock_gcs_client():
|
||||
"""
|
||||
Monkey-patch AsyncHTTPHandler methods to intercept GCS calls.
|
||||
|
||||
AsyncHTTPHandler is used by LiteLLM's get_async_httpx_client() which is what
|
||||
GCSBucketBase uses for making API calls.
|
||||
|
||||
This function is idempotent - it only initializes mocks once, even if called multiple times.
|
||||
"""
|
||||
global _original_async_handler_get, _original_async_handler_delete, _mocks_initialized
|
||||
|
||||
# Use factory for POST handler
|
||||
_create_mock_gcs_post()
|
||||
|
||||
# If already initialized, skip GET/DELETE patching
|
||||
if _mocks_initialized:
|
||||
return
|
||||
|
||||
verbose_logger.debug("[GCS MOCK] Initializing GCS GET/DELETE handlers...")
|
||||
|
||||
# Patch GET and DELETE handlers (GCS-specific)
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
if _original_async_handler_get is None:
|
||||
_original_async_handler_get = AsyncHTTPHandler.get
|
||||
AsyncHTTPHandler.get = _mock_async_handler_get # type: ignore
|
||||
verbose_logger.debug("[GCS MOCK] Patched AsyncHTTPHandler.get")
|
||||
|
||||
if _original_async_handler_delete is None:
|
||||
_original_async_handler_delete = AsyncHTTPHandler.delete
|
||||
AsyncHTTPHandler.delete = _mock_async_handler_delete # type: ignore
|
||||
verbose_logger.debug("[GCS MOCK] Patched AsyncHTTPHandler.delete")
|
||||
|
||||
verbose_logger.debug(
|
||||
f"[GCS MOCK] Mock latency set to {_MOCK_LATENCY_SECONDS*1000:.0f}ms"
|
||||
)
|
||||
verbose_logger.debug("[GCS MOCK] GCS mock client initialization complete")
|
||||
|
||||
_mocks_initialized = True
|
||||
|
||||
|
||||
def mock_vertex_auth_methods():
|
||||
"""
|
||||
Monkey-patch Vertex AI auth methods to return fake tokens.
|
||||
This prevents auth failures when GCS_MOCK is enabled.
|
||||
|
||||
This function is idempotent - it only patches once, even if called multiple times.
|
||||
"""
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
|
||||
# Store original methods if not already stored
|
||||
if not hasattr(VertexBase, "_original_ensure_access_token_async"):
|
||||
setattr(
|
||||
VertexBase,
|
||||
"_original_ensure_access_token_async",
|
||||
VertexBase._ensure_access_token_async,
|
||||
)
|
||||
setattr(
|
||||
VertexBase, "_original_ensure_access_token", VertexBase._ensure_access_token
|
||||
)
|
||||
setattr(
|
||||
VertexBase, "_original_get_token_and_url", VertexBase._get_token_and_url
|
||||
)
|
||||
|
||||
async def _mock_ensure_access_token_async(
|
||||
self, credentials, project_id, custom_llm_provider
|
||||
):
|
||||
"""Mock async auth method - returns fake token."""
|
||||
verbose_logger.debug(
|
||||
"[GCS MOCK] Vertex AI auth: _ensure_access_token_async called"
|
||||
)
|
||||
return ("mock-gcs-token", "mock-project-id")
|
||||
|
||||
def _mock_ensure_access_token(
|
||||
self, credentials, project_id, custom_llm_provider
|
||||
):
|
||||
"""Mock sync auth method - returns fake token."""
|
||||
verbose_logger.debug(
|
||||
"[GCS MOCK] Vertex AI auth: _ensure_access_token called"
|
||||
)
|
||||
return ("mock-gcs-token", "mock-project-id")
|
||||
|
||||
def _mock_get_token_and_url(
|
||||
self,
|
||||
model,
|
||||
auth_header,
|
||||
vertex_credentials,
|
||||
vertex_project,
|
||||
vertex_location,
|
||||
gemini_api_key,
|
||||
stream,
|
||||
custom_llm_provider,
|
||||
api_base,
|
||||
):
|
||||
"""Mock get_token_and_url - returns fake token."""
|
||||
verbose_logger.debug("[GCS MOCK] Vertex AI auth: _get_token_and_url called")
|
||||
return ("mock-gcs-token", "https://storage.googleapis.com")
|
||||
|
||||
# Patch the methods
|
||||
VertexBase._ensure_access_token_async = _mock_ensure_access_token_async # type: ignore
|
||||
VertexBase._ensure_access_token = _mock_ensure_access_token # type: ignore
|
||||
VertexBase._get_token_and_url = _mock_get_token_and_url # type: ignore
|
||||
|
||||
verbose_logger.debug("[GCS MOCK] Patched Vertex AI auth methods")
|
||||
|
||||
|
||||
# should_use_gcs_mock is already created by the factory
|
||||
Reference in New Issue
Block a user