chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
# Integrations
|
||||
|
||||
This folder contains logging integrations for litellm
|
||||
|
||||
eg. logging to Datadog, Langfuse, Prometheus, s3, GCS Bucket, etc.
|
||||
@@ -0,0 +1,46 @@
|
||||
# Slack Alerting on LiteLLM Gateway
|
||||
|
||||
This folder contains the Slack Alerting integration for LiteLLM Gateway.
|
||||
|
||||
## Folder Structure
|
||||
|
||||
- `slack_alerting.py`: This is the main file that handles sending different types of alerts
|
||||
- `batching_handler.py`: Handles Batching + sending Httpx Post requests to slack. Slack alerts are sent every 10s or when events are greater than X events. Done to ensure litellm has good performance under high traffic
|
||||
- `types.py`: This file contains the AlertType enum which is used to define the different types of alerts that can be sent to Slack.
|
||||
- `utils.py`: This file contains common utils used specifically for slack alerting
|
||||
|
||||
## Budget Alert Types
|
||||
|
||||
The `budget_alert_types.py` module provides a flexible framework for handling different types of budget alerts:
|
||||
|
||||
- `BaseBudgetAlertType`: An abstract base class with abstract methods that all alert types must implement:
|
||||
- `get_event_group()`: Returns the Litellm_EntityType for the alert
|
||||
- `get_event_message()`: Returns the message prefix for the alert
|
||||
- `get_id(user_info)`: Returns the ID to use for caching/tracking the alert
|
||||
|
||||
Concrete implementations include:
|
||||
- `ProxyBudgetAlert`: Alerting for proxy-level budget concerns
|
||||
- `SoftBudgetAlert`: Alerting when soft budgets are crossed
|
||||
- `UserBudgetAlert`: Alerting for user-level budget concerns
|
||||
- `TeamBudgetAlert`: Alerting for team-level budget concerns
|
||||
- `TokenBudgetAlert`: Alerting for API key budget concerns
|
||||
- `ProjectedLimitExceededAlert`: Alerting when projected spend will exceed budget
|
||||
|
||||
Use the `get_budget_alert_type()` factory function to get the appropriate alert type class for a given alert type string:
|
||||
|
||||
```python
|
||||
from litellm.integrations.SlackAlerting.budget_alert_types import get_budget_alert_type
|
||||
|
||||
# Get the appropriate handler
|
||||
budget_alert_class = get_budget_alert_type("user_budget")
|
||||
|
||||
# Use the handler methods
|
||||
event_group = budget_alert_class.get_event_group() # Returns Litellm_EntityType.USER
|
||||
event_message = budget_alert_class.get_event_message() # Returns "User Budget: "
|
||||
cache_id = budget_alert_class.get_id(user_info) # Returns user_id
|
||||
```
|
||||
|
||||
To add a new budget alert type, simply create a new class that extends `BaseBudgetAlertType` and implements all the required methods, then add it to the dictionary in the `get_budget_alert_type()` function.
|
||||
|
||||
## Further Reading
|
||||
- [Doc setting up Alerting on LiteLLM Proxy (Gateway)](https://docs.litellm.ai/docs/proxy/alerting)
|
||||
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Handles Batching + sending Httpx Post requests to slack
|
||||
|
||||
Slack alerts are sent every 10s or when events are greater than X events
|
||||
|
||||
see custom_batch_logger.py for more details / defaults
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .slack_alerting import SlackAlerting as _SlackAlerting
|
||||
|
||||
SlackAlertingType = _SlackAlerting
|
||||
else:
|
||||
SlackAlertingType = Any
|
||||
|
||||
|
||||
def squash_payloads(queue):
|
||||
squashed = {}
|
||||
if len(queue) == 0:
|
||||
return squashed
|
||||
if len(queue) == 1:
|
||||
return {"key": {"item": queue[0], "count": 1}}
|
||||
|
||||
for item in queue:
|
||||
url = item["url"]
|
||||
alert_type = item["alert_type"]
|
||||
_key = (url, alert_type)
|
||||
|
||||
if _key in squashed:
|
||||
squashed[_key]["count"] += 1
|
||||
# Merge the payloads
|
||||
|
||||
else:
|
||||
squashed[_key] = {"item": item, "count": 1}
|
||||
|
||||
return squashed
|
||||
|
||||
|
||||
def _print_alerting_payload_warning(
|
||||
payload: dict, slackAlertingInstance: SlackAlertingType
|
||||
):
|
||||
"""
|
||||
Print the payload to the console when
|
||||
slackAlertingInstance.alerting_args.log_to_console is True
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/7372
|
||||
"""
|
||||
if slackAlertingInstance.alerting_args.log_to_console is True:
|
||||
verbose_proxy_logger.warning(payload)
|
||||
|
||||
|
||||
async def send_to_webhook(slackAlertingInstance: SlackAlertingType, item, count):
|
||||
"""
|
||||
Send a single slack alert to the webhook
|
||||
"""
|
||||
import json
|
||||
|
||||
payload = item.get("payload", {})
|
||||
try:
|
||||
if count > 1:
|
||||
payload["text"] = f"[Num Alerts: {count}]\n\n{payload['text']}"
|
||||
|
||||
response = await slackAlertingInstance.async_http_handler.post(
|
||||
url=item["url"],
|
||||
headers=item["headers"],
|
||||
data=json.dumps(payload),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Error sending slack alert to url={item['url']}. Error={response.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error sending slack alert: {str(e)}")
|
||||
finally:
|
||||
_print_alerting_payload_warning(
|
||||
payload, slackAlertingInstance=slackAlertingInstance
|
||||
)
|
||||
@@ -0,0 +1,115 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Literal
|
||||
|
||||
from litellm.proxy._types import CallInfo
|
||||
|
||||
|
||||
class BaseBudgetAlertType(ABC):
|
||||
"""Base class for different budget alert types"""
|
||||
|
||||
@abstractmethod
|
||||
def get_event_message(self) -> str:
|
||||
"""Return the event message for this alert type"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_id(self, user_info: CallInfo) -> str:
|
||||
"""Return the ID to use for caching/tracking this alert"""
|
||||
pass
|
||||
|
||||
|
||||
class ProxyBudgetAlert(BaseBudgetAlertType):
|
||||
def get_event_message(self) -> str:
|
||||
return "Proxy Budget: "
|
||||
|
||||
def get_id(self, user_info: CallInfo) -> str:
|
||||
return "default_id"
|
||||
|
||||
|
||||
class SoftBudgetAlert(BaseBudgetAlertType):
|
||||
def get_event_message(self) -> str:
|
||||
return "Soft Budget Crossed: "
|
||||
|
||||
def get_id(self, user_info: CallInfo) -> str:
|
||||
return user_info.token or "default_id"
|
||||
|
||||
|
||||
class UserBudgetAlert(BaseBudgetAlertType):
|
||||
def get_event_message(self) -> str:
|
||||
return "User Budget: "
|
||||
|
||||
def get_id(self, user_info: CallInfo) -> str:
|
||||
return user_info.user_id or "default_id"
|
||||
|
||||
|
||||
class TeamBudgetAlert(BaseBudgetAlertType):
|
||||
def get_event_message(self) -> str:
|
||||
return "Team Budget: "
|
||||
|
||||
def get_id(self, user_info: CallInfo) -> str:
|
||||
return user_info.team_id or "default_id"
|
||||
|
||||
|
||||
class OrganizationBudgetAlert(BaseBudgetAlertType):
|
||||
def get_event_message(self) -> str:
|
||||
return "Organization Budget: "
|
||||
|
||||
def get_id(self, user_info: CallInfo) -> str:
|
||||
return user_info.organization_id or "default_id"
|
||||
|
||||
|
||||
class TokenBudgetAlert(BaseBudgetAlertType):
|
||||
def get_event_message(self) -> str:
|
||||
return "Key Budget: "
|
||||
|
||||
def get_id(self, user_info: CallInfo) -> str:
|
||||
return user_info.token or "default_id"
|
||||
|
||||
|
||||
class ProjectedLimitExceededAlert(BaseBudgetAlertType):
|
||||
def get_event_message(self) -> str:
|
||||
return "Key Budget: Projected Limit Exceeded"
|
||||
|
||||
def get_id(self, user_info: CallInfo) -> str:
|
||||
return user_info.token or "default_id"
|
||||
|
||||
|
||||
class ProjectBudgetAlert(BaseBudgetAlertType):
|
||||
def get_event_message(self) -> str:
|
||||
return "Project Budget: "
|
||||
|
||||
def get_id(self, user_info: CallInfo) -> str:
|
||||
return user_info.token or "default_id"
|
||||
|
||||
|
||||
def get_budget_alert_type(
|
||||
type: Literal[
|
||||
"token_budget",
|
||||
"user_budget",
|
||||
"soft_budget",
|
||||
"max_budget_alert",
|
||||
"team_budget",
|
||||
"organization_budget",
|
||||
"proxy_budget",
|
||||
"projected_limit_exceeded",
|
||||
"project_budget",
|
||||
],
|
||||
) -> BaseBudgetAlertType:
|
||||
"""Factory function to get the appropriate budget alert type class"""
|
||||
|
||||
alert_types = {
|
||||
"proxy_budget": ProxyBudgetAlert(),
|
||||
"soft_budget": SoftBudgetAlert(),
|
||||
"user_budget": UserBudgetAlert(),
|
||||
"max_budget_alert": TokenBudgetAlert(),
|
||||
"team_budget": TeamBudgetAlert(),
|
||||
"organization_budget": OrganizationBudgetAlert(),
|
||||
"token_budget": TokenBudgetAlert(),
|
||||
"projected_limit_exceeded": ProjectedLimitExceededAlert(),
|
||||
"project_budget": ProjectBudgetAlert(),
|
||||
}
|
||||
|
||||
if type in alert_types:
|
||||
return alert_types[type]
|
||||
else:
|
||||
return ProxyBudgetAlert()
|
||||
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
Class to check for LLM API hanging requests
|
||||
|
||||
|
||||
Notes:
|
||||
- Do not create tasks that sleep, that can saturate the event loop
|
||||
- Do not store large objects (eg. messages in memory) that can increase RAM usage
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.in_memory_cache import InMemoryCache
|
||||
from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
|
||||
from litellm.types.integrations.slack_alerting import (
|
||||
HANGING_ALERT_BUFFER_TIME_SECONDS,
|
||||
MAX_OLDEST_HANGING_REQUESTS_TO_CHECK,
|
||||
HangingRequestData,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
|
||||
else:
|
||||
SlackAlerting = Any
|
||||
|
||||
|
||||
class AlertingHangingRequestCheck:
|
||||
"""
|
||||
Class to safely handle checking hanging requests alerts
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
slack_alerting_object: SlackAlerting,
|
||||
):
|
||||
self.slack_alerting_object = slack_alerting_object
|
||||
self.hanging_request_cache = InMemoryCache(
|
||||
default_ttl=int(
|
||||
self.slack_alerting_object.alerting_threshold
|
||||
+ HANGING_ALERT_BUFFER_TIME_SECONDS
|
||||
),
|
||||
)
|
||||
|
||||
async def add_request_to_hanging_request_check(
|
||||
self,
|
||||
request_data: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Add a request to the hanging request cache. This is the list of request_ids that gets periodicall checked for hanging requests
|
||||
"""
|
||||
if request_data is None:
|
||||
return
|
||||
|
||||
request_metadata = get_litellm_metadata_from_kwargs(kwargs=request_data)
|
||||
model = request_data.get("model", "")
|
||||
api_base: Optional[str] = None
|
||||
|
||||
if request_data.get("deployment", None) is not None and isinstance(
|
||||
request_data["deployment"], dict
|
||||
):
|
||||
api_base = litellm.get_api_base(
|
||||
model=model,
|
||||
optional_params=request_data["deployment"].get("litellm_params", {}),
|
||||
)
|
||||
|
||||
hanging_request_data = HangingRequestData(
|
||||
request_id=request_data.get("litellm_call_id", ""),
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
key_alias=request_metadata.get("user_api_key_alias", ""),
|
||||
team_alias=request_metadata.get("user_api_key_team_alias", ""),
|
||||
)
|
||||
|
||||
await self.hanging_request_cache.async_set_cache(
|
||||
key=hanging_request_data.request_id,
|
||||
value=hanging_request_data,
|
||||
ttl=int(
|
||||
self.slack_alerting_object.alerting_threshold
|
||||
+ HANGING_ALERT_BUFFER_TIME_SECONDS
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
async def send_alerts_for_hanging_requests(self):
|
||||
"""
|
||||
Send alerts for hanging requests
|
||||
"""
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
#########################################################
|
||||
# Find all requests that have been hanging for more than the alerting threshold
|
||||
# Get the last 50 oldest items in the cache and check if they have completed
|
||||
#########################################################
|
||||
# check if request_id is in internal usage cache
|
||||
if proxy_logging_obj.internal_usage_cache is None:
|
||||
return
|
||||
|
||||
hanging_requests = await self.hanging_request_cache.async_get_oldest_n_keys(
|
||||
n=MAX_OLDEST_HANGING_REQUESTS_TO_CHECK,
|
||||
)
|
||||
|
||||
for request_id in hanging_requests:
|
||||
hanging_request_data: Optional[
|
||||
HangingRequestData
|
||||
] = await self.hanging_request_cache.async_get_cache(
|
||||
key=request_id,
|
||||
)
|
||||
|
||||
if hanging_request_data is None:
|
||||
continue
|
||||
|
||||
request_status = (
|
||||
await proxy_logging_obj.internal_usage_cache.async_get_cache(
|
||||
key="request_status:{}".format(hanging_request_data.request_id),
|
||||
litellm_parent_otel_span=None,
|
||||
local_only=True,
|
||||
)
|
||||
)
|
||||
# this means the request status was either success or fail
|
||||
# and is not hanging
|
||||
if request_status is not None:
|
||||
# clear this request from hanging request cache since the request was either success or failed
|
||||
self.hanging_request_cache._remove_key(
|
||||
key=request_id,
|
||||
)
|
||||
continue
|
||||
|
||||
################
|
||||
# Send the Alert on Slack
|
||||
################
|
||||
await self.send_hanging_request_alert(
|
||||
hanging_request_data=hanging_request_data
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
async def check_for_hanging_requests(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
Background task that checks all request ids in self.hanging_request_cache to check if they have completed
|
||||
|
||||
Runs every alerting_threshold/2 seconds to check for hanging requests
|
||||
"""
|
||||
while True:
|
||||
verbose_proxy_logger.debug("Checking for hanging requests....")
|
||||
await self.send_alerts_for_hanging_requests()
|
||||
await asyncio.sleep(self.slack_alerting_object.alerting_threshold / 2)
|
||||
|
||||
async def send_hanging_request_alert(
|
||||
self,
|
||||
hanging_request_data: HangingRequestData,
|
||||
):
|
||||
"""
|
||||
Send a hanging request alert
|
||||
"""
|
||||
from litellm.integrations.SlackAlerting.slack_alerting import AlertType
|
||||
|
||||
################
|
||||
# Send the Alert on Slack
|
||||
################
|
||||
request_info = f"""Request Model: `{hanging_request_data.model}`
|
||||
API Base: `{hanging_request_data.api_base}`
|
||||
Key Alias: `{hanging_request_data.key_alias}`
|
||||
Team Alias: `{hanging_request_data.team_alias}`"""
|
||||
|
||||
alerting_message = f"`Requests are hanging - {self.slack_alerting_object.alerting_threshold}s+ request time`"
|
||||
await self.slack_alerting_object.send_alert(
|
||||
message=alerting_message + "\n" + request_info,
|
||||
level="Medium",
|
||||
alert_type=AlertType.llm_requests_hanging,
|
||||
alerting_metadata=hanging_request_data.alerting_metadata or {},
|
||||
request_model=hanging_request_data.model,
|
||||
api_base=hanging_request_data.api_base,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Utils used for slack alerting
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm.proxy._types import AlertType
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _Logging
|
||||
|
||||
Logging = _Logging
|
||||
else:
|
||||
Logging = Any
|
||||
|
||||
|
||||
def process_slack_alerting_variables(
|
||||
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]]
|
||||
) -> Optional[Dict[AlertType, Union[List[str], str]]]:
|
||||
"""
|
||||
process alert_to_webhook_url
|
||||
- check if any urls are set as os.environ/SLACK_WEBHOOK_URL_1 read env var and set the correct value
|
||||
"""
|
||||
if alert_to_webhook_url is None:
|
||||
return None
|
||||
|
||||
for alert_type, webhook_urls in alert_to_webhook_url.items():
|
||||
if isinstance(webhook_urls, list):
|
||||
_webhook_values: List[str] = []
|
||||
for webhook_url in webhook_urls:
|
||||
if "os.environ/" in webhook_url:
|
||||
_env_value = get_secret(secret_name=webhook_url)
|
||||
if not isinstance(_env_value, str):
|
||||
raise ValueError(
|
||||
f"Invalid webhook url value for: {webhook_url}. Got type={type(_env_value)}"
|
||||
)
|
||||
_webhook_values.append(_env_value)
|
||||
else:
|
||||
_webhook_values.append(webhook_url)
|
||||
|
||||
alert_to_webhook_url[alert_type] = _webhook_values
|
||||
else:
|
||||
_webhook_value_str: str = webhook_urls
|
||||
if "os.environ/" in webhook_urls:
|
||||
_env_value = get_secret(secret_name=webhook_urls)
|
||||
if not isinstance(_env_value, str):
|
||||
raise ValueError(
|
||||
f"Invalid webhook url value for: {webhook_urls}. Got type={type(_env_value)}"
|
||||
)
|
||||
_webhook_value_str = _env_value
|
||||
else:
|
||||
_webhook_value_str = webhook_urls
|
||||
|
||||
alert_to_webhook_url[alert_type] = _webhook_value_str
|
||||
|
||||
return alert_to_webhook_url
|
||||
|
||||
|
||||
async def _add_langfuse_trace_id_to_alert(
|
||||
request_data: Optional[dict] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Returns langfuse trace url
|
||||
|
||||
- check:
|
||||
-> existing_trace_id
|
||||
-> trace_id
|
||||
-> litellm_call_id
|
||||
"""
|
||||
if "langfuse" not in litellm.logging_callback_manager._get_all_callbacks():
|
||||
return None
|
||||
#########################################################
|
||||
# Only run if langfuse is added as a callback
|
||||
#########################################################
|
||||
|
||||
if (
|
||||
request_data is not None
|
||||
and request_data.get("litellm_logging_obj", None) is not None
|
||||
):
|
||||
trace_id: Optional[str] = None
|
||||
litellm_logging_obj: Logging = request_data["litellm_logging_obj"]
|
||||
|
||||
for _ in range(3):
|
||||
trace_id = litellm_logging_obj._get_trace_id(service_name="langfuse")
|
||||
if trace_id is not None:
|
||||
break
|
||||
await asyncio.sleep(3) # wait 3s before retrying for trace id
|
||||
#########################################################
|
||||
langfuse_object = litellm_logging_obj._get_callback_object(
|
||||
service_name="langfuse"
|
||||
)
|
||||
if langfuse_object is not None:
|
||||
base_url = langfuse_object.Langfuse.base_url
|
||||
return f"{base_url}/trace/{trace_id}"
|
||||
|
||||
return None
|
||||
@@ -0,0 +1 @@
|
||||
from . import *
|
||||
@@ -0,0 +1,440 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SpanAttributes:
|
||||
OUTPUT_VALUE = "output.value"
|
||||
OUTPUT_MIME_TYPE = "output.mime_type"
|
||||
"""
|
||||
The type of output.value. If unspecified, the type is plain text by default.
|
||||
If type is JSON, the value is a string representing a JSON object.
|
||||
"""
|
||||
INPUT_VALUE = "input.value"
|
||||
INPUT_MIME_TYPE = "input.mime_type"
|
||||
"""
|
||||
The type of input.value. If unspecified, the type is plain text by default.
|
||||
If type is JSON, the value is a string representing a JSON object.
|
||||
"""
|
||||
|
||||
EMBEDDING_EMBEDDINGS = "embedding.embeddings"
|
||||
"""
|
||||
A list of objects containing embedding data, including the vector and represented piece of text.
|
||||
"""
|
||||
EMBEDDING_MODEL_NAME = "embedding.model_name"
|
||||
"""
|
||||
The name of the embedding model.
|
||||
"""
|
||||
|
||||
LLM_FUNCTION_CALL = "llm.function_call"
|
||||
"""
|
||||
For models and APIs that support function calling. Records attributes such as the function
|
||||
name and arguments to the called function.
|
||||
"""
|
||||
LLM_INVOCATION_PARAMETERS = "llm.invocation_parameters"
|
||||
"""
|
||||
Invocation parameters passed to the LLM or API, such as the model name, temperature, etc.
|
||||
"""
|
||||
LLM_INPUT_MESSAGES = "llm.input_messages"
|
||||
"""
|
||||
Messages provided to a chat API.
|
||||
"""
|
||||
LLM_OUTPUT_MESSAGES = "llm.output_messages"
|
||||
"""
|
||||
Messages received from a chat API.
|
||||
"""
|
||||
LLM_MODEL_NAME = "llm.model_name"
|
||||
"""
|
||||
The name of the model being used.
|
||||
"""
|
||||
LLM_PROVIDER = "llm.provider"
|
||||
"""
|
||||
The provider of the model, such as OpenAI, Azure, Google, etc.
|
||||
"""
|
||||
LLM_SYSTEM = "llm.system"
|
||||
"""
|
||||
The AI product as identified by the client or server
|
||||
"""
|
||||
LLM_PROMPTS = "llm.prompts"
|
||||
"""
|
||||
Prompts provided to a completions API.
|
||||
"""
|
||||
LLM_PROMPT_TEMPLATE = "llm.prompt_template.template"
|
||||
"""
|
||||
The prompt template as a Python f-string.
|
||||
"""
|
||||
LLM_PROMPT_TEMPLATE_VARIABLES = "llm.prompt_template.variables"
|
||||
"""
|
||||
A list of input variables to the prompt template.
|
||||
"""
|
||||
LLM_PROMPT_TEMPLATE_VERSION = "llm.prompt_template.version"
|
||||
"""
|
||||
The version of the prompt template being used.
|
||||
"""
|
||||
LLM_TOKEN_COUNT_PROMPT = "llm.token_count.prompt"
|
||||
"""
|
||||
Number of tokens in the prompt.
|
||||
"""
|
||||
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE = (
|
||||
"llm.token_count.prompt_details.cache_write"
|
||||
)
|
||||
"""
|
||||
Number of tokens in the prompt that were written to cache.
|
||||
"""
|
||||
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ = (
|
||||
"llm.token_count.prompt_details.cache_read"
|
||||
)
|
||||
"""
|
||||
Number of tokens in the prompt that were read from cache.
|
||||
"""
|
||||
LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO = "llm.token_count.prompt_details.audio"
|
||||
"""
|
||||
The number of audio input tokens presented in the prompt
|
||||
"""
|
||||
LLM_TOKEN_COUNT_COMPLETION = "llm.token_count.completion"
|
||||
"""
|
||||
Number of tokens in the completion.
|
||||
"""
|
||||
LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING = (
|
||||
"llm.token_count.completion_details.reasoning"
|
||||
)
|
||||
"""
|
||||
Number of tokens used for reasoning steps in the completion.
|
||||
"""
|
||||
LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO = (
|
||||
"llm.token_count.completion_details.audio"
|
||||
)
|
||||
"""
|
||||
The number of audio input tokens generated by the model
|
||||
"""
|
||||
LLM_TOKEN_COUNT_TOTAL = "llm.token_count.total"
|
||||
"""
|
||||
Total number of tokens, including both prompt and completion.
|
||||
"""
|
||||
|
||||
LLM_TOOLS = "llm.tools"
|
||||
"""
|
||||
List of tools that are advertised to the LLM to be able to call
|
||||
"""
|
||||
|
||||
TOOL_NAME = "tool.name"
|
||||
"""
|
||||
Name of the tool being used.
|
||||
"""
|
||||
TOOL_DESCRIPTION = "tool.description"
|
||||
"""
|
||||
Description of the tool's purpose, typically used to select the tool.
|
||||
"""
|
||||
TOOL_PARAMETERS = "tool.parameters"
|
||||
"""
|
||||
Parameters of the tool represented a dictionary JSON string, e.g.
|
||||
see https://platform.openai.com/docs/guides/gpt/function-calling
|
||||
"""
|
||||
|
||||
RETRIEVAL_DOCUMENTS = "retrieval.documents"
|
||||
|
||||
METADATA = "metadata"
|
||||
"""
|
||||
Metadata attributes are used to store user-defined key-value pairs.
|
||||
For example, LangChain uses metadata to store user-defined attributes for a chain.
|
||||
"""
|
||||
|
||||
TAG_TAGS = "tag.tags"
|
||||
"""
|
||||
Custom categorical tags for the span.
|
||||
"""
|
||||
|
||||
OPENINFERENCE_SPAN_KIND = "openinference.span.kind"
|
||||
|
||||
SESSION_ID = "session.id"
|
||||
"""
|
||||
The id of the session
|
||||
"""
|
||||
USER_ID = "user.id"
|
||||
"""
|
||||
The id of the user
|
||||
"""
|
||||
|
||||
PROMPT_VENDOR = "prompt.vendor"
|
||||
"""
|
||||
The vendor or origin of the prompt, e.g. a prompt library, a specialized service, etc.
|
||||
"""
|
||||
PROMPT_ID = "prompt.id"
|
||||
"""
|
||||
A vendor-specific id used to locate the prompt.
|
||||
"""
|
||||
PROMPT_URL = "prompt.url"
|
||||
"""
|
||||
A vendor-specific url used to locate the prompt.
|
||||
"""
|
||||
|
||||
|
||||
class MessageAttributes:
|
||||
"""
|
||||
Attributes for a message sent to or from an LLM
|
||||
"""
|
||||
|
||||
MESSAGE_ROLE = "message.role"
|
||||
"""
|
||||
The role of the message, such as "user", "agent", "function".
|
||||
"""
|
||||
MESSAGE_CONTENT = "message.content"
|
||||
"""
|
||||
The content of the message to or from the llm, must be a string.
|
||||
"""
|
||||
MESSAGE_CONTENTS = "message.contents"
|
||||
"""
|
||||
The message contents to the llm, it is an array of
|
||||
`message_content` prefixed attributes.
|
||||
"""
|
||||
MESSAGE_NAME = "message.name"
|
||||
"""
|
||||
The name of the message, often used to identify the function
|
||||
that was used to generate the message.
|
||||
"""
|
||||
MESSAGE_TOOL_CALLS = "message.tool_calls"
|
||||
"""
|
||||
The tool calls generated by the model, such as function calls.
|
||||
"""
|
||||
MESSAGE_FUNCTION_CALL_NAME = "message.function_call_name"
|
||||
"""
|
||||
The function name that is a part of the message list.
|
||||
This is populated for role 'function' or 'agent' as a mechanism to identify
|
||||
the function that was called during the execution of a tool.
|
||||
"""
|
||||
MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON = "message.function_call_arguments_json"
|
||||
"""
|
||||
The JSON string representing the arguments passed to the function
|
||||
during a function call.
|
||||
"""
|
||||
MESSAGE_TOOL_CALL_ID = "message.tool_call_id"
|
||||
"""
|
||||
The id of the tool call.
|
||||
"""
|
||||
MESSAGE_REASONING_SUMMARY = "message.reasoning_summary"
|
||||
"""
|
||||
The reasoning summary from the model's chain-of-thought process.
|
||||
"""
|
||||
|
||||
|
||||
class MessageContentAttributes:
|
||||
"""
|
||||
Attributes for the contents of user messages sent to an LLM.
|
||||
"""
|
||||
|
||||
MESSAGE_CONTENT_TYPE = "message_content.type"
|
||||
"""
|
||||
The type of the content, such as "text" or "image".
|
||||
"""
|
||||
MESSAGE_CONTENT_TEXT = "message_content.text"
|
||||
"""
|
||||
The text content of the message, if the type is "text".
|
||||
"""
|
||||
MESSAGE_CONTENT_IMAGE = "message_content.image"
|
||||
"""
|
||||
The image content of the message, if the type is "image".
|
||||
An image can be made available to the model by passing a link to
|
||||
the image or by passing the base64 encoded image directly in the
|
||||
request.
|
||||
"""
|
||||
|
||||
|
||||
class ImageAttributes:
|
||||
"""
|
||||
Attributes for images
|
||||
"""
|
||||
|
||||
IMAGE_URL = "image.url"
|
||||
"""
|
||||
An http or base64 image url
|
||||
"""
|
||||
|
||||
|
||||
class AudioAttributes:
|
||||
"""
|
||||
Attributes for audio
|
||||
"""
|
||||
|
||||
AUDIO_URL = "audio.url"
|
||||
"""
|
||||
The url to an audio file
|
||||
"""
|
||||
AUDIO_MIME_TYPE = "audio.mime_type"
|
||||
"""
|
||||
The mime type of the audio file
|
||||
"""
|
||||
AUDIO_TRANSCRIPT = "audio.transcript"
|
||||
"""
|
||||
The transcript of the audio file
|
||||
"""
|
||||
|
||||
|
||||
class DocumentAttributes:
|
||||
"""
|
||||
Attributes for a document.
|
||||
"""
|
||||
|
||||
DOCUMENT_ID = "document.id"
|
||||
"""
|
||||
The id of the document.
|
||||
"""
|
||||
DOCUMENT_SCORE = "document.score"
|
||||
"""
|
||||
The score of the document
|
||||
"""
|
||||
DOCUMENT_CONTENT = "document.content"
|
||||
"""
|
||||
The content of the document.
|
||||
"""
|
||||
DOCUMENT_METADATA = "document.metadata"
|
||||
"""
|
||||
The metadata of the document represented as a dictionary
|
||||
JSON string, e.g. `"{ 'title': 'foo' }"`
|
||||
"""
|
||||
|
||||
|
||||
class RerankerAttributes:
|
||||
"""
|
||||
Attributes for a reranker
|
||||
"""
|
||||
|
||||
RERANKER_INPUT_DOCUMENTS = "reranker.input_documents"
|
||||
"""
|
||||
List of documents as input to the reranker
|
||||
"""
|
||||
RERANKER_OUTPUT_DOCUMENTS = "reranker.output_documents"
|
||||
"""
|
||||
List of documents as output from the reranker
|
||||
"""
|
||||
RERANKER_QUERY = "reranker.query"
|
||||
"""
|
||||
Query string for the reranker
|
||||
"""
|
||||
RERANKER_MODEL_NAME = "reranker.model_name"
|
||||
"""
|
||||
Model name of the reranker
|
||||
"""
|
||||
RERANKER_TOP_K = "reranker.top_k"
|
||||
"""
|
||||
Top K parameter of the reranker
|
||||
"""
|
||||
|
||||
|
||||
class EmbeddingAttributes:
|
||||
"""
|
||||
Attributes for an embedding
|
||||
"""
|
||||
|
||||
EMBEDDING_TEXT = "embedding.text"
|
||||
"""
|
||||
The text represented by the embedding.
|
||||
"""
|
||||
EMBEDDING_VECTOR = "embedding.vector"
|
||||
"""
|
||||
The embedding vector.
|
||||
"""
|
||||
|
||||
|
||||
class ToolCallAttributes:
|
||||
"""
|
||||
Attributes for a tool call
|
||||
"""
|
||||
|
||||
TOOL_CALL_ID = "tool_call.id"
|
||||
"""
|
||||
The id of the tool call.
|
||||
"""
|
||||
TOOL_CALL_FUNCTION_NAME = "tool_call.function.name"
|
||||
"""
|
||||
The name of function that is being called during a tool call.
|
||||
"""
|
||||
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = "tool_call.function.arguments"
|
||||
"""
|
||||
The JSON string representing the arguments passed to the function
|
||||
during a tool call.
|
||||
"""
|
||||
|
||||
|
||||
class ToolAttributes:
|
||||
"""
|
||||
Attributes for a tools
|
||||
"""
|
||||
|
||||
TOOL_JSON_SCHEMA = "tool.json_schema"
|
||||
"""
|
||||
The json schema of a tool input, It is RECOMMENDED that this be in the
|
||||
OpenAI tool calling format: https://platform.openai.com/docs/assistants/tools
|
||||
"""
|
||||
|
||||
|
||||
class OpenInferenceSpanKindValues(Enum):
|
||||
TOOL = "TOOL"
|
||||
CHAIN = "CHAIN"
|
||||
LLM = "LLM"
|
||||
RETRIEVER = "RETRIEVER"
|
||||
EMBEDDING = "EMBEDDING"
|
||||
AGENT = "AGENT"
|
||||
RERANKER = "RERANKER"
|
||||
UNKNOWN = "UNKNOWN"
|
||||
GUARDRAIL = "GUARDRAIL"
|
||||
EVALUATOR = "EVALUATOR"
|
||||
|
||||
|
||||
class OpenInferenceMimeTypeValues(Enum):
|
||||
TEXT = "text/plain"
|
||||
JSON = "application/json"
|
||||
|
||||
|
||||
class OpenInferenceLLMSystemValues(Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
COHERE = "cohere"
|
||||
MISTRALAI = "mistralai"
|
||||
VERTEXAI = "vertexai"
|
||||
|
||||
|
||||
class OpenInferenceLLMProviderValues(Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
COHERE = "cohere"
|
||||
MISTRALAI = "mistralai"
|
||||
GOOGLE = "google"
|
||||
AZURE = "azure"
|
||||
AWS = "aws"
|
||||
|
||||
|
||||
class ErrorAttributes:
|
||||
"""
|
||||
Attributes for error information in spans.
|
||||
|
||||
These attributes follow OpenTelemetry semantic conventions for exceptions
|
||||
and are used to record error information from StandardLoggingPayloadErrorInformation.
|
||||
"""
|
||||
|
||||
ERROR_TYPE = "error.type"
|
||||
"""
|
||||
The type/class of the error (e.g., 'ValueError', 'OpenAIError', 'RateLimitError').
|
||||
Corresponds to StandardLoggingPayloadErrorInformation.error_class
|
||||
"""
|
||||
|
||||
ERROR_MESSAGE = "error.message"
|
||||
"""
|
||||
The error message describing what went wrong.
|
||||
Corresponds to StandardLoggingPayloadErrorInformation.error_message
|
||||
"""
|
||||
|
||||
ERROR_CODE = "error.code"
|
||||
"""
|
||||
The error code (e.g., HTTP status code like '500', '429', or provider-specific codes).
|
||||
Corresponds to StandardLoggingPayloadErrorInformation.error_code
|
||||
"""
|
||||
|
||||
ERROR_STACK_TRACE = "error.stack_trace"
|
||||
"""
|
||||
The full stack trace of the error.
|
||||
Corresponds to StandardLoggingPayloadErrorInformation.traceback
|
||||
"""
|
||||
|
||||
ERROR_LLM_PROVIDER = "error.llm_provider"
|
||||
"""
|
||||
The LLM provider where the error occurred (e.g., 'openai', 'anthropic', 'azure').
|
||||
Corresponds to StandardLoggingPayloadErrorInformation.llm_provider
|
||||
"""
|
||||
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Base class for Additional Logging Utils for CustomLoggers
|
||||
|
||||
- Health Check for the logging util
|
||||
- Get Request / Response Payload for the logging util
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
|
||||
|
||||
|
||||
class AdditionalLoggingUtils(ABC):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
async def async_health_check(self) -> IntegrationHealthCheckStatus:
|
||||
"""
|
||||
Check if the service is healthy
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_request_response_payload(
|
||||
self,
|
||||
request_id: str,
|
||||
start_time_utc: Optional[datetime],
|
||||
end_time_utc: Optional[datetime],
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the request and response payload for a given `request_id`
|
||||
"""
|
||||
return None
|
||||
@@ -0,0 +1,3 @@
|
||||
from .agentops import AgentOps
|
||||
|
||||
__all__ = ["AgentOps"]
|
||||
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
AgentOps integration for LiteLLM - Provides OpenTelemetry tracing for LLM calls
|
||||
"""
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry, OpenTelemetryConfig
|
||||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentOpsConfig:
|
||||
endpoint: str = "https://otlp.agentops.cloud/v1/traces"
|
||||
api_key: Optional[str] = None
|
||||
service_name: Optional[str] = None
|
||||
deployment_environment: Optional[str] = None
|
||||
auth_endpoint: str = "https://api.agentops.ai/v3/auth/token"
|
||||
|
||||
@classmethod
|
||||
def from_env(cls):
|
||||
return cls(
|
||||
endpoint="https://otlp.agentops.cloud/v1/traces",
|
||||
api_key=os.getenv("AGENTOPS_API_KEY"),
|
||||
service_name=os.getenv("AGENTOPS_SERVICE_NAME", "agentops"),
|
||||
deployment_environment=os.getenv("AGENTOPS_ENVIRONMENT", "production"),
|
||||
auth_endpoint="https://api.agentops.ai/v3/auth/token",
|
||||
)
|
||||
|
||||
|
||||
class AgentOps(OpenTelemetry):
|
||||
"""
|
||||
AgentOps integration - built on top of OpenTelemetry
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
import litellm
|
||||
|
||||
litellm.success_callback = ["agentops"]
|
||||
|
||||
response = litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[AgentOpsConfig] = None,
|
||||
):
|
||||
if config is None:
|
||||
config = AgentOpsConfig.from_env()
|
||||
|
||||
# Prefetch JWT token for authentication
|
||||
jwt_token = None
|
||||
project_id = None
|
||||
if config.api_key:
|
||||
try:
|
||||
response = self._fetch_auth_token(config.api_key, config.auth_endpoint)
|
||||
jwt_token = response.get("token")
|
||||
project_id = response.get("project_id")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
headers = f"Authorization=Bearer {jwt_token}" if jwt_token else None
|
||||
|
||||
otel_config = OpenTelemetryConfig(
|
||||
exporter="otlp_http", endpoint=config.endpoint, headers=headers
|
||||
)
|
||||
|
||||
# Initialize OpenTelemetry with our config
|
||||
super().__init__(config=otel_config, callback_name="agentops")
|
||||
|
||||
# Set AgentOps-specific resource attributes
|
||||
resource_attrs = {
|
||||
"service.name": config.service_name or "litellm",
|
||||
"deployment.environment": config.deployment_environment or "production",
|
||||
"telemetry.sdk.name": "agentops",
|
||||
}
|
||||
|
||||
if project_id:
|
||||
resource_attrs["project.id"] = project_id
|
||||
|
||||
self.resource_attributes = resource_attrs
|
||||
|
||||
def _fetch_auth_token(self, api_key: str, auth_endpoint: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch JWT authentication token from AgentOps API
|
||||
|
||||
Args:
|
||||
api_key: AgentOps API key
|
||||
auth_endpoint: Authentication endpoint
|
||||
|
||||
Returns:
|
||||
Dict containing JWT token and project ID
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
|
||||
client = _get_httpx_client()
|
||||
try:
|
||||
response = client.post(
|
||||
url=auth_endpoint,
|
||||
headers=headers,
|
||||
json={"api_key": api_key},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to fetch auth token: {response.text}")
|
||||
|
||||
return response.json()
|
||||
finally:
|
||||
client.close()
|
||||
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
This hook is used to inject cache control directives into the messages of a chat completion.
|
||||
|
||||
Users can define
|
||||
- `cache_control_injection_points` in the completion params and litellm will inject the cache control directives into the messages at the specified injection points.
|
||||
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
from litellm.integrations.prompt_management_base import PromptManagementClient
|
||||
from litellm.types.integrations.anthropic_cache_control_hook import (
|
||||
CacheControlInjectionPoint,
|
||||
CacheControlMessageInjectionPoint,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionCachedContent
|
||||
from litellm.types.prompts.init_prompts import PromptSpec
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AnthropicCacheControlHook(CustomPromptManagement):
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
ignore_prompt_manager_model: Optional[bool] = False,
|
||||
ignore_prompt_manager_optional_params: Optional[bool] = False,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Apply cache control directives based on specified injection points.
|
||||
|
||||
Returns:
|
||||
- model: str - the model to use
|
||||
- messages: List[AllMessageValues] - messages with applied cache controls
|
||||
- non_default_params: dict - params with any global cache controls
|
||||
"""
|
||||
# Extract cache control injection points
|
||||
injection_points: List[CacheControlInjectionPoint] = non_default_params.pop(
|
||||
"cache_control_injection_points", []
|
||||
)
|
||||
if not injection_points:
|
||||
return model, messages, non_default_params
|
||||
|
||||
# Create a deep copy of messages to avoid modifying the original list
|
||||
processed_messages = copy.deepcopy(messages)
|
||||
|
||||
# Process message-level cache controls
|
||||
for point in injection_points:
|
||||
if point.get("location") == "message":
|
||||
point = cast(CacheControlMessageInjectionPoint, point)
|
||||
processed_messages = self._process_message_injection(
|
||||
point=point, messages=processed_messages
|
||||
)
|
||||
|
||||
return model, processed_messages, non_default_params
|
||||
|
||||
@staticmethod
|
||||
def _process_message_injection(
|
||||
point: CacheControlMessageInjectionPoint, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
"""Process message-level cache control injection."""
|
||||
control: ChatCompletionCachedContent = point.get(
|
||||
"control", None
|
||||
) or ChatCompletionCachedContent(type="ephemeral")
|
||||
|
||||
_targetted_index: Optional[Union[int, str]] = point.get("index", None)
|
||||
targetted_index: Optional[int] = None
|
||||
if isinstance(_targetted_index, str):
|
||||
try:
|
||||
targetted_index = int(_targetted_index)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
targetted_index = _targetted_index
|
||||
|
||||
targetted_role = point.get("role", None)
|
||||
|
||||
# Case 1: Target by specific index
|
||||
if targetted_index is not None:
|
||||
original_index = targetted_index
|
||||
# Handle negative indices (convert to positive)
|
||||
if targetted_index < 0:
|
||||
targetted_index += len(messages)
|
||||
|
||||
if 0 <= targetted_index < len(messages):
|
||||
messages[
|
||||
targetted_index
|
||||
] = AnthropicCacheControlHook._safe_insert_cache_control_in_message(
|
||||
messages[targetted_index], control
|
||||
)
|
||||
else:
|
||||
verbose_logger.warning(
|
||||
f"AnthropicCacheControlHook: Provided index {original_index} is out of bounds for message list of length {len(messages)}. "
|
||||
f"Targeted index was {targetted_index}. Skipping cache control injection for this point."
|
||||
)
|
||||
# Case 2: Target by role
|
||||
elif targetted_role is not None:
|
||||
for msg in messages:
|
||||
if msg.get("role") == targetted_role:
|
||||
msg = (
|
||||
AnthropicCacheControlHook._safe_insert_cache_control_in_message(
|
||||
message=msg, control=control
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _safe_insert_cache_control_in_message(
|
||||
message: AllMessageValues, control: ChatCompletionCachedContent
|
||||
) -> AllMessageValues:
|
||||
"""
|
||||
Safe way to insert cache control in a message
|
||||
|
||||
OpenAI Message content can be either:
|
||||
- string
|
||||
- list of objects
|
||||
|
||||
This method handles inserting cache control in both cases.
|
||||
Per Anthropic's API specification, when using multiple content blocks,
|
||||
only the last content block can have cache_control.
|
||||
"""
|
||||
message_content = message.get("content", None)
|
||||
|
||||
# 1. if string, insert cache control in the message
|
||||
if isinstance(message_content, str):
|
||||
message["cache_control"] = control # type: ignore
|
||||
# 2. list of objects - only apply to last item per Anthropic spec
|
||||
elif isinstance(message_content, list):
|
||||
if len(message_content) > 0 and isinstance(message_content[-1], dict):
|
||||
message_content[-1]["cache_control"] = control # type: ignore
|
||||
return message
|
||||
|
||||
@property
|
||||
def integration_name(self) -> str:
|
||||
"""Return the integration name for this hook."""
|
||||
return "anthropic_cache_control_hook"
|
||||
|
||||
def should_run_prompt_management(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> bool:
|
||||
"""Always return False since this is not a true prompt management system."""
|
||||
return False
|
||||
|
||||
def _compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> PromptManagementClient:
|
||||
"""Not used - this hook only modifies messages, doesn't fetch prompts."""
|
||||
return PromptManagementClient(
|
||||
prompt_id=prompt_id,
|
||||
prompt_template=[],
|
||||
prompt_template_model=None,
|
||||
prompt_template_optional_params=None,
|
||||
completed_messages=None,
|
||||
)
|
||||
|
||||
async def async_compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> PromptManagementClient:
|
||||
"""Not used - this hook only modifies messages, doesn't fetch prompts."""
|
||||
return self._compile_prompt_helper(
|
||||
prompt_id=prompt_id,
|
||||
prompt_spec=prompt_spec,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
)
|
||||
|
||||
async def async_get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
ignore_prompt_manager_model: Optional[bool] = False,
|
||||
ignore_prompt_manager_optional_params: Optional[bool] = False,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""Async version - delegates to sync since no async operations needed."""
|
||||
return self.get_chat_completion_prompt(
|
||||
model=model,
|
||||
messages=messages,
|
||||
non_default_params=non_default_params,
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
prompt_spec=prompt_spec,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
ignore_prompt_manager_model=ignore_prompt_manager_model,
|
||||
ignore_prompt_manager_optional_params=ignore_prompt_manager_optional_params,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_use_anthropic_cache_control_hook(non_default_params: Dict) -> bool:
|
||||
if non_default_params.get("cache_control_injection_points", None):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_custom_logger_for_anthropic_cache_control_hook(
|
||||
non_default_params: Dict,
|
||||
) -> Optional[CustomLogger]:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
_init_custom_logger_compatible_class,
|
||||
)
|
||||
|
||||
if AnthropicCacheControlHook.should_use_anthropic_cache_control_hook(
|
||||
non_default_params
|
||||
):
|
||||
return _init_custom_logger_compatible_class(
|
||||
logging_integration="anthropic_cache_control_hook",
|
||||
internal_usage_cache=None,
|
||||
llm_router=None,
|
||||
)
|
||||
return None
|
||||
@@ -0,0 +1,391 @@
|
||||
"""
|
||||
Send logs to Argilla for annotation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import types
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.integrations.argilla import (
|
||||
SUPPORTED_PAYLOAD_FIELDS,
|
||||
ArgillaCredentialsObject,
|
||||
ArgillaItem,
|
||||
)
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
def is_serializable(value):
|
||||
non_serializable_types = (
|
||||
types.CoroutineType,
|
||||
types.FunctionType,
|
||||
types.GeneratorType,
|
||||
BaseModel,
|
||||
)
|
||||
return not isinstance(value, non_serializable_types)
|
||||
|
||||
|
||||
class ArgillaLogger(CustomBatchLogger):
|
||||
def __init__(
|
||||
self,
|
||||
argilla_api_key: Optional[str] = None,
|
||||
argilla_dataset_name: Optional[str] = None,
|
||||
argilla_base_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if litellm.argilla_transformation_object is None:
|
||||
raise Exception(
|
||||
"'litellm.argilla_transformation_object' is required, to log your payload to Argilla."
|
||||
)
|
||||
self.validate_argilla_transformation_object(
|
||||
litellm.argilla_transformation_object
|
||||
)
|
||||
self.argilla_transformation_object = litellm.argilla_transformation_object
|
||||
self.default_credentials = self.get_credentials_from_env(
|
||||
argilla_api_key=argilla_api_key,
|
||||
argilla_dataset_name=argilla_dataset_name,
|
||||
argilla_base_url=argilla_base_url,
|
||||
)
|
||||
self.sampling_rate: float = (
|
||||
float(os.getenv("ARGILLA_SAMPLING_RATE")) # type: ignore
|
||||
if os.getenv("ARGILLA_SAMPLING_RATE") is not None
|
||||
and os.getenv("ARGILLA_SAMPLING_RATE").strip().isdigit() # type: ignore
|
||||
else 1.0
|
||||
)
|
||||
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
_batch_size = (
|
||||
os.getenv("ARGILLA_BATCH_SIZE", None) or litellm.argilla_batch_size
|
||||
)
|
||||
if _batch_size:
|
||||
self.batch_size = int(_batch_size)
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(**kwargs, flush_lock=self.flush_lock)
|
||||
|
||||
def validate_argilla_transformation_object(
|
||||
self, argilla_transformation_object: Dict[str, Any]
|
||||
):
|
||||
if not isinstance(argilla_transformation_object, dict):
|
||||
raise Exception(
|
||||
"'argilla_transformation_object' must be a dictionary, to log your payload to Argilla."
|
||||
)
|
||||
|
||||
for v in argilla_transformation_object.values():
|
||||
if v not in SUPPORTED_PAYLOAD_FIELDS:
|
||||
raise Exception(
|
||||
f"All values in argilla_transformation_object must be a key in SUPPORTED_PAYLOAD_FIELDS, {v} is not a valid key."
|
||||
)
|
||||
|
||||
def get_credentials_from_env(
|
||||
self,
|
||||
argilla_api_key: Optional[str],
|
||||
argilla_dataset_name: Optional[str],
|
||||
argilla_base_url: Optional[str],
|
||||
) -> ArgillaCredentialsObject:
|
||||
_credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY")
|
||||
if _credentials_api_key is None:
|
||||
raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.")
|
||||
|
||||
_credentials_base_url = (
|
||||
argilla_base_url
|
||||
or os.getenv("ARGILLA_BASE_URL")
|
||||
or "http://localhost:6900/"
|
||||
)
|
||||
if _credentials_base_url is None:
|
||||
raise Exception(
|
||||
"Invalid Argilla Base URL given. _credentials_base_url=None."
|
||||
)
|
||||
|
||||
_credentials_dataset_name = (
|
||||
argilla_dataset_name
|
||||
or os.getenv("ARGILLA_DATASET_NAME")
|
||||
or "litellm-completion"
|
||||
)
|
||||
if _credentials_dataset_name is None:
|
||||
raise Exception("Invalid Argilla Dataset give. Value=None.")
|
||||
else:
|
||||
dataset_response = litellm.module_level_client.get(
|
||||
url=f"{_credentials_base_url}/api/v1/me/datasets?name={_credentials_dataset_name}",
|
||||
headers={"X-Argilla-Api-Key": _credentials_api_key},
|
||||
)
|
||||
json_response = dataset_response.json()
|
||||
if (
|
||||
"items" in json_response
|
||||
and isinstance(json_response["items"], list)
|
||||
and len(json_response["items"]) > 0
|
||||
):
|
||||
_credentials_dataset_name = json_response["items"][0]["id"]
|
||||
|
||||
return ArgillaCredentialsObject(
|
||||
ARGILLA_API_KEY=_credentials_api_key,
|
||||
ARGILLA_BASE_URL=_credentials_base_url,
|
||||
ARGILLA_DATASET_NAME=_credentials_dataset_name,
|
||||
)
|
||||
|
||||
def get_chat_messages(
|
||||
self, payload: StandardLoggingPayload
|
||||
) -> List[Dict[str, Any]]:
|
||||
payload_messages = payload.get("messages", None)
|
||||
|
||||
if payload_messages is None:
|
||||
raise Exception("No chat messages found in payload.")
|
||||
|
||||
if (
|
||||
isinstance(payload_messages, list)
|
||||
and len(payload_messages) > 0
|
||||
and isinstance(payload_messages[0], dict)
|
||||
):
|
||||
return payload_messages
|
||||
elif isinstance(payload_messages, dict):
|
||||
return [payload_messages]
|
||||
else:
|
||||
raise Exception(f"Invalid chat messages format: {payload_messages}")
|
||||
|
||||
def get_str_response(self, payload: StandardLoggingPayload) -> str:
|
||||
response = payload["response"]
|
||||
|
||||
if response is None:
|
||||
raise Exception("No response found in payload.")
|
||||
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
elif isinstance(response, dict):
|
||||
return (
|
||||
response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Invalid response format: {response}")
|
||||
|
||||
def _prepare_log_data(
|
||||
self, kwargs, response_obj, start_time, end_time
|
||||
) -> Optional[ArgillaItem]:
|
||||
try:
|
||||
# Ensure everything in the payload is converted to str
|
||||
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
if payload is None:
|
||||
raise Exception("Error logging request payload. Payload=none.")
|
||||
|
||||
argilla_message = self.get_chat_messages(payload)
|
||||
argilla_response = self.get_str_response(payload)
|
||||
argilla_item: ArgillaItem = {"fields": {}}
|
||||
for k, v in self.argilla_transformation_object.items():
|
||||
if v == "messages":
|
||||
argilla_item["fields"][k] = argilla_message
|
||||
elif v == "response":
|
||||
argilla_item["fields"][k] = argilla_response
|
||||
else:
|
||||
argilla_item["fields"][k] = payload.get(v, None)
|
||||
|
||||
return argilla_item
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def _send_batch(self):
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"]
|
||||
argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"]
|
||||
|
||||
url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk"
|
||||
|
||||
argilla_api_key = self.default_credentials["ARGILLA_API_KEY"]
|
||||
|
||||
headers = {"X-Argilla-Api-Key": argilla_api_key}
|
||||
|
||||
try:
|
||||
response = litellm.module_level_client.post(
|
||||
url=url,
|
||||
json=self.log_queue,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code >= 300:
|
||||
verbose_logger.error(
|
||||
f"Argilla Error: {response.status_code} - {response.text}"
|
||||
)
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"Batch of {len(self.log_queue)} runs successfully created"
|
||||
)
|
||||
|
||||
self.log_queue.clear()
|
||||
except Exception:
|
||||
verbose_logger.exception("Argilla Layer Error - Error sending batch.")
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
sampling_rate = (
|
||||
float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore
|
||||
if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
|
||||
and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore
|
||||
else 1.0
|
||||
)
|
||||
random_sample = random.random()
|
||||
if random_sample > sampling_rate:
|
||||
verbose_logger.info(
|
||||
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
|
||||
sampling_rate, random_sample
|
||||
)
|
||||
)
|
||||
return # Skip logging
|
||||
verbose_logger.debug(
|
||||
"Langsmith Sync Layer Logging - kwargs: %s, response_obj: %s",
|
||||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||
if data is None:
|
||||
return
|
||||
|
||||
self.log_queue.append(data)
|
||||
verbose_logger.debug(
|
||||
f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
|
||||
)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
self._send_batch()
|
||||
|
||||
except Exception:
|
||||
verbose_logger.exception("Langsmith Layer Error - log_success_event error")
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
sampling_rate = self.sampling_rate
|
||||
random_sample = random.random()
|
||||
if random_sample > sampling_rate:
|
||||
verbose_logger.info(
|
||||
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
|
||||
sampling_rate, random_sample
|
||||
)
|
||||
)
|
||||
return # Skip logging
|
||||
verbose_logger.debug(
|
||||
"Langsmith Async Layer Logging - kwargs: %s, response_obj: %s",
|
||||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||
|
||||
## ALLOW CUSTOM LOGGERS TO MODIFY / FILTER DATA BEFORE LOGGING
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, CustomLogger):
|
||||
try:
|
||||
if data is None:
|
||||
break
|
||||
data = await callback.async_dataset_hook(data, payload)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
if data is None:
|
||||
return
|
||||
|
||||
self.log_queue.append(data)
|
||||
verbose_logger.debug(
|
||||
"Langsmith logging: queue length %s, batch size %s",
|
||||
len(self.log_queue),
|
||||
self.batch_size,
|
||||
)
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.flush_queue()
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
"Argilla Layer Error - error logging async success event."
|
||||
)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
sampling_rate = self.sampling_rate
|
||||
random_sample = random.random()
|
||||
if random_sample > sampling_rate:
|
||||
verbose_logger.info(
|
||||
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
|
||||
sampling_rate, random_sample
|
||||
)
|
||||
)
|
||||
return # Skip logging
|
||||
verbose_logger.info("Langsmith Failure Event Logging!")
|
||||
try:
|
||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||
self.log_queue.append(data)
|
||||
verbose_logger.debug(
|
||||
"Langsmith logging: queue length %s, batch size %s",
|
||||
len(self.log_queue),
|
||||
self.batch_size,
|
||||
)
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.flush_queue()
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
"Langsmith Layer Error - error logging async failure event."
|
||||
)
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
sends runs to /batch endpoint
|
||||
|
||||
Sends runs from self.log_queue
|
||||
|
||||
Returns: None
|
||||
|
||||
Raises: Does not raise an exception, will only verbose_logger.exception()
|
||||
"""
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"]
|
||||
argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"]
|
||||
|
||||
url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk"
|
||||
|
||||
argilla_api_key = self.default_credentials["ARGILLA_API_KEY"]
|
||||
|
||||
headers = {"X-Argilla-Api-Key": argilla_api_key}
|
||||
|
||||
try:
|
||||
response = await self.async_httpx_client.put(
|
||||
url=url,
|
||||
data=json.dumps(
|
||||
{
|
||||
"items": self.log_queue,
|
||||
}
|
||||
),
|
||||
headers=headers,
|
||||
timeout=60000,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
if response.status_code >= 300:
|
||||
verbose_logger.error(
|
||||
f"Argilla Error: {response.status_code} - {response.text}"
|
||||
)
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
"Batch of %s runs successfully created", len(self.log_queue)
|
||||
)
|
||||
except httpx.HTTPStatusError:
|
||||
verbose_logger.exception("Argilla HTTP Error")
|
||||
except Exception:
|
||||
verbose_logger.exception("Argilla Layer Error")
|
||||
@@ -0,0 +1,210 @@
|
||||
# Arize Phoenix Prompt Management Integration
|
||||
|
||||
This integration enables using prompt versions from Arize Phoenix with LiteLLM's completion function.
|
||||
|
||||
## Features
|
||||
|
||||
- Fetch prompt versions from Arize Phoenix API
|
||||
- Workspace-based access control through Arize Phoenix permissions
|
||||
- Mustache/Handlebars-style variable templating (`{{variable}}`)
|
||||
- Support for multi-message chat templates
|
||||
- Automatic model and parameter configuration from prompt metadata
|
||||
- OpenAI and Anthropic provider parameter support
|
||||
|
||||
## Configuration
|
||||
|
||||
Configure Arize Phoenix access in your application:
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
# Configure Arize Phoenix access
|
||||
# api_base should include your workspace, e.g., "https://app.phoenix.arize.com/s/your-workspace/v1"
|
||||
api_key = "your-arize-phoenix-token"
|
||||
api_base = "https://app.phoenix.arize.com/s/krrishdholakia/v1"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
# Use with completion
|
||||
response = litellm.completion(
|
||||
model="arize/gpt-4o",
|
||||
prompt_id="UHJvbXB0VmVyc2lvbjox", # Your prompt version ID
|
||||
prompt_variables={"question": "What is artificial intelligence?"},
|
||||
api_key="your-arize-phoenix-token",
|
||||
api_base="https://app.phoenix.arize.com/s/krrishdholakia/v1",
|
||||
)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
### With Additional Messages
|
||||
|
||||
You can also combine prompt templates with additional messages:
|
||||
|
||||
```python
|
||||
response = litellm.completion(
|
||||
model="arize/gpt-4o",
|
||||
prompt_id="UHJvbXB0VmVyc2lvbjox",
|
||||
prompt_variables={"question": "Explain quantum computing"},
|
||||
api_key="your-arize-phoenix-token",
|
||||
api_base="https://app.phoenix.arize.com/s/krrishdholakia/v1",
|
||||
messages=[
|
||||
{"role": "user", "content": "Please keep your response under 100 words."}
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
### Direct Manager Usage
|
||||
|
||||
You can also use the prompt manager directly:
|
||||
|
||||
```python
|
||||
from litellm.integrations.arize.arize_phoenix_prompt_manager import ArizePhoenixPromptManager
|
||||
|
||||
# Initialize the manager
|
||||
manager = ArizePhoenixPromptManager(
|
||||
api_key="your-arize-phoenix-token",
|
||||
api_base="https://app.phoenix.arize.com/s/krrishdholakia/v1",
|
||||
prompt_id="UHJvbXB0VmVyc2lvbjox",
|
||||
)
|
||||
|
||||
# Get rendered messages
|
||||
messages, metadata = manager.get_prompt_template(
|
||||
prompt_id="UHJvbXB0VmVyc2lvbjox",
|
||||
prompt_variables={"question": "What is machine learning?"}
|
||||
)
|
||||
|
||||
print("Rendered messages:", messages)
|
||||
print("Metadata:", metadata)
|
||||
```
|
||||
|
||||
## Prompt Format
|
||||
|
||||
Arize Phoenix prompts support the following structure:
|
||||
|
||||
```json
|
||||
{
|
||||
"data": {
|
||||
"description": "A chatbot prompt",
|
||||
"model_provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"template": {
|
||||
"type": "chat",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are a chatbot"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "{{question}}"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"template_type": "CHAT",
|
||||
"template_format": "MUSTACHE",
|
||||
"invocation_parameters": {
|
||||
"type": "openai",
|
||||
"openai": {
|
||||
"temperature": 1.0
|
||||
}
|
||||
},
|
||||
"id": "UHJvbXB0VmVyc2lvbjox"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Variable Substitution
|
||||
|
||||
Variables in your prompt templates use Mustache/Handlebars syntax:
|
||||
- `{{variable_name}}` - Simple variable substitution
|
||||
|
||||
Example:
|
||||
```
|
||||
Template: "Hello {{name}}, your order {{order_id}} is ready!"
|
||||
Variables: {"name": "Alice", "order_id": "12345"}
|
||||
Result: "Hello Alice, your order 12345 is ready!"
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### ArizePhoenixPromptManager
|
||||
|
||||
Main class for managing Arize Phoenix prompts.
|
||||
|
||||
**Methods:**
|
||||
- `get_prompt_template(prompt_id, prompt_variables)` - Get and render a prompt template
|
||||
- `get_available_prompts()` - List available prompt IDs
|
||||
- `reload_prompts()` - Reload prompts from Arize Phoenix
|
||||
|
||||
### ArizePhoenixClient
|
||||
|
||||
Low-level client for Arize Phoenix API.
|
||||
|
||||
**Methods:**
|
||||
- `get_prompt_version(prompt_version_id)` - Fetch a prompt version
|
||||
- `test_connection()` - Test API connection
|
||||
|
||||
## Error Handling
|
||||
|
||||
The integration provides detailed error messages:
|
||||
|
||||
- **404**: Prompt version not found
|
||||
- **401**: Authentication failed (check your access token)
|
||||
- **403**: Access denied (check workspace permissions)
|
||||
|
||||
Example:
|
||||
```python
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="arize/gpt-4o",
|
||||
prompt_id="invalid-id",
|
||||
arize_config=arize_config,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
```
|
||||
|
||||
## Getting Your Prompt Version ID and API Base
|
||||
|
||||
1. Log in to Arize Phoenix
|
||||
2. Navigate to your workspace
|
||||
3. Go to Prompts section
|
||||
4. Select a prompt version
|
||||
5. The ID will be in the URL: `/s/{workspace}/v1/prompt_versions/{PROMPT_VERSION_ID}`
|
||||
|
||||
Your `api_base` should be: `https://app.phoenix.arize.com/s/{workspace}/v1`
|
||||
|
||||
For example:
|
||||
- Workspace: `krrishdholakia`
|
||||
- API Base: `https://app.phoenix.arize.com/s/krrishdholakia/v1`
|
||||
- Prompt Version ID: `UHJvbXB0VmVyc2lvbjox`
|
||||
|
||||
You can also fetch it via API:
|
||||
```bash
|
||||
curl -L -X GET 'https://app.phoenix.arize.com/s/krrishdholakia/v1/prompt_versions/UHJvbXB0VmVyc2lvbjox' \
|
||||
-H 'Authorization: Bearer YOUR_TOKEN'
|
||||
```
|
||||
|
||||
## Support
|
||||
|
||||
For issues or questions:
|
||||
- LiteLLM Issues: https://github.com/BerriAI/litellm/issues
|
||||
- Arize Phoenix Docs: https://docs.arize.com/phoenix
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.prompts.init_prompts import PromptLiteLLMParams, PromptSpec
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
|
||||
from litellm.types.prompts.init_prompts import SupportedPromptIntegrations
|
||||
|
||||
from .arize_phoenix_prompt_manager import ArizePhoenixPromptManager
|
||||
|
||||
# Global instances
|
||||
global_arize_config: Optional[dict] = None
|
||||
|
||||
|
||||
def prompt_initializer(
|
||||
litellm_params: "PromptLiteLLMParams", prompt_spec: "PromptSpec"
|
||||
) -> "CustomPromptManagement":
|
||||
"""
|
||||
Initialize a prompt from Arize Phoenix.
|
||||
"""
|
||||
api_key = getattr(litellm_params, "api_key", None) or os.environ.get(
|
||||
"PHOENIX_API_KEY"
|
||||
)
|
||||
api_base = getattr(litellm_params, "api_base", None)
|
||||
prompt_id = getattr(litellm_params, "prompt_id", None)
|
||||
|
||||
if not api_key or not api_base:
|
||||
raise ValueError(
|
||||
"api_key and api_base are required for Arize Phoenix prompt integration"
|
||||
)
|
||||
|
||||
try:
|
||||
arize_prompt_manager = ArizePhoenixPromptManager(
|
||||
**{
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"prompt_id": prompt_id,
|
||||
**litellm_params.model_dump(
|
||||
exclude={"api_key", "api_base", "prompt_id"}
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return arize_prompt_manager
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
prompt_initializer_registry = {
|
||||
SupportedPromptIntegrations.ARIZE_PHOENIX.value: prompt_initializer,
|
||||
}
|
||||
@@ -0,0 +1,502 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.opentelemetry_utils.base_otel_llm_obs_attributes import (
|
||||
BaseLLMObsOTELAttributes,
|
||||
safe_set_attribute,
|
||||
)
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span
|
||||
from litellm.integrations._types.open_inference import (
|
||||
MessageAttributes,
|
||||
ImageAttributes,
|
||||
SpanAttributes,
|
||||
AudioAttributes,
|
||||
EmbeddingAttributes,
|
||||
OpenInferenceSpanKindValues,
|
||||
)
|
||||
|
||||
|
||||
class ArizeOTELAttributes(BaseLLMObsOTELAttributes):
|
||||
@staticmethod
|
||||
@override
|
||||
def set_messages(span: "Span", kwargs: Dict[str, Any]):
|
||||
messages = kwargs.get("messages")
|
||||
|
||||
# for /chat/completions
|
||||
# https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.INPUT_VALUE,
|
||||
last_message.get("content", ""),
|
||||
)
|
||||
|
||||
# LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page.
|
||||
for idx, msg in enumerate(messages):
|
||||
prefix = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}"
|
||||
# Set the role per message.
|
||||
safe_set_attribute(
|
||||
span, f"{prefix}.{MessageAttributes.MESSAGE_ROLE}", msg.get("role")
|
||||
)
|
||||
# Set the content per message.
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{MessageAttributes.MESSAGE_CONTENT}",
|
||||
msg.get("content", ""),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def set_response_output_messages(span: "Span", response_obj):
|
||||
"""
|
||||
Sets output message attributes on the span from the LLM response.
|
||||
Args:
|
||||
span: The OpenTelemetry span to set attributes on
|
||||
response_obj: The response object containing choices with messages
|
||||
"""
|
||||
from litellm.integrations._types.open_inference import (
|
||||
MessageAttributes,
|
||||
SpanAttributes,
|
||||
)
|
||||
|
||||
for idx, choice in enumerate(response_obj.get("choices", [])):
|
||||
response_message = choice.get("message", {})
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.OUTPUT_VALUE,
|
||||
response_message.get("content", ""),
|
||||
)
|
||||
|
||||
# This shows up under `output_messages` tab on the span page.
|
||||
prefix = f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{idx}"
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{MessageAttributes.MESSAGE_ROLE}",
|
||||
response_message.get("role"),
|
||||
)
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{MessageAttributes.MESSAGE_CONTENT}",
|
||||
response_message.get("content", ""),
|
||||
)
|
||||
|
||||
|
||||
def _set_response_attributes(span: "Span", response_obj):
|
||||
"""Helper to set response output and token usage attributes on span."""
|
||||
|
||||
if not hasattr(response_obj, "get"):
|
||||
return
|
||||
|
||||
_set_choice_outputs(span, response_obj, MessageAttributes, SpanAttributes)
|
||||
_set_image_outputs(span, response_obj, ImageAttributes, SpanAttributes)
|
||||
_set_audio_outputs(span, response_obj, AudioAttributes, SpanAttributes)
|
||||
_set_embedding_outputs(span, response_obj, EmbeddingAttributes, SpanAttributes)
|
||||
_set_structured_outputs(span, response_obj, MessageAttributes, SpanAttributes)
|
||||
_set_usage_outputs(span, response_obj, SpanAttributes)
|
||||
|
||||
|
||||
def _set_choice_outputs(span: "Span", response_obj, msg_attrs, span_attrs):
|
||||
for idx, choice in enumerate(response_obj.get("choices", [])):
|
||||
response_message = choice.get("message", {})
|
||||
safe_set_attribute(
|
||||
span,
|
||||
span_attrs.OUTPUT_VALUE,
|
||||
response_message.get("content", ""),
|
||||
)
|
||||
prefix = f"{span_attrs.LLM_OUTPUT_MESSAGES}.{idx}"
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{msg_attrs.MESSAGE_ROLE}",
|
||||
response_message.get("role"),
|
||||
)
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{msg_attrs.MESSAGE_CONTENT}",
|
||||
response_message.get("content", ""),
|
||||
)
|
||||
|
||||
|
||||
def _set_image_outputs(span: "Span", response_obj, image_attrs, span_attrs):
|
||||
images = response_obj.get("data", [])
|
||||
for i, image in enumerate(images):
|
||||
img_url = image.get("url")
|
||||
if img_url is None and image.get("b64_json"):
|
||||
img_url = f"data:image/png;base64,{image.get('b64_json')}"
|
||||
|
||||
if not img_url:
|
||||
continue
|
||||
|
||||
if i == 0:
|
||||
safe_set_attribute(span, span_attrs.OUTPUT_VALUE, img_url)
|
||||
|
||||
safe_set_attribute(span, f"{image_attrs.IMAGE_URL}.{i}", img_url)
|
||||
|
||||
|
||||
def _set_audio_outputs(span: "Span", response_obj, audio_attrs, span_attrs):
|
||||
audio = response_obj.get("audio", [])
|
||||
for i, audio_item in enumerate(audio):
|
||||
audio_url = audio_item.get("url")
|
||||
if audio_url is None and audio_item.get("b64_json"):
|
||||
audio_url = f"data:audio/wav;base64,{audio_item.get('b64_json')}"
|
||||
|
||||
if audio_url:
|
||||
if i == 0:
|
||||
safe_set_attribute(span, span_attrs.OUTPUT_VALUE, audio_url)
|
||||
safe_set_attribute(span, f"{audio_attrs.AUDIO_URL}.{i}", audio_url)
|
||||
|
||||
audio_mime = audio_item.get("mime_type")
|
||||
if audio_mime:
|
||||
safe_set_attribute(span, f"{audio_attrs.AUDIO_MIME_TYPE}.{i}", audio_mime)
|
||||
|
||||
audio_transcript = audio_item.get("transcript")
|
||||
if audio_transcript:
|
||||
safe_set_attribute(
|
||||
span, f"{audio_attrs.AUDIO_TRANSCRIPT}.{i}", audio_transcript
|
||||
)
|
||||
|
||||
|
||||
def _set_embedding_outputs(span: "Span", response_obj, embedding_attrs, span_attrs):
|
||||
embeddings = response_obj.get("data", [])
|
||||
for i, embedding_item in enumerate(embeddings):
|
||||
embedding_vector = embedding_item.get("embedding")
|
||||
if embedding_vector:
|
||||
if i == 0:
|
||||
safe_set_attribute(
|
||||
span,
|
||||
span_attrs.OUTPUT_VALUE,
|
||||
str(embedding_vector),
|
||||
)
|
||||
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{embedding_attrs.EMBEDDING_VECTOR}.{i}",
|
||||
str(embedding_vector),
|
||||
)
|
||||
|
||||
embedding_text = embedding_item.get("text")
|
||||
if embedding_text:
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{embedding_attrs.EMBEDDING_TEXT}.{i}",
|
||||
str(embedding_text),
|
||||
)
|
||||
|
||||
|
||||
def _set_structured_outputs(span: "Span", response_obj, msg_attrs, span_attrs):
|
||||
output_items = response_obj.get("output", [])
|
||||
for i, item in enumerate(output_items):
|
||||
prefix = f"{span_attrs.LLM_OUTPUT_MESSAGES}.{i}"
|
||||
if not hasattr(item, "type"):
|
||||
continue
|
||||
|
||||
item_type = item.type
|
||||
if item_type == "reasoning" and hasattr(item, "summary"):
|
||||
for summary in item.summary:
|
||||
if hasattr(summary, "text"):
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{msg_attrs.MESSAGE_REASONING_SUMMARY}",
|
||||
summary.text,
|
||||
)
|
||||
elif item_type == "message" and hasattr(item, "content"):
|
||||
message_content = ""
|
||||
content_list = item.content
|
||||
if content_list and len(content_list) > 0:
|
||||
first_content = content_list[0]
|
||||
message_content = getattr(first_content, "text", "")
|
||||
message_role = getattr(item, "role", "assistant")
|
||||
safe_set_attribute(span, span_attrs.OUTPUT_VALUE, message_content)
|
||||
safe_set_attribute(
|
||||
span, f"{prefix}.{msg_attrs.MESSAGE_CONTENT}", message_content
|
||||
)
|
||||
safe_set_attribute(span, f"{prefix}.{msg_attrs.MESSAGE_ROLE}", message_role)
|
||||
|
||||
|
||||
def _set_usage_outputs(span: "Span", response_obj, span_attrs):
|
||||
usage = response_obj and response_obj.get("usage")
|
||||
if not usage:
|
||||
return
|
||||
|
||||
safe_set_attribute(
|
||||
span, span_attrs.LLM_TOKEN_COUNT_TOTAL, usage.get("total_tokens")
|
||||
)
|
||||
completion_tokens = usage.get("completion_tokens") or usage.get("output_tokens")
|
||||
if completion_tokens:
|
||||
safe_set_attribute(
|
||||
span, span_attrs.LLM_TOKEN_COUNT_COMPLETION, completion_tokens
|
||||
)
|
||||
prompt_tokens = usage.get("prompt_tokens") or usage.get("input_tokens")
|
||||
if prompt_tokens:
|
||||
safe_set_attribute(span, span_attrs.LLM_TOKEN_COUNT_PROMPT, prompt_tokens)
|
||||
reasoning_tokens = usage.get("output_tokens_details", {}).get("reasoning_tokens")
|
||||
if reasoning_tokens:
|
||||
safe_set_attribute(
|
||||
span,
|
||||
span_attrs.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING,
|
||||
reasoning_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _infer_open_inference_span_kind(call_type: Optional[str]) -> str:
|
||||
"""
|
||||
Map LiteLLM call types to OpenInference span kinds.
|
||||
"""
|
||||
|
||||
if not call_type:
|
||||
return OpenInferenceSpanKindValues.UNKNOWN.value
|
||||
|
||||
lowered = str(call_type).lower()
|
||||
|
||||
if "embed" in lowered:
|
||||
return OpenInferenceSpanKindValues.EMBEDDING.value
|
||||
|
||||
if "rerank" in lowered:
|
||||
return OpenInferenceSpanKindValues.RERANKER.value
|
||||
|
||||
if "search" in lowered:
|
||||
return OpenInferenceSpanKindValues.RETRIEVER.value
|
||||
|
||||
if "moderation" in lowered or "guardrail" in lowered:
|
||||
return OpenInferenceSpanKindValues.GUARDRAIL.value
|
||||
|
||||
if lowered == "call_mcp_tool" or lowered == "mcp" or lowered.endswith("tool"):
|
||||
return OpenInferenceSpanKindValues.TOOL.value
|
||||
|
||||
if "asend_message" in lowered or "a2a" in lowered or "assistant" in lowered:
|
||||
return OpenInferenceSpanKindValues.AGENT.value
|
||||
|
||||
if any(
|
||||
keyword in lowered
|
||||
for keyword in (
|
||||
"completion",
|
||||
"chat",
|
||||
"image",
|
||||
"audio",
|
||||
"speech",
|
||||
"transcription",
|
||||
"generate_content",
|
||||
"response",
|
||||
"videos",
|
||||
"realtime",
|
||||
"pass_through",
|
||||
"anthropic_messages",
|
||||
"ocr",
|
||||
)
|
||||
):
|
||||
return OpenInferenceSpanKindValues.LLM.value
|
||||
|
||||
if any(
|
||||
keyword in lowered
|
||||
for keyword in ("file", "batch", "container", "fine_tuning_job")
|
||||
):
|
||||
return OpenInferenceSpanKindValues.CHAIN.value
|
||||
|
||||
return OpenInferenceSpanKindValues.UNKNOWN.value
|
||||
|
||||
|
||||
def _set_tool_attributes(
|
||||
span: "Span", optional_tools: Optional[list], metadata_tools: Optional[list]
|
||||
):
|
||||
"""set tool attributes on span from optional_params or tool call metadata"""
|
||||
if optional_tools:
|
||||
for idx, tool in enumerate(optional_tools):
|
||||
if not isinstance(tool, dict):
|
||||
continue
|
||||
function = (
|
||||
tool.get("function") if isinstance(tool.get("function"), dict) else None
|
||||
)
|
||||
if not function:
|
||||
continue
|
||||
tool_name = function.get("name")
|
||||
if tool_name:
|
||||
safe_set_attribute(
|
||||
span, f"{SpanAttributes.LLM_TOOLS}.{idx}.name", tool_name
|
||||
)
|
||||
tool_description = function.get("description")
|
||||
if tool_description:
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{SpanAttributes.LLM_TOOLS}.{idx}.description",
|
||||
tool_description,
|
||||
)
|
||||
params = function.get("parameters")
|
||||
if params is not None:
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{SpanAttributes.LLM_TOOLS}.{idx}.parameters",
|
||||
json.dumps(params),
|
||||
)
|
||||
|
||||
if metadata_tools and isinstance(metadata_tools, list):
|
||||
for idx, tool in enumerate(metadata_tools):
|
||||
if not isinstance(tool, dict):
|
||||
continue
|
||||
tool_name = tool.get("name")
|
||||
if tool_name:
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{SpanAttributes.LLM_INVOCATION_PARAMETERS}.tools.{idx}.name",
|
||||
tool_name,
|
||||
)
|
||||
|
||||
tool_description = tool.get("description")
|
||||
if tool_description:
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{SpanAttributes.LLM_INVOCATION_PARAMETERS}.tools.{idx}.description",
|
||||
tool_description,
|
||||
)
|
||||
|
||||
|
||||
def set_attributes(
|
||||
span: "Span", kwargs, response_obj, attributes: Type[BaseLLMObsOTELAttributes]
|
||||
):
|
||||
"""
|
||||
Populates span with OpenInference-compliant LLM attributes for Arize and Phoenix tracing.
|
||||
"""
|
||||
try:
|
||||
optional_params = _sanitize_optional_params(kwargs.get("optional_params"))
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
if standard_logging_payload is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
|
||||
metadata = (
|
||||
standard_logging_payload.get("metadata")
|
||||
if standard_logging_payload
|
||||
else None
|
||||
)
|
||||
_set_metadata_attributes(span, metadata, SpanAttributes)
|
||||
|
||||
metadata_tools = _extract_metadata_tools(metadata)
|
||||
optional_tools = _extract_optional_tools(optional_params)
|
||||
|
||||
call_type = standard_logging_payload.get("call_type")
|
||||
_set_request_attributes(
|
||||
span=span,
|
||||
kwargs=kwargs,
|
||||
standard_logging_payload=standard_logging_payload,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
response_obj=response_obj,
|
||||
span_attrs=SpanAttributes,
|
||||
)
|
||||
|
||||
span_kind = _infer_open_inference_span_kind(call_type=call_type)
|
||||
_set_tool_attributes(span, optional_tools, metadata_tools)
|
||||
if (
|
||||
optional_tools or metadata_tools
|
||||
) and span_kind != OpenInferenceSpanKindValues.TOOL.value:
|
||||
span_kind = OpenInferenceSpanKindValues.TOOL.value
|
||||
|
||||
safe_set_attribute(span, SpanAttributes.OPENINFERENCE_SPAN_KIND, span_kind)
|
||||
attributes.set_messages(span, kwargs)
|
||||
|
||||
model_params = (
|
||||
standard_logging_payload.get("model_parameters")
|
||||
if standard_logging_payload
|
||||
else None
|
||||
)
|
||||
_set_model_params(span, model_params, SpanAttributes)
|
||||
|
||||
_set_response_attributes(span=span, response_obj=response_obj)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"[Arize/Phoenix] Failed to set OpenInference span attributes: {e}"
|
||||
)
|
||||
if hasattr(span, "record_exception"):
|
||||
span.record_exception(e)
|
||||
|
||||
|
||||
def _sanitize_optional_params(optional_params: Optional[dict]) -> dict:
|
||||
if not isinstance(optional_params, dict):
|
||||
return {}
|
||||
optional_params.pop("secret_fields", None)
|
||||
return optional_params
|
||||
|
||||
|
||||
def _set_metadata_attributes(span: "Span", metadata: Optional[Any], span_attrs) -> None:
|
||||
if metadata is not None:
|
||||
safe_set_attribute(span, span_attrs.METADATA, safe_dumps(metadata))
|
||||
|
||||
|
||||
def _extract_metadata_tools(metadata: Optional[Any]) -> Optional[list]:
|
||||
if not isinstance(metadata, dict):
|
||||
return None
|
||||
llm_obj = metadata.get("llm")
|
||||
if isinstance(llm_obj, dict):
|
||||
return llm_obj.get("tools")
|
||||
return None
|
||||
|
||||
|
||||
def _extract_optional_tools(optional_params: dict) -> Optional[list]:
|
||||
return optional_params.get("tools") if isinstance(optional_params, dict) else None
|
||||
|
||||
|
||||
def _set_request_attributes(
|
||||
span: "Span",
|
||||
kwargs,
|
||||
standard_logging_payload: StandardLoggingPayload,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
response_obj,
|
||||
span_attrs,
|
||||
):
|
||||
if kwargs.get("model"):
|
||||
safe_set_attribute(span, span_attrs.LLM_MODEL_NAME, kwargs.get("model"))
|
||||
|
||||
safe_set_attribute(
|
||||
span, "llm.request.type", standard_logging_payload.get("call_type")
|
||||
)
|
||||
safe_set_attribute(
|
||||
span,
|
||||
span_attrs.LLM_PROVIDER,
|
||||
litellm_params.get("custom_llm_provider", "Unknown"),
|
||||
)
|
||||
|
||||
if optional_params.get("max_tokens"):
|
||||
safe_set_attribute(
|
||||
span, "llm.request.max_tokens", optional_params.get("max_tokens")
|
||||
)
|
||||
if optional_params.get("temperature"):
|
||||
safe_set_attribute(
|
||||
span, "llm.request.temperature", optional_params.get("temperature")
|
||||
)
|
||||
if optional_params.get("top_p"):
|
||||
safe_set_attribute(span, "llm.request.top_p", optional_params.get("top_p"))
|
||||
|
||||
safe_set_attribute(
|
||||
span, "llm.is_streaming", str(optional_params.get("stream", False))
|
||||
)
|
||||
|
||||
if optional_params.get("user"):
|
||||
safe_set_attribute(span, "llm.user", optional_params.get("user"))
|
||||
|
||||
if response_obj and response_obj.get("id"):
|
||||
safe_set_attribute(span, "llm.response.id", response_obj.get("id"))
|
||||
if response_obj and response_obj.get("model"):
|
||||
safe_set_attribute(span, "llm.response.model", response_obj.get("model"))
|
||||
|
||||
|
||||
def _set_model_params(span: "Span", model_params: Optional[dict], span_attrs) -> None:
|
||||
if not model_params:
|
||||
return
|
||||
|
||||
safe_set_attribute(
|
||||
span, span_attrs.LLM_INVOCATION_PARAMETERS, safe_dumps(model_params)
|
||||
)
|
||||
if model_params.get("user"):
|
||||
user_id = model_params.get("user")
|
||||
if user_id is not None:
|
||||
safe_set_attribute(span, span_attrs.USER_ID, user_id)
|
||||
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
arize AI is OTEL compatible
|
||||
|
||||
this file has Arize ai specific helper functions
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from litellm.integrations.arize import _utils
|
||||
from litellm.integrations.arize._utils import ArizeOTELAttributes
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
from litellm.types.integrations.arize import ArizeConfig
|
||||
from litellm.types.services import ServiceLoggerPayload
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.types.integrations.arize import Protocol as _Protocol
|
||||
|
||||
Protocol = _Protocol
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Protocol = Any
|
||||
Span = Any
|
||||
|
||||
|
||||
class ArizeLogger(OpenTelemetry):
|
||||
"""
|
||||
Arize logger that sends traces to an Arize endpoint.
|
||||
|
||||
Creates its own dedicated TracerProvider so it can coexist with the
|
||||
generic ``otel`` callback (or any other OTEL-based integration) without
|
||||
fighting over the global ``opentelemetry.trace`` TracerProvider singleton.
|
||||
"""
|
||||
|
||||
def _init_tracing(self, tracer_provider):
|
||||
"""
|
||||
Override to always create a *private* TracerProvider for Arize.
|
||||
|
||||
See ArizePhoenixLogger._init_tracing for full rationale.
|
||||
"""
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.trace import SpanKind
|
||||
|
||||
if tracer_provider is not None:
|
||||
self.tracer = tracer_provider.get_tracer("litellm")
|
||||
self.span_kind = SpanKind
|
||||
return
|
||||
|
||||
provider = TracerProvider(resource=self._get_litellm_resource(self.config))
|
||||
provider.add_span_processor(self._get_span_processor())
|
||||
self.tracer = provider.get_tracer("litellm")
|
||||
self.span_kind = SpanKind
|
||||
|
||||
def _init_otel_logger_on_litellm_proxy(self):
|
||||
"""
|
||||
Override: Arize should NOT overwrite the proxy's
|
||||
``open_telemetry_logger``. That attribute is reserved for the
|
||||
primary ``otel`` callback which handles proxy-level parent spans.
|
||||
"""
|
||||
pass
|
||||
|
||||
def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]):
|
||||
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def set_arize_attributes(span: Span, kwargs, response_obj):
|
||||
_utils.set_attributes(span, kwargs, response_obj, ArizeOTELAttributes)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def get_arize_config() -> ArizeConfig:
|
||||
"""
|
||||
Helper function to get Arize configuration.
|
||||
|
||||
Returns:
|
||||
ArizeConfig: A Pydantic model containing Arize configuration.
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are not set.
|
||||
"""
|
||||
space_id = os.environ.get("ARIZE_SPACE_ID")
|
||||
space_key = os.environ.get("ARIZE_SPACE_KEY")
|
||||
api_key = os.environ.get("ARIZE_API_KEY")
|
||||
project_name = os.environ.get("ARIZE_PROJECT_NAME")
|
||||
|
||||
grpc_endpoint = os.environ.get("ARIZE_ENDPOINT")
|
||||
http_endpoint = os.environ.get("ARIZE_HTTP_ENDPOINT")
|
||||
|
||||
endpoint = None
|
||||
protocol: Protocol = "otlp_grpc"
|
||||
|
||||
if grpc_endpoint:
|
||||
protocol = "otlp_grpc"
|
||||
endpoint = grpc_endpoint
|
||||
elif http_endpoint:
|
||||
protocol = "otlp_http"
|
||||
endpoint = http_endpoint
|
||||
else:
|
||||
protocol = "otlp_grpc"
|
||||
endpoint = "https://otlp.arize.com/v1"
|
||||
|
||||
return ArizeConfig(
|
||||
space_id=space_id,
|
||||
space_key=space_key,
|
||||
api_key=api_key,
|
||||
protocol=protocol,
|
||||
endpoint=endpoint,
|
||||
project_name=project_name,
|
||||
)
|
||||
|
||||
async def async_service_success_hook(
|
||||
self,
|
||||
payload: ServiceLoggerPayload,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[datetime, float]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs"""
|
||||
pass
|
||||
|
||||
async def async_service_failure_hook(
|
||||
self,
|
||||
payload: ServiceLoggerPayload,
|
||||
error: Optional[str] = "",
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[float, datetime]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs"""
|
||||
pass
|
||||
|
||||
# def create_litellm_proxy_request_started_span(
|
||||
# self,
|
||||
# start_time: datetime,
|
||||
# headers: dict,
|
||||
# ):
|
||||
# """Arize is used mainly for LLM I/O tracing, sending Proxy Server Request adds bloat to arize logs"""
|
||||
# pass
|
||||
|
||||
async def async_health_check(self):
|
||||
"""
|
||||
Performs a health check for Arize integration.
|
||||
|
||||
Returns:
|
||||
dict: Health check result with status and message
|
||||
"""
|
||||
try:
|
||||
config = self.get_arize_config()
|
||||
|
||||
if not config.space_id and not config.space_key:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error_message": "ARIZE_SPACE_ID or ARIZE_SPACE_KEY environment variable not set",
|
||||
}
|
||||
|
||||
if not config.api_key:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error_message": "ARIZE_API_KEY environment variable not set",
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"message": "Arize credentials are configured properly",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error_message": f"Arize health check failed: {str(e)}",
|
||||
}
|
||||
|
||||
def construct_dynamic_otel_headers(
|
||||
self, standard_callback_dynamic_params: StandardCallbackDynamicParams
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Construct dynamic Arize headers from standard callback dynamic params
|
||||
|
||||
This is used for team/key based logging.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of dynamic Arize headers
|
||||
"""
|
||||
dynamic_headers = {}
|
||||
|
||||
#########################################################
|
||||
# `arize-space-id` handling
|
||||
# the suggested param is `arize_space_key`
|
||||
#########################################################
|
||||
if standard_callback_dynamic_params.get("arize_space_id"):
|
||||
dynamic_headers["arize-space-id"] = standard_callback_dynamic_params.get(
|
||||
"arize_space_id"
|
||||
)
|
||||
if standard_callback_dynamic_params.get("arize_space_key"):
|
||||
dynamic_headers["arize-space-id"] = standard_callback_dynamic_params.get(
|
||||
"arize_space_key"
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# `api_key` handling
|
||||
#########################################################
|
||||
if standard_callback_dynamic_params.get("arize_api_key"):
|
||||
dynamic_headers["api_key"] = standard_callback_dynamic_params.get(
|
||||
"arize_api_key"
|
||||
)
|
||||
|
||||
return dynamic_headers
|
||||
@@ -0,0 +1,360 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.arize import _utils
|
||||
from litellm.integrations.arize._utils import ArizeOTELAttributes
|
||||
from litellm.types.integrations.arize_phoenix import ArizePhoenixConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.trace import Span as _Span
|
||||
from opentelemetry.trace import SpanKind
|
||||
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry as _OpenTelemetry
|
||||
from litellm.integrations.opentelemetry import (
|
||||
OpenTelemetryConfig as _OpenTelemetryConfig,
|
||||
)
|
||||
from litellm.types.integrations.arize import Protocol as _Protocol
|
||||
|
||||
Protocol = _Protocol
|
||||
OpenTelemetryConfig = _OpenTelemetryConfig
|
||||
Span = Union[_Span, Any]
|
||||
OpenTelemetry = _OpenTelemetry
|
||||
else:
|
||||
Protocol = Any
|
||||
OpenTelemetryConfig = Any
|
||||
Span = Any
|
||||
TracerProvider = Any
|
||||
SpanKind = Any
|
||||
# Import OpenTelemetry at runtime
|
||||
try:
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
except ImportError:
|
||||
OpenTelemetry = None # type: ignore
|
||||
|
||||
|
||||
ARIZE_HOSTED_PHOENIX_ENDPOINT = "https://otlp.arize.com/v1/traces"
|
||||
|
||||
|
||||
class ArizePhoenixLogger(OpenTelemetry): # type: ignore
|
||||
"""
|
||||
Arize Phoenix logger that sends traces to a Phoenix endpoint.
|
||||
|
||||
Creates its own dedicated TracerProvider so it can coexist with the
|
||||
generic ``otel`` callback (or any other OTEL-based integration) without
|
||||
fighting over the global ``opentelemetry.trace`` TracerProvider singleton.
|
||||
"""
|
||||
|
||||
def _init_tracing(self, tracer_provider):
|
||||
"""
|
||||
Override to always create a *private* TracerProvider for Arize Phoenix.
|
||||
|
||||
The base ``OpenTelemetry._init_tracing`` falls back to the global
|
||||
TracerProvider when one already exists. That causes whichever
|
||||
integration initialises second to silently reuse the first one's
|
||||
exporter, so spans only reach one destination.
|
||||
|
||||
By creating our own provider we guarantee Arize Phoenix always gets
|
||||
its own exporter pipeline, regardless of initialisation order.
|
||||
"""
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.trace import SpanKind
|
||||
|
||||
if tracer_provider is not None:
|
||||
# Explicitly supplied (e.g. in tests) — honour it.
|
||||
self.tracer = tracer_provider.get_tracer("litellm")
|
||||
self.span_kind = SpanKind
|
||||
return
|
||||
|
||||
# Always create a dedicated provider — never touch the global one.
|
||||
provider = TracerProvider(resource=self._get_litellm_resource(self.config))
|
||||
provider.add_span_processor(self._get_span_processor())
|
||||
self.tracer = provider.get_tracer("litellm")
|
||||
self.span_kind = SpanKind
|
||||
verbose_logger.debug(
|
||||
"ArizePhoenixLogger: Created dedicated TracerProvider "
|
||||
"(endpoint=%s, exporter=%s)",
|
||||
self.config.endpoint,
|
||||
self.config.exporter,
|
||||
)
|
||||
|
||||
def _init_otel_logger_on_litellm_proxy(self):
|
||||
"""
|
||||
Override: Arize Phoenix should NOT overwrite the proxy's
|
||||
``open_telemetry_logger``. That attribute is reserved for the
|
||||
primary ``otel`` callback which handles proxy-level parent spans.
|
||||
"""
|
||||
pass
|
||||
|
||||
def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]):
|
||||
ArizePhoenixLogger.set_arize_phoenix_attributes(span, kwargs, response_obj)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def set_arize_phoenix_attributes(span: Span, kwargs, response_obj):
|
||||
from litellm.integrations.opentelemetry_utils.base_otel_llm_obs_attributes import (
|
||||
safe_set_attribute,
|
||||
)
|
||||
|
||||
_utils.set_attributes(span, kwargs, response_obj, ArizeOTELAttributes)
|
||||
|
||||
# Dynamic project name: check metadata first, then fall back to env var config
|
||||
dynamic_project_name = ArizePhoenixLogger._get_dynamic_project_name(kwargs)
|
||||
if dynamic_project_name:
|
||||
safe_set_attribute(span, "openinference.project.name", dynamic_project_name)
|
||||
else:
|
||||
# Fall back to static config from env var
|
||||
config = ArizePhoenixLogger.get_arize_phoenix_config()
|
||||
if config.project_name:
|
||||
safe_set_attribute(
|
||||
span, "openinference.project.name", config.project_name
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _get_dynamic_project_name(kwargs) -> Optional[str]:
|
||||
"""
|
||||
Retrieve dynamic Phoenix project name from request metadata.
|
||||
|
||||
Users can set `metadata.phoenix_project_name` in their request to route
|
||||
traces to different Phoenix projects dynamically.
|
||||
"""
|
||||
standard_logging_payload = kwargs.get("standard_logging_object")
|
||||
if isinstance(standard_logging_payload, dict):
|
||||
metadata = standard_logging_payload.get("metadata")
|
||||
if isinstance(metadata, dict):
|
||||
project_name = metadata.get("phoenix_project_name")
|
||||
if project_name:
|
||||
return str(project_name)
|
||||
|
||||
# Also check litellm_params.metadata for SDK usage
|
||||
litellm_params = kwargs.get("litellm_params")
|
||||
if isinstance(litellm_params, dict):
|
||||
metadata = litellm_params.get("metadata") or {}
|
||||
else:
|
||||
metadata = {}
|
||||
if isinstance(metadata, dict):
|
||||
project_name = metadata.get("phoenix_project_name")
|
||||
if project_name:
|
||||
return str(project_name)
|
||||
|
||||
return None
|
||||
|
||||
def _get_phoenix_context(self, kwargs):
|
||||
"""
|
||||
Build a trace context for Phoenix's dedicated TracerProvider.
|
||||
|
||||
The base ``_get_span_context`` returns parent spans from the global
|
||||
TracerProvider (the ``otel`` callback). Those spans live on a
|
||||
*different* TracerProvider, so they won't appear in Phoenix — using
|
||||
them as parents just creates broken links.
|
||||
|
||||
Instead we:
|
||||
1. Honour an incoming ``traceparent`` HTTP header (distributed tracing).
|
||||
2. In proxy mode, create our *own* parent span on Phoenix's tracer
|
||||
so the hierarchy is visible end-to-end inside Phoenix.
|
||||
3. In SDK (non-proxy) mode, just return (None, None) for a root span.
|
||||
"""
|
||||
from opentelemetry import trace
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
proxy_server_request = litellm_params.get("proxy_server_request", {}) or {}
|
||||
headers = proxy_server_request.get("headers", {}) or {}
|
||||
|
||||
# Propagate distributed trace context if the caller sent a traceparent
|
||||
traceparent_ctx = (
|
||||
self.get_traceparent_from_header(headers=headers)
|
||||
if headers.get("traceparent")
|
||||
else None
|
||||
)
|
||||
|
||||
is_proxy_mode = bool(proxy_server_request)
|
||||
|
||||
if is_proxy_mode:
|
||||
# Create a parent span on Phoenix's own tracer so both parent
|
||||
# and child are exported to Phoenix.
|
||||
start_time_val = kwargs.get("start_time", kwargs.get("api_call_start_time"))
|
||||
parent_span = self.tracer.start_span(
|
||||
name="litellm_proxy_request",
|
||||
start_time=self._to_ns(start_time_val)
|
||||
if start_time_val is not None
|
||||
else None,
|
||||
context=traceparent_ctx,
|
||||
kind=self.span_kind.SERVER,
|
||||
)
|
||||
ctx = trace.set_span_in_context(parent_span)
|
||||
return ctx, parent_span
|
||||
|
||||
# SDK mode — no parent span needed
|
||||
return traceparent_ctx, None
|
||||
|
||||
def _handle_success(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Override to always create spans on ArizePhoenixLogger's dedicated TracerProvider.
|
||||
|
||||
The base class's ``_get_span_context`` would find the parent span created by
|
||||
the ``otel`` callback on the *global* TracerProvider. That span is invisible
|
||||
in Phoenix (different exporter pipeline), so we ignore it and build our own
|
||||
hierarchy via ``_get_phoenix_context``.
|
||||
"""
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
verbose_logger.debug(
|
||||
"ArizePhoenixLogger: Logging kwargs: %s, OTEL config settings=%s",
|
||||
kwargs,
|
||||
self.config,
|
||||
)
|
||||
|
||||
ctx, parent_span = self._get_phoenix_context(kwargs)
|
||||
|
||||
# Create litellm_request span (child of our parent when in proxy mode)
|
||||
span = self.tracer.start_span(
|
||||
name=self._get_span_name(kwargs),
|
||||
start_time=self._to_ns(start_time),
|
||||
context=ctx,
|
||||
)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
self.set_attributes(span, kwargs, response_obj)
|
||||
|
||||
# Raw-request sub-span (if enabled) — must be created before
|
||||
# ending the parent span so the hierarchy is valid.
|
||||
self._maybe_log_raw_request(kwargs, response_obj, start_time, end_time, span)
|
||||
span.end(end_time=self._to_ns(end_time))
|
||||
|
||||
# Guardrail span
|
||||
self._create_guardrail_span(kwargs=kwargs, context=ctx)
|
||||
|
||||
# Annotate and close our proxy parent span
|
||||
if parent_span is not None:
|
||||
parent_span.set_status(Status(StatusCode.OK))
|
||||
self.set_attributes(parent_span, kwargs, response_obj)
|
||||
parent_span.end(end_time=self._to_ns(end_time))
|
||||
|
||||
# Metrics & cost recording
|
||||
self._record_metrics(kwargs, response_obj, start_time, end_time)
|
||||
|
||||
# Semantic logs
|
||||
if self.config.enable_events:
|
||||
self._emit_semantic_logs(kwargs, response_obj, span)
|
||||
|
||||
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Override to always create failure spans on ArizePhoenixLogger's dedicated
|
||||
TracerProvider. Mirrors ``_handle_success`` but sets ERROR status.
|
||||
"""
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
verbose_logger.debug(
|
||||
"ArizePhoenixLogger: Failure - Logging kwargs: %s, OTEL config settings=%s",
|
||||
kwargs,
|
||||
self.config,
|
||||
)
|
||||
|
||||
ctx, parent_span = self._get_phoenix_context(kwargs)
|
||||
|
||||
# Create litellm_request span (child of our parent when in proxy mode)
|
||||
span = self.tracer.start_span(
|
||||
name=self._get_span_name(kwargs),
|
||||
start_time=self._to_ns(start_time),
|
||||
context=ctx,
|
||||
)
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
self.set_attributes(span, kwargs, response_obj)
|
||||
self._record_exception_on_span(span=span, kwargs=kwargs)
|
||||
span.end(end_time=self._to_ns(end_time))
|
||||
|
||||
# Guardrail span
|
||||
self._create_guardrail_span(kwargs=kwargs, context=ctx)
|
||||
|
||||
# Annotate and close our proxy parent span
|
||||
if parent_span is not None:
|
||||
parent_span.set_status(Status(StatusCode.ERROR))
|
||||
self.set_attributes(parent_span, kwargs, response_obj)
|
||||
self._record_exception_on_span(span=parent_span, kwargs=kwargs)
|
||||
parent_span.end(end_time=self._to_ns(end_time))
|
||||
|
||||
@staticmethod
|
||||
def get_arize_phoenix_config() -> ArizePhoenixConfig:
|
||||
"""
|
||||
Retrieves the Arize Phoenix configuration based on environment variables.
|
||||
Returns:
|
||||
ArizePhoenixConfig: A Pydantic model containing Arize Phoenix configuration.
|
||||
"""
|
||||
api_key = os.environ.get("PHOENIX_API_KEY", None)
|
||||
|
||||
collector_endpoint = os.environ.get("PHOENIX_COLLECTOR_HTTP_ENDPOINT", None)
|
||||
|
||||
if not collector_endpoint:
|
||||
grpc_endpoint = os.environ.get("PHOENIX_COLLECTOR_ENDPOINT", None)
|
||||
http_endpoint = os.environ.get("PHOENIX_COLLECTOR_HTTP_ENDPOINT", None)
|
||||
collector_endpoint = http_endpoint or grpc_endpoint
|
||||
|
||||
endpoint = None
|
||||
protocol: Protocol = "otlp_http"
|
||||
|
||||
if collector_endpoint:
|
||||
# Parse the endpoint to determine protocol
|
||||
if collector_endpoint.startswith("grpc://") or (
|
||||
":4317" in collector_endpoint and "/v1/traces" not in collector_endpoint
|
||||
):
|
||||
endpoint = collector_endpoint
|
||||
protocol = "otlp_grpc"
|
||||
else:
|
||||
# Phoenix Cloud endpoints (app.phoenix.arize.com) include the space in the URL
|
||||
if "app.phoenix.arize.com" in collector_endpoint:
|
||||
endpoint = collector_endpoint
|
||||
protocol = "otlp_http"
|
||||
# For other HTTP endpoints, ensure they have the correct path
|
||||
elif "/v1/traces" not in collector_endpoint:
|
||||
if collector_endpoint.endswith("/v1"):
|
||||
endpoint = collector_endpoint + "/traces"
|
||||
elif collector_endpoint.endswith("/"):
|
||||
endpoint = f"{collector_endpoint}v1/traces"
|
||||
else:
|
||||
endpoint = f"{collector_endpoint}/v1/traces"
|
||||
else:
|
||||
endpoint = collector_endpoint
|
||||
protocol = "otlp_http"
|
||||
else:
|
||||
# If no endpoint specified, self hosted phoenix
|
||||
endpoint = "http://localhost:6006/v1/traces"
|
||||
protocol = "otlp_http"
|
||||
verbose_logger.debug(
|
||||
f"No PHOENIX_COLLECTOR_ENDPOINT found, using default local Phoenix endpoint: {endpoint}"
|
||||
)
|
||||
|
||||
otlp_auth_headers = None
|
||||
if api_key is not None:
|
||||
otlp_auth_headers = f"Authorization=Bearer {api_key}"
|
||||
elif "app.phoenix.arize.com" in endpoint:
|
||||
# Phoenix Cloud requires an API key
|
||||
raise ValueError(
|
||||
"PHOENIX_API_KEY must be set when using Phoenix Cloud (app.phoenix.arize.com)."
|
||||
)
|
||||
|
||||
project_name = os.environ.get("PHOENIX_PROJECT_NAME", "default")
|
||||
|
||||
return ArizePhoenixConfig(
|
||||
otlp_auth_headers=otlp_auth_headers,
|
||||
protocol=protocol,
|
||||
endpoint=endpoint,
|
||||
project_name=project_name,
|
||||
)
|
||||
|
||||
## cannot suppress additional proxy server spans, removed previous methods.
|
||||
|
||||
async def async_health_check(self):
|
||||
config = self.get_arize_phoenix_config()
|
||||
|
||||
if not config.otlp_auth_headers:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error_message": "PHOENIX_API_KEY environment variable not set",
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"message": "Arize-Phoenix credentials are configured properly",
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Arize Phoenix API client for fetching prompt versions from Arize Phoenix.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
|
||||
class ArizePhoenixClient:
|
||||
"""
|
||||
Client for interacting with Arize Phoenix API to fetch prompt versions.
|
||||
|
||||
Supports:
|
||||
- Authentication with Bearer tokens
|
||||
- Fetching prompt versions
|
||||
- Direct API base URL configuration
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||
"""
|
||||
Initialize the Arize Phoenix client.
|
||||
|
||||
Args:
|
||||
api_key: Arize Phoenix API token
|
||||
api_base: Base URL for the Arize Phoenix API (e.g., 'https://app.phoenix.arize.com/s/workspace/v1')
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("api_key is required")
|
||||
|
||||
if not self.api_base:
|
||||
raise ValueError("api_base is required")
|
||||
|
||||
# Set up authentication headers
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
# Initialize HTTPHandler
|
||||
self.http_handler = HTTPHandler(disable_default_headers=True)
|
||||
|
||||
def get_prompt_version(self, prompt_version_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch a prompt version from Arize Phoenix.
|
||||
|
||||
Args:
|
||||
prompt_version_id: The ID of the prompt version to fetch
|
||||
|
||||
Returns:
|
||||
Dictionary containing prompt version data, or None if not found
|
||||
"""
|
||||
url = f"{self.api_base}/v1/prompt_versions/{prompt_version_id}"
|
||||
|
||||
try:
|
||||
# Use the underlying httpx client directly to avoid query param extraction
|
||||
response = self.http_handler.get(url, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
return data.get("data")
|
||||
|
||||
except Exception as e:
|
||||
# Check if it's an HTTP error
|
||||
response = getattr(e, "response", None)
|
||||
if response is not None and hasattr(response, "status_code"):
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
elif response.status_code == 403:
|
||||
raise Exception(
|
||||
f"Access denied to prompt version '{prompt_version_id}'. Check your Arize Phoenix permissions."
|
||||
)
|
||||
elif response.status_code == 401:
|
||||
raise Exception(
|
||||
"Authentication failed. Check your Arize Phoenix API key and permissions."
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to fetch prompt version '{prompt_version_id}': {e}"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Error fetching prompt version '{prompt_version_id}': {e}"
|
||||
)
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""
|
||||
Test the connection to the Arize Phoenix API.
|
||||
|
||||
Returns:
|
||||
True if connection is successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Try to access the prompt_versions endpoint to test connection
|
||||
url = f"{self.api_base}/prompt_versions"
|
||||
response = self.http_handler.client.get(url, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
"""Close the HTTP handler to free resources."""
|
||||
if hasattr(self, "http_handler"):
|
||||
self.http_handler.close()
|
||||
@@ -0,0 +1,488 @@
|
||||
"""
|
||||
Arize Phoenix prompt manager that integrates with LiteLLM's prompt management system.
|
||||
Fetches prompt versions from Arize Phoenix and provides workspace-based access control.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from jinja2 import DictLoader, Environment, select_autoescape
|
||||
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
from litellm.integrations.prompt_management_base import (
|
||||
PromptManagementBase,
|
||||
PromptManagementClient,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.prompts.init_prompts import PromptSpec
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
from .arize_phoenix_client import ArizePhoenixClient
|
||||
|
||||
|
||||
class ArizePhoenixPromptTemplate:
|
||||
"""
|
||||
Represents a prompt template loaded from Arize Phoenix.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
template_id: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
metadata: Dict[str, Any],
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
self.template_id = template_id
|
||||
self.messages = messages
|
||||
self.metadata = metadata
|
||||
self.model = model or metadata.get("model_name")
|
||||
self.model_provider = metadata.get("model_provider")
|
||||
self.temperature = metadata.get("temperature")
|
||||
self.max_tokens = metadata.get("max_tokens")
|
||||
self.invocation_parameters = metadata.get("invocation_parameters", {})
|
||||
self.description = metadata.get("description", "")
|
||||
self.template_format = metadata.get("template_format", "MUSTACHE")
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ArizePhoenixPromptTemplate(id='{self.template_id}', model='{self.model}')"
|
||||
)
|
||||
|
||||
|
||||
class ArizePhoenixTemplateManager:
|
||||
"""
|
||||
Manager for loading and rendering prompt templates from Arize Phoenix.
|
||||
|
||||
Supports:
|
||||
- Fetching prompt versions from Arize Phoenix API
|
||||
- Workspace-based access control through Arize Phoenix permissions
|
||||
- Mustache/Handlebars-style templating (using Jinja2)
|
||||
- Model configuration and invocation parameters
|
||||
- Multi-message chat templates
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
prompt_id: Optional[str] = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.prompt_id = prompt_id
|
||||
self.prompts: Dict[str, ArizePhoenixPromptTemplate] = {}
|
||||
self.arize_client = ArizePhoenixClient(
|
||||
api_key=self.api_key, api_base=self.api_base
|
||||
)
|
||||
|
||||
self.jinja_env = Environment(
|
||||
loader=DictLoader({}),
|
||||
autoescape=select_autoescape(["html", "xml"]),
|
||||
# Use Mustache/Handlebars-style delimiters
|
||||
variable_start_string="{{",
|
||||
variable_end_string="}}",
|
||||
block_start_string="{%",
|
||||
block_end_string="%}",
|
||||
comment_start_string="{#",
|
||||
comment_end_string="#}",
|
||||
)
|
||||
|
||||
# Load prompt from Arize Phoenix if prompt_id is provided
|
||||
if self.prompt_id:
|
||||
self._load_prompt_from_arize(self.prompt_id)
|
||||
|
||||
def _load_prompt_from_arize(self, prompt_version_id: str) -> None:
|
||||
"""Load a specific prompt version from Arize Phoenix."""
|
||||
try:
|
||||
# Fetch the prompt version from Arize Phoenix
|
||||
prompt_data = self.arize_client.get_prompt_version(prompt_version_id)
|
||||
|
||||
if prompt_data:
|
||||
template = self._parse_prompt_data(prompt_data, prompt_version_id)
|
||||
self.prompts[prompt_version_id] = template
|
||||
else:
|
||||
raise ValueError(f"Prompt version '{prompt_version_id}' not found")
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Failed to load prompt version '{prompt_version_id}' from Arize Phoenix: {e}"
|
||||
)
|
||||
|
||||
def _parse_prompt_data(
|
||||
self, data: Dict[str, Any], prompt_version_id: str
|
||||
) -> ArizePhoenixPromptTemplate:
|
||||
"""Parse Arize Phoenix prompt data and extract messages and metadata."""
|
||||
template_data = data.get("template", {})
|
||||
messages = template_data.get("messages", [])
|
||||
|
||||
# Extract invocation parameters
|
||||
invocation_params = data.get("invocation_parameters", {})
|
||||
provider_params = {}
|
||||
|
||||
# Extract provider-specific parameters
|
||||
if "openai" in invocation_params:
|
||||
provider_params = invocation_params["openai"]
|
||||
elif "anthropic" in invocation_params:
|
||||
provider_params = invocation_params["anthropic"]
|
||||
else:
|
||||
# Try to find any nested provider params
|
||||
for key, value in invocation_params.items():
|
||||
if isinstance(value, dict):
|
||||
provider_params = value
|
||||
break
|
||||
|
||||
# Build metadata dictionary
|
||||
metadata = {
|
||||
"model_name": data.get("model_name"),
|
||||
"model_provider": data.get("model_provider"),
|
||||
"description": data.get("description", ""),
|
||||
"template_type": data.get("template_type"),
|
||||
"template_format": data.get("template_format", "MUSTACHE"),
|
||||
"invocation_parameters": invocation_params,
|
||||
"temperature": provider_params.get("temperature"),
|
||||
"max_tokens": provider_params.get("max_tokens"),
|
||||
}
|
||||
|
||||
return ArizePhoenixPromptTemplate(
|
||||
template_id=prompt_version_id,
|
||||
messages=messages,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def render_template(
|
||||
self, template_id: str, variables: Optional[Dict[str, Any]] = None
|
||||
) -> List[AllMessageValues]:
|
||||
"""Render a template with the given variables and return formatted messages."""
|
||||
if template_id not in self.prompts:
|
||||
raise ValueError(f"Template '{template_id}' not found")
|
||||
|
||||
template = self.prompts[template_id]
|
||||
rendered_messages: List[AllMessageValues] = []
|
||||
|
||||
for message in template.messages:
|
||||
role = message.get("role", "user")
|
||||
content_parts = message.get("content", [])
|
||||
|
||||
# Render each content part
|
||||
rendered_content_parts = []
|
||||
for part in content_parts:
|
||||
if part.get("type") == "text":
|
||||
text = part.get("text", "")
|
||||
# Render the text with Jinja2 (Mustache-style)
|
||||
jinja_template = self.jinja_env.from_string(text)
|
||||
rendered_text = jinja_template.render(**(variables or {}))
|
||||
rendered_content_parts.append(rendered_text)
|
||||
else:
|
||||
# Handle other content types if needed
|
||||
rendered_content_parts.append(part)
|
||||
|
||||
# Combine rendered content
|
||||
final_content = " ".join(rendered_content_parts)
|
||||
|
||||
rendered_messages.append(
|
||||
{"role": role, "content": final_content} # type: ignore
|
||||
)
|
||||
|
||||
return rendered_messages
|
||||
|
||||
def get_template(self, template_id: str) -> Optional[ArizePhoenixPromptTemplate]:
|
||||
"""Get a template by ID."""
|
||||
return self.prompts.get(template_id)
|
||||
|
||||
def list_templates(self) -> List[str]:
|
||||
"""List all available template IDs."""
|
||||
return list(self.prompts.keys())
|
||||
|
||||
|
||||
class ArizePhoenixPromptManager(CustomPromptManagement):
|
||||
"""
|
||||
Arize Phoenix prompt manager that integrates with LiteLLM's prompt management system.
|
||||
|
||||
This class enables using prompt versions from Arize Phoenix with the
|
||||
litellm completion() function by implementing the PromptManagementBase interface.
|
||||
|
||||
Usage:
|
||||
# Configure Arize Phoenix access
|
||||
arize_config = {
|
||||
"workspace": "your-workspace",
|
||||
"access_token": "your-token",
|
||||
}
|
||||
|
||||
# Use with completion
|
||||
response = litellm.completion(
|
||||
model="arize/gpt-4o",
|
||||
prompt_id="UHJvbXB0VmVyc2lvbjox",
|
||||
prompt_variables={"question": "What is AI?"},
|
||||
arize_config=arize_config,
|
||||
messages=[{"role": "user", "content": "This will be combined with the prompt"}]
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
prompt_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.prompt_id = prompt_id
|
||||
self._prompt_manager: Optional[ArizePhoenixTemplateManager] = None
|
||||
|
||||
@property
|
||||
def integration_name(self) -> str:
|
||||
"""Integration name used in model names like 'arize/gpt-4o'."""
|
||||
return "arize"
|
||||
|
||||
@property
|
||||
def prompt_manager(self) -> ArizePhoenixTemplateManager:
|
||||
"""Get or create the prompt manager instance."""
|
||||
if self._prompt_manager is None:
|
||||
self._prompt_manager = ArizePhoenixTemplateManager(
|
||||
api_key=self.api_key,
|
||||
api_base=self.api_base,
|
||||
prompt_id=self.prompt_id,
|
||||
)
|
||||
return self._prompt_manager
|
||||
|
||||
def get_prompt_template(
|
||||
self,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[List[AllMessageValues], Dict[str, Any]]:
|
||||
"""
|
||||
Get a prompt template and render it with variables.
|
||||
|
||||
Args:
|
||||
prompt_id: The ID of the prompt version
|
||||
prompt_variables: Variables to substitute in the template
|
||||
|
||||
Returns:
|
||||
Tuple of (rendered_messages, metadata)
|
||||
"""
|
||||
template = self.prompt_manager.get_template(prompt_id)
|
||||
if not template:
|
||||
raise ValueError(f"Prompt template '{prompt_id}' not found")
|
||||
|
||||
# Render the template
|
||||
rendered_messages = self.prompt_manager.render_template(
|
||||
prompt_id, prompt_variables or {}
|
||||
)
|
||||
|
||||
# Extract metadata
|
||||
metadata = {
|
||||
"model": template.model,
|
||||
"temperature": template.temperature,
|
||||
"max_tokens": template.max_tokens,
|
||||
}
|
||||
|
||||
# Add additional invocation parameters
|
||||
invocation_params = template.invocation_parameters
|
||||
provider_params = {}
|
||||
|
||||
if "openai" in invocation_params:
|
||||
provider_params = invocation_params["openai"]
|
||||
elif "anthropic" in invocation_params:
|
||||
provider_params = invocation_params["anthropic"]
|
||||
|
||||
# Add any additional parameters
|
||||
for key, value in provider_params.items():
|
||||
if key not in metadata:
|
||||
metadata[key] = value
|
||||
|
||||
return rendered_messages, metadata
|
||||
|
||||
def pre_call_hook(
|
||||
self,
|
||||
user_id: Optional[str],
|
||||
messages: List[AllMessageValues],
|
||||
function_call: Optional[Union[Dict[str, Any], str]] = None,
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
prompt_id: Optional[str] = None,
|
||||
prompt_variables: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[List[AllMessageValues], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Pre-call hook that processes the prompt template before making the LLM call.
|
||||
"""
|
||||
if not prompt_id:
|
||||
return messages, litellm_params
|
||||
|
||||
try:
|
||||
# Get the rendered messages and metadata
|
||||
rendered_messages, prompt_metadata = self.get_prompt_template(
|
||||
prompt_id, prompt_variables
|
||||
)
|
||||
|
||||
# Merge rendered messages with existing messages
|
||||
if rendered_messages:
|
||||
# Prepend rendered messages to existing messages
|
||||
final_messages = rendered_messages + messages
|
||||
else:
|
||||
final_messages = messages
|
||||
|
||||
# Update litellm_params with prompt metadata
|
||||
if litellm_params is None:
|
||||
litellm_params = {}
|
||||
|
||||
# Apply model and parameters from prompt metadata
|
||||
if prompt_metadata.get("model") and not self.ignore_prompt_manager_model:
|
||||
litellm_params["model"] = prompt_metadata["model"]
|
||||
|
||||
if not self.ignore_prompt_manager_optional_params:
|
||||
for param in [
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
]:
|
||||
if param in prompt_metadata:
|
||||
litellm_params[param] = prompt_metadata[param]
|
||||
|
||||
return final_messages, litellm_params
|
||||
|
||||
except Exception as e:
|
||||
# Log error but don't fail the call
|
||||
import litellm
|
||||
|
||||
litellm._logging.verbose_proxy_logger.error(
|
||||
f"Error in Arize Phoenix prompt pre_call_hook: {e}"
|
||||
)
|
||||
return messages, litellm_params
|
||||
|
||||
def get_available_prompts(self) -> List[str]:
|
||||
"""Get list of available prompt IDs."""
|
||||
return self.prompt_manager.list_templates()
|
||||
|
||||
def reload_prompts(self) -> None:
|
||||
"""Reload prompts from Arize Phoenix."""
|
||||
if self.prompt_id:
|
||||
self._prompt_manager = None # Reset to force reload
|
||||
self.prompt_manager # This will trigger reload
|
||||
|
||||
def should_run_prompt_management(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if prompt management should run based on the prompt_id.
|
||||
|
||||
For Arize Phoenix, we always return True and handle the prompt loading
|
||||
in the _compile_prompt_helper method.
|
||||
"""
|
||||
return True
|
||||
|
||||
def _compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> PromptManagementClient:
|
||||
"""
|
||||
Compile an Arize Phoenix prompt template into a PromptManagementClient structure.
|
||||
|
||||
This method:
|
||||
1. Loads the prompt version from Arize Phoenix
|
||||
2. Renders it with the provided variables
|
||||
3. Returns formatted chat messages
|
||||
4. Extracts model and optional parameters from metadata
|
||||
"""
|
||||
if prompt_id is None:
|
||||
raise ValueError("prompt_id is required for Arize Phoenix prompt manager")
|
||||
try:
|
||||
# Load the prompt from Arize Phoenix if not already loaded
|
||||
if prompt_id not in self.prompt_manager.prompts:
|
||||
self.prompt_manager._load_prompt_from_arize(prompt_id)
|
||||
|
||||
# Get the rendered messages and metadata
|
||||
rendered_messages, prompt_metadata = self.get_prompt_template(
|
||||
prompt_id, prompt_variables
|
||||
)
|
||||
|
||||
# Extract model from metadata (if specified)
|
||||
template_model = prompt_metadata.get("model")
|
||||
|
||||
# Extract optional parameters from metadata
|
||||
optional_params = {}
|
||||
for param in [
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
]:
|
||||
if param in prompt_metadata:
|
||||
optional_params[param] = prompt_metadata[param]
|
||||
|
||||
return PromptManagementClient(
|
||||
prompt_id=prompt_id,
|
||||
prompt_template=rendered_messages,
|
||||
prompt_template_model=template_model,
|
||||
prompt_template_optional_params=optional_params,
|
||||
completed_messages=None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error compiling prompt '{prompt_id}': {e}")
|
||||
|
||||
async def async_compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> PromptManagementClient:
|
||||
"""
|
||||
Async version of compile prompt helper. Since Arize Phoenix operations are synchronous,
|
||||
this simply delegates to the sync version.
|
||||
"""
|
||||
if prompt_id is None:
|
||||
raise ValueError("prompt_id is required for Arize Phoenix prompt manager")
|
||||
return self._compile_prompt_helper(
|
||||
prompt_id=prompt_id,
|
||||
prompt_spec=prompt_spec,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
)
|
||||
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
ignore_prompt_manager_model: Optional[bool] = False,
|
||||
ignore_prompt_manager_optional_params: Optional[bool] = False,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Get chat completion prompt from Arize Phoenix and return processed model, messages, and parameters.
|
||||
"""
|
||||
return PromptManagementBase.get_chat_completion_prompt(
|
||||
self,
|
||||
model,
|
||||
messages,
|
||||
non_default_params,
|
||||
prompt_id,
|
||||
prompt_variables,
|
||||
dynamic_callback_params,
|
||||
prompt_spec=prompt_spec,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
ignore_prompt_manager_model=ignore_prompt_manager_model,
|
||||
ignore_prompt_manager_optional_params=ignore_prompt_manager_optional_params,
|
||||
)
|
||||
@@ -0,0 +1,105 @@
|
||||
import datetime
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class AthinaLogger:
|
||||
def __init__(self):
|
||||
import os
|
||||
|
||||
self.athina_api_key = os.getenv("ATHINA_API_KEY")
|
||||
self.headers = {
|
||||
"athina-api-key": self.athina_api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self.athina_logging_url = (
|
||||
os.getenv("ATHINA_BASE_URL", "https://log.athina.ai")
|
||||
+ "/api/v1/log/inference"
|
||||
)
|
||||
self.additional_keys = [
|
||||
"environment",
|
||||
"prompt_slug",
|
||||
"customer_id",
|
||||
"customer_user_id",
|
||||
"session_id",
|
||||
"external_reference_id",
|
||||
"context",
|
||||
"expected_response",
|
||||
"user_query",
|
||||
"tags",
|
||||
"user_feedback",
|
||||
"model_options",
|
||||
"custom_attributes",
|
||||
]
|
||||
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
import json
|
||||
import traceback
|
||||
|
||||
try:
|
||||
is_stream = kwargs.get("stream", False)
|
||||
if is_stream:
|
||||
if "complete_streaming_response" in kwargs:
|
||||
# Log the completion response in streaming mode
|
||||
completion_response = kwargs["complete_streaming_response"]
|
||||
response_json = (
|
||||
completion_response.model_dump() if completion_response else {}
|
||||
)
|
||||
else:
|
||||
# Skip logging if the completion response is not available
|
||||
return
|
||||
else:
|
||||
# Log the completion response in non streaming mode
|
||||
response_json = response_obj.model_dump() if response_obj else {}
|
||||
data = {
|
||||
"language_model_id": kwargs.get("model"),
|
||||
"request": kwargs,
|
||||
"response": response_json,
|
||||
"prompt_tokens": response_json.get("usage", {}).get("prompt_tokens"),
|
||||
"completion_tokens": response_json.get("usage", {}).get(
|
||||
"completion_tokens"
|
||||
),
|
||||
"total_tokens": response_json.get("usage", {}).get("total_tokens"),
|
||||
}
|
||||
|
||||
if (
|
||||
type(end_time) is datetime.datetime
|
||||
and type(start_time) is datetime.datetime
|
||||
):
|
||||
data["response_time"] = int(
|
||||
(end_time - start_time).total_seconds() * 1000
|
||||
)
|
||||
|
||||
if "messages" in kwargs:
|
||||
data["prompt"] = kwargs.get("messages", None)
|
||||
|
||||
# Directly add tools or functions if present
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
data.update(
|
||||
(k, v)
|
||||
for k, v in optional_params.items()
|
||||
if k in ["tools", "functions"]
|
||||
)
|
||||
|
||||
# Add additional metadata keys
|
||||
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
if metadata:
|
||||
for key in self.additional_keys:
|
||||
if key in metadata:
|
||||
data[key] = metadata[key]
|
||||
response = litellm.module_level_client.post(
|
||||
self.athina_logging_url,
|
||||
headers=self.headers,
|
||||
data=json.dumps(data, default=str),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
print_verbose(
|
||||
f"Athina Logger Error - {response.text}, {response.status_code}"
|
||||
)
|
||||
else:
|
||||
print_verbose(f"Athina Logger Succeeded - {response.text}")
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
f"Athina Logger Error - {e}, Stack trace: {traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
@@ -0,0 +1,3 @@
|
||||
from litellm.integrations.azure_sentinel.azure_sentinel import AzureSentinelLogger
|
||||
|
||||
__all__ = ["AzureSentinelLogger"]
|
||||
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Azure Sentinel Integration - sends logs to Azure Log Analytics using Logs Ingestion API
|
||||
|
||||
Azure Sentinel uses Log Analytics workspaces for data storage. This integration sends
|
||||
LiteLLM logs to the Log Analytics workspace using the Azure Monitor Logs Ingestion API.
|
||||
|
||||
Reference API: https://learn.microsoft.com/en-us/azure/azure-monitor/logs/logs-ingestion-api-overview
|
||||
|
||||
`async_log_success_event` - used by litellm proxy to send logs to Azure Sentinel
|
||||
`async_log_failure_event` - used by litellm proxy to send failure logs to Azure Sentinel
|
||||
|
||||
For batching specific details see CustomBatchLogger class
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import traceback
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
class AzureSentinelLogger(CustomBatchLogger):
|
||||
"""
|
||||
Logger that sends LiteLLM logs to Azure Sentinel via Azure Monitor Logs Ingestion API
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dcr_immutable_id: Optional[str] = None,
|
||||
stream_name: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
tenant_id: Optional[str] = None,
|
||||
client_id: Optional[str] = None,
|
||||
client_secret: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize Azure Sentinel logger using Logs Ingestion API
|
||||
|
||||
Args:
|
||||
dcr_immutable_id (str, optional): Data Collection Rule (DCR) Immutable ID.
|
||||
If not provided, will use AZURE_SENTINEL_DCR_IMMUTABLE_ID env var.
|
||||
stream_name (str, optional): Stream name from DCR (e.g., "Custom-LiteLLM").
|
||||
If not provided, will use AZURE_SENTINEL_STREAM_NAME env var or default to "Custom-LiteLLM".
|
||||
endpoint (str, optional): Data Collection Endpoint (DCE) or DCR ingestion endpoint.
|
||||
If not provided, will use AZURE_SENTINEL_ENDPOINT env var.
|
||||
tenant_id (str, optional): Azure Tenant ID for OAuth2 authentication.
|
||||
If not provided, will use AZURE_SENTINEL_TENANT_ID or AZURE_TENANT_ID env var.
|
||||
client_id (str, optional): Azure Client ID (Application ID) for OAuth2 authentication.
|
||||
If not provided, will use AZURE_SENTINEL_CLIENT_ID or AZURE_CLIENT_ID env var.
|
||||
client_secret (str, optional): Azure Client Secret for OAuth2 authentication.
|
||||
If not provided, will use AZURE_SENTINEL_CLIENT_SECRET or AZURE_CLIENT_SECRET env var.
|
||||
"""
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
|
||||
self.dcr_immutable_id = dcr_immutable_id or os.getenv(
|
||||
"AZURE_SENTINEL_DCR_IMMUTABLE_ID"
|
||||
)
|
||||
self.stream_name = stream_name or os.getenv(
|
||||
"AZURE_SENTINEL_STREAM_NAME", "Custom-LiteLLM"
|
||||
)
|
||||
self.endpoint = endpoint or os.getenv("AZURE_SENTINEL_ENDPOINT")
|
||||
self.tenant_id = (
|
||||
tenant_id
|
||||
or os.getenv("AZURE_SENTINEL_TENANT_ID")
|
||||
or os.getenv("AZURE_TENANT_ID")
|
||||
)
|
||||
self.client_id = (
|
||||
client_id
|
||||
or os.getenv("AZURE_SENTINEL_CLIENT_ID")
|
||||
or os.getenv("AZURE_CLIENT_ID")
|
||||
)
|
||||
self.client_secret = (
|
||||
client_secret
|
||||
or os.getenv("AZURE_SENTINEL_CLIENT_SECRET")
|
||||
or os.getenv("AZURE_CLIENT_SECRET")
|
||||
)
|
||||
|
||||
if not self.dcr_immutable_id:
|
||||
raise ValueError(
|
||||
"AZURE_SENTINEL_DCR_IMMUTABLE_ID is required. Set it as an environment variable or pass dcr_immutable_id parameter."
|
||||
)
|
||||
if not self.endpoint:
|
||||
raise ValueError(
|
||||
"AZURE_SENTINEL_ENDPOINT is required. Set it as an environment variable or pass endpoint parameter."
|
||||
)
|
||||
if not self.tenant_id:
|
||||
raise ValueError(
|
||||
"AZURE_SENTINEL_TENANT_ID or AZURE_TENANT_ID is required. Set it as an environment variable or pass tenant_id parameter."
|
||||
)
|
||||
if not self.client_id:
|
||||
raise ValueError(
|
||||
"AZURE_SENTINEL_CLIENT_ID or AZURE_CLIENT_ID is required. Set it as an environment variable or pass client_id parameter."
|
||||
)
|
||||
if not self.client_secret:
|
||||
raise ValueError(
|
||||
"AZURE_SENTINEL_CLIENT_SECRET or AZURE_CLIENT_SECRET is required. Set it as an environment variable or pass client_secret parameter."
|
||||
)
|
||||
|
||||
# Build API endpoint: {Endpoint}/dataCollectionRules/{DCR Immutable ID}/streams/{Stream Name}?api-version=2023-01-01
|
||||
self.api_endpoint = f"{self.endpoint.rstrip('/')}/dataCollectionRules/{self.dcr_immutable_id}/streams/{self.stream_name}?api-version=2023-01-01"
|
||||
|
||||
# OAuth2 scope for Azure Monitor
|
||||
self.oauth_scope = "https://monitor.azure.com/.default"
|
||||
self.oauth_token: Optional[str] = None
|
||||
self.oauth_token_expires_at: Optional[float] = None
|
||||
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(**kwargs, flush_lock=self.flush_lock)
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.log_queue: List[StandardLoggingPayload] = []
|
||||
|
||||
async def _get_oauth_token(self) -> str:
|
||||
"""
|
||||
Get OAuth2 Bearer token for Azure Monitor Logs Ingestion API
|
||||
|
||||
Returns:
|
||||
Bearer token string
|
||||
"""
|
||||
# Check if we have a valid cached token
|
||||
import time
|
||||
|
||||
if (
|
||||
self.oauth_token
|
||||
and self.oauth_token_expires_at
|
||||
and time.time() < self.oauth_token_expires_at - 60
|
||||
): # Refresh 60 seconds before expiry
|
||||
return self.oauth_token
|
||||
|
||||
# Get new token using client credentials flow
|
||||
assert self.tenant_id is not None, "tenant_id is required"
|
||||
assert self.client_id is not None, "client_id is required"
|
||||
assert self.client_secret is not None, "client_secret is required"
|
||||
|
||||
token_url = (
|
||||
f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/v2.0/token"
|
||||
)
|
||||
|
||||
token_data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"scope": self.oauth_scope,
|
||||
"grant_type": "client_credentials",
|
||||
}
|
||||
|
||||
response = await self.async_httpx_client.post(
|
||||
url=token_url,
|
||||
data=token_data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to get OAuth2 token: {response.status_code} - {response.text}"
|
||||
)
|
||||
|
||||
token_response = response.json()
|
||||
self.oauth_token = token_response.get("access_token")
|
||||
expires_in = token_response.get("expires_in", 3600)
|
||||
|
||||
if not self.oauth_token:
|
||||
raise Exception("OAuth2 token response did not contain access_token")
|
||||
|
||||
# Cache token expiry time
|
||||
import time
|
||||
|
||||
self.oauth_token_expires_at = time.time() + expires_in
|
||||
|
||||
return self.oauth_token
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Async Log success events to Azure Sentinel
|
||||
|
||||
- Gets StandardLoggingPayload from kwargs
|
||||
- Adds to batch queue
|
||||
- Flushes based on CustomBatchLogger settings
|
||||
|
||||
Raises:
|
||||
Raises a NON Blocking verbose_logger.exception if an error occurs
|
||||
"""
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"Azure Sentinel: Logging - Enters logging function for model %s", kwargs
|
||||
)
|
||||
standard_logging_payload = kwargs.get("standard_logging_object", None)
|
||||
|
||||
if standard_logging_payload is None:
|
||||
verbose_logger.warning(
|
||||
"Azure Sentinel: standard_logging_object not found in kwargs"
|
||||
)
|
||||
return
|
||||
|
||||
self.log_queue.append(standard_logging_payload)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.async_send_batch()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Azure Sentinel Layer Error - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Async Log failure events to Azure Sentinel
|
||||
|
||||
- Gets StandardLoggingPayload from kwargs
|
||||
- Adds to batch queue
|
||||
- Flushes based on CustomBatchLogger settings
|
||||
|
||||
Raises:
|
||||
Raises a NON Blocking verbose_logger.exception if an error occurs
|
||||
"""
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"Azure Sentinel: Logging - Enters failure logging function for model %s",
|
||||
kwargs,
|
||||
)
|
||||
standard_logging_payload = kwargs.get("standard_logging_object", None)
|
||||
|
||||
if standard_logging_payload is None:
|
||||
verbose_logger.warning(
|
||||
"Azure Sentinel: standard_logging_object not found in kwargs"
|
||||
)
|
||||
return
|
||||
|
||||
self.log_queue.append(standard_logging_payload)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.async_send_batch()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Azure Sentinel Layer Error - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
Sends the batch of logs to Azure Monitor Logs Ingestion API
|
||||
|
||||
Raises:
|
||||
Raises a NON Blocking verbose_logger.exception if an error occurs
|
||||
"""
|
||||
try:
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
verbose_logger.debug(
|
||||
"Azure Sentinel - about to flush %s events", len(self.log_queue)
|
||||
)
|
||||
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
# Get OAuth2 token
|
||||
bearer_token = await self._get_oauth_token()
|
||||
|
||||
# Convert log queue to JSON array format expected by Logs Ingestion API
|
||||
# Each log entry should be a JSON object in the array
|
||||
body = safe_dumps(self.log_queue)
|
||||
|
||||
# Set headers for Logs Ingestion API
|
||||
headers = {
|
||||
"Authorization": f"Bearer {bearer_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Send the request
|
||||
response = await self.async_httpx_client.post(
|
||||
url=self.api_endpoint, data=body.encode("utf-8"), headers=headers
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 204]:
|
||||
verbose_logger.error(
|
||||
"Azure Sentinel API error: status_code=%s, response=%s",
|
||||
response.status_code,
|
||||
response.text,
|
||||
)
|
||||
raise Exception(
|
||||
f"Failed to send logs to Azure Sentinel: {response.status_code} - {response.text}"
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"Azure Sentinel: Response from API status_code: %s",
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Azure Sentinel Error sending batch API - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
finally:
|
||||
self.log_queue.clear()
|
||||
@@ -0,0 +1,179 @@
|
||||
{
|
||||
"id": "chatcmpl-2299b6a2-82a3-465a-b47c-04e685a2227f",
|
||||
"trace_id": "97311c60-9a61-4f48-a814-70139ee57868",
|
||||
"call_type": "acompletion",
|
||||
"cache_hit": null,
|
||||
"stream": true,
|
||||
"status": "success",
|
||||
"custom_llm_provider": "openai",
|
||||
"saved_cache_cost": 0.0,
|
||||
"startTime": 1766000068.28466,
|
||||
"endTime": 1766000070.07935,
|
||||
"completionStartTime": 1766000070.07935,
|
||||
"response_time": 1.79468512535095,
|
||||
"model": "gpt-4o",
|
||||
"metadata": {
|
||||
"user_api_key_hash": null,
|
||||
"user_api_key_alias": null,
|
||||
"user_api_key_team_id": null,
|
||||
"user_api_key_org_id": null,
|
||||
"user_api_key_user_id": null,
|
||||
"user_api_key_team_alias": null,
|
||||
"user_api_key_user_email": null,
|
||||
"spend_logs_metadata": null,
|
||||
"requester_ip_address": null,
|
||||
"requester_metadata": null,
|
||||
"user_api_key_end_user_id": null,
|
||||
"prompt_management_metadata": null,
|
||||
"applied_guardrails": [],
|
||||
"mcp_tool_call_metadata": null,
|
||||
"vector_store_request_metadata": null,
|
||||
"guardrail_information": null
|
||||
},
|
||||
"cache_key": null,
|
||||
"response_cost": 0.00022500000000000002,
|
||||
"total_tokens": 30,
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"request_tags": [],
|
||||
"end_user": "",
|
||||
"api_base": "",
|
||||
"model_group": "",
|
||||
"model_id": "",
|
||||
"requester_ip_address": null,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, world!"
|
||||
}
|
||||
],
|
||||
"response": {
|
||||
"id": "chatcmpl-2299b6a2-82a3-465a-b47c-04e685a2227f",
|
||||
"created": 1742855151,
|
||||
"model": "gpt-4o",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": null,
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": "hi",
|
||||
"role": "assistant",
|
||||
"tool_calls": null,
|
||||
"function_call": null,
|
||||
"provider_specific_fields": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"completion_tokens": 20,
|
||||
"prompt_tokens": 10,
|
||||
"total_tokens": 30,
|
||||
"completion_tokens_details": null,
|
||||
"prompt_tokens_details": null
|
||||
}
|
||||
},
|
||||
"model_parameters": {},
|
||||
"hidden_params": {
|
||||
"model_id": null,
|
||||
"cache_key": null,
|
||||
"api_base": "https://api.openai.com",
|
||||
"response_cost": 0.00022500000000000002,
|
||||
"additional_headers": {},
|
||||
"litellm_overhead_time_ms": null,
|
||||
"batch_models": null,
|
||||
"litellm_model_name": "gpt-4o"
|
||||
},
|
||||
"model_map_information": {
|
||||
"model_map_key": "gpt-4o",
|
||||
"model_map_value": {
|
||||
"key": "gpt-4o",
|
||||
"max_tokens": 16384,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 16384,
|
||||
"input_cost_per_token": 2.5e-06,
|
||||
"cache_creation_input_token_cost": null,
|
||||
"cache_read_input_token_cost": 1.25e-06,
|
||||
"input_cost_per_character": null,
|
||||
"input_cost_per_token_above_128k_tokens": null,
|
||||
"input_cost_per_query": null,
|
||||
"input_cost_per_second": null,
|
||||
"input_cost_per_audio_token": null,
|
||||
"input_cost_per_token_batches": 1.25e-06,
|
||||
"output_cost_per_token_batches": 5e-06,
|
||||
"output_cost_per_token": 1e-05,
|
||||
"output_cost_per_audio_token": null,
|
||||
"output_cost_per_character": null,
|
||||
"output_cost_per_token_above_128k_tokens": null,
|
||||
"output_cost_per_character_above_128k_tokens": null,
|
||||
"output_cost_per_second": null,
|
||||
"output_cost_per_image": null,
|
||||
"output_vector_size": null,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat",
|
||||
"supports_system_messages": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_vision": true,
|
||||
"supports_function_calling": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_assistant_prefill": false,
|
||||
"supports_prompt_caching": true,
|
||||
"supports_audio_input": false,
|
||||
"supports_audio_output": false,
|
||||
"supports_pdf_input": false,
|
||||
"supports_embedding_image_input": false,
|
||||
"supports_native_streaming": null,
|
||||
"supports_web_search": true,
|
||||
"search_context_cost_per_query": {
|
||||
"search_context_size_low": 0.03,
|
||||
"search_context_size_medium": 0.035,
|
||||
"search_context_size_high": 0.05
|
||||
},
|
||||
"tpm": null,
|
||||
"rpm": null,
|
||||
"supported_openai_params": [
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"modalities",
|
||||
"prediction",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"seed",
|
||||
"stop",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"function_call",
|
||||
"functions",
|
||||
"max_retries",
|
||||
"extra_headers",
|
||||
"parallel_tool_calls",
|
||||
"audio",
|
||||
"response_format",
|
||||
"user"
|
||||
]
|
||||
}
|
||||
},
|
||||
"error_str": null,
|
||||
"error_information": {
|
||||
"error_code": "",
|
||||
"error_class": "",
|
||||
"llm_provider": "",
|
||||
"traceback": "",
|
||||
"error_message": ""
|
||||
},
|
||||
"response_cost_failure_debug_info": null,
|
||||
"guardrail_information": null,
|
||||
"standard_built_in_tools_params": {
|
||||
"web_search_options": null,
|
||||
"file_search": null
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,400 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from litellm._uuid import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import _DEFAULT_TTL_FOR_HTTPX_CLIENTS, AZURE_STORAGE_MSFT_VERSION
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.llms.azure.common_utils import get_azure_ad_token_from_entra_id
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
class AzureBlobStorageLogger(CustomBatchLogger):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"AzureBlobStorageLogger: in init azure blob storage logger"
|
||||
)
|
||||
|
||||
# Env Variables used for Azure Storage Authentication
|
||||
self.tenant_id = os.getenv("AZURE_STORAGE_TENANT_ID")
|
||||
self.client_id = os.getenv("AZURE_STORAGE_CLIENT_ID")
|
||||
self.client_secret = os.getenv("AZURE_STORAGE_CLIENT_SECRET")
|
||||
self.azure_storage_account_key: Optional[str] = os.getenv(
|
||||
"AZURE_STORAGE_ACCOUNT_KEY"
|
||||
)
|
||||
|
||||
# Required Env Variables for Azure Storage
|
||||
_azure_storage_account_name = os.getenv("AZURE_STORAGE_ACCOUNT_NAME")
|
||||
if not _azure_storage_account_name:
|
||||
raise ValueError(
|
||||
"Missing required environment variable: AZURE_STORAGE_ACCOUNT_NAME"
|
||||
)
|
||||
self.azure_storage_account_name: str = _azure_storage_account_name
|
||||
_azure_storage_file_system = os.getenv("AZURE_STORAGE_FILE_SYSTEM")
|
||||
if not _azure_storage_file_system:
|
||||
raise ValueError(
|
||||
"Missing required environment variable: AZURE_STORAGE_FILE_SYSTEM"
|
||||
)
|
||||
self.azure_storage_file_system: str = _azure_storage_file_system
|
||||
self._service_client = None
|
||||
# Time that the azure service client expires, in order to reset the connection pool and keep it fresh
|
||||
self._service_client_timeout: Optional[float] = None
|
||||
|
||||
# Internal variables used for Token based authentication
|
||||
self.azure_auth_token: Optional[
|
||||
str
|
||||
] = None # the Azure AD token to use for Azure Storage API requests
|
||||
self.token_expiry: Optional[
|
||||
datetime
|
||||
] = None # the expiry time of the currentAzure AD token
|
||||
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.flush_lock = asyncio.Lock()
|
||||
self.log_queue: List[StandardLoggingPayload] = []
|
||||
super().__init__(**kwargs, flush_lock=self.flush_lock)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"AzureBlobStorageLogger: Got exception on init AzureBlobStorageLogger client {str(e)}"
|
||||
)
|
||||
raise e
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Async Log success events to Azure Blob Storage
|
||||
|
||||
Raises:
|
||||
Raises a NON Blocking verbose_logger.exception if an error occurs
|
||||
"""
|
||||
try:
|
||||
self._premium_user_check()
|
||||
verbose_logger.debug(
|
||||
"AzureBlobStorageLogger: Logging - Enters logging function for model %s",
|
||||
kwargs,
|
||||
)
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
|
||||
if standard_logging_payload is None:
|
||||
raise ValueError("standard_logging_payload is not set")
|
||||
|
||||
self.log_queue.append(standard_logging_payload)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"AzureBlobStorageLogger Layer Error - {str(e)}")
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Async Log failure events to Azure Blob Storage
|
||||
|
||||
Raises:
|
||||
Raises a NON Blocking verbose_logger.exception if an error occurs
|
||||
"""
|
||||
try:
|
||||
self._premium_user_check()
|
||||
verbose_logger.debug(
|
||||
"AzureBlobStorageLogger: Logging - Enters logging function for model %s",
|
||||
kwargs,
|
||||
)
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
|
||||
if standard_logging_payload is None:
|
||||
raise ValueError("standard_logging_payload is not set")
|
||||
|
||||
self.log_queue.append(standard_logging_payload)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"AzureBlobStorageLogger Layer Error - {str(e)}")
|
||||
pass
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
Sends the in memory logs queue to Azure Blob Storage
|
||||
|
||||
Raises:
|
||||
Raises a NON Blocking verbose_logger.exception if an error occurs
|
||||
"""
|
||||
try:
|
||||
if not self.log_queue:
|
||||
verbose_logger.exception("Datadog: log_queue does not exist")
|
||||
return
|
||||
|
||||
verbose_logger.debug(
|
||||
"AzureBlobStorageLogger - about to flush %s events",
|
||||
len(self.log_queue),
|
||||
)
|
||||
|
||||
for payload in self.log_queue:
|
||||
await self.async_upload_payload_to_azure_blob_storage(payload=payload)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"AzureBlobStorageLogger Error sending batch API - {str(e)}"
|
||||
)
|
||||
|
||||
async def async_upload_payload_to_azure_blob_storage(
|
||||
self, payload: StandardLoggingPayload
|
||||
):
|
||||
"""
|
||||
Uploads the payload to Azure Blob Storage using a 3-step process:
|
||||
1. Create file resource
|
||||
2. Append data
|
||||
3. Flush the data
|
||||
"""
|
||||
try:
|
||||
if self.azure_storage_account_key:
|
||||
await self.upload_to_azure_data_lake_with_azure_account_key(
|
||||
payload=payload
|
||||
)
|
||||
else:
|
||||
# Get a valid token instead of always requesting a new one
|
||||
await self.set_valid_azure_ad_token()
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
json_payload = (
|
||||
safe_dumps(payload) + "\n"
|
||||
) # Add newline for each log entry
|
||||
payload_bytes = json_payload.encode("utf-8")
|
||||
filename = f"{payload.get('id') or str(uuid.uuid4())}.json"
|
||||
base_url = f"https://{self.azure_storage_account_name}.dfs.core.windows.net/{self.azure_storage_file_system}/{filename}"
|
||||
|
||||
# Execute the 3-step upload process
|
||||
await self._create_file(async_client, base_url)
|
||||
await self._append_data(async_client, base_url, json_payload)
|
||||
await self._flush_data(async_client, base_url, len(payload_bytes))
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Successfully uploaded log to Azure Blob Storage: {filename}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error uploading to Azure Blob Storage: {str(e)}")
|
||||
raise e
|
||||
|
||||
async def _create_file(self, client: AsyncHTTPHandler, base_url: str):
|
||||
"""Helper method to create the file resource"""
|
||||
try:
|
||||
verbose_logger.debug(f"Creating file resource at: {base_url}")
|
||||
headers = {
|
||||
"x-ms-version": AZURE_STORAGE_MSFT_VERSION,
|
||||
"Content-Length": "0",
|
||||
"Authorization": f"Bearer {self.azure_auth_token}",
|
||||
}
|
||||
response = await client.put(f"{base_url}?resource=file", headers=headers)
|
||||
response.raise_for_status()
|
||||
verbose_logger.debug("Successfully created file resource")
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error creating file resource: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _append_data(
|
||||
self, client: AsyncHTTPHandler, base_url: str, json_payload: str
|
||||
):
|
||||
"""Helper method to append data to the file"""
|
||||
try:
|
||||
verbose_logger.debug(f"Appending data to file: {base_url}")
|
||||
headers = {
|
||||
"x-ms-version": AZURE_STORAGE_MSFT_VERSION,
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.azure_auth_token}",
|
||||
}
|
||||
response = await client.patch(
|
||||
f"{base_url}?action=append&position=0",
|
||||
headers=headers,
|
||||
data=json_payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
verbose_logger.debug("Successfully appended data")
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error appending data: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _flush_data(self, client: AsyncHTTPHandler, base_url: str, position: int):
|
||||
"""Helper method to flush the data"""
|
||||
try:
|
||||
verbose_logger.debug(f"Flushing data at position {position}")
|
||||
headers = {
|
||||
"x-ms-version": AZURE_STORAGE_MSFT_VERSION,
|
||||
"Content-Length": "0",
|
||||
"Authorization": f"Bearer {self.azure_auth_token}",
|
||||
}
|
||||
response = await client.patch(
|
||||
f"{base_url}?action=flush&position={position}", headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
verbose_logger.debug("Successfully flushed data")
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error flushing data: {str(e)}")
|
||||
raise
|
||||
|
||||
####### Helper methods to managing Authentication to Azure Storage #######
|
||||
##########################################################################
|
||||
|
||||
async def set_valid_azure_ad_token(self):
|
||||
"""
|
||||
Wrapper to set self.azure_auth_token to a valid Azure AD token, refreshing if necessary
|
||||
|
||||
Refreshes the token when:
|
||||
- Token is expired
|
||||
- Token is not set
|
||||
"""
|
||||
# Check if token needs refresh
|
||||
if self._azure_ad_token_is_expired() or self.azure_auth_token is None:
|
||||
verbose_logger.debug("Azure AD token needs refresh")
|
||||
self.azure_auth_token = self.get_azure_ad_token_from_azure_storage(
|
||||
tenant_id=self.tenant_id,
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
)
|
||||
# Token typically expires in 1 hour
|
||||
self.token_expiry = datetime.now() + timedelta(hours=1)
|
||||
verbose_logger.debug(f"New token will expire at {self.token_expiry}")
|
||||
|
||||
def get_azure_ad_token_from_azure_storage(
|
||||
self,
|
||||
tenant_id: Optional[str],
|
||||
client_id: Optional[str],
|
||||
client_secret: Optional[str],
|
||||
) -> str:
|
||||
"""
|
||||
Gets Azure AD token to use for Azure Storage API requests
|
||||
"""
|
||||
verbose_logger.debug("Getting Azure AD Token from Azure Storage")
|
||||
verbose_logger.debug(
|
||||
"tenant_id %s, client_id %s, client_secret %s",
|
||||
tenant_id,
|
||||
client_id,
|
||||
client_secret,
|
||||
)
|
||||
if tenant_id is None:
|
||||
raise ValueError(
|
||||
"Missing required environment variable: AZURE_STORAGE_TENANT_ID"
|
||||
)
|
||||
if client_id is None:
|
||||
raise ValueError(
|
||||
"Missing required environment variable: AZURE_STORAGE_CLIENT_ID"
|
||||
)
|
||||
if client_secret is None:
|
||||
raise ValueError(
|
||||
"Missing required environment variable: AZURE_STORAGE_CLIENT_SECRET"
|
||||
)
|
||||
|
||||
token_provider = get_azure_ad_token_from_entra_id(
|
||||
tenant_id=tenant_id,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scope="https://storage.azure.com/.default",
|
||||
)
|
||||
token = token_provider()
|
||||
|
||||
verbose_logger.debug("azure auth token %s", token)
|
||||
|
||||
return token
|
||||
|
||||
def _azure_ad_token_is_expired(self):
|
||||
"""
|
||||
Returns True if Azure AD token is expired, False otherwise
|
||||
"""
|
||||
if self.azure_auth_token and self.token_expiry:
|
||||
if datetime.now() + timedelta(minutes=5) >= self.token_expiry:
|
||||
verbose_logger.debug("Azure AD token is expired. Requesting new token")
|
||||
return True
|
||||
return False
|
||||
|
||||
def _premium_user_check(self):
|
||||
"""
|
||||
Checks if the user is a premium user, raises an error if not
|
||||
"""
|
||||
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"AzureBlobStorageLogger is only available for premium users. {CommonProxyErrors.not_premium_user}"
|
||||
)
|
||||
|
||||
async def get_service_client(self):
|
||||
from azure.storage.filedatalake.aio import DataLakeServiceClient
|
||||
|
||||
# expire old clients to recover from connection issues
|
||||
if (
|
||||
self._service_client_timeout
|
||||
and self._service_client
|
||||
and self._service_client_timeout > time.time()
|
||||
):
|
||||
await self._service_client.close()
|
||||
self._service_client = None
|
||||
if not self._service_client:
|
||||
self._service_client = DataLakeServiceClient(
|
||||
account_url=f"https://{self.azure_storage_account_name}.dfs.core.windows.net",
|
||||
credential=self.azure_storage_account_key,
|
||||
)
|
||||
self._service_client_timeout = time.time() + _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
||||
return self._service_client
|
||||
|
||||
async def upload_to_azure_data_lake_with_azure_account_key(
|
||||
self, payload: StandardLoggingPayload
|
||||
):
|
||||
"""
|
||||
Uploads the payload to Azure Data Lake using the Azure SDK
|
||||
|
||||
This is used when Azure Storage Account Key is set - Azure Storage Account Key does not work directly with Azure Rest API
|
||||
"""
|
||||
|
||||
# Create an async service client
|
||||
|
||||
service_client = await self.get_service_client()
|
||||
# Get file system client
|
||||
file_system_client = service_client.get_file_system_client(
|
||||
file_system=self.azure_storage_file_system
|
||||
)
|
||||
|
||||
try:
|
||||
# Create directory with today's date
|
||||
from datetime import datetime
|
||||
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
directory_client = file_system_client.get_directory_client(today)
|
||||
|
||||
# check if the directory exists
|
||||
if not await directory_client.exists():
|
||||
await directory_client.create_directory()
|
||||
verbose_logger.debug(f"Created directory: {today}")
|
||||
|
||||
# Create a file client
|
||||
file_name = f"{payload.get('id') or str(uuid.uuid4())}.json"
|
||||
file_client = directory_client.get_file_client(file_name)
|
||||
|
||||
# Create the file
|
||||
await file_client.create_file()
|
||||
|
||||
# Content to append
|
||||
content = safe_dumps(payload).encode("utf-8")
|
||||
|
||||
# Append content to the file
|
||||
await file_client.append_data(data=content, offset=0, length=len(content))
|
||||
|
||||
# Flush the content to finalize the file
|
||||
await file_client.flush_data(position=len(content), offset=0)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Successfully uploaded and wrote to {today}/{file_name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error occurred: {str(e)}")
|
||||
@@ -0,0 +1,317 @@
|
||||
# LiteLLM BitBucket Prompt Management
|
||||
|
||||
A powerful prompt management system for LiteLLM that fetches `.prompt` files from BitBucket repositories. This enables team-based prompt management with BitBucket's built-in access control and version control capabilities.
|
||||
|
||||
## Features
|
||||
|
||||
- **🏢 Team-based access control**: Leverage BitBucket's workspace and repository permissions
|
||||
- **📁 Repository-based prompt storage**: Store prompts in BitBucket repositories
|
||||
- **🔐 Multiple authentication methods**: Support for access tokens and basic auth
|
||||
- **🎯 YAML frontmatter**: Define model, parameters, and schemas in file headers
|
||||
- **🔧 Handlebars templating**: Use `{{variable}}` syntax with Jinja2 backend
|
||||
- **✅ Input validation**: Automatic validation against defined schemas
|
||||
- **🔗 LiteLLM integration**: Works seamlessly with `litellm.completion()`
|
||||
- **💬 Smart message parsing**: Converts prompts to proper chat messages
|
||||
- **⚙️ Parameter extraction**: Automatically applies model settings from prompts
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Set up BitBucket Repository
|
||||
|
||||
Create a repository in your BitBucket workspace and add `.prompt` files:
|
||||
|
||||
```
|
||||
your-repo/
|
||||
├── prompts/
|
||||
│ ├── chat_assistant.prompt
|
||||
│ ├── code_reviewer.prompt
|
||||
│ └── data_analyst.prompt
|
||||
```
|
||||
|
||||
### 2. Create a `.prompt` file
|
||||
|
||||
Create a file called `prompts/chat_assistant.prompt`:
|
||||
|
||||
```yaml
|
||||
---
|
||||
model: gpt-4
|
||||
temperature: 0.7
|
||||
max_tokens: 150
|
||||
input:
|
||||
schema:
|
||||
user_message: string
|
||||
system_context?: string
|
||||
---
|
||||
|
||||
{% if system_context %}System: {{system_context}}
|
||||
|
||||
{% endif %}User: {{user_message}}
|
||||
```
|
||||
|
||||
### 3. Configure BitBucket Access
|
||||
|
||||
#### Option A: Access Token (Recommended)
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
# Configure BitBucket access
|
||||
bitbucket_config = {
|
||||
"workspace": "your-workspace",
|
||||
"repository": "your-repo",
|
||||
"access_token": "your-access-token",
|
||||
"branch": "main" # optional, defaults to main
|
||||
}
|
||||
|
||||
# Set global BitBucket configuration
|
||||
litellm.set_global_bitbucket_config(bitbucket_config)
|
||||
```
|
||||
|
||||
#### Option B: Basic Authentication
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
# Configure BitBucket access with basic auth
|
||||
bitbucket_config = {
|
||||
"workspace": "your-workspace",
|
||||
"repository": "your-repo",
|
||||
"username": "your-username",
|
||||
"access_token": "your-app-password", # Use app password for basic auth
|
||||
"auth_method": "basic",
|
||||
"branch": "main"
|
||||
}
|
||||
|
||||
litellm.set_global_bitbucket_config(bitbucket_config)
|
||||
```
|
||||
|
||||
### 4. Use with LiteLLM
|
||||
|
||||
```python
|
||||
# Use with completion - the model prefix 'bitbucket/' tells LiteLLM to use BitBucket prompt management
|
||||
response = litellm.completion(
|
||||
model="bitbucket/gpt-4", # The actual model comes from the .prompt file
|
||||
prompt_id="prompts/chat_assistant", # Location of the prompt file
|
||||
prompt_variables={
|
||||
"user_message": "What is machine learning?",
|
||||
"system_context": "You are a helpful AI tutor."
|
||||
},
|
||||
# Any additional messages will be appended after the prompt
|
||||
messages=[{"role": "user", "content": "Please explain it simply."}]
|
||||
)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
## Proxy Server Configuration
|
||||
|
||||
### 1. Create a `.prompt` file
|
||||
|
||||
Create `prompts/hello.prompt`:
|
||||
|
||||
```yaml
|
||||
---
|
||||
model: gpt-4
|
||||
temperature: 0.7
|
||||
---
|
||||
System: You are a helpful assistant.
|
||||
|
||||
User: {{user_message}}
|
||||
```
|
||||
|
||||
### 2. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: my-bitbucket-model
|
||||
litellm_params:
|
||||
model: bitbucket/gpt-4
|
||||
prompt_id: "prompts/hello"
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
litellm_settings:
|
||||
global_bitbucket_config:
|
||||
workspace: "your-workspace"
|
||||
repository: "your-repo"
|
||||
access_token: "your-access-token"
|
||||
branch: "main"
|
||||
```
|
||||
|
||||
### 3. Start the proxy
|
||||
|
||||
```bash
|
||||
litellm --config config.yaml --detailed_debug
|
||||
```
|
||||
|
||||
### 4. Test it!
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-d '{
|
||||
"model": "my-bitbucket-model",
|
||||
"messages": [{"role": "user", "content": "IGNORED"}],
|
||||
"prompt_variables": {
|
||||
"user_message": "What is the capital of France?"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
## Prompt File Format
|
||||
|
||||
### Basic Structure
|
||||
|
||||
```yaml
|
||||
---
|
||||
# Model configuration
|
||||
model: gpt-4
|
||||
temperature: 0.7
|
||||
max_tokens: 500
|
||||
|
||||
# Input schema (optional)
|
||||
input:
|
||||
schema:
|
||||
user_message: string
|
||||
system_context?: string
|
||||
---
|
||||
|
||||
System: You are a helpful {{role}} assistant.
|
||||
|
||||
User: {{user_message}}
|
||||
```
|
||||
|
||||
### Advanced Features
|
||||
|
||||
**Multi-role conversations:**
|
||||
|
||||
```yaml
|
||||
---
|
||||
model: gpt-4
|
||||
temperature: 0.3
|
||||
---
|
||||
System: You are a helpful coding assistant.
|
||||
|
||||
User: {{user_question}}
|
||||
```
|
||||
|
||||
**Dynamic model selection:**
|
||||
|
||||
```yaml
|
||||
---
|
||||
model: "{{preferred_model}}" # Model can be a variable
|
||||
temperature: 0.7
|
||||
---
|
||||
System: You are a helpful assistant specialized in {{domain}}.
|
||||
|
||||
User: {{user_message}}
|
||||
```
|
||||
|
||||
## Team-Based Access Control
|
||||
|
||||
BitBucket's built-in permission system provides team-based access control:
|
||||
|
||||
1. **Workspace-level permissions**: Control access to entire workspaces
|
||||
2. **Repository-level permissions**: Control access to specific repositories
|
||||
3. **Branch-level permissions**: Control access to specific branches
|
||||
4. **User and group management**: Manage team members and their access levels
|
||||
|
||||
### Setting up Team Access
|
||||
|
||||
1. **Create workspaces for each team**:
|
||||
```
|
||||
team-a-prompts/
|
||||
team-b-prompts/
|
||||
team-c-prompts/
|
||||
```
|
||||
|
||||
2. **Configure repository permissions**:
|
||||
- Grant read access to team members
|
||||
- Grant write access to prompt maintainers
|
||||
- Use branch protection rules for production prompts
|
||||
|
||||
3. **Use different access tokens**:
|
||||
- Each team can have their own access token
|
||||
- Tokens can be scoped to specific repositories
|
||||
- Use app passwords for additional security
|
||||
|
||||
## API Reference
|
||||
|
||||
### BitBucket Configuration
|
||||
|
||||
```python
|
||||
bitbucket_config = {
|
||||
"workspace": str, # Required: BitBucket workspace name
|
||||
"repository": str, # Required: Repository name
|
||||
"access_token": str, # Required: BitBucket access token or app password
|
||||
"branch": str, # Optional: Branch to fetch from (default: "main")
|
||||
"base_url": str, # Optional: Custom BitBucket API URL
|
||||
"auth_method": str, # Optional: "token" or "basic" (default: "token")
|
||||
"username": str, # Optional: Username for basic auth
|
||||
"base_url" : str # Optional: Incase where the base url is not https://api.bitbucket.org/2.0
|
||||
}
|
||||
```
|
||||
|
||||
### LiteLLM Integration
|
||||
|
||||
```python
|
||||
response = litellm.completion(
|
||||
model="bitbucket/<base_model>", # required (e.g., bitbucket/gpt-4)
|
||||
prompt_id=str, # required - the .prompt filename without extension
|
||||
prompt_variables=dict, # optional - variables for template rendering
|
||||
bitbucket_config=dict, # optional - BitBucket configuration (if not set globally)
|
||||
messages=list, # optional - additional messages
|
||||
)
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The BitBucket integration provides detailed error messages for common issues:
|
||||
|
||||
- **Authentication errors**: Invalid access tokens or credentials
|
||||
- **Permission errors**: Insufficient access to workspace/repository
|
||||
- **File not found**: Missing .prompt files
|
||||
- **Network errors**: Connection issues with BitBucket API
|
||||
|
||||
## Security Considerations
|
||||
|
||||
1. **Access Token Security**: Store access tokens securely using environment variables or secret management systems
|
||||
2. **Repository Permissions**: Use BitBucket's permission system to control access
|
||||
3. **Branch Protection**: Protect main branches from unauthorized changes
|
||||
4. **Audit Logging**: BitBucket provides audit logs for all repository access
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **"Access denied" errors**: Check your BitBucket permissions for the workspace and repository
|
||||
2. **"Authentication failed" errors**: Verify your access token or credentials
|
||||
3. **"File not found" errors**: Ensure the .prompt file exists in the specified branch
|
||||
4. **Template rendering errors**: Check your Handlebars syntax in the .prompt file
|
||||
|
||||
### Debug Mode
|
||||
|
||||
Enable debug logging to troubleshoot issues:
|
||||
|
||||
```python
|
||||
import litellm
|
||||
litellm.set_verbose = True
|
||||
|
||||
# Your BitBucket prompt calls will now show detailed logs
|
||||
response = litellm.completion(
|
||||
model="bitbucket/gpt-4",
|
||||
prompt_id="your_prompt",
|
||||
prompt_variables={"key": "value"}
|
||||
)
|
||||
```
|
||||
|
||||
## Migration from File-Based Prompts
|
||||
|
||||
If you're currently using file-based prompts with the dotprompt integration, you can easily migrate to BitBucket:
|
||||
|
||||
1. **Upload your .prompt files** to a BitBucket repository
|
||||
2. **Update your configuration** to use BitBucket instead of local files
|
||||
3. **Set up team access** using BitBucket's permission system
|
||||
4. **Update your code** to use `bitbucket/` model prefix instead of `dotprompt/`
|
||||
|
||||
This provides better collaboration, version control, and team-based access control for your prompts.
|
||||
@@ -0,0 +1,66 @@
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .bitbucket_prompt_manager import BitBucketPromptManager
|
||||
from litellm.types.prompts.init_prompts import PromptLiteLLMParams, PromptSpec
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
|
||||
from litellm.types.prompts.init_prompts import SupportedPromptIntegrations
|
||||
|
||||
from .bitbucket_prompt_manager import BitBucketPromptManager
|
||||
|
||||
# Global instances
|
||||
global_bitbucket_config: Optional[dict] = None
|
||||
|
||||
|
||||
def set_global_bitbucket_config(config: dict) -> None:
|
||||
"""
|
||||
Set the global BitBucket configuration for prompt management.
|
||||
|
||||
Args:
|
||||
config: Dictionary containing BitBucket configuration
|
||||
- workspace: BitBucket workspace name
|
||||
- repository: Repository name
|
||||
- access_token: BitBucket access token
|
||||
- branch: Branch to fetch prompts from (default: main)
|
||||
"""
|
||||
import litellm
|
||||
|
||||
litellm.global_bitbucket_config = config # type: ignore
|
||||
|
||||
|
||||
def prompt_initializer(
|
||||
litellm_params: "PromptLiteLLMParams", prompt_spec: "PromptSpec"
|
||||
) -> "CustomPromptManagement":
|
||||
"""
|
||||
Initialize a prompt from a BitBucket repository.
|
||||
"""
|
||||
bitbucket_config = getattr(litellm_params, "bitbucket_config", None)
|
||||
prompt_id = getattr(litellm_params, "prompt_id", None)
|
||||
|
||||
if not bitbucket_config:
|
||||
raise ValueError(
|
||||
"bitbucket_config is required for BitBucket prompt integration"
|
||||
)
|
||||
|
||||
try:
|
||||
bitbucket_prompt_manager = BitBucketPromptManager(
|
||||
bitbucket_config=bitbucket_config,
|
||||
prompt_id=prompt_id,
|
||||
)
|
||||
|
||||
return bitbucket_prompt_manager
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
prompt_initializer_registry = {
|
||||
SupportedPromptIntegrations.BITBUCKET.value: prompt_initializer,
|
||||
}
|
||||
|
||||
# Export public API
|
||||
__all__ = [
|
||||
"BitBucketPromptManager",
|
||||
"set_global_bitbucket_config",
|
||||
"global_bitbucket_config",
|
||||
]
|
||||
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
BitBucket API client for fetching .prompt files from BitBucket repositories.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
|
||||
class BitBucketClient:
|
||||
"""
|
||||
Client for interacting with BitBucket API to fetch .prompt files.
|
||||
|
||||
Supports:
|
||||
- Authentication with access tokens
|
||||
- Fetching file contents from repositories
|
||||
- Team-based access control through BitBucket permissions
|
||||
- Branch-specific file fetching
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialize the BitBucket client.
|
||||
|
||||
Args:
|
||||
config: Dictionary containing:
|
||||
- workspace: BitBucket workspace name
|
||||
- repository: Repository name
|
||||
- access_token: BitBucket access token (or app password)
|
||||
- branch: Branch to fetch from (default: main)
|
||||
- base_url: Custom BitBucket API base URL (optional)
|
||||
- auth_method: Authentication method ('token' or 'basic', default: 'token')
|
||||
- username: Username for basic auth (optional)
|
||||
"""
|
||||
self.workspace = config.get("workspace")
|
||||
self.repository = config.get("repository")
|
||||
self.access_token = config.get("access_token")
|
||||
self.branch = config.get("branch", "main")
|
||||
self.base_url = config.get("", "https://api.bitbucket.org/2.0")
|
||||
self.auth_method = config.get("auth_method", "token")
|
||||
self.username = config.get("username")
|
||||
|
||||
if not all([self.workspace, self.repository, self.access_token]):
|
||||
raise ValueError("workspace, repository, and access_token are required")
|
||||
|
||||
# Set up authentication headers
|
||||
self.headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
if self.auth_method == "basic" and self.username:
|
||||
# Use basic auth with username and app password
|
||||
credentials = f"{self.username}:{self.access_token}"
|
||||
encoded_credentials = base64.b64encode(credentials.encode()).decode()
|
||||
self.headers["Authorization"] = f"Basic {encoded_credentials}"
|
||||
else:
|
||||
# Use token-based authentication (default)
|
||||
self.headers["Authorization"] = f"Bearer {self.access_token}"
|
||||
|
||||
# Initialize HTTPHandler
|
||||
self.http_handler = HTTPHandler()
|
||||
|
||||
def get_file_content(self, file_path: str) -> Optional[str]:
|
||||
"""
|
||||
Fetch the content of a file from the BitBucket repository.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file in the repository
|
||||
|
||||
Returns:
|
||||
File content as string, or None if file not found
|
||||
"""
|
||||
url = f"{self.base_url}/repositories/{self.workspace}/{self.repository}/src/{self.branch}/{file_path}"
|
||||
|
||||
try:
|
||||
response = self.http_handler.get(url, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
|
||||
# BitBucket returns file content as base64 encoded
|
||||
if response.headers.get("content-type", "").startswith("text/"):
|
||||
return response.text
|
||||
else:
|
||||
# For binary files or when content-type is not text, try to decode as base64
|
||||
try:
|
||||
return base64.b64decode(response.content).decode("utf-8")
|
||||
except Exception:
|
||||
return response.text
|
||||
|
||||
except Exception as e:
|
||||
# Check if it's an HTTP error
|
||||
if hasattr(e, "response") and hasattr(e.response, "status_code"):
|
||||
if e.response.status_code == 404:
|
||||
return None
|
||||
elif e.response.status_code == 403:
|
||||
raise Exception(
|
||||
f"Access denied to file '{file_path}'. Check your BitBucket permissions for workspace '{self.workspace}' and repository '{self.repository}'."
|
||||
)
|
||||
elif e.response.status_code == 401:
|
||||
raise Exception(
|
||||
"Authentication failed. Check your BitBucket access token and permissions."
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Failed to fetch file '{file_path}': {e}")
|
||||
else:
|
||||
raise Exception(f"Error fetching file '{file_path}': {e}")
|
||||
|
||||
def list_files(
|
||||
self, directory_path: str = "", file_extension: str = ".prompt"
|
||||
) -> List[str]:
|
||||
"""
|
||||
List files in a directory with a specific extension.
|
||||
|
||||
Args:
|
||||
directory_path: Directory path in the repository (empty for root)
|
||||
file_extension: File extension to filter by (default: .prompt)
|
||||
|
||||
Returns:
|
||||
List of file paths
|
||||
"""
|
||||
url = f"{self.base_url}/repositories/{self.workspace}/{self.repository}/src/{self.branch}/{directory_path}"
|
||||
|
||||
try:
|
||||
response = self.http_handler.get(url, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
files = []
|
||||
|
||||
for item in data.get("values", []):
|
||||
if item.get("type") == "commit_file":
|
||||
file_path = item.get("path", "")
|
||||
if file_path.endswith(file_extension):
|
||||
files.append(file_path)
|
||||
|
||||
return files
|
||||
|
||||
except Exception as e:
|
||||
# Check if it's an HTTP error
|
||||
if hasattr(e, "response") and hasattr(e.response, "status_code"):
|
||||
if e.response.status_code == 404:
|
||||
return []
|
||||
elif e.response.status_code == 403:
|
||||
raise Exception(
|
||||
f"Access denied to directory '{directory_path}'. Check your BitBucket permissions for workspace '{self.workspace}' and repository '{self.repository}'."
|
||||
)
|
||||
elif e.response.status_code == 401:
|
||||
raise Exception(
|
||||
"Authentication failed. Check your BitBucket access token and permissions."
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Failed to list files in '{directory_path}': {e}")
|
||||
else:
|
||||
raise Exception(f"Error listing files in '{directory_path}': {e}")
|
||||
|
||||
def get_repository_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the repository.
|
||||
|
||||
Returns:
|
||||
Dictionary containing repository information
|
||||
"""
|
||||
url = f"{self.base_url}/repositories/{self.workspace}/{self.repository}"
|
||||
|
||||
try:
|
||||
response = self.http_handler.get(url, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to get repository info: {e}")
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""
|
||||
Test the connection to the BitBucket repository.
|
||||
|
||||
Returns:
|
||||
True if connection is successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
self.get_repository_info()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_branches(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get list of branches in the repository.
|
||||
|
||||
Returns:
|
||||
List of branch information dictionaries
|
||||
"""
|
||||
url = f"{self.base_url}/repositories/{self.workspace}/{self.repository}/refs/branches"
|
||||
|
||||
try:
|
||||
response = self.http_handler.get(url, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
return data.get("values", [])
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to get branches: {e}")
|
||||
|
||||
def get_file_metadata(self, file_path: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get metadata about a file (size, last modified, etc.).
|
||||
|
||||
Args:
|
||||
file_path: Path to the file in the repository
|
||||
|
||||
Returns:
|
||||
Dictionary containing file metadata, or None if file not found
|
||||
"""
|
||||
url = f"{self.base_url}/repositories/{self.workspace}/{self.repository}/src/{self.branch}/{file_path}"
|
||||
|
||||
try:
|
||||
# Use GET with Range header to get just the headers (HEAD equivalent)
|
||||
headers = self.headers.copy()
|
||||
headers["Range"] = "bytes=0-0" # Request only first byte to get headers
|
||||
|
||||
response = self.http_handler.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
return {
|
||||
"content_type": response.headers.get("content-type"),
|
||||
"content_length": response.headers.get("content-length"),
|
||||
"last_modified": response.headers.get("last-modified"),
|
||||
}
|
||||
except Exception as e:
|
||||
# Check if it's an HTTP error
|
||||
if hasattr(e, "response") and hasattr(e.response, "status_code"):
|
||||
if e.response.status_code == 404:
|
||||
return None
|
||||
raise Exception(f"Failed to get file metadata for '{file_path}': {e}")
|
||||
else:
|
||||
raise Exception(f"Error getting file metadata for '{file_path}': {e}")
|
||||
|
||||
def close(self):
|
||||
"""Close the HTTP handler to free resources."""
|
||||
if hasattr(self, "http_handler"):
|
||||
self.http_handler.close()
|
||||
@@ -0,0 +1,584 @@
|
||||
"""
|
||||
BitBucket prompt manager that integrates with LiteLLM's prompt management system.
|
||||
Fetches .prompt files from BitBucket repositories and provides team-based access control.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from jinja2 import DictLoader, Environment, select_autoescape
|
||||
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
from litellm.integrations.prompt_management_base import (
|
||||
PromptManagementBase,
|
||||
PromptManagementClient,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.prompts.init_prompts import PromptSpec
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
from .bitbucket_client import BitBucketClient
|
||||
|
||||
|
||||
class BitBucketPromptTemplate:
|
||||
"""
|
||||
Represents a prompt template loaded from BitBucket.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
template_id: str,
|
||||
content: str,
|
||||
metadata: Dict[str, Any],
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
self.template_id = template_id
|
||||
self.content = content
|
||||
self.metadata = metadata
|
||||
self.model = model or metadata.get("model")
|
||||
self.temperature = metadata.get("temperature")
|
||||
self.max_tokens = metadata.get("max_tokens")
|
||||
self.input_schema = metadata.get("input", {}).get("schema", {})
|
||||
self.optional_params = {
|
||||
k: v for k, v in metadata.items() if k not in ["model", "input", "content"]
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
return f"BitBucketPromptTemplate(id='{self.template_id}', model='{self.model}')"
|
||||
|
||||
|
||||
class BitBucketTemplateManager:
|
||||
"""
|
||||
Manager for loading and rendering .prompt files from BitBucket repositories.
|
||||
|
||||
Supports:
|
||||
- Fetching .prompt files from BitBucket repositories
|
||||
- Team-based access control through BitBucket permissions
|
||||
- YAML frontmatter for metadata
|
||||
- Handlebars-style templating (using Jinja2)
|
||||
- Input/output schema validation
|
||||
- Model configuration
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bitbucket_config: Dict[str, Any],
|
||||
prompt_id: Optional[str] = None,
|
||||
):
|
||||
self.bitbucket_config = bitbucket_config
|
||||
self.prompt_id = prompt_id
|
||||
self.prompts: Dict[str, BitBucketPromptTemplate] = {}
|
||||
self.bitbucket_client = BitBucketClient(bitbucket_config)
|
||||
|
||||
self.jinja_env = Environment(
|
||||
loader=DictLoader({}),
|
||||
autoescape=select_autoescape(["html", "xml"]),
|
||||
# Use Handlebars-style delimiters to match Dotprompt spec
|
||||
variable_start_string="{{",
|
||||
variable_end_string="}}",
|
||||
block_start_string="{%",
|
||||
block_end_string="%}",
|
||||
comment_start_string="{#",
|
||||
comment_end_string="#}",
|
||||
)
|
||||
|
||||
# Load prompts from BitBucket if prompt_id is provided
|
||||
if self.prompt_id:
|
||||
self._load_prompt_from_bitbucket(self.prompt_id)
|
||||
|
||||
def _load_prompt_from_bitbucket(self, prompt_id: str) -> None:
|
||||
"""Load a specific .prompt file from BitBucket."""
|
||||
try:
|
||||
# Fetch the .prompt file from BitBucket
|
||||
prompt_content = self.bitbucket_client.get_file_content(
|
||||
f"{prompt_id}.prompt"
|
||||
)
|
||||
|
||||
if prompt_content:
|
||||
template = self._parse_prompt_file(prompt_content, prompt_id)
|
||||
self.prompts[prompt_id] = template
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to load prompt '{prompt_id}' from BitBucket: {e}")
|
||||
|
||||
def _parse_prompt_file(
|
||||
self, content: str, prompt_id: str
|
||||
) -> BitBucketPromptTemplate:
|
||||
"""Parse a .prompt file content and extract metadata and template."""
|
||||
# Split frontmatter and content
|
||||
if content.startswith("---"):
|
||||
parts = content.split("---", 2)
|
||||
if len(parts) >= 3:
|
||||
frontmatter_str = parts[1].strip()
|
||||
template_content = parts[2].strip()
|
||||
else:
|
||||
frontmatter_str = ""
|
||||
template_content = content
|
||||
else:
|
||||
frontmatter_str = ""
|
||||
template_content = content
|
||||
|
||||
# Parse YAML frontmatter
|
||||
metadata: Dict[str, Any] = {}
|
||||
if frontmatter_str:
|
||||
try:
|
||||
import yaml
|
||||
|
||||
metadata = yaml.safe_load(frontmatter_str) or {}
|
||||
except ImportError:
|
||||
# Fallback to basic parsing if PyYAML is not available
|
||||
metadata = self._parse_yaml_basic(frontmatter_str)
|
||||
except Exception:
|
||||
metadata = {}
|
||||
|
||||
return BitBucketPromptTemplate(
|
||||
template_id=prompt_id,
|
||||
content=template_content,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _parse_yaml_basic(self, yaml_str: str) -> Dict[str, Any]:
|
||||
"""Basic YAML parser for simple cases when PyYAML is not available."""
|
||||
result: Dict[str, Any] = {}
|
||||
for line in yaml_str.split("\n"):
|
||||
line = line.strip()
|
||||
if ":" in line and not line.startswith("#"):
|
||||
key, value = line.split(":", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# Try to parse value as appropriate type
|
||||
if value.lower() in ["true", "false"]:
|
||||
result[key] = value.lower() == "true"
|
||||
elif value.isdigit():
|
||||
result[key] = int(value)
|
||||
elif value.replace(".", "").isdigit():
|
||||
result[key] = float(value)
|
||||
else:
|
||||
result[key] = value.strip("\"'")
|
||||
return result
|
||||
|
||||
def render_template(
|
||||
self, template_id: str, variables: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""Render a template with the given variables."""
|
||||
if template_id not in self.prompts:
|
||||
raise ValueError(f"Template '{template_id}' not found")
|
||||
|
||||
template = self.prompts[template_id]
|
||||
jinja_template = self.jinja_env.from_string(template.content)
|
||||
|
||||
return jinja_template.render(**(variables or {}))
|
||||
|
||||
def get_template(self, template_id: str) -> Optional[BitBucketPromptTemplate]:
|
||||
"""Get a template by ID."""
|
||||
return self.prompts.get(template_id)
|
||||
|
||||
def list_templates(self) -> List[str]:
|
||||
"""List all available template IDs."""
|
||||
return list(self.prompts.keys())
|
||||
|
||||
|
||||
class BitBucketPromptManager(CustomPromptManagement):
|
||||
"""
|
||||
BitBucket prompt manager that integrates with LiteLLM's prompt management system.
|
||||
|
||||
This class enables using .prompt files from BitBucket repositories with the
|
||||
litellm completion() function by implementing the PromptManagementBase interface.
|
||||
|
||||
Usage:
|
||||
# Configure BitBucket access
|
||||
bitbucket_config = {
|
||||
"workspace": "your-workspace",
|
||||
"repository": "your-repo",
|
||||
"access_token": "your-token",
|
||||
"branch": "main" # optional, defaults to main
|
||||
}
|
||||
|
||||
# Use with completion
|
||||
response = litellm.completion(
|
||||
model="bitbucket/gpt-4",
|
||||
prompt_id="my_prompt",
|
||||
prompt_variables={"variable": "value"},
|
||||
bitbucket_config=bitbucket_config,
|
||||
messages=[{"role": "user", "content": "This will be combined with the prompt"}]
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bitbucket_config: Dict[str, Any],
|
||||
prompt_id: Optional[str] = None,
|
||||
):
|
||||
self.bitbucket_config = bitbucket_config
|
||||
self.prompt_id = prompt_id
|
||||
self._prompt_manager: Optional[BitBucketTemplateManager] = None
|
||||
|
||||
@property
|
||||
def integration_name(self) -> str:
|
||||
"""Integration name used in model names like 'bitbucket/gpt-4'."""
|
||||
return "bitbucket"
|
||||
|
||||
@property
|
||||
def prompt_manager(self) -> BitBucketTemplateManager:
|
||||
"""Get or create the prompt manager instance."""
|
||||
if self._prompt_manager is None:
|
||||
self._prompt_manager = BitBucketTemplateManager(
|
||||
bitbucket_config=self.bitbucket_config,
|
||||
prompt_id=self.prompt_id,
|
||||
)
|
||||
return self._prompt_manager
|
||||
|
||||
def get_prompt_template(
|
||||
self,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
Get a prompt template and render it with variables.
|
||||
|
||||
Args:
|
||||
prompt_id: The ID of the prompt template
|
||||
prompt_variables: Variables to substitute in the template
|
||||
|
||||
Returns:
|
||||
Tuple of (rendered_prompt, metadata)
|
||||
"""
|
||||
template = self.prompt_manager.get_template(prompt_id)
|
||||
if not template:
|
||||
raise ValueError(f"Prompt template '{prompt_id}' not found")
|
||||
|
||||
# Render the template
|
||||
rendered_prompt = self.prompt_manager.render_template(
|
||||
prompt_id, prompt_variables or {}
|
||||
)
|
||||
|
||||
# Extract metadata
|
||||
metadata = {
|
||||
"model": template.model,
|
||||
"temperature": template.temperature,
|
||||
"max_tokens": template.max_tokens,
|
||||
**template.optional_params,
|
||||
}
|
||||
|
||||
return rendered_prompt, metadata
|
||||
|
||||
def pre_call_hook(
|
||||
self,
|
||||
user_id: Optional[str],
|
||||
messages: List[AllMessageValues],
|
||||
function_call: Optional[Union[Dict[str, Any], str]] = None,
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
prompt_id: Optional[str] = None,
|
||||
prompt_variables: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[List[AllMessageValues], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Pre-call hook that processes the prompt template before making the LLM call.
|
||||
"""
|
||||
if not prompt_id:
|
||||
return messages, litellm_params
|
||||
|
||||
try:
|
||||
# Get the rendered prompt and metadata
|
||||
rendered_prompt, prompt_metadata = self.get_prompt_template(
|
||||
prompt_id, prompt_variables
|
||||
)
|
||||
|
||||
# Parse the rendered prompt into messages
|
||||
parsed_messages = self._parse_prompt_to_messages(rendered_prompt)
|
||||
|
||||
# Merge with existing messages
|
||||
if parsed_messages:
|
||||
# If we have parsed messages, use them instead of the original messages
|
||||
final_messages: List[AllMessageValues] = parsed_messages
|
||||
else:
|
||||
# If no messages were parsed, prepend the prompt to existing messages
|
||||
final_messages = [
|
||||
{"role": "user", "content": rendered_prompt} # type: ignore
|
||||
] + messages
|
||||
|
||||
# Update litellm_params with prompt metadata
|
||||
if litellm_params is None:
|
||||
litellm_params = {}
|
||||
|
||||
# Apply model and parameters from prompt metadata
|
||||
if prompt_metadata.get("model"):
|
||||
litellm_params["model"] = prompt_metadata["model"]
|
||||
|
||||
for param in [
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
]:
|
||||
if param in prompt_metadata:
|
||||
litellm_params[param] = prompt_metadata[param]
|
||||
|
||||
return final_messages, litellm_params
|
||||
|
||||
except Exception as e:
|
||||
# Log error but don't fail the call
|
||||
import litellm
|
||||
|
||||
litellm._logging.verbose_proxy_logger.error(
|
||||
f"Error in BitBucket prompt pre_call_hook: {e}"
|
||||
)
|
||||
return messages, litellm_params
|
||||
|
||||
def _parse_prompt_to_messages(self, prompt_content: str) -> List[AllMessageValues]:
|
||||
"""
|
||||
Parse prompt content into a list of messages.
|
||||
Handles both simple prompts and multi-role conversations.
|
||||
"""
|
||||
messages = []
|
||||
lines = prompt_content.strip().split("\n")
|
||||
current_role = None
|
||||
current_content = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Check for role indicators
|
||||
if line.lower().startswith("system:"):
|
||||
if current_role and current_content:
|
||||
messages.append(
|
||||
{
|
||||
"role": current_role,
|
||||
"content": "\n".join(current_content).strip(),
|
||||
} # type: ignore
|
||||
)
|
||||
current_role = "system"
|
||||
current_content = [line[7:].strip()] # Remove "System:" prefix
|
||||
elif line.lower().startswith("user:"):
|
||||
if current_role and current_content:
|
||||
messages.append(
|
||||
{
|
||||
"role": current_role,
|
||||
"content": "\n".join(current_content).strip(),
|
||||
} # type: ignore
|
||||
)
|
||||
current_role = "user"
|
||||
current_content = [line[5:].strip()] # Remove "User:" prefix
|
||||
elif line.lower().startswith("assistant:"):
|
||||
if current_role and current_content:
|
||||
messages.append(
|
||||
{
|
||||
"role": current_role,
|
||||
"content": "\n".join(current_content).strip(),
|
||||
} # type: ignore
|
||||
)
|
||||
current_role = "assistant"
|
||||
current_content = [line[10:].strip()] # Remove "Assistant:" prefix
|
||||
else:
|
||||
# Continue building current message
|
||||
current_content.append(line)
|
||||
|
||||
# Add the last message
|
||||
if current_role and current_content:
|
||||
messages.append(
|
||||
{"role": current_role, "content": "\n".join(current_content).strip()}
|
||||
)
|
||||
|
||||
# If no role indicators found, treat as a single user message
|
||||
if not messages and prompt_content.strip():
|
||||
messages = [{"role": "user", "content": prompt_content.strip()}] # type: ignore
|
||||
|
||||
return messages # type: ignore
|
||||
|
||||
def post_call_hook(
|
||||
self,
|
||||
user_id: Optional[str],
|
||||
response: Any,
|
||||
input_messages: List[AllMessageValues],
|
||||
function_call: Optional[Union[Dict[str, Any], str]] = None,
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
prompt_id: Optional[str] = None,
|
||||
prompt_variables: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Post-call hook for any post-processing after the LLM call.
|
||||
"""
|
||||
return response
|
||||
|
||||
def get_available_prompts(self) -> List[str]:
|
||||
"""Get list of available prompt IDs."""
|
||||
return self.prompt_manager.list_templates()
|
||||
|
||||
def reload_prompts(self) -> None:
|
||||
"""Reload prompts from BitBucket."""
|
||||
if self.prompt_id:
|
||||
self._prompt_manager = None # Reset to force reload
|
||||
self.prompt_manager # This will trigger reload
|
||||
|
||||
def should_run_prompt_management(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if prompt management should run based on the prompt_id.
|
||||
|
||||
For BitBucket, we always return True and handle the prompt loading
|
||||
in the _compile_prompt_helper method.
|
||||
"""
|
||||
return prompt_id is not None
|
||||
|
||||
def _compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> PromptManagementClient:
|
||||
"""
|
||||
Compile a BitBucket prompt template into a PromptManagementClient structure.
|
||||
|
||||
This method:
|
||||
1. Loads the prompt template from BitBucket
|
||||
2. Renders it with the provided variables
|
||||
3. Converts the rendered text into chat messages
|
||||
4. Extracts model and optional parameters from metadata
|
||||
"""
|
||||
if prompt_id is None:
|
||||
raise ValueError("prompt_id is required for BitBucket prompt manager")
|
||||
|
||||
try:
|
||||
# Load the prompt from BitBucket if not already loaded
|
||||
if prompt_id not in self.prompt_manager.prompts:
|
||||
self.prompt_manager._load_prompt_from_bitbucket(prompt_id)
|
||||
|
||||
# Get the rendered prompt and metadata
|
||||
rendered_prompt, prompt_metadata = self.get_prompt_template(
|
||||
prompt_id, prompt_variables
|
||||
)
|
||||
|
||||
# Convert rendered content to chat messages
|
||||
messages = self._parse_prompt_to_messages(rendered_prompt)
|
||||
|
||||
# Extract model from metadata (if specified)
|
||||
template_model = prompt_metadata.get("model")
|
||||
|
||||
# Extract optional parameters from metadata
|
||||
optional_params = {}
|
||||
for param in [
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
]:
|
||||
if param in prompt_metadata:
|
||||
optional_params[param] = prompt_metadata[param]
|
||||
|
||||
return PromptManagementClient(
|
||||
prompt_id=prompt_id,
|
||||
prompt_template=messages,
|
||||
prompt_template_model=template_model,
|
||||
prompt_template_optional_params=optional_params,
|
||||
completed_messages=None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error compiling prompt '{prompt_id}': {e}")
|
||||
|
||||
async def async_compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> PromptManagementClient:
|
||||
"""
|
||||
Async version of compile prompt helper. Since BitBucket operations use sync client,
|
||||
this simply delegates to the sync version.
|
||||
"""
|
||||
if prompt_id is None:
|
||||
raise ValueError("prompt_id is required for BitBucket prompt manager")
|
||||
|
||||
return self._compile_prompt_helper(
|
||||
prompt_id=prompt_id,
|
||||
prompt_spec=prompt_spec,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
)
|
||||
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
ignore_prompt_manager_model: Optional[bool] = False,
|
||||
ignore_prompt_manager_optional_params: Optional[bool] = False,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Get chat completion prompt from BitBucket and return processed model, messages, and parameters.
|
||||
"""
|
||||
return PromptManagementBase.get_chat_completion_prompt(
|
||||
self,
|
||||
model,
|
||||
messages,
|
||||
non_default_params,
|
||||
prompt_id,
|
||||
prompt_variables,
|
||||
dynamic_callback_params,
|
||||
prompt_spec=prompt_spec,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
)
|
||||
|
||||
async def async_get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
ignore_prompt_manager_model: Optional[bool] = False,
|
||||
ignore_prompt_manager_optional_params: Optional[bool] = False,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Async version - delegates to PromptManagementBase async implementation.
|
||||
"""
|
||||
return await PromptManagementBase.async_get_chat_completion_prompt(
|
||||
self,
|
||||
model,
|
||||
messages,
|
||||
non_default_params,
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
prompt_spec=prompt_spec,
|
||||
tools=tools,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
ignore_prompt_manager_model=ignore_prompt_manager_model,
|
||||
ignore_prompt_manager_optional_params=ignore_prompt_manager_optional_params,
|
||||
)
|
||||
@@ -0,0 +1,422 @@
|
||||
# What is this?
|
||||
## Log success + failure events to Braintrust
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.integrations.braintrust_mock_client import (
|
||||
should_use_braintrust_mock,
|
||||
create_mock_braintrust_client,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.utils import print_verbose
|
||||
|
||||
API_BASE = "https://api.braintrustdata.com/v1"
|
||||
|
||||
|
||||
def get_utc_datetime():
|
||||
import datetime as dt
|
||||
|
||||
if hasattr(dt, "UTC"):
|
||||
return datetime.now(dt.UTC) # type: ignore
|
||||
else:
|
||||
return datetime.utcnow() # type: ignore
|
||||
|
||||
|
||||
class BraintrustLogger(CustomLogger):
|
||||
def __init__(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_mock_mode = should_use_braintrust_mock()
|
||||
if self.is_mock_mode:
|
||||
create_mock_braintrust_client()
|
||||
verbose_logger.info(
|
||||
"[BRAINTRUST MOCK] Braintrust logger initialized in mock mode"
|
||||
)
|
||||
self.validate_environment(api_key=api_key)
|
||||
self.api_base = api_base or os.getenv("BRAINTRUST_API_BASE") or API_BASE
|
||||
self.default_project_id = None
|
||||
self.api_key: str = api_key or os.getenv("BRAINTRUST_API_KEY") # type: ignore
|
||||
self.headers = {
|
||||
"Authorization": "Bearer " + self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self._project_id_cache: Dict[
|
||||
str, str
|
||||
] = {} # Cache mapping project names to IDs
|
||||
self.global_braintrust_http_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
self.global_braintrust_sync_http_handler = HTTPHandler()
|
||||
|
||||
def validate_environment(self, api_key: Optional[str]):
|
||||
"""
|
||||
Expects
|
||||
BRAINTRUST_API_KEY
|
||||
|
||||
in the environment
|
||||
"""
|
||||
missing_keys = []
|
||||
if api_key is None and os.getenv("BRAINTRUST_API_KEY", None) is None:
|
||||
missing_keys.append("BRAINTRUST_API_KEY")
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
raise Exception("Missing keys={} in environment.".format(missing_keys))
|
||||
|
||||
def get_project_id_sync(self, project_name: str) -> str:
|
||||
"""
|
||||
Get project ID from name, using cache if available.
|
||||
If project doesn't exist, creates it.
|
||||
"""
|
||||
if project_name in self._project_id_cache:
|
||||
return self._project_id_cache[project_name]
|
||||
|
||||
try:
|
||||
response = self.global_braintrust_sync_http_handler.post(
|
||||
f"{self.api_base}/project",
|
||||
headers=self.headers,
|
||||
json={"name": project_name},
|
||||
)
|
||||
project_dict = response.json()
|
||||
project_id = project_dict["id"]
|
||||
self._project_id_cache[project_name] = project_id
|
||||
return project_id
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(f"Failed to register project: {e.response.text}")
|
||||
|
||||
async def get_project_id_async(self, project_name: str) -> str:
|
||||
"""
|
||||
Async version of get_project_id_sync
|
||||
"""
|
||||
if project_name in self._project_id_cache:
|
||||
return self._project_id_cache[project_name]
|
||||
|
||||
try:
|
||||
response = await self.global_braintrust_http_handler.post(
|
||||
f"{self.api_base}/project/register",
|
||||
headers=self.headers,
|
||||
json={"name": project_name},
|
||||
)
|
||||
project_dict = response.json()
|
||||
project_id = project_dict["id"]
|
||||
self._project_id_cache[project_name] = project_id
|
||||
return project_id
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(f"Failed to register project: {e.response.text}")
|
||||
|
||||
async def create_default_project_and_experiment(self):
|
||||
project = await self.global_braintrust_http_handler.post(
|
||||
f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"}
|
||||
)
|
||||
|
||||
project_dict = project.json()
|
||||
|
||||
self.default_project_id = project_dict["id"]
|
||||
|
||||
def create_sync_default_project_and_experiment(self):
|
||||
project = self.global_braintrust_sync_http_handler.post(
|
||||
f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"}
|
||||
)
|
||||
|
||||
project_dict = project.json()
|
||||
|
||||
self.default_project_id = project_dict["id"]
|
||||
|
||||
def log_success_event( # noqa: PLR0915
|
||||
self, kwargs, response_obj, start_time, end_time
|
||||
):
|
||||
verbose_logger.debug("REACHES BRAINTRUST SUCCESS")
|
||||
try:
|
||||
litellm_call_id = kwargs.get("litellm_call_id")
|
||||
standard_logging_object = kwargs.get("standard_logging_object", {})
|
||||
prompt = {"messages": kwargs.get("messages")}
|
||||
|
||||
output = None
|
||||
choices = []
|
||||
if response_obj is not None and (
|
||||
kwargs.get("call_type", None) == "embedding"
|
||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
):
|
||||
output = None
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ModelResponse
|
||||
):
|
||||
output = response_obj["choices"][0]["message"].json()
|
||||
choices = response_obj["choices"]
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.TextCompletionResponse
|
||||
):
|
||||
output = response_obj.choices[0].text
|
||||
choices = response_obj.choices
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ImageResponse
|
||||
):
|
||||
output = response_obj["data"]
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
dynamic_metadata = litellm_params.get("metadata", {}) or {}
|
||||
|
||||
# Get project_id from metadata or create default if needed
|
||||
project_id = dynamic_metadata.get("project_id")
|
||||
if project_id is None:
|
||||
project_name = dynamic_metadata.get("project_name")
|
||||
project_id = (
|
||||
self.get_project_id_sync(project_name) if project_name else None
|
||||
)
|
||||
|
||||
if project_id is None:
|
||||
if self.default_project_id is None:
|
||||
self.create_sync_default_project_and_experiment()
|
||||
project_id = self.default_project_id
|
||||
|
||||
tags = []
|
||||
|
||||
if isinstance(dynamic_metadata, dict):
|
||||
for key, value in dynamic_metadata.items():
|
||||
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
|
||||
if (
|
||||
litellm.langfuse_default_tags is not None
|
||||
and isinstance(litellm.langfuse_default_tags, list)
|
||||
and key in litellm.langfuse_default_tags
|
||||
):
|
||||
tags.append(f"{key}:{value}")
|
||||
|
||||
if (
|
||||
isinstance(value, str) and key not in standard_logging_object
|
||||
): # support logging dynamic metadata to braintrust
|
||||
standard_logging_object[key] = value
|
||||
|
||||
cost = kwargs.get("response_cost", None)
|
||||
|
||||
metrics: Optional[dict] = None
|
||||
usage_obj = getattr(response_obj, "usage", None)
|
||||
if usage_obj and isinstance(usage_obj, litellm.Usage):
|
||||
litellm.utils.get_logging_id(start_time, response_obj)
|
||||
metrics = {
|
||||
"prompt_tokens": usage_obj.prompt_tokens,
|
||||
"completion_tokens": usage_obj.completion_tokens,
|
||||
"total_tokens": usage_obj.total_tokens,
|
||||
"total_cost": cost,
|
||||
"time_to_first_token": end_time.timestamp()
|
||||
- start_time.timestamp(),
|
||||
"start": start_time.timestamp(),
|
||||
"end": end_time.timestamp(),
|
||||
}
|
||||
|
||||
# Allow metadata override for span name
|
||||
span_name = dynamic_metadata.get("span_name", "Chat Completion")
|
||||
|
||||
# Span parents is a special case
|
||||
span_parents = dynamic_metadata.get("span_parents")
|
||||
|
||||
# Convert comma-separated string to list if present
|
||||
if span_parents:
|
||||
span_parents = [s.strip() for s in span_parents.split(",") if s.strip()]
|
||||
|
||||
# Add optional span attributes only if present
|
||||
span_attributes = {
|
||||
"span_id": dynamic_metadata.get("span_id"),
|
||||
"root_span_id": dynamic_metadata.get("root_span_id"),
|
||||
"span_parents": span_parents,
|
||||
}
|
||||
|
||||
request_data = {
|
||||
"id": litellm_call_id,
|
||||
"input": prompt["messages"],
|
||||
"metadata": standard_logging_object,
|
||||
"span_attributes": {"name": span_name, "type": "llm"},
|
||||
}
|
||||
|
||||
# Braintrust cannot specify 'tags' for non-root spans
|
||||
if dynamic_metadata.get("root_span_id") is None:
|
||||
request_data["tags"] = tags
|
||||
|
||||
# Only add those that are not None (or falsy)
|
||||
for key, value in span_attributes.items():
|
||||
if value:
|
||||
request_data[key] = value
|
||||
|
||||
if choices is not None:
|
||||
request_data["output"] = [choice.dict() for choice in choices]
|
||||
else:
|
||||
request_data["output"] = output
|
||||
|
||||
if metrics is not None:
|
||||
request_data["metrics"] = metrics
|
||||
|
||||
try:
|
||||
print_verbose(
|
||||
f"self.global_braintrust_sync_http_handler.post: {self.global_braintrust_sync_http_handler.post}"
|
||||
)
|
||||
self.global_braintrust_sync_http_handler.post(
|
||||
url=f"{self.api_base}/project_logs/{project_id}/insert",
|
||||
json={"events": [request_data]},
|
||||
headers=self.headers,
|
||||
)
|
||||
if self.is_mock_mode:
|
||||
print_verbose("[BRAINTRUST MOCK] Sync event successfully mocked")
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(e.response.text)
|
||||
except Exception as e:
|
||||
raise e # don't use verbose_logger.exception, if exception is raised
|
||||
|
||||
async def async_log_success_event( # noqa: PLR0915
|
||||
self, kwargs, response_obj, start_time, end_time
|
||||
):
|
||||
verbose_logger.debug("REACHES BRAINTRUST SUCCESS")
|
||||
try:
|
||||
litellm_call_id = kwargs.get("litellm_call_id")
|
||||
standard_logging_object = kwargs.get("standard_logging_object", {})
|
||||
prompt = {"messages": kwargs.get("messages")}
|
||||
output = None
|
||||
choices = []
|
||||
if response_obj is not None and (
|
||||
kwargs.get("call_type", None) == "embedding"
|
||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
):
|
||||
output = None
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ModelResponse
|
||||
):
|
||||
output = response_obj["choices"][0]["message"].json()
|
||||
choices = response_obj["choices"]
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.TextCompletionResponse
|
||||
):
|
||||
output = response_obj.choices[0].text
|
||||
choices = response_obj.choices
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ImageResponse
|
||||
):
|
||||
output = response_obj["data"]
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
dynamic_metadata = litellm_params.get("metadata", {}) or {}
|
||||
|
||||
# Get project_id from metadata or create default if needed
|
||||
project_id = dynamic_metadata.get("project_id")
|
||||
if project_id is None:
|
||||
project_name = dynamic_metadata.get("project_name")
|
||||
project_id = (
|
||||
await self.get_project_id_async(project_name)
|
||||
if project_name
|
||||
else None
|
||||
)
|
||||
|
||||
if project_id is None:
|
||||
if self.default_project_id is None:
|
||||
await self.create_default_project_and_experiment()
|
||||
project_id = self.default_project_id
|
||||
|
||||
tags = []
|
||||
|
||||
if isinstance(dynamic_metadata, dict):
|
||||
for key, value in dynamic_metadata.items():
|
||||
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
|
||||
if (
|
||||
litellm.langfuse_default_tags is not None
|
||||
and isinstance(litellm.langfuse_default_tags, list)
|
||||
and key in litellm.langfuse_default_tags
|
||||
):
|
||||
tags.append(f"{key}:{value}")
|
||||
|
||||
if (
|
||||
isinstance(value, str) and key not in standard_logging_object
|
||||
): # support logging dynamic metadata to braintrust
|
||||
standard_logging_object[key] = value
|
||||
|
||||
cost = kwargs.get("response_cost", None)
|
||||
|
||||
metrics: Optional[dict] = None
|
||||
usage_obj = getattr(response_obj, "usage", None)
|
||||
if usage_obj and isinstance(usage_obj, litellm.Usage):
|
||||
litellm.utils.get_logging_id(start_time, response_obj)
|
||||
metrics = {
|
||||
"prompt_tokens": usage_obj.prompt_tokens,
|
||||
"completion_tokens": usage_obj.completion_tokens,
|
||||
"total_tokens": usage_obj.total_tokens,
|
||||
"total_cost": cost,
|
||||
"start": start_time.timestamp(),
|
||||
"end": end_time.timestamp(),
|
||||
}
|
||||
|
||||
api_call_start_time = kwargs.get("api_call_start_time")
|
||||
completion_start_time = kwargs.get("completion_start_time")
|
||||
|
||||
if (
|
||||
api_call_start_time is not None
|
||||
and completion_start_time is not None
|
||||
):
|
||||
metrics["time_to_first_token"] = (
|
||||
completion_start_time.timestamp()
|
||||
- api_call_start_time.timestamp()
|
||||
)
|
||||
|
||||
# Allow metadata override for span name
|
||||
span_name = dynamic_metadata.get("span_name", "Chat Completion")
|
||||
|
||||
# Span parents is a special case
|
||||
span_parents = dynamic_metadata.get("span_parents")
|
||||
|
||||
# Convert comma-separated string to list if present
|
||||
if span_parents:
|
||||
span_parents = [s.strip() for s in span_parents.split(",") if s.strip()]
|
||||
|
||||
# Add optional span attributes only if present
|
||||
span_attributes = {
|
||||
"span_id": dynamic_metadata.get("span_id"),
|
||||
"root_span_id": dynamic_metadata.get("root_span_id"),
|
||||
"span_parents": span_parents,
|
||||
}
|
||||
|
||||
request_data = {
|
||||
"id": litellm_call_id,
|
||||
"input": prompt["messages"],
|
||||
"output": output,
|
||||
"metadata": standard_logging_object,
|
||||
"span_attributes": {"name": span_name, "type": "llm"},
|
||||
}
|
||||
|
||||
# Braintrust cannot specify 'tags' for non-root spans
|
||||
if dynamic_metadata.get("root_span_id") is None:
|
||||
request_data["tags"] = tags
|
||||
|
||||
# Only add those that are not None (or falsy)
|
||||
for key, value in span_attributes.items():
|
||||
if value:
|
||||
request_data[key] = value
|
||||
|
||||
if choices is not None:
|
||||
request_data["output"] = [choice.dict() for choice in choices]
|
||||
else:
|
||||
request_data["output"] = output
|
||||
|
||||
if metrics is not None:
|
||||
request_data["metrics"] = metrics
|
||||
|
||||
try:
|
||||
await self.global_braintrust_http_handler.post(
|
||||
url=f"{self.api_base}/project_logs/{project_id}/insert",
|
||||
json={"events": [request_data]},
|
||||
headers=self.headers,
|
||||
)
|
||||
if self.is_mock_mode:
|
||||
print_verbose("[BRAINTRUST MOCK] Async event successfully mocked")
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(e.response.text)
|
||||
except Exception as e:
|
||||
raise e # don't use verbose_logger.exception, if exception is raised
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
return super().log_failure_event(kwargs, response_obj, start_time, end_time)
|
||||
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Mock HTTP client for Braintrust integration testing.
|
||||
|
||||
This module intercepts Braintrust API calls and returns successful mock responses,
|
||||
allowing full code execution without making actual network calls.
|
||||
|
||||
Usage:
|
||||
Set BRAINTRUST_MOCK=true in environment variables or config to enable mock mode.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.mock_client_factory import (
|
||||
MockClientConfig,
|
||||
MockResponse,
|
||||
create_mock_client_factory,
|
||||
)
|
||||
|
||||
# Use factory for should_use_mock and MockResponse
|
||||
# Braintrust uses both HTTPHandler (sync) and AsyncHTTPHandler (async)
|
||||
# Braintrust needs endpoint-specific responses, so we use custom HTTPHandler.post patching
|
||||
_config = MockClientConfig(
|
||||
"BRAINTRUST",
|
||||
"BRAINTRUST_MOCK",
|
||||
default_latency_ms=100,
|
||||
default_status_code=200,
|
||||
default_json_data={"id": "mock-project-id", "status": "success"},
|
||||
url_matchers=[
|
||||
".braintrustdata.com",
|
||||
"braintrustdata.com",
|
||||
".braintrust.dev",
|
||||
"braintrust.dev",
|
||||
],
|
||||
patch_async_handler=True, # Patch AsyncHTTPHandler.post for async calls
|
||||
patch_sync_client=False, # HTTPHandler uses self.client.send(), not self.client.post()
|
||||
patch_http_handler=False, # We use custom patching for endpoint-specific responses
|
||||
)
|
||||
|
||||
# Get should_use_mock and create_mock_client from factory
|
||||
# We need to call the factory's create_mock_client to patch AsyncHTTPHandler.post
|
||||
(
|
||||
create_mock_braintrust_factory_client,
|
||||
should_use_braintrust_mock,
|
||||
) = create_mock_client_factory(_config)
|
||||
|
||||
# Store original HTTPHandler.post method (Braintrust-specific for sync calls with custom logic)
|
||||
_original_http_handler_post = None
|
||||
_mocks_initialized = False
|
||||
|
||||
# Default mock latency in seconds
|
||||
_MOCK_LATENCY_SECONDS = float(os.getenv("BRAINTRUST_MOCK_LATENCY_MS", "100")) / 1000.0
|
||||
|
||||
|
||||
def _is_braintrust_url(url: str) -> bool:
|
||||
"""Check if URL is a Braintrust API URL."""
|
||||
if not isinstance(url, str):
|
||||
return False
|
||||
|
||||
parsed = urlparse(url)
|
||||
host = (parsed.hostname or "").lower()
|
||||
|
||||
if not host:
|
||||
return False
|
||||
|
||||
return (
|
||||
host == "braintrustdata.com"
|
||||
or host.endswith(".braintrustdata.com")
|
||||
or host == "braintrust.dev"
|
||||
or host.endswith(".braintrust.dev")
|
||||
)
|
||||
|
||||
|
||||
def _mock_http_handler_post(
|
||||
self,
|
||||
url,
|
||||
data=None,
|
||||
json=None,
|
||||
params=None,
|
||||
headers=None,
|
||||
timeout=None,
|
||||
stream=False,
|
||||
files=None,
|
||||
content=None,
|
||||
logging_obj=None,
|
||||
):
|
||||
"""Monkey-patched HTTPHandler.post that intercepts Braintrust calls with endpoint-specific responses."""
|
||||
# Only mock Braintrust API calls
|
||||
if isinstance(url, str) and _is_braintrust_url(url):
|
||||
verbose_logger.info(f"[BRAINTRUST MOCK] POST to {url}")
|
||||
time.sleep(_MOCK_LATENCY_SECONDS)
|
||||
# Return appropriate mock response based on endpoint
|
||||
if "/project" in url:
|
||||
# Project creation/retrieval/register endpoint
|
||||
project_name = json.get("name", "litellm") if json else "litellm"
|
||||
mock_data = {"id": f"mock-project-id-{project_name}", "name": project_name}
|
||||
elif "/project_logs" in url:
|
||||
# Log insertion endpoint
|
||||
mock_data = {"status": "success"}
|
||||
else:
|
||||
mock_data = _config.default_json_data
|
||||
return MockResponse(
|
||||
status_code=_config.default_status_code,
|
||||
json_data=mock_data,
|
||||
url=url,
|
||||
elapsed_seconds=_MOCK_LATENCY_SECONDS,
|
||||
)
|
||||
if _original_http_handler_post is not None:
|
||||
return _original_http_handler_post(
|
||||
self,
|
||||
url=url,
|
||||
data=data,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
files=files,
|
||||
content=content,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
raise RuntimeError("Original HTTPHandler.post not available")
|
||||
|
||||
|
||||
def create_mock_braintrust_client():
|
||||
"""
|
||||
Monkey-patch HTTPHandler.post to intercept Braintrust sync calls.
|
||||
|
||||
Braintrust uses HTTPHandler for sync calls and AsyncHTTPHandler for async calls.
|
||||
HTTPHandler.post uses self.client.send(), not self.client.post(), so we need
|
||||
custom patching for sync (similar to Helicone).
|
||||
AsyncHTTPHandler.post is patched by the factory.
|
||||
|
||||
We use custom patching instead of factory's patch_http_handler because we need
|
||||
endpoint-specific responses (different for /project vs /project_logs).
|
||||
|
||||
This function is idempotent - it only initializes mocks once, even if called multiple times.
|
||||
"""
|
||||
global _original_http_handler_post, _mocks_initialized
|
||||
|
||||
if _mocks_initialized:
|
||||
return
|
||||
|
||||
verbose_logger.debug("[BRAINTRUST MOCK] Initializing Braintrust mock client...")
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
if _original_http_handler_post is None:
|
||||
_original_http_handler_post = HTTPHandler.post
|
||||
HTTPHandler.post = _mock_http_handler_post # type: ignore
|
||||
verbose_logger.debug("[BRAINTRUST MOCK] Patched HTTPHandler.post")
|
||||
|
||||
# CRITICAL: Call the factory's initialization function to patch AsyncHTTPHandler.post
|
||||
# This is required for async calls to be mocked
|
||||
create_mock_braintrust_factory_client()
|
||||
|
||||
verbose_logger.debug(
|
||||
f"[BRAINTRUST MOCK] Mock latency set to {_MOCK_LATENCY_SECONDS*1000:.0f}ms"
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"[BRAINTRUST MOCK] Braintrust mock client initialization complete"
|
||||
)
|
||||
|
||||
_mocks_initialized = True
|
||||
@@ -0,0 +1,458 @@
|
||||
[
|
||||
{
|
||||
"id": "arize",
|
||||
"displayName": "Arize",
|
||||
"logo": "arize.png",
|
||||
"supports_key_team_logging": true,
|
||||
"dynamic_params": {
|
||||
"arize_api_key": {
|
||||
"type": "password",
|
||||
"ui_name": "API Key",
|
||||
"description": "Arize API key for authentication",
|
||||
"required": true
|
||||
},
|
||||
"arize_space_id": {
|
||||
"type": "password",
|
||||
"ui_name": "Space ID",
|
||||
"description": "Arize Space ID to identify your workspace",
|
||||
"required": true
|
||||
}
|
||||
},
|
||||
"description": "Arize Logging Integration"
|
||||
},
|
||||
{
|
||||
"id": "braintrust",
|
||||
"displayName": "Braintrust",
|
||||
"logo": "braintrust.png",
|
||||
"supports_key_team_logging": false,
|
||||
"dynamic_params": {
|
||||
"braintrust_api_key": {
|
||||
"type": "password",
|
||||
"ui_name": "API Key",
|
||||
"description": "Braintrust API key for authentication",
|
||||
"required": true
|
||||
},
|
||||
"braintrust_project_name": {
|
||||
"type": "text",
|
||||
"ui_name": "Project Name",
|
||||
"description": "Name of the Braintrust project to log to",
|
||||
"required": true
|
||||
}
|
||||
},
|
||||
"description": "Braintrust Logging Integration"
|
||||
},
|
||||
{
|
||||
"id": "generic_api",
|
||||
"displayName": "Custom Callback API",
|
||||
"logo": "custom.svg",
|
||||
"supports_key_team_logging": true,
|
||||
"dynamic_params": {
|
||||
"GENERIC_LOGGER_ENDPOINT": {
|
||||
"type": "text",
|
||||
"ui_name": "Callback URL",
|
||||
"description": "Your custom webhook/API endpoint URL to receive logs",
|
||||
"required": true
|
||||
},
|
||||
"GENERIC_LOGGER_HEADERS": {
|
||||
"type": "text",
|
||||
"ui_name": "Headers",
|
||||
"description": "Custom HTTP headers as a comma-separated string (e.g., Authorization: Bearer token, Content-Type: application/json)",
|
||||
"required": false
|
||||
}
|
||||
},
|
||||
"description": "Custom Callback API Logging Integration"
|
||||
},
|
||||
{
|
||||
"id": "datadog",
|
||||
"displayName": "Datadog",
|
||||
"logo": "datadog.png",
|
||||
"supports_key_team_logging": false,
|
||||
"dynamic_params": {
|
||||
"dd_api_key": {
|
||||
"type": "password",
|
||||
"ui_name": "API Key",
|
||||
"description": "Datadog API key for authentication",
|
||||
"required": true
|
||||
},
|
||||
"dd_site": {
|
||||
"type": "text",
|
||||
"ui_name": "Site",
|
||||
"description": "Datadog site URL (e.g., us5.datadoghq.com)",
|
||||
"required": true
|
||||
}
|
||||
},
|
||||
"description": "Datadog Logging Integration"
|
||||
},
|
||||
{
|
||||
"id": "datadog_metrics",
|
||||
"displayName": "Datadog Metrics",
|
||||
"logo": "datadog.png",
|
||||
"supports_key_team_logging": false,
|
||||
"dynamic_params": {
|
||||
"dd_api_key": {
|
||||
"type": "password",
|
||||
"ui_name": "API Key",
|
||||
"description": "Datadog API key for authentication",
|
||||
"required": true
|
||||
},
|
||||
"dd_site": {
|
||||
"type": "text",
|
||||
"ui_name": "Site",
|
||||
"description": "Datadog site URL (e.g., us5.datadoghq.com)",
|
||||
"required": true
|
||||
}
|
||||
},
|
||||
"description": "Datadog Custom Metrics Integration"
|
||||
},
|
||||
{
|
||||
"id": "datadog_cost_management",
|
||||
"displayName": "Datadog Cost Management",
|
||||
"logo": "datadog.png",
|
||||
"supports_key_team_logging": false,
|
||||
"dynamic_params": {
|
||||
"dd_api_key": {
|
||||
"type": "password",
|
||||
"ui_name": "API Key",
|
||||
"description": "Datadog API key for authentication",
|
||||
"required": true
|
||||
},
|
||||
"dd_app_key": {
|
||||
"type": "password",
|
||||
"ui_name": "App Key",
|
||||
"description": "Datadog Application Key for Cloud Cost Management",
|
||||
"required": true
|
||||
},
|
||||
"dd_site": {
|
||||
"type": "text",
|
||||
"ui_name": "Site",
|
||||
"description": "Datadog site URL (e.g., us5.datadoghq.com)",
|
||||
"required": true
|
||||
}
|
||||
},
|
||||
"description": "Datadog Cloud Cost Management Integration"
|
||||
},
|
||||
{
|
||||
"id": "lago",
|
||||
"displayName": "Lago",
|
||||
"logo": "lago.svg",
|
||||
"supports_key_team_logging": false,
|
||||
"dynamic_params": {
|
||||
"lago_api_url": {
|
||||
"type": "text",
|
||||
"ui_name": "API URL",
|
||||
"description": "Lago API base URL",
|
||||
"required": true
|
||||
},
|
||||
"lago_api_key": {
|
||||
"type": "password",
|
||||
"ui_name": "API Key",
|
||||
"description": "Lago API key for authentication",
|
||||
"required": true
|
||||
}
|
||||
},
|
||||
"description": "Lago Billing Logging Integration"
|
||||
},
|
||||
{
|
||||
"id": "langfuse",
|
||||
"displayName": "Langfuse",
|
||||
"logo": "langfuse.png",
|
||||
"supports_key_team_logging": true,
|
||||
"dynamic_params": {
|
||||
"langfuse_public_key": {
|
||||
"type": "text",
|
||||
"ui_name": "Public Key",
|
||||
"description": "Langfuse public key",
|
||||
"required": true
|
||||
},
|
||||
"langfuse_secret_key": {
|
||||
"type": "password",
|
||||
"ui_name": "Secret Key",
|
||||
"description": "Langfuse secret key for authentication",
|
||||
"required": true
|
||||
},
|
||||
"langfuse_host": {
|
||||
"type": "text",
|
||||
"ui_name": "Host URL",
|
||||
"description": "Langfuse host URL (default: https://cloud.langfuse.com)",
|
||||
"required": false
|
||||
}
|
||||
},
|
||||
"description": "Langfuse v2 Logging Integration"
|
||||
},
|
||||
{
|
||||
"id": "langfuse_otel",
|
||||
"displayName": "Langfuse OTEL",
|
||||
"logo": "langfuse.png",
|
||||
"supports_key_team_logging": true,
|
||||
"dynamic_params": {
|
||||
"langfuse_public_key": {
|
||||
"type": "text",
|
||||
"ui_name": "Public Key",
|
||||
"description": "Langfuse public key",
|
||||
"required": true
|
||||
},
|
||||
"langfuse_secret_key": {
|
||||
"type": "password",
|
||||
"ui_name": "Secret Key",
|
||||
"description": "Langfuse secret key for authentication",
|
||||
"required": true
|
||||
},
|
||||
"langfuse_host": {
|
||||
"type": "text",
|
||||
"ui_name": "Host URL",
|
||||
"description": "Langfuse host URL (default: https://cloud.langfuse.com)",
|
||||
"required": false
|
||||
}
|
||||
},
|
||||
"description": "Langfuse v3 OTEL Logging Integration"
|
||||
},
|
||||
{
|
||||
"id": "langsmith",
|
||||
"displayName": "LangSmith",
|
||||
"logo": "langsmith.png",
|
||||
"supports_key_team_logging": true,
|
||||
"dynamic_params": {
|
||||
"langsmith_api_key": {
|
||||
"type": "password",
|
||||
"ui_name": "API Key",
|
||||
"description": "LangSmith API key for authentication",
|
||||
"required": true
|
||||
},
|
||||
"langsmith_project": {
|
||||
"type": "text",
|
||||
"ui_name": "Project Name",
|
||||
"description": "LangSmith project name (default: litellm-completion)",
|
||||
"required": false
|
||||
},
|
||||
"langsmith_base_url": {
|
||||
"type": "text",
|
||||
"ui_name": "Base URL",
|
||||
"description": "LangSmith base URL (default: https://api.smith.langchain.com)",
|
||||
"required": false
|
||||
},
|
||||
"langsmith_sampling_rate": {
|
||||
"type": "number",
|
||||
"ui_name": "Sampling Rate",
|
||||
"description": "Sampling rate for logging (0.0 to 1.0, default: 1.0)",
|
||||
"required": false
|
||||
},
|
||||
"langsmith_tenant_id": {
|
||||
"type": "text",
|
||||
"ui_name": "Tenant ID",
|
||||
"description": "LangSmith tenant ID for organization-scoped API keys (required when using org-scoped keys)",
|
||||
"required": false
|
||||
}
|
||||
},
|
||||
"description": "Langsmith Logging Integration"
|
||||
},
|
||||
{
|
||||
"id": "openmeter",
|
||||
"displayName": "OpenMeter",
|
||||
"logo": "openmeter.png",
|
||||
"supports_key_team_logging": false,
|
||||
"dynamic_params": {
|
||||
"openmeter_api_key": {
|
||||
"type": "password",
|
||||
"ui_name": "API Key",
|
||||
"description": "OpenMeter API key for authentication",
|
||||
"required": true
|
||||
},
|
||||
"openmeter_base_url": {
|
||||
"type": "text",
|
||||
"ui_name": "Base URL",
|
||||
"description": "OpenMeter base URL (default: https://openmeter.cloud)",
|
||||
"required": false
|
||||
}
|
||||
},
|
||||
"description": "OpenMeter Logging Integration"
|
||||
},
|
||||
{
|
||||
"id": "otel",
|
||||
"displayName": "Open Telemetry",
|
||||
"logo": "otel.png",
|
||||
"supports_key_team_logging": false,
|
||||
"dynamic_params": {
|
||||
"otel_endpoint": {
|
||||
"type": "text",
|
||||
"ui_name": "Endpoint URL",
|
||||
"description": "OpenTelemetry collector endpoint URL",
|
||||
"required": true
|
||||
},
|
||||
"otel_headers": {
|
||||
"type": "text",
|
||||
"ui_name": "Headers",
|
||||
"description": "Headers for OTEL exporter (e.g., x-honeycomb-team=YOUR_API_KEY)",
|
||||
"required": false
|
||||
}
|
||||
},
|
||||
"description": "OpenTelemetry Logging Integration"
|
||||
},
|
||||
{
|
||||
"id": "s3",
|
||||
"displayName": "S3",
|
||||
"logo": "aws.svg",
|
||||
"supports_key_team_logging": false,
|
||||
"dynamic_params": {
|
||||
"s3_bucket_name": {
|
||||
"type": "text",
|
||||
"ui_name": "Bucket Name",
|
||||
"description": "AWS S3 bucket name to store logs",
|
||||
"required": true
|
||||
},
|
||||
"s3_region_name": {
|
||||
"type": "text",
|
||||
"ui_name": "AWS Region",
|
||||
"description": "AWS region name (e.g., us-east-1)",
|
||||
"required": false
|
||||
},
|
||||
"s3_aws_access_key_id": {
|
||||
"type": "password",
|
||||
"ui_name": "AWS Access Key ID",
|
||||
"description": "AWS access key ID for authentication",
|
||||
"required": false
|
||||
},
|
||||
"s3_aws_secret_access_key": {
|
||||
"type": "password",
|
||||
"ui_name": "AWS Secret Access Key",
|
||||
"description": "AWS secret access key for authentication",
|
||||
"required": false
|
||||
},
|
||||
"s3_aws_session_token": {
|
||||
"type": "password",
|
||||
"ui_name": "AWS Session Token",
|
||||
"description": "AWS session token for temporary credentials",
|
||||
"required": false
|
||||
},
|
||||
"s3_endpoint_url": {
|
||||
"type": "text",
|
||||
"ui_name": "S3 Endpoint URL",
|
||||
"description": "Custom S3 endpoint URL (for MinIO or custom S3-compatible services)",
|
||||
"required": false
|
||||
},
|
||||
"s3_path": {
|
||||
"type": "text",
|
||||
"ui_name": "S3 Path Prefix",
|
||||
"description": "Path prefix within the bucket for organizing logs",
|
||||
"required": false
|
||||
}
|
||||
},
|
||||
"description": "S3 Bucket (AWS) Logging Integration"
|
||||
},
|
||||
{
|
||||
"id": "sqs",
|
||||
"displayName": "SQS",
|
||||
"logo": "aws.svg",
|
||||
"supports_key_team_logging": false,
|
||||
"dynamic_params": {
|
||||
"sqs_queue_url": {
|
||||
"type": "text",
|
||||
"ui_name": "Queue URL",
|
||||
"description": "AWS SQS Queue URL",
|
||||
"required": true
|
||||
},
|
||||
"sqs_region_name": {
|
||||
"type": "text",
|
||||
"ui_name": "AWS Region",
|
||||
"description": "AWS region name (e.g., us-east-1)",
|
||||
"required": false
|
||||
},
|
||||
"sqs_aws_access_key_id": {
|
||||
"type": "password",
|
||||
"ui_name": "AWS Access Key ID",
|
||||
"description": "AWS access key ID for authentication",
|
||||
"required": false
|
||||
},
|
||||
"sqs_aws_secret_access_key": {
|
||||
"type": "password",
|
||||
"ui_name": "AWS Secret Access Key",
|
||||
"description": "AWS secret access key for authentication",
|
||||
"required": false
|
||||
},
|
||||
"sqs_aws_session_token": {
|
||||
"type": "password",
|
||||
"ui_name": "AWS Session Token",
|
||||
"description": "AWS session token for temporary credentials",
|
||||
"required": false
|
||||
},
|
||||
"sqs_aws_session_name": {
|
||||
"type": "text",
|
||||
"ui_name": "AWS Session Name",
|
||||
"description": "Name for AWS session",
|
||||
"required": false
|
||||
},
|
||||
"sqs_aws_profile_name": {
|
||||
"type": "text",
|
||||
"ui_name": "AWS Profile Name",
|
||||
"description": "AWS profile name from credentials file",
|
||||
"required": false
|
||||
},
|
||||
"sqs_aws_role_name": {
|
||||
"type": "text",
|
||||
"ui_name": "AWS Role Name",
|
||||
"description": "AWS IAM role name to assume",
|
||||
"required": false
|
||||
},
|
||||
"sqs_aws_web_identity_token": {
|
||||
"type": "password",
|
||||
"ui_name": "AWS Web Identity Token",
|
||||
"description": "AWS web identity token for authentication",
|
||||
"required": false
|
||||
},
|
||||
"sqs_aws_sts_endpoint": {
|
||||
"type": "text",
|
||||
"ui_name": "AWS STS Endpoint",
|
||||
"description": "AWS STS endpoint URL",
|
||||
"required": false
|
||||
},
|
||||
"sqs_endpoint_url": {
|
||||
"type": "text",
|
||||
"ui_name": "SQS Endpoint URL",
|
||||
"description": "Custom SQS endpoint URL (for LocalStack or custom endpoints)",
|
||||
"required": false
|
||||
},
|
||||
"sqs_api_version": {
|
||||
"type": "text",
|
||||
"ui_name": "API Version",
|
||||
"description": "SQS API version",
|
||||
"required": false
|
||||
},
|
||||
"sqs_use_ssl": {
|
||||
"type": "boolean",
|
||||
"ui_name": "Use SSL",
|
||||
"description": "Whether to use SSL for SQS connections",
|
||||
"required": false
|
||||
},
|
||||
"sqs_verify": {
|
||||
"type": "boolean",
|
||||
"ui_name": "Verify SSL",
|
||||
"description": "Whether to verify SSL certificates",
|
||||
"required": false
|
||||
},
|
||||
"sqs_strip_base64_files": {
|
||||
"type": "boolean",
|
||||
"ui_name": "Strip Base64 Files",
|
||||
"description": "Remove base64-encoded files from logs to reduce payload size",
|
||||
"required": false
|
||||
},
|
||||
"sqs_aws_use_application_level_encryption": {
|
||||
"type": "boolean",
|
||||
"ui_name": "Use Application-Level Encryption",
|
||||
"description": "Enable application-level encryption for SQS messages",
|
||||
"required": false
|
||||
},
|
||||
"sqs_app_encryption_key_b64": {
|
||||
"type": "password",
|
||||
"ui_name": "Encryption Key (Base64)",
|
||||
"description": "Base64-encoded encryption key for application-level encryption",
|
||||
"required": false
|
||||
},
|
||||
"sqs_app_encryption_aad": {
|
||||
"type": "text",
|
||||
"ui_name": "Encryption AAD",
|
||||
"description": "Additional authenticated data for encryption",
|
||||
"required": false
|
||||
}
|
||||
},
|
||||
"description": "SQS Queue (AWS) Logging Integration"
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,422 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import CLOUDZERO_EXPORT_INTERVAL_MINUTES
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
else:
|
||||
AsyncIOScheduler = Any
|
||||
|
||||
|
||||
class CloudZeroLogger(CustomLogger):
|
||||
"""
|
||||
CloudZero Logger for exporting LiteLLM usage data to CloudZero AnyCost API.
|
||||
|
||||
Environment Variables:
|
||||
CLOUDZERO_API_KEY: CloudZero API key for authentication
|
||||
CLOUDZERO_CONNECTION_ID: CloudZero connection ID for data submission
|
||||
CLOUDZERO_TIMEZONE: Timezone for date handling (default: UTC)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
connection_id: Optional[str] = None,
|
||||
timezone: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize CloudZero logger with configuration from parameters or environment variables."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Get configuration from parameters first, fall back to environment variables
|
||||
self.api_key = api_key or os.getenv("CLOUDZERO_API_KEY")
|
||||
self.connection_id = connection_id or os.getenv("CLOUDZERO_CONNECTION_ID")
|
||||
self.timezone = timezone or os.getenv("CLOUDZERO_TIMEZONE", "UTC")
|
||||
verbose_logger.debug(
|
||||
f"CloudZero Logger initialized with connection ID: {self.connection_id}, timezone: {self.timezone}"
|
||||
)
|
||||
|
||||
async def initialize_cloudzero_export_job(self):
|
||||
"""
|
||||
Handler for initializing CloudZero export job.
|
||||
|
||||
Runs when CloudZero logger starts up.
|
||||
|
||||
- If redis cache is available, we use the pod lock manager to acquire a lock and export the data.
|
||||
- Ensures only one pod exports the data at a time.
|
||||
- If redis cache is not available, we export the data directly.
|
||||
"""
|
||||
from litellm.constants import (
|
||||
CLOUDZERO_EXPORT_USAGE_DATA_JOB_NAME,
|
||||
)
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
pod_lock_manager = proxy_logging_obj.db_spend_update_writer.pod_lock_manager
|
||||
|
||||
# if using redis, ensure only one pod exports the data at a time
|
||||
if pod_lock_manager and pod_lock_manager.redis_cache:
|
||||
if await pod_lock_manager.acquire_lock(
|
||||
cronjob_id=CLOUDZERO_EXPORT_USAGE_DATA_JOB_NAME
|
||||
):
|
||||
try:
|
||||
await self._hourly_usage_data_export()
|
||||
finally:
|
||||
await pod_lock_manager.release_lock(
|
||||
cronjob_id=CLOUDZERO_EXPORT_USAGE_DATA_JOB_NAME
|
||||
)
|
||||
else:
|
||||
# if not using redis, export the data directly
|
||||
await self._hourly_usage_data_export()
|
||||
|
||||
async def _hourly_usage_data_export(self):
|
||||
"""
|
||||
Exports the hourly usage data to CloudZero.
|
||||
|
||||
Start time: 1 hour ago
|
||||
End time: current time
|
||||
"""
|
||||
from datetime import timedelta, timezone
|
||||
|
||||
from litellm.constants import CLOUDZERO_MAX_FETCHED_DATA_RECORDS
|
||||
|
||||
current_time_utc = datetime.now(timezone.utc)
|
||||
# Mitigates the possibility of missing spend if an hour is skipped due to a restart in an ephemeral environment
|
||||
one_hour_ago_utc = current_time_utc - timedelta(
|
||||
minutes=CLOUDZERO_EXPORT_INTERVAL_MINUTES * 2
|
||||
)
|
||||
await self.export_usage_data(
|
||||
limit=CLOUDZERO_MAX_FETCHED_DATA_RECORDS,
|
||||
operation="replace_hourly",
|
||||
start_time_utc=one_hour_ago_utc,
|
||||
end_time_utc=current_time_utc,
|
||||
)
|
||||
|
||||
async def export_usage_data(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
operation: str = "replace_hourly",
|
||||
start_time_utc: Optional[datetime] = None,
|
||||
end_time_utc: Optional[datetime] = None,
|
||||
):
|
||||
"""
|
||||
Exports the usage data to CloudZero.
|
||||
|
||||
- Reads data from the DB
|
||||
- Transforms the data to the CloudZero format
|
||||
- Sends the data to CloudZero
|
||||
|
||||
Args:
|
||||
limit: Optional limit on number of records to export
|
||||
operation: CloudZero operation type ("replace_hourly" or "sum")
|
||||
"""
|
||||
from litellm.integrations.cloudzero.cz_stream_api import CloudZeroStreamer
|
||||
from litellm.integrations.cloudzero.database import LiteLLMDatabase
|
||||
from litellm.integrations.cloudzero.transform import CBFTransformer
|
||||
|
||||
try:
|
||||
verbose_logger.debug("CloudZero Logger: Starting usage data export")
|
||||
|
||||
# Validate required configuration
|
||||
if not self.api_key or not self.connection_id:
|
||||
raise ValueError(
|
||||
"CloudZero configuration missing. Please set CLOUDZERO_API_KEY and CLOUDZERO_CONNECTION_ID environment variables."
|
||||
)
|
||||
|
||||
# Initialize database connection and load data
|
||||
database = LiteLLMDatabase()
|
||||
verbose_logger.debug("CloudZero Logger: Loading usage data from database")
|
||||
data = await database.get_usage_data(
|
||||
limit=limit, start_time_utc=start_time_utc, end_time_utc=end_time_utc
|
||||
)
|
||||
|
||||
if data.is_empty():
|
||||
verbose_logger.debug("CloudZero Logger: No usage data found to export")
|
||||
return
|
||||
|
||||
verbose_logger.debug(f"CloudZero Logger: Processing {len(data)} records")
|
||||
|
||||
# Transform data to CloudZero CBF format
|
||||
transformer = CBFTransformer()
|
||||
cbf_data = transformer.transform(data)
|
||||
|
||||
if cbf_data.is_empty():
|
||||
verbose_logger.warning(
|
||||
"CloudZero Logger: No valid data after transformation"
|
||||
)
|
||||
return
|
||||
|
||||
# Send data to CloudZero
|
||||
streamer = CloudZeroStreamer(
|
||||
api_key=self.api_key,
|
||||
connection_id=self.connection_id,
|
||||
user_timezone=self.timezone,
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"CloudZero Logger: Transmitting {len(cbf_data)} records to CloudZero"
|
||||
)
|
||||
streamer.send_batched(cbf_data, operation=operation)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"CloudZero Logger: Successfully exported {len(cbf_data)} records to CloudZero"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"CloudZero Logger: Error exporting usage data: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def dry_run_export_usage_data(self, limit: Optional[int] = 10000):
|
||||
"""
|
||||
Returns the data that would be exported to CloudZero without actually sending it.
|
||||
|
||||
Args:
|
||||
limit: Limit number of records to display (default: 10000)
|
||||
|
||||
Returns:
|
||||
dict: Contains usage_data, cbf_data, and summary statistics
|
||||
"""
|
||||
from litellm.integrations.cloudzero.database import LiteLLMDatabase
|
||||
from litellm.integrations.cloudzero.transform import CBFTransformer
|
||||
|
||||
try:
|
||||
verbose_logger.debug("CloudZero Logger: Starting dry run export")
|
||||
|
||||
# Initialize database connection and load data
|
||||
database = LiteLLMDatabase()
|
||||
verbose_logger.debug("CloudZero Logger: Loading usage data for dry run")
|
||||
data = await database.get_usage_data(limit=limit)
|
||||
|
||||
if data.is_empty():
|
||||
verbose_logger.warning("CloudZero Dry Run: No usage data found")
|
||||
return {
|
||||
"usage_data": [],
|
||||
"cbf_data": [],
|
||||
"summary": {
|
||||
"total_records": 0,
|
||||
"total_cost": 0,
|
||||
"total_tokens": 0,
|
||||
"unique_accounts": 0,
|
||||
"unique_services": 0,
|
||||
},
|
||||
}
|
||||
|
||||
verbose_logger.debug(
|
||||
f"CloudZero Dry Run: Processing {len(data)} records..."
|
||||
)
|
||||
|
||||
# Convert usage data to dict format for response
|
||||
usage_data_sample = data.head(50).to_dicts() # Return first 50 rows
|
||||
|
||||
# Transform data to CloudZero CBF format
|
||||
transformer = CBFTransformer()
|
||||
cbf_data = transformer.transform(data)
|
||||
|
||||
if cbf_data.is_empty():
|
||||
verbose_logger.warning(
|
||||
"CloudZero Dry Run: No valid data after transformation"
|
||||
)
|
||||
return {
|
||||
"usage_data": usage_data_sample,
|
||||
"cbf_data": [],
|
||||
"summary": {
|
||||
"total_records": len(usage_data_sample),
|
||||
"total_cost": sum(
|
||||
row.get("spend", 0) for row in usage_data_sample
|
||||
),
|
||||
"total_tokens": sum(
|
||||
row.get("prompt_tokens", 0)
|
||||
+ row.get("completion_tokens", 0)
|
||||
for row in usage_data_sample
|
||||
),
|
||||
"unique_accounts": 0,
|
||||
"unique_services": 0,
|
||||
},
|
||||
}
|
||||
|
||||
# Convert CBF data to dict format for response
|
||||
cbf_data_dict = cbf_data.to_dicts()
|
||||
|
||||
# Calculate summary statistics
|
||||
total_cost = sum(record.get("cost/cost", 0) for record in cbf_data_dict)
|
||||
unique_accounts = len(
|
||||
set(
|
||||
record.get("resource/account", "")
|
||||
for record in cbf_data_dict
|
||||
if record.get("resource/account")
|
||||
)
|
||||
)
|
||||
unique_services = len(
|
||||
set(
|
||||
record.get("resource/service", "")
|
||||
for record in cbf_data_dict
|
||||
if record.get("resource/service")
|
||||
)
|
||||
)
|
||||
total_tokens = sum(
|
||||
record.get("usage/amount", 0) for record in cbf_data_dict
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"CloudZero Logger: Dry run completed for {len(cbf_data)} records"
|
||||
)
|
||||
|
||||
return {
|
||||
"usage_data": usage_data_sample,
|
||||
"cbf_data": cbf_data_dict,
|
||||
"summary": {
|
||||
"total_records": len(cbf_data_dict),
|
||||
"total_cost": total_cost,
|
||||
"total_tokens": total_tokens,
|
||||
"unique_accounts": unique_accounts,
|
||||
"unique_services": unique_services,
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"CloudZero Logger: Error in dry run export: {str(e)}")
|
||||
verbose_logger.error(f"CloudZero Dry Run Error: {str(e)}")
|
||||
raise
|
||||
|
||||
def _display_cbf_data_on_screen(self, cbf_data):
|
||||
"""Display CBF transformed data in a formatted table on screen."""
|
||||
from rich.box import SIMPLE
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
console = Console()
|
||||
|
||||
if cbf_data.is_empty():
|
||||
console.print("[yellow]No CBF data to display[/yellow]")
|
||||
return
|
||||
|
||||
console.print(
|
||||
f"\n[bold green]💰 CloudZero CBF Transformed Data ({len(cbf_data)} records)[/bold green]"
|
||||
)
|
||||
|
||||
# Convert to dicts for easier processing
|
||||
records = cbf_data.to_dicts()
|
||||
|
||||
# Create main CBF table
|
||||
cbf_table = Table(
|
||||
show_header=True, header_style="bold cyan", box=SIMPLE, padding=(0, 1)
|
||||
)
|
||||
cbf_table.add_column("time/usage_start", style="blue", no_wrap=False)
|
||||
cbf_table.add_column("cost/cost", style="green", justify="right", no_wrap=False)
|
||||
cbf_table.add_column(
|
||||
"entity_type", style="magenta", justify="right", no_wrap=False
|
||||
)
|
||||
cbf_table.add_column(
|
||||
"entity_id", style="magenta", justify="right", no_wrap=False
|
||||
)
|
||||
cbf_table.add_column("team_id", style="cyan", no_wrap=False)
|
||||
cbf_table.add_column("team_alias", style="cyan", no_wrap=False)
|
||||
cbf_table.add_column("user_email", style="cyan", no_wrap=False)
|
||||
cbf_table.add_column("api_key_alias", style="yellow", no_wrap=False)
|
||||
cbf_table.add_column(
|
||||
"usage/amount", style="yellow", justify="right", no_wrap=False
|
||||
)
|
||||
cbf_table.add_column("resource/id", style="magenta", no_wrap=False)
|
||||
cbf_table.add_column("resource/service", style="cyan", no_wrap=False)
|
||||
cbf_table.add_column("resource/account", style="white", no_wrap=False)
|
||||
cbf_table.add_column("resource/region", style="dim", no_wrap=False)
|
||||
|
||||
for record in records:
|
||||
# Use proper CBF field names
|
||||
time_usage_start = str(record.get("time/usage_start", "N/A"))
|
||||
cost_cost = str(record.get("cost/cost", 0))
|
||||
usage_amount = str(record.get("usage/amount", 0))
|
||||
resource_id = str(record.get("resource/id", "N/A"))
|
||||
resource_service = str(record.get("resource/service", "N/A"))
|
||||
resource_account = str(record.get("resource/account", "N/A"))
|
||||
resource_region = str(record.get("resource/region", "N/A"))
|
||||
entity_type = str(record.get("entity_type", "N/A"))
|
||||
entity_id = str(record.get("entity_id", "N/A"))
|
||||
team_id = str(record.get("resource/tag:team_id", "N/A"))
|
||||
team_alias = str(record.get("resource/tag:team_alias", "N/A"))
|
||||
user_email = str(record.get("resource/tag:user_email", "N/A"))
|
||||
api_key_alias = str(record.get("resource/tag:api_key_alias", "N/A"))
|
||||
|
||||
cbf_table.add_row(
|
||||
time_usage_start,
|
||||
cost_cost,
|
||||
entity_type,
|
||||
entity_id,
|
||||
team_id,
|
||||
team_alias,
|
||||
user_email,
|
||||
api_key_alias,
|
||||
usage_amount,
|
||||
resource_id,
|
||||
resource_service,
|
||||
resource_account,
|
||||
resource_region,
|
||||
)
|
||||
|
||||
console.print(cbf_table)
|
||||
|
||||
# Show summary statistics
|
||||
total_cost = sum(record.get("cost/cost", 0) for record in records)
|
||||
unique_accounts = len(
|
||||
set(
|
||||
record.get("resource/account", "")
|
||||
for record in records
|
||||
if record.get("resource/account")
|
||||
)
|
||||
)
|
||||
unique_services = len(
|
||||
set(
|
||||
record.get("resource/service", "")
|
||||
for record in records
|
||||
if record.get("resource/service")
|
||||
)
|
||||
)
|
||||
|
||||
# Count total tokens from usage metrics
|
||||
total_tokens = sum(record.get("usage/amount", 0) for record in records)
|
||||
|
||||
console.print("\n[bold blue]📊 CBF Summary[/bold blue]")
|
||||
console.print(f" Records: {len(records):,}")
|
||||
console.print(f" Total Cost: ${total_cost:.2f}")
|
||||
console.print(f" Total Tokens: {total_tokens:,}")
|
||||
console.print(f" Unique Accounts: {unique_accounts}")
|
||||
console.print(f" Unique Services: {unique_services}")
|
||||
|
||||
console.print(
|
||||
"\n[dim]💡 This is the CloudZero CBF format ready for AnyCost ingestion[/dim]"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def init_cloudzero_background_job(scheduler: AsyncIOScheduler):
|
||||
"""
|
||||
Initialize the CloudZero background job.
|
||||
|
||||
Starts the background job that exports the usage data to CloudZero every hour.
|
||||
"""
|
||||
from litellm.constants import CLOUDZERO_EXPORT_INTERVAL_MINUTES
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
prometheus_loggers: List[
|
||||
CustomLogger
|
||||
] = litellm.logging_callback_manager.get_custom_loggers_for_type(
|
||||
callback_type=CloudZeroLogger
|
||||
)
|
||||
# we need to get the initialized prometheus logger instance(s) and call logger.initialize_remaining_budget_metrics() on them
|
||||
verbose_logger.debug("found %s cloudzero loggers", len(prometheus_loggers))
|
||||
if len(prometheus_loggers) > 0:
|
||||
cloudzero_logger = cast(CloudZeroLogger, prometheus_loggers[0])
|
||||
verbose_logger.debug(
|
||||
"Initializing remaining budget metrics as a cron job executing every %s minutes"
|
||||
% CLOUDZERO_EXPORT_INTERVAL_MINUTES
|
||||
)
|
||||
scheduler.add_job(
|
||||
cloudzero_logger.initialize_cloudzero_export_job,
|
||||
"interval",
|
||||
minutes=CLOUDZERO_EXPORT_INTERVAL_MINUTES,
|
||||
)
|
||||
@@ -0,0 +1,161 @@
|
||||
# Copyright 2025 CloudZero
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# CHANGELOG: 2025-01-19 - Initial CZRN module for CloudZero Resource Names (erik.peterson)
|
||||
|
||||
"""CloudZero Resource Names (CZRN) generation and validation for LiteLLM resources."""
|
||||
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Any, cast
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class CZEntityType(str, Enum):
|
||||
TEAM = "team"
|
||||
|
||||
|
||||
class CZRNGenerator:
|
||||
"""Generate CloudZero Resource Names (CZRNs) for LiteLLM resources."""
|
||||
|
||||
CZRN_REGEX = re.compile(
|
||||
r"^czrn:([a-z0-9-]+):([a-zA-Z0-9-]+):([a-z0-9-]+):([a-z0-9-]+):([a-z0-9-]+):(.+)$"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize CZRN generator."""
|
||||
pass
|
||||
|
||||
def create_from_litellm_data(self, row: dict[str, Any]) -> str:
|
||||
"""Create a CZRN from LiteLLM daily spend data.
|
||||
|
||||
CZRN format: czrn:<service-type>:<provider>:<region>:<owner-account-id>:<resource-type>:<cloud-local-id>
|
||||
|
||||
For LiteLLM resources, we map:
|
||||
- service-type: 'litellm' (the service managing the LLM calls)
|
||||
- provider: The custom_llm_provider (e.g., 'openai', 'anthropic', 'azure')
|
||||
- region: 'cross-region' (LiteLLM operates across regions)
|
||||
- owner-account-id: The team_id or user_id (entity_id)
|
||||
- resource-type: 'llm-usage' (represents LLM usage/inference)
|
||||
- cloud-local-id: model
|
||||
"""
|
||||
service_type = "litellm"
|
||||
provider = self._normalize_provider(row.get("custom_llm_provider", "unknown"))
|
||||
region = "cross-region"
|
||||
|
||||
# Use the actual entity_id (team_id or user_id) as the owner account
|
||||
team_id = row.get("team_id", "unknown")
|
||||
owner_account_id = self._normalize_component(team_id)
|
||||
|
||||
resource_type = "llm-usage"
|
||||
|
||||
# Create a unique identifier with just the model (entity info already in owner_account_id)
|
||||
model = row.get("model", "unknown")
|
||||
|
||||
cloud_local_id = model
|
||||
|
||||
return self.create_from_components(
|
||||
service_type=service_type,
|
||||
provider=provider,
|
||||
region=region,
|
||||
owner_account_id=owner_account_id,
|
||||
resource_type=resource_type,
|
||||
cloud_local_id=cloud_local_id,
|
||||
)
|
||||
|
||||
def create_from_components(
|
||||
self,
|
||||
service_type: str,
|
||||
provider: str,
|
||||
region: str,
|
||||
owner_account_id: str,
|
||||
resource_type: str,
|
||||
cloud_local_id: str,
|
||||
) -> str:
|
||||
"""Create a CZRN from individual components."""
|
||||
# Normalize components to ensure they meet CZRN requirements
|
||||
service_type = self._normalize_component(service_type, allow_uppercase=True)
|
||||
provider = self._normalize_component(provider)
|
||||
region = self._normalize_component(region)
|
||||
owner_account_id = self._normalize_component(owner_account_id)
|
||||
resource_type = self._normalize_component(resource_type)
|
||||
# cloud_local_id can contain pipes and other characters, so don't normalize it
|
||||
|
||||
czrn = f"czrn:{service_type}:{provider}:{region}:{owner_account_id}:{resource_type}:{cloud_local_id}"
|
||||
|
||||
if not self.is_valid(czrn):
|
||||
raise ValueError(f"Generated CZRN is invalid: {czrn}")
|
||||
|
||||
return czrn
|
||||
|
||||
def is_valid(self, czrn: str) -> bool:
|
||||
"""Validate a CZRN string against the standard format."""
|
||||
return bool(self.CZRN_REGEX.match(czrn))
|
||||
|
||||
def extract_components(self, czrn: str) -> tuple[str, str, str, str, str, str]:
|
||||
"""Extract all components from a CZRN.
|
||||
|
||||
Returns: (service_type, provider, region, owner_account_id, resource_type, cloud_local_id)
|
||||
"""
|
||||
match = self.CZRN_REGEX.match(czrn)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid CZRN format: {czrn}")
|
||||
|
||||
return cast(tuple[str, str, str, str, str, str], match.groups())
|
||||
|
||||
def _normalize_provider(self, provider: str) -> str:
|
||||
"""Normalize provider names to standard CZRN format."""
|
||||
# Map common provider names to CZRN standards
|
||||
provider_map = {
|
||||
litellm.LlmProviders.AZURE.value: "azure",
|
||||
litellm.LlmProviders.AZURE_AI.value: "azure",
|
||||
litellm.LlmProviders.ANTHROPIC.value: "anthropic",
|
||||
litellm.LlmProviders.BEDROCK.value: "aws",
|
||||
litellm.LlmProviders.VERTEX_AI.value: "gcp",
|
||||
litellm.LlmProviders.GEMINI.value: "google",
|
||||
litellm.LlmProviders.COHERE.value: "cohere",
|
||||
litellm.LlmProviders.HUGGINGFACE.value: "huggingface",
|
||||
litellm.LlmProviders.REPLICATE.value: "replicate",
|
||||
litellm.LlmProviders.TOGETHER_AI.value: "together-ai",
|
||||
}
|
||||
|
||||
normalized = provider.lower().replace("_", "-")
|
||||
|
||||
# use litellm custom llm provider if not in provider_map
|
||||
if normalized not in provider_map:
|
||||
return normalized
|
||||
return provider_map.get(normalized, normalized)
|
||||
|
||||
def _normalize_component(
|
||||
self, component: str, allow_uppercase: bool = False
|
||||
) -> str:
|
||||
"""Normalize a CZRN component to meet format requirements."""
|
||||
if not component:
|
||||
return "unknown"
|
||||
|
||||
# Convert to lowercase unless uppercase is allowed
|
||||
if not allow_uppercase:
|
||||
component = component.lower()
|
||||
|
||||
# Replace invalid characters with hyphens
|
||||
component = re.sub(r"[^a-zA-Z0-9-]", "-", component)
|
||||
|
||||
# Remove consecutive hyphens
|
||||
component = re.sub(r"-+", "-", component)
|
||||
|
||||
# Remove leading/trailing hyphens
|
||||
component = component.strip("-")
|
||||
|
||||
return component or "unknown"
|
||||
@@ -0,0 +1,278 @@
|
||||
# Copyright 2025 CloudZero
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# CHANGELOG: 2025-01-19 - Added pathlib for filesystem operations (erik.peterson)
|
||||
# CHANGELOG: 2025-01-19 - Migrated from pandas to polars and requests to httpx (erik.peterson)
|
||||
# CHANGELOG: 2025-01-19 - Initial output module for CSV and CloudZero API (erik.peterson)
|
||||
|
||||
"""Output modules for writing CBF data to various destinations."""
|
||||
|
||||
import zoneinfo
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
import polars as pl
|
||||
from rich.console import Console
|
||||
|
||||
|
||||
class CloudZeroStreamer:
|
||||
"""Stream CBF data to CloudZero AnyCost API with proper batching and timezone handling."""
|
||||
|
||||
def __init__(
|
||||
self, api_key: str, connection_id: str, user_timezone: Optional[str] = None
|
||||
):
|
||||
"""Initialize CloudZero streamer with credentials."""
|
||||
self.api_key = api_key
|
||||
self.connection_id = connection_id
|
||||
self.base_url = "https://api.cloudzero.com"
|
||||
self.console = Console()
|
||||
|
||||
# Set timezone - default to UTC
|
||||
self.user_timezone: Union[zoneinfo.ZoneInfo, timezone]
|
||||
if user_timezone:
|
||||
try:
|
||||
self.user_timezone = zoneinfo.ZoneInfo(user_timezone)
|
||||
except zoneinfo.ZoneInfoNotFoundError:
|
||||
self.console.print(
|
||||
f"[yellow]Warning: Unknown timezone '{user_timezone}', using UTC[/yellow]"
|
||||
)
|
||||
self.user_timezone = timezone.utc
|
||||
else:
|
||||
self.user_timezone = timezone.utc
|
||||
|
||||
def send_batched(
|
||||
self, data: pl.DataFrame, operation: str = "replace_hourly"
|
||||
) -> None:
|
||||
"""Send CBF data in daily batches to CloudZero AnyCost API."""
|
||||
if data.is_empty():
|
||||
self.console.print("[yellow]No data to send to CloudZero[/yellow]")
|
||||
return
|
||||
|
||||
# Group data by date and send each day as a batch
|
||||
daily_batches = self._group_by_date(data)
|
||||
|
||||
if not daily_batches:
|
||||
self.console.print("[yellow]No valid daily batches to send[/yellow]")
|
||||
return
|
||||
|
||||
self.console.print(
|
||||
f"[blue]Sending {len(daily_batches)} daily batch(es) with operation '{operation}'[/blue]"
|
||||
)
|
||||
|
||||
for batch_date, batch_data in daily_batches.items():
|
||||
self._send_daily_batch(batch_date, batch_data, operation)
|
||||
|
||||
def _group_by_date(self, data: pl.DataFrame) -> dict[str, pl.DataFrame]:
|
||||
"""Group data by date, converting to UTC and validating dates."""
|
||||
daily_batches: dict[str, list[dict[str, Any]]] = {}
|
||||
|
||||
# Ensure we have the required columns
|
||||
if "time/usage_start" not in data.columns:
|
||||
self.console.print(
|
||||
"[red]Error: Missing 'time/usage_start' column for date grouping[/red]"
|
||||
)
|
||||
return {}
|
||||
|
||||
timestamp_str: Optional[str] = None
|
||||
for row in data.iter_rows(named=True):
|
||||
try:
|
||||
# Parse the timestamp and convert to UTC
|
||||
timestamp_str = row.get("time/usage_start")
|
||||
if not timestamp_str:
|
||||
continue
|
||||
|
||||
# Parse timestamp and handle timezone conversion
|
||||
dt = self._parse_and_convert_timestamp(timestamp_str)
|
||||
batch_date = dt.strftime("%Y-%m-%d")
|
||||
|
||||
if batch_date not in daily_batches:
|
||||
daily_batches[batch_date] = []
|
||||
|
||||
daily_batches[batch_date].append(row)
|
||||
|
||||
except Exception as e:
|
||||
self.console.print(
|
||||
f"[yellow]Warning: Could not process timestamp '{timestamp_str}': {e}[/yellow]"
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert lists back to DataFrames
|
||||
return {
|
||||
date_key: pl.DataFrame(records)
|
||||
for date_key, records in daily_batches.items()
|
||||
if records
|
||||
}
|
||||
|
||||
def _parse_and_convert_timestamp(self, timestamp_str: str) -> datetime:
|
||||
"""Parse timestamp string and convert to UTC."""
|
||||
# Try to parse the timestamp string
|
||||
try:
|
||||
# Handle various ISO 8601 formats
|
||||
if timestamp_str.endswith("Z"):
|
||||
dt = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00"))
|
||||
elif "+" in timestamp_str or timestamp_str.endswith(
|
||||
(
|
||||
"-00:00",
|
||||
"-01:00",
|
||||
"-02:00",
|
||||
"-03:00",
|
||||
"-04:00",
|
||||
"-05:00",
|
||||
"-06:00",
|
||||
"-07:00",
|
||||
"-08:00",
|
||||
"-09:00",
|
||||
"-10:00",
|
||||
"-11:00",
|
||||
"-12:00",
|
||||
"+01:00",
|
||||
"+02:00",
|
||||
"+03:00",
|
||||
"+04:00",
|
||||
"+05:00",
|
||||
"+06:00",
|
||||
"+07:00",
|
||||
"+08:00",
|
||||
"+09:00",
|
||||
"+10:00",
|
||||
"+11:00",
|
||||
"+12:00",
|
||||
)
|
||||
):
|
||||
dt = datetime.fromisoformat(timestamp_str)
|
||||
else:
|
||||
# Assume user timezone if no timezone info
|
||||
dt = datetime.fromisoformat(timestamp_str)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=self.user_timezone)
|
||||
|
||||
# Convert to UTC
|
||||
return dt.astimezone(timezone.utc)
|
||||
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Could not parse timestamp '{timestamp_str}': {e}")
|
||||
|
||||
def _send_daily_batch(
|
||||
self, batch_date: str, batch_data: pl.DataFrame, operation: str
|
||||
) -> None:
|
||||
"""Send a single daily batch to CloudZero API."""
|
||||
if batch_data.is_empty():
|
||||
return
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Use the correct API endpoint format from documentation
|
||||
url = f"{self.base_url}/v2/connections/billing/anycost/{self.connection_id}/billing_drops"
|
||||
|
||||
# Prepare the batch payload according to AnyCost API format
|
||||
payload = self._prepare_batch_payload(batch_date, batch_data, operation)
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
self.console.print(
|
||||
f"[blue]Sending batch for {batch_date} ({len(batch_data)} records)[/blue]"
|
||||
)
|
||||
|
||||
response = client.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
self.console.print(
|
||||
f"[green]✓ Successfully sent batch for {batch_date} ({len(batch_data)} records)[/green]"
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
self.console.print(
|
||||
f"[red]✗ Network error sending batch for {batch_date}: {e}[/red]"
|
||||
)
|
||||
raise
|
||||
except httpx.HTTPStatusError as e:
|
||||
self.console.print(
|
||||
f"[red]✗ HTTP error sending batch for {batch_date}: {e.response.status_code} {e.response.text}[/red]"
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_batch_payload(
|
||||
self, batch_date: str, batch_data: pl.DataFrame, operation: str
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare batch payload according to CloudZero AnyCost API format."""
|
||||
# Convert batch_date to month for the API (YYYY-MM format)
|
||||
try:
|
||||
date_obj = datetime.strptime(batch_date, "%Y-%m-%d")
|
||||
month_str = date_obj.strftime("%Y-%m")
|
||||
except ValueError:
|
||||
# Fallback to current month
|
||||
month_str = datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Convert DataFrame rows to API format
|
||||
data_records = []
|
||||
for row in batch_data.iter_rows(named=True):
|
||||
record = self._convert_cbf_to_api_format(row)
|
||||
if record:
|
||||
data_records.append(record)
|
||||
|
||||
payload = {"month": month_str, "operation": operation, "data": data_records}
|
||||
|
||||
return payload
|
||||
|
||||
def _convert_cbf_to_api_format(
|
||||
self, row: dict[str, Any]
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""Convert CBF row to CloudZero API format - keeping CBF field names as CloudZero expects them."""
|
||||
try:
|
||||
# CloudZero expects CBF format field names directly, not converted names
|
||||
api_record = {}
|
||||
|
||||
# Copy all CBF fields, converting numeric values to strings as required by CloudZero
|
||||
for key, value in row.items():
|
||||
if value is not None:
|
||||
# CloudZero requires numeric values to be strings, but NOT in scientific notation
|
||||
if isinstance(value, (int, float)):
|
||||
# Format floats to avoid scientific notation
|
||||
if isinstance(value, float):
|
||||
# Use a reasonable precision that avoids scientific notation
|
||||
api_record[key] = f"{value:.10f}".rstrip("0").rstrip(".")
|
||||
else:
|
||||
api_record[key] = str(value)
|
||||
else:
|
||||
api_record[key] = value
|
||||
|
||||
# Ensure timestamp is in UTC format
|
||||
if "time/usage_start" in api_record:
|
||||
api_record["time/usage_start"] = self._ensure_utc_timestamp(
|
||||
api_record["time/usage_start"]
|
||||
)
|
||||
|
||||
return api_record
|
||||
|
||||
except Exception as e:
|
||||
self.console.print(
|
||||
f"[yellow]Warning: Could not convert record to API format: {e}[/yellow]"
|
||||
)
|
||||
return None
|
||||
|
||||
def _ensure_utc_timestamp(self, timestamp_str: str) -> str:
|
||||
"""Ensure timestamp is in UTC format for API."""
|
||||
if not timestamp_str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
try:
|
||||
dt = self._parse_and_convert_timestamp(timestamp_str)
|
||||
return dt.isoformat().replace("+00:00", "Z")
|
||||
except Exception:
|
||||
# Fallback to current time in UTC
|
||||
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
@@ -0,0 +1,101 @@
|
||||
# Copyright 2025 CloudZero
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# CHANGELOG: 2025-01-19 - Refactored to use daily spend tables for proper CBF mapping (erik.peterson)
|
||||
# CHANGELOG: 2025-01-19 - Migrated from pandas to polars for database operations (erik.peterson)
|
||||
# CHANGELOG: 2025-01-19 - Initial database module for LiteLLM data extraction (erik.peterson)
|
||||
|
||||
"""Database connection and data extraction for LiteLLM."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, List
|
||||
|
||||
import polars as pl
|
||||
|
||||
|
||||
class LiteLLMDatabase:
|
||||
"""Handle LiteLLM PostgreSQL database connections and queries."""
|
||||
|
||||
def _ensure_prisma_client(self):
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
"""Ensure prisma client is available."""
|
||||
if prisma_client is None:
|
||||
raise Exception(
|
||||
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||
)
|
||||
return prisma_client
|
||||
|
||||
async def get_usage_data(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
start_time_utc: Optional[datetime] = None,
|
||||
end_time_utc: Optional[datetime] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""Retrieve usage data from LiteLLM daily user spend table."""
|
||||
client = self._ensure_prisma_client()
|
||||
|
||||
# Query to get user spend data with team information. Use parameter binding to
|
||||
# avoid SQL injection from user-supplied timestamps or limits.
|
||||
query = """
|
||||
SELECT
|
||||
dus.id,
|
||||
dus.date,
|
||||
dus.user_id,
|
||||
dus.api_key,
|
||||
dus.model,
|
||||
dus.model_group,
|
||||
dus.custom_llm_provider,
|
||||
dus.prompt_tokens,
|
||||
dus.completion_tokens,
|
||||
dus.spend,
|
||||
dus.api_requests,
|
||||
dus.successful_requests,
|
||||
dus.failed_requests,
|
||||
dus.cache_creation_input_tokens,
|
||||
dus.cache_read_input_tokens,
|
||||
dus.created_at,
|
||||
dus.updated_at,
|
||||
vt.team_id,
|
||||
vt.key_alias as api_key_alias,
|
||||
tt.team_alias,
|
||||
ut.user_email as user_email
|
||||
FROM "LiteLLM_DailyUserSpend" dus
|
||||
LEFT JOIN "LiteLLM_VerificationToken" vt ON dus.api_key = vt.token
|
||||
LEFT JOIN "LiteLLM_TeamTable" tt ON vt.team_id = tt.team_id
|
||||
LEFT JOIN "LiteLLM_UserTable" ut ON dus.user_id = ut.user_id
|
||||
WHERE ($1::timestamptz IS NULL OR dus.updated_at >= $1::timestamptz)
|
||||
AND ($2::timestamptz IS NULL OR dus.updated_at <= $2::timestamptz)
|
||||
ORDER BY dus.date DESC, dus.created_at DESC
|
||||
"""
|
||||
|
||||
params: List[Any] = [
|
||||
start_time_utc,
|
||||
end_time_utc,
|
||||
]
|
||||
|
||||
if limit is not None:
|
||||
try:
|
||||
params.append(int(limit))
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError("limit must be an integer")
|
||||
query += " LIMIT $3"
|
||||
|
||||
try:
|
||||
db_response = await client.db.query_raw(query, *params)
|
||||
# Convert the response to polars DataFrame with full schema inference
|
||||
# This prevents schema mismatch errors when data types vary across rows
|
||||
return pl.DataFrame(db_response, infer_schema_length=None)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error retrieving usage data: {str(e)}")
|
||||
@@ -0,0 +1,223 @@
|
||||
# Copyright 2025 CloudZero
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# CHANGELOG: 2025-01-19 - Updated CBF transformation for daily spend tables and proper CloudZero mapping (erik.peterson)
|
||||
# CHANGELOG: 2025-01-19 - Migrated from pandas to polars for data transformation (erik.peterson)
|
||||
# CHANGELOG: 2025-01-19 - Initial CBF transformation module (erik.peterson)
|
||||
|
||||
"""Transform LiteLLM data to CloudZero AnyCost CBF format."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import polars as pl
|
||||
|
||||
from ...types.integrations.cloudzero import CBFRecord
|
||||
from .cz_resource_names import CZEntityType, CZRNGenerator
|
||||
|
||||
|
||||
class CBFTransformer:
|
||||
"""Transform LiteLLM usage data to CloudZero Billing Format (CBF)."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize transformer with CZRN generator."""
|
||||
self.czrn_generator = CZRNGenerator()
|
||||
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
"""Transform LiteLLM data to CBF format, dropping records with zero successful_requests or invalid CZRNs."""
|
||||
if data.is_empty():
|
||||
return pl.DataFrame()
|
||||
|
||||
# Filter out records with zero successful_requests first
|
||||
original_count = len(data)
|
||||
if "successful_requests" in data.columns:
|
||||
filtered_data = data.filter(pl.col("successful_requests") > 0)
|
||||
zero_requests_dropped = original_count - len(filtered_data)
|
||||
else:
|
||||
filtered_data = data
|
||||
zero_requests_dropped = 0
|
||||
|
||||
cbf_data = []
|
||||
czrn_dropped_count = 0
|
||||
filtered_count = len(filtered_data)
|
||||
|
||||
for row in filtered_data.iter_rows(named=True):
|
||||
try:
|
||||
cbf_record = self._create_cbf_record(row)
|
||||
# Only include the record if CZRN generation was successful
|
||||
cbf_data.append(cbf_record)
|
||||
except Exception:
|
||||
# Skip records that fail CZRN generation
|
||||
czrn_dropped_count += 1
|
||||
continue
|
||||
|
||||
# Print summary of dropped records if any
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
|
||||
if zero_requests_dropped > 0:
|
||||
console.print(
|
||||
f"[yellow]⚠️ Dropped {zero_requests_dropped:,} of {original_count:,} records with zero successful_requests[/yellow]"
|
||||
)
|
||||
|
||||
if czrn_dropped_count > 0:
|
||||
console.print(
|
||||
f"[yellow]⚠️ Dropped {czrn_dropped_count:,} of {filtered_count:,} filtered records due to invalid CZRNs[/yellow]"
|
||||
)
|
||||
|
||||
if len(cbf_data) > 0:
|
||||
console.print(
|
||||
f"[green]✓ Successfully transformed {len(cbf_data):,} records[/green]"
|
||||
)
|
||||
|
||||
return pl.DataFrame(cbf_data)
|
||||
|
||||
def _create_cbf_record(self, row: dict[str, Any]) -> CBFRecord:
|
||||
"""Create a single CBF record from LiteLLM daily spend row."""
|
||||
|
||||
# Parse date (daily spend tables use date strings like '2025-04-19')
|
||||
usage_date = self._parse_date(row.get("date"))
|
||||
|
||||
# Calculate total tokens
|
||||
prompt_tokens = int(row.get("prompt_tokens", 0))
|
||||
completion_tokens = int(row.get("completion_tokens", 0))
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
# Create CloudZero Resource Name (CZRN) as resource_id
|
||||
resource_id = self.czrn_generator.create_from_litellm_data(row)
|
||||
|
||||
# Build dimensions for CloudZero
|
||||
model = str(row.get("model", ""))
|
||||
api_key_hash = str(row.get("api_key", ""))[
|
||||
:8
|
||||
] # First 8 chars for identification
|
||||
|
||||
# Handle team information with fallbacks
|
||||
team_id = row.get("team_id")
|
||||
team_alias = row.get("team_alias")
|
||||
user_email = row.get("user_email")
|
||||
|
||||
# Use team_alias if available, otherwise team_id, otherwise fallback to 'unknown'
|
||||
entity_id = (
|
||||
str(team_alias) if team_alias else (str(team_id) if team_id else "unknown")
|
||||
)
|
||||
|
||||
# Get alias fields if they exist
|
||||
api_key_alias = row.get("api_key_alias")
|
||||
organization_alias = row.get("organization_alias")
|
||||
project_alias = row.get("project_alias")
|
||||
user_alias = row.get("user_alias")
|
||||
|
||||
dimensions = {
|
||||
"entity_type": CZEntityType.TEAM.value,
|
||||
"entity_id": entity_id,
|
||||
"team_alias": str(team_alias) if team_alias else "unknown",
|
||||
"model": model,
|
||||
"model_group": str(row.get("model_group", "")),
|
||||
"provider": str(row.get("custom_llm_provider", "")),
|
||||
"api_key_prefix": api_key_hash,
|
||||
"api_key_alias": str(row.get("api_key_alias", "")),
|
||||
"user_email": str(user_email) if user_email else "",
|
||||
"api_requests": str(row.get("api_requests", 0)),
|
||||
"successful_requests": str(row.get("successful_requests", 0)),
|
||||
"failed_requests": str(row.get("failed_requests", 0)),
|
||||
"cache_creation_tokens": str(row.get("cache_creation_input_tokens", 0)),
|
||||
"cache_read_tokens": str(row.get("cache_read_input_tokens", 0)),
|
||||
"organization_alias": str(organization_alias) if organization_alias else "",
|
||||
"project_alias": str(project_alias) if project_alias else "",
|
||||
"user_alias": str(user_alias) if user_alias else "",
|
||||
}
|
||||
|
||||
# Extract CZRN components to populate corresponding CBF columns
|
||||
czrn_components = self.czrn_generator.extract_components(resource_id)
|
||||
(
|
||||
service_type,
|
||||
provider,
|
||||
region,
|
||||
owner_account_id,
|
||||
resource_type,
|
||||
cloud_local_id,
|
||||
) = czrn_components
|
||||
|
||||
# Build resource/account as concat of api_key_alias and api_key_prefix
|
||||
resource_account = (
|
||||
f"{api_key_alias}|{api_key_hash}" if api_key_alias else api_key_hash
|
||||
)
|
||||
|
||||
# CloudZero CBF format with proper column names
|
||||
cbf_record = {
|
||||
# Required CBF fields
|
||||
"time/usage_start": usage_date.isoformat()
|
||||
if usage_date
|
||||
else None, # Required: ISO-formatted UTC datetime
|
||||
"cost/cost": float(row.get("spend", 0.0)), # Required: billed cost
|
||||
"resource/id": resource_id, # CZRN (CloudZero Resource Name)
|
||||
# Usage metrics for token consumption
|
||||
"usage/amount": total_tokens, # Numeric value of tokens consumed
|
||||
"usage/units": "tokens", # Description of token units
|
||||
# CBF fields - updated per LIT-1907
|
||||
"resource/service": str(row.get("model_group", "")), # Send model_group
|
||||
"resource/account": resource_account, # Send api_key_alias|api_key_prefix
|
||||
"resource/region": region, # Maps to CZRN region (cross-region)
|
||||
"resource/usage_family": str(
|
||||
row.get("custom_llm_provider", "")
|
||||
), # Send provider
|
||||
# Action field
|
||||
"action/operation": str(team_id) if team_id else "", # Send team_id
|
||||
# Line item details
|
||||
"lineitem/type": "Usage", # Standard usage line item
|
||||
}
|
||||
|
||||
# Add CZRN components that don't have direct CBF column mappings as resource tags
|
||||
cbf_record["resource/tag:provider"] = provider # CZRN provider component
|
||||
cbf_record[
|
||||
"resource/tag:model"
|
||||
] = cloud_local_id # CZRN cloud-local-id component (model)
|
||||
|
||||
# Add resource tags for all dimensions (using resource/tag:<key> format)
|
||||
for key, value in dimensions.items():
|
||||
if (
|
||||
value and value != "N/A" and value != "unknown"
|
||||
): # Only add meaningful tags
|
||||
cbf_record[f"resource/tag:{key}"] = str(value)
|
||||
|
||||
# Add token breakdown as resource tags for analysis (excluding total_tokens per LIT-1907)
|
||||
if prompt_tokens > 0:
|
||||
cbf_record["resource/tag:prompt_tokens"] = str(prompt_tokens)
|
||||
if completion_tokens > 0:
|
||||
cbf_record["resource/tag:completion_tokens"] = str(completion_tokens)
|
||||
|
||||
return CBFRecord(cbf_record)
|
||||
|
||||
def _parse_date(self, date_str) -> Optional[datetime]:
|
||||
"""Parse date string from daily spend tables (e.g., '2025-04-19')."""
|
||||
if date_str is None:
|
||||
return None
|
||||
|
||||
if isinstance(date_str, datetime):
|
||||
return date_str
|
||||
|
||||
if isinstance(date_str, str):
|
||||
try:
|
||||
# Parse date string and set to midnight UTC for daily aggregation
|
||||
return pl.Series([date_str]).str.to_datetime("%Y-%m-%d").item()
|
||||
except Exception:
|
||||
try:
|
||||
# Fallback: try ISO format parsing
|
||||
return pl.Series([date_str]).str.to_datetime().item()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Custom Logger that handles batching logic
|
||||
|
||||
Use this if you want your logs to be stored in memory and flushed periodically.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
|
||||
class CustomBatchLogger(CustomLogger):
|
||||
def __init__(
|
||||
self,
|
||||
flush_lock: Optional[asyncio.Lock] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
flush_interval: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching
|
||||
"""
|
||||
self.log_queue: List = []
|
||||
self.flush_interval = flush_interval or litellm.DEFAULT_FLUSH_INTERVAL_SECONDS
|
||||
self.batch_size: int = batch_size or litellm.DEFAULT_BATCH_SIZE
|
||||
self.last_flush_time = time.time()
|
||||
self.flush_lock = flush_lock
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def periodic_flush(self):
|
||||
while True:
|
||||
await asyncio.sleep(self.flush_interval)
|
||||
verbose_logger.debug(
|
||||
f"CustomLogger periodic flush after {self.flush_interval} seconds"
|
||||
)
|
||||
await self.flush_queue()
|
||||
|
||||
async def flush_queue(self):
|
||||
if self.flush_lock is None:
|
||||
return
|
||||
|
||||
async with self.flush_lock:
|
||||
if self.log_queue:
|
||||
verbose_logger.debug(
|
||||
"CustomLogger: Flushing batch of %s events", len(self.log_queue)
|
||||
)
|
||||
await self.async_send_batch()
|
||||
self.log_queue.clear()
|
||||
self.last_flush_time = time.time()
|
||||
|
||||
async def async_send_batch(self, *args, **kwargs):
|
||||
pass
|
||||
@@ -0,0 +1,955 @@
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
get_args,
|
||||
)
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.guardrails import (
|
||||
DynamicGuardrailParams,
|
||||
GuardrailEventHooks,
|
||||
LitellmParams,
|
||||
Mode,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel
|
||||
from litellm.types.utils import (
|
||||
CallTypes,
|
||||
GenericGuardrailAPIInputs,
|
||||
GuardrailStatus,
|
||||
GuardrailTracingDetail,
|
||||
LLMResponseTypes,
|
||||
StandardLoggingGuardrailInformation,
|
||||
)
|
||||
|
||||
try:
|
||||
from fastapi.exceptions import HTTPException
|
||||
except ImportError:
|
||||
HTTPException = None # type: ignore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
dc = DualCache()
|
||||
|
||||
|
||||
class ModifyResponseException(Exception):
|
||||
"""
|
||||
Exception raised when a guardrail wants to modify the response.
|
||||
|
||||
This exception carries the synthetic response that should be returned
|
||||
to the user instead of calling the LLM or instead of the LLM's response.
|
||||
It should be caught by the proxy and returned with a 200 status code.
|
||||
|
||||
This is a base exception that all guardrails can use to replace responses,
|
||||
allowing violation messages to be returned as successful responses
|
||||
rather than errors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
model: str,
|
||||
request_data: Dict[str, Any],
|
||||
guardrail_name: Optional[str] = None,
|
||||
detection_info: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the modify response exception.
|
||||
|
||||
Args:
|
||||
message: The violation message to return to the user
|
||||
model: The model that was being called
|
||||
request_data: The original request data
|
||||
guardrail_name: Name of the guardrail that raised this exception
|
||||
detection_info: Additional detection metadata (scores, rules, etc.)
|
||||
"""
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.request_data = request_data
|
||||
self.guardrail_name = guardrail_name
|
||||
self.detection_info = detection_info or {}
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class CustomGuardrail(CustomLogger):
|
||||
def __init__(
|
||||
self,
|
||||
guardrail_name: Optional[str] = None,
|
||||
supported_event_hooks: Optional[List[GuardrailEventHooks]] = None,
|
||||
event_hook: Optional[
|
||||
Union[GuardrailEventHooks, List[GuardrailEventHooks], Mode]
|
||||
] = None,
|
||||
default_on: bool = False,
|
||||
mask_request_content: bool = False,
|
||||
mask_response_content: bool = False,
|
||||
violation_message_template: Optional[str] = None,
|
||||
end_session_after_n_fails: Optional[int] = None,
|
||||
on_violation: Optional[str] = None,
|
||||
realtime_violation_message: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the CustomGuardrail class
|
||||
|
||||
Args:
|
||||
guardrail_name: The name of the guardrail. This is the name used in your requests.
|
||||
supported_event_hooks: The event hooks that the guardrail supports
|
||||
event_hook: The event hook to run the guardrail on
|
||||
default_on: If True, the guardrail will be run by default on all requests
|
||||
mask_request_content: If True, the guardrail will mask the request content
|
||||
mask_response_content: If True, the guardrail will mask the response content
|
||||
end_session_after_n_fails: For /v1/realtime sessions, end the session after this many violations
|
||||
on_violation: For /v1/realtime sessions, 'warn' or 'end_session'
|
||||
realtime_violation_message: Message the bot speaks aloud when a /v1/realtime guardrail fires
|
||||
"""
|
||||
self.guardrail_name = guardrail_name
|
||||
self.supported_event_hooks = supported_event_hooks
|
||||
self.event_hook: Optional[
|
||||
Union[GuardrailEventHooks, List[GuardrailEventHooks], Mode]
|
||||
] = event_hook
|
||||
self.default_on: bool = default_on
|
||||
self.mask_request_content: bool = mask_request_content
|
||||
self.mask_response_content: bool = mask_response_content
|
||||
self.violation_message_template: Optional[str] = violation_message_template
|
||||
self.end_session_after_n_fails: Optional[int] = end_session_after_n_fails
|
||||
self.on_violation: Optional[str] = on_violation
|
||||
self.realtime_violation_message: Optional[str] = realtime_violation_message
|
||||
|
||||
if supported_event_hooks:
|
||||
## validate event_hook is in supported_event_hooks
|
||||
self._validate_event_hook(event_hook, supported_event_hooks)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def render_violation_message(
|
||||
self, default: str, context: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""Return a custom violation message if template is configured."""
|
||||
|
||||
if not self.violation_message_template:
|
||||
return default
|
||||
|
||||
format_context: Dict[str, Any] = {"default_message": default}
|
||||
if context:
|
||||
format_context.update(context)
|
||||
try:
|
||||
return self.violation_message_template.format(**format_context)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
"Failed to format violation message template for guardrail %s: %s",
|
||||
self.guardrail_name,
|
||||
e,
|
||||
)
|
||||
return default
|
||||
|
||||
def raise_passthrough_exception(
|
||||
self,
|
||||
violation_message: str,
|
||||
request_data: Dict[str, Any],
|
||||
detection_info: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Raise a passthrough exception for guardrail violations.
|
||||
|
||||
This helper method should be used by guardrails when they detect a violation
|
||||
in passthrough mode.
|
||||
|
||||
The exception will be caught by the proxy endpoints and converted to a 200 response
|
||||
with the violation message, preventing the LLM call from being made (pre_call/during_call)
|
||||
or replacing the LLM response (post_call).
|
||||
|
||||
Args:
|
||||
violation_message: The formatted violation message to return to the user
|
||||
request_data: The original request data dictionary
|
||||
detection_info: Optional dictionary with detection metadata (scores, rules, etc.)
|
||||
|
||||
Raises:
|
||||
ModifyResponseException: Always raises this exception to short-circuit
|
||||
the LLM call and return the violation message
|
||||
|
||||
Example:
|
||||
if violation_detected and self.on_flagged_action == "passthrough":
|
||||
message = self._format_violation_message(detection_info)
|
||||
self.raise_passthrough_exception(
|
||||
violation_message=message,
|
||||
request_data=data,
|
||||
detection_info=detection_info
|
||||
)
|
||||
"""
|
||||
model = request_data.get("model", "unknown")
|
||||
|
||||
raise ModifyResponseException(
|
||||
message=violation_message,
|
||||
model=model,
|
||||
request_data=request_data,
|
||||
guardrail_name=self.guardrail_name,
|
||||
detection_info=detection_info,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
|
||||
"""
|
||||
Returns the config model for the guardrail
|
||||
|
||||
This is used to render the config model in the UI.
|
||||
"""
|
||||
return None
|
||||
|
||||
def _validate_event_hook(
|
||||
self,
|
||||
event_hook: Optional[
|
||||
Union[GuardrailEventHooks, List[GuardrailEventHooks], Mode]
|
||||
],
|
||||
supported_event_hooks: List[GuardrailEventHooks],
|
||||
) -> None:
|
||||
def _validate_event_hook_list_is_in_supported_event_hooks(
|
||||
event_hook: Union[List[GuardrailEventHooks], List[str]],
|
||||
supported_event_hooks: List[GuardrailEventHooks],
|
||||
) -> None:
|
||||
for hook in event_hook:
|
||||
if isinstance(hook, str):
|
||||
hook = GuardrailEventHooks(hook)
|
||||
if hook not in supported_event_hooks:
|
||||
raise ValueError(
|
||||
f"Event hook {hook} is not in the supported event hooks {supported_event_hooks}"
|
||||
)
|
||||
|
||||
if event_hook is None:
|
||||
return
|
||||
if isinstance(event_hook, str):
|
||||
event_hook = GuardrailEventHooks(event_hook)
|
||||
if isinstance(event_hook, list):
|
||||
_validate_event_hook_list_is_in_supported_event_hooks(
|
||||
event_hook, supported_event_hooks
|
||||
)
|
||||
elif isinstance(event_hook, Mode):
|
||||
tag_values_flat: list = []
|
||||
for v in event_hook.tags.values():
|
||||
if isinstance(v, list):
|
||||
tag_values_flat.extend(v)
|
||||
else:
|
||||
tag_values_flat.append(v)
|
||||
_validate_event_hook_list_is_in_supported_event_hooks(
|
||||
tag_values_flat, supported_event_hooks
|
||||
)
|
||||
if event_hook.default:
|
||||
default_list = (
|
||||
event_hook.default
|
||||
if isinstance(event_hook.default, list)
|
||||
else [event_hook.default]
|
||||
)
|
||||
_validate_event_hook_list_is_in_supported_event_hooks(
|
||||
default_list, supported_event_hooks
|
||||
)
|
||||
elif isinstance(event_hook, GuardrailEventHooks):
|
||||
if event_hook not in supported_event_hooks:
|
||||
raise ValueError(
|
||||
f"Event hook {event_hook} is not in the supported event hooks {supported_event_hooks}"
|
||||
)
|
||||
|
||||
def get_disable_global_guardrail(self, data: dict) -> Optional[bool]:
|
||||
"""
|
||||
Returns True if the global guardrail should be disabled
|
||||
"""
|
||||
if "disable_global_guardrail" in data:
|
||||
return data["disable_global_guardrail"]
|
||||
metadata = data.get("litellm_metadata") or data.get("metadata", {})
|
||||
if "disable_global_guardrail" in metadata:
|
||||
return metadata["disable_global_guardrail"]
|
||||
return False
|
||||
|
||||
def _is_valid_response_type(self, result: Any) -> bool:
|
||||
"""
|
||||
Check if result is a valid LLMResponseTypes instance.
|
||||
|
||||
Safely handles TypedDict types which don't support isinstance checks.
|
||||
For non-LiteLLM responses (like passthrough httpx.Response), returns True
|
||||
to allow them through.
|
||||
"""
|
||||
if result is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Try isinstance check on valid types that support it
|
||||
response_types = get_args(LLMResponseTypes)
|
||||
return isinstance(result, response_types)
|
||||
except TypeError as e:
|
||||
# TypedDict types don't support isinstance checks
|
||||
# In this case, we can't validate the type, so we allow it through
|
||||
if "TypedDict" in str(e):
|
||||
return True
|
||||
raise
|
||||
|
||||
def get_guardrail_from_metadata(
|
||||
self, data: dict
|
||||
) -> Union[List[str], List[Dict[str, DynamicGuardrailParams]]]:
|
||||
"""
|
||||
Returns the guardrail(s) to be run from the metadata or root
|
||||
"""
|
||||
|
||||
if "guardrails" in data:
|
||||
return data["guardrails"]
|
||||
metadata = data.get("litellm_metadata") or data.get("metadata", {})
|
||||
return metadata.get("guardrails") or []
|
||||
|
||||
def _guardrail_is_in_requested_guardrails(
|
||||
self,
|
||||
requested_guardrails: Union[List[str], List[Dict[str, DynamicGuardrailParams]]],
|
||||
) -> bool:
|
||||
for _guardrail in requested_guardrails:
|
||||
if isinstance(_guardrail, dict):
|
||||
if self.guardrail_name in _guardrail:
|
||||
return True
|
||||
elif isinstance(_guardrail, str):
|
||||
if self.guardrail_name == _guardrail:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def async_pre_call_deployment_hook(
|
||||
self, kwargs: Dict[str, Any], call_type: Optional[CallTypes]
|
||||
) -> Optional[dict]:
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
# should run guardrail
|
||||
litellm_guardrails = kwargs.get("guardrails")
|
||||
if litellm_guardrails is None or not isinstance(litellm_guardrails, list):
|
||||
return kwargs
|
||||
|
||||
if (
|
||||
self.should_run_guardrail(
|
||||
data=kwargs, event_type=GuardrailEventHooks.pre_call
|
||||
)
|
||||
is not True
|
||||
):
|
||||
return kwargs
|
||||
|
||||
# CHECK IF GUARDRAIL REJECTS THE REQUEST
|
||||
if call_type == CallTypes.completion or call_type == CallTypes.acompletion:
|
||||
result = await self.async_pre_call_hook(
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_id=kwargs.get("user_api_key_user_id"),
|
||||
team_id=kwargs.get("user_api_key_team_id"),
|
||||
end_user_id=kwargs.get("user_api_key_end_user_id"),
|
||||
api_key=kwargs.get("user_api_key_hash"),
|
||||
request_route=kwargs.get("user_api_key_request_route"),
|
||||
),
|
||||
cache=dc,
|
||||
data=kwargs,
|
||||
call_type=call_type.value or "acompletion", # type: ignore
|
||||
)
|
||||
|
||||
if result is not None and isinstance(result, dict):
|
||||
result_messages = result.get("messages")
|
||||
if result_messages is not None: # update for any pii / masking logic
|
||||
kwargs["messages"] = result_messages
|
||||
|
||||
return kwargs
|
||||
|
||||
async def async_post_call_success_deployment_hook(
|
||||
self,
|
||||
request_data: dict,
|
||||
response: LLMResponseTypes,
|
||||
call_type: Optional[CallTypes],
|
||||
) -> Optional[LLMResponseTypes]:
|
||||
"""
|
||||
Allow modifying / reviewing the response just after it's received from the deployment.
|
||||
"""
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
# should run guardrail
|
||||
litellm_guardrails = request_data.get("guardrails")
|
||||
if litellm_guardrails is None or not isinstance(litellm_guardrails, list):
|
||||
return response
|
||||
|
||||
if (
|
||||
self.should_run_guardrail(
|
||||
data=request_data, event_type=GuardrailEventHooks.post_call
|
||||
)
|
||||
is not True
|
||||
):
|
||||
return response
|
||||
|
||||
# CHECK IF GUARDRAIL REJECTS THE REQUEST
|
||||
result = await self.async_post_call_success_hook(
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_id=request_data.get("user_api_key_user_id"),
|
||||
team_id=request_data.get("user_api_key_team_id"),
|
||||
end_user_id=request_data.get("user_api_key_end_user_id"),
|
||||
api_key=request_data.get("user_api_key_hash"),
|
||||
request_route=request_data.get("user_api_key_request_route"),
|
||||
),
|
||||
data=request_data,
|
||||
response=response,
|
||||
)
|
||||
|
||||
if not self._is_valid_response_type(result):
|
||||
return response
|
||||
|
||||
return result
|
||||
|
||||
def should_run_guardrail(
|
||||
self,
|
||||
data,
|
||||
event_type: GuardrailEventHooks,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if the guardrail should be run on the event_type
|
||||
"""
|
||||
requested_guardrails = self.get_guardrail_from_metadata(data)
|
||||
disable_global_guardrail = self.get_disable_global_guardrail(data)
|
||||
verbose_logger.debug(
|
||||
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s self.default_on= %s",
|
||||
self.guardrail_name,
|
||||
event_type,
|
||||
self.event_hook,
|
||||
requested_guardrails,
|
||||
self.default_on,
|
||||
)
|
||||
if self.default_on is True and disable_global_guardrail is not True:
|
||||
if self._event_hook_is_event_type(event_type):
|
||||
if isinstance(self.event_hook, Mode):
|
||||
try:
|
||||
from litellm_enterprise.integrations.custom_guardrail import (
|
||||
EnterpriseCustomGuardrailHelper,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Setting tag-based guardrails is only available in litellm-enterprise. You must be a premium user to use this feature."
|
||||
)
|
||||
result = EnterpriseCustomGuardrailHelper._should_run_if_mode_by_tag(
|
||||
data, self.event_hook, event_type
|
||||
)
|
||||
if result is not None:
|
||||
return result
|
||||
return True
|
||||
return False
|
||||
|
||||
if (
|
||||
self.event_hook
|
||||
and not self._guardrail_is_in_requested_guardrails(requested_guardrails)
|
||||
and event_type.value != "logging_only"
|
||||
):
|
||||
return False
|
||||
|
||||
if not self._event_hook_is_event_type(event_type):
|
||||
return False
|
||||
|
||||
if isinstance(self.event_hook, Mode):
|
||||
try:
|
||||
from litellm_enterprise.integrations.custom_guardrail import (
|
||||
EnterpriseCustomGuardrailHelper,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Setting tag-based guardrails is only available in litellm-enterprise. You must be a premium user to use this feature."
|
||||
)
|
||||
result = EnterpriseCustomGuardrailHelper._should_run_if_mode_by_tag(
|
||||
data, self.event_hook, event_type
|
||||
)
|
||||
if result is not None:
|
||||
return result
|
||||
return True
|
||||
|
||||
def _event_hook_is_event_type(self, event_type: GuardrailEventHooks) -> bool:
|
||||
"""
|
||||
Returns True if the event_hook is the same as the event_type
|
||||
|
||||
eg. if `self.event_hook == "pre_call" and event_type == "pre_call"` -> then True
|
||||
eg. if `self.event_hook == "pre_call" and event_type == "post_call"` -> then False
|
||||
"""
|
||||
|
||||
if self.event_hook is None:
|
||||
return True
|
||||
if isinstance(self.event_hook, list):
|
||||
return event_type.value in self.event_hook
|
||||
if isinstance(self.event_hook, Mode):
|
||||
for tag_value in self.event_hook.tags.values():
|
||||
if isinstance(tag_value, list):
|
||||
if event_type.value in tag_value:
|
||||
return True
|
||||
elif event_type.value == tag_value:
|
||||
return True
|
||||
if self.event_hook.default:
|
||||
default_list = (
|
||||
self.event_hook.default
|
||||
if isinstance(self.event_hook.default, list)
|
||||
else [self.event_hook.default]
|
||||
)
|
||||
return event_type.value in default_list
|
||||
return False
|
||||
return self.event_hook == event_type.value
|
||||
|
||||
def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict:
|
||||
"""
|
||||
Returns `extra_body` to be added to the request body for the Guardrail API call
|
||||
|
||||
Use this to pass dynamic params to the guardrail API call - eg. success_threshold, failure_threshold, etc.
|
||||
|
||||
```
|
||||
[{"lakera_guard": {"extra_body": {"foo": "bar"}}}]
|
||||
```
|
||||
|
||||
Will return: for guardrail=`lakera-guard`:
|
||||
{
|
||||
"foo": "bar"
|
||||
}
|
||||
|
||||
Args:
|
||||
request_data: The original `request_data` passed to LiteLLM Proxy
|
||||
"""
|
||||
requested_guardrails = self.get_guardrail_from_metadata(request_data)
|
||||
|
||||
# Look for the guardrail configuration matching self.guardrail_name
|
||||
for guardrail in requested_guardrails:
|
||||
if isinstance(guardrail, dict) and self.guardrail_name in guardrail:
|
||||
# Get the configuration for this guardrail
|
||||
guardrail_config: DynamicGuardrailParams = DynamicGuardrailParams(
|
||||
**guardrail[self.guardrail_name]
|
||||
)
|
||||
extra_body = guardrail_config.get("extra_body", {})
|
||||
if self._validate_premium_user() is not True:
|
||||
if isinstance(extra_body, dict) and extra_body:
|
||||
verbose_logger.warning(
|
||||
"Guardrail %s: ignoring dynamic extra_body keys %s because premium_user is False",
|
||||
self.guardrail_name,
|
||||
list(extra_body.keys()),
|
||||
)
|
||||
return {}
|
||||
|
||||
# Return the extra_body if it exists, otherwise empty dict
|
||||
return extra_body
|
||||
|
||||
return {}
|
||||
|
||||
def _validate_premium_user(self) -> bool:
|
||||
"""
|
||||
Returns True if the user is a premium user
|
||||
"""
|
||||
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
verbose_logger.warning(
|
||||
f"Trying to use premium guardrail without premium user {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def add_standard_logging_guardrail_information_to_request_data(
|
||||
self,
|
||||
guardrail_json_response: Union[Exception, str, dict, List[dict]],
|
||||
request_data: dict,
|
||||
guardrail_status: GuardrailStatus,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
duration: Optional[float] = None,
|
||||
masked_entity_count: Optional[Dict[str, int]] = None,
|
||||
guardrail_provider: Optional[str] = None,
|
||||
event_type: Optional[GuardrailEventHooks] = None,
|
||||
tracing_detail: Optional[GuardrailTracingDetail] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Builds `StandardLoggingGuardrailInformation` and adds it to the request metadata so it can be used for logging to DataDog, Langfuse, etc.
|
||||
|
||||
Args:
|
||||
tracing_detail: Optional typed dict with provider-specific tracing fields
|
||||
(guardrail_id, policy_template, detection_method, confidence_score,
|
||||
classification, match_details, patterns_checked, alert_recipients).
|
||||
"""
|
||||
if isinstance(guardrail_json_response, Exception):
|
||||
guardrail_json_response = str(guardrail_json_response)
|
||||
from litellm.types.utils import GuardrailMode
|
||||
|
||||
# Use event_type if provided, otherwise fall back to self.event_hook
|
||||
guardrail_mode: Union[
|
||||
GuardrailEventHooks, GuardrailMode, List[GuardrailEventHooks]
|
||||
]
|
||||
if event_type is not None:
|
||||
guardrail_mode = event_type
|
||||
elif isinstance(self.event_hook, Mode):
|
||||
guardrail_mode = GuardrailMode(**dict(self.event_hook.model_dump())) # type: ignore[typeddict-item]
|
||||
else:
|
||||
guardrail_mode = self.event_hook # type: ignore[assignment]
|
||||
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
filter_exceptions_from_params,
|
||||
)
|
||||
|
||||
# Sanitize the response to ensure it's JSON serializable and free of circular refs
|
||||
# This prevents RecursionErrors in downstream loggers (Langfuse, Datadog, etc.)
|
||||
clean_guardrail_response = filter_exceptions_from_params(
|
||||
guardrail_json_response
|
||||
)
|
||||
|
||||
# Strip secret_fields to prevent plaintext Authorization headers from
|
||||
# being persisted to spend logs, OTEL traces, or other logging backends.
|
||||
# This matches the pattern used by Langfuse and Arize integrations.
|
||||
if isinstance(clean_guardrail_response, dict):
|
||||
clean_guardrail_response.pop("secret_fields", None)
|
||||
elif isinstance(clean_guardrail_response, list):
|
||||
for item in clean_guardrail_response:
|
||||
if isinstance(item, dict):
|
||||
item.pop("secret_fields", None)
|
||||
|
||||
slg = StandardLoggingGuardrailInformation(
|
||||
guardrail_name=self.guardrail_name,
|
||||
guardrail_provider=guardrail_provider,
|
||||
guardrail_mode=guardrail_mode,
|
||||
guardrail_response=clean_guardrail_response,
|
||||
guardrail_status=guardrail_status,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
duration=duration,
|
||||
masked_entity_count=masked_entity_count,
|
||||
**(tracing_detail or {}),
|
||||
)
|
||||
|
||||
def _append_guardrail_info(container: dict) -> None:
|
||||
key = "standard_logging_guardrail_information"
|
||||
existing = container.get(key)
|
||||
if existing is None:
|
||||
container[key] = [slg]
|
||||
elif isinstance(existing, list):
|
||||
existing.append(slg)
|
||||
else:
|
||||
# should not happen
|
||||
container[key] = [existing, slg]
|
||||
|
||||
if "metadata" in request_data:
|
||||
if request_data["metadata"] is None:
|
||||
request_data["metadata"] = {}
|
||||
_append_guardrail_info(request_data["metadata"])
|
||||
elif "litellm_metadata" in request_data:
|
||||
_append_guardrail_info(request_data["litellm_metadata"])
|
||||
else:
|
||||
# Ensure guardrail info is always logged (e.g. proxy may not have set
|
||||
# metadata yet). Attach to "metadata" so spend log / standard logging see it.
|
||||
request_data["metadata"] = {}
|
||||
_append_guardrail_info(request_data["metadata"])
|
||||
|
||||
async def apply_guardrail(
|
||||
self,
|
||||
inputs: GenericGuardrailAPIInputs,
|
||||
request_data: dict,
|
||||
input_type: Literal["request", "response"],
|
||||
logging_obj: Optional["LiteLLMLoggingObj"] = None,
|
||||
) -> GenericGuardrailAPIInputs:
|
||||
"""
|
||||
Apply your guardrail logic to the given inputs
|
||||
|
||||
Args:
|
||||
inputs: Dictionary containing:
|
||||
- texts: List of texts to apply the guardrail to
|
||||
- images: Optional list of images to apply the guardrail to
|
||||
- tool_calls: Optional list of tool calls to apply the guardrail to
|
||||
request_data: The request data dictionary - containing user api key metadata (e.g. user_id, team_id, etc.)
|
||||
input_type: The type of input to apply the guardrail to - "request" or "response"
|
||||
logging_obj: Optional logging object for tracking the guardrail execution
|
||||
|
||||
Any of the custom guardrails can override this method to provide custom guardrail logic
|
||||
|
||||
Returns the texts with the guardrail applied and the images with the guardrail applied (if any)
|
||||
|
||||
Raises:
|
||||
Exception:
|
||||
- If the guardrail raises an exception
|
||||
|
||||
"""
|
||||
return inputs
|
||||
|
||||
def _process_response(
|
||||
self,
|
||||
response: Optional[Dict],
|
||||
request_data: dict,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
duration: Optional[float] = None,
|
||||
event_type: Optional[GuardrailEventHooks] = None,
|
||||
original_inputs: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Add StandardLoggingGuardrailInformation to the request data
|
||||
|
||||
This gets logged on downsteam Langfuse, DataDog, etc.
|
||||
"""
|
||||
# Convert None to empty dict to satisfy type requirements
|
||||
guardrail_response: Union[Dict[str, Any], str] = (
|
||||
{} if response is None else response
|
||||
)
|
||||
|
||||
# For apply_guardrail functions in custom_code_guardrail scenario,
|
||||
# simplify the logged response to "allow", "deny", or "mask"
|
||||
if original_inputs is not None and isinstance(response, dict):
|
||||
# Check if inputs were modified by comparing them
|
||||
if self._inputs_were_modified(original_inputs, response):
|
||||
guardrail_response = "mask"
|
||||
else:
|
||||
guardrail_response = "allow"
|
||||
|
||||
verbose_logger.debug(f"Guardrail response: {response}")
|
||||
|
||||
self.add_standard_logging_guardrail_information_to_request_data(
|
||||
guardrail_json_response=guardrail_response,
|
||||
request_data=request_data,
|
||||
guardrail_status="success",
|
||||
duration=duration,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_type=event_type,
|
||||
)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _is_guardrail_intervention(e: Exception) -> bool:
|
||||
"""
|
||||
Returns True if the exception represents an intentional guardrail block
|
||||
(this was logged previously as an API failure - guardrail_failed_to_respond).
|
||||
|
||||
Guardrails signal intentional blocks by raising:
|
||||
- HTTPException with status 400 (content policy violation)
|
||||
- ModifyResponseException (passthrough mode violation)
|
||||
"""
|
||||
|
||||
if isinstance(e, ModifyResponseException):
|
||||
return True
|
||||
if (
|
||||
HTTPException is not None
|
||||
and isinstance(e, HTTPException)
|
||||
and e.status_code == 400
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _process_error(
|
||||
self,
|
||||
e: Exception,
|
||||
request_data: dict,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
duration: Optional[float] = None,
|
||||
event_type: Optional[GuardrailEventHooks] = None,
|
||||
):
|
||||
"""
|
||||
Add StandardLoggingGuardrailInformation to the request data
|
||||
|
||||
This gets logged on downsteam Langfuse, DataDog, etc.
|
||||
"""
|
||||
guardrail_status: GuardrailStatus = (
|
||||
"guardrail_intervened"
|
||||
if self._is_guardrail_intervention(e)
|
||||
else "guardrail_failed_to_respond"
|
||||
)
|
||||
# For custom_code_guardrail scenario, log as "deny" instead of full exception
|
||||
# Check if this is from custom_code_guardrail by checking the class name
|
||||
guardrail_response: Union[Exception, str] = e
|
||||
if "CustomCodeGuardrail" in self.__class__.__name__:
|
||||
guardrail_response = "deny"
|
||||
|
||||
self.add_standard_logging_guardrail_information_to_request_data(
|
||||
guardrail_json_response=guardrail_response,
|
||||
request_data=request_data,
|
||||
guardrail_status=guardrail_status,
|
||||
duration=duration,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_type=event_type,
|
||||
)
|
||||
raise e
|
||||
|
||||
def _inputs_were_modified(self, original_inputs: Dict, response: Dict) -> bool:
|
||||
"""
|
||||
Compare original inputs with response to determine if content was modified.
|
||||
|
||||
Returns True if the inputs were modified (mask scenario), False otherwise (allow scenario).
|
||||
"""
|
||||
# Get all keys from both dictionaries
|
||||
all_keys = set(original_inputs.keys()) | set(response.keys())
|
||||
|
||||
# Compare each key's value
|
||||
for key in all_keys:
|
||||
original_value = original_inputs.get(key)
|
||||
response_value = response.get(key)
|
||||
if original_value != response_value:
|
||||
return True
|
||||
|
||||
# No modifications detected
|
||||
return False
|
||||
|
||||
def mask_content_in_string(
|
||||
self,
|
||||
content_string: str,
|
||||
mask_string: str,
|
||||
start_index: int,
|
||||
end_index: int,
|
||||
) -> str:
|
||||
"""
|
||||
Mask the content in the string between the start and end indices.
|
||||
"""
|
||||
|
||||
# Do nothing if the start or end are not valid
|
||||
if not (0 <= start_index < end_index <= len(content_string)):
|
||||
return content_string
|
||||
|
||||
# Mask the content
|
||||
return content_string[:start_index] + mask_string + content_string[end_index:]
|
||||
|
||||
def update_in_memory_litellm_params(self, litellm_params: LitellmParams) -> None:
|
||||
"""
|
||||
Update the guardrails litellm params in memory
|
||||
"""
|
||||
for key, value in vars(litellm_params).items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def get_guardrails_messages_for_call_type(
|
||||
self, call_type: CallTypes, data: Optional[dict] = None
|
||||
) -> Optional[List[AllMessageValues]]:
|
||||
"""
|
||||
Returns the messages for the given call type and data
|
||||
"""
|
||||
if call_type is None or data is None:
|
||||
return None
|
||||
|
||||
#########################################################
|
||||
# /chat/completions
|
||||
# /messages
|
||||
# Both endpoints store the messages in the "messages" key
|
||||
#########################################################
|
||||
if (
|
||||
call_type == CallTypes.completion.value
|
||||
or call_type == CallTypes.acompletion.value
|
||||
or call_type == CallTypes.anthropic_messages.value
|
||||
):
|
||||
return data.get("messages")
|
||||
|
||||
#########################################################
|
||||
# /responses
|
||||
# User/System messages are stored in the "input" key, use litellm transformation to get the messages
|
||||
#########################################################
|
||||
if (
|
||||
call_type == CallTypes.responses.value
|
||||
or call_type == CallTypes.aresponses.value
|
||||
):
|
||||
from typing import cast
|
||||
|
||||
from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
LiteLLMCompletionResponsesConfig,
|
||||
)
|
||||
|
||||
input_data = data.get("input")
|
||||
if input_data is None:
|
||||
return None
|
||||
|
||||
messages = LiteLLMCompletionResponsesConfig.transform_responses_api_input_to_messages(
|
||||
input=input_data,
|
||||
responses_api_request=data,
|
||||
)
|
||||
return cast(List[AllMessageValues], messages)
|
||||
return None
|
||||
|
||||
|
||||
def log_guardrail_information(func):
|
||||
"""
|
||||
Decorator to add standard logging guardrail information to any function
|
||||
|
||||
Add this decorator to ensure your guardrail response is logged to DataDog, OTEL, s3, GCS etc.
|
||||
|
||||
Logs for:
|
||||
- pre_call
|
||||
- during_call
|
||||
- post_call
|
||||
"""
|
||||
import functools
|
||||
import inspect
|
||||
|
||||
def _infer_event_type_from_function_name(
|
||||
func_name: str,
|
||||
) -> Optional[GuardrailEventHooks]:
|
||||
"""Infer the actual event type from the function name"""
|
||||
if func_name == "async_pre_call_hook":
|
||||
return GuardrailEventHooks.pre_call
|
||||
elif func_name == "async_moderation_hook":
|
||||
return GuardrailEventHooks.during_call
|
||||
elif func_name in (
|
||||
"async_post_call_success_hook",
|
||||
"async_post_call_streaming_hook",
|
||||
):
|
||||
return GuardrailEventHooks.post_call
|
||||
return None
|
||||
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
start_time = datetime.now() # Move start_time inside the wrapper
|
||||
self: CustomGuardrail = args[0]
|
||||
request_data: dict = kwargs.get("data") or kwargs.get("request_data") or {}
|
||||
event_type = _infer_event_type_from_function_name(func.__name__)
|
||||
|
||||
# Store original inputs for comparison (for apply_guardrail functions)
|
||||
original_inputs = None
|
||||
if func.__name__ == "apply_guardrail" and "inputs" in kwargs:
|
||||
original_inputs = kwargs.get("inputs")
|
||||
|
||||
try:
|
||||
response = await func(*args, **kwargs)
|
||||
return self._process_response(
|
||||
response=response,
|
||||
request_data=request_data,
|
||||
start_time=start_time.timestamp(),
|
||||
end_time=datetime.now().timestamp(),
|
||||
duration=(datetime.now() - start_time).total_seconds(),
|
||||
event_type=event_type,
|
||||
original_inputs=original_inputs,
|
||||
)
|
||||
except Exception as e:
|
||||
return self._process_error(
|
||||
e=e,
|
||||
request_data=request_data,
|
||||
start_time=start_time.timestamp(),
|
||||
end_time=datetime.now().timestamp(),
|
||||
duration=(datetime.now() - start_time).total_seconds(),
|
||||
event_type=event_type,
|
||||
)
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
start_time = datetime.now() # Move start_time inside the wrapper
|
||||
self: CustomGuardrail = args[0]
|
||||
request_data: dict = kwargs.get("data") or kwargs.get("request_data") or {}
|
||||
event_type = _infer_event_type_from_function_name(func.__name__)
|
||||
|
||||
# Store original inputs for comparison (for apply_guardrail functions)
|
||||
original_inputs = None
|
||||
if func.__name__ == "apply_guardrail" and "inputs" in kwargs:
|
||||
original_inputs = kwargs.get("inputs")
|
||||
|
||||
try:
|
||||
response = func(*args, **kwargs)
|
||||
return self._process_response(
|
||||
response=response,
|
||||
request_data=request_data,
|
||||
duration=(datetime.now() - start_time).total_seconds(),
|
||||
event_type=event_type,
|
||||
original_inputs=original_inputs,
|
||||
)
|
||||
except Exception as e:
|
||||
return self._process_error(
|
||||
e=e,
|
||||
request_data=request_data,
|
||||
duration=(datetime.now() - start_time).total_seconds(),
|
||||
event_type=event_type,
|
||||
)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return async_wrapper(*args, **kwargs)
|
||||
return sync_wrapper(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,83 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.prompt_management_base import (
|
||||
PromptManagementBase,
|
||||
PromptManagementClient,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.prompts.init_prompts import PromptSpec
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
|
||||
class CustomPromptManagement(CustomLogger, PromptManagementBase):
|
||||
def __init__(
|
||||
self,
|
||||
ignore_prompt_manager_model: Optional[bool] = False,
|
||||
ignore_prompt_manager_optional_params: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_prompt_manager_model = ignore_prompt_manager_model
|
||||
self.ignore_prompt_manager_optional_params = (
|
||||
ignore_prompt_manager_optional_params
|
||||
)
|
||||
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
ignore_prompt_manager_model: Optional[bool] = False,
|
||||
ignore_prompt_manager_optional_params: Optional[bool] = False,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Returns:
|
||||
- model: str - the model to use (can be pulled from prompt management tool)
|
||||
- messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool)
|
||||
- non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool)
|
||||
"""
|
||||
return model, messages, non_default_params
|
||||
|
||||
@property
|
||||
def integration_name(self) -> str:
|
||||
return "custom-prompt-management"
|
||||
|
||||
def should_run_prompt_management(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def _compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> PromptManagementClient:
|
||||
raise NotImplementedError(
|
||||
"Custom prompt management does not support compile prompt helper"
|
||||
)
|
||||
|
||||
async def async_compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> PromptManagementClient:
|
||||
raise NotImplementedError(
|
||||
"Custom prompt management does not support async compile prompt helper"
|
||||
)
|
||||
@@ -0,0 +1,252 @@
|
||||
"""
|
||||
Custom Secret Manager Integration
|
||||
|
||||
This module provides a base class for implementing custom secret managers in LiteLLM.
|
||||
|
||||
Usage:
|
||||
from litellm.integrations.custom_secret_manager import CustomSecretManager
|
||||
|
||||
class MySecretManager(CustomSecretManager):
|
||||
def __init__(self):
|
||||
super().__init__(secret_manager_name="my_secret_manager")
|
||||
|
||||
async def async_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params=None,
|
||||
timeout=None,
|
||||
):
|
||||
# Your implementation here
|
||||
return await self._fetch_secret_from_service(secret_name)
|
||||
|
||||
def sync_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params=None,
|
||||
timeout=None,
|
||||
):
|
||||
# Your implementation here
|
||||
return self._fetch_secret_from_service_sync(secret_name)
|
||||
|
||||
# Set your custom secret manager
|
||||
import litellm
|
||||
from litellm.types.secret_managers.main import KeyManagementSystem
|
||||
|
||||
litellm.secret_manager_client = MySecretManager()
|
||||
litellm._key_management_system = KeyManagementSystem.CUSTOM
|
||||
"""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.secret_managers.base_secret_manager import BaseSecretManager
|
||||
|
||||
|
||||
class CustomSecretManager(BaseSecretManager):
|
||||
"""
|
||||
Base class for implementing custom secret managers.
|
||||
|
||||
This class provides a standard interface for implementing custom secret management
|
||||
integrations in LiteLLM. Users can extend this class to integrate their own secret
|
||||
management systems.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from litellm.integrations.custom_secret_manager import CustomSecretManager
|
||||
|
||||
class MyVaultSecretManager(CustomSecretManager):
|
||||
def __init__(self, vault_url: str, token: str):
|
||||
super().__init__(secret_manager_name="my_vault")
|
||||
self.vault_url = vault_url
|
||||
self.token = token
|
||||
|
||||
async def async_read_secret(self, secret_name: str, optional_params=None, timeout=None):
|
||||
# Implementation for reading secrets from your vault
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.vault_url}/v1/secret/{secret_name}",
|
||||
headers={"X-Vault-Token": self.token},
|
||||
timeout=timeout
|
||||
)
|
||||
return response.json()["data"]["value"]
|
||||
|
||||
def sync_read_secret(self, secret_name: str, optional_params=None, timeout=None):
|
||||
# Sync implementation
|
||||
with httpx.Client() as client:
|
||||
response = client.get(
|
||||
f"{self.vault_url}/v1/secret/{secret_name}",
|
||||
headers={"X-Vault-Token": self.token},
|
||||
timeout=timeout
|
||||
)
|
||||
return response.json()["data"]["value"]
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
secret_manager_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the CustomSecretManager.
|
||||
|
||||
Args:
|
||||
secret_manager_name: A descriptive name for your secret manager.
|
||||
This is used for logging and debugging purposes.
|
||||
**kwargs: Additional keyword arguments to pass to your secret manager.
|
||||
"""
|
||||
super().__init__()
|
||||
self.secret_manager_name = secret_manager_name or "custom_secret_manager"
|
||||
verbose_logger.info("Initialized custom secret manager")
|
||||
|
||||
@abstractmethod
|
||||
async def async_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Asynchronously read a secret from your custom secret manager.
|
||||
|
||||
Args:
|
||||
secret_name: Name/path of the secret to read
|
||||
optional_params: Additional parameters specific to your secret manager
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
The secret value if found, None otherwise
|
||||
|
||||
Raises:
|
||||
Exception: If there's an error reading the secret
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Synchronously read a secret from your custom secret manager.
|
||||
|
||||
Args:
|
||||
secret_name: Name/path of the secret to read
|
||||
optional_params: Additional parameters specific to your secret manager
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
The secret value if found, None otherwise
|
||||
|
||||
Raises:
|
||||
Exception: If there's an error reading the secret
|
||||
"""
|
||||
pass
|
||||
|
||||
async def async_write_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
secret_value: str,
|
||||
description: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
tags: Optional[Union[dict, list]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Asynchronously write a secret to your custom secret manager.
|
||||
|
||||
This is optional to implement. If your secret manager supports writing secrets,
|
||||
you can override this method.
|
||||
|
||||
Args:
|
||||
secret_name: Name/path of the secret to write
|
||||
secret_value: Value to store
|
||||
description: Description of the secret
|
||||
optional_params: Additional parameters specific to your secret manager
|
||||
timeout: Request timeout
|
||||
tags: Optional tags to apply to the secret
|
||||
|
||||
Returns:
|
||||
Response from the secret manager containing write operation details
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If write operations are not supported
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"Write operations are not implemented for {self.secret_manager_name}. "
|
||||
"Override async_write_secret() to add write support."
|
||||
)
|
||||
|
||||
async def async_delete_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
recovery_window_in_days: Optional[int] = 7,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Asynchronously delete a secret from your custom secret manager.
|
||||
|
||||
This is optional to implement. If your secret manager supports deleting secrets,
|
||||
you can override this method.
|
||||
|
||||
Args:
|
||||
secret_name: Name of the secret to delete
|
||||
recovery_window_in_days: Number of days before permanent deletion (if supported)
|
||||
optional_params: Additional parameters specific to your secret manager
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
Response from the secret manager containing deletion details
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If delete operations are not supported
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"Delete operations are not implemented for {self.secret_manager_name}. "
|
||||
"Override async_delete_secret() to add delete support."
|
||||
)
|
||||
|
||||
def validate_environment(self) -> bool:
|
||||
"""
|
||||
Validate that all required environment variables and configuration are present.
|
||||
|
||||
Override this method to validate your secret manager's configuration.
|
||||
|
||||
Returns:
|
||||
True if the environment is valid
|
||||
|
||||
Raises:
|
||||
ValueError: If required configuration is missing
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
"No environment validation configured for custom secret manager"
|
||||
)
|
||||
return True
|
||||
|
||||
async def async_health_check(
|
||||
self, timeout: Optional[Union[float, httpx.Timeout]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Perform a health check on your secret manager.
|
||||
|
||||
This is optional to implement. Override this method to add health check support.
|
||||
|
||||
Args:
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
True if the secret manager is healthy, False otherwise
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
f"Health check not implemented for {self.secret_manager_name}"
|
||||
)
|
||||
return True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(name={self.secret_manager_name})>"
|
||||
@@ -0,0 +1,30 @@
|
||||
from fastapi import Request
|
||||
from fastapi_sso.sso.base import OpenID
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
|
||||
class CustomSSOLoginHandler(CustomLogger):
|
||||
"""
|
||||
Custom logger for the UI SSO sign in
|
||||
|
||||
Use this to parse the request headers and return a OpenID object
|
||||
|
||||
Useful when you have an OAuth proxy in front of LiteLLM
|
||||
and you want to use the headers from the proxy to sign in the user
|
||||
"""
|
||||
|
||||
async def handle_custom_ui_sso_sign_in(
|
||||
self,
|
||||
request: Request,
|
||||
) -> OpenID:
|
||||
request_headers_dict = dict(request.headers)
|
||||
return OpenID(
|
||||
id=request_headers_dict.get("x-litellm-user-id"),
|
||||
email=request_headers_dict.get("x-litellm-user-email"),
|
||||
first_name="Test",
|
||||
last_name="Test",
|
||||
display_name="Test",
|
||||
picture="https://test.com/test.png",
|
||||
provider="test",
|
||||
)
|
||||
@@ -0,0 +1,777 @@
|
||||
"""
|
||||
DataDog Integration - sends logs to /api/v2/log
|
||||
|
||||
DD Reference API: https://docs.datadoghq.com/api/latest/logs
|
||||
|
||||
`async_log_success_event` - used by litellm proxy to send logs to datadog
|
||||
`log_success_event` - sync version of logging to DataDog, only used on litellm Python SDK, if user opts in to using sync functions
|
||||
|
||||
async_log_success_event: will store batch of DD_MAX_BATCH_SIZE in memory and flush to Datadog once it reaches DD_MAX_BATCH_SIZE or every 5 seconds
|
||||
|
||||
async_service_failure_hook: Logs failures from Redis, Postgres (Adjacent systems), as 'WARNING' on DataDog
|
||||
|
||||
For batching specific details see CustomBatchLogger class
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import os
|
||||
import traceback
|
||||
from datetime import datetime as datetimeObj
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from httpx import Response
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.integrations.datadog.datadog_mock_client import (
|
||||
should_use_datadog_mock,
|
||||
create_mock_datadog_client,
|
||||
)
|
||||
from litellm.integrations.datadog.datadog_handler import (
|
||||
get_datadog_hostname,
|
||||
get_datadog_service,
|
||||
get_datadog_source,
|
||||
get_datadog_tags,
|
||||
get_datadog_base_url_from_env,
|
||||
)
|
||||
from litellm.litellm_core_utils.dd_tracing import tracer
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
|
||||
from litellm.types.integrations.datadog import (
|
||||
DD_ERRORS,
|
||||
DD_MAX_BATCH_SIZE,
|
||||
DataDogStatus,
|
||||
DatadogInitParams,
|
||||
DatadogPayload,
|
||||
DatadogProxyFailureHookJsonMessage,
|
||||
)
|
||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
from ..additional_logging_utils import AdditionalLoggingUtils
|
||||
|
||||
# max number of logs DD API can accept
|
||||
|
||||
|
||||
# specify what ServiceTypes are logged as success events to DD. (We don't want to spam DD traces with large number of service types)
|
||||
DD_LOGGED_SUCCESS_SERVICE_TYPES = [
|
||||
ServiceTypes.RESET_BUDGET_JOB,
|
||||
]
|
||||
|
||||
|
||||
class DataDogLogger(
|
||||
CustomBatchLogger,
|
||||
AdditionalLoggingUtils,
|
||||
):
|
||||
# Class variables or attributes
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the datadog logger, checks if the correct env variables are set
|
||||
|
||||
Required environment variables (Direct API):
|
||||
`DD_API_KEY` - your datadog api key
|
||||
`DD_SITE` - your datadog site, example = `"us5.datadoghq.com"`
|
||||
|
||||
Optional environment variables (DataDog Agent):
|
||||
`LITELLM_DD_AGENT_HOST` - hostname or IP of DataDog agent, example = `"localhost"`
|
||||
`LITELLM_DD_AGENT_PORT` - port of DataDog agent (default: 10518 for logs)
|
||||
|
||||
Note: We use LITELLM_DD_AGENT_HOST instead of DD_AGENT_HOST to avoid conflicts
|
||||
with ddtrace which automatically sets DD_AGENT_HOST for APM tracing.
|
||||
"""
|
||||
try:
|
||||
verbose_logger.debug("Datadog: in init datadog logger")
|
||||
|
||||
self.is_mock_mode = should_use_datadog_mock()
|
||||
|
||||
if self.is_mock_mode:
|
||||
create_mock_datadog_client()
|
||||
verbose_logger.debug(
|
||||
"[DATADOG MOCK] Datadog logger initialized in mock mode"
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# Handle datadog_params set as litellm.datadog_params
|
||||
#########################################################
|
||||
dict_datadog_params = self._get_datadog_params()
|
||||
kwargs.update(dict_datadog_params)
|
||||
|
||||
self.async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
|
||||
# Configure DataDog endpoint (Agent or Direct API)
|
||||
# Use LITELLM_DD_AGENT_HOST to avoid conflicts with ddtrace's DD_AGENT_HOST
|
||||
dd_agent_host = os.getenv("LITELLM_DD_AGENT_HOST")
|
||||
if dd_agent_host:
|
||||
self._configure_dd_agent(dd_agent_host=dd_agent_host)
|
||||
else:
|
||||
self._configure_dd_direct_api()
|
||||
|
||||
# Optional override for testing
|
||||
dd_base_url = get_datadog_base_url_from_env()
|
||||
if dd_base_url:
|
||||
self.intake_url = f"{dd_base_url}/api/v2/logs"
|
||||
self.sync_client = _get_httpx_client()
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(
|
||||
**kwargs, flush_lock=self.flush_lock, batch_size=DD_MAX_BATCH_SIZE
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog: Got exception on init Datadog client {str(e)}"
|
||||
)
|
||||
raise e
|
||||
|
||||
def _get_datadog_params(self) -> Dict:
|
||||
"""
|
||||
Get the datadog_params from litellm.datadog_params
|
||||
|
||||
These are params specific to initializing the DataDogLogger e.g. turn_off_message_logging
|
||||
"""
|
||||
dict_datadog_params: Dict = {}
|
||||
if litellm.datadog_params is not None:
|
||||
if isinstance(litellm.datadog_params, DatadogInitParams):
|
||||
dict_datadog_params = litellm.datadog_params.model_dump()
|
||||
elif isinstance(litellm.datadog_params, Dict):
|
||||
# only allow params that are of DatadogInitParams
|
||||
dict_datadog_params = DatadogInitParams(
|
||||
**litellm.datadog_params
|
||||
).model_dump()
|
||||
return dict_datadog_params
|
||||
|
||||
def _configure_dd_agent(self, dd_agent_host: str) -> None:
|
||||
"""
|
||||
Configure DataDog Agent for log forwarding
|
||||
|
||||
Args:
|
||||
dd_agent_host: Hostname or IP of DataDog agent
|
||||
"""
|
||||
dd_agent_port = os.getenv(
|
||||
"LITELLM_DD_AGENT_PORT", "10518"
|
||||
) # default port for logs
|
||||
self.intake_url = f"http://{dd_agent_host}:{dd_agent_port}/api/v2/logs"
|
||||
self.DD_API_KEY = os.getenv("DD_API_KEY") # Optional when using agent
|
||||
verbose_logger.debug(f"Datadog: Using DD Agent at {self.intake_url}")
|
||||
|
||||
def _configure_dd_direct_api(self) -> None:
|
||||
"""
|
||||
Configure direct DataDog API connection
|
||||
|
||||
Raises:
|
||||
Exception: If required environment variables are not set
|
||||
"""
|
||||
if os.getenv("DD_API_KEY", None) is None:
|
||||
raise Exception("DD_API_KEY is not set, set 'DD_API_KEY=<>")
|
||||
if os.getenv("DD_SITE", None) is None:
|
||||
raise Exception("DD_SITE is not set in .env, set 'DD_SITE=<>")
|
||||
|
||||
self.DD_API_KEY = os.getenv("DD_API_KEY")
|
||||
self.intake_url = f"https://http-intake.logs.{os.getenv('DD_SITE')}/api/v2/logs"
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Async Log success events to Datadog
|
||||
|
||||
- Creates a Datadog payload
|
||||
- Adds the Payload to the in memory logs queue
|
||||
- Payload is flushed every 10 seconds or when batch size is greater than 100
|
||||
|
||||
|
||||
Raises:
|
||||
Raises a NON Blocking verbose_logger.exception if an error occurs
|
||||
"""
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"Datadog: Logging - Enters logging function for model %s", kwargs
|
||||
)
|
||||
await self._log_async_event(kwargs, response_obj, start_time, end_time)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"Datadog: Logging - Enters logging function for model %s", kwargs
|
||||
)
|
||||
await self._log_async_event(kwargs, response_obj, start_time, end_time)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_post_call_failure_hook(
|
||||
self,
|
||||
request_data: dict,
|
||||
original_exception: Exception,
|
||||
user_api_key_dict: Any,
|
||||
traceback_str: Optional[str] = None,
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Log proxy-level failures (e.g. 401 auth, DB connection errors) to Datadog.
|
||||
|
||||
Ensures failures that occur before or outside the LLM completion flow
|
||||
(e.g. ConnectError during auth when DB is down) are visible in Datadog
|
||||
alongside Prometheus.
|
||||
"""
|
||||
try:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
StandardLoggingPayloadSetup,
|
||||
)
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
error_information = StandardLoggingPayloadSetup.get_error_information(
|
||||
original_exception=original_exception,
|
||||
traceback_str=traceback_str,
|
||||
)
|
||||
_code = error_information.get("error_code") or ""
|
||||
status_code: Optional[int] = None
|
||||
if _code and str(_code).strip().isdigit():
|
||||
status_code = int(_code)
|
||||
|
||||
# Use project-standard sanitized user context when running in proxy
|
||||
user_context: Dict[str, Any] = {}
|
||||
try:
|
||||
from litellm.proxy.litellm_pre_call_utils import (
|
||||
LiteLLMProxyRequestSetup,
|
||||
)
|
||||
|
||||
_meta = (
|
||||
LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
|
||||
user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
)
|
||||
user_context = dict(_meta) if isinstance(_meta, dict) else _meta
|
||||
except Exception:
|
||||
# Fallback if proxy not available (e.g. SDK-only): minimal safe fields
|
||||
if hasattr(user_api_key_dict, "request_route"):
|
||||
user_context["request_route"] = getattr(
|
||||
user_api_key_dict, "request_route", None
|
||||
)
|
||||
if hasattr(user_api_key_dict, "team_id"):
|
||||
user_context["team_id"] = getattr(
|
||||
user_api_key_dict, "team_id", None
|
||||
)
|
||||
if hasattr(user_api_key_dict, "user_id"):
|
||||
user_context["user_id"] = getattr(
|
||||
user_api_key_dict, "user_id", None
|
||||
)
|
||||
if hasattr(user_api_key_dict, "end_user_id"):
|
||||
user_context["end_user_id"] = getattr(
|
||||
user_api_key_dict, "end_user_id", None
|
||||
)
|
||||
|
||||
message_payload: DatadogProxyFailureHookJsonMessage = {
|
||||
"exception": error_information.get("error_message")
|
||||
or str(original_exception),
|
||||
"error_class": error_information.get("error_class")
|
||||
or original_exception.__class__.__name__,
|
||||
"status_code": status_code,
|
||||
"traceback": error_information.get("traceback") or "",
|
||||
"user_api_key_dict": user_context,
|
||||
}
|
||||
|
||||
dd_payload = DatadogPayload(
|
||||
ddsource=get_datadog_source(),
|
||||
ddtags=get_datadog_tags(),
|
||||
hostname=get_datadog_hostname(),
|
||||
message=safe_dumps(message_payload),
|
||||
service=get_datadog_service(),
|
||||
status=DataDogStatus.ERROR,
|
||||
)
|
||||
self._add_trace_context_to_payload(dd_payload=dd_payload)
|
||||
self.log_queue.append(dd_payload)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.async_send_batch()
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog: async_post_call_failure_hook - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
Sends the in memory logs queue to datadog api
|
||||
|
||||
Logs sent to /api/v2/logs
|
||||
|
||||
DD Ref: https://docs.datadoghq.com/api/latest/logs/
|
||||
|
||||
Raises:
|
||||
Raises a NON Blocking verbose_logger.exception if an error occurs
|
||||
"""
|
||||
try:
|
||||
if not self.log_queue:
|
||||
verbose_logger.exception("Datadog: log_queue does not exist")
|
||||
return
|
||||
|
||||
verbose_logger.debug(
|
||||
"Datadog - about to flush %s events on %s",
|
||||
len(self.log_queue),
|
||||
self.intake_url,
|
||||
)
|
||||
|
||||
if self.is_mock_mode:
|
||||
verbose_logger.debug(
|
||||
"[DATADOG MOCK] Mock mode enabled - API calls will be intercepted"
|
||||
)
|
||||
|
||||
response = await self.async_send_compressed_data(self.log_queue)
|
||||
if response.status_code == 413:
|
||||
verbose_logger.exception(DD_ERRORS.DATADOG_413_ERROR.value)
|
||||
return
|
||||
|
||||
response.raise_for_status()
|
||||
if response.status_code != 202:
|
||||
raise Exception(
|
||||
f"Response from datadog API status_code: {response.status_code}, text: {response.text}"
|
||||
)
|
||||
|
||||
if self.is_mock_mode:
|
||||
verbose_logger.debug(
|
||||
f"[DATADOG MOCK] Batch of {len(self.log_queue)} events successfully mocked"
|
||||
)
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
"Datadog: Response from datadog API status_code: %s, text: %s",
|
||||
response.status_code,
|
||||
response.text,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog Error sending batch API - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Sync Log success events to Datadog
|
||||
|
||||
- Creates a Datadog payload
|
||||
- instantly logs it on DD API
|
||||
"""
|
||||
try:
|
||||
if litellm.datadog_use_v1 is True:
|
||||
dd_payload = self._create_v0_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
else:
|
||||
dd_payload = self.create_datadog_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
# Build headers
|
||||
headers = {}
|
||||
# Add API key if available (required for direct API, optional for agent)
|
||||
if self.DD_API_KEY:
|
||||
headers["DD-API-KEY"] = self.DD_API_KEY
|
||||
|
||||
response = self.sync_client.post(
|
||||
url=self.intake_url,
|
||||
json=dd_payload, # type: ignore
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
if response.status_code != 202:
|
||||
raise Exception(
|
||||
f"Response from datadog API status_code: {response.status_code}, text: {response.text}"
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"Datadog: Response from datadog API status_code: %s, text: %s",
|
||||
response.status_code,
|
||||
response.text,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
pass
|
||||
|
||||
async def _log_async_event(self, kwargs, response_obj, start_time, end_time):
|
||||
dd_payload = self.create_datadog_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
self.log_queue.append(dd_payload)
|
||||
verbose_logger.debug(
|
||||
f"Datadog, event added to queue. Will flush in {self.flush_interval} seconds..."
|
||||
)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.async_send_batch()
|
||||
|
||||
def _create_datadog_logging_payload_helper(
|
||||
self,
|
||||
standard_logging_object: StandardLoggingPayload,
|
||||
status: DataDogStatus,
|
||||
) -> DatadogPayload:
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
json_payload = safe_dumps(standard_logging_object)
|
||||
verbose_logger.debug("Datadog: Logger - Logging payload = %s", json_payload)
|
||||
dd_payload = DatadogPayload(
|
||||
ddsource=get_datadog_source(),
|
||||
ddtags=get_datadog_tags(standard_logging_object=standard_logging_object),
|
||||
hostname=get_datadog_hostname(),
|
||||
message=json_payload,
|
||||
service=get_datadog_service(),
|
||||
status=status,
|
||||
)
|
||||
self._add_trace_context_to_payload(dd_payload=dd_payload)
|
||||
return dd_payload
|
||||
|
||||
def create_datadog_logging_payload(
|
||||
self,
|
||||
kwargs: Union[dict, Any],
|
||||
response_obj: Any,
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
) -> DatadogPayload:
|
||||
"""
|
||||
Helper function to create a datadog payload for logging
|
||||
|
||||
Args:
|
||||
kwargs (Union[dict, Any]): request kwargs
|
||||
response_obj (Any): llm api response
|
||||
start_time (datetime.datetime): start time of request
|
||||
end_time (datetime.datetime): end time of request
|
||||
|
||||
Returns:
|
||||
DatadogPayload: defined in types.py
|
||||
"""
|
||||
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
if standard_logging_object is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
|
||||
status = DataDogStatus.INFO
|
||||
if standard_logging_object.get("status") == "failure":
|
||||
status = DataDogStatus.ERROR
|
||||
|
||||
# Build the initial payload
|
||||
self.truncate_standard_logging_payload_content(standard_logging_object)
|
||||
|
||||
dd_payload = self._create_datadog_logging_payload_helper(
|
||||
standard_logging_object=standard_logging_object,
|
||||
status=status,
|
||||
)
|
||||
return dd_payload
|
||||
|
||||
async def async_send_compressed_data(self, data: List) -> Response:
|
||||
"""
|
||||
Async helper to send compressed data to datadog self.intake_url
|
||||
|
||||
Datadog recommends using gzip to compress data
|
||||
https://docs.datadoghq.com/api/latest/logs/
|
||||
|
||||
"Datadog recommends sending your logs compressed. Add the Content-Encoding: gzip header to the request when sending"
|
||||
"""
|
||||
|
||||
import gzip
|
||||
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
compressed_data = gzip.compress(safe_dumps(data).encode("utf-8"))
|
||||
|
||||
# Build headers
|
||||
headers = {
|
||||
"Content-Encoding": "gzip",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Add API key if available (required for direct API, optional for agent)
|
||||
if self.DD_API_KEY:
|
||||
headers["DD-API-KEY"] = self.DD_API_KEY
|
||||
|
||||
response = await self.async_client.post(
|
||||
url=self.intake_url,
|
||||
data=compressed_data, # type: ignore
|
||||
headers=headers,
|
||||
)
|
||||
return response
|
||||
|
||||
async def async_service_failure_hook(
|
||||
self,
|
||||
payload: ServiceLoggerPayload,
|
||||
error: Optional[str] = "",
|
||||
parent_otel_span: Optional[Any] = None,
|
||||
start_time: Optional[Union[datetimeObj, float]] = None,
|
||||
end_time: Optional[Union[float, datetimeObj]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Logs failures from Redis, Postgres (Adjacent systems), as 'WARNING' on DataDog
|
||||
|
||||
- example - Redis is failing / erroring, will be logged on DataDog
|
||||
"""
|
||||
try:
|
||||
_payload_dict = payload.model_dump()
|
||||
_payload_dict.update(event_metadata or {})
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
_dd_message_str = safe_dumps(_payload_dict)
|
||||
_dd_payload = DatadogPayload(
|
||||
ddsource=get_datadog_source(),
|
||||
ddtags=get_datadog_tags(),
|
||||
hostname=get_datadog_hostname(),
|
||||
message=_dd_message_str,
|
||||
service=get_datadog_service(),
|
||||
status=DataDogStatus.WARN,
|
||||
)
|
||||
|
||||
self.log_queue.append(_dd_payload)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog: Logger - Exception in async_service_failure_hook: {e}"
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_service_success_hook(
|
||||
self,
|
||||
payload: ServiceLoggerPayload,
|
||||
error: Optional[str] = "",
|
||||
parent_otel_span: Optional[Any] = None,
|
||||
start_time: Optional[Union[datetimeObj, float]] = None,
|
||||
end_time: Optional[Union[float, datetimeObj]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Logs success from Redis, Postgres (Adjacent systems), as 'INFO' on DataDog
|
||||
|
||||
No user has asked for this so far, this might be spammy on datatdog. If need arises we can implement this
|
||||
"""
|
||||
try:
|
||||
# intentionally done. Don't want to log all service types to DD
|
||||
if payload.service not in DD_LOGGED_SUCCESS_SERVICE_TYPES:
|
||||
return
|
||||
|
||||
_payload_dict = payload.model_dump()
|
||||
_payload_dict.update(event_metadata or {})
|
||||
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
_dd_message_str = safe_dumps(_payload_dict)
|
||||
_dd_payload = DatadogPayload(
|
||||
ddsource=get_datadog_source(),
|
||||
ddtags=get_datadog_tags(),
|
||||
hostname=get_datadog_hostname(),
|
||||
message=_dd_message_str,
|
||||
service=get_datadog_service(),
|
||||
status=DataDogStatus.INFO,
|
||||
)
|
||||
|
||||
self.log_queue.append(_dd_payload)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog: Logger - Exception in async_service_failure_hook: {e}"
|
||||
)
|
||||
|
||||
def _create_v0_logging_payload(
|
||||
self,
|
||||
kwargs: Union[dict, Any],
|
||||
response_obj: Any,
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
) -> DatadogPayload:
|
||||
"""
|
||||
Note: This is our V1 Version of DataDog Logging Payload
|
||||
|
||||
|
||||
(Not Recommended) If you want this to get logged set `litellm.datadog_use_v1 = True`
|
||||
"""
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
) # if litellm_params['metadata'] == None
|
||||
messages = kwargs.get("messages")
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
call_type = kwargs.get("call_type", "litellm.completion")
|
||||
cache_hit = kwargs.get("cache_hit", False)
|
||||
usage = response_obj["usage"]
|
||||
id = response_obj.get("id", str(uuid.uuid4()))
|
||||
usage = dict(usage)
|
||||
try:
|
||||
response_time = (end_time - start_time).total_seconds() * 1000
|
||||
except Exception:
|
||||
response_time = None
|
||||
|
||||
try:
|
||||
response_obj = dict(response_obj)
|
||||
except Exception:
|
||||
response_obj = response_obj
|
||||
|
||||
# Clean Metadata before logging - never log raw metadata
|
||||
# the raw metadata can contain circular references which leads to infinite recursion
|
||||
# we clean out all extra litellm metadata params before logging
|
||||
clean_metadata = {}
|
||||
if isinstance(metadata, dict):
|
||||
for key, value in metadata.items():
|
||||
# clean litellm metadata before logging
|
||||
if key in [
|
||||
"endpoint",
|
||||
"caching_groups",
|
||||
"previous_models",
|
||||
]:
|
||||
continue
|
||||
else:
|
||||
clean_metadata[key] = value
|
||||
|
||||
# Build the initial payload
|
||||
payload = {
|
||||
"id": id,
|
||||
"call_type": call_type,
|
||||
"cache_hit": cache_hit,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"response_time": response_time,
|
||||
"model": kwargs.get("model", ""),
|
||||
"user": kwargs.get("user", ""),
|
||||
"model_parameters": optional_params,
|
||||
"spend": kwargs.get("response_cost", 0),
|
||||
"messages": messages,
|
||||
"response": response_obj,
|
||||
"usage": usage,
|
||||
"metadata": clean_metadata,
|
||||
}
|
||||
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
json_payload = safe_dumps(payload)
|
||||
|
||||
verbose_logger.debug("Datadog: Logger - Logging payload = %s", json_payload)
|
||||
|
||||
dd_payload = DatadogPayload(
|
||||
ddsource=get_datadog_source(),
|
||||
ddtags=get_datadog_tags(),
|
||||
hostname=get_datadog_hostname(),
|
||||
message=json_payload,
|
||||
service=get_datadog_service(),
|
||||
status=DataDogStatus.INFO,
|
||||
)
|
||||
return dd_payload
|
||||
|
||||
def _add_trace_context_to_payload(
|
||||
self,
|
||||
dd_payload: DatadogPayload,
|
||||
) -> None:
|
||||
"""Attach Datadog APM trace context if one is active."""
|
||||
|
||||
try:
|
||||
trace_context = self._get_active_trace_context()
|
||||
if trace_context is None:
|
||||
return
|
||||
|
||||
dd_payload["dd.trace_id"] = trace_context["trace_id"]
|
||||
span_id = trace_context.get("span_id")
|
||||
if span_id is not None:
|
||||
dd_payload["dd.span_id"] = span_id
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
"Datadog: Failed to attach trace context to payload"
|
||||
)
|
||||
|
||||
def _get_active_trace_context(self) -> Optional[Dict[str, str]]:
|
||||
try:
|
||||
current_span = None
|
||||
current_span_fn = getattr(tracer, "current_span", None)
|
||||
if callable(current_span_fn):
|
||||
current_span = current_span_fn()
|
||||
|
||||
if current_span is None:
|
||||
current_root_span_fn = getattr(tracer, "current_root_span", None)
|
||||
if callable(current_root_span_fn):
|
||||
current_span = current_root_span_fn()
|
||||
|
||||
if current_span is None:
|
||||
return None
|
||||
|
||||
trace_id = getattr(current_span, "trace_id", None)
|
||||
if trace_id is None:
|
||||
return None
|
||||
|
||||
span_id = getattr(current_span, "span_id", None)
|
||||
trace_context: Dict[str, str] = {"trace_id": str(trace_id)}
|
||||
if span_id is not None:
|
||||
trace_context["span_id"] = str(span_id)
|
||||
return trace_context
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
"Datadog: Failed to retrieve active trace context from tracer"
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_health_check(self) -> IntegrationHealthCheckStatus:
|
||||
"""
|
||||
Check if the service is healthy
|
||||
"""
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
create_dummy_standard_logging_payload,
|
||||
)
|
||||
|
||||
standard_logging_object = create_dummy_standard_logging_payload()
|
||||
dd_payload = self._create_datadog_logging_payload_helper(
|
||||
standard_logging_object=standard_logging_object,
|
||||
status=DataDogStatus.INFO,
|
||||
)
|
||||
log_queue = [dd_payload]
|
||||
response = await self.async_send_compressed_data(log_queue)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
return IntegrationHealthCheckStatus(
|
||||
status="healthy",
|
||||
error_message=None,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
return IntegrationHealthCheckStatus(
|
||||
status="unhealthy",
|
||||
error_message=e.response.text,
|
||||
)
|
||||
except Exception as e:
|
||||
return IntegrationHealthCheckStatus(
|
||||
status="unhealthy",
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
async def get_request_response_payload(
|
||||
self,
|
||||
request_id: str,
|
||||
start_time_utc: Optional[datetimeObj],
|
||||
end_time_utc: Optional[datetimeObj],
|
||||
) -> Optional[dict]:
|
||||
pass
|
||||
@@ -0,0 +1,216 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.integrations.datadog_cost_management import (
|
||||
DatadogFOCUSCostEntry,
|
||||
)
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
class DatadogCostManagementLogger(CustomBatchLogger):
|
||||
def __init__(self, **kwargs):
|
||||
self.dd_api_key = os.getenv("DD_API_KEY")
|
||||
self.dd_app_key = os.getenv("DD_APP_KEY")
|
||||
self.dd_site = os.getenv("DD_SITE", "datadoghq.com")
|
||||
|
||||
if not self.dd_api_key or not self.dd_app_key:
|
||||
verbose_logger.warning(
|
||||
"Datadog Cost Management: DD_API_KEY and DD_APP_KEY are required. Integration will not work."
|
||||
)
|
||||
|
||||
self.upload_url = f"https://api.{self.dd_site}/api/v2/cost/custom_costs"
|
||||
|
||||
self.async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
|
||||
# Initialize lock and start periodic flush task
|
||||
self.flush_lock = asyncio.Lock()
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
|
||||
# Check if flush_lock is already in kwargs to avoid double passing (unlikely but safe)
|
||||
if "flush_lock" not in kwargs:
|
||||
kwargs["flush_lock"] = self.flush_lock
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
if standard_logging_object is None:
|
||||
return
|
||||
|
||||
# Only log if there is a cost associated
|
||||
if standard_logging_object.get("response_cost", 0) > 0:
|
||||
self.log_queue.append(standard_logging_object)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.async_send_batch()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog Cost Management: Error in async_log_success_event: {str(e)}"
|
||||
)
|
||||
|
||||
async def async_send_batch(self):
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
try:
|
||||
# Aggregate costs from the batch
|
||||
aggregated_entries = self._aggregate_costs(self.log_queue)
|
||||
|
||||
if not aggregated_entries:
|
||||
return
|
||||
|
||||
# Send to Datadog
|
||||
await self._upload_to_datadog(aggregated_entries)
|
||||
|
||||
# Clear queue only on success (or if we decide to drop on failure)
|
||||
# CustomBatchLogger clears queue in flush_queue, so we just process here
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog Cost Management: Error in async_send_batch: {str(e)}"
|
||||
)
|
||||
|
||||
def _aggregate_costs(
|
||||
self, logs: List[StandardLoggingPayload]
|
||||
) -> List[DatadogFOCUSCostEntry]:
|
||||
"""
|
||||
Aggregates costs by Provider, Model, and Date.
|
||||
Returns a list of DatadogFOCUSCostEntry.
|
||||
"""
|
||||
aggregator: Dict[
|
||||
Tuple[str, str, str, Tuple[Tuple[str, str], ...]], DatadogFOCUSCostEntry
|
||||
] = {}
|
||||
|
||||
for log in logs:
|
||||
try:
|
||||
# Extract keys for aggregation
|
||||
provider = log.get("custom_llm_provider") or "unknown"
|
||||
model = log.get("model") or "unknown"
|
||||
cost = log.get("response_cost", 0)
|
||||
|
||||
if cost == 0:
|
||||
continue
|
||||
|
||||
# Get date strings (FOCUS format requires specific keys, but for aggregation we group by Day)
|
||||
# UTC date
|
||||
# We interpret "ChargePeriod" as the day of the request.
|
||||
ts = log.get("startTime") or time.time()
|
||||
dt = datetime.fromtimestamp(ts)
|
||||
date_str = dt.strftime("%Y-%m-%d")
|
||||
|
||||
# ChargePeriodStart and End
|
||||
# If we want daily granularity, end date is usually same day or next day?
|
||||
# Datadog Custom Costs usually expects periods.
|
||||
# "ChargePeriodStart": "2023-01-01", "ChargePeriodEnd": "2023-12-31" in example.
|
||||
# If we send daily, we can say Start=Date, End=Date.
|
||||
|
||||
# Grouping Key: Provider + Model + Date + Tags?
|
||||
# For simplicity, let's aggregate by Provider + Model + Date first.
|
||||
# If we handle tags, we need to include them in the key.
|
||||
|
||||
tags = self._extract_tags(log)
|
||||
tags_key = tuple(sorted(tags.items())) if tags else ()
|
||||
|
||||
key = (provider, model, date_str, tags_key)
|
||||
|
||||
if key not in aggregator:
|
||||
aggregator[key] = {
|
||||
"ProviderName": provider,
|
||||
"ChargeDescription": f"LLM Usage for {model}",
|
||||
"ChargePeriodStart": date_str,
|
||||
"ChargePeriodEnd": date_str,
|
||||
"BilledCost": 0.0,
|
||||
"BillingCurrency": "USD",
|
||||
"Tags": tags if tags else None,
|
||||
}
|
||||
|
||||
aggregator[key]["BilledCost"] += cost
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
f"Error processing log for cost aggregation: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
return list(aggregator.values())
|
||||
|
||||
def _extract_tags(self, log: StandardLoggingPayload) -> Dict[str, str]:
|
||||
from litellm.integrations.datadog.datadog_handler import (
|
||||
get_datadog_env,
|
||||
get_datadog_hostname,
|
||||
get_datadog_pod_name,
|
||||
get_datadog_service,
|
||||
)
|
||||
|
||||
tags = {
|
||||
"env": get_datadog_env(),
|
||||
"service": get_datadog_service(),
|
||||
"host": get_datadog_hostname(),
|
||||
"pod_name": get_datadog_pod_name(),
|
||||
}
|
||||
|
||||
# Add metadata as tags
|
||||
metadata = log.get("metadata", {})
|
||||
if metadata:
|
||||
# Add user info
|
||||
# Add user info
|
||||
if metadata.get("user_api_key_alias"):
|
||||
tags["user"] = str(metadata["user_api_key_alias"])
|
||||
|
||||
# Add Team Tag
|
||||
team_tag = (
|
||||
metadata.get("user_api_key_team_alias")
|
||||
or metadata.get("team_alias") # type: ignore
|
||||
or metadata.get("user_api_key_team_id")
|
||||
or metadata.get("team_id") # type: ignore
|
||||
)
|
||||
|
||||
if team_tag:
|
||||
tags["team"] = str(team_tag)
|
||||
# model_group is not in StandardLoggingMetadata TypedDict, so we need to access it via dict.get()
|
||||
model_group = metadata.get("model_group") # type: ignore[misc]
|
||||
if model_group:
|
||||
tags["model_group"] = str(model_group)
|
||||
|
||||
return tags
|
||||
|
||||
async def _upload_to_datadog(self, payload: List[Dict]):
|
||||
if not self.dd_api_key or not self.dd_app_key:
|
||||
return
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"DD-API-KEY": self.dd_api_key,
|
||||
"DD-APPLICATION-KEY": self.dd_app_key,
|
||||
}
|
||||
|
||||
# The API endpoint expects a list of objects directly in the body (file content behavior)
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
data_json = safe_dumps(payload)
|
||||
|
||||
response = await self.async_client.put(
|
||||
self.upload_url, content=data_json, headers=headers
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Datadog Cost Management: Uploaded {len(payload)} cost entries. Status: {response.status_code}"
|
||||
)
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Shared helpers for Datadog integrations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
def get_datadog_source() -> str:
|
||||
return os.getenv("DD_SOURCE", "litellm")
|
||||
|
||||
|
||||
def get_datadog_service() -> str:
|
||||
return os.getenv("DD_SERVICE", "litellm-server")
|
||||
|
||||
|
||||
def get_datadog_hostname() -> str:
|
||||
return os.getenv("HOSTNAME", "")
|
||||
|
||||
|
||||
def get_datadog_base_url_from_env() -> Optional[str]:
|
||||
"""
|
||||
Get base URL override from common DD_BASE_URL env var.
|
||||
This is useful for testing or custom endpoints.
|
||||
"""
|
||||
return os.getenv("DD_BASE_URL")
|
||||
|
||||
|
||||
def get_datadog_env() -> str:
|
||||
return os.getenv("DD_ENV", "unknown")
|
||||
|
||||
|
||||
def get_datadog_pod_name() -> str:
|
||||
return os.getenv("POD_NAME", "unknown")
|
||||
|
||||
|
||||
def get_datadog_tags(
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = None,
|
||||
) -> str:
|
||||
"""Build Datadog tags string used by multiple integrations."""
|
||||
|
||||
base_tags = {
|
||||
"env": get_datadog_env(),
|
||||
"service": get_datadog_service(),
|
||||
"version": os.getenv("DD_VERSION", "unknown"),
|
||||
"HOSTNAME": get_datadog_hostname(),
|
||||
"POD_NAME": get_datadog_pod_name(),
|
||||
}
|
||||
|
||||
tags: List[str] = [f"{k}:{v}" for k, v in base_tags.items()]
|
||||
|
||||
if standard_logging_object:
|
||||
request_tags = standard_logging_object.get("request_tags", []) or []
|
||||
tags.extend(f"request_tag:{tag}" for tag in request_tags)
|
||||
|
||||
# Add Team Tag
|
||||
metadata = standard_logging_object.get("metadata", {}) or {}
|
||||
team_tag = (
|
||||
metadata.get("user_api_key_team_alias")
|
||||
or metadata.get("team_alias")
|
||||
or metadata.get("user_api_key_team_id")
|
||||
or metadata.get("team_id")
|
||||
)
|
||||
if team_tag:
|
||||
tags.append(f"team:{team_tag}")
|
||||
|
||||
return ",".join(tags)
|
||||
@@ -0,0 +1,856 @@
|
||||
"""
|
||||
Implements logging integration with Datadog's LLM Observability Service
|
||||
|
||||
|
||||
API Reference: https://docs.datadoghq.com/llm_observability/setup/api/?tab=example#api-standards
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from litellm._uuid import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.integrations.datadog.datadog_mock_client import (
|
||||
should_use_datadog_mock,
|
||||
create_mock_datadog_client,
|
||||
)
|
||||
from litellm.integrations.datadog.datadog_handler import (
|
||||
get_datadog_service,
|
||||
get_datadog_tags,
|
||||
get_datadog_base_url_from_env,
|
||||
)
|
||||
from litellm.litellm_core_utils.dd_tracing import tracer
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
handle_any_messages_to_chat_completion_str_messages_conversion,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.integrations.datadog_llm_obs import *
|
||||
from litellm.types.utils import (
|
||||
CallTypes,
|
||||
StandardLoggingGuardrailInformation,
|
||||
StandardLoggingPayload,
|
||||
StandardLoggingPayloadErrorInformation,
|
||||
)
|
||||
|
||||
|
||||
class DataDogLLMObsLogger(CustomBatchLogger):
|
||||
def __init__(self, **kwargs):
|
||||
try:
|
||||
verbose_logger.debug("DataDogLLMObs: Initializing logger")
|
||||
|
||||
self.is_mock_mode = should_use_datadog_mock()
|
||||
|
||||
if self.is_mock_mode:
|
||||
create_mock_datadog_client()
|
||||
verbose_logger.debug(
|
||||
"[DATADOG MOCK] DataDogLLMObs logger initialized in mock mode"
|
||||
)
|
||||
|
||||
# Configure DataDog endpoint (Agent or Direct API)
|
||||
# Use LITELLM_DD_AGENT_HOST to avoid conflicts with ddtrace's DD_AGENT_HOST
|
||||
# Check for agent mode FIRST - agent mode doesn't require DD_API_KEY or DD_SITE
|
||||
dd_agent_host = os.getenv("LITELLM_DD_AGENT_HOST")
|
||||
|
||||
self.async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
self.DD_API_KEY = os.getenv("DD_API_KEY")
|
||||
|
||||
if dd_agent_host:
|
||||
self._configure_dd_agent(dd_agent_host=dd_agent_host)
|
||||
else:
|
||||
# Only require DD_API_KEY and DD_SITE for direct API mode
|
||||
if os.getenv("DD_API_KEY", None) is None:
|
||||
raise Exception("DD_API_KEY is not set, set 'DD_API_KEY=<>'")
|
||||
if os.getenv("DD_SITE", None) is None:
|
||||
raise Exception(
|
||||
"DD_SITE is not set, set 'DD_SITE=<>', example sit = `us5.datadoghq.com`"
|
||||
)
|
||||
self._configure_dd_direct_api()
|
||||
|
||||
# Optional override for testing
|
||||
dd_base_url = get_datadog_base_url_from_env()
|
||||
if dd_base_url:
|
||||
self.intake_url = f"{dd_base_url}/api/intake/llm-obs/v1/trace/spans"
|
||||
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.flush_lock = asyncio.Lock()
|
||||
self.log_queue: List[LLMObsPayload] = []
|
||||
|
||||
#########################################################
|
||||
# Handle datadog_llm_observability_params set as litellm.datadog_llm_observability_params
|
||||
#########################################################
|
||||
dict_datadog_llm_obs_params = self._get_datadog_llm_obs_params()
|
||||
kwargs.update(dict_datadog_llm_obs_params)
|
||||
CustomBatchLogger.__init__(self, **kwargs, flush_lock=self.flush_lock)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"DataDogLLMObs: Error initializing - {str(e)}")
|
||||
raise e
|
||||
|
||||
def _configure_dd_agent(self, dd_agent_host: str):
|
||||
"""
|
||||
Configure the Datadog logger to send traces to the Agent.
|
||||
"""
|
||||
# When using the Agent, LLM Observability Intake does NOT require the API Key
|
||||
# Reference: https://docs.datadoghq.com/llm_observability/setup/sdk/#agent-setup
|
||||
|
||||
# Use specific port for LLM Obs (Trace Agent) to avoid conflict with Logs Agent (10518)
|
||||
agent_port = os.getenv("LITELLM_DD_LLM_OBS_PORT", "8126")
|
||||
self.DD_SITE = "localhost" # Not used for URL construction in agent mode
|
||||
self.intake_url = (
|
||||
f"http://{dd_agent_host}:{agent_port}/api/intake/llm-obs/v1/trace/spans"
|
||||
)
|
||||
verbose_logger.debug(f"DataDogLLMObs: Using DD Agent at {self.intake_url}")
|
||||
|
||||
def _configure_dd_direct_api(self):
|
||||
"""
|
||||
Configure the Datadog logger to send traces directly to the Datadog API.
|
||||
"""
|
||||
if not self.DD_API_KEY:
|
||||
raise Exception("DD_API_KEY is not set, set 'DD_API_KEY=<>'")
|
||||
|
||||
self.DD_SITE = os.getenv("DD_SITE")
|
||||
if not self.DD_SITE:
|
||||
raise Exception(
|
||||
"DD_SITE is not set, set 'DD_SITE=<>', example site = `us5.datadoghq.com`"
|
||||
)
|
||||
|
||||
self.intake_url = (
|
||||
f"https://api.{self.DD_SITE}/api/intake/llm-obs/v1/trace/spans"
|
||||
)
|
||||
|
||||
def _get_datadog_llm_obs_params(self) -> Dict:
|
||||
"""
|
||||
Get the datadog_llm_observability_params from litellm.datadog_llm_observability_params
|
||||
|
||||
These are params specific to initializing the DataDogLLMObsLogger e.g. turn_off_message_logging
|
||||
"""
|
||||
dict_datadog_llm_obs_params: Dict = {}
|
||||
if litellm.datadog_llm_observability_params is not None:
|
||||
if isinstance(
|
||||
litellm.datadog_llm_observability_params, DatadogLLMObsInitParams
|
||||
):
|
||||
dict_datadog_llm_obs_params = (
|
||||
litellm.datadog_llm_observability_params.model_dump()
|
||||
)
|
||||
elif isinstance(litellm.datadog_llm_observability_params, Dict):
|
||||
# only allow params that are of DatadogLLMObsInitParams
|
||||
dict_datadog_llm_obs_params = DatadogLLMObsInitParams(
|
||||
**litellm.datadog_llm_observability_params
|
||||
).model_dump()
|
||||
return dict_datadog_llm_obs_params
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
f"DataDogLLMObs: Logging success event for model {kwargs.get('model', 'unknown')}"
|
||||
)
|
||||
payload = self.create_llm_obs_payload(kwargs, start_time, end_time)
|
||||
verbose_logger.debug(f"DataDogLLMObs: Payload: {payload}")
|
||||
self.log_queue.append(payload)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.async_send_batch()
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"DataDogLLMObs: Error logging success event - {str(e)}"
|
||||
)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
f"DataDogLLMObs: Logging failure event for model {kwargs.get('model', 'unknown')}"
|
||||
)
|
||||
payload = self.create_llm_obs_payload(kwargs, start_time, end_time)
|
||||
verbose_logger.debug(f"DataDogLLMObs: Payload: {payload}")
|
||||
self.log_queue.append(payload)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.async_send_batch()
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"DataDogLLMObs: Error logging failure event - {str(e)}"
|
||||
)
|
||||
|
||||
async def async_send_batch(self):
|
||||
try:
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
verbose_logger.debug(
|
||||
f"DataDogLLMObs: Flushing {len(self.log_queue)} events"
|
||||
)
|
||||
|
||||
if self.is_mock_mode:
|
||||
verbose_logger.debug(
|
||||
"[DATADOG MOCK] Mock mode enabled - API calls will be intercepted"
|
||||
)
|
||||
|
||||
# Prepare the payload
|
||||
payload = {
|
||||
"data": DDIntakePayload(
|
||||
type="span",
|
||||
attributes=DDSpanAttributes(
|
||||
ml_app=get_datadog_service(),
|
||||
tags=[get_datadog_tags()],
|
||||
spans=self.log_queue,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
# serialize datetime objects - for budget reset time in spend metrics
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
try:
|
||||
verbose_logger.debug("payload %s", safe_dumps(payload))
|
||||
except Exception as debug_error:
|
||||
verbose_logger.debug(
|
||||
"payload serialization failed: %s", str(debug_error)
|
||||
)
|
||||
|
||||
json_payload = safe_dumps(payload)
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.DD_API_KEY:
|
||||
headers["DD-API-KEY"] = self.DD_API_KEY
|
||||
|
||||
response = await self.async_client.post(
|
||||
url=self.intake_url,
|
||||
content=json_payload,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code != 202:
|
||||
raise Exception(
|
||||
f"DataDogLLMObs: Unexpected response - status_code: {response.status_code}, text: {response.text}"
|
||||
)
|
||||
|
||||
if self.is_mock_mode:
|
||||
verbose_logger.debug(
|
||||
f"[DATADOG MOCK] Batch of {len(self.log_queue)} events successfully mocked"
|
||||
)
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"DataDogLLMObs: Successfully sent batch - status_code: {response.status_code}"
|
||||
)
|
||||
self.log_queue.clear()
|
||||
except httpx.HTTPStatusError as e:
|
||||
verbose_logger.exception(
|
||||
f"DataDogLLMObs: Error sending batch - {e.response.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"DataDogLLMObs: Error sending batch - {str(e)}")
|
||||
|
||||
def create_llm_obs_payload(
|
||||
self, kwargs: Dict, start_time: datetime, end_time: datetime
|
||||
) -> LLMObsPayload:
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
if standard_logging_payload is None:
|
||||
raise Exception("DataDogLLMObs: standard_logging_object is not set")
|
||||
|
||||
messages = standard_logging_payload["messages"]
|
||||
messages = self._ensure_string_content(messages=messages)
|
||||
|
||||
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
|
||||
input_meta = InputMeta(
|
||||
messages=handle_any_messages_to_chat_completion_str_messages_conversion(
|
||||
messages
|
||||
)
|
||||
)
|
||||
output_meta = OutputMeta(
|
||||
messages=self._get_response_messages(
|
||||
standard_logging_payload=standard_logging_payload,
|
||||
call_type=standard_logging_payload.get("call_type"),
|
||||
)
|
||||
)
|
||||
|
||||
error_info = self._assemble_error_info(standard_logging_payload)
|
||||
|
||||
metadata_parent_id: Optional[str] = None
|
||||
if isinstance(metadata, dict):
|
||||
metadata_parent_id = metadata.get("parent_id")
|
||||
|
||||
meta = Meta(
|
||||
kind=self._get_datadog_span_kind(
|
||||
standard_logging_payload.get("call_type"), metadata_parent_id
|
||||
),
|
||||
input=input_meta,
|
||||
output=output_meta,
|
||||
metadata=self._get_dd_llm_obs_payload_metadata(standard_logging_payload),
|
||||
error=error_info,
|
||||
)
|
||||
|
||||
# Calculate metrics (you may need to adjust these based on available data)
|
||||
metrics = LLMMetrics(
|
||||
input_tokens=float(standard_logging_payload.get("prompt_tokens", 0)),
|
||||
output_tokens=float(standard_logging_payload.get("completion_tokens", 0)),
|
||||
total_tokens=float(standard_logging_payload.get("total_tokens", 0)),
|
||||
total_cost=float(standard_logging_payload.get("response_cost", 0)),
|
||||
time_to_first_token=self._get_time_to_first_token_seconds(
|
||||
standard_logging_payload
|
||||
),
|
||||
)
|
||||
|
||||
payload: LLMObsPayload = LLMObsPayload(
|
||||
parent_id=metadata_parent_id if metadata_parent_id else "undefined",
|
||||
trace_id=standard_logging_payload.get("trace_id", str(uuid.uuid4())),
|
||||
span_id=metadata.get("span_id", str(uuid.uuid4())),
|
||||
name=metadata.get("name", "litellm_llm_call"),
|
||||
meta=meta,
|
||||
start_ns=int(start_time.timestamp() * 1e9),
|
||||
duration=int((end_time - start_time).total_seconds() * 1e9),
|
||||
metrics=metrics,
|
||||
status="error" if error_info else "ok",
|
||||
tags=[get_datadog_tags(standard_logging_object=standard_logging_payload)],
|
||||
)
|
||||
|
||||
apm_trace_id = self._get_apm_trace_id()
|
||||
if apm_trace_id is not None:
|
||||
payload["apm_id"] = apm_trace_id
|
||||
|
||||
return payload
|
||||
|
||||
def _get_apm_trace_id(self) -> Optional[str]:
|
||||
"""Retrieve the current APM trace ID if available."""
|
||||
try:
|
||||
current_span_fn = getattr(tracer, "current_span", None)
|
||||
if callable(current_span_fn):
|
||||
current_span = current_span_fn()
|
||||
if current_span is not None:
|
||||
trace_id = getattr(current_span, "trace_id", None)
|
||||
if trace_id is not None:
|
||||
return str(trace_id)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def _assemble_error_info(
|
||||
self, standard_logging_payload: StandardLoggingPayload
|
||||
) -> Optional[DDLLMObsError]:
|
||||
"""
|
||||
Assemble error information for failure cases according to DD LLM Obs API spec
|
||||
"""
|
||||
# Handle error information for failure cases according to DD LLM Obs API spec
|
||||
error_info: Optional[DDLLMObsError] = None
|
||||
|
||||
if standard_logging_payload.get("status") == "failure":
|
||||
# Try to get structured error information first
|
||||
error_information: Optional[
|
||||
StandardLoggingPayloadErrorInformation
|
||||
] = standard_logging_payload.get("error_information")
|
||||
|
||||
if error_information:
|
||||
error_info = DDLLMObsError(
|
||||
message=error_information.get("error_message")
|
||||
or standard_logging_payload.get("error_str")
|
||||
or "Unknown error",
|
||||
type=error_information.get("error_class"),
|
||||
stack=error_information.get("traceback"),
|
||||
)
|
||||
return error_info
|
||||
|
||||
def _get_time_to_first_token_seconds(
|
||||
self, standard_logging_payload: StandardLoggingPayload
|
||||
) -> float:
|
||||
"""
|
||||
Get the time to first token in seconds
|
||||
|
||||
CompletionStartTime - StartTime = Time to first token
|
||||
|
||||
For non streaming calls, CompletionStartTime is time we get the response back
|
||||
"""
|
||||
start_time: Optional[float] = standard_logging_payload.get("startTime")
|
||||
completion_start_time: Optional[float] = standard_logging_payload.get(
|
||||
"completionStartTime"
|
||||
)
|
||||
end_time: Optional[float] = standard_logging_payload.get("endTime")
|
||||
|
||||
if completion_start_time is not None and start_time is not None:
|
||||
return completion_start_time - start_time
|
||||
elif end_time is not None and start_time is not None:
|
||||
return end_time - start_time
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def _get_response_messages(
|
||||
self, standard_logging_payload: StandardLoggingPayload, call_type: Optional[str]
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Get the messages from the response object
|
||||
|
||||
for now this handles logging /chat/completions responses
|
||||
"""
|
||||
|
||||
response_obj = standard_logging_payload.get("response")
|
||||
if response_obj is None:
|
||||
return []
|
||||
|
||||
# edge case: handle response_obj is a string representation of a dict
|
||||
if isinstance(response_obj, str):
|
||||
try:
|
||||
import ast
|
||||
|
||||
response_obj = ast.literal_eval(response_obj)
|
||||
except (ValueError, SyntaxError):
|
||||
try:
|
||||
# fallback to json parsing
|
||||
response_obj = json.loads(str(response_obj))
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
if call_type in [
|
||||
CallTypes.completion.value,
|
||||
CallTypes.acompletion.value,
|
||||
CallTypes.text_completion.value,
|
||||
CallTypes.atext_completion.value,
|
||||
CallTypes.generate_content.value,
|
||||
CallTypes.agenerate_content.value,
|
||||
CallTypes.generate_content_stream.value,
|
||||
CallTypes.agenerate_content_stream.value,
|
||||
CallTypes.anthropic_messages.value,
|
||||
]:
|
||||
try:
|
||||
# Safely extract message from response_obj, handle failure cases
|
||||
if isinstance(response_obj, dict) and "choices" in response_obj:
|
||||
choices = response_obj["choices"]
|
||||
if choices and len(choices) > 0 and "message" in choices[0]:
|
||||
return [choices[0]["message"]]
|
||||
return []
|
||||
except (KeyError, IndexError, TypeError):
|
||||
# In case of any error accessing the response structure, return empty list
|
||||
return []
|
||||
return []
|
||||
|
||||
def _get_datadog_span_kind(
|
||||
self, call_type: Optional[str], parent_id: Optional[str] = None
|
||||
) -> Literal["llm", "tool", "task", "embedding", "retrieval"]:
|
||||
"""
|
||||
Map liteLLM call_type to appropriate DataDog LLM Observability span kind.
|
||||
|
||||
Available DataDog span kinds: "llm", "tool", "task", "embedding", "retrieval"
|
||||
see: https://docs.datadoghq.com/ja/llm_observability/terms/
|
||||
"""
|
||||
# Non llm/workflow/agent kinds cannot be root spans, so fallback to "llm" when parent metadata is missing
|
||||
if call_type is None or parent_id is None:
|
||||
return "llm"
|
||||
|
||||
# Embedding operations
|
||||
if call_type in [CallTypes.embedding.value, CallTypes.aembedding.value]:
|
||||
return "embedding"
|
||||
|
||||
# LLM completion operations
|
||||
if call_type in [
|
||||
CallTypes.completion.value,
|
||||
CallTypes.acompletion.value,
|
||||
CallTypes.text_completion.value,
|
||||
CallTypes.atext_completion.value,
|
||||
CallTypes.generate_content.value,
|
||||
CallTypes.agenerate_content.value,
|
||||
CallTypes.generate_content_stream.value,
|
||||
CallTypes.agenerate_content_stream.value,
|
||||
CallTypes.anthropic_messages.value,
|
||||
CallTypes.responses.value,
|
||||
CallTypes.aresponses.value,
|
||||
]:
|
||||
return "llm"
|
||||
|
||||
# Tool operations
|
||||
if call_type in [CallTypes.call_mcp_tool.value]:
|
||||
return "tool"
|
||||
|
||||
# Retrieval operations
|
||||
if call_type in [
|
||||
CallTypes.get_assistants.value,
|
||||
CallTypes.aget_assistants.value,
|
||||
CallTypes.get_thread.value,
|
||||
CallTypes.aget_thread.value,
|
||||
CallTypes.get_messages.value,
|
||||
CallTypes.aget_messages.value,
|
||||
CallTypes.afile_retrieve.value,
|
||||
CallTypes.file_retrieve.value,
|
||||
CallTypes.afile_list.value,
|
||||
CallTypes.file_list.value,
|
||||
CallTypes.afile_content.value,
|
||||
CallTypes.file_content.value,
|
||||
CallTypes.retrieve_batch.value,
|
||||
CallTypes.aretrieve_batch.value,
|
||||
CallTypes.retrieve_fine_tuning_job.value,
|
||||
CallTypes.aretrieve_fine_tuning_job.value,
|
||||
CallTypes.alist_input_items.value,
|
||||
]:
|
||||
return "retrieval"
|
||||
|
||||
# Task operations (batch, fine-tuning, file operations, etc.)
|
||||
if call_type in [
|
||||
CallTypes.create_batch.value,
|
||||
CallTypes.acreate_batch.value,
|
||||
CallTypes.create_fine_tuning_job.value,
|
||||
CallTypes.acreate_fine_tuning_job.value,
|
||||
CallTypes.cancel_fine_tuning_job.value,
|
||||
CallTypes.acancel_fine_tuning_job.value,
|
||||
CallTypes.list_fine_tuning_jobs.value,
|
||||
CallTypes.alist_fine_tuning_jobs.value,
|
||||
CallTypes.create_assistants.value,
|
||||
CallTypes.acreate_assistants.value,
|
||||
CallTypes.delete_assistant.value,
|
||||
CallTypes.adelete_assistant.value,
|
||||
CallTypes.create_thread.value,
|
||||
CallTypes.acreate_thread.value,
|
||||
CallTypes.add_message.value,
|
||||
CallTypes.a_add_message.value,
|
||||
CallTypes.run_thread.value,
|
||||
CallTypes.arun_thread.value,
|
||||
CallTypes.run_thread_stream.value,
|
||||
CallTypes.arun_thread_stream.value,
|
||||
CallTypes.file_delete.value,
|
||||
CallTypes.afile_delete.value,
|
||||
CallTypes.create_file.value,
|
||||
CallTypes.acreate_file.value,
|
||||
CallTypes.image_generation.value,
|
||||
CallTypes.aimage_generation.value,
|
||||
CallTypes.image_edit.value,
|
||||
CallTypes.aimage_edit.value,
|
||||
CallTypes.moderation.value,
|
||||
CallTypes.amoderation.value,
|
||||
CallTypes.transcription.value,
|
||||
CallTypes.atranscription.value,
|
||||
CallTypes.speech.value,
|
||||
CallTypes.aspeech.value,
|
||||
CallTypes.rerank.value,
|
||||
CallTypes.arerank.value,
|
||||
]:
|
||||
return "task"
|
||||
|
||||
# Default fallback for unknown or passthrough operations
|
||||
return "llm"
|
||||
|
||||
def _ensure_string_content(
|
||||
self, messages: Optional[Union[str, List[Any], Dict[Any, Any]]]
|
||||
) -> List[Any]:
|
||||
if messages is None:
|
||||
return []
|
||||
if isinstance(messages, str):
|
||||
return [messages]
|
||||
elif isinstance(messages, list):
|
||||
return [message for message in messages]
|
||||
elif isinstance(messages, dict):
|
||||
return [str(messages.get("content", ""))]
|
||||
return []
|
||||
|
||||
def _get_dd_llm_obs_payload_metadata(
|
||||
self, standard_logging_payload: StandardLoggingPayload
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Fields to track in DD LLM Observability metadata from litellm standard logging payload
|
||||
"""
|
||||
_metadata: Dict[str, Any] = {
|
||||
"model_name": standard_logging_payload.get("model", "unknown"),
|
||||
"model_provider": standard_logging_payload.get(
|
||||
"custom_llm_provider", "unknown"
|
||||
),
|
||||
"id": standard_logging_payload.get("id", "unknown"),
|
||||
"trace_id": standard_logging_payload.get("trace_id", "unknown"),
|
||||
"cache_hit": standard_logging_payload.get("cache_hit", "unknown"),
|
||||
"cache_key": standard_logging_payload.get("cache_key", "unknown"),
|
||||
"saved_cache_cost": standard_logging_payload.get("saved_cache_cost", 0),
|
||||
"guardrail_information": standard_logging_payload.get(
|
||||
"guardrail_information", None
|
||||
),
|
||||
"is_streamed_request": self._get_stream_value_from_payload(
|
||||
standard_logging_payload
|
||||
),
|
||||
}
|
||||
|
||||
#########################################################
|
||||
# Add latency metrics to metadata
|
||||
#########################################################
|
||||
latency_metrics = self._get_latency_metrics(standard_logging_payload)
|
||||
_metadata.update({"latency_metrics": dict(latency_metrics)})
|
||||
|
||||
#########################################################
|
||||
# Add spend metrics to metadata
|
||||
#########################################################
|
||||
spend_metrics = self._get_spend_metrics(standard_logging_payload)
|
||||
_metadata.update({"spend_metrics": dict(spend_metrics)})
|
||||
|
||||
## extract tool calls and add to metadata
|
||||
tool_call_metadata = self._extract_tool_call_metadata(standard_logging_payload)
|
||||
_metadata.update(tool_call_metadata)
|
||||
|
||||
_standard_logging_metadata: dict = (
|
||||
dict(standard_logging_payload.get("metadata", {})) or {}
|
||||
)
|
||||
_metadata.update(_standard_logging_metadata)
|
||||
return _metadata
|
||||
|
||||
def _get_latency_metrics(
|
||||
self, standard_logging_payload: StandardLoggingPayload
|
||||
) -> DDLLMObsLatencyMetrics:
|
||||
"""
|
||||
Get the latency metrics from the standard logging payload
|
||||
"""
|
||||
latency_metrics: DDLLMObsLatencyMetrics = DDLLMObsLatencyMetrics()
|
||||
# Add latency metrics to metadata
|
||||
# Time to first token (convert from seconds to milliseconds for consistency)
|
||||
time_to_first_token_seconds = self._get_time_to_first_token_seconds(
|
||||
standard_logging_payload
|
||||
)
|
||||
if time_to_first_token_seconds > 0:
|
||||
latency_metrics["time_to_first_token_ms"] = (
|
||||
time_to_first_token_seconds * 1000
|
||||
)
|
||||
|
||||
# LiteLLM overhead time
|
||||
hidden_params = standard_logging_payload.get("hidden_params", {})
|
||||
litellm_overhead_ms = hidden_params.get("litellm_overhead_time_ms")
|
||||
if litellm_overhead_ms is not None:
|
||||
latency_metrics["litellm_overhead_time_ms"] = litellm_overhead_ms
|
||||
|
||||
# Guardrail overhead latency
|
||||
guardrail_info: Optional[
|
||||
list[StandardLoggingGuardrailInformation]
|
||||
] = standard_logging_payload.get("guardrail_information")
|
||||
if guardrail_info is not None:
|
||||
total_duration = 0.0
|
||||
for info in guardrail_info:
|
||||
_guardrail_duration_seconds: Optional[float] = info.get("duration")
|
||||
if _guardrail_duration_seconds is not None:
|
||||
total_duration += float(_guardrail_duration_seconds)
|
||||
|
||||
if total_duration > 0:
|
||||
# Convert from seconds to milliseconds for consistency
|
||||
latency_metrics["guardrail_overhead_time_ms"] = total_duration * 1000
|
||||
|
||||
return latency_metrics
|
||||
|
||||
def _get_stream_value_from_payload(
|
||||
self, standard_logging_payload: StandardLoggingPayload
|
||||
) -> bool:
|
||||
"""
|
||||
Extract the stream value from standard logging payload.
|
||||
|
||||
The stream field in StandardLoggingPayload is only set to True for completed streaming responses.
|
||||
For non-streaming requests, it's None. The original stream parameter is in model_parameters.
|
||||
|
||||
Returns:
|
||||
bool: True if this was a streaming request, False otherwise
|
||||
"""
|
||||
# Check top-level stream field first (only True for completed streaming)
|
||||
stream_value = standard_logging_payload.get("stream")
|
||||
if stream_value is True:
|
||||
return True
|
||||
|
||||
# Fallback to model_parameters.stream for original request parameters
|
||||
model_params = standard_logging_payload.get("model_parameters", {})
|
||||
if isinstance(model_params, dict):
|
||||
stream_value = model_params.get("stream")
|
||||
if stream_value is True:
|
||||
return True
|
||||
|
||||
# Default to False for non-streaming requests
|
||||
return False
|
||||
|
||||
def _get_spend_metrics(
|
||||
self, standard_logging_payload: StandardLoggingPayload
|
||||
) -> DDLLMObsSpendMetrics:
|
||||
"""
|
||||
Get the spend metrics from the standard logging payload
|
||||
"""
|
||||
spend_metrics: DDLLMObsSpendMetrics = DDLLMObsSpendMetrics()
|
||||
|
||||
# send response cost
|
||||
spend_metrics["response_cost"] = standard_logging_payload.get(
|
||||
"response_cost", 0.0
|
||||
)
|
||||
|
||||
# Get budget information from metadata
|
||||
metadata = standard_logging_payload.get("metadata", {})
|
||||
|
||||
# API key max budget
|
||||
user_api_key_max_budget = metadata.get("user_api_key_max_budget")
|
||||
if user_api_key_max_budget is not None:
|
||||
spend_metrics["user_api_key_max_budget"] = float(user_api_key_max_budget)
|
||||
|
||||
# API key spend
|
||||
user_api_key_spend = metadata.get("user_api_key_spend")
|
||||
if user_api_key_spend is not None:
|
||||
try:
|
||||
spend_metrics["user_api_key_spend"] = float(user_api_key_spend)
|
||||
except (ValueError, TypeError):
|
||||
verbose_logger.debug(
|
||||
f"Invalid user_api_key_spend value: {user_api_key_spend}"
|
||||
)
|
||||
|
||||
# API key budget reset datetime
|
||||
user_api_key_budget_reset_at = metadata.get("user_api_key_budget_reset_at")
|
||||
if user_api_key_budget_reset_at is not None:
|
||||
try:
|
||||
from datetime import datetime, timezone
|
||||
|
||||
budget_reset_at = None
|
||||
if isinstance(user_api_key_budget_reset_at, str):
|
||||
# Handle ISO format strings that might have 'Z' suffix
|
||||
iso_string = user_api_key_budget_reset_at.replace("Z", "+00:00")
|
||||
budget_reset_at = datetime.fromisoformat(iso_string)
|
||||
elif isinstance(user_api_key_budget_reset_at, datetime):
|
||||
budget_reset_at = user_api_key_budget_reset_at
|
||||
|
||||
if budget_reset_at is not None:
|
||||
# Preserve timezone info if already present
|
||||
if budget_reset_at.tzinfo is None:
|
||||
budget_reset_at = budget_reset_at.replace(tzinfo=timezone.utc)
|
||||
|
||||
# Convert to ISO string format for JSON serialization
|
||||
# This prevents circular reference issues and ensures proper timezone representation
|
||||
iso_string = budget_reset_at.isoformat()
|
||||
spend_metrics["user_api_key_budget_reset_at"] = iso_string
|
||||
|
||||
# Debug logging to verify the conversion
|
||||
verbose_logger.debug(
|
||||
f"Converted budget_reset_at to ISO format: {iso_string}"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error processing budget reset datetime: {e}")
|
||||
verbose_logger.debug(f"Original value: {user_api_key_budget_reset_at}")
|
||||
|
||||
return spend_metrics
|
||||
|
||||
def _process_input_messages_preserving_tool_calls(
|
||||
self, messages: List[Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Process input messages while preserving tool_calls and tool message types.
|
||||
|
||||
This bypasses the lossy string conversion when tool calls are present,
|
||||
allowing complex nested tool_calls objects to be preserved for Datadog.
|
||||
"""
|
||||
processed = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict):
|
||||
# Preserve messages with tool_calls or tool role as-is
|
||||
if "tool_calls" in msg or msg.get("role") == "tool":
|
||||
processed.append(msg)
|
||||
else:
|
||||
# For regular messages, still apply string conversion
|
||||
converted = (
|
||||
handle_any_messages_to_chat_completion_str_messages_conversion(
|
||||
[msg]
|
||||
)
|
||||
)
|
||||
processed.extend(converted)
|
||||
else:
|
||||
# For non-dict messages, apply string conversion
|
||||
converted = (
|
||||
handle_any_messages_to_chat_completion_str_messages_conversion(
|
||||
[msg]
|
||||
)
|
||||
)
|
||||
processed.extend(converted)
|
||||
return processed
|
||||
|
||||
@staticmethod
|
||||
def _tool_calls_kv_pair(tool_calls: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract tool call information into key-value pairs for Datadog metadata.
|
||||
|
||||
Similar to OpenTelemetry's implementation but adapted for Datadog's format.
|
||||
"""
|
||||
kv_pairs: Dict[str, Any] = {}
|
||||
for idx, tool_call in enumerate(tool_calls):
|
||||
try:
|
||||
# Extract tool call ID
|
||||
tool_id = tool_call.get("id")
|
||||
if tool_id:
|
||||
kv_pairs[f"tool_calls.{idx}.id"] = tool_id
|
||||
|
||||
# Extract tool call type
|
||||
tool_type = tool_call.get("type")
|
||||
if tool_type:
|
||||
kv_pairs[f"tool_calls.{idx}.type"] = tool_type
|
||||
|
||||
# Extract function information
|
||||
function = tool_call.get("function")
|
||||
if function:
|
||||
function_name = function.get("name")
|
||||
if function_name:
|
||||
kv_pairs[f"tool_calls.{idx}.function.name"] = function_name
|
||||
|
||||
function_arguments = function.get("arguments")
|
||||
if function_arguments:
|
||||
# Store arguments as JSON string for Datadog
|
||||
if isinstance(function_arguments, str):
|
||||
kv_pairs[
|
||||
f"tool_calls.{idx}.function.arguments"
|
||||
] = function_arguments
|
||||
else:
|
||||
import json
|
||||
|
||||
kv_pairs[
|
||||
f"tool_calls.{idx}.function.arguments"
|
||||
] = json.dumps(function_arguments)
|
||||
except (KeyError, TypeError, ValueError) as e:
|
||||
verbose_logger.debug(
|
||||
f"DataDogLLMObs: Error processing tool call {idx}: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
return kv_pairs
|
||||
|
||||
def _extract_tool_call_metadata(
|
||||
self, standard_logging_payload: StandardLoggingPayload
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract tool call information from both input messages and response for Datadog metadata.
|
||||
"""
|
||||
tool_call_metadata: Dict[str, Any] = {}
|
||||
|
||||
try:
|
||||
# Extract tool calls from input messages
|
||||
messages = standard_logging_payload.get("messages", [])
|
||||
if messages and isinstance(messages, list):
|
||||
for message in messages:
|
||||
if isinstance(message, dict) and "tool_calls" in message:
|
||||
tool_calls = message.get("tool_calls")
|
||||
if tool_calls:
|
||||
input_tool_calls_kv = self._tool_calls_kv_pair(tool_calls)
|
||||
# Prefix with "input_" to distinguish from response tool calls
|
||||
for key, value in input_tool_calls_kv.items():
|
||||
tool_call_metadata[f"input_{key}"] = value
|
||||
|
||||
# Extract tool calls from response
|
||||
response_obj = standard_logging_payload.get("response")
|
||||
if response_obj and isinstance(response_obj, dict):
|
||||
choices = response_obj.get("choices", [])
|
||||
for choice in choices:
|
||||
if isinstance(choice, dict):
|
||||
message = choice.get("message")
|
||||
if message and isinstance(message, dict):
|
||||
tool_calls = message.get("tool_calls")
|
||||
if tool_calls:
|
||||
response_tool_calls_kv = self._tool_calls_kv_pair(
|
||||
tool_calls
|
||||
)
|
||||
# Prefix with "output_" to distinguish from input tool calls
|
||||
for key, value in response_tool_calls_kv.items():
|
||||
tool_call_metadata[f"output_{key}"] = value
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"DataDogLLMObs: Error extracting tool call metadata: {str(e)}"
|
||||
)
|
||||
|
||||
return tool_call_metadata
|
||||
@@ -0,0 +1,286 @@
|
||||
import asyncio
|
||||
import gzip
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.integrations.datadog.datadog_handler import (
|
||||
get_datadog_env,
|
||||
get_datadog_hostname,
|
||||
get_datadog_pod_name,
|
||||
get_datadog_service,
|
||||
)
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
|
||||
from litellm.types.integrations.datadog_metrics import (
|
||||
DatadogMetricPoint,
|
||||
DatadogMetricSeries,
|
||||
DatadogMetricsPayload,
|
||||
)
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
class DatadogMetricsLogger(CustomBatchLogger):
|
||||
def __init__(self, start_periodic_flush: bool = True, **kwargs):
|
||||
self.dd_api_key = os.getenv("DD_API_KEY")
|
||||
self.dd_app_key = os.getenv("DD_APP_KEY")
|
||||
self.dd_site = os.getenv("DD_SITE", "datadoghq.com")
|
||||
|
||||
if not self.dd_api_key:
|
||||
verbose_logger.warning(
|
||||
"Datadog Metrics: DD_API_KEY is required. Integration will not work."
|
||||
)
|
||||
|
||||
self.upload_url = f"https://api.{self.dd_site}/api/v2/series"
|
||||
|
||||
self.async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
|
||||
# Initialize lock
|
||||
self.flush_lock = asyncio.Lock()
|
||||
|
||||
# Only set flush_lock if not already provided by caller
|
||||
if "flush_lock" not in kwargs:
|
||||
kwargs["flush_lock"] = self.flush_lock
|
||||
|
||||
# Send metrics more quickly to datadog (every 5 seconds)
|
||||
if "flush_interval" not in kwargs:
|
||||
kwargs["flush_interval"] = 5
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Start periodic flush task only if instructed
|
||||
if start_periodic_flush:
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
|
||||
def _extract_tags(
|
||||
self,
|
||||
log: StandardLoggingPayload,
|
||||
status_code: Optional[Union[str, int]] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Builds the list of tags for a Datadog metric point
|
||||
"""
|
||||
# Base tags
|
||||
tags = [
|
||||
f"env:{get_datadog_env()}",
|
||||
f"service:{get_datadog_service()}",
|
||||
f"version:{os.getenv('DD_VERSION', 'unknown')}",
|
||||
f"HOSTNAME:{get_datadog_hostname()}",
|
||||
f"POD_NAME:{get_datadog_pod_name()}",
|
||||
]
|
||||
|
||||
# Add metric-specific tags
|
||||
if provider := log.get("custom_llm_provider"):
|
||||
tags.append(f"provider:{provider}")
|
||||
|
||||
if model := log.get("model"):
|
||||
tags.append(f"model_name:{model}")
|
||||
|
||||
if model_group := log.get("model_group"):
|
||||
tags.append(f"model_group:{model_group}")
|
||||
|
||||
if status_code is not None:
|
||||
tags.append(f"status_code:{status_code}")
|
||||
|
||||
# Extract team tag
|
||||
metadata = log.get("metadata", {}) or {}
|
||||
team_tag = (
|
||||
metadata.get("user_api_key_team_alias")
|
||||
or metadata.get("team_alias") # type: ignore
|
||||
or metadata.get("user_api_key_team_id")
|
||||
or metadata.get("team_id") # type: ignore
|
||||
)
|
||||
|
||||
if team_tag:
|
||||
tags.append(f"team:{team_tag}")
|
||||
|
||||
return tags
|
||||
|
||||
def _add_metrics_from_log(
|
||||
self,
|
||||
log: StandardLoggingPayload,
|
||||
kwargs: dict,
|
||||
status_code: Union[str, int] = "200",
|
||||
):
|
||||
"""
|
||||
Extracts latencies and appends Datadog metric series to the queue
|
||||
"""
|
||||
tags = self._extract_tags(log, status_code=status_code)
|
||||
|
||||
# We record metrics with the end_time as the timestamp for the point
|
||||
end_time_dt = kwargs.get("end_time") or datetime.now()
|
||||
timestamp = int(end_time_dt.timestamp())
|
||||
|
||||
# 1. Total Request Latency Metric (End to End)
|
||||
start_time_dt = kwargs.get("start_time")
|
||||
if start_time_dt and end_time_dt:
|
||||
total_duration = (end_time_dt - start_time_dt).total_seconds()
|
||||
series_total_latency: DatadogMetricSeries = {
|
||||
"metric": "litellm.request.total_latency",
|
||||
"type": 3, # gauge
|
||||
"points": [{"timestamp": timestamp, "value": total_duration}],
|
||||
"tags": tags,
|
||||
}
|
||||
self.log_queue.append(series_total_latency)
|
||||
|
||||
# 2. LLM API Latency Metric (Provider alone)
|
||||
api_call_start_time = kwargs.get("api_call_start_time")
|
||||
if api_call_start_time and end_time_dt:
|
||||
llm_api_duration = (end_time_dt - api_call_start_time).total_seconds()
|
||||
series_llm_latency: DatadogMetricSeries = {
|
||||
"metric": "litellm.llm_api.latency",
|
||||
"type": 3, # gauge
|
||||
"points": [{"timestamp": timestamp, "value": llm_api_duration}],
|
||||
"tags": tags,
|
||||
}
|
||||
self.log_queue.append(series_llm_latency)
|
||||
|
||||
# 3. Request Count / Status Code
|
||||
series_count: DatadogMetricSeries = {
|
||||
"metric": "litellm.llm_api.request_count",
|
||||
"type": 1, # count
|
||||
"points": [{"timestamp": timestamp, "value": 1.0}],
|
||||
"tags": tags,
|
||||
"interval": self.flush_interval,
|
||||
}
|
||||
self.log_queue.append(series_count)
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
if standard_logging_object is None:
|
||||
return
|
||||
|
||||
self._add_metrics_from_log(
|
||||
log=standard_logging_object, kwargs=kwargs, status_code="200"
|
||||
)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.flush_queue()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog Metrics: Error in async_log_success_event: {str(e)}"
|
||||
)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
if standard_logging_object is None:
|
||||
return
|
||||
|
||||
# Extract status code from error information
|
||||
status_code = "500" # default
|
||||
error_information = (
|
||||
standard_logging_object.get("error_information", {}) or {}
|
||||
)
|
||||
error_code = error_information.get("error_code") # type: ignore
|
||||
if error_code is not None:
|
||||
status_code = str(error_code)
|
||||
|
||||
self._add_metrics_from_log(
|
||||
log=standard_logging_object, kwargs=kwargs, status_code=status_code
|
||||
)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.flush_queue()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog Metrics: Error in async_log_failure_event: {str(e)}"
|
||||
)
|
||||
|
||||
async def async_send_batch(self):
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
batch = self.log_queue.copy()
|
||||
payload_data: DatadogMetricsPayload = {"series": batch}
|
||||
|
||||
try:
|
||||
await self._upload_to_datadog(payload_data)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Datadog Metrics: Error in async_send_batch: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def _upload_to_datadog(self, payload: DatadogMetricsPayload):
|
||||
if not self.dd_api_key:
|
||||
return
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"DD-API-KEY": self.dd_api_key,
|
||||
}
|
||||
|
||||
if self.dd_app_key:
|
||||
headers["DD-APPLICATION-KEY"] = self.dd_app_key
|
||||
|
||||
json_data = safe_dumps(payload)
|
||||
compressed_data = gzip.compress(json_data.encode("utf-8"))
|
||||
headers["Content-Encoding"] = "gzip"
|
||||
|
||||
response = await self.async_client.post(
|
||||
self.upload_url, content=compressed_data, headers=headers # type: ignore
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Datadog Metrics: Uploaded {len(payload['series'])} metric points. Status: {response.status_code}"
|
||||
)
|
||||
|
||||
async def async_health_check(self) -> IntegrationHealthCheckStatus:
|
||||
"""
|
||||
Check if the service is healthy
|
||||
"""
|
||||
try:
|
||||
# Send a test metric point to Datadog
|
||||
test_metric_point: DatadogMetricPoint = {
|
||||
"timestamp": int(time.time()),
|
||||
"value": 1.0,
|
||||
}
|
||||
test_metric_series: DatadogMetricSeries = {
|
||||
"metric": "litellm.health_check",
|
||||
"type": 3, # Gauge
|
||||
"points": [test_metric_point],
|
||||
"tags": ["env:health_check"],
|
||||
}
|
||||
|
||||
payload_data: DatadogMetricsPayload = {"series": [test_metric_series]}
|
||||
|
||||
await self._upload_to_datadog(payload_data)
|
||||
|
||||
return IntegrationHealthCheckStatus(
|
||||
status="healthy",
|
||||
error_message=None,
|
||||
)
|
||||
except Exception as e:
|
||||
return IntegrationHealthCheckStatus(
|
||||
status="unhealthy",
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
async def get_request_response_payload(
|
||||
self,
|
||||
request_id: str,
|
||||
start_time_utc: Optional[datetime],
|
||||
end_time_utc: Optional[datetime],
|
||||
) -> Optional[dict]:
|
||||
pass
|
||||
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
Mock client for Datadog integration testing.
|
||||
|
||||
This module intercepts Datadog API calls and returns successful mock responses,
|
||||
allowing full code execution without making actual network calls.
|
||||
|
||||
Usage:
|
||||
Set DATADOG_MOCK=true in environment variables or config to enable mock mode.
|
||||
"""
|
||||
|
||||
from litellm.integrations.mock_client_factory import (
|
||||
MockClientConfig,
|
||||
create_mock_client_factory,
|
||||
)
|
||||
|
||||
# Create mock client using factory
|
||||
_config = MockClientConfig(
|
||||
name="DATADOG",
|
||||
env_var="DATADOG_MOCK",
|
||||
default_latency_ms=100,
|
||||
default_status_code=202,
|
||||
default_json_data={"status": "ok"},
|
||||
url_matchers=[
|
||||
".datadoghq.com",
|
||||
"datadoghq.com",
|
||||
],
|
||||
patch_async_handler=True,
|
||||
patch_sync_client=True,
|
||||
)
|
||||
|
||||
create_mock_datadog_client, should_use_datadog_mock = create_mock_client_factory(
|
||||
_config
|
||||
)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .deepeval import DeepEvalLogger
|
||||
|
||||
__all__ = ["DeepEvalLogger"]
|
||||
@@ -0,0 +1,120 @@
|
||||
# duplicate -> https://github.com/confident-ai/deepeval/blob/main/deepeval/confident/api.py
|
||||
import logging
|
||||
import httpx
|
||||
from enum import Enum
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
DEEPEVAL_BASE_URL = "https://deepeval.confident-ai.com"
|
||||
DEEPEVAL_BASE_URL_EU = "https://eu.deepeval.confident-ai.com"
|
||||
API_BASE_URL = "https://api.confident-ai.com"
|
||||
API_BASE_URL_EU = "https://eu.api.confident-ai.com"
|
||||
retryable_exceptions = httpx.HTTPError
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
|
||||
|
||||
def log_retry_error(details):
|
||||
exception = details.get("exception")
|
||||
tries = details.get("tries")
|
||||
if exception:
|
||||
logging.error(f"Confident AI Error: {exception}. Retrying: {tries} time(s)...")
|
||||
else:
|
||||
logging.error(f"Retrying: {tries} time(s)...")
|
||||
|
||||
|
||||
class HttpMethods(Enum):
|
||||
GET = "GET"
|
||||
POST = "POST"
|
||||
DELETE = "DELETE"
|
||||
PUT = "PUT"
|
||||
|
||||
|
||||
class Endpoints(Enum):
|
||||
DATASET_ENDPOINT = "/v1/dataset"
|
||||
TEST_RUN_ENDPOINT = "/v1/test-run"
|
||||
TRACING_ENDPOINT = "/v1/tracing"
|
||||
EVENT_ENDPOINT = "/v1/event"
|
||||
FEEDBACK_ENDPOINT = "/v1/feedback"
|
||||
PROMPT_ENDPOINT = "/v1/prompt"
|
||||
RECOMMEND_ENDPOINT = "/v1/recommend-metrics"
|
||||
EVALUATE_ENDPOINT = "/evaluate"
|
||||
GUARD_ENDPOINT = "/guard"
|
||||
GUARDRAILS_ENDPOINT = "/guardrails"
|
||||
BASELINE_ATTACKS_ENDPOINT = "/generate-baseline-attacks"
|
||||
|
||||
|
||||
class Api:
|
||||
def __init__(self, api_key: str, base_url=None):
|
||||
self.api_key = api_key
|
||||
self._headers = {
|
||||
"Content-Type": "application/json",
|
||||
# "User-Agent": "Python/Requests",
|
||||
"CONFIDENT_API_KEY": api_key,
|
||||
}
|
||||
# using the global non-eu variable for base url
|
||||
self.base_api_url = base_url or API_BASE_URL
|
||||
self.sync_http_handler = HTTPHandler()
|
||||
self.async_http_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
|
||||
def _http_request(
|
||||
self, method: str, url: str, headers=None, json=None, params=None
|
||||
):
|
||||
if method != "POST":
|
||||
raise Exception("Only POST requests are supported")
|
||||
try:
|
||||
self.sync_http_handler.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
json=json,
|
||||
params=params,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(f"DeepEval logging error: {e.response.text}")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def send_request(
|
||||
self, method: HttpMethods, endpoint: Endpoints, body=None, params=None
|
||||
):
|
||||
url = f"{self.base_api_url}{endpoint.value}"
|
||||
res = self._http_request(
|
||||
method=method.value,
|
||||
url=url,
|
||||
headers=self._headers,
|
||||
json=body,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if res.status_code == 200:
|
||||
try:
|
||||
return res.json()
|
||||
except ValueError:
|
||||
return res.text
|
||||
else:
|
||||
verbose_logger.debug(res.json())
|
||||
raise Exception(res.json().get("error", res.text))
|
||||
|
||||
async def a_send_request(
|
||||
self, method: HttpMethods, endpoint: Endpoints, body=None, params=None
|
||||
):
|
||||
if method != HttpMethods.POST:
|
||||
raise Exception("Only POST requests are supported")
|
||||
|
||||
url = f"{self.base_api_url}{endpoint.value}"
|
||||
try:
|
||||
await self.async_http_handler.post(
|
||||
url=url,
|
||||
headers=self._headers,
|
||||
json=body,
|
||||
params=params,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(f"DeepEval logging error: {e.response.text}")
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,175 @@
|
||||
import os
|
||||
from litellm._uuid import uuid
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.deepeval.api import Api, Endpoints, HttpMethods
|
||||
from litellm.integrations.deepeval.types import (
|
||||
BaseApiSpan,
|
||||
SpanApiType,
|
||||
TraceApi,
|
||||
TraceSpanApiStatus,
|
||||
)
|
||||
from litellm.integrations.deepeval.utils import (
|
||||
to_zod_compatible_iso,
|
||||
validate_environment,
|
||||
)
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
# This file includes the custom callbacks for LiteLLM Proxy
|
||||
# Once defined, these can be passed in proxy_config.yaml
|
||||
class DeepEvalLogger(CustomLogger):
|
||||
"""Logs litellm traces to DeepEval's platform."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
api_key = os.getenv("CONFIDENT_API_KEY")
|
||||
self.litellm_environment = os.getenv("LITELM_ENVIRONMENT", "development")
|
||||
validate_environment(self.litellm_environment)
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"Please set 'CONFIDENT_API_KEY=<>' in your environment variables."
|
||||
)
|
||||
self.api = Api(api_key=api_key)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""Logs a success event to DeepEval's platform."""
|
||||
self._sync_event_handler(
|
||||
kwargs, response_obj, start_time, end_time, is_success=True
|
||||
)
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""Logs a failure event to DeepEval's platform."""
|
||||
self._sync_event_handler(
|
||||
kwargs, response_obj, start_time, end_time, is_success=False
|
||||
)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""Logs a failure event to DeepEval's platform."""
|
||||
await self._async_event_handler(
|
||||
kwargs, response_obj, start_time, end_time, is_success=False
|
||||
)
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""Logs a success event to DeepEval's platform."""
|
||||
await self._async_event_handler(
|
||||
kwargs, response_obj, start_time, end_time, is_success=True
|
||||
)
|
||||
|
||||
def _prepare_trace_api(
|
||||
self, kwargs, response_obj, start_time, end_time, is_success
|
||||
):
|
||||
_start_time = to_zod_compatible_iso(start_time)
|
||||
_end_time = to_zod_compatible_iso(end_time)
|
||||
_standard_logging_object = kwargs.get("standard_logging_object", {})
|
||||
base_api_span = self._create_base_api_span(
|
||||
kwargs,
|
||||
standard_logging_object=_standard_logging_object,
|
||||
start_time=_start_time,
|
||||
end_time=_end_time,
|
||||
is_success=is_success,
|
||||
)
|
||||
trace_api = self._create_trace_api(
|
||||
base_api_span,
|
||||
standard_logging_object=_standard_logging_object,
|
||||
start_time=_start_time,
|
||||
end_time=_end_time,
|
||||
litellm_environment=self.litellm_environment,
|
||||
)
|
||||
|
||||
body = {}
|
||||
|
||||
try:
|
||||
body = trace_api.model_dump(by_alias=True, exclude_none=True)
|
||||
except AttributeError:
|
||||
# Pydantic version below 2.0
|
||||
body = trace_api.dict(by_alias=True, exclude_none=True)
|
||||
return body
|
||||
|
||||
def _sync_event_handler(
|
||||
self, kwargs, response_obj, start_time, end_time, is_success
|
||||
):
|
||||
body = self._prepare_trace_api(
|
||||
kwargs, response_obj, start_time, end_time, is_success
|
||||
)
|
||||
try:
|
||||
response = self.api.send_request(
|
||||
method=HttpMethods.POST,
|
||||
endpoint=Endpoints.TRACING_ENDPOINT,
|
||||
body=body,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
verbose_logger.debug(
|
||||
"DeepEvalLogger: sync_log_failure_event: Api response %s", response
|
||||
)
|
||||
|
||||
async def _async_event_handler(
|
||||
self, kwargs, response_obj, start_time, end_time, is_success
|
||||
):
|
||||
body = self._prepare_trace_api(
|
||||
kwargs, response_obj, start_time, end_time, is_success
|
||||
)
|
||||
response = await self.api.a_send_request(
|
||||
method=HttpMethods.POST,
|
||||
endpoint=Endpoints.TRACING_ENDPOINT,
|
||||
body=body,
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"DeepEvalLogger: async_event_handler: Api response %s", response
|
||||
)
|
||||
|
||||
def _create_base_api_span(
|
||||
self, kwargs, standard_logging_object, start_time, end_time, is_success
|
||||
):
|
||||
# extract usage
|
||||
usage = standard_logging_object.get("response", {}).get("usage", {})
|
||||
if is_success:
|
||||
output = (
|
||||
standard_logging_object.get("response", {})
|
||||
.get("choices", [{}])[0]
|
||||
.get("message", {})
|
||||
.get("content", "NO_OUTPUT")
|
||||
)
|
||||
else:
|
||||
output = str(standard_logging_object.get("error_string", ""))
|
||||
return BaseApiSpan(
|
||||
uuid=standard_logging_object.get("id", uuid.uuid4()),
|
||||
name=(
|
||||
"litellm_success_callback" if is_success else "litellm_failure_callback"
|
||||
),
|
||||
status=(
|
||||
TraceSpanApiStatus.SUCCESS if is_success else TraceSpanApiStatus.ERRORED
|
||||
),
|
||||
type=SpanApiType.LLM,
|
||||
traceUuid=standard_logging_object.get("trace_id", uuid.uuid4()),
|
||||
startTime=str(start_time),
|
||||
endTime=str(end_time),
|
||||
input=kwargs.get("input", "NO_INPUT"),
|
||||
output=output,
|
||||
model=standard_logging_object.get("model", None),
|
||||
inputTokenCount=usage.get("prompt_tokens", None) if is_success else None,
|
||||
outputTokenCount=(
|
||||
usage.get("completion_tokens", None) if is_success else None
|
||||
),
|
||||
)
|
||||
|
||||
def _create_trace_api(
|
||||
self,
|
||||
base_api_span,
|
||||
standard_logging_object,
|
||||
start_time,
|
||||
end_time,
|
||||
litellm_environment,
|
||||
):
|
||||
return TraceApi(
|
||||
uuid=standard_logging_object.get("trace_id", uuid.uuid4()),
|
||||
baseSpans=[],
|
||||
agentSpans=[],
|
||||
llmSpans=[base_api_span],
|
||||
retrieverSpans=[],
|
||||
toolSpans=[],
|
||||
startTime=str(start_time),
|
||||
endTime=str(end_time),
|
||||
environment=litellm_environment,
|
||||
)
|
||||
@@ -0,0 +1,63 @@
|
||||
# Duplicate -> https://github.com/confident-ai/deepeval/blob/main/deepeval/tracing/api.py
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Union, Literal
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
|
||||
class SpanApiType(Enum):
|
||||
BASE = "base"
|
||||
AGENT = "agent"
|
||||
LLM = "llm"
|
||||
RETRIEVER = "retriever"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
span_api_type_literals = Literal["base", "agent", "llm", "retriever", "tool"]
|
||||
|
||||
|
||||
class TraceSpanApiStatus(Enum):
|
||||
SUCCESS = "SUCCESS"
|
||||
ERRORED = "ERRORED"
|
||||
|
||||
|
||||
class BaseApiSpan(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(use_enum_values=True)
|
||||
|
||||
uuid: str
|
||||
name: Optional[str] = None
|
||||
status: TraceSpanApiStatus
|
||||
type: SpanApiType
|
||||
trace_uuid: str = Field(alias="traceUuid")
|
||||
parent_uuid: Optional[str] = Field(None, alias="parentUuid")
|
||||
start_time: str = Field(alias="startTime")
|
||||
end_time: str = Field(alias="endTime")
|
||||
input: Optional[Union[Dict, list, str]] = None
|
||||
output: Optional[Union[Dict, list, str]] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
# llm
|
||||
model: Optional[str] = None
|
||||
input_token_count: Optional[int] = Field(None, alias="inputTokenCount")
|
||||
output_token_count: Optional[int] = Field(None, alias="outputTokenCount")
|
||||
cost_per_input_token: Optional[float] = Field(None, alias="costPerInputToken")
|
||||
cost_per_output_token: Optional[float] = Field(None, alias="costPerOutputToken")
|
||||
|
||||
|
||||
class TraceApi(BaseModel):
|
||||
uuid: str
|
||||
base_spans: List[BaseApiSpan] = Field(alias="baseSpans")
|
||||
agent_spans: List[BaseApiSpan] = Field(alias="agentSpans")
|
||||
llm_spans: List[BaseApiSpan] = Field(alias="llmSpans")
|
||||
retriever_spans: List[BaseApiSpan] = Field(alias="retrieverSpans")
|
||||
tool_spans: List[BaseApiSpan] = Field(alias="toolSpans")
|
||||
start_time: str = Field(alias="startTime")
|
||||
end_time: str = Field(alias="endTime")
|
||||
metadata: Optional[Dict[str, Any]] = Field(None)
|
||||
tags: Optional[List[str]] = Field(None)
|
||||
environment: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class Environment(Enum):
|
||||
PRODUCTION = "production"
|
||||
DEVELOPMENT = "development"
|
||||
STAGING = "staging"
|
||||
@@ -0,0 +1,18 @@
|
||||
from datetime import datetime, timezone
|
||||
from litellm.integrations.deepeval.types import Environment
|
||||
|
||||
|
||||
def to_zod_compatible_iso(dt: datetime) -> str:
|
||||
return (
|
||||
dt.astimezone(timezone.utc)
|
||||
.isoformat(timespec="milliseconds")
|
||||
.replace("+00:00", "Z")
|
||||
)
|
||||
|
||||
|
||||
def validate_environment(environment: str):
|
||||
if environment not in [env.value for env in Environment]:
|
||||
valid_values = ", ".join(f'"{env.value}"' for env in Environment)
|
||||
raise ValueError(
|
||||
f"Invalid environment: {environment}. Please use one of the following instead: {valid_values}"
|
||||
)
|
||||
@@ -0,0 +1,316 @@
|
||||
# LiteLLM Dotprompt Manager
|
||||
|
||||
A powerful prompt management system for LiteLLM that supports [Google's Dotprompt specification](https://google.github.io/dotprompt/getting-started/). This allows you to manage your AI prompts in organized `.prompt` files with YAML frontmatter, Handlebars templating, and full integration with LiteLLM's completion API.
|
||||
|
||||
## Features
|
||||
|
||||
- **📁 File-based prompt management**: Organize prompts in `.prompt` files
|
||||
- **🎯 YAML frontmatter**: Define model, parameters, and schemas in file headers
|
||||
- **🔧 Handlebars templating**: Use `{{variable}}` syntax with Jinja2 backend
|
||||
- **✅ Input validation**: Automatic validation against defined schemas
|
||||
- **🔗 LiteLLM integration**: Works seamlessly with `litellm.completion()`
|
||||
- **💬 Smart message parsing**: Converts prompts to proper chat messages
|
||||
- **⚙️ Parameter extraction**: Automatically applies model settings from prompts
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Create a `.prompt` file
|
||||
|
||||
Create a file called `chat_assistant.prompt`:
|
||||
|
||||
```yaml
|
||||
---
|
||||
model: gpt-4
|
||||
temperature: 0.7
|
||||
max_tokens: 150
|
||||
input:
|
||||
schema:
|
||||
user_message: string
|
||||
system_context?: string
|
||||
---
|
||||
|
||||
{% if system_context %}System: {{system_context}}
|
||||
|
||||
{% endif %}User: {{user_message}}
|
||||
```
|
||||
|
||||
### 2. Use with LiteLLM
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
litellm.set_global_prompt_directory("path/to/your/prompts")
|
||||
|
||||
# Use with completion - the model prefix 'dotprompt/' tells LiteLLM to use prompt management
|
||||
response = litellm.completion(
|
||||
model="dotprompt/gpt-4", # The actual model comes from the .prompt file
|
||||
prompt_id="chat_assistant",
|
||||
prompt_variables={
|
||||
"user_message": "What is machine learning?",
|
||||
"system_context": "You are a helpful AI tutor."
|
||||
},
|
||||
# Any additional messages will be appended after the prompt
|
||||
messages=[{"role": "user", "content": "Please explain it simply."}]
|
||||
)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
## Prompt File Format
|
||||
|
||||
### Basic Structure
|
||||
|
||||
```yaml
|
||||
---
|
||||
# Model configuration
|
||||
model: gpt-4
|
||||
temperature: 0.7
|
||||
max_tokens: 500
|
||||
|
||||
# Input schema (optional)
|
||||
input:
|
||||
schema:
|
||||
name: string
|
||||
age: integer
|
||||
preferences?: array
|
||||
---
|
||||
|
||||
# Template content using Handlebars syntax
|
||||
Hello {{name}}!
|
||||
|
||||
{% if age >= 18 %}
|
||||
You're an adult, so here are some mature recommendations:
|
||||
{% else %}
|
||||
Here are some age-appropriate suggestions:
|
||||
{% endif %}
|
||||
|
||||
{% for pref in preferences %}
|
||||
- Based on your interest in {{pref}}, I recommend...
|
||||
{% endfor %}
|
||||
```
|
||||
|
||||
### Supported Frontmatter Fields
|
||||
|
||||
- **`model`**: The LLM model to use (e.g., `gpt-4`, `claude-3-sonnet`)
|
||||
- **`input.schema`**: Define expected input variables and their types
|
||||
- **`output.format`**: Expected output format (`json`, `text`, etc.)
|
||||
- **`output.schema`**: Structure of expected output
|
||||
|
||||
### Additional Parameters
|
||||
|
||||
- **`temperature`**: Model temperature (0.0 to 1.0)
|
||||
- **`max_tokens`**: Maximum tokens to generate
|
||||
- **`top_p`**: Nucleus sampling parameter (0.0 to 1.0)
|
||||
- **`frequency_penalty`**: Frequency penalty (0.0 to 1.0)
|
||||
- **`presence_penalty`**: Presence penalty (0.0 to 1.0)
|
||||
- any other parameters that are not model or schema-related will be treated as optional parameters to the model.
|
||||
|
||||
### Input Schema Types
|
||||
|
||||
- `string` or `str`: Text values
|
||||
- `integer` or `int`: Whole numbers
|
||||
- `float`: Decimal numbers
|
||||
- `boolean` or `bool`: True/false values
|
||||
- `array` or `list`: Lists of values
|
||||
- `object` or `dict`: Key-value objects
|
||||
|
||||
Use `?` suffix for optional fields: `name?: string`
|
||||
|
||||
## Message Format Conversion
|
||||
|
||||
The dotprompt manager intelligently converts your rendered prompts into proper chat messages:
|
||||
|
||||
### Simple Text → User Message
|
||||
```yaml
|
||||
---
|
||||
model: gpt-4
|
||||
---
|
||||
Tell me about {{topic}}.
|
||||
```
|
||||
Becomes: `[{"role": "user", "content": "Tell me about AI."}]`
|
||||
|
||||
### Role-Based Format → Multiple Messages
|
||||
```yaml
|
||||
---
|
||||
model: gpt-4
|
||||
---
|
||||
System: You are a {{role}}.
|
||||
|
||||
User: {{question}}
|
||||
```
|
||||
|
||||
Becomes:
|
||||
```python
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is AI?"}
|
||||
]
|
||||
```
|
||||
|
||||
|
||||
## Example Prompts
|
||||
|
||||
### Data Extraction
|
||||
```yaml
|
||||
# extract_info.prompt
|
||||
---
|
||||
model: gemini/gemini-1.5-pro
|
||||
input:
|
||||
schema:
|
||||
text: string
|
||||
output:
|
||||
format: json
|
||||
schema:
|
||||
title?: string
|
||||
summary: string
|
||||
tags: array
|
||||
---
|
||||
|
||||
Extract the requested information from the given text. Return JSON format.
|
||||
|
||||
Text: {{text}}
|
||||
```
|
||||
|
||||
### Code Assistant
|
||||
```yaml
|
||||
# code_helper.prompt
|
||||
---
|
||||
model: claude-3-5-sonnet-20241022
|
||||
temperature: 0.2
|
||||
max_tokens: 2000
|
||||
input:
|
||||
schema:
|
||||
language: string
|
||||
task: string
|
||||
code?: string
|
||||
---
|
||||
|
||||
You are an expert {{language}} programmer.
|
||||
|
||||
Task: {{task}}
|
||||
|
||||
{% if code %}
|
||||
Current code:
|
||||
```{{language}}
|
||||
{{code}}
|
||||
```
|
||||
{% endif %}
|
||||
|
||||
Please provide a complete, well-documented solution.
|
||||
```
|
||||
|
||||
### Multi-turn Conversation
|
||||
```yaml
|
||||
# conversation.prompt
|
||||
---
|
||||
model: gpt-4
|
||||
temperature: 0.8
|
||||
input:
|
||||
schema:
|
||||
personality: string
|
||||
context: string
|
||||
---
|
||||
|
||||
System: You are a {{personality}}. {{context}}
|
||||
|
||||
User: Let's start our conversation.
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### PromptManager
|
||||
|
||||
The core class for managing `.prompt` files.
|
||||
|
||||
#### Methods
|
||||
|
||||
- **`__init__(prompt_directory: str)`**: Initialize with directory path
|
||||
- **`render(prompt_id: str, variables: dict) -> str`**: Render prompt with variables
|
||||
- **`list_prompts() -> List[str]`**: Get all available prompt IDs
|
||||
- **`get_prompt(prompt_id: str) -> PromptTemplate`**: Get prompt template object
|
||||
- **`get_prompt_metadata(prompt_id: str) -> dict`**: Get prompt metadata
|
||||
- **`reload_prompts() -> None`**: Reload all prompts from directory
|
||||
- **`add_prompt(prompt_id: str, content: str, metadata: dict)`**: Add prompt programmatically
|
||||
|
||||
### DotpromptManager
|
||||
|
||||
LiteLLM integration class extending `PromptManagementBase`.
|
||||
|
||||
#### Methods
|
||||
|
||||
- **`__init__(prompt_directory: str)`**: Initialize with directory path
|
||||
- **`should_run_prompt_management(prompt_id: str, params: dict) -> bool`**: Check if prompt exists
|
||||
- **`set_prompt_directory(directory: str)`**: Change prompt directory
|
||||
- **`reload_prompts()`**: Reload prompts from directory
|
||||
|
||||
### PromptTemplate
|
||||
|
||||
Represents a single prompt with metadata.
|
||||
|
||||
#### Properties
|
||||
|
||||
- **`content: str`**: The prompt template content
|
||||
- **`metadata: dict`**: Full metadata from frontmatter
|
||||
- **`model: str`**: Specified model name
|
||||
- **`temperature: float`**: Model temperature
|
||||
- **`max_tokens: int`**: Token limit
|
||||
- **`input_schema: dict`**: Input validation schema
|
||||
- **`output_format: str`**: Expected output format
|
||||
- **`output_schema: dict`**: Output structure schema
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Organize by purpose**: Group related prompts in subdirectories
|
||||
2. **Use descriptive names**: `extract_user_info.prompt` vs `prompt1.prompt`
|
||||
3. **Define schemas**: Always specify input schemas for validation
|
||||
4. **Version control**: Store `.prompt` files in git for change tracking
|
||||
5. **Test prompts**: Use the test framework to validate prompt behavior
|
||||
6. **Keep templates focused**: One prompt should do one thing well
|
||||
7. **Use includes**: Break complex prompts into reusable components
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Prompt not found**: Ensure the `.prompt` file exists and has correct extension
|
||||
```python
|
||||
# Check available prompts
|
||||
from litellm.integrations.dotprompt import get_dotprompt_manager
|
||||
manager = get_dotprompt_manager()
|
||||
print(manager.prompt_manager.list_prompts())
|
||||
```
|
||||
|
||||
**Template errors**: Verify Handlebars syntax and variable names
|
||||
```python
|
||||
# Test rendering directly
|
||||
manager.prompt_manager.render("my_prompt", {"test": "value"})
|
||||
```
|
||||
|
||||
**Model not working**: Check that model name in frontmatter is correct
|
||||
```python
|
||||
# Check prompt metadata
|
||||
metadata = manager.prompt_manager.get_prompt_metadata("my_prompt")
|
||||
print(metadata)
|
||||
```
|
||||
|
||||
### Validation Errors
|
||||
|
||||
Input validation failures show helpful error messages:
|
||||
```
|
||||
ValueError: Invalid type for field 'age': expected int, got str
|
||||
```
|
||||
|
||||
Make sure your variables match the defined schema types.
|
||||
|
||||
## Contributing
|
||||
|
||||
The LiteLLM Dotprompt manager follows the [Dotprompt specification](https://google.github.io/dotprompt/) for maximum compatibility. When contributing:
|
||||
|
||||
1. Ensure compatibility with existing `.prompt` files
|
||||
2. Add tests for new features
|
||||
3. Update documentation
|
||||
4. Follow the existing code style
|
||||
|
||||
## License
|
||||
|
||||
This prompt management system is part of LiteLLM and follows the same license terms.
|
||||
@@ -0,0 +1,91 @@
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .prompt_manager import PromptManager, PromptTemplate
|
||||
from litellm.types.prompts.init_prompts import PromptLiteLLMParams, PromptSpec
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
|
||||
from litellm.types.prompts.init_prompts import SupportedPromptIntegrations
|
||||
|
||||
from .dotprompt_manager import DotpromptManager
|
||||
|
||||
# Global instances
|
||||
global_prompt_directory: Optional[str] = None
|
||||
global_prompt_manager: Optional["PromptManager"] = None
|
||||
|
||||
|
||||
def set_global_prompt_directory(directory: str) -> None:
|
||||
"""
|
||||
Set the global prompt directory for dotprompt files.
|
||||
|
||||
Args:
|
||||
directory: Path to directory containing .prompt files
|
||||
"""
|
||||
import litellm
|
||||
|
||||
litellm.global_prompt_directory = directory # type: ignore
|
||||
|
||||
|
||||
def _get_prompt_data_from_dotprompt_content(dotprompt_content: str) -> dict:
|
||||
"""
|
||||
Get the prompt data from the dotprompt content.
|
||||
|
||||
The UI stores prompts under `dotprompt_content` in the database. This function parses the content and returns the prompt data in the format expected by the prompt manager.
|
||||
"""
|
||||
from .prompt_manager import PromptManager
|
||||
|
||||
# Parse the dotprompt content to extract frontmatter and content
|
||||
temp_manager = PromptManager()
|
||||
metadata, content = temp_manager._parse_frontmatter(dotprompt_content)
|
||||
|
||||
# Convert to prompt_data format
|
||||
return {"content": content.strip(), "metadata": metadata}
|
||||
|
||||
|
||||
def prompt_initializer(
|
||||
litellm_params: "PromptLiteLLMParams", prompt_spec: "PromptSpec"
|
||||
) -> "CustomPromptManagement":
|
||||
"""
|
||||
Initialize a prompt from a .prompt file.
|
||||
"""
|
||||
prompt_directory = getattr(litellm_params, "prompt_directory", None)
|
||||
prompt_data = getattr(litellm_params, "prompt_data", None)
|
||||
prompt_id = getattr(litellm_params, "prompt_id", None)
|
||||
if prompt_directory:
|
||||
raise ValueError(
|
||||
"Cannot set prompt_directory when working with prompt_initializer. Needs to be a specific dotprompt file"
|
||||
)
|
||||
|
||||
prompt_file = getattr(litellm_params, "prompt_file", None)
|
||||
|
||||
# Handle dotprompt_content from database
|
||||
dotprompt_content = getattr(litellm_params, "dotprompt_content", None)
|
||||
if dotprompt_content and not prompt_data and not prompt_file:
|
||||
prompt_data = _get_prompt_data_from_dotprompt_content(dotprompt_content)
|
||||
|
||||
try:
|
||||
dot_prompt_manager = DotpromptManager(
|
||||
prompt_directory=prompt_directory,
|
||||
prompt_data=prompt_data,
|
||||
prompt_file=prompt_file,
|
||||
prompt_id=prompt_id,
|
||||
)
|
||||
|
||||
return dot_prompt_manager
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
prompt_initializer_registry = {
|
||||
SupportedPromptIntegrations.DOT_PROMPT.value: prompt_initializer,
|
||||
}
|
||||
|
||||
# Export public API
|
||||
__all__ = [
|
||||
"PromptManager",
|
||||
"DotpromptManager",
|
||||
"PromptTemplate",
|
||||
"set_global_prompt_directory",
|
||||
"global_prompt_directory",
|
||||
"global_prompt_manager",
|
||||
]
|
||||
@@ -0,0 +1,378 @@
|
||||
"""
|
||||
Dotprompt manager that integrates with LiteLLM's prompt management system.
|
||||
Builds on top of PromptManagementBase to provide .prompt file support.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
from litellm.integrations.prompt_management_base import PromptManagementClient
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.prompts.init_prompts import PromptSpec
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
from .prompt_manager import PromptManager, PromptTemplate
|
||||
|
||||
|
||||
class DotpromptManager(CustomPromptManagement):
|
||||
"""
|
||||
Dotprompt manager that integrates with LiteLLM's prompt management system.
|
||||
|
||||
This class enables using .prompt files with the litellm completion() function
|
||||
by implementing the PromptManagementBase interface.
|
||||
|
||||
Usage:
|
||||
# Set global prompt directory
|
||||
litellm.prompt_directory = "path/to/prompts"
|
||||
|
||||
# Use with completion
|
||||
response = litellm.completion(
|
||||
model="dotprompt/gpt-4",
|
||||
prompt_id="my_prompt",
|
||||
prompt_variables={"variable": "value"},
|
||||
messages=[{"role": "user", "content": "This will be combined with the prompt"}]
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_directory: Optional[str] = None,
|
||||
prompt_file: Optional[str] = None,
|
||||
prompt_data: Optional[Union[dict, str]] = None,
|
||||
prompt_id: Optional[str] = None,
|
||||
):
|
||||
import litellm
|
||||
|
||||
self.prompt_directory = prompt_directory or litellm.global_prompt_directory
|
||||
# Support for JSON-based prompts stored in memory/database
|
||||
if isinstance(prompt_data, str):
|
||||
self.prompt_data = json.loads(prompt_data)
|
||||
else:
|
||||
self.prompt_data = prompt_data or {}
|
||||
|
||||
self._prompt_manager: Optional[PromptManager] = None
|
||||
self.prompt_file = prompt_file
|
||||
self.prompt_id = prompt_id
|
||||
|
||||
@property
|
||||
def integration_name(self) -> str:
|
||||
"""Integration name used in model names like 'dotprompt/gpt-4'."""
|
||||
return "dotprompt"
|
||||
|
||||
@property
|
||||
def prompt_manager(self) -> PromptManager:
|
||||
"""Lazy-load the prompt manager."""
|
||||
if self._prompt_manager is None:
|
||||
if (
|
||||
self.prompt_directory is None
|
||||
and not self.prompt_data
|
||||
and not self.prompt_file
|
||||
):
|
||||
raise ValueError(
|
||||
"Either prompt_directory or prompt_data must be set before using dotprompt manager. "
|
||||
"Set litellm.global_prompt_directory, initialize with prompt_directory parameter, or provide prompt_data."
|
||||
)
|
||||
self._prompt_manager = PromptManager(
|
||||
prompt_directory=self.prompt_directory,
|
||||
prompt_data=self.prompt_data,
|
||||
prompt_file=self.prompt_file,
|
||||
prompt_id=self.prompt_id,
|
||||
)
|
||||
return self._prompt_manager
|
||||
|
||||
def should_run_prompt_management(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if prompt management should run based on the prompt_id.
|
||||
|
||||
Returns True if the prompt_id exists in our prompt manager.
|
||||
"""
|
||||
if prompt_id is None:
|
||||
return False
|
||||
try:
|
||||
return prompt_id in self.prompt_manager.list_prompts()
|
||||
except Exception:
|
||||
# If there's any error accessing prompts, don't run prompt management
|
||||
return False
|
||||
|
||||
def _compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> PromptManagementClient:
|
||||
"""
|
||||
Compile a .prompt file into a PromptManagementClient structure.
|
||||
|
||||
This method:
|
||||
1. Loads the prompt template from the .prompt file (with optional version)
|
||||
2. Renders it with the provided variables
|
||||
3. Converts the rendered text into chat messages
|
||||
4. Extracts model and optional parameters from metadata
|
||||
"""
|
||||
|
||||
if prompt_id is None:
|
||||
raise ValueError("prompt_id is required for dotprompt manager")
|
||||
|
||||
try:
|
||||
# Get the prompt template (versioned or base)
|
||||
template = self.prompt_manager.get_prompt(
|
||||
prompt_id=prompt_id, version=prompt_version
|
||||
)
|
||||
if template is None:
|
||||
version_str = f" (version {prompt_version})" if prompt_version else ""
|
||||
raise ValueError(
|
||||
f"Prompt '{prompt_id}'{version_str} not found in prompt directory"
|
||||
)
|
||||
|
||||
# Render the template with variables (pass version for proper lookup)
|
||||
rendered_content = self.prompt_manager.render(
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
version=prompt_version,
|
||||
)
|
||||
|
||||
# Convert rendered content to chat messages
|
||||
messages = self._convert_to_messages(rendered_content)
|
||||
|
||||
# Extract model from metadata (if specified)
|
||||
template_model = template.model
|
||||
|
||||
# Extract optional parameters from metadata
|
||||
optional_params = self._extract_optional_params(template)
|
||||
|
||||
return PromptManagementClient(
|
||||
prompt_id=prompt_id,
|
||||
prompt_template=messages,
|
||||
prompt_template_model=template_model,
|
||||
prompt_template_optional_params=optional_params,
|
||||
completed_messages=None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error compiling prompt '{prompt_id}': {e}")
|
||||
|
||||
async def async_compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> PromptManagementClient:
|
||||
"""
|
||||
Async version of compile prompt helper. Since dotprompt operations are synchronous,
|
||||
this simply delegates to the sync version.
|
||||
"""
|
||||
if prompt_id is None:
|
||||
raise ValueError("prompt_id is required for dotprompt manager")
|
||||
|
||||
return self._compile_prompt_helper(
|
||||
prompt_id=prompt_id,
|
||||
prompt_spec=prompt_spec,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
)
|
||||
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
ignore_prompt_manager_model: Optional[bool] = False,
|
||||
ignore_prompt_manager_optional_params: Optional[bool] = False,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
from litellm.integrations.prompt_management_base import PromptManagementBase
|
||||
|
||||
return PromptManagementBase.get_chat_completion_prompt(
|
||||
self,
|
||||
model,
|
||||
messages,
|
||||
non_default_params,
|
||||
prompt_id,
|
||||
prompt_variables,
|
||||
dynamic_callback_params,
|
||||
prompt_spec=prompt_spec,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
)
|
||||
|
||||
async def async_get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
ignore_prompt_manager_model: Optional[bool] = False,
|
||||
ignore_prompt_manager_optional_params: Optional[bool] = False,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Async version - delegates to PromptManagementBase async implementation.
|
||||
"""
|
||||
from litellm.integrations.prompt_management_base import PromptManagementBase
|
||||
|
||||
return await PromptManagementBase.async_get_chat_completion_prompt(
|
||||
self,
|
||||
model,
|
||||
messages,
|
||||
non_default_params,
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
prompt_spec=prompt_spec,
|
||||
tools=tools,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
ignore_prompt_manager_model=ignore_prompt_manager_model,
|
||||
ignore_prompt_manager_optional_params=ignore_prompt_manager_optional_params,
|
||||
)
|
||||
|
||||
def _convert_to_messages(self, rendered_content: str) -> List[AllMessageValues]:
|
||||
"""
|
||||
Convert rendered prompt content to chat messages.
|
||||
|
||||
This method supports multiple formats:
|
||||
1. Simple text -> converted to user message
|
||||
2. Text with role prefixes (System:, User:, Assistant:) -> parsed into separate messages
|
||||
3. Already formatted as a single message
|
||||
"""
|
||||
# Clean up the content
|
||||
content = rendered_content.strip()
|
||||
|
||||
# Try to parse role-based format (System: ..., User: ..., etc.)
|
||||
messages = []
|
||||
current_role = None
|
||||
current_content = []
|
||||
|
||||
lines = content.split("\n")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
|
||||
# Check for role prefixes
|
||||
if line.startswith("System:"):
|
||||
if current_role and current_content:
|
||||
messages.append(
|
||||
self._create_message(
|
||||
current_role, "\n".join(current_content).strip()
|
||||
)
|
||||
)
|
||||
current_role = "system"
|
||||
current_content = [line[7:].strip()] # Remove "System:" prefix
|
||||
elif line.startswith("User:"):
|
||||
if current_role and current_content:
|
||||
messages.append(
|
||||
self._create_message(
|
||||
current_role, "\n".join(current_content).strip()
|
||||
)
|
||||
)
|
||||
current_role = "user"
|
||||
current_content = [line[5:].strip()] # Remove "User:" prefix
|
||||
elif line.startswith("Assistant:"):
|
||||
if current_role and current_content:
|
||||
messages.append(
|
||||
self._create_message(
|
||||
current_role, "\n".join(current_content).strip()
|
||||
)
|
||||
)
|
||||
current_role = "assistant"
|
||||
current_content = [line[10:].strip()] # Remove "Assistant:" prefix
|
||||
else:
|
||||
# Continue current message content
|
||||
if current_role:
|
||||
current_content.append(line)
|
||||
else:
|
||||
# No role prefix found, treat as user message
|
||||
current_role = "user"
|
||||
current_content = [line]
|
||||
|
||||
# Add the last message
|
||||
if current_role and current_content:
|
||||
content_text = "\n".join(current_content).strip()
|
||||
if content_text: # Only add if there's actual content
|
||||
messages.append(self._create_message(current_role, content_text))
|
||||
|
||||
# If no messages were created, treat the entire content as a user message
|
||||
if not messages and content:
|
||||
messages.append(self._create_message("user", content))
|
||||
|
||||
return messages
|
||||
|
||||
def _create_message(self, role: str, content: str) -> AllMessageValues:
|
||||
"""Create a message with the specified role and content."""
|
||||
return {
|
||||
"role": role, # type: ignore
|
||||
"content": content,
|
||||
}
|
||||
|
||||
def _extract_optional_params(self, template: PromptTemplate) -> dict:
|
||||
"""
|
||||
Extract optional parameters from the prompt template metadata.
|
||||
|
||||
Includes parameters like temperature, max_tokens, etc.
|
||||
"""
|
||||
optional_params = {}
|
||||
|
||||
# Extract common parameters from metadata
|
||||
if template.optional_params is not None:
|
||||
optional_params.update(template.optional_params)
|
||||
|
||||
return optional_params
|
||||
|
||||
def set_prompt_directory(self, prompt_directory: str) -> None:
|
||||
"""Set the prompt directory and reload prompts."""
|
||||
self.prompt_directory = prompt_directory
|
||||
self._prompt_manager = None # Reset to force reload
|
||||
|
||||
def reload_prompts(self) -> None:
|
||||
"""Reload all prompts from the directory."""
|
||||
if self._prompt_manager:
|
||||
self._prompt_manager.reload_prompts()
|
||||
|
||||
def add_prompt_from_json(self, prompt_id: str, json_data: Dict[str, Any]) -> None:
|
||||
"""Add a prompt from JSON data."""
|
||||
content = json_data.get("content", "")
|
||||
metadata = json_data.get("metadata", {})
|
||||
self.prompt_manager.add_prompt(prompt_id, content, metadata)
|
||||
|
||||
def load_prompts_from_json(self, prompts_data: Dict[str, Dict[str, Any]]) -> None:
|
||||
"""Load multiple prompts from JSON data."""
|
||||
self.prompt_manager.load_prompts_from_json_data(prompts_data)
|
||||
|
||||
def get_prompts_as_json(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all prompts in JSON format."""
|
||||
return self.prompt_manager.get_all_prompts_as_json()
|
||||
|
||||
def convert_prompt_file_to_json(self, file_path: str) -> Dict[str, Any]:
|
||||
"""Convert a .prompt file to JSON format."""
|
||||
return self.prompt_manager.prompt_file_to_json(file_path)
|
||||
@@ -0,0 +1,368 @@
|
||||
"""
|
||||
Based on Google's GenAI Kit dotprompt implementation: https://google.github.io/dotprompt/reference/frontmatter/
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import yaml
|
||||
from jinja2 import DictLoader, Environment, select_autoescape
|
||||
|
||||
|
||||
class PromptTemplate:
|
||||
"""Represents a single prompt template with metadata and content."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
template_id: Optional[str] = None,
|
||||
):
|
||||
self.content = content
|
||||
self.metadata = metadata or {}
|
||||
self.template_id = template_id
|
||||
|
||||
# Extract common metadata fields
|
||||
restricted_keys = ["model", "input", "output"]
|
||||
self.model = self.metadata.get("model")
|
||||
self.input_schema = self.metadata.get("input", {}).get("schema", {})
|
||||
self.output_format = self.metadata.get("output", {}).get("format")
|
||||
self.output_schema = self.metadata.get("output", {}).get("schema", {})
|
||||
self.optional_params = {}
|
||||
for key in self.metadata.keys():
|
||||
if key not in restricted_keys:
|
||||
self.optional_params[key] = self.metadata[key]
|
||||
|
||||
def __repr__(self):
|
||||
return f"PromptTemplate(id='{self.template_id}', model='{self.model}')"
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""
|
||||
Manager for loading and rendering .prompt files following the Dotprompt specification.
|
||||
|
||||
Supports:
|
||||
- YAML frontmatter for metadata
|
||||
- Handlebars-style templating (using Jinja2)
|
||||
- Input/output schema validation
|
||||
- Model configuration
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_id: Optional[str] = None,
|
||||
prompt_directory: Optional[str] = None,
|
||||
prompt_data: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
prompt_file: Optional[str] = None,
|
||||
):
|
||||
self.prompt_directory = Path(prompt_directory) if prompt_directory else None
|
||||
self.prompts: Dict[str, PromptTemplate] = {}
|
||||
self.prompt_file = prompt_file
|
||||
self.jinja_env = Environment(
|
||||
loader=DictLoader({}),
|
||||
autoescape=select_autoescape(["html", "xml"]),
|
||||
# Use Handlebars-style delimiters to match Dotprompt spec
|
||||
variable_start_string="{{",
|
||||
variable_end_string="}}",
|
||||
block_start_string="{%",
|
||||
block_end_string="%}",
|
||||
comment_start_string="{#",
|
||||
comment_end_string="#}",
|
||||
)
|
||||
|
||||
# Load prompts from directory if provided
|
||||
if self.prompt_directory:
|
||||
self._load_prompts()
|
||||
|
||||
if self.prompt_file:
|
||||
if not prompt_id:
|
||||
raise ValueError("prompt_id is required when prompt_file is provided")
|
||||
|
||||
template = self._load_prompt_file(self.prompt_file, prompt_id)
|
||||
self.prompts[prompt_id] = template
|
||||
|
||||
# Load prompts from JSON data if provided
|
||||
if prompt_data:
|
||||
self._load_prompts_from_json(prompt_data, prompt_id)
|
||||
|
||||
def _load_prompts(self) -> None:
|
||||
"""Load all .prompt files from the prompt directory."""
|
||||
if not self.prompt_directory or not self.prompt_directory.exists():
|
||||
raise ValueError(
|
||||
f"Prompt directory does not exist: {self.prompt_directory}"
|
||||
)
|
||||
|
||||
prompt_files = list(self.prompt_directory.glob("*.prompt"))
|
||||
|
||||
for prompt_file in prompt_files:
|
||||
try:
|
||||
prompt_id = prompt_file.stem # filename without extension
|
||||
template = self._load_prompt_file(prompt_file, prompt_id)
|
||||
self.prompts[prompt_id] = template
|
||||
# Optional: print(f"Loaded prompt: {prompt_id}")
|
||||
except Exception:
|
||||
# Optional: print(f"Error loading prompt file {prompt_file}")
|
||||
pass
|
||||
|
||||
def _load_prompts_from_json(
|
||||
self, prompt_data: Dict[str, Dict[str, Any]], prompt_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""Load prompts from JSON data structure.
|
||||
|
||||
Expected format:
|
||||
{
|
||||
"prompt_id": {
|
||||
"content": "template content",
|
||||
"metadata": {"model": "gpt-4", "temperature": 0.7, ...}
|
||||
}
|
||||
}
|
||||
|
||||
or
|
||||
|
||||
{
|
||||
"content": "template content",
|
||||
"metadata": {"model": "gpt-4", "temperature": 0.7, ...}
|
||||
} + prompt_id
|
||||
"""
|
||||
if prompt_id:
|
||||
prompt_data = {prompt_id: prompt_data}
|
||||
|
||||
for prompt_id, prompt_info in prompt_data.items():
|
||||
try:
|
||||
content = prompt_info.get("content", "")
|
||||
metadata = prompt_info.get("metadata", {})
|
||||
|
||||
template = PromptTemplate(
|
||||
content=content,
|
||||
metadata=metadata,
|
||||
template_id=prompt_id,
|
||||
)
|
||||
self.prompts[prompt_id] = template
|
||||
except Exception:
|
||||
# Optional: print(f"Error loading prompt from JSON: {prompt_id}")
|
||||
pass
|
||||
|
||||
def _load_prompt_file(
|
||||
self, file_path: Union[str, Path], prompt_id: str
|
||||
) -> PromptTemplate:
|
||||
"""Load and parse a single .prompt file."""
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
|
||||
# Split frontmatter and content
|
||||
frontmatter, template_content = self._parse_frontmatter(content)
|
||||
|
||||
return PromptTemplate(
|
||||
content=template_content.strip(),
|
||||
metadata=frontmatter,
|
||||
template_id=prompt_id,
|
||||
)
|
||||
|
||||
def _parse_frontmatter(self, content: str) -> Tuple[Dict[str, Any], str]:
|
||||
"""Parse YAML frontmatter from prompt content."""
|
||||
# Match YAML frontmatter between --- delimiters
|
||||
frontmatter_pattern = r"^---\s*\n(.*?)\n---\s*\n(.*)$"
|
||||
match = re.match(frontmatter_pattern, content, re.DOTALL)
|
||||
|
||||
if match:
|
||||
frontmatter_yaml = match.group(1)
|
||||
template_content = match.group(2)
|
||||
|
||||
try:
|
||||
frontmatter = yaml.safe_load(frontmatter_yaml) or {}
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Invalid YAML frontmatter: {e}")
|
||||
else:
|
||||
# No frontmatter found, treat entire content as template
|
||||
frontmatter = {}
|
||||
template_content = content
|
||||
|
||||
return frontmatter, template_content
|
||||
|
||||
def render(
|
||||
self,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[Dict[str, Any]] = None,
|
||||
version: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Render a prompt template with the given variables.
|
||||
|
||||
Args:
|
||||
prompt_id: The ID of the prompt template to render
|
||||
prompt_variables: Variables to substitute in the template
|
||||
version: Optional version number. If provided, looks for {prompt_id}.v{version}
|
||||
|
||||
Returns:
|
||||
The rendered prompt string
|
||||
|
||||
Raises:
|
||||
KeyError: If prompt_id is not found
|
||||
ValueError: If template rendering fails
|
||||
"""
|
||||
# Get the template (versioned or base)
|
||||
template = self.get_prompt(prompt_id=prompt_id, version=version)
|
||||
|
||||
if template is None:
|
||||
available_prompts = list(self.prompts.keys())
|
||||
version_str = f" (version {version})" if version else ""
|
||||
raise KeyError(
|
||||
f"Prompt '{prompt_id}'{version_str} not found. Available prompts: {available_prompts}"
|
||||
)
|
||||
|
||||
variables = prompt_variables or {}
|
||||
|
||||
# Validate input variables against schema if defined
|
||||
if template.input_schema:
|
||||
self._validate_input(variables, template.input_schema)
|
||||
|
||||
try:
|
||||
# Create Jinja2 template and render
|
||||
jinja_template = self.jinja_env.from_string(template.content)
|
||||
rendered = jinja_template.render(**variables)
|
||||
return rendered
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error rendering template '{prompt_id}': {e}")
|
||||
|
||||
def _validate_input(
|
||||
self, variables: Dict[str, Any], schema: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Basic validation of input variables against schema."""
|
||||
for field_name, field_type in schema.items():
|
||||
if field_name in variables:
|
||||
value = variables[field_name]
|
||||
expected_type = self._get_python_type(field_type)
|
||||
|
||||
if not isinstance(value, expected_type):
|
||||
raise ValueError(
|
||||
f"Invalid type for field '{field_name}': "
|
||||
f"expected {getattr(expected_type, '__name__', str(expected_type))}, got {type(value).__name__}"
|
||||
)
|
||||
|
||||
def _get_python_type(self, schema_type: str) -> Union[type, tuple]:
|
||||
"""Convert schema type string to Python type."""
|
||||
type_mapping: Dict[str, Union[type, tuple]] = {
|
||||
"string": str,
|
||||
"str": str,
|
||||
"number": (int, float),
|
||||
"integer": int,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"boolean": bool,
|
||||
"bool": bool,
|
||||
"array": list,
|
||||
"list": list,
|
||||
"object": dict,
|
||||
"dict": dict,
|
||||
}
|
||||
|
||||
return type_mapping.get(schema_type.lower(), str) # type: ignore
|
||||
|
||||
def get_prompt(
|
||||
self, prompt_id: str, version: Optional[int] = None
|
||||
) -> Optional[PromptTemplate]:
|
||||
"""
|
||||
Get a prompt template by ID and optional version.
|
||||
|
||||
Args:
|
||||
prompt_id: The base prompt ID
|
||||
version: Optional version number. If provided, looks for {prompt_id}.v{version}
|
||||
|
||||
Returns:
|
||||
The prompt template if found, None otherwise
|
||||
"""
|
||||
if version is not None:
|
||||
# Try versioned prompt first: prompt_id.v{version}
|
||||
versioned_id = f"{prompt_id}.v{version}"
|
||||
if versioned_id in self.prompts:
|
||||
return self.prompts[versioned_id]
|
||||
|
||||
# Fall back to base prompt_id
|
||||
return self.prompts.get(prompt_id)
|
||||
|
||||
def list_prompts(self) -> List[str]:
|
||||
"""Get a list of all available prompt IDs."""
|
||||
return list(self.prompts.keys())
|
||||
|
||||
def get_prompt_metadata(self, prompt_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get metadata for a specific prompt."""
|
||||
template = self.prompts.get(prompt_id)
|
||||
return template.metadata if template else None
|
||||
|
||||
def reload_prompts(self) -> None:
|
||||
"""Reload all prompts from the directory (if directory was provided)."""
|
||||
self.prompts.clear()
|
||||
if self.prompt_directory:
|
||||
self._load_prompts()
|
||||
|
||||
def add_prompt(
|
||||
self, prompt_id: str, content: str, metadata: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""Add a prompt template programmatically."""
|
||||
template = PromptTemplate(
|
||||
content=content, metadata=metadata or {}, template_id=prompt_id
|
||||
)
|
||||
self.prompts[prompt_id] = template
|
||||
|
||||
def prompt_file_to_json(self, file_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
"""Convert a .prompt file to JSON format.
|
||||
|
||||
Args:
|
||||
file_path: Path to the .prompt file
|
||||
|
||||
Returns:
|
||||
Dictionary with 'content' and 'metadata' keys
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
|
||||
# Parse frontmatter and content
|
||||
frontmatter, template_content = self._parse_frontmatter(content)
|
||||
|
||||
return {"content": template_content.strip(), "metadata": frontmatter}
|
||||
|
||||
def json_to_prompt_file(self, prompt_data: Dict[str, Any]) -> str:
|
||||
"""Convert JSON prompt data to .prompt file format.
|
||||
|
||||
Args:
|
||||
prompt_data: Dictionary with 'content' and 'metadata' keys
|
||||
|
||||
Returns:
|
||||
String content in .prompt file format
|
||||
"""
|
||||
content = prompt_data.get("content", "")
|
||||
metadata = prompt_data.get("metadata", {})
|
||||
|
||||
if not metadata:
|
||||
# No metadata, return just the content
|
||||
return content
|
||||
|
||||
# Convert metadata to YAML frontmatter
|
||||
import yaml
|
||||
|
||||
frontmatter_yaml = yaml.dump(metadata, default_flow_style=False)
|
||||
|
||||
return f"---\n{frontmatter_yaml}---\n{content}"
|
||||
|
||||
def get_all_prompts_as_json(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all loaded prompts in JSON format.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping prompt_id to prompt data
|
||||
"""
|
||||
result = {}
|
||||
for prompt_id, template in self.prompts.items():
|
||||
result[prompt_id] = {
|
||||
"content": template.content,
|
||||
"metadata": template.metadata,
|
||||
}
|
||||
return result
|
||||
|
||||
def load_prompts_from_json_data(
|
||||
self, prompt_data: Dict[str, Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Load additional prompts from JSON data (merges with existing prompts)."""
|
||||
self._load_prompts_from_json(prompt_data)
|
||||
@@ -0,0 +1,89 @@
|
||||
#### What this does ####
|
||||
# On success + failure, log events to Supabase
|
||||
|
||||
import os
|
||||
import traceback
|
||||
from litellm._uuid import uuid
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class DyanmoDBLogger:
|
||||
# Class variables or attributes
|
||||
|
||||
def __init__(self):
|
||||
# Instance variables
|
||||
import boto3
|
||||
|
||||
self.dynamodb: Any = boto3.resource(
|
||||
"dynamodb", region_name=os.environ["AWS_REGION_NAME"]
|
||||
)
|
||||
if litellm.dynamodb_table_name is None:
|
||||
raise ValueError(
|
||||
"LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`"
|
||||
)
|
||||
self.table_name = litellm.dynamodb_table_name
|
||||
|
||||
async def _async_log_event(
|
||||
self, kwargs, response_obj, start_time, end_time, print_verbose
|
||||
):
|
||||
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
|
||||
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
try:
|
||||
print_verbose(
|
||||
f"DynamoDB Logging - Enters logging function for model {kwargs}"
|
||||
)
|
||||
|
||||
# construct payload to send to DynamoDB
|
||||
# follows the same params as langfuse.py
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
) # if litellm_params['metadata'] == None
|
||||
messages = kwargs.get("messages")
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
call_type = kwargs.get("call_type", "litellm.completion")
|
||||
usage = response_obj["usage"]
|
||||
id = response_obj.get("id", str(uuid.uuid4()))
|
||||
|
||||
# Build the initial payload
|
||||
payload = {
|
||||
"id": id,
|
||||
"call_type": call_type,
|
||||
"startTime": start_time,
|
||||
"endTime": end_time,
|
||||
"model": kwargs.get("model", ""),
|
||||
"user": kwargs.get("user", ""),
|
||||
"modelParameters": optional_params,
|
||||
"messages": messages,
|
||||
"response": response_obj,
|
||||
"usage": usage,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
# Ensure everything in the payload is converted to str
|
||||
for key, value in payload.items():
|
||||
try:
|
||||
payload[key] = str(value)
|
||||
except Exception:
|
||||
# non blocking if it can't cast to a str
|
||||
pass
|
||||
|
||||
print_verbose(f"\nDynamoDB Logger - Logging payload = {payload}")
|
||||
|
||||
# put data in dyanmo DB
|
||||
table = self.dynamodb.Table(self.table_name)
|
||||
# Assuming log_data is a dictionary with log information
|
||||
response = table.put_item(Item=payload)
|
||||
|
||||
print_verbose(f"Response from DynamoDB:{str(response)}")
|
||||
|
||||
print_verbose(
|
||||
f"DynamoDB Layer Logging - final response object: {response_obj}"
|
||||
)
|
||||
return response
|
||||
except Exception:
|
||||
print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
|
||||
pass
|
||||
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Functions for sending Email Alerts
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm.proxy._types import WebhookEvent
|
||||
|
||||
# we use this for the email header, please send a test email if you change this. verify it looks good on email
|
||||
LITELLM_LOGO_URL = "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
|
||||
LITELLM_SUPPORT_CONTACT = "support@berri.ai"
|
||||
|
||||
|
||||
async def get_all_team_member_emails(team_id: Optional[str] = None) -> list:
|
||||
verbose_logger.debug(
|
||||
"Email Alerting: Getting all team members for team_id=%s", team_id
|
||||
)
|
||||
if team_id is None:
|
||||
return []
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise Exception("Not connected to DB!")
|
||||
|
||||
team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={
|
||||
"team_id": team_id,
|
||||
}
|
||||
)
|
||||
|
||||
if team_row is None:
|
||||
return []
|
||||
|
||||
_team_members = team_row.members_with_roles
|
||||
verbose_logger.debug(
|
||||
"Email Alerting: Got team members for team_id=%s Team Members: %s",
|
||||
team_id,
|
||||
_team_members,
|
||||
)
|
||||
_team_member_user_ids: List[str] = []
|
||||
for member in _team_members:
|
||||
if member and isinstance(member, dict):
|
||||
_user_id = member.get("user_id")
|
||||
if _user_id and isinstance(_user_id, str):
|
||||
_team_member_user_ids.append(_user_id)
|
||||
|
||||
sql_query = """
|
||||
SELECT user_email
|
||||
FROM "LiteLLM_UserTable"
|
||||
WHERE user_id = ANY($1::TEXT[]);
|
||||
"""
|
||||
|
||||
_result = await prisma_client.db.query_raw(sql_query, _team_member_user_ids)
|
||||
|
||||
verbose_logger.debug("Email Alerting: Got all Emails for team, emails=%s", _result)
|
||||
|
||||
if _result is None:
|
||||
return []
|
||||
|
||||
emails = []
|
||||
for user in _result:
|
||||
if user and isinstance(user, dict) and user.get("user_email", None) is not None:
|
||||
emails.append(user.get("user_email"))
|
||||
return emails
|
||||
|
||||
|
||||
async def send_team_budget_alert(webhook_event: WebhookEvent) -> bool:
|
||||
"""
|
||||
Send an Email Alert to All Team Members when the Team Budget is crossed
|
||||
Returns -> True if sent, False if not.
|
||||
"""
|
||||
from litellm.proxy.utils import send_email
|
||||
|
||||
_team_id = webhook_event.team_id
|
||||
team_alias = webhook_event.team_alias
|
||||
verbose_logger.debug(
|
||||
"Email Alerting: Sending Team Budget Alert for team=%s", team_alias
|
||||
)
|
||||
|
||||
email_logo_url = os.getenv("SMTP_SENDER_LOGO", os.getenv("EMAIL_LOGO_URL", None))
|
||||
email_support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", None)
|
||||
|
||||
# await self._check_if_using_premium_email_feature(
|
||||
# premium_user, email_logo_url, email_support_contact
|
||||
# )
|
||||
|
||||
if email_logo_url is None:
|
||||
email_logo_url = LITELLM_LOGO_URL
|
||||
if email_support_contact is None:
|
||||
email_support_contact = LITELLM_SUPPORT_CONTACT
|
||||
recipient_emails = await get_all_team_member_emails(_team_id)
|
||||
recipient_emails_str: str = ",".join(recipient_emails)
|
||||
verbose_logger.debug(
|
||||
"Email Alerting: Sending team budget alert to %s", recipient_emails_str
|
||||
)
|
||||
|
||||
event_name = webhook_event.event_message
|
||||
max_budget = webhook_event.max_budget
|
||||
email_html_content = "Alert from LiteLLM Server"
|
||||
|
||||
if recipient_emails_str is None:
|
||||
verbose_proxy_logger.warning(
|
||||
"Email Alerting: Trying to send email alert to no recipient, got recipient_emails=%s",
|
||||
recipient_emails_str,
|
||||
)
|
||||
|
||||
email_html_content = f"""
|
||||
<img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" /> <br/><br/><br/>
|
||||
|
||||
Budget Crossed for Team <b> {team_alias} </b> <br/> <br/>
|
||||
|
||||
Your Teams LLM API usage has crossed it's <b> budget of ${max_budget} </b>, current spend is <b>${webhook_event.spend}</b><br /> <br />
|
||||
|
||||
API requests will be rejected until either (a) you increase your budget or (b) your budget gets reset <br /> <br />
|
||||
|
||||
If you have any questions, please send an email to {email_support_contact} <br /> <br />
|
||||
|
||||
Best, <br />
|
||||
The LiteLLM team <br />
|
||||
"""
|
||||
|
||||
email_event = {
|
||||
"to": recipient_emails_str,
|
||||
"subject": f"LiteLLM {event_name} for Team {team_alias}",
|
||||
"html": email_html_content,
|
||||
}
|
||||
|
||||
await send_email(
|
||||
receiver_email=email_event["to"],
|
||||
subject=email_event["subject"],
|
||||
html=email_event["html"],
|
||||
)
|
||||
|
||||
return False
|
||||
@@ -0,0 +1,10 @@
|
||||
EMAIL_FOOTER = """
|
||||
<div class="footer">
|
||||
<p>© 2025 LiteLLM. All rights reserved.</p>
|
||||
<div class="social-links">
|
||||
<a href="https://twitter.com/litellm">Twitter</a> •
|
||||
<a href="https://github.com/BerriAI/litellm">GitHub</a> •
|
||||
<a href="https://litellm.ai">Website</a>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Modern Email Templates for LiteLLM Email Service with professional styling
|
||||
"""
|
||||
|
||||
KEY_CREATED_EMAIL_TEMPLATE = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Your API Key is Ready</title>
|
||||
<style>
|
||||
body, html {{
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
|
||||
color: #333333;
|
||||
background-color: #f8fafc;
|
||||
line-height: 1.5;
|
||||
}}
|
||||
.container {{
|
||||
max-width: 560px;
|
||||
margin: 20px auto;
|
||||
background-color: #ffffff;
|
||||
border-radius: 8px;
|
||||
overflow: hidden;
|
||||
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
|
||||
}}
|
||||
.header {{
|
||||
padding: 24px 0;
|
||||
text-align: center;
|
||||
border-bottom: 1px solid #f1f5f9;
|
||||
}}
|
||||
.content {{
|
||||
padding: 32px 40px;
|
||||
}}
|
||||
.greeting {{
|
||||
font-size: 16px;
|
||||
margin-bottom: 20px;
|
||||
color: #333333;
|
||||
}}
|
||||
.message {{
|
||||
font-size: 16px;
|
||||
color: #333333;
|
||||
margin-bottom: 20px;
|
||||
}}
|
||||
.key-container {{
|
||||
margin: 28px 0;
|
||||
}}
|
||||
.key-label {{
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
margin-bottom: 8px;
|
||||
color: #4b5563;
|
||||
}}
|
||||
.key {{
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
|
||||
word-break: break-all;
|
||||
background-color: #f9fafb;
|
||||
border-radius: 6px;
|
||||
padding: 16px;
|
||||
font-size: 14px;
|
||||
border: 1px solid #e5e7eb;
|
||||
color: #4338ca;
|
||||
}}
|
||||
h2 {{
|
||||
font-size: 18px;
|
||||
font-weight: 600;
|
||||
margin-top: 36px;
|
||||
margin-bottom: 16px;
|
||||
color: #333333;
|
||||
}}
|
||||
.budget-info {{
|
||||
background-color: #f0fdf4;
|
||||
border-radius: 6px;
|
||||
padding: 14px 16px;
|
||||
margin: 24px 0;
|
||||
font-size: 14px;
|
||||
border: 1px solid #dcfce7;
|
||||
}}
|
||||
.code-block {{
|
||||
background-color: #f8fafc;
|
||||
color: #334155;
|
||||
border-radius: 8px;
|
||||
padding: 20px;
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
|
||||
font-size: 13px;
|
||||
overflow-x: auto;
|
||||
margin: 20px 0;
|
||||
line-height: 1.6;
|
||||
border: 1px solid #e2e8f0;
|
||||
}}
|
||||
.code-comment {{
|
||||
color: #64748b;
|
||||
}}
|
||||
.code-string {{
|
||||
color: #0369a1;
|
||||
}}
|
||||
.code-keyword {{
|
||||
color: #7e22ce;
|
||||
}}
|
||||
.btn {{
|
||||
display: inline-block;
|
||||
padding: 8px 20px;
|
||||
background-color: #6366f1;
|
||||
color: #ffffff !important;
|
||||
text-decoration: none;
|
||||
border-radius: 6px;
|
||||
font-weight: 500;
|
||||
margin-top: 24px;
|
||||
text-align: center;
|
||||
font-size: 14px;
|
||||
transition: background-color 0.2s;
|
||||
}}
|
||||
.btn:hover {{
|
||||
background-color: #4f46e5;
|
||||
color: #ffffff !important;
|
||||
}}
|
||||
.separator {{
|
||||
height: 1px;
|
||||
background-color: #f1f5f9;
|
||||
margin: 40px 0 30px;
|
||||
}}
|
||||
.footer {{
|
||||
padding: 24px 40px 32px;
|
||||
text-align: center;
|
||||
color: #64748b;
|
||||
font-size: 13px;
|
||||
background-color: #f8fafc;
|
||||
border-top: 1px solid #f1f5f9;
|
||||
}}
|
||||
.social-links {{
|
||||
margin-top: 12px;
|
||||
}}
|
||||
.social-links a {{
|
||||
display: inline-block;
|
||||
margin: 0 8px;
|
||||
color: #64748b;
|
||||
text-decoration: none;
|
||||
}}
|
||||
@media only screen and (max-width: 620px) {{
|
||||
.container {{
|
||||
width: 100%;
|
||||
margin: 0;
|
||||
border-radius: 0;
|
||||
}}
|
||||
.content {{
|
||||
padding: 24px 20px;
|
||||
}}
|
||||
.footer {{
|
||||
padding: 20px;
|
||||
}}
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<img src="{email_logo_url}" alt="LiteLLM Logo" style="height: 32px; width: auto;">
|
||||
</div>
|
||||
<div class="content">
|
||||
<div class="greeting">
|
||||
<p>Hi {recipient_email},</p>
|
||||
</div>
|
||||
|
||||
<div class="message">
|
||||
<p>Great news! Your LiteLLM API key is ready to use.</p>
|
||||
</div>
|
||||
|
||||
<div class="budget-info">
|
||||
<p style="margin: 0;"><strong>Monthly Budget:</strong> {key_budget}</p>
|
||||
</div>
|
||||
|
||||
<div class="key-container">
|
||||
<div class="key-label">Your API Key</div>
|
||||
<div class="key">{key_token}</div>
|
||||
</div>
|
||||
|
||||
<h2>Quick Start Guide</h2>
|
||||
<p>Here's how to use your key with the OpenAI SDK:</p>
|
||||
|
||||
<div class="code-block">
|
||||
<span class="code-keyword">import</span> openai<br>
|
||||
<br>
|
||||
client = openai.OpenAI(<br>
|
||||
api_key=<span class="code-string">"{key_token}"</span>,<br>
|
||||
base_url=<span class="code-string">"{base_url}"</span><br>
|
||||
)<br>
|
||||
<br>
|
||||
response = client.chat.completions.create(<br>
|
||||
model=<span class="code-string">"gpt-3.5-turbo"</span>, <span class="code-comment"># model to send to the proxy</span><br>
|
||||
messages = [<br>
|
||||
{{<br>
|
||||
<span class="code-string">"role"</span>: <span class="code-string">"user"</span>,<br>
|
||||
<span class="code-string">"content"</span>: <span class="code-string">"this is a test request, write a short poem"</span><br>
|
||||
}}<br>
|
||||
]<br>
|
||||
)
|
||||
</div>
|
||||
|
||||
<a href="https://docs.litellm.ai/docs/proxy/user_keys" class="btn" style="color: #ffffff;">View Documentation</a>
|
||||
|
||||
<div class="separator"></div>
|
||||
|
||||
<h2>Need Help?</h2>
|
||||
<p>If you have any questions or need assistance, please contact us at {email_support_contact}.</p>
|
||||
</div>
|
||||
{email_footer}
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Modern Email Templates for LiteLLM Email Service with professional styling
|
||||
"""
|
||||
|
||||
KEY_ROTATED_EMAIL_TEMPLATE = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Your API Key Has Been Rotated</title>
|
||||
<style>
|
||||
body, html {{
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
|
||||
color: #333333;
|
||||
background-color: #f8fafc;
|
||||
line-height: 1.5;
|
||||
}}
|
||||
.container {{
|
||||
max-width: 560px;
|
||||
margin: 20px auto;
|
||||
background-color: #ffffff;
|
||||
border-radius: 8px;
|
||||
overflow: hidden;
|
||||
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
|
||||
}}
|
||||
.header {{
|
||||
padding: 24px 0;
|
||||
text-align: center;
|
||||
border-bottom: 1px solid #f1f5f9;
|
||||
}}
|
||||
.content {{
|
||||
padding: 32px 40px;
|
||||
}}
|
||||
.greeting {{
|
||||
font-size: 16px;
|
||||
margin-bottom: 20px;
|
||||
color: #333333;
|
||||
}}
|
||||
.message {{
|
||||
font-size: 16px;
|
||||
color: #333333;
|
||||
margin-bottom: 20px;
|
||||
}}
|
||||
.key-container {{
|
||||
margin: 28px 0;
|
||||
}}
|
||||
.key-label {{
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
margin-bottom: 8px;
|
||||
color: #4b5563;
|
||||
}}
|
||||
.key {{
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
|
||||
word-break: break-all;
|
||||
background-color: #f9fafb;
|
||||
border-radius: 6px;
|
||||
padding: 16px;
|
||||
font-size: 14px;
|
||||
border: 1px solid #e5e7eb;
|
||||
color: #4338ca;
|
||||
}}
|
||||
h2 {{
|
||||
font-size: 18px;
|
||||
font-weight: 600;
|
||||
margin-top: 36px;
|
||||
margin-bottom: 16px;
|
||||
color: #333333;
|
||||
}}
|
||||
.budget-info {{
|
||||
background-color: #f0fdf4;
|
||||
border-radius: 6px;
|
||||
padding: 14px 16px;
|
||||
margin: 24px 0;
|
||||
font-size: 14px;
|
||||
border: 1px solid #dcfce7;
|
||||
}}
|
||||
.code-block {{
|
||||
background-color: #f8fafc;
|
||||
color: #334155;
|
||||
border-radius: 8px;
|
||||
padding: 20px;
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
|
||||
font-size: 13px;
|
||||
overflow-x: auto;
|
||||
margin: 20px 0;
|
||||
line-height: 1.6;
|
||||
border: 1px solid #e2e8f0;
|
||||
}}
|
||||
.code-comment {{
|
||||
color: #64748b;
|
||||
}}
|
||||
.code-string {{
|
||||
color: #0369a1;
|
||||
}}
|
||||
.code-keyword {{
|
||||
color: #7e22ce;
|
||||
}}
|
||||
.btn {{
|
||||
display: inline-block;
|
||||
padding: 8px 20px;
|
||||
background-color: #6366f1;
|
||||
color: #ffffff !important;
|
||||
text-decoration: none;
|
||||
border-radius: 6px;
|
||||
font-weight: 500;
|
||||
margin-top: 24px;
|
||||
text-align: center;
|
||||
font-size: 14px;
|
||||
transition: background-color 0.2s;
|
||||
}}
|
||||
.btn:hover {{
|
||||
background-color: #4f46e5;
|
||||
color: #ffffff !important;
|
||||
}}
|
||||
.separator {{
|
||||
height: 1px;
|
||||
background-color: #f1f5f9;
|
||||
margin: 40px 0 30px;
|
||||
}}
|
||||
.footer {{
|
||||
padding: 24px 40px 32px;
|
||||
text-align: center;
|
||||
color: #64748b;
|
||||
font-size: 13px;
|
||||
background-color: #f8fafc;
|
||||
border-top: 1px solid #f1f5f9;
|
||||
}}
|
||||
.social-links {{
|
||||
margin-top: 12px;
|
||||
}}
|
||||
.social-links a {{
|
||||
display: inline-block;
|
||||
margin: 0 8px;
|
||||
color: #64748b;
|
||||
text-decoration: none;
|
||||
}}
|
||||
@media only screen and (max-width: 620px) {{
|
||||
.container {{
|
||||
width: 100%;
|
||||
margin: 0;
|
||||
border-radius: 0;
|
||||
}}
|
||||
.content {{
|
||||
padding: 24px 20px;
|
||||
}}
|
||||
.footer {{
|
||||
padding: 20px;
|
||||
}}
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<img src="{email_logo_url}" alt="LiteLLM Logo" style="height: 32px; width: auto;">
|
||||
</div>
|
||||
<div class="content">
|
||||
<div class="greeting">
|
||||
<p>Hi {recipient_email},</p>
|
||||
</div>
|
||||
|
||||
<div class="message">
|
||||
<p><strong>Your LiteLLM API key has been rotated</strong> as part of our ongoing commitment to security best practices.</p>
|
||||
<p style="margin-top: 16px;">Your previous API key has been deactivated and will no longer work. Please update your applications with the new key below.</p>
|
||||
</div>
|
||||
|
||||
<div class="key-container">
|
||||
<div class="key-label">Your New API Key</div>
|
||||
<div class="key">{key_token}</div>
|
||||
</div>
|
||||
|
||||
<div class="budget-info">
|
||||
<p style="margin: 0;"><strong>Monthly Budget:</strong> {key_budget}</p>
|
||||
</div>
|
||||
|
||||
<h2>Action Required</h2>
|
||||
<p>Update your applications and systems with the new API key. Here's an example:</p>
|
||||
|
||||
<div class="code-block">
|
||||
<span class="code-keyword">import</span> openai<br>
|
||||
<br>
|
||||
client = openai.OpenAI(<br>
|
||||
api_key=<span class="code-string">"{key_token}"</span>,<br>
|
||||
base_url=<span class="code-string">"{base_url}"</span><br>
|
||||
)<br>
|
||||
<br>
|
||||
response = client.chat.completions.create(<br>
|
||||
model=<span class="code-string">"gpt-3.5-turbo"</span>, <span class="code-comment"># model to send to the proxy</span><br>
|
||||
messages = [<br>
|
||||
{{<br>
|
||||
<span class="code-string">"role"</span>: <span class="code-string">"user"</span>,<br>
|
||||
<span class="code-string">"content"</span>: <span class="code-string">"this is a test request, write a short poem"</span><br>
|
||||
}}<br>
|
||||
]<br>
|
||||
)
|
||||
</div>
|
||||
|
||||
<div class="separator"></div>
|
||||
|
||||
<h2>Security Best Practices</h2>
|
||||
<p style="margin-bottom: 12px;">To keep your API key secure:</p>
|
||||
<ul style="margin: 0; padding-left: 20px; color: #333333;">
|
||||
<li style="margin-bottom: 8px;">Never share your API key publicly or commit it to version control</li>
|
||||
<li style="margin-bottom: 8px;">Store it securely using environment variables or secret management systems</li>
|
||||
<li style="margin-bottom: 8px;">Monitor your API usage regularly for any unusual activity</li>
|
||||
<li style="margin-bottom: 8px;">Rotate your keys periodically as a security best practice</li>
|
||||
</ul>
|
||||
|
||||
<a href="https://docs.litellm.ai/docs/proxy/user_keys" class="btn" style="color: #ffffff;">View Documentation</a>
|
||||
|
||||
<div class="separator"></div>
|
||||
|
||||
<h2>Need Help?</h2>
|
||||
<p>If you have any questions or need assistance updating your systems, please contact us at {email_support_contact}.</p>
|
||||
</div>
|
||||
{email_footer}
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Email Templates used by the LiteLLM Email Service in slack_alerting.py
|
||||
"""
|
||||
|
||||
KEY_CREATED_EMAIL_TEMPLATE = """
|
||||
<img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" />
|
||||
|
||||
<p> Hi {recipient_email}, <br/>
|
||||
|
||||
I'm happy to provide you with an OpenAI Proxy API Key, loaded with ${key_budget} per month. <br /> <br />
|
||||
|
||||
<b>
|
||||
Key: <pre>{key_token}</pre> <br>
|
||||
</b>
|
||||
|
||||
<h2>Usage Example</h2>
|
||||
|
||||
Detailed Documentation on <a href="https://docs.litellm.ai/docs/proxy/user_keys">Usage with OpenAI Python SDK, Langchain, LlamaIndex, Curl</a>
|
||||
|
||||
<pre>
|
||||
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="{key_token}",
|
||||
base_url={{base_url}}
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo", # model to send to the proxy
|
||||
messages = [
|
||||
{{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}}
|
||||
]
|
||||
)
|
||||
|
||||
</pre>
|
||||
|
||||
|
||||
If you have any questions, please send an email to {email_support_contact} <br /> <br />
|
||||
|
||||
Best, <br />
|
||||
The LiteLLM team <br />
|
||||
"""
|
||||
|
||||
|
||||
USER_INVITED_EMAIL_TEMPLATE = """
|
||||
<img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" />
|
||||
|
||||
<p> Hi {recipient_email}, <br/>
|
||||
|
||||
You were invited to use OpenAI Proxy API for team {team_name} <br /> <br />
|
||||
|
||||
<a href="{base_url}" style="display: inline-block; padding: 10px 20px; background-color: #87ceeb; color: #fff; text-decoration: none; border-radius: 20px;">Get Started here</a> <br /> <br />
|
||||
|
||||
|
||||
If you have any questions, please send an email to {email_support_contact} <br /> <br />
|
||||
|
||||
Best, <br />
|
||||
The LiteLLM team <br />
|
||||
"""
|
||||
|
||||
SOFT_BUDGET_ALERT_EMAIL_TEMPLATE = """
|
||||
<img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" />
|
||||
|
||||
<p> Hi {recipient_email}, <br/>
|
||||
|
||||
Your LiteLLM API key has crossed its <b>soft budget limit of {soft_budget}</b>. <br /> <br />
|
||||
|
||||
<b>Current Spend:</b> {spend} <br />
|
||||
<b>Soft Budget:</b> {soft_budget} <br />
|
||||
{max_budget_info}
|
||||
|
||||
<p style="color: #dc2626; font-weight: 500;">
|
||||
⚠️ Note: Your API requests will continue to work, but you should monitor your usage closely.
|
||||
If you reach your maximum budget, requests will be rejected.
|
||||
</p>
|
||||
|
||||
You can view your usage and manage your budget in the <a href="{base_url}">LiteLLM Dashboard</a>. <br /> <br />
|
||||
|
||||
If you have any questions, please send an email to {email_support_contact} <br /> <br />
|
||||
|
||||
Best, <br />
|
||||
The LiteLLM team <br />
|
||||
"""
|
||||
|
||||
TEAM_SOFT_BUDGET_ALERT_EMAIL_TEMPLATE = """
|
||||
<img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" />
|
||||
|
||||
<p> Hi {team_alias} team member, <br/>
|
||||
|
||||
Your LiteLLM team has crossed its <b>soft budget limit of {soft_budget}</b>. <br /> <br />
|
||||
|
||||
<b>Current Spend:</b> {spend} <br />
|
||||
<b>Soft Budget:</b> {soft_budget} <br />
|
||||
{max_budget_info}
|
||||
|
||||
<p style="color: #dc2626; font-weight: 500;">
|
||||
⚠️ Note: Your API requests will continue to work, but you should monitor your usage closely.
|
||||
If you reach your maximum budget, requests will be rejected.
|
||||
</p>
|
||||
|
||||
You can view your usage and manage your budget in the <a href="{base_url}">LiteLLM Dashboard</a>. <br /> <br />
|
||||
|
||||
If you have any questions, please send an email to {email_support_contact} <br /> <br />
|
||||
|
||||
Best, <br />
|
||||
The LiteLLM team <br />
|
||||
"""
|
||||
|
||||
MAX_BUDGET_ALERT_EMAIL_TEMPLATE = """
|
||||
<img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" />
|
||||
|
||||
<p> Hi {recipient_email}, <br/>
|
||||
|
||||
Your LiteLLM API key has reached <b>{percentage}% of its maximum budget</b>. <br /> <br />
|
||||
|
||||
<b>Current Spend:</b> {spend} <br />
|
||||
<b>Maximum Budget:</b> {max_budget} <br />
|
||||
<b>Alert Threshold:</b> {alert_threshold} ({percentage}%) <br />
|
||||
|
||||
<p style="color: #dc2626; font-weight: 500;">
|
||||
⚠️ Warning: You are approaching your maximum budget limit.
|
||||
Once you reach your maximum budget of {max_budget}, all API requests will be rejected.
|
||||
</p>
|
||||
|
||||
You can view your usage and manage your budget in the <a href="{base_url}">LiteLLM Dashboard</a>. <br /> <br />
|
||||
|
||||
If you have any questions, please send an email to {email_support_contact} <br /> <br />
|
||||
|
||||
Best, <br />
|
||||
The LiteLLM team <br />
|
||||
"""
|
||||
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Modern Email Templates for LiteLLM Email Service with professional styling
|
||||
"""
|
||||
|
||||
USER_INVITATION_EMAIL_TEMPLATE = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Welcome to LiteLLM</title>
|
||||
<style>
|
||||
body, html {{
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
|
||||
color: #333333;
|
||||
background-color: #f8f8f8;
|
||||
line-height: 1.5;
|
||||
}}
|
||||
.container {{
|
||||
max-width: 560px;
|
||||
margin: 20px auto;
|
||||
background-color: #ffffff;
|
||||
border-radius: 8px;
|
||||
overflow: hidden;
|
||||
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
|
||||
}}
|
||||
.logo {{
|
||||
padding: 24px 0 0 24px;
|
||||
text-align: left;
|
||||
}}
|
||||
.greeting {{
|
||||
font-size: 16px;
|
||||
margin-bottom: 20px;
|
||||
color: #333333;
|
||||
}}
|
||||
.content {{
|
||||
padding: 24px 40px 32px;
|
||||
}}
|
||||
h1 {{
|
||||
font-size: 24px;
|
||||
font-weight: 600;
|
||||
margin-top: 24px;
|
||||
margin-bottom: 16px;
|
||||
color: #333333;
|
||||
}}
|
||||
p {{
|
||||
font-size: 16px;
|
||||
color: #333333;
|
||||
margin-bottom: 16px;
|
||||
line-height: 1.5;
|
||||
}}
|
||||
.intro-text {{
|
||||
margin-bottom: 24px;
|
||||
}}
|
||||
.link {{
|
||||
color: #6366f1;
|
||||
text-decoration: none;
|
||||
font-weight: 500;
|
||||
}}
|
||||
.link:hover {{
|
||||
text-decoration: underline;
|
||||
}}
|
||||
.link-with-arrow {{
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
color: #6366f1;
|
||||
text-decoration: none;
|
||||
font-weight: 500;
|
||||
margin-bottom: 20px;
|
||||
}}
|
||||
.link-with-arrow:hover {{
|
||||
text-decoration: underline;
|
||||
}}
|
||||
.arrow {{
|
||||
margin-left: 6px;
|
||||
}}
|
||||
.divider {{
|
||||
height: 1px;
|
||||
background-color: #f1f1f1;
|
||||
margin: 24px 0;
|
||||
}}
|
||||
.btn {{
|
||||
display: inline-block;
|
||||
padding: 12px 24px;
|
||||
background-color: #5c5ce0;
|
||||
color: #ffffff !important;
|
||||
text-decoration: none;
|
||||
border-radius: 6px;
|
||||
font-weight: 500;
|
||||
margin-top: 12px;
|
||||
text-align: center;
|
||||
font-size: 15px;
|
||||
transition: background-color 0.2s ease;
|
||||
}}
|
||||
.btn:hover {{
|
||||
background-color: #4b4bb3;
|
||||
}}
|
||||
.btn-container {{
|
||||
text-align: center;
|
||||
margin: 24px 0;
|
||||
}}
|
||||
.footer {{
|
||||
padding: 24px 40px 32px;
|
||||
text-align: left;
|
||||
color: #666;
|
||||
font-size: 14px;
|
||||
}}
|
||||
.quickstart {{
|
||||
margin-top: 32px;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="logo">
|
||||
<img src="{email_logo_url}" alt="LiteLLM Logo" style="height: 32px; width: auto;">
|
||||
</div>
|
||||
<div class="content">
|
||||
<h1>Welcome to LiteLLM</h1>
|
||||
|
||||
<div class="greeting">
|
||||
<p>Hi {recipient_email},</p>
|
||||
</div>
|
||||
|
||||
<div class="intro-text">
|
||||
<p>LiteLLM allows you to call 100+ LLM providers in the OpenAI API format. Get started by accepting your invitation.</p>
|
||||
</div>
|
||||
|
||||
<div class="btn-container">
|
||||
<a href="{base_url}" class="btn">Accept Invitation</a>
|
||||
</div>
|
||||
|
||||
<div class="quickstart">
|
||||
<p>Here's a quickstart guide to get you started:</p>
|
||||
</div>
|
||||
|
||||
<div class="divider"></div>
|
||||
|
||||
<a href="https://docs.litellm.ai/docs/proxy/user_keys" class="link-with-arrow">
|
||||
Make your first LLM request →
|
||||
<span class="arrow"></span>
|
||||
</a>
|
||||
|
||||
<p>Making LLM requests with OpenAI SDK, Langchain, LlamaIndex, and more.</p>
|
||||
|
||||
<div class="divider"></div>
|
||||
|
||||
<a href="https://docs.litellm.ai/docs/supported_endpoints" class="link-with-arrow">
|
||||
Supported Endpoints →
|
||||
<span class="arrow"></span>
|
||||
</a>
|
||||
|
||||
<p>View all supported LLM endpoints on LiteLLM (/chat/completions, /embeddings, /responses etc.)</p>
|
||||
|
||||
<div class="divider"></div>
|
||||
|
||||
<a href="https://docs.litellm.ai/docs/pass_through/vertex_ai" class="link-with-arrow">
|
||||
Passthrough Endpoints →
|
||||
<span class="arrow"></span>
|
||||
</a>
|
||||
|
||||
<p>We support calling VertexAI, Anthropic, and other providers in their native API format.</p>
|
||||
|
||||
<div class="divider"></div>
|
||||
|
||||
<p>Thanks for signing up. We're here to help you and your team. If you have any questions, contact us at {email_support_contact}</p>
|
||||
|
||||
</div>
|
||||
{email_footer}
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
@@ -0,0 +1,113 @@
|
||||
"""Database access helpers for Focus export."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import polars as pl
|
||||
|
||||
|
||||
class FocusLiteLLMDatabase:
|
||||
"""Retrieves LiteLLM usage data for Focus export workflows."""
|
||||
|
||||
def _ensure_prisma_client(self):
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise RuntimeError(
|
||||
"Database not connected. Connect a database to your proxy - "
|
||||
"https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||
)
|
||||
return prisma_client
|
||||
|
||||
async def get_usage_data(
|
||||
self,
|
||||
*,
|
||||
limit: Optional[int] = None,
|
||||
start_time_utc: Optional[datetime] = None,
|
||||
end_time_utc: Optional[datetime] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""Return usage data for the requested window."""
|
||||
client = self._ensure_prisma_client()
|
||||
|
||||
where_clauses: list[str] = []
|
||||
query_params: list[Any] = []
|
||||
placeholder_index = 1
|
||||
if start_time_utc:
|
||||
where_clauses.append(f"dus.updated_at >= ${placeholder_index}::timestamptz")
|
||||
query_params.append(start_time_utc)
|
||||
placeholder_index += 1
|
||||
if end_time_utc:
|
||||
where_clauses.append(f"dus.updated_at <= ${placeholder_index}::timestamptz")
|
||||
query_params.append(end_time_utc)
|
||||
placeholder_index += 1
|
||||
|
||||
where_clause = ""
|
||||
if where_clauses:
|
||||
where_clause = "WHERE " + " AND ".join(where_clauses)
|
||||
|
||||
limit_clause = ""
|
||||
if limit is not None:
|
||||
try:
|
||||
limit_value = int(limit)
|
||||
except (TypeError, ValueError) as exc: # pragma: no cover - defensive guard
|
||||
raise ValueError("limit must be an integer") from exc
|
||||
if limit_value < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
limit_clause = f" LIMIT ${placeholder_index}"
|
||||
query_params.append(limit_value)
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
dus.id,
|
||||
dus.date,
|
||||
dus.user_id,
|
||||
dus.api_key,
|
||||
dus.model,
|
||||
dus.model_group,
|
||||
dus.custom_llm_provider,
|
||||
dus.prompt_tokens,
|
||||
dus.completion_tokens,
|
||||
dus.spend,
|
||||
dus.api_requests,
|
||||
dus.successful_requests,
|
||||
dus.failed_requests,
|
||||
dus.cache_creation_input_tokens,
|
||||
dus.cache_read_input_tokens,
|
||||
dus.created_at,
|
||||
dus.updated_at,
|
||||
vt.team_id,
|
||||
vt.key_alias as api_key_alias,
|
||||
tt.team_alias,
|
||||
ut.user_email as user_email
|
||||
FROM "LiteLLM_DailyUserSpend" dus
|
||||
LEFT JOIN "LiteLLM_VerificationToken" vt ON dus.api_key = vt.token
|
||||
LEFT JOIN "LiteLLM_TeamTable" tt ON vt.team_id = tt.team_id
|
||||
LEFT JOIN "LiteLLM_UserTable" ut ON dus.user_id = ut.user_id
|
||||
{where_clause}
|
||||
ORDER BY dus.date DESC, dus.created_at DESC
|
||||
{limit_clause}
|
||||
"""
|
||||
|
||||
try:
|
||||
db_response = await client.db.query_raw(query, *query_params)
|
||||
return pl.DataFrame(db_response, infer_schema_length=None)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"Error retrieving usage data: {exc}") from exc
|
||||
|
||||
async def get_table_info(self) -> Dict[str, Any]:
|
||||
"""Return metadata about the spend table for diagnostics."""
|
||||
client = self._ensure_prisma_client()
|
||||
|
||||
info_query = """
|
||||
SELECT column_name, data_type, is_nullable
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'LiteLLM_DailyUserSpend'
|
||||
ORDER BY ordinal_position;
|
||||
"""
|
||||
try:
|
||||
columns_response = await client.db.query_raw(info_query)
|
||||
return {"columns": columns_response, "table_name": "LiteLLM_DailyUserSpend"}
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"Error getting table info: {exc}") from exc
|
||||
@@ -0,0 +1,12 @@
|
||||
"""Destination implementations for Focus export."""
|
||||
|
||||
from .base import FocusDestination, FocusTimeWindow
|
||||
from .factory import FocusDestinationFactory
|
||||
from .s3_destination import FocusS3Destination
|
||||
|
||||
__all__ = [
|
||||
"FocusDestination",
|
||||
"FocusDestinationFactory",
|
||||
"FocusTimeWindow",
|
||||
"FocusS3Destination",
|
||||
]
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Abstract destination interfaces for Focus export."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FocusTimeWindow:
|
||||
"""Represents the span of data exported in a single batch."""
|
||||
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
frequency: str
|
||||
|
||||
|
||||
class FocusDestination(Protocol):
|
||||
"""Protocol for anything that can receive Focus export files."""
|
||||
|
||||
async def deliver(
|
||||
self,
|
||||
*,
|
||||
content: bytes,
|
||||
time_window: FocusTimeWindow,
|
||||
filename: str,
|
||||
) -> None:
|
||||
"""Persist the serialized export for the provided time window."""
|
||||
...
|
||||
@@ -0,0 +1,59 @@
|
||||
"""Factory helpers for Focus export destinations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .base import FocusDestination
|
||||
from .s3_destination import FocusS3Destination
|
||||
|
||||
|
||||
class FocusDestinationFactory:
|
||||
"""Builds destination instances based on provider/config settings."""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
*,
|
||||
provider: str,
|
||||
prefix: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> FocusDestination:
|
||||
"""Return a destination implementation for the requested provider."""
|
||||
provider_lower = provider.lower()
|
||||
normalized_config = FocusDestinationFactory._resolve_config(
|
||||
provider=provider_lower, overrides=config or {}
|
||||
)
|
||||
if provider_lower == "s3":
|
||||
return FocusS3Destination(prefix=prefix, config=normalized_config)
|
||||
raise NotImplementedError(
|
||||
f"Provider '{provider}' not supported for Focus export"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_config(
|
||||
*,
|
||||
provider: str,
|
||||
overrides: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
if provider == "s3":
|
||||
resolved = {
|
||||
"bucket_name": overrides.get("bucket_name")
|
||||
or os.getenv("FOCUS_S3_BUCKET_NAME"),
|
||||
"region_name": overrides.get("region_name")
|
||||
or os.getenv("FOCUS_S3_REGION_NAME"),
|
||||
"endpoint_url": overrides.get("endpoint_url")
|
||||
or os.getenv("FOCUS_S3_ENDPOINT_URL"),
|
||||
"aws_access_key_id": overrides.get("aws_access_key_id")
|
||||
or os.getenv("FOCUS_S3_ACCESS_KEY"),
|
||||
"aws_secret_access_key": overrides.get("aws_secret_access_key")
|
||||
or os.getenv("FOCUS_S3_SECRET_KEY"),
|
||||
"aws_session_token": overrides.get("aws_session_token")
|
||||
or os.getenv("FOCUS_S3_SESSION_TOKEN"),
|
||||
}
|
||||
if not resolved.get("bucket_name"):
|
||||
raise ValueError("FOCUS_S3_BUCKET_NAME must be provided for S3 exports")
|
||||
return {k: v for k, v in resolved.items() if v is not None}
|
||||
raise NotImplementedError(
|
||||
f"Provider '{provider}' not supported for Focus export configuration"
|
||||
)
|
||||
@@ -0,0 +1,74 @@
|
||||
"""S3 destination implementation for Focus export."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
|
||||
from .base import FocusDestination, FocusTimeWindow
|
||||
|
||||
|
||||
class FocusS3Destination(FocusDestination):
|
||||
"""Handles uploading serialized exports to S3 buckets."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
prefix: str,
|
||||
config: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
config = config or {}
|
||||
bucket_name = config.get("bucket_name")
|
||||
if not bucket_name:
|
||||
raise ValueError("bucket_name must be provided for S3 destination")
|
||||
self.bucket_name = bucket_name
|
||||
self.prefix = prefix.rstrip("/")
|
||||
self.config = config
|
||||
|
||||
async def deliver(
|
||||
self,
|
||||
*,
|
||||
content: bytes,
|
||||
time_window: FocusTimeWindow,
|
||||
filename: str,
|
||||
) -> None:
|
||||
object_key = self._build_object_key(time_window=time_window, filename=filename)
|
||||
await asyncio.to_thread(self._upload, content, object_key)
|
||||
|
||||
def _build_object_key(self, *, time_window: FocusTimeWindow, filename: str) -> str:
|
||||
start_utc = time_window.start_time.astimezone(timezone.utc)
|
||||
date_component = f"date={start_utc.strftime('%Y-%m-%d')}"
|
||||
parts = [self.prefix, date_component]
|
||||
if time_window.frequency == "hourly":
|
||||
parts.append(f"hour={start_utc.strftime('%H')}")
|
||||
key_prefix = "/".join(filter(None, parts))
|
||||
return f"{key_prefix}/{filename}" if key_prefix else filename
|
||||
|
||||
def _upload(self, content: bytes, object_key: str) -> None:
|
||||
client_kwargs: dict[str, Any] = {}
|
||||
region_name = self.config.get("region_name")
|
||||
if region_name:
|
||||
client_kwargs["region_name"] = region_name
|
||||
endpoint_url = self.config.get("endpoint_url")
|
||||
if endpoint_url:
|
||||
client_kwargs["endpoint_url"] = endpoint_url
|
||||
|
||||
session_kwargs: dict[str, Any] = {}
|
||||
for key in (
|
||||
"aws_access_key_id",
|
||||
"aws_secret_access_key",
|
||||
"aws_session_token",
|
||||
):
|
||||
if self.config.get(key):
|
||||
session_kwargs[key] = self.config[key]
|
||||
|
||||
s3_client = boto3.client("s3", **client_kwargs, **session_kwargs)
|
||||
s3_client.put_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=object_key,
|
||||
Body=content,
|
||||
ContentType="application/octet-stream",
|
||||
)
|
||||
@@ -0,0 +1,124 @@
|
||||
"""Core export engine for Focus integrations (heavy dependencies)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import polars as pl
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
from .database import FocusLiteLLMDatabase
|
||||
from .destinations import FocusDestinationFactory, FocusTimeWindow
|
||||
from .serializers import FocusParquetSerializer, FocusSerializer
|
||||
from .transformer import FocusTransformer
|
||||
|
||||
|
||||
class FocusExportEngine:
|
||||
"""Engine that fetches, normalizes, and uploads Focus exports."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
export_format: str,
|
||||
prefix: str,
|
||||
destination_config: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.provider = provider
|
||||
self.export_format = export_format
|
||||
self.prefix = prefix
|
||||
self._destination = FocusDestinationFactory.create(
|
||||
provider=self.provider,
|
||||
prefix=self.prefix,
|
||||
config=destination_config,
|
||||
)
|
||||
self._serializer = self._init_serializer()
|
||||
self._transformer = FocusTransformer()
|
||||
self._database = FocusLiteLLMDatabase()
|
||||
|
||||
def _init_serializer(self) -> FocusSerializer:
|
||||
if self.export_format != "parquet":
|
||||
raise NotImplementedError("Only parquet export supported currently")
|
||||
return FocusParquetSerializer()
|
||||
|
||||
async def dry_run_export_usage_data(self, limit: Optional[int]) -> Dict[str, Any]:
|
||||
data = await self._database.get_usage_data(limit=limit)
|
||||
normalized = self._transformer.transform(data)
|
||||
|
||||
usage_sample = data.head(min(50, len(data))).to_dicts()
|
||||
normalized_sample = normalized.head(min(50, len(normalized))).to_dicts()
|
||||
|
||||
summary = {
|
||||
"total_records": len(normalized),
|
||||
"total_spend": self._sum_column(normalized, "spend"),
|
||||
"total_tokens": self._sum_column(normalized, "total_tokens"),
|
||||
"unique_teams": self._count_unique(normalized, "team_id"),
|
||||
"unique_models": self._count_unique(normalized, "model"),
|
||||
}
|
||||
|
||||
return {
|
||||
"usage_data": usage_sample,
|
||||
"normalized_data": normalized_sample,
|
||||
"summary": summary,
|
||||
}
|
||||
|
||||
async def export_window(
|
||||
self,
|
||||
*,
|
||||
window: FocusTimeWindow,
|
||||
limit: Optional[int],
|
||||
) -> None:
|
||||
data = await self._database.get_usage_data(
|
||||
limit=limit,
|
||||
start_time_utc=window.start_time,
|
||||
end_time_utc=window.end_time,
|
||||
)
|
||||
if data.is_empty():
|
||||
verbose_logger.debug("Focus export: no usage data for window %s", window)
|
||||
return
|
||||
|
||||
normalized = self._transformer.transform(data)
|
||||
if normalized.is_empty():
|
||||
verbose_logger.debug(
|
||||
"Focus export: normalized data empty for window %s", window
|
||||
)
|
||||
return
|
||||
|
||||
await self._serialize_and_upload(normalized, window)
|
||||
|
||||
async def _serialize_and_upload(
|
||||
self, frame: pl.DataFrame, window: FocusTimeWindow
|
||||
) -> None:
|
||||
payload = self._serializer.serialize(frame)
|
||||
if not payload:
|
||||
verbose_logger.debug("Focus export: serializer returned empty payload")
|
||||
return
|
||||
await self._destination.deliver(
|
||||
content=payload,
|
||||
time_window=window,
|
||||
filename=self._build_filename(),
|
||||
)
|
||||
|
||||
def _build_filename(self) -> str:
|
||||
if not self._serializer.extension:
|
||||
raise ValueError("Serializer must declare a file extension")
|
||||
return f"usage.{self._serializer.extension}"
|
||||
|
||||
@staticmethod
|
||||
def _sum_column(frame: pl.DataFrame, column: str) -> float:
|
||||
if frame.is_empty() or column not in frame.columns:
|
||||
return 0.0
|
||||
value = frame.select(pl.col(column).sum().alias("sum")).row(0)[0]
|
||||
if value is None:
|
||||
return 0.0
|
||||
return float(value)
|
||||
|
||||
@staticmethod
|
||||
def _count_unique(frame: pl.DataFrame, column: str) -> int:
|
||||
if frame.is_empty() or column not in frame.columns:
|
||||
return 0
|
||||
value = frame.select(pl.col(column).n_unique().alias("unique")).row(0)[0]
|
||||
if value is None:
|
||||
return 0
|
||||
return int(value)
|
||||
@@ -0,0 +1,214 @@
|
||||
"""Focus export logger orchestrating DB pull/transform/upload."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
from .destinations import FocusTimeWindow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from .export_engine import FocusExportEngine
|
||||
else:
|
||||
AsyncIOScheduler = Any
|
||||
|
||||
FOCUS_USAGE_DATA_JOB_NAME = "focus_export_usage_data"
|
||||
DEFAULT_DRY_RUN_LIMIT = 500
|
||||
|
||||
|
||||
class FocusLogger(CustomLogger):
|
||||
"""Coordinates Focus export jobs across transformer/serializer/destination layers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
export_format: Optional[str] = None,
|
||||
frequency: Optional[str] = None,
|
||||
cron_offset_minute: Optional[int] = None,
|
||||
interval_seconds: Optional[int] = None,
|
||||
prefix: Optional[str] = None,
|
||||
destination_config: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.provider = (provider or os.getenv("FOCUS_PROVIDER") or "s3").lower()
|
||||
self.export_format = (
|
||||
export_format or os.getenv("FOCUS_FORMAT") or "parquet"
|
||||
).lower()
|
||||
self.frequency = (frequency or os.getenv("FOCUS_FREQUENCY") or "hourly").lower()
|
||||
self.cron_offset_minute = (
|
||||
cron_offset_minute
|
||||
if cron_offset_minute is not None
|
||||
else int(os.getenv("FOCUS_CRON_OFFSET", "5"))
|
||||
)
|
||||
raw_interval = (
|
||||
interval_seconds
|
||||
if interval_seconds is not None
|
||||
else os.getenv("FOCUS_INTERVAL_SECONDS")
|
||||
)
|
||||
self.interval_seconds = int(raw_interval) if raw_interval is not None else None
|
||||
env_prefix = os.getenv("FOCUS_PREFIX")
|
||||
self.prefix: str = (
|
||||
prefix
|
||||
if prefix is not None
|
||||
else (env_prefix if env_prefix else "focus_exports")
|
||||
)
|
||||
|
||||
self._destination_config = destination_config
|
||||
self._engine: Optional["FocusExportEngine"] = None
|
||||
|
||||
def _ensure_engine(self) -> "FocusExportEngine":
|
||||
"""Instantiate the heavy export engine lazily."""
|
||||
if self._engine is None:
|
||||
from .export_engine import FocusExportEngine
|
||||
|
||||
self._engine = FocusExportEngine(
|
||||
provider=self.provider,
|
||||
export_format=self.export_format,
|
||||
prefix=self.prefix,
|
||||
destination_config=self._destination_config,
|
||||
)
|
||||
return self._engine
|
||||
|
||||
async def export_usage_data(
|
||||
self,
|
||||
*,
|
||||
limit: Optional[int] = None,
|
||||
start_time_utc: Optional[datetime] = None,
|
||||
end_time_utc: Optional[datetime] = None,
|
||||
) -> None:
|
||||
"""Public hook to trigger export immediately."""
|
||||
if bool(start_time_utc) ^ bool(end_time_utc):
|
||||
raise ValueError(
|
||||
"start_time_utc and end_time_utc must be provided together"
|
||||
)
|
||||
|
||||
if start_time_utc and end_time_utc:
|
||||
window = FocusTimeWindow(
|
||||
start_time=start_time_utc,
|
||||
end_time=end_time_utc,
|
||||
frequency=self.frequency,
|
||||
)
|
||||
else:
|
||||
window = self._compute_time_window(datetime.now(timezone.utc))
|
||||
await self._export_window(window=window, limit=limit)
|
||||
|
||||
async def dry_run_export_usage_data(
|
||||
self, limit: Optional[int] = DEFAULT_DRY_RUN_LIMIT
|
||||
) -> dict[str, Any]:
|
||||
"""Return transformed data without uploading."""
|
||||
engine = self._ensure_engine()
|
||||
return await engine.dry_run_export_usage_data(limit=limit)
|
||||
|
||||
async def initialize_focus_export_job(self) -> None:
|
||||
"""Entry point for scheduler jobs to run export cycle with locking."""
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
pod_lock_manager = None
|
||||
if proxy_logging_obj is not None:
|
||||
writer = getattr(proxy_logging_obj, "db_spend_update_writer", None)
|
||||
if writer is not None:
|
||||
pod_lock_manager = getattr(writer, "pod_lock_manager", None)
|
||||
|
||||
if pod_lock_manager and pod_lock_manager.redis_cache:
|
||||
acquired = await pod_lock_manager.acquire_lock(
|
||||
cronjob_id=FOCUS_USAGE_DATA_JOB_NAME
|
||||
)
|
||||
if not acquired:
|
||||
verbose_logger.debug("Focus export: unable to acquire pod lock")
|
||||
return
|
||||
try:
|
||||
await self._run_scheduled_export()
|
||||
finally:
|
||||
await pod_lock_manager.release_lock(
|
||||
cronjob_id=FOCUS_USAGE_DATA_JOB_NAME
|
||||
)
|
||||
else:
|
||||
await self._run_scheduled_export()
|
||||
|
||||
@staticmethod
|
||||
async def init_focus_export_background_job(
|
||||
scheduler: AsyncIOScheduler,
|
||||
) -> None:
|
||||
"""Register the export cron/interval job with the provided scheduler."""
|
||||
|
||||
focus_loggers: List[
|
||||
CustomLogger
|
||||
] = litellm.logging_callback_manager.get_custom_loggers_for_type(
|
||||
callback_type=FocusLogger
|
||||
)
|
||||
if not focus_loggers:
|
||||
verbose_logger.debug(
|
||||
"No Focus export logger registered; skipping scheduler"
|
||||
)
|
||||
return
|
||||
|
||||
focus_logger = cast(FocusLogger, focus_loggers[0])
|
||||
trigger_kwargs = focus_logger._build_scheduler_trigger()
|
||||
scheduler.add_job(
|
||||
focus_logger.initialize_focus_export_job,
|
||||
**trigger_kwargs,
|
||||
)
|
||||
|
||||
def _build_scheduler_trigger(self) -> Dict[str, Any]:
|
||||
"""Return scheduler configuration for the selected frequency."""
|
||||
if self.frequency == "interval":
|
||||
seconds = self.interval_seconds or 60
|
||||
return {"trigger": "interval", "seconds": seconds}
|
||||
|
||||
if self.frequency == "hourly":
|
||||
minute = max(0, min(59, self.cron_offset_minute))
|
||||
return {"trigger": "cron", "minute": minute, "second": 0}
|
||||
|
||||
if self.frequency == "daily":
|
||||
total_minutes = max(0, self.cron_offset_minute)
|
||||
hour = min(23, total_minutes // 60)
|
||||
minute = min(59, total_minutes % 60)
|
||||
return {"trigger": "cron", "hour": hour, "minute": minute, "second": 0}
|
||||
|
||||
raise ValueError(f"Unsupported frequency: {self.frequency}")
|
||||
|
||||
async def _run_scheduled_export(self) -> None:
|
||||
"""Execute the scheduled export for the configured window."""
|
||||
window = self._compute_time_window(datetime.now(timezone.utc))
|
||||
await self._export_window(window=window, limit=None)
|
||||
|
||||
async def _export_window(
|
||||
self,
|
||||
*,
|
||||
window: FocusTimeWindow,
|
||||
limit: Optional[int],
|
||||
) -> None:
|
||||
engine = self._ensure_engine()
|
||||
await engine.export_window(window=window, limit=limit)
|
||||
|
||||
def _compute_time_window(self, now: datetime) -> FocusTimeWindow:
|
||||
"""Derive the time window to export based on configured frequency."""
|
||||
now_utc = now.astimezone(timezone.utc)
|
||||
if self.frequency == "hourly":
|
||||
end_time = now_utc.replace(minute=0, second=0, microsecond=0)
|
||||
start_time = end_time - timedelta(hours=1)
|
||||
elif self.frequency == "daily":
|
||||
end_time = now_utc.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
start_time = end_time - timedelta(days=1)
|
||||
elif self.frequency == "interval":
|
||||
interval = timedelta(seconds=self.interval_seconds or 60)
|
||||
end_time = now_utc
|
||||
start_time = end_time - interval
|
||||
else:
|
||||
raise ValueError(f"Unsupported frequency: {self.frequency}")
|
||||
return FocusTimeWindow(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
frequency=self.frequency,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["FocusLogger"]
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Schema definitions for Focus export data."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import polars as pl
|
||||
|
||||
# see: https://focus.finops.org/focus-specification/v1-2/
|
||||
FOCUS_NORMALIZED_SCHEMA = pl.Schema(
|
||||
[
|
||||
("BilledCost", pl.Decimal(18, 6)),
|
||||
("BillingAccountId", pl.String),
|
||||
("BillingAccountName", pl.String),
|
||||
("BillingCurrency", pl.String),
|
||||
("BillingPeriodStart", pl.Datetime(time_unit="us")),
|
||||
("BillingPeriodEnd", pl.Datetime(time_unit="us")),
|
||||
("ChargeCategory", pl.String),
|
||||
("ChargeClass", pl.String),
|
||||
("ChargeDescription", pl.String),
|
||||
("ChargeFrequency", pl.String),
|
||||
("ChargePeriodStart", pl.Datetime(time_unit="us")),
|
||||
("ChargePeriodEnd", pl.Datetime(time_unit="us")),
|
||||
("ConsumedQuantity", pl.Decimal(18, 6)),
|
||||
("ConsumedUnit", pl.String),
|
||||
("ContractedCost", pl.Decimal(18, 6)),
|
||||
("ContractedUnitPrice", pl.Decimal(18, 6)),
|
||||
("EffectiveCost", pl.Decimal(18, 6)),
|
||||
("InvoiceIssuerName", pl.String),
|
||||
("ListCost", pl.Decimal(18, 6)),
|
||||
("ListUnitPrice", pl.Decimal(18, 6)),
|
||||
("PricingCategory", pl.String),
|
||||
("PricingQuantity", pl.Decimal(18, 6)),
|
||||
("PricingUnit", pl.String),
|
||||
("ProviderName", pl.String),
|
||||
("PublisherName", pl.String),
|
||||
("RegionId", pl.String),
|
||||
("RegionName", pl.String),
|
||||
("ResourceId", pl.String),
|
||||
("ResourceName", pl.String),
|
||||
("ResourceType", pl.String),
|
||||
("ServiceCategory", pl.String),
|
||||
("ServiceSubcategory", pl.String),
|
||||
("ServiceName", pl.String),
|
||||
("SubAccountId", pl.String),
|
||||
("SubAccountName", pl.String),
|
||||
("SubAccountType", pl.String),
|
||||
("Tags", pl.Object),
|
||||
]
|
||||
)
|
||||
|
||||
__all__ = ["FOCUS_NORMALIZED_SCHEMA"]
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Serializer package exports for Focus integration."""
|
||||
|
||||
from .base import FocusSerializer
|
||||
from .parquet import FocusParquetSerializer
|
||||
|
||||
__all__ = ["FocusSerializer", "FocusParquetSerializer"]
|
||||
@@ -0,0 +1,18 @@
|
||||
"""Serializer abstractions for Focus export."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import polars as pl
|
||||
|
||||
|
||||
class FocusSerializer(ABC):
|
||||
"""Base serializer turning Focus frames into bytes."""
|
||||
|
||||
extension: str = ""
|
||||
|
||||
@abstractmethod
|
||||
def serialize(self, frame: pl.DataFrame) -> bytes:
|
||||
"""Convert the normalized Focus frame into the chosen format."""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,22 @@
|
||||
"""Parquet serializer for Focus export."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
|
||||
import polars as pl
|
||||
|
||||
from .base import FocusSerializer
|
||||
|
||||
|
||||
class FocusParquetSerializer(FocusSerializer):
|
||||
"""Serialize normalized Focus frames to Parquet bytes."""
|
||||
|
||||
extension = "parquet"
|
||||
|
||||
def serialize(self, frame: pl.DataFrame) -> bytes:
|
||||
"""Encode the provided frame as a parquet payload."""
|
||||
target = frame if not frame.is_empty() else pl.DataFrame(schema=frame.schema)
|
||||
buffer = io.BytesIO()
|
||||
target.write_parquet(buffer, compression="snappy")
|
||||
return buffer.getvalue()
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Focus export data transformer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
import polars as pl
|
||||
|
||||
from .schema import FOCUS_NORMALIZED_SCHEMA
|
||||
|
||||
|
||||
class FocusTransformer:
|
||||
"""Transforms LiteLLM DB rows into Focus-compatible schema."""
|
||||
|
||||
schema = FOCUS_NORMALIZED_SCHEMA
|
||||
|
||||
def transform(self, frame: pl.DataFrame) -> pl.DataFrame:
|
||||
"""Return a normalized frame expected by downstream serializers."""
|
||||
if frame.is_empty():
|
||||
return pl.DataFrame(schema=self.schema)
|
||||
|
||||
# derive period start/end from usage date
|
||||
frame = frame.with_columns(
|
||||
pl.col("date")
|
||||
.cast(pl.Utf8)
|
||||
.str.strptime(pl.Datetime(time_unit="us"), format="%Y-%m-%d", strict=False)
|
||||
.alias("usage_date"),
|
||||
)
|
||||
frame = frame.with_columns(
|
||||
pl.col("usage_date").alias("ChargePeriodStart"),
|
||||
(pl.col("usage_date") + timedelta(days=1)).alias("ChargePeriodEnd"),
|
||||
)
|
||||
|
||||
def fmt(col):
|
||||
return col.dt.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
DEC = pl.Decimal(18, 6)
|
||||
|
||||
def dec(col):
|
||||
return col.cast(DEC)
|
||||
|
||||
none_str = pl.lit(None, dtype=pl.Utf8)
|
||||
none_dec = pl.lit(None, dtype=pl.Decimal(18, 6))
|
||||
|
||||
return frame.select(
|
||||
dec(pl.col("spend").fill_null(0.0)).alias("BilledCost"),
|
||||
pl.col("api_key").cast(pl.String).alias("BillingAccountId"),
|
||||
pl.col("api_key_alias").cast(pl.String).alias("BillingAccountName"),
|
||||
pl.lit("API Key").alias("BillingAccountType"),
|
||||
pl.lit("USD").alias("BillingCurrency"),
|
||||
fmt(pl.col("ChargePeriodEnd")).alias("BillingPeriodEnd"),
|
||||
fmt(pl.col("ChargePeriodStart")).alias("BillingPeriodStart"),
|
||||
pl.lit("Usage").alias("ChargeCategory"),
|
||||
none_str.alias("ChargeClass"),
|
||||
pl.col("model").cast(pl.String).alias("ChargeDescription"),
|
||||
pl.lit("Usage-Based").alias("ChargeFrequency"),
|
||||
fmt(pl.col("ChargePeriodEnd")).alias("ChargePeriodEnd"),
|
||||
fmt(pl.col("ChargePeriodStart")).alias("ChargePeriodStart"),
|
||||
dec(pl.lit(1.0)).alias("ConsumedQuantity"),
|
||||
pl.lit("Requests").alias("ConsumedUnit"),
|
||||
dec(pl.col("spend").fill_null(0.0)).alias("ContractedCost"),
|
||||
none_str.alias("ContractedUnitPrice"),
|
||||
dec(pl.col("spend").fill_null(0.0)).alias("EffectiveCost"),
|
||||
pl.col("custom_llm_provider").cast(pl.String).alias("InvoiceIssuerName"),
|
||||
none_str.alias("InvoiceId"),
|
||||
dec(pl.col("spend").fill_null(0.0)).alias("ListCost"),
|
||||
none_dec.alias("ListUnitPrice"),
|
||||
none_str.alias("AvailabilityZone"),
|
||||
pl.lit("USD").alias("PricingCurrency"),
|
||||
none_str.alias("PricingCategory"),
|
||||
dec(pl.lit(1.0)).alias("PricingQuantity"),
|
||||
none_dec.alias("PricingCurrencyContractedUnitPrice"),
|
||||
dec(pl.col("spend").fill_null(0.0)).alias("PricingCurrencyEffectiveCost"),
|
||||
none_dec.alias("PricingCurrencyListUnitPrice"),
|
||||
pl.lit("Requests").alias("PricingUnit"),
|
||||
pl.col("custom_llm_provider").cast(pl.String).alias("ProviderName"),
|
||||
pl.col("custom_llm_provider").cast(pl.String).alias("PublisherName"),
|
||||
none_str.alias("RegionId"),
|
||||
none_str.alias("RegionName"),
|
||||
pl.col("model").cast(pl.String).alias("ResourceId"),
|
||||
pl.col("model").cast(pl.String).alias("ResourceName"),
|
||||
pl.col("model").cast(pl.String).alias("ResourceType"),
|
||||
pl.lit("AI and Machine Learning").alias("ServiceCategory"),
|
||||
pl.lit("Generative AI").alias("ServiceSubcategory"),
|
||||
pl.col("model_group").cast(pl.String).alias("ServiceName"),
|
||||
pl.col("team_id").cast(pl.String).alias("SubAccountId"),
|
||||
pl.col("team_alias").cast(pl.String).alias("SubAccountName"),
|
||||
none_str.alias("SubAccountType"),
|
||||
none_str.alias("Tags"),
|
||||
)
|
||||
@@ -0,0 +1,157 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
|
||||
|
||||
# from here: https://docs.rungalileo.io/galileo/gen-ai-studio-products/galileo-observe/how-to/logging-data-via-restful-apis#structuring-your-records
|
||||
class LLMResponse(BaseModel):
|
||||
latency_ms: int
|
||||
status_code: int
|
||||
input_text: str
|
||||
output_text: str
|
||||
node_type: str
|
||||
model: str
|
||||
num_input_tokens: int
|
||||
num_output_tokens: int
|
||||
output_logprobs: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Optional. When available, logprobs are used to compute Uncertainty.",
|
||||
)
|
||||
created_at: str = Field(
|
||||
..., description='timestamp constructed in "%Y-%m-%dT%H:%M:%S" format'
|
||||
)
|
||||
tags: Optional[List[str]] = None
|
||||
user_metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class GalileoObserve(CustomLogger):
|
||||
def __init__(self) -> None:
|
||||
self.in_memory_records: List[dict] = []
|
||||
self.batch_size = 1
|
||||
self.base_url = os.getenv("GALILEO_BASE_URL", None)
|
||||
self.project_id = os.getenv("GALILEO_PROJECT_ID", None)
|
||||
self.headers: Optional[Dict[str, str]] = None
|
||||
self.async_httpx_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
pass
|
||||
|
||||
def set_galileo_headers(self):
|
||||
# following https://docs.rungalileo.io/galileo/gen-ai-studio-products/galileo-observe/how-to/logging-data-via-restful-apis#logging-your-records
|
||||
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
galileo_login_response = litellm.module_level_client.post(
|
||||
url=f"{self.base_url}/login",
|
||||
headers=headers,
|
||||
data={
|
||||
"username": os.getenv("GALILEO_USERNAME"),
|
||||
"password": os.getenv("GALILEO_PASSWORD"),
|
||||
},
|
||||
)
|
||||
|
||||
access_token = galileo_login_response.json()["access_token"]
|
||||
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
def get_output_str_from_response(self, response_obj, kwargs):
|
||||
output = None
|
||||
if response_obj is not None and (
|
||||
kwargs.get("call_type", None) == "embedding"
|
||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
):
|
||||
output = None
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ModelResponse
|
||||
):
|
||||
output = response_obj["choices"][0]["message"].json()
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.TextCompletionResponse
|
||||
):
|
||||
output = response_obj.choices[0].text
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ImageResponse
|
||||
):
|
||||
output = response_obj["data"]
|
||||
|
||||
return output
|
||||
|
||||
async def async_log_success_event(
|
||||
self, kwargs: Any, response_obj: Any, start_time: Any, end_time: Any
|
||||
):
|
||||
verbose_logger.debug("On Async Success")
|
||||
|
||||
_latency_ms = int((end_time - start_time).total_seconds() * 1000)
|
||||
_call_type = kwargs.get("call_type", "litellm")
|
||||
input_text = litellm.utils.get_formatted_prompt(
|
||||
data=kwargs, call_type=_call_type
|
||||
)
|
||||
|
||||
_usage = response_obj.get("usage", {}) or {}
|
||||
num_input_tokens = _usage.get("prompt_tokens", 0)
|
||||
num_output_tokens = _usage.get("completion_tokens", 0)
|
||||
|
||||
output_text = self.get_output_str_from_response(
|
||||
response_obj=response_obj, kwargs=kwargs
|
||||
)
|
||||
|
||||
if output_text is not None:
|
||||
request_record = LLMResponse(
|
||||
latency_ms=_latency_ms,
|
||||
status_code=200,
|
||||
input_text=input_text,
|
||||
output_text=output_text,
|
||||
node_type=_call_type,
|
||||
model=kwargs.get("model", "-"),
|
||||
num_input_tokens=num_input_tokens,
|
||||
num_output_tokens=num_output_tokens,
|
||||
created_at=start_time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%S"
|
||||
), # timestamp str constructed in "%Y-%m-%dT%H:%M:%S" format
|
||||
)
|
||||
|
||||
# dump to dict
|
||||
request_dict = request_record.model_dump()
|
||||
self.in_memory_records.append(request_dict)
|
||||
|
||||
if len(self.in_memory_records) >= self.batch_size:
|
||||
await self.flush_in_memory_records()
|
||||
|
||||
async def flush_in_memory_records(self):
|
||||
verbose_logger.debug("flushing in memory records")
|
||||
response = await self.async_httpx_handler.post(
|
||||
url=f"{self.base_url}/projects/{self.project_id}/observe/ingest",
|
||||
headers=self.headers,
|
||||
json={"records": self.in_memory_records},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
verbose_logger.debug(
|
||||
"Galileo Logger:successfully flushed in memory records"
|
||||
)
|
||||
self.in_memory_records = []
|
||||
else:
|
||||
verbose_logger.debug("Galileo Logger: failed to flush in memory records")
|
||||
verbose_logger.debug(
|
||||
"Galileo Logger error=%s, status code=%s",
|
||||
response.text,
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
verbose_logger.debug("On Async Failure")
|
||||
@@ -0,0 +1,12 @@
|
||||
# GCS (Google Cloud Storage) Bucket Logging on LiteLLM Gateway
|
||||
|
||||
This folder contains the GCS Bucket Logging integration for LiteLLM Gateway.
|
||||
|
||||
## Folder Structure
|
||||
|
||||
- `gcs_bucket.py`: This is the main file that handles failure/success logging to GCS Bucket
|
||||
- `gcs_bucket_base.py`: This file contains the GCSBucketBase class which handles Authentication for GCS Buckets
|
||||
|
||||
## Further Reading
|
||||
- [Doc setting up GCS Bucket Logging on LiteLLM Proxy (Gateway)](https://docs.litellm.ai/docs/observability/gcs_bucket_integration)
|
||||
- [Doc on Key / Team Based logging with GCS](https://docs.litellm.ai/docs/proxy/team_logging)
|
||||
@@ -0,0 +1,419 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from litellm._uuid import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import LITELLM_ASYNCIO_QUEUE_MAXSIZE
|
||||
from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
|
||||
from litellm.types.integrations.gcs_bucket import *
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
else:
|
||||
VertexBase = Any
|
||||
|
||||
|
||||
class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
|
||||
def __init__(self, bucket_name: Optional[str] = None) -> None:
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
super().__init__(bucket_name=bucket_name)
|
||||
|
||||
self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE))
|
||||
self.flush_interval = int(
|
||||
os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS)
|
||||
)
|
||||
self.use_batched_logging = (
|
||||
os.getenv(
|
||||
"GCS_USE_BATCHED_LOGGING", str(GCS_DEFAULT_USE_BATCHED_LOGGING).lower()
|
||||
).lower()
|
||||
== "true"
|
||||
)
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(
|
||||
flush_lock=self.flush_lock,
|
||||
batch_size=self.batch_size,
|
||||
flush_interval=self.flush_interval,
|
||||
)
|
||||
self.log_queue: asyncio.Queue[GCSLogQueueItem] = asyncio.Queue( # type: ignore[assignment]
|
||||
maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE
|
||||
)
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
AdditionalLoggingUtils.__init__(self)
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
|
||||
#### ASYNC ####
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"GCS Logger: async_log_success_event logging kwargs: %s, response_obj: %s",
|
||||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
if logging_payload is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
# When queue is at maxsize, flush immediately to make room (no blocking, no data dropped)
|
||||
if self.log_queue.full():
|
||||
await self.flush_queue()
|
||||
await self.log_queue.put(
|
||||
GCSLogQueueItem(
|
||||
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
|
||||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
|
||||
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
if logging_payload is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
# When queue is at maxsize, flush immediately to make room (no blocking, no data dropped)
|
||||
if self.log_queue.full():
|
||||
await self.flush_queue()
|
||||
await self.log_queue.put(
|
||||
GCSLogQueueItem(
|
||||
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
|
||||
|
||||
def _drain_queue_batch(self) -> List[GCSLogQueueItem]:
|
||||
"""
|
||||
Drain items from the queue (non-blocking), respecting batch_size limit.
|
||||
|
||||
This prevents unbounded queue growth when processing is slower than log accumulation.
|
||||
|
||||
Returns:
|
||||
List of items to process, up to batch_size items
|
||||
"""
|
||||
items_to_process: List[GCSLogQueueItem] = []
|
||||
while len(items_to_process) < self.batch_size:
|
||||
try:
|
||||
items_to_process.append(self.log_queue.get_nowait())
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
return items_to_process
|
||||
|
||||
def _generate_batch_object_name(self, date_str: str, batch_id: str) -> str:
|
||||
"""
|
||||
Generate object name for a batched log file.
|
||||
Format: {date}/batch-{batch_id}.ndjson
|
||||
"""
|
||||
return f"{date_str}/batch-{batch_id}.ndjson"
|
||||
|
||||
def _get_config_key(self, kwargs: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract a synchronous grouping key from kwargs to group items by GCS config.
|
||||
This allows us to batch items with the same bucket/credentials together.
|
||||
|
||||
Returns a string key that uniquely identifies the GCS config combination.
|
||||
This key may contain sensitive information (bucket names, paths) - use _sanitize_config_key()
|
||||
for logging purposes.
|
||||
"""
|
||||
standard_callback_dynamic_params = (
|
||||
kwargs.get("standard_callback_dynamic_params", None) or {}
|
||||
)
|
||||
|
||||
bucket_name = (
|
||||
standard_callback_dynamic_params.get("gcs_bucket_name", None)
|
||||
or self.BUCKET_NAME
|
||||
or "default"
|
||||
)
|
||||
path_service_account = (
|
||||
standard_callback_dynamic_params.get("gcs_path_service_account", None)
|
||||
or self.path_service_account_json
|
||||
or "default"
|
||||
)
|
||||
|
||||
return f"{bucket_name}|{path_service_account}"
|
||||
|
||||
def _sanitize_config_key(self, config_key: str) -> str:
|
||||
"""
|
||||
Create a sanitized version of the config key for logging.
|
||||
Uses a hash to avoid exposing sensitive bucket names or service account paths.
|
||||
|
||||
Returns a short hash prefix for safe logging.
|
||||
"""
|
||||
hash_obj = hashlib.sha256(config_key.encode("utf-8"))
|
||||
return f"config-{hash_obj.hexdigest()[:8]}"
|
||||
|
||||
def _group_items_by_config(
|
||||
self, items: List[GCSLogQueueItem]
|
||||
) -> Dict[str, List[GCSLogQueueItem]]:
|
||||
"""
|
||||
Group items by their GCS config (bucket + credentials).
|
||||
This ensures items with different configs are processed separately.
|
||||
|
||||
Returns a dict mapping config_key -> list of items with that config.
|
||||
"""
|
||||
grouped: Dict[str, List[GCSLogQueueItem]] = {}
|
||||
for item in items:
|
||||
config_key = self._get_config_key(item["kwargs"])
|
||||
if config_key not in grouped:
|
||||
grouped[config_key] = []
|
||||
grouped[config_key].append(item)
|
||||
return grouped
|
||||
|
||||
def _combine_payloads_to_ndjson(self, items: List[GCSLogQueueItem]) -> str:
|
||||
"""
|
||||
Combine multiple log payloads into newline-delimited JSON (NDJSON) format.
|
||||
Each line is a valid JSON object representing one log entry.
|
||||
"""
|
||||
lines = []
|
||||
for item in items:
|
||||
logging_payload = item["payload"]
|
||||
json_line = json.dumps(logging_payload, default=str, ensure_ascii=False)
|
||||
lines.append(json_line)
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _send_grouped_batch(
|
||||
self, items: List[GCSLogQueueItem], config_key: str
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Send a batch of items that share the same GCS config.
|
||||
|
||||
Returns:
|
||||
(success_count, error_count)
|
||||
"""
|
||||
if not items:
|
||||
return (0, 0)
|
||||
|
||||
first_kwargs = items[0]["kwargs"]
|
||||
|
||||
try:
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
first_kwargs
|
||||
)
|
||||
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
|
||||
current_date = self._get_object_date_from_datetime(
|
||||
datetime.now(timezone.utc)
|
||||
)
|
||||
batch_id = f"{int(time.time() * 1000)}-{uuid.uuid4().hex[:8]}"
|
||||
object_name = self._generate_batch_object_name(current_date, batch_id)
|
||||
combined_payload = self._combine_payloads_to_ndjson(items)
|
||||
|
||||
await self._log_json_data_on_gcs(
|
||||
headers=headers,
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
logging_payload=combined_payload,
|
||||
)
|
||||
|
||||
success_count = len(items)
|
||||
error_count = 0
|
||||
return (success_count, error_count)
|
||||
|
||||
except Exception as e:
|
||||
success_count = 0
|
||||
error_count = len(items)
|
||||
verbose_logger.exception(
|
||||
f"GCS Bucket error logging batch payload to GCS bucket: {str(e)}"
|
||||
)
|
||||
return (success_count, error_count)
|
||||
|
||||
async def _send_individual_logs(self, items: List[GCSLogQueueItem]) -> None:
|
||||
"""
|
||||
Send each log individually as separate GCS objects (legacy behavior).
|
||||
This is used when GCS_USE_BATCHED_LOGGING is disabled.
|
||||
"""
|
||||
for item in items:
|
||||
await self._send_single_log_item(item)
|
||||
|
||||
async def _send_single_log_item(self, item: GCSLogQueueItem) -> None:
|
||||
"""
|
||||
Send a single log item to GCS as an individual object.
|
||||
"""
|
||||
try:
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
item["kwargs"]
|
||||
)
|
||||
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
|
||||
object_name = self._get_object_name(
|
||||
kwargs=item["kwargs"],
|
||||
logging_payload=item["payload"],
|
||||
response_obj=item["response_obj"],
|
||||
)
|
||||
|
||||
await self._log_json_data_on_gcs(
|
||||
headers=headers,
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
logging_payload=item["payload"],
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"GCS Bucket error logging individual payload to GCS bucket: {str(e)}"
|
||||
)
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
Process queued logs - sends logs to GCS Bucket.
|
||||
|
||||
If `GCS_USE_BATCHED_LOGGING` is enabled (default), batches multiple log payloads
|
||||
into single GCS object uploads (NDJSON format), dramatically reducing API calls.
|
||||
|
||||
If disabled, sends each log individually as separate GCS objects (legacy behavior).
|
||||
"""
|
||||
items_to_process = self._drain_queue_batch()
|
||||
|
||||
if not items_to_process:
|
||||
return
|
||||
|
||||
if self.use_batched_logging:
|
||||
grouped_items = self._group_items_by_config(items_to_process)
|
||||
|
||||
for config_key, group_items in grouped_items.items():
|
||||
await self._send_grouped_batch(group_items, config_key)
|
||||
else:
|
||||
await self._send_individual_logs(items_to_process)
|
||||
|
||||
def _get_object_name(
|
||||
self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any
|
||||
) -> str:
|
||||
"""
|
||||
Get the object name to use for the current payload
|
||||
"""
|
||||
current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc))
|
||||
if logging_payload.get("error_str", None) is not None:
|
||||
object_name = self._generate_failure_object_name(
|
||||
request_date_str=current_date,
|
||||
)
|
||||
else:
|
||||
object_name = self._generate_success_object_name(
|
||||
request_date_str=current_date,
|
||||
response_id=response_obj.get("id", ""),
|
||||
)
|
||||
|
||||
# used for testing
|
||||
_litellm_params = kwargs.get("litellm_params", None) or {}
|
||||
_metadata = _litellm_params.get("metadata", None) or {}
|
||||
if "gcs_log_id" in _metadata:
|
||||
object_name = _metadata["gcs_log_id"]
|
||||
|
||||
return object_name
|
||||
|
||||
async def get_request_response_payload(
|
||||
self,
|
||||
request_id: str,
|
||||
start_time_utc: Optional[datetime],
|
||||
end_time_utc: Optional[datetime],
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the request and response payload for a given `request_id`
|
||||
Tries current day, next day, and previous day until it finds the payload
|
||||
"""
|
||||
if start_time_utc is None:
|
||||
raise ValueError(
|
||||
"start_time_utc is required for getting a payload from GCS Bucket"
|
||||
)
|
||||
|
||||
dates_to_try = [
|
||||
start_time_utc,
|
||||
start_time_utc + timedelta(days=1),
|
||||
start_time_utc - timedelta(days=1),
|
||||
]
|
||||
date_str = None
|
||||
for date in dates_to_try:
|
||||
try:
|
||||
date_str = self._get_object_date_from_datetime(datetime_obj=date)
|
||||
object_name = self._generate_success_object_name(
|
||||
request_date_str=date_str,
|
||||
response_id=request_id,
|
||||
)
|
||||
encoded_object_name = quote(object_name, safe="")
|
||||
response = await self.download_gcs_object(encoded_object_name)
|
||||
|
||||
if response is not None:
|
||||
loaded_response = json.loads(response)
|
||||
return loaded_response
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Failed to fetch payload for date {date_str}: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def _generate_success_object_name(
|
||||
self,
|
||||
request_date_str: str,
|
||||
response_id: str,
|
||||
) -> str:
|
||||
return f"{request_date_str}/{response_id}"
|
||||
|
||||
def _generate_failure_object_name(
|
||||
self,
|
||||
request_date_str: str,
|
||||
) -> str:
|
||||
return f"{request_date_str}/failure-{uuid.uuid4().hex}"
|
||||
|
||||
def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str:
|
||||
return datetime_obj.strftime("%Y-%m-%d")
|
||||
|
||||
async def flush_queue(self):
|
||||
"""
|
||||
Override flush_queue to work with asyncio.Queue.
|
||||
"""
|
||||
await self.async_send_batch()
|
||||
self.last_flush_time = time.time()
|
||||
|
||||
async def periodic_flush(self):
|
||||
"""
|
||||
Override periodic_flush to work with asyncio.Queue.
|
||||
"""
|
||||
while True:
|
||||
await asyncio.sleep(self.flush_interval)
|
||||
verbose_logger.debug(
|
||||
f"GCS Bucket periodic flush after {self.flush_interval} seconds"
|
||||
)
|
||||
await self.flush_queue()
|
||||
|
||||
async def async_health_check(self) -> IntegrationHealthCheckStatus:
|
||||
raise NotImplementedError("GCS Bucket does not support health check")
|
||||
@@ -0,0 +1,347 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket_mock_client import (
|
||||
should_use_gcs_mock,
|
||||
create_mock_gcs_client,
|
||||
mock_vertex_auth_methods,
|
||||
)
|
||||
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.integrations.gcs_bucket import *
|
||||
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
else:
|
||||
VertexBase = Any
|
||||
IAM_AUTH_KEY = "IAM_AUTH"
|
||||
|
||||
|
||||
class GCSBucketBase(CustomBatchLogger):
|
||||
def __init__(self, bucket_name: Optional[str] = None, **kwargs) -> None:
|
||||
self.is_mock_mode = should_use_gcs_mock()
|
||||
|
||||
if self.is_mock_mode:
|
||||
mock_vertex_auth_methods()
|
||||
create_mock_gcs_client()
|
||||
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
_path_service_account = os.getenv("GCS_PATH_SERVICE_ACCOUNT")
|
||||
_bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
|
||||
self.path_service_account_json: Optional[str] = _path_service_account
|
||||
self.BUCKET_NAME: Optional[str] = _bucket_name
|
||||
self.vertex_instances: Dict[str, VertexBase] = {}
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def construct_request_headers(
|
||||
self,
|
||||
service_account_json: Optional[str],
|
||||
vertex_instance: Optional[VertexBase] = None,
|
||||
) -> Dict[str, str]:
|
||||
from litellm import vertex_chat_completion
|
||||
|
||||
if vertex_instance is None:
|
||||
vertex_instance = vertex_chat_completion
|
||||
|
||||
_auth_header, vertex_project = await vertex_instance._ensure_access_token_async(
|
||||
credentials=service_account_json,
|
||||
project_id=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_instance._get_token_and_url(
|
||||
model="gcs-bucket",
|
||||
auth_header=_auth_header,
|
||||
vertex_credentials=service_account_json,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=None,
|
||||
gemini_api_key=None,
|
||||
stream=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
api_base=None,
|
||||
)
|
||||
verbose_logger.debug("constructed auth_header %s", auth_header)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}", # auth_header
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
return headers
|
||||
|
||||
def sync_construct_request_headers(self) -> Dict[str, str]:
|
||||
"""
|
||||
Construct request headers for GCS API calls
|
||||
"""
|
||||
from litellm import vertex_chat_completion
|
||||
|
||||
# Get project_id from environment if available, otherwise None
|
||||
# This helps support use of this library to auth to pull secrets
|
||||
# from Secret Manager.
|
||||
project_id = os.getenv("GOOGLE_SECRET_MANAGER_PROJECT_ID")
|
||||
|
||||
_auth_header, vertex_project = vertex_chat_completion._ensure_access_token(
|
||||
credentials=self.path_service_account_json,
|
||||
project_id=project_id,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
||||
model="gcs-bucket",
|
||||
auth_header=_auth_header,
|
||||
vertex_credentials=self.path_service_account_json,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=None,
|
||||
gemini_api_key=None,
|
||||
stream=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
api_base=None,
|
||||
)
|
||||
verbose_logger.debug("constructed auth_header %s", auth_header)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}", # auth_header
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
return headers
|
||||
|
||||
def _handle_folders_in_bucket_name(
|
||||
self,
|
||||
bucket_name: str,
|
||||
object_name: str,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Handles when the user passes a bucket name with a folder postfix
|
||||
|
||||
|
||||
Example:
|
||||
- Bucket name: "my-bucket/my-folder/dev"
|
||||
- Object name: "my-object"
|
||||
- Returns: bucket_name="my-bucket", object_name="my-folder/dev/my-object"
|
||||
|
||||
"""
|
||||
if "/" in bucket_name:
|
||||
bucket_name, prefix = bucket_name.split("/", 1)
|
||||
object_name = f"{prefix}/{object_name}"
|
||||
return bucket_name, object_name
|
||||
return bucket_name, object_name
|
||||
|
||||
async def get_gcs_logging_config(
|
||||
self, kwargs: Optional[Dict[str, Any]] = {}
|
||||
) -> GCSLoggingConfig:
|
||||
"""
|
||||
This function is used to get the GCS logging config for the GCS Bucket Logger.
|
||||
It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config.
|
||||
If no dynamic parameters are provided, it uses the default values.
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
standard_callback_dynamic_params: Optional[
|
||||
StandardCallbackDynamicParams
|
||||
] = kwargs.get("standard_callback_dynamic_params", None)
|
||||
|
||||
bucket_name: str
|
||||
path_service_account: Optional[str]
|
||||
if standard_callback_dynamic_params is not None:
|
||||
verbose_logger.debug("Using dynamic GCS logging")
|
||||
verbose_logger.debug(
|
||||
"standard_callback_dynamic_params: %s", standard_callback_dynamic_params
|
||||
)
|
||||
|
||||
_bucket_name: Optional[str] = (
|
||||
standard_callback_dynamic_params.get("gcs_bucket_name", None)
|
||||
or self.BUCKET_NAME
|
||||
)
|
||||
_path_service_account: Optional[str] = (
|
||||
standard_callback_dynamic_params.get("gcs_path_service_account", None)
|
||||
or self.path_service_account_json
|
||||
)
|
||||
|
||||
if _bucket_name is None:
|
||||
raise ValueError(
|
||||
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
||||
)
|
||||
bucket_name = _bucket_name
|
||||
path_service_account = _path_service_account
|
||||
vertex_instance = await self.get_or_create_vertex_instance(
|
||||
credentials=path_service_account
|
||||
)
|
||||
else:
|
||||
# If no dynamic parameters, use the default instance
|
||||
if self.BUCKET_NAME is None:
|
||||
raise ValueError(
|
||||
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
||||
)
|
||||
bucket_name = self.BUCKET_NAME
|
||||
path_service_account = self.path_service_account_json
|
||||
vertex_instance = await self.get_or_create_vertex_instance(
|
||||
credentials=path_service_account
|
||||
)
|
||||
|
||||
return GCSLoggingConfig(
|
||||
bucket_name=bucket_name,
|
||||
vertex_instance=vertex_instance,
|
||||
path_service_account=path_service_account,
|
||||
)
|
||||
|
||||
async def get_or_create_vertex_instance(
|
||||
self, credentials: Optional[str]
|
||||
) -> VertexBase:
|
||||
"""
|
||||
This function is used to get the Vertex instance for the GCS Bucket Logger.
|
||||
It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
|
||||
"""
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
|
||||
_in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
|
||||
if _in_memory_key not in self.vertex_instances:
|
||||
vertex_instance = VertexBase()
|
||||
await vertex_instance._ensure_access_token_async(
|
||||
credentials=credentials,
|
||||
project_id=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
self.vertex_instances[_in_memory_key] = vertex_instance
|
||||
return self.vertex_instances[_in_memory_key]
|
||||
|
||||
def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
|
||||
"""
|
||||
Returns key to use for caching the Vertex instance in-memory.
|
||||
|
||||
When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
|
||||
|
||||
- If a credentials string is provided, it is used as the key.
|
||||
- If no credentials string is provided, "IAM_AUTH" is used as the key.
|
||||
"""
|
||||
return credentials or IAM_AUTH_KEY
|
||||
|
||||
async def download_gcs_object(self, object_name: str, **kwargs):
|
||||
"""
|
||||
Download an object from GCS.
|
||||
|
||||
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
|
||||
"""
|
||||
try:
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs=kwargs
|
||||
)
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
)
|
||||
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
|
||||
|
||||
# Send the GET request to download the object
|
||||
response = await self.async_httpx_client.get(url=url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
verbose_logger.error(
|
||||
"GCS object download error: %s", str(response.text)
|
||||
)
|
||||
return None
|
||||
|
||||
verbose_logger.debug(
|
||||
"GCS object download response status code: %s", response.status_code
|
||||
)
|
||||
|
||||
# Return the content of the downloaded object
|
||||
return response.content
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error("GCS object download error: %s", str(e))
|
||||
return None
|
||||
|
||||
async def delete_gcs_object(self, object_name: str, **kwargs):
|
||||
"""
|
||||
Delete an object from GCS.
|
||||
"""
|
||||
try:
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs=kwargs
|
||||
)
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
)
|
||||
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
|
||||
|
||||
# Send the DELETE request to delete the object
|
||||
response = await self.async_httpx_client.delete(url=url, headers=headers)
|
||||
|
||||
if (response.status_code != 200) or (response.status_code != 204):
|
||||
verbose_logger.error(
|
||||
"GCS object delete error: %s, status code: %s",
|
||||
str(response.text),
|
||||
response.status_code,
|
||||
)
|
||||
return None
|
||||
|
||||
verbose_logger.debug(
|
||||
"GCS object delete response status code: %s, response: %s",
|
||||
response.status_code,
|
||||
response.text,
|
||||
)
|
||||
|
||||
# Return the content of the downloaded object
|
||||
return response.text
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error("GCS object download error: %s", str(e))
|
||||
return None
|
||||
|
||||
async def _log_json_data_on_gcs(
|
||||
self,
|
||||
headers: Dict[str, str],
|
||||
bucket_name: str,
|
||||
object_name: str,
|
||||
logging_payload: Union[StandardLoggingPayload, str],
|
||||
):
|
||||
"""
|
||||
Helper function to make POST request to GCS Bucket in the specified bucket.
|
||||
"""
|
||||
if isinstance(logging_payload, str):
|
||||
json_logged_payload = logging_payload
|
||||
else:
|
||||
json_logged_payload = json.dumps(logging_payload, default=str)
|
||||
|
||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
)
|
||||
|
||||
response = await self.async_httpx_client.post(
|
||||
headers=headers,
|
||||
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
|
||||
data=json_logged_payload,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
|
||||
|
||||
verbose_logger.debug("GCS Bucket response %s", response)
|
||||
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
|
||||
verbose_logger.debug("GCS Bucket response.text %s", response.text)
|
||||
|
||||
return response.json()
|
||||
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
Mock client for GCS Bucket integration testing.
|
||||
|
||||
This module intercepts GCS API calls and Vertex AI auth calls, returning successful
|
||||
mock responses, allowing full code execution without making actual network calls.
|
||||
|
||||
Usage:
|
||||
Set GCS_MOCK=true in environment variables or config to enable mock mode.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.mock_client_factory import (
|
||||
MockClientConfig,
|
||||
create_mock_client_factory,
|
||||
MockResponse,
|
||||
)
|
||||
|
||||
# Use factory for POST handler
|
||||
_config = MockClientConfig(
|
||||
name="GCS",
|
||||
env_var="GCS_MOCK",
|
||||
default_latency_ms=150,
|
||||
default_status_code=200,
|
||||
default_json_data={"kind": "storage#object", "name": "mock-object"},
|
||||
url_matchers=["storage.googleapis.com"],
|
||||
patch_async_handler=True,
|
||||
patch_sync_client=False,
|
||||
)
|
||||
|
||||
_create_mock_gcs_post, should_use_gcs_mock = create_mock_client_factory(_config)
|
||||
|
||||
# Store original methods for GET/DELETE (GCS-specific)
|
||||
_original_async_handler_get = None
|
||||
_original_async_handler_delete = None
|
||||
_mocks_initialized = False
|
||||
|
||||
# Default mock latency in seconds (simulates network round-trip)
|
||||
# Typical GCS API calls take 100-300ms for uploads, 50-150ms for GET/DELETE
|
||||
_MOCK_LATENCY_SECONDS = (
|
||||
float(__import__("os").getenv("GCS_MOCK_LATENCY_MS", "150")) / 1000.0
|
||||
)
|
||||
|
||||
|
||||
async def _mock_async_handler_get(
|
||||
self, url, params=None, headers=None, follow_redirects=None
|
||||
):
|
||||
"""Monkey-patched AsyncHTTPHandler.get that intercepts GCS calls."""
|
||||
# Only mock GCS API calls
|
||||
if isinstance(url, str) and "storage.googleapis.com" in url:
|
||||
verbose_logger.info(f"[GCS MOCK] GET to {url}")
|
||||
await asyncio.sleep(_MOCK_LATENCY_SECONDS)
|
||||
# Return a minimal but valid StandardLoggingPayload JSON string as bytes
|
||||
# This matches what GCS returns when downloading with ?alt=media
|
||||
mock_payload = {
|
||||
"id": "mock-request-id",
|
||||
"trace_id": "mock-trace-id",
|
||||
"call_type": "completion",
|
||||
"stream": False,
|
||||
"response_cost": 0.0,
|
||||
"status": "success",
|
||||
"status_fields": {"llm_api_status": "success"},
|
||||
"custom_llm_provider": "mock",
|
||||
"total_tokens": 0,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"startTime": 0.0,
|
||||
"endTime": 0.0,
|
||||
"completionStartTime": 0.0,
|
||||
"response_time": 0.0,
|
||||
"model_map_information": {"model": "mock-model"},
|
||||
"model": "mock-model",
|
||||
"model_id": None,
|
||||
"model_group": None,
|
||||
"api_base": "https://api.mock.com",
|
||||
"metadata": {},
|
||||
"cache_hit": None,
|
||||
"cache_key": None,
|
||||
"saved_cache_cost": 0.0,
|
||||
"request_tags": [],
|
||||
"end_user": None,
|
||||
"requester_ip_address": None,
|
||||
"messages": None,
|
||||
"response": None,
|
||||
"error_str": None,
|
||||
"error_information": None,
|
||||
"model_parameters": {},
|
||||
"hidden_params": {},
|
||||
"guardrail_information": None,
|
||||
"standard_built_in_tools_params": None,
|
||||
}
|
||||
return MockResponse(
|
||||
status_code=200,
|
||||
json_data=mock_payload,
|
||||
url=url,
|
||||
elapsed_seconds=_MOCK_LATENCY_SECONDS,
|
||||
)
|
||||
if _original_async_handler_get is not None:
|
||||
return await _original_async_handler_get(
|
||||
self,
|
||||
url=url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
follow_redirects=follow_redirects,
|
||||
)
|
||||
raise RuntimeError("Original AsyncHTTPHandler.get not available")
|
||||
|
||||
|
||||
async def _mock_async_handler_delete(
|
||||
self,
|
||||
url,
|
||||
data=None,
|
||||
json=None,
|
||||
params=None,
|
||||
headers=None,
|
||||
timeout=None,
|
||||
stream=False,
|
||||
content=None,
|
||||
):
|
||||
"""Monkey-patched AsyncHTTPHandler.delete that intercepts GCS calls."""
|
||||
# Only mock GCS API calls
|
||||
if isinstance(url, str) and "storage.googleapis.com" in url:
|
||||
verbose_logger.info(f"[GCS MOCK] DELETE to {url}")
|
||||
await asyncio.sleep(_MOCK_LATENCY_SECONDS)
|
||||
# DELETE returns 204 No Content with empty body (not JSON)
|
||||
return MockResponse(
|
||||
status_code=204,
|
||||
json_data=None, # Empty body for DELETE
|
||||
url=url,
|
||||
elapsed_seconds=_MOCK_LATENCY_SECONDS,
|
||||
)
|
||||
if _original_async_handler_delete is not None:
|
||||
return await _original_async_handler_delete(
|
||||
self,
|
||||
url=url,
|
||||
data=data,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
content=content,
|
||||
)
|
||||
raise RuntimeError("Original AsyncHTTPHandler.delete not available")
|
||||
|
||||
|
||||
def create_mock_gcs_client():
|
||||
"""
|
||||
Monkey-patch AsyncHTTPHandler methods to intercept GCS calls.
|
||||
|
||||
AsyncHTTPHandler is used by LiteLLM's get_async_httpx_client() which is what
|
||||
GCSBucketBase uses for making API calls.
|
||||
|
||||
This function is idempotent - it only initializes mocks once, even if called multiple times.
|
||||
"""
|
||||
global _original_async_handler_get, _original_async_handler_delete, _mocks_initialized
|
||||
|
||||
# Use factory for POST handler
|
||||
_create_mock_gcs_post()
|
||||
|
||||
# If already initialized, skip GET/DELETE patching
|
||||
if _mocks_initialized:
|
||||
return
|
||||
|
||||
verbose_logger.debug("[GCS MOCK] Initializing GCS GET/DELETE handlers...")
|
||||
|
||||
# Patch GET and DELETE handlers (GCS-specific)
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
if _original_async_handler_get is None:
|
||||
_original_async_handler_get = AsyncHTTPHandler.get
|
||||
AsyncHTTPHandler.get = _mock_async_handler_get # type: ignore
|
||||
verbose_logger.debug("[GCS MOCK] Patched AsyncHTTPHandler.get")
|
||||
|
||||
if _original_async_handler_delete is None:
|
||||
_original_async_handler_delete = AsyncHTTPHandler.delete
|
||||
AsyncHTTPHandler.delete = _mock_async_handler_delete # type: ignore
|
||||
verbose_logger.debug("[GCS MOCK] Patched AsyncHTTPHandler.delete")
|
||||
|
||||
verbose_logger.debug(
|
||||
f"[GCS MOCK] Mock latency set to {_MOCK_LATENCY_SECONDS*1000:.0f}ms"
|
||||
)
|
||||
verbose_logger.debug("[GCS MOCK] GCS mock client initialization complete")
|
||||
|
||||
_mocks_initialized = True
|
||||
|
||||
|
||||
def mock_vertex_auth_methods():
|
||||
"""
|
||||
Monkey-patch Vertex AI auth methods to return fake tokens.
|
||||
This prevents auth failures when GCS_MOCK is enabled.
|
||||
|
||||
This function is idempotent - it only patches once, even if called multiple times.
|
||||
"""
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
|
||||
# Store original methods if not already stored
|
||||
if not hasattr(VertexBase, "_original_ensure_access_token_async"):
|
||||
setattr(
|
||||
VertexBase,
|
||||
"_original_ensure_access_token_async",
|
||||
VertexBase._ensure_access_token_async,
|
||||
)
|
||||
setattr(
|
||||
VertexBase, "_original_ensure_access_token", VertexBase._ensure_access_token
|
||||
)
|
||||
setattr(
|
||||
VertexBase, "_original_get_token_and_url", VertexBase._get_token_and_url
|
||||
)
|
||||
|
||||
async def _mock_ensure_access_token_async(
|
||||
self, credentials, project_id, custom_llm_provider
|
||||
):
|
||||
"""Mock async auth method - returns fake token."""
|
||||
verbose_logger.debug(
|
||||
"[GCS MOCK] Vertex AI auth: _ensure_access_token_async called"
|
||||
)
|
||||
return ("mock-gcs-token", "mock-project-id")
|
||||
|
||||
def _mock_ensure_access_token(
|
||||
self, credentials, project_id, custom_llm_provider
|
||||
):
|
||||
"""Mock sync auth method - returns fake token."""
|
||||
verbose_logger.debug(
|
||||
"[GCS MOCK] Vertex AI auth: _ensure_access_token called"
|
||||
)
|
||||
return ("mock-gcs-token", "mock-project-id")
|
||||
|
||||
def _mock_get_token_and_url(
|
||||
self,
|
||||
model,
|
||||
auth_header,
|
||||
vertex_credentials,
|
||||
vertex_project,
|
||||
vertex_location,
|
||||
gemini_api_key,
|
||||
stream,
|
||||
custom_llm_provider,
|
||||
api_base,
|
||||
):
|
||||
"""Mock get_token_and_url - returns fake token."""
|
||||
verbose_logger.debug("[GCS MOCK] Vertex AI auth: _get_token_and_url called")
|
||||
return ("mock-gcs-token", "https://storage.googleapis.com")
|
||||
|
||||
# Patch the methods
|
||||
VertexBase._ensure_access_token_async = _mock_ensure_access_token_async # type: ignore
|
||||
VertexBase._ensure_access_token = _mock_ensure_access_token # type: ignore
|
||||
VertexBase._get_token_and_url = _mock_get_token_and_url # type: ignore
|
||||
|
||||
verbose_logger.debug("[GCS MOCK] Patched Vertex AI auth methods")
|
||||
|
||||
|
||||
# should_use_gcs_mock is already created by the factory
|
||||
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
BETA
|
||||
|
||||
This is the PubSub logger for GCS PubSub, this sends LiteLLM SpendLogs Payloads to GCS PubSub.
|
||||
|
||||
Users can use this instead of sending their SpendLogs to their Postgres database.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy._types import SpendLogsPayload
|
||||
else:
|
||||
SpendLogsPayload = Any
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
|
||||
|
||||
class GcsPubSubLogger(CustomBatchLogger):
|
||||
def __init__(
|
||||
self,
|
||||
project_id: Optional[str] = None,
|
||||
topic_id: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize Google Cloud Pub/Sub publisher
|
||||
|
||||
Args:
|
||||
project_id (str): Google Cloud project ID
|
||||
topic_id (str): Pub/Sub topic ID
|
||||
credentials_path (str, optional): Path to Google Cloud credentials JSON file
|
||||
"""
|
||||
from litellm.proxy.utils import _premium_user_check
|
||||
|
||||
_premium_user_check()
|
||||
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
|
||||
self.project_id = project_id or os.getenv("GCS_PUBSUB_PROJECT_ID")
|
||||
self.topic_id = topic_id or os.getenv("GCS_PUBSUB_TOPIC_ID")
|
||||
self.path_service_account_json = credentials_path or os.getenv(
|
||||
"GCS_PATH_SERVICE_ACCOUNT"
|
||||
)
|
||||
|
||||
if not self.project_id or not self.topic_id:
|
||||
raise ValueError("Both project_id and topic_id must be provided")
|
||||
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(**kwargs, flush_lock=self.flush_lock)
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.log_queue: List[Union[SpendLogsPayload, StandardLoggingPayload]] = []
|
||||
|
||||
async def construct_request_headers(self) -> Dict[str, str]:
|
||||
"""Construct authorization headers using Vertex AI auth"""
|
||||
from litellm import vertex_chat_completion
|
||||
|
||||
(
|
||||
_auth_header,
|
||||
vertex_project,
|
||||
) = await vertex_chat_completion._ensure_access_token_async(
|
||||
credentials=self.path_service_account_json,
|
||||
project_id=self.project_id,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
||||
model="pub-sub",
|
||||
auth_header=_auth_header,
|
||||
vertex_credentials=self.path_service_account_json,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=None,
|
||||
gemini_api_key=None,
|
||||
stream=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
api_base=None,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
return headers
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Async Log success events to GCS PubSub Topic
|
||||
|
||||
- Creates a SpendLogsPayload
|
||||
- Adds to batch queue
|
||||
- Flushes based on CustomBatchLogger settings
|
||||
|
||||
Raises:
|
||||
Raises a NON Blocking verbose_logger.exception if an error occurs
|
||||
"""
|
||||
from litellm.proxy.spend_tracking.spend_tracking_utils import (
|
||||
get_logging_payload,
|
||||
)
|
||||
from litellm.proxy.utils import _premium_user_check
|
||||
|
||||
_premium_user_check()
|
||||
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"PubSub: Logging - Enters logging function for model %s", kwargs
|
||||
)
|
||||
standard_logging_payload = kwargs.get("standard_logging_object", None)
|
||||
|
||||
# Backwards compatibility with old logging payload
|
||||
if litellm.gcs_pub_sub_use_v1 is True:
|
||||
spend_logs_payload = get_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
self.log_queue.append(spend_logs_payload)
|
||||
else:
|
||||
# New logging payload, StandardLoggingPayload
|
||||
self.log_queue.append(standard_logging_payload)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.async_send_batch()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"PubSub Layer Error - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
Sends the batch of messages to Pub/Sub
|
||||
"""
|
||||
try:
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
verbose_logger.debug(
|
||||
f"PubSub - about to flush {len(self.log_queue)} events"
|
||||
)
|
||||
|
||||
for message in self.log_queue:
|
||||
await self.publish_message(message)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"PubSub Error sending batch - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
finally:
|
||||
self.log_queue.clear()
|
||||
|
||||
async def publish_message(
|
||||
self, message: Union[SpendLogsPayload, StandardLoggingPayload]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Publish message to Google Cloud Pub/Sub using REST API
|
||||
|
||||
Args:
|
||||
message: Message to publish (dict or string)
|
||||
|
||||
Returns:
|
||||
dict: Published message response
|
||||
"""
|
||||
try:
|
||||
headers = await self.construct_request_headers()
|
||||
|
||||
# Prepare message data
|
||||
if isinstance(message, str):
|
||||
message_data = message
|
||||
else:
|
||||
message_data = json.dumps(message, default=str)
|
||||
|
||||
# Base64 encode the message
|
||||
import base64
|
||||
|
||||
encoded_message = base64.b64encode(message_data.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
# Construct request body
|
||||
request_body = {"messages": [{"data": encoded_message}]}
|
||||
|
||||
url = f"https://pubsub.googleapis.com/v1/projects/{self.project_id}/topics/{self.topic_id}:publish"
|
||||
|
||||
response = await self.async_httpx_client.post(
|
||||
url=url, headers=headers, json=request_body
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 202]:
|
||||
verbose_logger.error("Pub/Sub publish error: %s", str(response.text))
|
||||
raise Exception(f"Failed to publish message: {response.text}")
|
||||
|
||||
verbose_logger.debug("Pub/Sub response: %s", response.text)
|
||||
return response.json()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error("Pub/Sub publish error: %s", str(e))
|
||||
return None
|
||||
@@ -0,0 +1,428 @@
|
||||
"""
|
||||
Callback to log events to a Generic API Endpoint
|
||||
|
||||
- Creates a StandardLoggingPayload
|
||||
- Adds to batch queue
|
||||
- Flushes based on CustomBatchLogger settings
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
API_EVENT_TYPES = Literal["llm_api_success", "llm_api_failure"]
|
||||
LOG_FORMAT_TYPES = Literal["json_array", "ndjson", "single"]
|
||||
|
||||
|
||||
def load_compatible_callbacks() -> Dict:
|
||||
"""
|
||||
Load the generic_api_compatible_callbacks.json file
|
||||
|
||||
Returns:
|
||||
Dict: Dictionary of compatible callbacks configuration
|
||||
"""
|
||||
try:
|
||||
json_path = os.path.join(
|
||||
os.path.dirname(__file__), "generic_api_compatible_callbacks.json"
|
||||
)
|
||||
with open(json_path, "r") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
f"Error loading generic_api_compatible_callbacks.json: {str(e)}"
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
def is_callback_compatible(callback_name: str) -> bool:
|
||||
"""
|
||||
Check if a callback_name exists in the compatible callbacks list
|
||||
|
||||
Args:
|
||||
callback_name: Name of the callback to check
|
||||
|
||||
Returns:
|
||||
bool: True if callback_name exists in the compatible callbacks, False otherwise
|
||||
"""
|
||||
compatible_callbacks = load_compatible_callbacks()
|
||||
return callback_name in compatible_callbacks
|
||||
|
||||
|
||||
def get_callback_config(callback_name: str) -> Optional[Dict]:
|
||||
"""
|
||||
Get the configuration for a specific callback
|
||||
|
||||
Args:
|
||||
callback_name: Name of the callback to get config for
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: Configuration dict for the callback, or None if not found
|
||||
"""
|
||||
compatible_callbacks = load_compatible_callbacks()
|
||||
return compatible_callbacks.get(callback_name)
|
||||
|
||||
|
||||
def substitute_env_variables(value: str) -> str:
|
||||
"""
|
||||
Replace {{environment_variables.VAR_NAME}} patterns with actual environment variable values
|
||||
|
||||
Args:
|
||||
value: String that may contain {{environment_variables.VAR_NAME}} patterns
|
||||
|
||||
Returns:
|
||||
str: String with environment variables substituted
|
||||
"""
|
||||
pattern = r"\{\{environment_variables\.([A-Z_]+)\}\}"
|
||||
|
||||
def replace_env_var(match):
|
||||
env_var_name = match.group(1)
|
||||
return os.getenv(env_var_name, "")
|
||||
|
||||
return re.sub(pattern, replace_env_var, value)
|
||||
|
||||
|
||||
class GenericAPILogger(CustomBatchLogger):
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
event_types: Optional[List[API_EVENT_TYPES]] = None,
|
||||
callback_name: Optional[str] = None,
|
||||
log_format: Optional[LOG_FORMAT_TYPES] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the GenericAPILogger
|
||||
|
||||
Args:
|
||||
endpoint: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
event_types: Optional[List[API_EVENT_TYPES]] = None,
|
||||
callback_name: Optional[str] = None - If provided, loads config from generic_api_compatible_callbacks.json
|
||||
log_format: Optional[LOG_FORMAT_TYPES] = None - Format for log output: "json_array" (default), "ndjson", or "single"
|
||||
"""
|
||||
#########################################################
|
||||
# Check if callback_name is provided and load config
|
||||
#########################################################
|
||||
if callback_name:
|
||||
if is_callback_compatible(callback_name):
|
||||
verbose_logger.debug(
|
||||
f"Loading configuration for callback: {callback_name}"
|
||||
)
|
||||
callback_config = get_callback_config(callback_name)
|
||||
|
||||
# Use config from JSON if not explicitly provided
|
||||
if callback_config:
|
||||
if endpoint is None and "endpoint" in callback_config:
|
||||
endpoint = substitute_env_variables(callback_config["endpoint"])
|
||||
|
||||
if "headers" in callback_config:
|
||||
headers = headers or {}
|
||||
for key, value in callback_config["headers"].items():
|
||||
if key not in headers:
|
||||
headers[key] = substitute_env_variables(value)
|
||||
|
||||
if event_types is None and "event_types" in callback_config:
|
||||
event_types = callback_config["event_types"]
|
||||
|
||||
if log_format is None and "log_format" in callback_config:
|
||||
log_format = callback_config["log_format"]
|
||||
else:
|
||||
verbose_logger.warning(
|
||||
f"callback_name '{callback_name}' not found in generic_api_compatible_callbacks.json"
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# Init httpx client
|
||||
#########################################################
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
endpoint = endpoint or os.getenv("GENERIC_LOGGER_ENDPOINT")
|
||||
if endpoint is None:
|
||||
raise ValueError(
|
||||
"endpoint not set for GenericAPILogger, GENERIC_LOGGER_ENDPOINT not found in environment variables"
|
||||
)
|
||||
|
||||
self.headers: Dict = self._get_headers(headers)
|
||||
self.endpoint: str = endpoint
|
||||
self.event_types: Optional[List[API_EVENT_TYPES]] = event_types
|
||||
self.callback_name: Optional[str] = callback_name
|
||||
|
||||
# Validate and store log_format
|
||||
if log_format is not None and log_format not in [
|
||||
"json_array",
|
||||
"ndjson",
|
||||
"single",
|
||||
]:
|
||||
raise ValueError(
|
||||
f"Invalid log_format: {log_format}. Must be one of: 'json_array', 'ndjson', 'single'"
|
||||
)
|
||||
self.log_format: LOG_FORMAT_TYPES = log_format or "json_array"
|
||||
|
||||
verbose_logger.debug(
|
||||
f"in init GenericAPILogger, callback_name: {self.callback_name}, endpoint {self.endpoint}, headers {self.headers}, event_types: {self.event_types}, log_format: {self.log_format}"
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# Init variables for batch flushing logs
|
||||
#########################################################
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(**kwargs, flush_lock=self.flush_lock)
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.log_queue: List[Union[Dict, StandardLoggingPayload]] = []
|
||||
|
||||
def _get_headers(self, headers: Optional[dict] = None):
|
||||
"""
|
||||
Get headers for the Generic API Logger
|
||||
|
||||
Returns:
|
||||
Dict: Headers for the Generic API Logger
|
||||
|
||||
Args:
|
||||
headers: Optional[dict] = None
|
||||
"""
|
||||
# Process headers from different sources
|
||||
headers_dict = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# 1. First check for headers from env var
|
||||
env_headers = os.getenv("GENERIC_LOGGER_HEADERS")
|
||||
if env_headers:
|
||||
try:
|
||||
# Parse headers in format "key1=value1,key2=value2" or "key1=value1"
|
||||
header_items = env_headers.split(",")
|
||||
for item in header_items:
|
||||
if "=" in item:
|
||||
key, value = item.split("=", 1)
|
||||
headers_dict[key.strip()] = value.strip()
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
f"Error parsing headers from environment variables: {str(e)}"
|
||||
)
|
||||
|
||||
# 2. Update with litellm generic headers if available
|
||||
if litellm.generic_logger_headers:
|
||||
headers_dict.update(litellm.generic_logger_headers)
|
||||
|
||||
# 3. Override with directly provided headers if any
|
||||
if headers:
|
||||
headers_dict.update(headers)
|
||||
|
||||
return headers_dict
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Async Log success events to Generic API Endpoint
|
||||
|
||||
- Creates a StandardLoggingPayload
|
||||
- Adds to batch queue
|
||||
- Flushes based on CustomBatchLogger settings
|
||||
|
||||
Raises:
|
||||
Raises a NON Blocking verbose_logger.exception if an error occurs
|
||||
"""
|
||||
|
||||
if self.event_types is not None and "llm_api_success" not in self.event_types:
|
||||
return
|
||||
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"Generic API Logger - Enters logging function for model %s", kwargs
|
||||
)
|
||||
standard_logging_payload = kwargs.get("standard_logging_object", None)
|
||||
|
||||
# Backwards compatibility with old logging payload
|
||||
if litellm.generic_api_use_v1 is True:
|
||||
payload = self._get_v1_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
self.log_queue.append(payload)
|
||||
else:
|
||||
# New logging payload, StandardLoggingPayload
|
||||
self.log_queue.append(standard_logging_payload)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.async_send_batch()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Generic API Logger Error - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Async Log failure events to Generic API Endpoint
|
||||
|
||||
- Creates a StandardLoggingPayload
|
||||
- Adds to batch queue
|
||||
"""
|
||||
if self.event_types is not None and "llm_api_failure" not in self.event_types:
|
||||
return
|
||||
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"Generic API Logger - Enters logging function for model %s", kwargs
|
||||
)
|
||||
standard_logging_payload = kwargs.get("standard_logging_object", None)
|
||||
|
||||
if litellm.generic_api_use_v1 is True:
|
||||
payload = self._get_v1_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
self.log_queue.append(payload)
|
||||
else:
|
||||
self.log_queue.append(standard_logging_payload)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.async_send_batch()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Generic API Logger Error - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
Sends the batch of messages to Generic API Endpoint
|
||||
|
||||
Supports three formats:
|
||||
- json_array: Sends all logs as a JSON array (default)
|
||||
- ndjson: Sends logs as newline-delimited JSON
|
||||
- single: Sends each log as individual HTTP request in parallel
|
||||
"""
|
||||
try:
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Generic API Logger - about to flush {len(self.log_queue)} events in '{self.log_format}' format"
|
||||
)
|
||||
|
||||
if self.log_format == "single":
|
||||
# Send each log as individual HTTP request in parallel
|
||||
tasks = []
|
||||
for log_entry in self.log_queue:
|
||||
task = self.async_httpx_client.post(
|
||||
url=self.endpoint,
|
||||
headers=self.headers,
|
||||
data=safe_dumps(log_entry),
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Execute all requests in parallel
|
||||
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Log results
|
||||
for idx, result in enumerate(responses):
|
||||
if isinstance(result, Exception):
|
||||
verbose_logger.exception(
|
||||
f"Generic API Logger - Error sending log {idx}: {result}"
|
||||
)
|
||||
else:
|
||||
# result is a Response object
|
||||
verbose_logger.debug(
|
||||
f"Generic API Logger - sent log {idx}, status: {result.status_code}" # type: ignore
|
||||
)
|
||||
else:
|
||||
# Format the payload based on log_format
|
||||
if self.log_format == "json_array":
|
||||
data = safe_dumps(self.log_queue)
|
||||
elif self.log_format == "ndjson":
|
||||
data = "\n".join(safe_dumps(log) for log in self.log_queue)
|
||||
else:
|
||||
raise ValueError(f"Unknown log_format: {self.log_format}")
|
||||
|
||||
# Make POST request
|
||||
response = await self.async_httpx_client.post(
|
||||
url=self.endpoint,
|
||||
headers=self.headers,
|
||||
data=data,
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Generic API Logger - sent batch to {self.endpoint}, "
|
||||
f"status: {response.status_code}, format: {self.log_format}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Generic API Logger Error sending batch - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
finally:
|
||||
self.log_queue.clear()
|
||||
|
||||
def _get_v1_logging_payload(
|
||||
self, kwargs, response_obj, start_time, end_time
|
||||
) -> dict:
|
||||
"""
|
||||
Maintained for backwards compatibility with old logging payload
|
||||
|
||||
Returns a dict of the payload to send to the Generic API Endpoint
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
f"GenericAPILogger Logging - Enters logging function for model {kwargs}"
|
||||
)
|
||||
|
||||
# construct payload to send custom logger
|
||||
# follows the same params as langfuse.py
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
) # if litellm_params['metadata'] == None
|
||||
messages = kwargs.get("messages")
|
||||
cost = kwargs.get("response_cost", 0.0)
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
call_type = kwargs.get("call_type", "litellm.completion")
|
||||
cache_hit = kwargs.get("cache_hit", False)
|
||||
usage = response_obj["usage"]
|
||||
id = response_obj.get("id", str(uuid.uuid4()))
|
||||
|
||||
# Build the initial payload
|
||||
payload = {
|
||||
"id": id,
|
||||
"call_type": call_type,
|
||||
"cache_hit": cache_hit,
|
||||
"startTime": start_time,
|
||||
"endTime": end_time,
|
||||
"model": kwargs.get("model", ""),
|
||||
"user": kwargs.get("user", ""),
|
||||
"modelParameters": optional_params,
|
||||
"messages": messages,
|
||||
"response": response_obj,
|
||||
"usage": usage,
|
||||
"metadata": metadata,
|
||||
"cost": cost,
|
||||
}
|
||||
|
||||
# Ensure everything in the payload is converted to str
|
||||
for key, value in payload.items():
|
||||
try:
|
||||
payload[key] = str(value)
|
||||
except Exception:
|
||||
# non blocking if it can't cast to a str
|
||||
pass
|
||||
|
||||
return payload
|
||||
@@ -0,0 +1,37 @@
|
||||
{
|
||||
"sample_callback": {
|
||||
"event_types": ["llm_api_success", "llm_api_failure"],
|
||||
"endpoint": "{{environment_variables.SAMPLE_CALLBACK_URL}}",
|
||||
"headers": {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer {{environment_variables.SAMPLE_CALLBACK_API_KEY}}"
|
||||
},
|
||||
"environment_variables": ["SAMPLE_CALLBACK_URL", "SAMPLE_CALLBACK_API_KEY"]
|
||||
},
|
||||
"rubrik": {
|
||||
"event_types": ["llm_api_success"],
|
||||
"endpoint": "{{environment_variables.RUBRIK_WEBHOOK_URL}}",
|
||||
"headers": {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer {{environment_variables.RUBRIK_API_KEY}}"
|
||||
},
|
||||
"environment_variables": ["RUBRIK_API_KEY", "RUBRIK_WEBHOOK_URL"]
|
||||
},
|
||||
"sumologic": {
|
||||
"endpoint": "{{environment_variables.SUMOLOGIC_WEBHOOK_URL}}",
|
||||
"headers": {
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
"environment_variables": ["SUMOLOGIC_WEBHOOK_URL"],
|
||||
"log_format": "ndjson"
|
||||
},
|
||||
"qualifire_eval": {
|
||||
"event_types": ["llm_api_success"],
|
||||
"endpoint": "{{environment_variables.QUALIFIRE_WEBHOOK_URL}}",
|
||||
"headers": {
|
||||
"Content-Type": "application/json",
|
||||
"X-Qualifire-API-Key": "{{environment_variables.QUALIFIRE_API_KEY}}"
|
||||
},
|
||||
"environment_variables": ["QUALIFIRE_API_KEY", "QUALIFIRE_WEBHOOK_URL"]
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Generic prompt management integration for LiteLLM."""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .generic_prompt_manager import GenericPromptManager
|
||||
from litellm.types.prompts.init_prompts import PromptLiteLLMParams, PromptSpec
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
|
||||
from litellm.types.prompts.init_prompts import SupportedPromptIntegrations
|
||||
|
||||
from .generic_prompt_manager import GenericPromptManager
|
||||
|
||||
# Global instances
|
||||
global_generic_prompt_config: Optional[dict] = None
|
||||
|
||||
|
||||
def set_global_generic_prompt_config(config: dict) -> None:
|
||||
"""
|
||||
Set the global generic prompt configuration.
|
||||
|
||||
Args:
|
||||
config: Dictionary containing generic prompt configuration
|
||||
- api_base: Base URL for the API
|
||||
- api_key: Optional API key for authentication
|
||||
- timeout: Request timeout in seconds (default: 30)
|
||||
"""
|
||||
import litellm
|
||||
|
||||
litellm.global_generic_prompt_config = config # type: ignore
|
||||
|
||||
|
||||
def prompt_initializer(
|
||||
litellm_params: "PromptLiteLLMParams", prompt_spec: "PromptSpec"
|
||||
) -> "CustomPromptManagement":
|
||||
"""
|
||||
Initialize a prompt from a generic prompt management API.
|
||||
"""
|
||||
prompt_id = getattr(litellm_params, "prompt_id", None)
|
||||
|
||||
api_base = litellm_params.api_base
|
||||
api_key = litellm_params.api_key
|
||||
if not api_base:
|
||||
raise ValueError("api_base is required in generic_prompt_config")
|
||||
|
||||
provider_specific_query_params = litellm_params.provider_specific_query_params
|
||||
|
||||
try:
|
||||
generic_prompt_manager = GenericPromptManager(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
prompt_id=prompt_id,
|
||||
additional_provider_specific_query_params=provider_specific_query_params,
|
||||
**litellm_params.model_dump(
|
||||
exclude_none=True,
|
||||
exclude={
|
||||
"prompt_id",
|
||||
"api_key",
|
||||
"provider_specific_query_params",
|
||||
"api_base",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
return generic_prompt_manager
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
prompt_initializer_registry = {
|
||||
SupportedPromptIntegrations.GENERIC_PROMPT_MANAGEMENT.value: prompt_initializer,
|
||||
}
|
||||
|
||||
# Export public API
|
||||
__all__ = [
|
||||
"GenericPromptManager",
|
||||
"set_global_generic_prompt_config",
|
||||
"global_generic_prompt_config",
|
||||
"prompt_initializer_registry",
|
||||
]
|
||||
@@ -0,0 +1,499 @@
|
||||
"""
|
||||
Generic prompt manager that integrates with LiteLLM's prompt management system.
|
||||
Fetches prompts from any API that implements the /beta/litellm_prompt_management endpoint.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
from litellm.integrations.prompt_management_base import (
|
||||
PromptManagementBase,
|
||||
PromptManagementClient,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.prompts.init_prompts import PromptSpec
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
|
||||
|
||||
class GenericPromptManager(CustomPromptManagement):
|
||||
"""
|
||||
Generic prompt manager that integrates with LiteLLM's prompt management system.
|
||||
|
||||
This class enables using prompts from any API that implements the
|
||||
/beta/litellm_prompt_management endpoint.
|
||||
|
||||
Usage:
|
||||
# Configure API access
|
||||
generic_config = {
|
||||
"api_base": "https://your-api.com",
|
||||
"api_key": "your-api-key", # optional
|
||||
"timeout": 30, # optional, defaults to 30
|
||||
}
|
||||
|
||||
# Use with completion
|
||||
response = litellm.completion(
|
||||
model="generic_prompt/gpt-4",
|
||||
prompt_id="my_prompt_id",
|
||||
prompt_variables={"variable": "value"},
|
||||
generic_prompt_config=generic_config,
|
||||
messages=[{"role": "user", "content": "Additional message"}]
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_base: str,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: int = 30,
|
||||
prompt_id: Optional[str] = None,
|
||||
additional_provider_specific_query_params: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the Generic Prompt Manager.
|
||||
|
||||
Args:
|
||||
api_base: Base URL for the API (e.g., "https://your-api.com")
|
||||
api_key: Optional API key for authentication
|
||||
timeout: Request timeout in seconds (default: 30)
|
||||
prompt_id: Optional prompt ID to pre-load
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.api_base = api_base.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.timeout = timeout
|
||||
self.prompt_id = prompt_id
|
||||
self.additional_provider_specific_query_params = (
|
||||
additional_provider_specific_query_params
|
||||
)
|
||||
self._prompt_cache: Dict[str, PromptManagementClient] = {}
|
||||
|
||||
@property
|
||||
def integration_name(self) -> str:
|
||||
"""Integration name used in model names like 'generic_prompt/gpt-4'."""
|
||||
return "generic_prompt"
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""Get HTTP headers for API requests."""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
return headers
|
||||
|
||||
def _fetch_prompt_from_api(
|
||||
self, prompt_id: Optional[str], prompt_spec: Optional[PromptSpec]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch a prompt from the API.
|
||||
|
||||
Args:
|
||||
prompt_id: The ID of the prompt to fetch
|
||||
|
||||
Returns:
|
||||
The prompt data from the API
|
||||
|
||||
Raises:
|
||||
Exception: If the API request fails
|
||||
"""
|
||||
if prompt_id is None and prompt_spec is None:
|
||||
raise ValueError("prompt_id or prompt_spec is required")
|
||||
|
||||
url = f"{self.api_base}/beta/litellm_prompt_management"
|
||||
params = {
|
||||
"prompt_id": prompt_id,
|
||||
**(self.additional_provider_specific_query_params or {}),
|
||||
}
|
||||
http_client = _get_httpx_client()
|
||||
|
||||
try:
|
||||
response = http_client.get(
|
||||
url,
|
||||
params=params,
|
||||
headers=self._get_headers(),
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPError as e:
|
||||
raise Exception(f"Failed to fetch prompt '{prompt_id}' from API: {e}")
|
||||
except json.JSONDecodeError as e:
|
||||
raise Exception(f"Failed to parse prompt response for '{prompt_id}': {e}")
|
||||
|
||||
async def async_fetch_prompt_from_api(
|
||||
self, prompt_id: Optional[str], prompt_spec: Optional[PromptSpec]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch a prompt from the API asynchronously.
|
||||
"""
|
||||
if prompt_id is None and prompt_spec is None:
|
||||
raise ValueError("prompt_id or prompt_spec is required")
|
||||
|
||||
url = f"{self.api_base}/beta/litellm_prompt_management"
|
||||
params = {
|
||||
"prompt_id": prompt_id,
|
||||
**(
|
||||
prompt_spec.litellm_params.provider_specific_query_params
|
||||
if prompt_spec
|
||||
and prompt_spec.litellm_params.provider_specific_query_params
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
http_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.PromptManagement,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await http_client.get(
|
||||
url,
|
||||
params=params,
|
||||
headers=self._get_headers(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPError as e:
|
||||
raise Exception(f"Failed to fetch prompt '{prompt_id}' from API: {e}")
|
||||
except json.JSONDecodeError as e:
|
||||
raise Exception(f"Failed to parse prompt response for '{prompt_id}': {e}")
|
||||
|
||||
def _parse_api_response(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
api_response: Dict[str, Any],
|
||||
) -> PromptManagementClient:
|
||||
"""
|
||||
Parse the API response into a PromptManagementClient structure.
|
||||
|
||||
Expected API response format:
|
||||
{
|
||||
"prompt_id": "string",
|
||||
"prompt_template": [
|
||||
{"role": "system", "content": "..."},
|
||||
{"role": "user", "content": "..."}
|
||||
],
|
||||
"prompt_template_model": "gpt-4", # optional
|
||||
"prompt_template_optional_params": { # optional
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100
|
||||
}
|
||||
}
|
||||
|
||||
Args:
|
||||
prompt_id: The ID of the prompt
|
||||
api_response: The response from the API
|
||||
|
||||
Returns:
|
||||
PromptManagementClient structure
|
||||
"""
|
||||
return PromptManagementClient(
|
||||
prompt_id=prompt_id,
|
||||
prompt_template=api_response.get("prompt_template", []),
|
||||
prompt_template_model=api_response.get("prompt_template_model"),
|
||||
prompt_template_optional_params=api_response.get(
|
||||
"prompt_template_optional_params"
|
||||
),
|
||||
completed_messages=None,
|
||||
)
|
||||
|
||||
def should_run_prompt_management(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if prompt management should run based on the prompt_id.
|
||||
|
||||
For Generic Prompt Manager, we always return True and handle the prompt loading
|
||||
in the _compile_prompt_helper method.
|
||||
"""
|
||||
if prompt_id is not None or (
|
||||
prompt_spec is not None
|
||||
and prompt_spec.litellm_params.provider_specific_query_params is not None
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_cache_key(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> str:
|
||||
return f"{prompt_id}:{prompt_label}:{prompt_version}"
|
||||
|
||||
def _common_caching_logic(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
prompt_variables: Optional[dict] = None,
|
||||
) -> Optional[PromptManagementClient]:
|
||||
"""
|
||||
Common caching logic for the prompt manager.
|
||||
"""
|
||||
# Check cache first
|
||||
cache_key = self._get_cache_key(prompt_id, prompt_label, prompt_version)
|
||||
if cache_key in self._prompt_cache:
|
||||
cached_prompt = self._prompt_cache[cache_key]
|
||||
# Return a copy with variables applied if needed
|
||||
if prompt_variables:
|
||||
return self._apply_variables(cached_prompt, prompt_variables)
|
||||
return cached_prompt
|
||||
return None
|
||||
|
||||
def _compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> PromptManagementClient:
|
||||
"""
|
||||
Compile a prompt template into a PromptManagementClient structure.
|
||||
|
||||
This method:
|
||||
1. Fetches the prompt from the API (with caching)
|
||||
2. Applies any prompt variables (if the API supports it)
|
||||
3. Returns the structured prompt data
|
||||
|
||||
Args:
|
||||
prompt_id: The ID of the prompt
|
||||
prompt_variables: Variables to substitute in the template (optional)
|
||||
dynamic_callback_params: Dynamic callback parameters
|
||||
prompt_label: Optional label for the prompt version
|
||||
prompt_version: Optional specific version number
|
||||
|
||||
Returns:
|
||||
PromptManagementClient structure
|
||||
"""
|
||||
cached_prompt = self._common_caching_logic(
|
||||
prompt_id=prompt_id,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
prompt_variables=prompt_variables,
|
||||
)
|
||||
if cached_prompt:
|
||||
return cached_prompt
|
||||
|
||||
cache_key = self._get_cache_key(prompt_id, prompt_label, prompt_version)
|
||||
try:
|
||||
# Fetch from API
|
||||
api_response = self._fetch_prompt_from_api(prompt_id, prompt_spec)
|
||||
|
||||
# Parse the response
|
||||
prompt_client = self._parse_api_response(
|
||||
prompt_id, prompt_spec, api_response
|
||||
)
|
||||
|
||||
# Cache the result
|
||||
self._prompt_cache[cache_key] = prompt_client
|
||||
|
||||
# Apply variables if provided
|
||||
if prompt_variables:
|
||||
prompt_client = self._apply_variables(prompt_client, prompt_variables)
|
||||
|
||||
return prompt_client
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error compiling prompt '{prompt_id}': {e}")
|
||||
|
||||
async def async_compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> PromptManagementClient:
|
||||
# Check cache first
|
||||
cached_prompt = self._common_caching_logic(
|
||||
prompt_id=prompt_id,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
prompt_variables=prompt_variables,
|
||||
)
|
||||
if cached_prompt:
|
||||
return cached_prompt
|
||||
|
||||
cache_key = self._get_cache_key(prompt_id, prompt_label, prompt_version)
|
||||
|
||||
try:
|
||||
# Fetch from API
|
||||
|
||||
api_response = await self.async_fetch_prompt_from_api(
|
||||
prompt_id=prompt_id, prompt_spec=prompt_spec
|
||||
)
|
||||
|
||||
# Parse the response
|
||||
prompt_client = self._parse_api_response(
|
||||
prompt_id, prompt_spec, api_response
|
||||
)
|
||||
|
||||
# Cache the result
|
||||
self._prompt_cache[cache_key] = prompt_client
|
||||
|
||||
# Apply variables if provided
|
||||
if prompt_variables:
|
||||
prompt_client = self._apply_variables(prompt_client, prompt_variables)
|
||||
|
||||
return prompt_client
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error compiling prompt '{prompt_id}': {e}, prompt_spec: {prompt_spec}"
|
||||
)
|
||||
|
||||
def _apply_variables(
|
||||
self,
|
||||
prompt_client: PromptManagementClient,
|
||||
variables: Dict[str, Any],
|
||||
) -> PromptManagementClient:
|
||||
"""
|
||||
Apply variables to the prompt template.
|
||||
|
||||
This performs simple string substitution using {variable_name} syntax.
|
||||
|
||||
Args:
|
||||
prompt_client: The prompt client structure
|
||||
variables: Variables to substitute
|
||||
|
||||
Returns:
|
||||
Updated PromptManagementClient with variables applied
|
||||
"""
|
||||
# Create a copy of the prompt template with variables applied
|
||||
updated_messages: List[AllMessageValues] = []
|
||||
for message in prompt_client["prompt_template"]:
|
||||
updated_message = dict(message) # type: ignore
|
||||
if "content" in updated_message and isinstance(
|
||||
updated_message["content"], str
|
||||
):
|
||||
content = updated_message["content"]
|
||||
for key, value in variables.items():
|
||||
content = content.replace(f"{{{key}}}", str(value))
|
||||
content = content.replace(
|
||||
f"{{{{{key}}}}}", str(value)
|
||||
) # Also support {{key}}
|
||||
updated_message["content"] = content
|
||||
updated_messages.append(updated_message) # type: ignore
|
||||
|
||||
return PromptManagementClient(
|
||||
prompt_id=prompt_client["prompt_id"],
|
||||
prompt_template=updated_messages,
|
||||
prompt_template_model=prompt_client["prompt_template_model"],
|
||||
prompt_template_optional_params=prompt_client[
|
||||
"prompt_template_optional_params"
|
||||
],
|
||||
completed_messages=None,
|
||||
)
|
||||
|
||||
async def async_get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
litellm_logging_obj: "LiteLLMLoggingObj",
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
ignore_prompt_manager_model: Optional[bool] = False,
|
||||
ignore_prompt_manager_optional_params: Optional[bool] = False,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Get chat completion prompt and return processed model, messages, and parameters.
|
||||
"""
|
||||
|
||||
return await PromptManagementBase.async_get_chat_completion_prompt(
|
||||
self,
|
||||
model,
|
||||
messages,
|
||||
non_default_params,
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
prompt_spec=prompt_spec,
|
||||
tools=tools,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
ignore_prompt_manager_model=(
|
||||
ignore_prompt_manager_model
|
||||
or prompt_spec.litellm_params.ignore_prompt_manager_model
|
||||
if prompt_spec
|
||||
else False
|
||||
),
|
||||
ignore_prompt_manager_optional_params=(
|
||||
ignore_prompt_manager_optional_params
|
||||
or prompt_spec.litellm_params.ignore_prompt_manager_optional_params
|
||||
if prompt_spec
|
||||
else False
|
||||
),
|
||||
)
|
||||
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
ignore_prompt_manager_model: Optional[bool] = False,
|
||||
ignore_prompt_manager_optional_params: Optional[bool] = False,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Get chat completion prompt and return processed model, messages, and parameters.
|
||||
"""
|
||||
return PromptManagementBase.get_chat_completion_prompt(
|
||||
self,
|
||||
model,
|
||||
messages,
|
||||
non_default_params,
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
prompt_spec=prompt_spec,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
ignore_prompt_manager_model=(
|
||||
ignore_prompt_manager_model
|
||||
or prompt_spec.litellm_params.ignore_prompt_manager_model
|
||||
if prompt_spec
|
||||
else False
|
||||
),
|
||||
ignore_prompt_manager_optional_params=(
|
||||
ignore_prompt_manager_optional_params
|
||||
or prompt_spec.litellm_params.ignore_prompt_manager_optional_params
|
||||
if prompt_spec
|
||||
else False
|
||||
),
|
||||
)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the prompt cache."""
|
||||
self._prompt_cache.clear()
|
||||
@@ -0,0 +1,317 @@
|
||||
# LiteLLM gitlab Prompt Management
|
||||
|
||||
A powerful prompt management system for LiteLLM that fetches `.prompt` files from gitlab repositories. This enables team-based prompt management with gitlab's built-in access control and version control capabilities.
|
||||
|
||||
## Features
|
||||
|
||||
- **🏢 Team-based access control**: Leverage gitlab's workspace and repository permissions
|
||||
- **📁 Repository-based prompt storage**: Store prompts in gitlab repositories
|
||||
- **🔐 Multiple authentication methods**: Support for access tokens and basic auth
|
||||
- **🎯 YAML frontmatter**: Define model, parameters, and schemas in file headers
|
||||
- **🔧 Handlebars templating**: Use `{{variable}}` syntax with Jinja2 backend
|
||||
- **✅ Input validation**: Automatic validation against defined schemas
|
||||
- **🔗 LiteLLM integration**: Works seamlessly with `litellm.completion()`
|
||||
- **💬 Smart message parsing**: Converts prompts to proper chat messages
|
||||
- **⚙️ Parameter extraction**: Automatically applies model settings from prompts
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Set up gitlab Repository
|
||||
|
||||
Create a repository in your gitlab workspace and add `.prompt` files:
|
||||
|
||||
```
|
||||
your-repo/
|
||||
├── prompts/
|
||||
│ ├── chat_assistant.prompt
|
||||
│ ├── code_reviewer.prompt
|
||||
│ └── data_analyst.prompt
|
||||
```
|
||||
|
||||
### 2. Create a `.prompt` file
|
||||
|
||||
Create a file called `prompts/chat_assistant.prompt`:
|
||||
|
||||
```yaml
|
||||
---
|
||||
model: gpt-4
|
||||
temperature: 0.7
|
||||
max_tokens: 150
|
||||
input:
|
||||
schema:
|
||||
user_message: string
|
||||
system_context?: string
|
||||
---
|
||||
|
||||
{% if system_context %}System: {{system_context}}
|
||||
|
||||
{% endif %}User: {{user_message}}
|
||||
```
|
||||
|
||||
### 3. Configure gitlab Access
|
||||
|
||||
#### Option A: Access Token (Recommended)
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
# Configure gitlab access
|
||||
gitlab_config = {
|
||||
"project": "a/b/<repo_name>",
|
||||
"access_token": "your-access-token",
|
||||
"base_url": "gitlab url",
|
||||
"prompts_path": "src/prompts", # folder to point to, defaults to root
|
||||
"branch":"main" # optional, defaults to main
|
||||
}
|
||||
|
||||
# Set global gitlab configuration
|
||||
litellm.set_global_gitlab_config(gitlab_config)
|
||||
```
|
||||
|
||||
#### Option B: Basic Authentication
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
# Configure gitlab access with basic auth
|
||||
gitlab_config = {
|
||||
"project": "a/b/<repo_name>",
|
||||
"base_url": "base url",
|
||||
"access_token": "your-app-password", # Use app password for basic auth
|
||||
"branch": "main",
|
||||
"prompts_path": "src/prompts", # folder to point to, defaults to root
|
||||
}
|
||||
|
||||
litellm.set_global_gitlab_config(gitlab_config)
|
||||
```
|
||||
|
||||
### 4. Use with LiteLLM
|
||||
|
||||
```python
|
||||
# Use with completion - the model prefix 'gitlab/' tells LiteLLM to use gitlab prompt management
|
||||
response = litellm.completion(
|
||||
model="gitlab/gpt-4", # The actual model comes from the .prompt file
|
||||
prompt_id="prompts/chat_assistant", # Location of the prompt file
|
||||
prompt_variables={
|
||||
"user_message": "What is machine learning?",
|
||||
"system_context": "You are a helpful AI tutor."
|
||||
},
|
||||
# Any additional messages will be appended after the prompt
|
||||
messages=[{"role": "user", "content": "Please explain it simply."}]
|
||||
)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
## Proxy Server Configuration
|
||||
|
||||
### 1. Create a `.prompt` file
|
||||
|
||||
Create `prompts/hello.prompt`:
|
||||
|
||||
```yaml
|
||||
---
|
||||
model: gpt-4
|
||||
temperature: 0.7
|
||||
---
|
||||
System: You are a helpful assistant.
|
||||
|
||||
User: {{user_message}}
|
||||
```
|
||||
|
||||
### 2. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: my-gitlab-model
|
||||
litellm_params:
|
||||
model: gitlab/gpt-4
|
||||
prompt_id: "prompts/hello"
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
litellm_settings:
|
||||
global_gitlab_config:
|
||||
workspace: "your-workspace"
|
||||
repository: "your-repo"
|
||||
access_token: "your-access-token"
|
||||
branch: "main"
|
||||
```
|
||||
|
||||
### 3. Start the proxy
|
||||
|
||||
```bash
|
||||
litellm --config config.yaml --detailed_debug
|
||||
```
|
||||
|
||||
### 4. Test it!
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-d '{
|
||||
"model": "my-gitlab-model",
|
||||
"messages": [{"role": "user", "content": "IGNORED"}],
|
||||
"prompt_variables": {
|
||||
"user_message": "What is the capital of France?"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
## Prompt File Format
|
||||
|
||||
### Basic Structure
|
||||
|
||||
```yaml
|
||||
---
|
||||
# Model configuration
|
||||
model: gpt-4
|
||||
temperature: 0.7
|
||||
max_tokens: 500
|
||||
|
||||
# Input schema (optional)
|
||||
input:
|
||||
schema:
|
||||
user_message: string
|
||||
system_context?: string
|
||||
---
|
||||
|
||||
System: You are a helpful {{role}} assistant.
|
||||
|
||||
User: {{user_message}}
|
||||
```
|
||||
|
||||
### Advanced Features
|
||||
|
||||
**Multi-role conversations:**
|
||||
|
||||
```yaml
|
||||
---
|
||||
model: gpt-4
|
||||
temperature: 0.3
|
||||
---
|
||||
System: You are a helpful coding assistant.
|
||||
|
||||
User: {{user_question}}
|
||||
```
|
||||
|
||||
**Dynamic model selection:**
|
||||
|
||||
```yaml
|
||||
---
|
||||
model: "{{preferred_model}}" # Model can be a variable
|
||||
temperature: 0.7
|
||||
---
|
||||
System: You are a helpful assistant specialized in {{domain}}.
|
||||
|
||||
User: {{user_message}}
|
||||
```
|
||||
|
||||
## Team-Based Access Control
|
||||
|
||||
gitlab's built-in permission system provides team-based access control:
|
||||
|
||||
1. **Workspace-level permissions**: Control access to entire workspaces
|
||||
2. **Repository-level permissions**: Control access to specific repositories
|
||||
3. **Branch-level permissions**: Control access to specific branches
|
||||
4. **User and group management**: Manage team members and their access levels
|
||||
|
||||
### Setting up Team Access
|
||||
|
||||
1. **Create workspaces for each team**:
|
||||
```
|
||||
team-a-prompts/
|
||||
team-b-prompts/
|
||||
team-c-prompts/
|
||||
```
|
||||
|
||||
2. **Configure repository permissions**:
|
||||
- Grant read access to team members
|
||||
- Grant write access to prompt maintainers
|
||||
- Use branch protection rules for production prompts
|
||||
|
||||
3. **Use different access tokens**:
|
||||
- Each team can have their own access token
|
||||
- Tokens can be scoped to specific repositories
|
||||
- Use app passwords for additional security
|
||||
|
||||
## API Reference
|
||||
|
||||
### gitlab Configuration
|
||||
|
||||
```python
|
||||
gitlab_config = {
|
||||
"workspace": str, # Required: gitlab workspace name
|
||||
"repository": str, # Required: Repository name
|
||||
"access_token": str, # Required: gitlab access token or app password
|
||||
"branch": str, # Optional: Branch to fetch from (default: "main")
|
||||
"base_url": str, # Optional: Custom gitlab API URL
|
||||
"auth_method": str, # Optional: "token" or "basic" (default: "token")
|
||||
"username": str, # Optional: Username for basic auth
|
||||
"base_url" : str # Optional: Incase where the base url is not https://api.gitlab.org/2.0
|
||||
}
|
||||
```
|
||||
|
||||
### LiteLLM Integration
|
||||
|
||||
```python
|
||||
response = litellm.completion(
|
||||
model="gitlab/<base_model>", # required (e.g., gitlab/gpt-4)
|
||||
prompt_id=str, # required - the .prompt filename without extension
|
||||
prompt_variables=dict, # optional - variables for template rendering
|
||||
gitlab_config=dict, # optional - gitlab configuration (if not set globally)
|
||||
messages=list, # optional - additional messages
|
||||
)
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The gitlab integration provides detailed error messages for common issues:
|
||||
|
||||
- **Authentication errors**: Invalid access tokens or credentials
|
||||
- **Permission errors**: Insufficient access to workspace/repository
|
||||
- **File not found**: Missing .prompt files
|
||||
- **Network errors**: Connection issues with gitlab API
|
||||
|
||||
## Security Considerations
|
||||
|
||||
1. **Access Token Security**: Store access tokens securely using environment variables or secret management systems
|
||||
2. **Repository Permissions**: Use gitlab's permission system to control access
|
||||
3. **Branch Protection**: Protect main branches from unauthorized changes
|
||||
4. **Audit Logging**: gitlab provides audit logs for all repository access
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **"Access denied" errors**: Check your gitlab permissions for the workspace and repository
|
||||
2. **"Authentication failed" errors**: Verify your access token or credentials
|
||||
3. **"File not found" errors**: Ensure the .prompt file exists in the specified branch
|
||||
4. **Template rendering errors**: Check your Handlebars syntax in the .prompt file
|
||||
|
||||
### Debug Mode
|
||||
|
||||
Enable debug logging to troubleshoot issues:
|
||||
|
||||
```python
|
||||
import litellm
|
||||
litellm.set_verbose = True
|
||||
|
||||
# Your gitlab prompt calls will now show detailed logs
|
||||
response = litellm.completion(
|
||||
model="gitlab/gpt-4",
|
||||
prompt_id="your_prompt",
|
||||
prompt_variables={"key": "value"}
|
||||
)
|
||||
```
|
||||
|
||||
## Migration from File-Based Prompts
|
||||
|
||||
If you're currently using file-based prompts with the dotprompt integration, you can easily migrate to gitlab:
|
||||
|
||||
1. **Upload your .prompt files** to a gitlab repository
|
||||
2. **Update your configuration** to use gitlab instead of local files
|
||||
3. **Set up team access** using gitlab's permission system
|
||||
4. **Update your code** to use `gitlab/` model prefix instead of `dotprompt/`
|
||||
|
||||
This provides better collaboration, version control, and team-based access control for your prompts.
|
||||
@@ -0,0 +1,94 @@
|
||||
from typing import TYPE_CHECKING, Optional, Dict, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .gitlab_prompt_manager import GitLabPromptManager
|
||||
from litellm.types.prompts.init_prompts import PromptLiteLLMParams, PromptSpec
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
|
||||
from litellm.types.prompts.init_prompts import SupportedPromptIntegrations
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
from litellm.types.prompts.init_prompts import PromptSpec, PromptLiteLLMParams
|
||||
from .gitlab_prompt_manager import GitLabPromptManager, GitLabPromptCache
|
||||
|
||||
# Global instances
|
||||
global_gitlab_config: Optional[dict] = None
|
||||
|
||||
|
||||
def set_global_gitlab_config(config: dict) -> None:
|
||||
"""
|
||||
Set the global gitlab configuration for prompt management.
|
||||
|
||||
Args:
|
||||
config: Dictionary containing gitlab configuration
|
||||
- workspace: gitlab workspace name
|
||||
- repository: Repository name
|
||||
- access_token: gitlab access token
|
||||
- branch: Branch to fetch prompts from (default: main)
|
||||
"""
|
||||
import litellm
|
||||
|
||||
litellm.global_gitlab_config = config # type: ignore
|
||||
|
||||
|
||||
def prompt_initializer(
|
||||
litellm_params: "PromptLiteLLMParams", prompt_spec: "PromptSpec"
|
||||
) -> "CustomPromptManagement":
|
||||
"""
|
||||
Initialize a prompt from a Gitlab repository.
|
||||
"""
|
||||
gitlab_config = getattr(litellm_params, "gitlab_config", None)
|
||||
prompt_id = getattr(litellm_params, "prompt_id", None)
|
||||
|
||||
if not gitlab_config:
|
||||
raise ValueError("gitlab_config is required for gitlab prompt integration")
|
||||
|
||||
try:
|
||||
gitlab_prompt_manager = GitLabPromptManager(
|
||||
gitlab_config=gitlab_config,
|
||||
prompt_id=prompt_id,
|
||||
)
|
||||
|
||||
return gitlab_prompt_manager
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def _gitlab_prompt_initializer(
|
||||
litellm_params: PromptLiteLLMParams,
|
||||
prompt: PromptSpec,
|
||||
) -> CustomPromptManagement:
|
||||
"""
|
||||
Build a GitLab-backed prompt manager for this prompt.
|
||||
Expected fields on litellm_params:
|
||||
- prompt_integration="gitlab" (handled by the caller)
|
||||
- gitlab_config: Dict[str, Any] (project/access_token/branch/prompts_path/etc.)
|
||||
- git_ref (optional): per-prompt tag/branch/SHA override
|
||||
"""
|
||||
# You can store arbitrary integration-specific config on PromptLiteLLMParams.
|
||||
# If your dataclass doesn't have these attributes, add them or put inside
|
||||
# `litellm_params.extra` and pull them from there.
|
||||
gitlab_config: Dict[str, Any] = getattr(litellm_params, "gitlab_config", None) or {}
|
||||
git_ref: Optional[str] = getattr(litellm_params, "git_ref", None)
|
||||
|
||||
if not gitlab_config:
|
||||
raise ValueError("gitlab_config is required for gitlab prompt integration")
|
||||
|
||||
# prompt.prompt_id can map to a file path under prompts_path (e.g. "chat/greet/hi")
|
||||
return GitLabPromptManager(
|
||||
gitlab_config=gitlab_config,
|
||||
prompt_id=prompt.prompt_id,
|
||||
ref=git_ref,
|
||||
)
|
||||
|
||||
|
||||
prompt_initializer_registry = {
|
||||
SupportedPromptIntegrations.GITLAB.value: _gitlab_prompt_initializer,
|
||||
}
|
||||
|
||||
# Export public API
|
||||
__all__ = [
|
||||
"GitLabPromptManager",
|
||||
"GitLabPromptCache",
|
||||
"set_global_gitlab_config",
|
||||
"global_gitlab_config",
|
||||
]
|
||||
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
GitLab API client for fetching files from GitLab repositories.
|
||||
Now supports selecting a tag via `config["tag"]`; falls back to branch ("main").
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
|
||||
class GitLabClient:
|
||||
"""
|
||||
Client for interacting with the GitLab API to fetch files.
|
||||
|
||||
Supports:
|
||||
- Authentication with personal/access tokens or OAuth bearer tokens
|
||||
- Fetching file contents from repositories (raw endpoint with JSON fallback)
|
||||
- Namespace/project path or numeric project ID addressing
|
||||
- Ref selection via tag (preferred) or branch (default "main")
|
||||
- Directory listing via the repository tree API
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialize the GitLab client.
|
||||
|
||||
Args:
|
||||
config: Dictionary containing:
|
||||
- project: Project path ("group/subgroup/repo") or numeric project ID (str|int) [required]
|
||||
- access_token: GitLab personal/access token or OAuth token [required] (str)
|
||||
- auth_method: 'token' (default; sends Private-Token) or 'oauth' (Authorization: Bearer)
|
||||
- tag: Tag name to fetch from (takes precedence over branch if provided)
|
||||
- branch: Branch to fetch from (default: "main")
|
||||
- base_url: Base GitLab API URL (default: "https://gitlab.com/api/v4")
|
||||
"""
|
||||
project = config.get("project")
|
||||
access_token = config.get("access_token")
|
||||
if project is None or access_token is None:
|
||||
raise ValueError("project and access_token are required")
|
||||
|
||||
self.project: str | int = project
|
||||
self.access_token: str = str(access_token)
|
||||
self.auth_method = config.get("auth_method", "token") # 'token' or 'oauth'
|
||||
self.branch = config.get("branch", None)
|
||||
if not self.branch:
|
||||
self.branch = "main"
|
||||
self.tag = config.get("tag")
|
||||
self.base_url = config.get("base_url", "https://gitlab.com/api/v4")
|
||||
|
||||
if not all([self.project, self.access_token]):
|
||||
raise ValueError("project and access_token are required")
|
||||
|
||||
# Effective ref: prefer tag if provided, else branch ("main")
|
||||
self.ref = str(self.tag or self.branch)
|
||||
|
||||
# Build headers
|
||||
self.headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.auth_method == "oauth":
|
||||
self.headers["Authorization"] = f"Bearer {self.access_token}"
|
||||
else:
|
||||
# Default GitLab token header
|
||||
self.headers["Private-Token"] = self.access_token
|
||||
|
||||
# Project identifier must be URL-encoded (slashes become %2F)
|
||||
self._project_enc = quote(str(self.project), safe="")
|
||||
|
||||
# HTTP handler
|
||||
self.http_handler = HTTPHandler()
|
||||
|
||||
# ------------------------
|
||||
# Core helpers
|
||||
# ------------------------
|
||||
|
||||
def _file_raw_url(self, file_path: str, *, ref: Optional[str] = None) -> str:
|
||||
file_enc = quote(file_path, safe="")
|
||||
ref_q = quote(ref or self.ref, safe="")
|
||||
return f"{self.base_url}/projects/{self._project_enc}/repository/files/{file_enc}/raw?ref={ref_q}"
|
||||
|
||||
def _file_json_url(self, file_path: str, *, ref: Optional[str] = None) -> str:
|
||||
file_enc = quote(file_path, safe="")
|
||||
ref_q = quote(ref or self.ref, safe="")
|
||||
return f"{self.base_url}/projects/{self._project_enc}/repository/files/{file_enc}?ref={ref_q}"
|
||||
|
||||
def _tree_url(
|
||||
self,
|
||||
directory_path: str = "",
|
||||
recursive: bool = False,
|
||||
*,
|
||||
ref: Optional[str] = None,
|
||||
) -> str:
|
||||
path_q = f"&path={quote(directory_path, safe='')}" if directory_path else ""
|
||||
rec_q = "&recursive=true" if recursive else ""
|
||||
ref_q = quote(ref or self.ref, safe="")
|
||||
return f"{self.base_url}/projects/{self._project_enc}/repository/tree?ref={ref_q}{path_q}{rec_q}"
|
||||
|
||||
# ------------------------
|
||||
# Public API
|
||||
# ------------------------
|
||||
|
||||
def set_ref(self, ref: str) -> None:
|
||||
"""Override the default ref (tag/branch) for subsequent calls."""
|
||||
if not ref:
|
||||
raise ValueError("ref must be a non-empty string")
|
||||
self.ref = ref
|
||||
|
||||
def get_file_content(
|
||||
self, file_path: str, *, ref: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Fetch the content of a file from the GitLab repository at the given ref
|
||||
(tag, branch, or commit SHA). If `ref` is None, uses self.ref.
|
||||
|
||||
Strategy:
|
||||
1) Try the RAW endpoint (returns bytes of the file)
|
||||
2) Fallback to the JSON endpoint (returns base64-encoded content)
|
||||
|
||||
Returns:
|
||||
File content as UTF-8 string, or None if file not found.
|
||||
"""
|
||||
raw_url = self._file_raw_url(file_path, ref=ref)
|
||||
|
||||
try:
|
||||
resp = self.http_handler.get(raw_url, headers=self.headers)
|
||||
if resp.status_code == 404:
|
||||
# Fallback to JSON endpoint
|
||||
return self._get_file_content_via_json(file_path, ref=ref)
|
||||
resp.raise_for_status()
|
||||
|
||||
ctype = (resp.headers.get("content-type") or "").lower()
|
||||
if (
|
||||
ctype.startswith("text/")
|
||||
or "charset=" in ctype
|
||||
or ctype.startswith("application/json")
|
||||
):
|
||||
return resp.text
|
||||
try:
|
||||
return resp.content.decode("utf-8")
|
||||
except Exception:
|
||||
return resp.content.decode("utf-8", errors="replace")
|
||||
|
||||
except Exception as e:
|
||||
status = getattr(getattr(e, "response", None), "status_code", None)
|
||||
if status == 404:
|
||||
return None
|
||||
if status == 403:
|
||||
raise Exception(
|
||||
f"Access denied to file '{file_path}'. Check your GitLab permissions for project '{self.project}'."
|
||||
)
|
||||
if status == 401:
|
||||
raise Exception(
|
||||
"Authentication failed. Check your GitLab token and auth_method."
|
||||
)
|
||||
raise Exception(f"Failed to fetch file '{file_path}': {e}")
|
||||
|
||||
def _get_file_content_via_json(
|
||||
self, file_path: str, *, ref: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Fallback for get_file_content(): use the JSON file API which returns base64 content.
|
||||
"""
|
||||
json_url = self._file_json_url(file_path, ref=ref)
|
||||
try:
|
||||
resp = self.http_handler.get(json_url, headers=self.headers)
|
||||
if resp.status_code == 404:
|
||||
return None
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
content = data.get("content")
|
||||
encoding = data.get("encoding", "")
|
||||
if content and encoding == "base64":
|
||||
try:
|
||||
return base64.b64decode(content).decode("utf-8")
|
||||
except Exception:
|
||||
return base64.b64decode(content).decode("utf-8", errors="replace")
|
||||
return content
|
||||
except Exception as e:
|
||||
status = getattr(getattr(e, "response", None), "status_code", None)
|
||||
if status == 404:
|
||||
return None
|
||||
if status == 403:
|
||||
raise Exception(
|
||||
f"Access denied to file '{file_path}'. Check your GitLab permissions for project '{self.project}'."
|
||||
)
|
||||
if status == 401:
|
||||
raise Exception(
|
||||
"Authentication failed. Check your GitLab token and auth_method."
|
||||
)
|
||||
raise Exception(
|
||||
f"Failed to fetch file '{file_path}' via JSON endpoint: {e}"
|
||||
)
|
||||
|
||||
def list_files(
|
||||
self,
|
||||
directory_path: str = "",
|
||||
file_extension: str = ".prompt",
|
||||
recursive: bool = False,
|
||||
*,
|
||||
ref: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
List files in a directory with a specific extension using the repository tree API.
|
||||
|
||||
Args:
|
||||
directory_path: Directory path in the repository (empty for repo root)
|
||||
file_extension: File extension to filter by (default: .prompt)
|
||||
recursive: If True, traverses subdirectories
|
||||
ref: Optional override (tag/branch/SHA). Defaults to self.ref.
|
||||
|
||||
Returns:
|
||||
List of file paths (relative to repo root)
|
||||
"""
|
||||
url = self._tree_url(directory_path, recursive=recursive, ref=ref)
|
||||
|
||||
try:
|
||||
resp = self.http_handler.get(url, headers=self.headers)
|
||||
if resp.status_code == 404:
|
||||
return []
|
||||
resp.raise_for_status()
|
||||
|
||||
data = resp.json() or []
|
||||
files: List[str] = []
|
||||
for item in data:
|
||||
if item.get("type") == "blob":
|
||||
file_path = item.get("path", "")
|
||||
if not file_extension or file_path.endswith(file_extension):
|
||||
files.append(file_path)
|
||||
return files
|
||||
|
||||
except Exception as e:
|
||||
status = getattr(getattr(e, "response", None), "status_code", None)
|
||||
if status == 404:
|
||||
return []
|
||||
if status == 403:
|
||||
raise Exception(
|
||||
f"Access denied to directory '{directory_path}'. Check your GitLab permissions for project '{self.project}'."
|
||||
)
|
||||
if status == 401:
|
||||
raise Exception(
|
||||
"Authentication failed. Check your GitLab token and auth_method."
|
||||
)
|
||||
raise Exception(f"Failed to list files in '{directory_path}': {e}")
|
||||
|
||||
def get_repository_info(self) -> Dict[str, Any]:
|
||||
"""Get information about the project/repository."""
|
||||
url = f"{self.base_url}/projects/{self._project_enc}"
|
||||
try:
|
||||
resp = self.http_handler.get(url, headers=self.headers)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to get repository info: {e}")
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""Test the connection to the GitLab project."""
|
||||
try:
|
||||
self.get_repository_info()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_branches(self) -> List[Dict[str, Any]]:
|
||||
"""Get list of branches in the repository."""
|
||||
url = f"{self.base_url}/projects/{self._project_enc}/repository/branches"
|
||||
try:
|
||||
resp = self.http_handler.get(url, headers=self.headers)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data if isinstance(data, list) else []
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to get branches: {e}")
|
||||
|
||||
def get_file_metadata(
|
||||
self, file_path: str, *, ref: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get minimal metadata about a file via RAW endpoint headers at a given ref.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file in the repository.
|
||||
ref: Optional override (tag/branch/SHA). Defaults to self.ref.
|
||||
"""
|
||||
url = self._file_raw_url(file_path, ref=ref)
|
||||
try:
|
||||
headers = dict(self.headers)
|
||||
headers["Range"] = "bytes=0-0"
|
||||
resp = self.http_handler.get(url, headers=headers)
|
||||
if resp.status_code == 404:
|
||||
return None
|
||||
resp.raise_for_status()
|
||||
return {
|
||||
"content_type": resp.headers.get("content-type"),
|
||||
"content_length": resp.headers.get("content-length"),
|
||||
"last_modified": resp.headers.get("last-modified"),
|
||||
}
|
||||
except Exception as e:
|
||||
status = getattr(getattr(e, "response", None), "status_code", None)
|
||||
if status == 404:
|
||||
return None
|
||||
raise Exception(f"Failed to get file metadata for '{file_path}': {e}")
|
||||
|
||||
def close(self):
|
||||
"""Close the HTTP handler to free resources."""
|
||||
if hasattr(self, "http_handler"):
|
||||
self.http_handler.close()
|
||||
@@ -0,0 +1,760 @@
|
||||
"""
|
||||
GitLab prompt manager with configurable prompts folder.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from jinja2 import DictLoader, Environment, select_autoescape
|
||||
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
from litellm.integrations.gitlab.gitlab_client import GitLabClient
|
||||
from litellm.integrations.prompt_management_base import (
|
||||
PromptManagementBase,
|
||||
PromptManagementClient,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.prompts.init_prompts import PromptSpec
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
GITLAB_PREFIX = "gitlab::"
|
||||
|
||||
|
||||
def encode_prompt_id(raw_id: str) -> str:
|
||||
"""Convert GitLab path IDs like 'invoice/extract' → 'gitlab::invoice::extract'"""
|
||||
if raw_id.startswith(GITLAB_PREFIX):
|
||||
return raw_id # already encoded
|
||||
return f"{GITLAB_PREFIX}{raw_id.replace('/', '::')}"
|
||||
|
||||
|
||||
def decode_prompt_id(encoded_id: str) -> str:
|
||||
"""Convert 'gitlab::invoice::extract' → 'invoice/extract'"""
|
||||
if not encoded_id.startswith(GITLAB_PREFIX):
|
||||
return encoded_id
|
||||
return encoded_id[len(GITLAB_PREFIX) :].replace("::", "/")
|
||||
|
||||
|
||||
class GitLabPromptTemplate:
|
||||
def __init__(
|
||||
self,
|
||||
template_id: str,
|
||||
content: str,
|
||||
metadata: Dict[str, Any],
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
self.template_id = template_id
|
||||
self.content = content
|
||||
self.metadata = metadata
|
||||
self.model = model or metadata.get("model")
|
||||
self.temperature = metadata.get("temperature")
|
||||
self.max_tokens = metadata.get("max_tokens")
|
||||
self.input_schema = metadata.get("input", {}).get("schema", {})
|
||||
self.optional_params = {
|
||||
k: v for k, v in metadata.items() if k not in ["model", "input", "content"]
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
return f"GitLabPromptTemplate(id='{self.template_id}', model='{self.model}')"
|
||||
|
||||
|
||||
class GitLabTemplateManager:
|
||||
"""
|
||||
Manager for loading and rendering .prompt files from GitLab repositories.
|
||||
|
||||
New: supports `prompts_path` (or `folder`) in gitlab_config to scope where prompts live.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gitlab_config: Dict[str, Any],
|
||||
prompt_id: Optional[str] = None,
|
||||
ref: Optional[str] = None,
|
||||
gitlab_client: Optional[GitLabClient] = None,
|
||||
):
|
||||
self.gitlab_config = dict(gitlab_config)
|
||||
self.prompt_id = prompt_id
|
||||
self.prompts: Dict[str, GitLabPromptTemplate] = {}
|
||||
self.gitlab_client = gitlab_client or GitLabClient(self.gitlab_config)
|
||||
|
||||
if ref:
|
||||
self.gitlab_client.set_ref(ref)
|
||||
|
||||
# Folder inside repo to look for prompts (e.g., "prompts" or "prompts/chat")
|
||||
self.prompts_path: str = (
|
||||
self.gitlab_config.get("prompts_path")
|
||||
or self.gitlab_config.get("folder")
|
||||
or ""
|
||||
).strip("/")
|
||||
|
||||
self.jinja_env = Environment(
|
||||
loader=DictLoader({}),
|
||||
autoescape=select_autoescape(["html", "xml"]),
|
||||
variable_start_string="{{",
|
||||
variable_end_string="}}",
|
||||
block_start_string="{%",
|
||||
block_end_string="%}",
|
||||
comment_start_string="{#",
|
||||
comment_end_string="#}",
|
||||
)
|
||||
|
||||
if self.prompt_id:
|
||||
self._load_prompt_from_gitlab(self.prompt_id)
|
||||
|
||||
# ---------- path helpers ----------
|
||||
|
||||
def _id_to_repo_path(self, prompt_id: str) -> str:
|
||||
"""Map a prompt_id to a repo path (respects prompts_path and adds .prompt)."""
|
||||
prompt_id = decode_prompt_id(prompt_id)
|
||||
if self.prompts_path:
|
||||
return f"{self.prompts_path}/{prompt_id}.prompt"
|
||||
return f"{prompt_id}.prompt"
|
||||
|
||||
def _repo_path_to_id(self, repo_path: str) -> str:
|
||||
"""
|
||||
Map a repo path like 'prompts/chat/greeting.prompt' to an ID relative
|
||||
to prompts_path without the extension (e.g., 'chat/greeting').
|
||||
"""
|
||||
path = repo_path.strip("/")
|
||||
if self.prompts_path and path.startswith(self.prompts_path.strip("/") + "/"):
|
||||
path = path[len(self.prompts_path.strip("/")) + 1 :]
|
||||
if path.endswith(".prompt"):
|
||||
path = path[: -len(".prompt")]
|
||||
return encode_prompt_id(path)
|
||||
|
||||
# ---------- loading ----------
|
||||
|
||||
def _load_prompt_from_gitlab(
|
||||
self, prompt_id: str, *, ref: Optional[str] = None
|
||||
) -> None:
|
||||
"""Load a specific .prompt file from GitLab (scoped under prompts_path if set)."""
|
||||
try:
|
||||
# prompt_id = decode_prompt_id(prompt_id)
|
||||
file_path = self._id_to_repo_path(prompt_id)
|
||||
prompt_content = self.gitlab_client.get_file_content(file_path, ref=ref)
|
||||
if prompt_content:
|
||||
template = self._parse_prompt_file(prompt_content, prompt_id)
|
||||
self.prompts[prompt_id] = template
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Failed to load prompt '{encode_prompt_id(prompt_id)}' from GitLab: {e}"
|
||||
)
|
||||
|
||||
def load_all_prompts(self, *, recursive: bool = True) -> List[str]:
|
||||
"""
|
||||
Eagerly load all .prompt files from prompts_path. Returns loaded IDs.
|
||||
"""
|
||||
files = self.list_templates(recursive=recursive)
|
||||
loaded: List[str] = []
|
||||
for pid in files:
|
||||
if pid not in self.prompts:
|
||||
self._load_prompt_from_gitlab(pid)
|
||||
loaded.append(pid)
|
||||
return loaded
|
||||
|
||||
# ---------- parsing & rendering ----------
|
||||
|
||||
def _parse_prompt_file(self, content: str, prompt_id: str) -> GitLabPromptTemplate:
|
||||
if content.startswith("---"):
|
||||
parts = content.split("---", 2)
|
||||
if len(parts) >= 3:
|
||||
frontmatter_str = parts[1].strip()
|
||||
template_content = parts[2].strip()
|
||||
else:
|
||||
frontmatter_str = ""
|
||||
template_content = content
|
||||
else:
|
||||
frontmatter_str = ""
|
||||
template_content = content
|
||||
|
||||
metadata: Dict[str, Any] = {}
|
||||
if frontmatter_str:
|
||||
try:
|
||||
import yaml
|
||||
|
||||
metadata = yaml.safe_load(frontmatter_str) or {}
|
||||
except ImportError:
|
||||
metadata = self._parse_yaml_basic(frontmatter_str)
|
||||
except Exception:
|
||||
metadata = {}
|
||||
|
||||
return GitLabPromptTemplate(
|
||||
template_id=prompt_id,
|
||||
content=template_content,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _parse_yaml_basic(self, yaml_str: str) -> Dict[str, Any]:
|
||||
result: Dict[str, Any] = {}
|
||||
for line in yaml_str.split("\n"):
|
||||
line = line.strip()
|
||||
if ":" in line and not line.startswith("#"):
|
||||
key, value = line.split(":", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
if value.lower() in ["true", "false"]:
|
||||
result[key] = value.lower() == "true"
|
||||
elif value.isdigit():
|
||||
result[key] = int(value)
|
||||
elif value.replace(".", "").isdigit():
|
||||
try:
|
||||
result[key] = float(value)
|
||||
except Exception:
|
||||
result[key] = value
|
||||
else:
|
||||
result[key] = value.strip("\"'")
|
||||
return result
|
||||
|
||||
def render_template(
|
||||
self, template_id: str, variables: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
if template_id not in self.prompts:
|
||||
raise ValueError(f"Template '{template_id}' not found")
|
||||
template = self.prompts[template_id]
|
||||
jinja_template = self.jinja_env.from_string(template.content)
|
||||
return jinja_template.render(**(variables or {}))
|
||||
|
||||
def get_template(self, template_id: str) -> Optional[GitLabPromptTemplate]:
|
||||
return self.prompts.get(template_id)
|
||||
|
||||
def list_templates(self, *, recursive: bool = True) -> List[str]:
|
||||
"""
|
||||
List available prompt IDs under prompts_path (no extension).
|
||||
Compatible with both list_files signatures:
|
||||
- list_files(directory_path=..., file_extension=..., recursive=...)
|
||||
- list_files(path=..., ref=None, recursive=...)
|
||||
"""
|
||||
# First try the "new" signature (directory_path/file_extension)
|
||||
try:
|
||||
files = self.gitlab_client.list_files(
|
||||
directory_path=self.prompts_path,
|
||||
file_extension=".prompt",
|
||||
recursive=recursive,
|
||||
)
|
||||
base = self.prompts_path.strip("/")
|
||||
out: List[str] = []
|
||||
for p in files or []:
|
||||
path = str(p).strip("/")
|
||||
if base and not path.startswith(base + "/"):
|
||||
# if the client returns extra files outside the folder, skip them
|
||||
continue
|
||||
if not path.endswith(".prompt"):
|
||||
continue
|
||||
out.append(self._repo_path_to_id(path))
|
||||
return out
|
||||
except TypeError:
|
||||
# Fallback to the "classic" signature
|
||||
raw = self.gitlab_client.list_files(
|
||||
directory_path=self.prompts_path or "",
|
||||
ref=None,
|
||||
recursive=recursive,
|
||||
)
|
||||
# Classic returns GitLab tree entries; filter *.prompt blobs
|
||||
files = []
|
||||
for f in raw or []:
|
||||
if (
|
||||
isinstance(f, dict)
|
||||
and f.get("type") == "blob"
|
||||
and str(f.get("path", "")).endswith(".prompt")
|
||||
and "path" in f
|
||||
):
|
||||
files.append(f["path"]) # type: ignore
|
||||
|
||||
return [self._repo_path_to_id(p) for p in files]
|
||||
|
||||
|
||||
class GitLabPromptManager(CustomPromptManagement):
|
||||
"""
|
||||
GitLab prompt manager with folder support.
|
||||
|
||||
Example config:
|
||||
gitlab_config = {
|
||||
"project": "group/subgroup/repo",
|
||||
"access_token": "glpat_***",
|
||||
"tag": "v1.2.3", # optional; takes precedence
|
||||
"branch": "main", # default fallback
|
||||
"prompts_path": "prompts/chat"
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gitlab_config: Dict[str, Any],
|
||||
prompt_id: Optional[str] = None,
|
||||
ref: Optional[str] = None, # tag/branch/SHA override
|
||||
gitlab_client: Optional[GitLabClient] = None,
|
||||
):
|
||||
self.gitlab_config = gitlab_config
|
||||
self.prompt_id = prompt_id
|
||||
self._prompt_manager: Optional[GitLabTemplateManager] = None
|
||||
self._ref_override = ref
|
||||
self._injected_gitlab_client = gitlab_client
|
||||
if self.prompt_id:
|
||||
self._prompt_manager = GitLabTemplateManager(
|
||||
gitlab_config=self.gitlab_config,
|
||||
prompt_id=self.prompt_id,
|
||||
ref=self._ref_override,
|
||||
)
|
||||
|
||||
@property
|
||||
def integration_name(self) -> str:
|
||||
return "gitlab"
|
||||
|
||||
@property
|
||||
def prompt_manager(self) -> GitLabTemplateManager:
|
||||
if self._prompt_manager is None:
|
||||
self._prompt_manager = GitLabTemplateManager(
|
||||
gitlab_config=self.gitlab_config,
|
||||
prompt_id=self.prompt_id,
|
||||
ref=self._ref_override,
|
||||
gitlab_client=self._injected_gitlab_client,
|
||||
)
|
||||
return self._prompt_manager
|
||||
|
||||
def get_prompt_template(
|
||||
self,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
ref: Optional[str] = None,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
if prompt_id not in self.prompt_manager.prompts:
|
||||
self.prompt_manager._load_prompt_from_gitlab(prompt_id, ref=ref)
|
||||
|
||||
template = self.prompt_manager.get_template(prompt_id)
|
||||
if not template:
|
||||
raise ValueError(f"Prompt template '{prompt_id}' not found")
|
||||
|
||||
rendered_prompt = self.prompt_manager.render_template(
|
||||
prompt_id, prompt_variables or {}
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"model": template.model,
|
||||
"temperature": template.temperature,
|
||||
"max_tokens": template.max_tokens,
|
||||
**template.optional_params,
|
||||
}
|
||||
return rendered_prompt, metadata
|
||||
|
||||
def pre_call_hook(
|
||||
self,
|
||||
user_id: Optional[str],
|
||||
messages: List[AllMessageValues],
|
||||
function_call: Optional[Union[Dict[str, Any], str]] = None,
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
prompt_id: Optional[str] = None,
|
||||
prompt_variables: Optional[Dict[str, Any]] = None,
|
||||
prompt_version: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[List[AllMessageValues], Optional[Dict[str, Any]]]:
|
||||
if not prompt_id:
|
||||
return messages, litellm_params
|
||||
try:
|
||||
# Precedence: explicit prompt_version → per-call git_ref kwarg → manager override → config default
|
||||
git_ref = prompt_version or kwargs.get("git_ref") or self._ref_override
|
||||
|
||||
rendered_prompt, prompt_metadata = self.get_prompt_template(
|
||||
prompt_id, prompt_variables, ref=git_ref
|
||||
)
|
||||
parsed_messages = self._parse_prompt_to_messages(rendered_prompt)
|
||||
|
||||
if parsed_messages:
|
||||
final_messages: List[AllMessageValues] = parsed_messages
|
||||
else:
|
||||
final_messages = [{"role": "user", "content": rendered_prompt}] + messages # type: ignore
|
||||
|
||||
if litellm_params is None:
|
||||
litellm_params = {}
|
||||
|
||||
if prompt_metadata.get("model"):
|
||||
litellm_params["model"] = prompt_metadata["model"]
|
||||
|
||||
for param in [
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
]:
|
||||
if param in prompt_metadata:
|
||||
litellm_params[param] = prompt_metadata[param]
|
||||
|
||||
return final_messages, litellm_params
|
||||
except Exception as e:
|
||||
import litellm
|
||||
|
||||
litellm._logging.verbose_proxy_logger.error(
|
||||
f"Error in GitLab prompt pre_call_hook: {e}"
|
||||
)
|
||||
return messages, litellm_params
|
||||
|
||||
def _parse_prompt_to_messages(self, prompt_content: str) -> List[AllMessageValues]:
|
||||
messages: List[AllMessageValues] = []
|
||||
lines = prompt_content.strip().split("\n")
|
||||
current_role: Optional[str] = None
|
||||
current_content: List[str] = []
|
||||
|
||||
for raw in lines:
|
||||
line = raw.strip()
|
||||
if not line:
|
||||
continue
|
||||
low = line.lower()
|
||||
if low.startswith("system:"):
|
||||
if current_role and current_content:
|
||||
messages.append({"role": current_role, "content": "\n".join(current_content).strip()}) # type: ignore
|
||||
current_role = "system"
|
||||
current_content = [line[7:].strip()]
|
||||
elif low.startswith("user:"):
|
||||
if current_role and current_content:
|
||||
messages.append({"role": current_role, "content": "\n".join(current_content).strip()}) # type: ignore
|
||||
current_role = "user"
|
||||
current_content = [line[5:].strip()]
|
||||
elif low.startswith("assistant:"):
|
||||
if current_role and current_content:
|
||||
messages.append({"role": current_role, "content": "\n".join(current_content).strip()}) # type: ignore
|
||||
current_role = "assistant"
|
||||
current_content = [line[10:].strip()]
|
||||
else:
|
||||
current_content.append(line)
|
||||
|
||||
if current_role and current_content:
|
||||
messages.append({"role": current_role, "content": "\n".join(current_content).strip()}) # type: ignore
|
||||
if not messages and prompt_content.strip():
|
||||
messages = [{"role": "user", "content": prompt_content.strip()}] # type: ignore
|
||||
return messages
|
||||
|
||||
def post_call_hook(
|
||||
self,
|
||||
user_id: Optional[str],
|
||||
response: Any,
|
||||
input_messages: List[AllMessageValues],
|
||||
function_call: Optional[Union[Dict[str, Any], str]] = None,
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
prompt_id: Optional[str] = None,
|
||||
prompt_variables: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
return response
|
||||
|
||||
def get_available_prompts(self) -> List[str]:
|
||||
"""
|
||||
Return prompt IDs. Prefer already-loaded templates in memory to avoid
|
||||
unnecessary network calls (and to make tests deterministic).
|
||||
"""
|
||||
ids = set(self.prompt_manager.prompts.keys())
|
||||
try:
|
||||
ids.update(self.prompt_manager.list_templates())
|
||||
except Exception:
|
||||
# If GitLab list fails (auth, network), still return what we've loaded.
|
||||
pass
|
||||
return sorted(ids)
|
||||
|
||||
def reload_prompts(self) -> None:
|
||||
if self.prompt_id:
|
||||
self._prompt_manager = None
|
||||
_ = self.prompt_manager # trigger re-init/load
|
||||
|
||||
def should_run_prompt_management(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> bool:
|
||||
return prompt_id is not None
|
||||
|
||||
def _compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> PromptManagementClient:
|
||||
if prompt_id is None:
|
||||
raise ValueError("prompt_id is required for GitLab prompt manager")
|
||||
|
||||
try:
|
||||
decoded_id = decode_prompt_id(prompt_id)
|
||||
if decoded_id not in self.prompt_manager.prompts:
|
||||
git_ref = (
|
||||
getattr(dynamic_callback_params, "extra", {}).get("git_ref")
|
||||
if hasattr(dynamic_callback_params, "extra")
|
||||
else None
|
||||
)
|
||||
self.prompt_manager._load_prompt_from_gitlab(decoded_id, ref=git_ref)
|
||||
|
||||
rendered_prompt, prompt_metadata = self.get_prompt_template(
|
||||
prompt_id, prompt_variables
|
||||
)
|
||||
|
||||
messages = self._parse_prompt_to_messages(rendered_prompt)
|
||||
template_model = prompt_metadata.get("model")
|
||||
|
||||
optional_params: Dict[str, Any] = {}
|
||||
for param in [
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
]:
|
||||
if param in prompt_metadata:
|
||||
optional_params[param] = prompt_metadata[param]
|
||||
|
||||
return PromptManagementClient(
|
||||
prompt_id=prompt_id,
|
||||
prompt_template=messages,
|
||||
prompt_template_model=template_model,
|
||||
prompt_template_optional_params=optional_params,
|
||||
completed_messages=None,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error compiling prompt '{prompt_id}': {e}")
|
||||
|
||||
async def async_compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> PromptManagementClient:
|
||||
"""
|
||||
Async version of compile prompt helper. Since GitLab operations use sync client,
|
||||
this simply delegates to the sync version.
|
||||
"""
|
||||
if prompt_id is None:
|
||||
raise ValueError("prompt_id is required for GitLab prompt manager")
|
||||
|
||||
return self._compile_prompt_helper(
|
||||
prompt_id=prompt_id,
|
||||
prompt_spec=prompt_spec,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
)
|
||||
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
ignore_prompt_manager_model: Optional[bool] = False,
|
||||
ignore_prompt_manager_optional_params: Optional[bool] = False,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
return PromptManagementBase.get_chat_completion_prompt(
|
||||
self,
|
||||
model,
|
||||
messages,
|
||||
non_default_params,
|
||||
prompt_id,
|
||||
prompt_variables,
|
||||
dynamic_callback_params,
|
||||
prompt_spec=prompt_spec,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
)
|
||||
|
||||
async def async_get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
ignore_prompt_manager_model: Optional[bool] = False,
|
||||
ignore_prompt_manager_optional_params: Optional[bool] = False,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Async version - delegates to PromptManagementBase async implementation.
|
||||
"""
|
||||
return await PromptManagementBase.async_get_chat_completion_prompt(
|
||||
self,
|
||||
model,
|
||||
messages,
|
||||
non_default_params,
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
prompt_spec=prompt_spec,
|
||||
tools=tools,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
ignore_prompt_manager_model=ignore_prompt_manager_model,
|
||||
ignore_prompt_manager_optional_params=ignore_prompt_manager_optional_params,
|
||||
)
|
||||
|
||||
|
||||
class GitLabPromptCache:
|
||||
"""
|
||||
Cache all .prompt files from a GitLab repo into memory.
|
||||
|
||||
- Keys are the *repo file paths* (e.g. "prompts/chat/greet/hi.prompt")
|
||||
mapped to JSON-like dicts containing content + metadata.
|
||||
- Also exposes a by-ID view (ID == path relative to prompts_path without ".prompt",
|
||||
e.g. "greet/hi").
|
||||
|
||||
Usage:
|
||||
|
||||
cfg = {
|
||||
"project": "group/subgroup/repo",
|
||||
"access_token": "glpat_***",
|
||||
"prompts_path": "prompts/chat", # optional, can be empty for repo root
|
||||
# "branch": "main", # default is "main"
|
||||
# "tag": "v1.2.3", # takes precedence over branch
|
||||
# "base_url": "https://gitlab.com/api/v4" # default
|
||||
}
|
||||
|
||||
cache = GitLabPromptCache(cfg)
|
||||
cache.load_all() # fetch + parse all .prompt files
|
||||
|
||||
print(cache.list_files()) # repo file paths
|
||||
print(cache.list_ids()) # template IDs relative to prompts_path
|
||||
|
||||
prompt_json = cache.get_by_file("prompts/chat/greet/hi.prompt")
|
||||
prompt_json2 = cache.get_by_id("greet/hi")
|
||||
|
||||
# If GitLab content changes and you want to refresh:
|
||||
cache.reload() # re-scan and refresh all
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gitlab_config: Dict[str, Any],
|
||||
*,
|
||||
ref: Optional[str] = None,
|
||||
gitlab_client: Optional[GitLabClient] = None,
|
||||
) -> None:
|
||||
# Build a PromptManager (which internally builds TemplateManager + Client)
|
||||
self.prompt_manager = GitLabPromptManager(
|
||||
gitlab_config=gitlab_config,
|
||||
prompt_id=None,
|
||||
ref=ref,
|
||||
gitlab_client=gitlab_client,
|
||||
)
|
||||
self.template_manager: GitLabTemplateManager = (
|
||||
self.prompt_manager.prompt_manager
|
||||
)
|
||||
|
||||
# In-memory stores
|
||||
self._by_file: Dict[str, Dict[str, Any]] = {}
|
||||
self._by_id: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# -------------------------
|
||||
# Public API
|
||||
# -------------------------
|
||||
|
||||
def load_all(self, *, recursive: bool = True) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Scan GitLab for all .prompt files under prompts_path, load and parse each,
|
||||
and return the mapping of repo file path -> JSON-like dict.
|
||||
"""
|
||||
ids = self.template_manager.list_templates(
|
||||
recursive=recursive
|
||||
) # IDs relative to prompts_path
|
||||
for pid in ids:
|
||||
# Ensure template is loaded into TemplateManager
|
||||
if pid not in self.template_manager.prompts:
|
||||
self.template_manager._load_prompt_from_gitlab(pid)
|
||||
|
||||
tmpl = self.template_manager.get_template(pid)
|
||||
if tmpl is None:
|
||||
# If something raced/failed, try once more
|
||||
self.template_manager._load_prompt_from_gitlab(pid)
|
||||
tmpl = self.template_manager.get_template(pid)
|
||||
if tmpl is None:
|
||||
continue
|
||||
|
||||
file_path = self.template_manager._id_to_repo_path(
|
||||
pid
|
||||
) # "prompts/chat/..../file.prompt"
|
||||
entry = self._template_to_json(pid, tmpl)
|
||||
|
||||
self._by_file[file_path] = entry
|
||||
# prefixed_id = pid if pid.startswith("gitlab::") else f"gitlab::{pid}"
|
||||
encoded_id = encode_prompt_id(pid)
|
||||
self._by_id[encoded_id] = entry
|
||||
# self._by_id[pid] = entry
|
||||
|
||||
return self._by_id
|
||||
|
||||
def reload(self, *, recursive: bool = True) -> Dict[str, Dict[str, Any]]:
|
||||
"""Clear the cache and re-load from GitLab."""
|
||||
self._by_file.clear()
|
||||
self._by_id.clear()
|
||||
return self.load_all(recursive=recursive)
|
||||
|
||||
def list_files(self) -> List[str]:
|
||||
"""Return the repo file paths currently cached."""
|
||||
return list(self._by_file.keys())
|
||||
|
||||
def list_ids(self) -> List[str]:
|
||||
"""Return the template IDs (relative to prompts_path, without extension) currently cached."""
|
||||
return list(self._by_id.keys())
|
||||
|
||||
def get_by_file(self, file_path: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a cached prompt JSON by repo file path."""
|
||||
return self._by_file.get(file_path)
|
||||
|
||||
def get_by_id(self, prompt_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a cached prompt JSON by prompt ID (relative to prompts_path)."""
|
||||
if prompt_id in self._by_id:
|
||||
return self._by_id[prompt_id]
|
||||
|
||||
# Try normalized forms
|
||||
decoded = decode_prompt_id(prompt_id)
|
||||
encoded = encode_prompt_id(decoded)
|
||||
|
||||
return self._by_id.get(encoded) or self._by_id.get(decoded)
|
||||
|
||||
# -------------------------
|
||||
# Internals
|
||||
# -------------------------
|
||||
|
||||
def _template_to_json(
|
||||
self, prompt_id: str, tmpl: GitLabPromptTemplate
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Normalize a GitLabPromptTemplate into a JSON-like dict that is easy to serialize.
|
||||
"""
|
||||
# Safer copy of metadata (avoid accidental mutation)
|
||||
md = dict(tmpl.metadata or {})
|
||||
|
||||
# Pull standard fields (also present in metadata sometimes)
|
||||
model = tmpl.model
|
||||
temperature = tmpl.temperature
|
||||
max_tokens = tmpl.max_tokens
|
||||
optional_params = dict(tmpl.optional_params or {})
|
||||
|
||||
return {
|
||||
"id": prompt_id, # e.g. "greet/hi"
|
||||
"path": self.template_manager._id_to_repo_path(
|
||||
prompt_id
|
||||
), # e.g. "prompts/chat/greet/hi.prompt"
|
||||
"content": tmpl.content, # rendered content (without frontmatter)
|
||||
"metadata": md, # parsed frontmatter
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"optional_params": optional_params,
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class GreenscaleLogger:
|
||||
def __init__(self):
|
||||
import os
|
||||
|
||||
self.greenscale_api_key = os.getenv("GREENSCALE_API_KEY")
|
||||
self.headers = {
|
||||
"api-key": self.greenscale_api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self.greenscale_logging_url = os.getenv("GREENSCALE_ENDPOINT")
|
||||
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
try:
|
||||
response_json = response_obj.model_dump() if response_obj else {}
|
||||
data = {
|
||||
"modelId": kwargs.get("model"),
|
||||
"inputTokenCount": response_json.get("usage", {}).get("prompt_tokens"),
|
||||
"outputTokenCount": response_json.get("usage", {}).get(
|
||||
"completion_tokens"
|
||||
),
|
||||
}
|
||||
data["timestamp"] = datetime.now(timezone.utc).strftime(
|
||||
"%Y-%m-%dT%H:%M:%SZ"
|
||||
)
|
||||
|
||||
if type(end_time) is datetime and type(start_time) is datetime:
|
||||
data["invocationLatency"] = int(
|
||||
(end_time - start_time).total_seconds() * 1000
|
||||
)
|
||||
|
||||
# Add additional metadata keys to tags
|
||||
tags = []
|
||||
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
for key, value in metadata.items():
|
||||
if key.startswith("greenscale"):
|
||||
if key == "greenscale_project":
|
||||
data["project"] = value
|
||||
elif key == "greenscale_application":
|
||||
data["application"] = value
|
||||
else:
|
||||
tags.append(
|
||||
{"key": key.replace("greenscale_", ""), "value": str(value)}
|
||||
)
|
||||
|
||||
data["tags"] = tags
|
||||
|
||||
if self.greenscale_logging_url is None:
|
||||
raise Exception("Greenscale Logger Error - No logging URL found")
|
||||
|
||||
response = litellm.module_level_client.post(
|
||||
self.greenscale_logging_url,
|
||||
headers=self.headers,
|
||||
data=json.dumps(data, default=str),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
print_verbose(
|
||||
f"Greenscale Logger Error - {response.text}, {response.status_code}"
|
||||
)
|
||||
else:
|
||||
print_verbose(f"Greenscale Logger Succeeded - {response.text}")
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
f"Greenscale Logger Error - {e}, Stack trace: {traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
@@ -0,0 +1,228 @@
|
||||
#### What this does ####
|
||||
# On success, logs events to Helicone
|
||||
import os
|
||||
import traceback
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.helicone_mock_client import (
|
||||
should_use_helicone_mock,
|
||||
create_mock_helicone_client,
|
||||
)
|
||||
|
||||
|
||||
class HeliconeLogger:
|
||||
# Class variables or attributes
|
||||
helicone_model_list = [
|
||||
"gpt",
|
||||
"claude",
|
||||
"gemini",
|
||||
"command-r",
|
||||
"command-r-plus",
|
||||
"command-light",
|
||||
"command-medium",
|
||||
"command-medium-beta",
|
||||
"command-xlarge-nightly",
|
||||
"command-nightly",
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
# Instance variables
|
||||
self.is_mock_mode = should_use_helicone_mock()
|
||||
if self.is_mock_mode:
|
||||
create_mock_helicone_client()
|
||||
verbose_logger.info(
|
||||
"[HELICONE MOCK] Helicone logger initialized in mock mode"
|
||||
)
|
||||
|
||||
self.provider_url = "https://api.openai.com/v1"
|
||||
self.key = os.getenv("HELICONE_API_KEY")
|
||||
self.api_base = os.getenv("HELICONE_API_BASE") or "https://api.hconeai.com"
|
||||
if self.api_base.endswith("/"):
|
||||
self.api_base = self.api_base[:-1]
|
||||
|
||||
def claude_mapping(self, model, messages, response_obj):
|
||||
from anthropic import AI_PROMPT, HUMAN_PROMPT
|
||||
|
||||
prompt = f"{HUMAN_PROMPT}"
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += f"{HUMAN_PROMPT}{message['content']}"
|
||||
else:
|
||||
prompt += f"{AI_PROMPT}{message['content']}"
|
||||
else:
|
||||
prompt += f"{HUMAN_PROMPT}{message['content']}"
|
||||
prompt += f"{AI_PROMPT}"
|
||||
|
||||
choice = response_obj["choices"][0]
|
||||
message = choice["message"]
|
||||
|
||||
content = []
|
||||
if "tool_calls" in message and message["tool_calls"]:
|
||||
for tool_call in message["tool_calls"]:
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tool_call["id"],
|
||||
"name": tool_call["function"]["name"],
|
||||
"input": tool_call["function"]["arguments"],
|
||||
}
|
||||
)
|
||||
elif "content" in message and message["content"]:
|
||||
content = [{"type": "text", "text": message["content"]}]
|
||||
|
||||
claude_response_obj = {
|
||||
"id": response_obj["id"],
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": content,
|
||||
"stop_reason": choice["finish_reason"],
|
||||
"stop_sequence": None,
|
||||
"usage": {
|
||||
"input_tokens": response_obj["usage"]["prompt_tokens"],
|
||||
"output_tokens": response_obj["usage"]["completion_tokens"],
|
||||
},
|
||||
}
|
||||
|
||||
return claude_response_obj
|
||||
|
||||
@staticmethod
|
||||
def add_metadata_from_header(litellm_params: dict, metadata: dict) -> dict:
|
||||
"""
|
||||
Adds metadata from proxy request headers to Helicone logging if keys start with "helicone_"
|
||||
and overwrites litellm_params.metadata if already included.
|
||||
|
||||
For example if you want to add custom property to your request, send
|
||||
`headers: { ..., helicone-property-something: 1234 }` via proxy request.
|
||||
"""
|
||||
if litellm_params is None:
|
||||
return metadata
|
||||
|
||||
if litellm_params.get("proxy_server_request") is None:
|
||||
return metadata
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
proxy_headers = (
|
||||
litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
|
||||
)
|
||||
|
||||
for header_key in proxy_headers:
|
||||
if header_key.startswith("helicone_"):
|
||||
metadata[header_key] = proxy_headers.get(header_key)
|
||||
|
||||
# Remove OpenTelemetry span from metadata as it's not JSON serializable
|
||||
# The span is used internally for tracing but shouldn't be logged to external services
|
||||
if "litellm_parent_otel_span" in metadata:
|
||||
metadata.pop("litellm_parent_otel_span")
|
||||
|
||||
return metadata
|
||||
|
||||
def log_success(
|
||||
self, model, messages, response_obj, start_time, end_time, print_verbose, kwargs
|
||||
):
|
||||
# Method definition
|
||||
try:
|
||||
print_verbose(
|
||||
f"Helicone Logging - Enters logging function for model {model}"
|
||||
)
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider", "")
|
||||
kwargs.get("litellm_call_id", None)
|
||||
metadata = litellm_params.get("metadata", {}) or {}
|
||||
metadata = self.add_metadata_from_header(litellm_params, metadata)
|
||||
|
||||
# Check if model is a vertex_ai model
|
||||
is_vertex_ai = custom_llm_provider == "vertex_ai" or model.startswith(
|
||||
"vertex_ai/"
|
||||
)
|
||||
|
||||
model = (
|
||||
model
|
||||
if any(
|
||||
accepted_model in model
|
||||
for accepted_model in self.helicone_model_list
|
||||
)
|
||||
or is_vertex_ai
|
||||
else "gpt-3.5-turbo"
|
||||
)
|
||||
provider_request = {"model": model, "messages": messages}
|
||||
if isinstance(response_obj, litellm.EmbeddingResponse) or isinstance(
|
||||
response_obj, litellm.ModelResponse
|
||||
):
|
||||
response_obj = response_obj.json()
|
||||
|
||||
if "claude" in model and not is_vertex_ai:
|
||||
response_obj = self.claude_mapping(
|
||||
model=model, messages=messages, response_obj=response_obj
|
||||
)
|
||||
|
||||
providerResponse = {
|
||||
"json": response_obj,
|
||||
"headers": {"openai-version": "2020-10-01"},
|
||||
"status": 200,
|
||||
}
|
||||
|
||||
# Code to be executed
|
||||
provider_url = self.provider_url
|
||||
url = f"{self.api_base}/oai/v1/log"
|
||||
if "claude" in model and not is_vertex_ai:
|
||||
url = f"{self.api_base}/anthropic/v1/log"
|
||||
provider_url = "https://api.anthropic.com/v1/messages"
|
||||
elif is_vertex_ai:
|
||||
url = f"{self.api_base}/custom/v1/log"
|
||||
provider_url = "https://aiplatform.googleapis.com/v1"
|
||||
elif "gemini" in model:
|
||||
url = f"{self.api_base}/custom/v1/log"
|
||||
provider_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
start_time_seconds = int(start_time.timestamp())
|
||||
start_time_milliseconds = int(
|
||||
(start_time.timestamp() - start_time_seconds) * 1000
|
||||
)
|
||||
end_time_seconds = int(end_time.timestamp())
|
||||
end_time_milliseconds = int(
|
||||
(end_time.timestamp() - end_time_seconds) * 1000
|
||||
)
|
||||
meta = {"Helicone-Auth": f"Bearer {self.key}"}
|
||||
meta.update(metadata)
|
||||
data = {
|
||||
"providerRequest": {
|
||||
"url": provider_url,
|
||||
"json": provider_request,
|
||||
"meta": meta,
|
||||
},
|
||||
"providerResponse": providerResponse,
|
||||
"timing": {
|
||||
"startTime": {
|
||||
"seconds": start_time_seconds,
|
||||
"milliseconds": start_time_milliseconds,
|
||||
},
|
||||
"endTime": {
|
||||
"seconds": end_time_seconds,
|
||||
"milliseconds": end_time_milliseconds,
|
||||
},
|
||||
}, # {"seconds": .., "milliseconds": ..}
|
||||
}
|
||||
response = litellm.module_level_client.post(url, headers=headers, json=data)
|
||||
if response.status_code == 200:
|
||||
if self.is_mock_mode:
|
||||
print_verbose(
|
||||
"[HELICONE MOCK] Helicone Logging - Successfully mocked!"
|
||||
)
|
||||
else:
|
||||
print_verbose("Helicone Logging - Success!")
|
||||
else:
|
||||
print_verbose(
|
||||
f"Helicone Logging - Error Request was not successful. Status Code: {response.status_code}"
|
||||
)
|
||||
print_verbose(f"Helicone Logging - Error {response.text}")
|
||||
except Exception:
|
||||
print_verbose(f"Helicone Logging Error - {traceback.format_exc()}")
|
||||
pass
|
||||
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Mock HTTP client for Helicone integration testing.
|
||||
|
||||
This module intercepts Helicone API calls and returns successful mock responses,
|
||||
allowing full code execution without making actual network calls.
|
||||
|
||||
Usage:
|
||||
Set HELICONE_MOCK=true in environment variables or config to enable mock mode.
|
||||
"""
|
||||
|
||||
from litellm.integrations.mock_client_factory import (
|
||||
MockClientConfig,
|
||||
create_mock_client_factory,
|
||||
)
|
||||
|
||||
# Create mock client using factory
|
||||
# Helicone uses HTTPHandler which internally uses httpx.Client.send(), not httpx.Client.post()
|
||||
_config = MockClientConfig(
|
||||
name="HELICONE",
|
||||
env_var="HELICONE_MOCK",
|
||||
default_latency_ms=100,
|
||||
default_status_code=200,
|
||||
default_json_data={"status": "success"},
|
||||
url_matchers=[
|
||||
".hconeai.com",
|
||||
"hconeai.com",
|
||||
".helicone.ai",
|
||||
"helicone.ai",
|
||||
],
|
||||
patch_async_handler=False,
|
||||
patch_sync_client=False, # HTTPHandler uses self.client.send(), not self.client.post()
|
||||
patch_http_handler=True, # Patch HTTPHandler.post directly
|
||||
)
|
||||
|
||||
create_mock_helicone_client, should_use_helicone_mock = create_mock_client_factory(
|
||||
_config
|
||||
)
|
||||
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
Humanloop integration
|
||||
|
||||
https://humanloop.com/
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import litellm
|
||||
from litellm.caching import DualCache
|
||||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.prompts.init_prompts import PromptSpec
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
from .custom_logger import CustomLogger
|
||||
|
||||
|
||||
class PromptManagementClient(TypedDict):
|
||||
prompt_id: str
|
||||
prompt_template: List[AllMessageValues]
|
||||
model: Optional[str]
|
||||
optional_params: Optional[Dict[str, Any]]
|
||||
|
||||
|
||||
class HumanLoopPromptManager(DualCache):
|
||||
@property
|
||||
def integration_name(self):
|
||||
return "humanloop"
|
||||
|
||||
def _get_prompt_from_id_cache(
|
||||
self, humanloop_prompt_id: str
|
||||
) -> Optional[PromptManagementClient]:
|
||||
return cast(
|
||||
Optional[PromptManagementClient], self.get_cache(key=humanloop_prompt_id)
|
||||
)
|
||||
|
||||
def _compile_prompt_helper(
|
||||
self, prompt_template: List[AllMessageValues], prompt_variables: Dict[str, Any]
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Helper function to compile the prompt by substituting variables in the template.
|
||||
|
||||
Args:
|
||||
prompt_template: List[AllMessageValues]
|
||||
prompt_variables (dict): A dictionary of variables to substitute into the prompt template.
|
||||
|
||||
Returns:
|
||||
list: A list of dictionaries with variables substituted.
|
||||
"""
|
||||
compiled_prompts: List[AllMessageValues] = []
|
||||
|
||||
for template in prompt_template:
|
||||
tc = template.get("content")
|
||||
if tc and isinstance(tc, str):
|
||||
formatted_template = tc.replace("{{", "{").replace("}}", "}")
|
||||
compiled_content = formatted_template.format(**prompt_variables)
|
||||
template["content"] = compiled_content
|
||||
compiled_prompts.append(template)
|
||||
|
||||
return compiled_prompts
|
||||
|
||||
def _get_prompt_from_id_api(
|
||||
self, humanloop_prompt_id: str, humanloop_api_key: str
|
||||
) -> PromptManagementClient:
|
||||
client = _get_httpx_client()
|
||||
|
||||
base_url = "https://api.humanloop.com/v5/prompts/{}".format(humanloop_prompt_id)
|
||||
|
||||
response = client.get(
|
||||
url=base_url,
|
||||
headers={
|
||||
"X-Api-Key": humanloop_api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(f"Error getting prompt from Humanloop: {e.response.text}")
|
||||
|
||||
json_response = response.json()
|
||||
template_message = json_response["template"]
|
||||
if isinstance(template_message, dict):
|
||||
template_messages = [template_message]
|
||||
elif isinstance(template_message, list):
|
||||
template_messages = template_message
|
||||
else:
|
||||
raise ValueError(f"Invalid template message type: {type(template_message)}")
|
||||
template_model = json_response["model"]
|
||||
optional_params = {}
|
||||
for k, v in json_response.items():
|
||||
if k in litellm.OPENAI_CHAT_COMPLETION_PARAMS:
|
||||
optional_params[k] = v
|
||||
return PromptManagementClient(
|
||||
prompt_id=humanloop_prompt_id,
|
||||
prompt_template=cast(List[AllMessageValues], template_messages),
|
||||
model=template_model,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
def _get_prompt_from_id(
|
||||
self, humanloop_prompt_id: str, humanloop_api_key: str
|
||||
) -> PromptManagementClient:
|
||||
prompt = self._get_prompt_from_id_cache(humanloop_prompt_id)
|
||||
if prompt is None:
|
||||
prompt = self._get_prompt_from_id_api(
|
||||
humanloop_prompt_id, humanloop_api_key
|
||||
)
|
||||
self.set_cache(
|
||||
key=humanloop_prompt_id,
|
||||
value=prompt,
|
||||
ttl=litellm.HUMANLOOP_PROMPT_CACHE_TTL_SECONDS,
|
||||
)
|
||||
return prompt
|
||||
|
||||
def compile_prompt(
|
||||
self,
|
||||
prompt_template: List[AllMessageValues],
|
||||
prompt_variables: Optional[dict],
|
||||
) -> List[AllMessageValues]:
|
||||
compiled_prompt: Optional[Union[str, list]] = None
|
||||
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
|
||||
compiled_prompt = self._compile_prompt_helper(
|
||||
prompt_template=prompt_template,
|
||||
prompt_variables=prompt_variables,
|
||||
)
|
||||
|
||||
return compiled_prompt
|
||||
|
||||
def _get_model_from_prompt(
|
||||
self, prompt_management_client: PromptManagementClient, model: str
|
||||
) -> str:
|
||||
if prompt_management_client["model"] is not None:
|
||||
return prompt_management_client["model"]
|
||||
else:
|
||||
return model.replace("{}/".format(self.integration_name), "")
|
||||
|
||||
|
||||
prompt_manager = HumanLoopPromptManager()
|
||||
|
||||
|
||||
class HumanloopLogger(CustomLogger):
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
ignore_prompt_manager_model: Optional[bool] = False,
|
||||
ignore_prompt_manager_optional_params: Optional[bool] = False,
|
||||
) -> Tuple[str, List[AllMessageValues], dict,]:
|
||||
humanloop_api_key = dynamic_callback_params.get(
|
||||
"humanloop_api_key"
|
||||
) or get_secret_str("HUMANLOOP_API_KEY")
|
||||
|
||||
if prompt_id is None:
|
||||
raise ValueError("prompt_id is required for Humanloop integration")
|
||||
|
||||
if humanloop_api_key is None:
|
||||
return super().get_chat_completion_prompt(
|
||||
model=model,
|
||||
messages=messages,
|
||||
non_default_params=non_default_params,
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
prompt_spec=prompt_spec,
|
||||
)
|
||||
|
||||
prompt_template = prompt_manager._get_prompt_from_id(
|
||||
humanloop_prompt_id=prompt_id, humanloop_api_key=humanloop_api_key
|
||||
)
|
||||
|
||||
updated_messages = prompt_manager.compile_prompt(
|
||||
prompt_template=prompt_template["prompt_template"],
|
||||
prompt_variables=prompt_variables,
|
||||
)
|
||||
|
||||
prompt_template_optional_params = prompt_template["optional_params"] or {}
|
||||
|
||||
updated_non_default_params = {
|
||||
**non_default_params,
|
||||
**prompt_template_optional_params,
|
||||
}
|
||||
|
||||
model = prompt_manager._get_model_from_prompt(
|
||||
prompt_management_client=prompt_template, model=model
|
||||
)
|
||||
|
||||
return model, updated_messages, updated_non_default_params
|
||||
@@ -0,0 +1,202 @@
|
||||
# What is this?
|
||||
## On Success events log cost to Lago - https://github.com/BerriAI/litellm/issues/3639
|
||||
|
||||
import json
|
||||
import os
|
||||
from litellm._uuid import uuid
|
||||
from typing import Literal, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
|
||||
|
||||
def get_utc_datetime():
|
||||
import datetime as dt
|
||||
from datetime import datetime
|
||||
|
||||
if hasattr(dt, "UTC"):
|
||||
return datetime.now(dt.UTC) # type: ignore
|
||||
else:
|
||||
return datetime.utcnow() # type: ignore
|
||||
|
||||
|
||||
class LagoLogger(CustomLogger):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.validate_environment()
|
||||
self.async_http_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
self.sync_http_handler = HTTPHandler()
|
||||
|
||||
def validate_environment(self):
|
||||
"""
|
||||
Expects
|
||||
LAGO_API_BASE,
|
||||
LAGO_API_KEY,
|
||||
LAGO_API_EVENT_CODE,
|
||||
|
||||
Optional:
|
||||
LAGO_API_CHARGE_BY
|
||||
|
||||
in the environment
|
||||
"""
|
||||
missing_keys = []
|
||||
if os.getenv("LAGO_API_KEY", None) is None:
|
||||
missing_keys.append("LAGO_API_KEY")
|
||||
|
||||
if os.getenv("LAGO_API_BASE", None) is None:
|
||||
missing_keys.append("LAGO_API_BASE")
|
||||
|
||||
if os.getenv("LAGO_API_EVENT_CODE", None) is None:
|
||||
missing_keys.append("LAGO_API_EVENT_CODE")
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
raise Exception("Missing keys={} in environment.".format(missing_keys))
|
||||
|
||||
def _common_logic(self, kwargs: dict, response_obj) -> dict:
|
||||
response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||
get_utc_datetime().isoformat()
|
||||
cost = kwargs.get("response_cost", None)
|
||||
model = kwargs.get("model")
|
||||
usage = {}
|
||||
|
||||
if (
|
||||
isinstance(response_obj, litellm.ModelResponse)
|
||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
) and hasattr(response_obj, "usage"):
|
||||
usage = {
|
||||
"prompt_tokens": response_obj["usage"].get("prompt_tokens", 0),
|
||||
"completion_tokens": response_obj["usage"].get("completion_tokens", 0),
|
||||
"total_tokens": response_obj["usage"].get("total_tokens"),
|
||||
}
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
user_id = litellm_params["metadata"].get("user_api_key_user_id", None)
|
||||
team_id = litellm_params["metadata"].get("user_api_key_team_id", None)
|
||||
litellm_params["metadata"].get("user_api_key_org_id", None)
|
||||
|
||||
charge_by: Literal["end_user_id", "team_id", "user_id"] = "end_user_id"
|
||||
external_customer_id: Optional[str] = None
|
||||
|
||||
if os.getenv("LAGO_API_CHARGE_BY", None) is not None and isinstance(
|
||||
os.environ["LAGO_API_CHARGE_BY"], str
|
||||
):
|
||||
if os.environ["LAGO_API_CHARGE_BY"] in [
|
||||
"end_user_id",
|
||||
"user_id",
|
||||
"team_id",
|
||||
]:
|
||||
charge_by = os.environ["LAGO_API_CHARGE_BY"] # type: ignore
|
||||
else:
|
||||
raise Exception("invalid LAGO_API_CHARGE_BY set")
|
||||
|
||||
if charge_by == "end_user_id":
|
||||
external_customer_id = end_user_id
|
||||
elif charge_by == "team_id":
|
||||
external_customer_id = team_id
|
||||
elif charge_by == "user_id":
|
||||
external_customer_id = user_id
|
||||
|
||||
if external_customer_id is None:
|
||||
raise Exception(
|
||||
"External Customer ID is not set. Charge_by={}. User_id={}. End_user_id={}. Team_id={}".format(
|
||||
charge_by, user_id, end_user_id, team_id
|
||||
)
|
||||
)
|
||||
|
||||
returned_val = {
|
||||
"event": {
|
||||
"transaction_id": str(uuid.uuid4()),
|
||||
"external_subscription_id": external_customer_id,
|
||||
"code": os.getenv("LAGO_API_EVENT_CODE"),
|
||||
"properties": {"model": model, "response_cost": cost, **usage},
|
||||
}
|
||||
}
|
||||
|
||||
verbose_logger.debug(
|
||||
"\033[91mLogged Lago Object:\n{}\033[0m\n".format(returned_val)
|
||||
)
|
||||
return returned_val
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
_url = os.getenv("LAGO_API_BASE")
|
||||
assert _url is not None and isinstance(
|
||||
_url, str
|
||||
), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(_url)
|
||||
if _url.endswith("/"):
|
||||
_url += "api/v1/events"
|
||||
else:
|
||||
_url += "/api/v1/events"
|
||||
|
||||
api_key = os.getenv("LAGO_API_KEY")
|
||||
|
||||
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
||||
_headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.sync_http_handler.post(
|
||||
url=_url,
|
||||
data=json.dumps(_data),
|
||||
headers=_headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_response is not None and hasattr(error_response, "text"):
|
||||
verbose_logger.debug(f"\nError Message: {error_response.text}")
|
||||
raise e
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
verbose_logger.debug("ENTERS LAGO CALLBACK")
|
||||
_url = os.getenv("LAGO_API_BASE")
|
||||
assert _url is not None and isinstance(
|
||||
_url, str
|
||||
), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(
|
||||
_url
|
||||
)
|
||||
if _url.endswith("/"):
|
||||
_url += "api/v1/events"
|
||||
else:
|
||||
_url += "/api/v1/events"
|
||||
|
||||
api_key = os.getenv("LAGO_API_KEY")
|
||||
|
||||
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
||||
_headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
}
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
response: Optional[httpx.Response] = None
|
||||
try:
|
||||
response = await self.async_http_handler.post(
|
||||
url=_url,
|
||||
data=json.dumps(_data),
|
||||
headers=_headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
verbose_logger.debug(f"Logged Lago Object: {response.text}")
|
||||
except Exception as e:
|
||||
if response is not None and hasattr(response, "text"):
|
||||
verbose_logger.debug(f"\nError Message: {response.text}")
|
||||
raise e
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
This file contains the LangFuseHandler class
|
||||
|
||||
Used to get the LangFuseLogger for a given request
|
||||
|
||||
Handles Key/Team Based Langfuse Logging
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import StandardCallbackDynamicParams
|
||||
|
||||
from .langfuse import LangFuseLogger, LangfuseLoggingConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache
|
||||
else:
|
||||
DynamicLoggingCache = Any
|
||||
|
||||
|
||||
class LangFuseHandler:
|
||||
@staticmethod
|
||||
def get_langfuse_logger_for_request(
|
||||
standard_callback_dynamic_params: StandardCallbackDynamicParams,
|
||||
in_memory_dynamic_logger_cache: DynamicLoggingCache,
|
||||
globalLangfuseLogger: Optional[LangFuseLogger] = None,
|
||||
) -> LangFuseLogger:
|
||||
"""
|
||||
This function is used to get the LangFuseLogger for a given request
|
||||
|
||||
1. If dynamic credentials are passed
|
||||
- check if a LangFuseLogger is cached for the dynamic credentials
|
||||
- if cached LangFuseLogger is not found, create a new LangFuseLogger and cache it
|
||||
|
||||
2. If dynamic credentials are not passed return the globalLangfuseLogger
|
||||
|
||||
"""
|
||||
temp_langfuse_logger: Optional[LangFuseLogger] = globalLangfuseLogger
|
||||
if (
|
||||
LangFuseHandler._dynamic_langfuse_credentials_are_passed(
|
||||
standard_callback_dynamic_params
|
||||
)
|
||||
is False
|
||||
):
|
||||
return LangFuseHandler._return_global_langfuse_logger(
|
||||
globalLangfuseLogger=globalLangfuseLogger,
|
||||
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
|
||||
)
|
||||
|
||||
# get langfuse logging config to use for this request, based on standard_callback_dynamic_params
|
||||
_credentials = LangFuseHandler.get_dynamic_langfuse_logging_config(
|
||||
globalLangfuseLogger=globalLangfuseLogger,
|
||||
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||
)
|
||||
credentials_dict = dict(_credentials)
|
||||
|
||||
# check if langfuse logger is already cached
|
||||
temp_langfuse_logger = in_memory_dynamic_logger_cache.get_cache(
|
||||
credentials=credentials_dict, service_name="langfuse"
|
||||
)
|
||||
|
||||
# if not cached, create a new langfuse logger and cache it
|
||||
if temp_langfuse_logger is None:
|
||||
temp_langfuse_logger = (
|
||||
LangFuseHandler._create_langfuse_logger_from_credentials(
|
||||
credentials=credentials_dict,
|
||||
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
|
||||
)
|
||||
)
|
||||
|
||||
return temp_langfuse_logger
|
||||
|
||||
@staticmethod
|
||||
def _return_global_langfuse_logger(
|
||||
globalLangfuseLogger: Optional[LangFuseLogger],
|
||||
in_memory_dynamic_logger_cache: DynamicLoggingCache,
|
||||
) -> LangFuseLogger:
|
||||
"""
|
||||
Returns the Global LangfuseLogger set on litellm
|
||||
|
||||
(this is the default langfuse logger - used when no dynamic credentials are passed)
|
||||
|
||||
If no Global LangfuseLogger is set, it will check in_memory_dynamic_logger_cache for a cached LangFuseLogger
|
||||
This function is used to return the globalLangfuseLogger if it exists, otherwise it will check in_memory_dynamic_logger_cache for a cached LangFuseLogger
|
||||
"""
|
||||
if globalLangfuseLogger is not None:
|
||||
return globalLangfuseLogger
|
||||
|
||||
credentials_dict: Dict[
|
||||
str, Any
|
||||
] = (
|
||||
{}
|
||||
) # the global langfuse logger uses Environment Variables, there are no dynamic credentials
|
||||
globalLangfuseLogger = in_memory_dynamic_logger_cache.get_cache(
|
||||
credentials=credentials_dict,
|
||||
service_name="langfuse",
|
||||
)
|
||||
if globalLangfuseLogger is None:
|
||||
globalLangfuseLogger = (
|
||||
LangFuseHandler._create_langfuse_logger_from_credentials(
|
||||
credentials=credentials_dict,
|
||||
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
|
||||
)
|
||||
)
|
||||
return globalLangfuseLogger
|
||||
|
||||
@staticmethod
|
||||
def _create_langfuse_logger_from_credentials(
|
||||
credentials: Dict,
|
||||
in_memory_dynamic_logger_cache: DynamicLoggingCache,
|
||||
) -> LangFuseLogger:
|
||||
"""
|
||||
This function is used to
|
||||
1. create a LangFuseLogger from the credentials
|
||||
2. cache the LangFuseLogger to prevent re-creating it for the same credentials
|
||||
"""
|
||||
|
||||
langfuse_logger = LangFuseLogger(
|
||||
langfuse_public_key=credentials.get("langfuse_public_key"),
|
||||
langfuse_secret=credentials.get("langfuse_secret"),
|
||||
langfuse_host=credentials.get("langfuse_host"),
|
||||
)
|
||||
in_memory_dynamic_logger_cache.set_cache(
|
||||
credentials=credentials,
|
||||
service_name="langfuse",
|
||||
logging_obj=langfuse_logger,
|
||||
)
|
||||
return langfuse_logger
|
||||
|
||||
@staticmethod
|
||||
def get_dynamic_langfuse_logging_config(
|
||||
standard_callback_dynamic_params: StandardCallbackDynamicParams,
|
||||
globalLangfuseLogger: Optional[LangFuseLogger] = None,
|
||||
) -> LangfuseLoggingConfig:
|
||||
"""
|
||||
This function is used to get the Langfuse logging config to use for a given request.
|
||||
|
||||
It checks if the dynamic parameters are provided in the standard_callback_dynamic_params and uses them to get the Langfuse logging config.
|
||||
|
||||
If no dynamic parameters are provided, it uses the `globalLangfuseLogger` values
|
||||
"""
|
||||
# only use dynamic params if langfuse credentials are passed dynamically
|
||||
return LangfuseLoggingConfig(
|
||||
langfuse_secret=standard_callback_dynamic_params.get("langfuse_secret")
|
||||
or standard_callback_dynamic_params.get("langfuse_secret_key"),
|
||||
langfuse_public_key=standard_callback_dynamic_params.get(
|
||||
"langfuse_public_key"
|
||||
),
|
||||
langfuse_host=standard_callback_dynamic_params.get("langfuse_host"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _dynamic_langfuse_credentials_are_passed(
|
||||
standard_callback_dynamic_params: StandardCallbackDynamicParams,
|
||||
) -> bool:
|
||||
"""
|
||||
This function is used to check if the dynamic langfuse credentials are passed in standard_callback_dynamic_params
|
||||
|
||||
Returns:
|
||||
bool: True if the dynamic langfuse credentials are passed, False otherwise
|
||||
"""
|
||||
|
||||
if (
|
||||
standard_callback_dynamic_params.get("langfuse_host") is not None
|
||||
or standard_callback_dynamic_params.get("langfuse_public_key") is not None
|
||||
or standard_callback_dynamic_params.get("langfuse_secret") is not None
|
||||
or standard_callback_dynamic_params.get("langfuse_secret_key") is not None
|
||||
):
|
||||
return True
|
||||
return False
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user