chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,5 @@
# Integrations
This folder contains logging integrations for litellm
eg. logging to Datadog, Langfuse, Prometheus, s3, GCS Bucket, etc.

View File

@@ -0,0 +1,46 @@
# Slack Alerting on LiteLLM Gateway
This folder contains the Slack Alerting integration for LiteLLM Gateway.
## Folder Structure
- `slack_alerting.py`: This is the main file that handles sending different types of alerts
- `batching_handler.py`: Handles Batching + sending Httpx Post requests to slack. Slack alerts are sent every 10s or when events are greater than X events. Done to ensure litellm has good performance under high traffic
- `types.py`: This file contains the AlertType enum which is used to define the different types of alerts that can be sent to Slack.
- `utils.py`: This file contains common utils used specifically for slack alerting
## Budget Alert Types
The `budget_alert_types.py` module provides a flexible framework for handling different types of budget alerts:
- `BaseBudgetAlertType`: An abstract base class with abstract methods that all alert types must implement:
- `get_event_group()`: Returns the Litellm_EntityType for the alert
- `get_event_message()`: Returns the message prefix for the alert
- `get_id(user_info)`: Returns the ID to use for caching/tracking the alert
Concrete implementations include:
- `ProxyBudgetAlert`: Alerting for proxy-level budget concerns
- `SoftBudgetAlert`: Alerting when soft budgets are crossed
- `UserBudgetAlert`: Alerting for user-level budget concerns
- `TeamBudgetAlert`: Alerting for team-level budget concerns
- `TokenBudgetAlert`: Alerting for API key budget concerns
- `ProjectedLimitExceededAlert`: Alerting when projected spend will exceed budget
Use the `get_budget_alert_type()` factory function to get the appropriate alert type class for a given alert type string:
```python
from litellm.integrations.SlackAlerting.budget_alert_types import get_budget_alert_type
# Get the appropriate handler
budget_alert_class = get_budget_alert_type("user_budget")
# Use the handler methods
event_group = budget_alert_class.get_event_group() # Returns Litellm_EntityType.USER
event_message = budget_alert_class.get_event_message() # Returns "User Budget: "
cache_id = budget_alert_class.get_id(user_info) # Returns user_id
```
To add a new budget alert type, simply create a new class that extends `BaseBudgetAlertType` and implements all the required methods, then add it to the dictionary in the `get_budget_alert_type()` function.
## Further Reading
- [Doc setting up Alerting on LiteLLM Proxy (Gateway)](https://docs.litellm.ai/docs/proxy/alerting)

View File

@@ -0,0 +1,81 @@
"""
Handles Batching + sending Httpx Post requests to slack
Slack alerts are sent every 10s or when events are greater than X events
see custom_batch_logger.py for more details / defaults
"""
from typing import TYPE_CHECKING, Any
from litellm._logging import verbose_proxy_logger
if TYPE_CHECKING:
from .slack_alerting import SlackAlerting as _SlackAlerting
SlackAlertingType = _SlackAlerting
else:
SlackAlertingType = Any
def squash_payloads(queue):
squashed = {}
if len(queue) == 0:
return squashed
if len(queue) == 1:
return {"key": {"item": queue[0], "count": 1}}
for item in queue:
url = item["url"]
alert_type = item["alert_type"]
_key = (url, alert_type)
if _key in squashed:
squashed[_key]["count"] += 1
# Merge the payloads
else:
squashed[_key] = {"item": item, "count": 1}
return squashed
def _print_alerting_payload_warning(
payload: dict, slackAlertingInstance: SlackAlertingType
):
"""
Print the payload to the console when
slackAlertingInstance.alerting_args.log_to_console is True
Relevant issue: https://github.com/BerriAI/litellm/issues/7372
"""
if slackAlertingInstance.alerting_args.log_to_console is True:
verbose_proxy_logger.warning(payload)
async def send_to_webhook(slackAlertingInstance: SlackAlertingType, item, count):
"""
Send a single slack alert to the webhook
"""
import json
payload = item.get("payload", {})
try:
if count > 1:
payload["text"] = f"[Num Alerts: {count}]\n\n{payload['text']}"
response = await slackAlertingInstance.async_http_handler.post(
url=item["url"],
headers=item["headers"],
data=json.dumps(payload),
)
if response.status_code != 200:
verbose_proxy_logger.debug(
f"Error sending slack alert to url={item['url']}. Error={response.text}"
)
except Exception as e:
verbose_proxy_logger.debug(f"Error sending slack alert: {str(e)}")
finally:
_print_alerting_payload_warning(
payload, slackAlertingInstance=slackAlertingInstance
)

View File

@@ -0,0 +1,115 @@
from abc import ABC, abstractmethod
from typing import Literal
from litellm.proxy._types import CallInfo
class BaseBudgetAlertType(ABC):
"""Base class for different budget alert types"""
@abstractmethod
def get_event_message(self) -> str:
"""Return the event message for this alert type"""
pass
@abstractmethod
def get_id(self, user_info: CallInfo) -> str:
"""Return the ID to use for caching/tracking this alert"""
pass
class ProxyBudgetAlert(BaseBudgetAlertType):
def get_event_message(self) -> str:
return "Proxy Budget: "
def get_id(self, user_info: CallInfo) -> str:
return "default_id"
class SoftBudgetAlert(BaseBudgetAlertType):
def get_event_message(self) -> str:
return "Soft Budget Crossed: "
def get_id(self, user_info: CallInfo) -> str:
return user_info.token or "default_id"
class UserBudgetAlert(BaseBudgetAlertType):
def get_event_message(self) -> str:
return "User Budget: "
def get_id(self, user_info: CallInfo) -> str:
return user_info.user_id or "default_id"
class TeamBudgetAlert(BaseBudgetAlertType):
def get_event_message(self) -> str:
return "Team Budget: "
def get_id(self, user_info: CallInfo) -> str:
return user_info.team_id or "default_id"
class OrganizationBudgetAlert(BaseBudgetAlertType):
def get_event_message(self) -> str:
return "Organization Budget: "
def get_id(self, user_info: CallInfo) -> str:
return user_info.organization_id or "default_id"
class TokenBudgetAlert(BaseBudgetAlertType):
def get_event_message(self) -> str:
return "Key Budget: "
def get_id(self, user_info: CallInfo) -> str:
return user_info.token or "default_id"
class ProjectedLimitExceededAlert(BaseBudgetAlertType):
def get_event_message(self) -> str:
return "Key Budget: Projected Limit Exceeded"
def get_id(self, user_info: CallInfo) -> str:
return user_info.token or "default_id"
class ProjectBudgetAlert(BaseBudgetAlertType):
def get_event_message(self) -> str:
return "Project Budget: "
def get_id(self, user_info: CallInfo) -> str:
return user_info.token or "default_id"
def get_budget_alert_type(
type: Literal[
"token_budget",
"user_budget",
"soft_budget",
"max_budget_alert",
"team_budget",
"organization_budget",
"proxy_budget",
"projected_limit_exceeded",
"project_budget",
],
) -> BaseBudgetAlertType:
"""Factory function to get the appropriate budget alert type class"""
alert_types = {
"proxy_budget": ProxyBudgetAlert(),
"soft_budget": SoftBudgetAlert(),
"user_budget": UserBudgetAlert(),
"max_budget_alert": TokenBudgetAlert(),
"team_budget": TeamBudgetAlert(),
"organization_budget": OrganizationBudgetAlert(),
"token_budget": TokenBudgetAlert(),
"projected_limit_exceeded": ProjectedLimitExceededAlert(),
"project_budget": ProjectBudgetAlert(),
}
if type in alert_types:
return alert_types[type]
else:
return ProxyBudgetAlert()

View File

@@ -0,0 +1,177 @@
"""
Class to check for LLM API hanging requests
Notes:
- Do not create tasks that sleep, that can saturate the event loop
- Do not store large objects (eg. messages in memory) that can increase RAM usage
"""
import asyncio
from typing import TYPE_CHECKING, Any, Optional
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.in_memory_cache import InMemoryCache
from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
from litellm.types.integrations.slack_alerting import (
HANGING_ALERT_BUFFER_TIME_SECONDS,
MAX_OLDEST_HANGING_REQUESTS_TO_CHECK,
HangingRequestData,
)
if TYPE_CHECKING:
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
else:
SlackAlerting = Any
class AlertingHangingRequestCheck:
"""
Class to safely handle checking hanging requests alerts
"""
def __init__(
self,
slack_alerting_object: SlackAlerting,
):
self.slack_alerting_object = slack_alerting_object
self.hanging_request_cache = InMemoryCache(
default_ttl=int(
self.slack_alerting_object.alerting_threshold
+ HANGING_ALERT_BUFFER_TIME_SECONDS
),
)
async def add_request_to_hanging_request_check(
self,
request_data: Optional[dict] = None,
):
"""
Add a request to the hanging request cache. This is the list of request_ids that gets periodicall checked for hanging requests
"""
if request_data is None:
return
request_metadata = get_litellm_metadata_from_kwargs(kwargs=request_data)
model = request_data.get("model", "")
api_base: Optional[str] = None
if request_data.get("deployment", None) is not None and isinstance(
request_data["deployment"], dict
):
api_base = litellm.get_api_base(
model=model,
optional_params=request_data["deployment"].get("litellm_params", {}),
)
hanging_request_data = HangingRequestData(
request_id=request_data.get("litellm_call_id", ""),
model=model,
api_base=api_base,
key_alias=request_metadata.get("user_api_key_alias", ""),
team_alias=request_metadata.get("user_api_key_team_alias", ""),
)
await self.hanging_request_cache.async_set_cache(
key=hanging_request_data.request_id,
value=hanging_request_data,
ttl=int(
self.slack_alerting_object.alerting_threshold
+ HANGING_ALERT_BUFFER_TIME_SECONDS
),
)
return
async def send_alerts_for_hanging_requests(self):
"""
Send alerts for hanging requests
"""
from litellm.proxy.proxy_server import proxy_logging_obj
#########################################################
# Find all requests that have been hanging for more than the alerting threshold
# Get the last 50 oldest items in the cache and check if they have completed
#########################################################
# check if request_id is in internal usage cache
if proxy_logging_obj.internal_usage_cache is None:
return
hanging_requests = await self.hanging_request_cache.async_get_oldest_n_keys(
n=MAX_OLDEST_HANGING_REQUESTS_TO_CHECK,
)
for request_id in hanging_requests:
hanging_request_data: Optional[
HangingRequestData
] = await self.hanging_request_cache.async_get_cache(
key=request_id,
)
if hanging_request_data is None:
continue
request_status = (
await proxy_logging_obj.internal_usage_cache.async_get_cache(
key="request_status:{}".format(hanging_request_data.request_id),
litellm_parent_otel_span=None,
local_only=True,
)
)
# this means the request status was either success or fail
# and is not hanging
if request_status is not None:
# clear this request from hanging request cache since the request was either success or failed
self.hanging_request_cache._remove_key(
key=request_id,
)
continue
################
# Send the Alert on Slack
################
await self.send_hanging_request_alert(
hanging_request_data=hanging_request_data
)
return
async def check_for_hanging_requests(
self,
):
"""
Background task that checks all request ids in self.hanging_request_cache to check if they have completed
Runs every alerting_threshold/2 seconds to check for hanging requests
"""
while True:
verbose_proxy_logger.debug("Checking for hanging requests....")
await self.send_alerts_for_hanging_requests()
await asyncio.sleep(self.slack_alerting_object.alerting_threshold / 2)
async def send_hanging_request_alert(
self,
hanging_request_data: HangingRequestData,
):
"""
Send a hanging request alert
"""
from litellm.integrations.SlackAlerting.slack_alerting import AlertType
################
# Send the Alert on Slack
################
request_info = f"""Request Model: `{hanging_request_data.model}`
API Base: `{hanging_request_data.api_base}`
Key Alias: `{hanging_request_data.key_alias}`
Team Alias: `{hanging_request_data.team_alias}`"""
alerting_message = f"`Requests are hanging - {self.slack_alerting_object.alerting_threshold}s+ request time`"
await self.slack_alerting_object.send_alert(
message=alerting_message + "\n" + request_info,
level="Medium",
alert_type=AlertType.llm_requests_hanging,
alerting_metadata=hanging_request_data.alerting_metadata or {},
request_model=hanging_request_data.model,
api_base=hanging_request_data.api_base,
)

View File

@@ -0,0 +1,99 @@
"""
Utils used for slack alerting
"""
import asyncio
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import litellm
from litellm.proxy._types import AlertType
from litellm.secret_managers.main import get_secret
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _Logging
Logging = _Logging
else:
Logging = Any
def process_slack_alerting_variables(
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]]
) -> Optional[Dict[AlertType, Union[List[str], str]]]:
"""
process alert_to_webhook_url
- check if any urls are set as os.environ/SLACK_WEBHOOK_URL_1 read env var and set the correct value
"""
if alert_to_webhook_url is None:
return None
for alert_type, webhook_urls in alert_to_webhook_url.items():
if isinstance(webhook_urls, list):
_webhook_values: List[str] = []
for webhook_url in webhook_urls:
if "os.environ/" in webhook_url:
_env_value = get_secret(secret_name=webhook_url)
if not isinstance(_env_value, str):
raise ValueError(
f"Invalid webhook url value for: {webhook_url}. Got type={type(_env_value)}"
)
_webhook_values.append(_env_value)
else:
_webhook_values.append(webhook_url)
alert_to_webhook_url[alert_type] = _webhook_values
else:
_webhook_value_str: str = webhook_urls
if "os.environ/" in webhook_urls:
_env_value = get_secret(secret_name=webhook_urls)
if not isinstance(_env_value, str):
raise ValueError(
f"Invalid webhook url value for: {webhook_urls}. Got type={type(_env_value)}"
)
_webhook_value_str = _env_value
else:
_webhook_value_str = webhook_urls
alert_to_webhook_url[alert_type] = _webhook_value_str
return alert_to_webhook_url
async def _add_langfuse_trace_id_to_alert(
request_data: Optional[dict] = None,
) -> Optional[str]:
"""
Returns langfuse trace url
- check:
-> existing_trace_id
-> trace_id
-> litellm_call_id
"""
if "langfuse" not in litellm.logging_callback_manager._get_all_callbacks():
return None
#########################################################
# Only run if langfuse is added as a callback
#########################################################
if (
request_data is not None
and request_data.get("litellm_logging_obj", None) is not None
):
trace_id: Optional[str] = None
litellm_logging_obj: Logging = request_data["litellm_logging_obj"]
for _ in range(3):
trace_id = litellm_logging_obj._get_trace_id(service_name="langfuse")
if trace_id is not None:
break
await asyncio.sleep(3) # wait 3s before retrying for trace id
#########################################################
langfuse_object = litellm_logging_obj._get_callback_object(
service_name="langfuse"
)
if langfuse_object is not None:
base_url = langfuse_object.Langfuse.base_url
return f"{base_url}/trace/{trace_id}"
return None

View File

@@ -0,0 +1 @@
from . import *

View File

@@ -0,0 +1,440 @@
from enum import Enum
class SpanAttributes:
OUTPUT_VALUE = "output.value"
OUTPUT_MIME_TYPE = "output.mime_type"
"""
The type of output.value. If unspecified, the type is plain text by default.
If type is JSON, the value is a string representing a JSON object.
"""
INPUT_VALUE = "input.value"
INPUT_MIME_TYPE = "input.mime_type"
"""
The type of input.value. If unspecified, the type is plain text by default.
If type is JSON, the value is a string representing a JSON object.
"""
EMBEDDING_EMBEDDINGS = "embedding.embeddings"
"""
A list of objects containing embedding data, including the vector and represented piece of text.
"""
EMBEDDING_MODEL_NAME = "embedding.model_name"
"""
The name of the embedding model.
"""
LLM_FUNCTION_CALL = "llm.function_call"
"""
For models and APIs that support function calling. Records attributes such as the function
name and arguments to the called function.
"""
LLM_INVOCATION_PARAMETERS = "llm.invocation_parameters"
"""
Invocation parameters passed to the LLM or API, such as the model name, temperature, etc.
"""
LLM_INPUT_MESSAGES = "llm.input_messages"
"""
Messages provided to a chat API.
"""
LLM_OUTPUT_MESSAGES = "llm.output_messages"
"""
Messages received from a chat API.
"""
LLM_MODEL_NAME = "llm.model_name"
"""
The name of the model being used.
"""
LLM_PROVIDER = "llm.provider"
"""
The provider of the model, such as OpenAI, Azure, Google, etc.
"""
LLM_SYSTEM = "llm.system"
"""
The AI product as identified by the client or server
"""
LLM_PROMPTS = "llm.prompts"
"""
Prompts provided to a completions API.
"""
LLM_PROMPT_TEMPLATE = "llm.prompt_template.template"
"""
The prompt template as a Python f-string.
"""
LLM_PROMPT_TEMPLATE_VARIABLES = "llm.prompt_template.variables"
"""
A list of input variables to the prompt template.
"""
LLM_PROMPT_TEMPLATE_VERSION = "llm.prompt_template.version"
"""
The version of the prompt template being used.
"""
LLM_TOKEN_COUNT_PROMPT = "llm.token_count.prompt"
"""
Number of tokens in the prompt.
"""
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE = (
"llm.token_count.prompt_details.cache_write"
)
"""
Number of tokens in the prompt that were written to cache.
"""
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ = (
"llm.token_count.prompt_details.cache_read"
)
"""
Number of tokens in the prompt that were read from cache.
"""
LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO = "llm.token_count.prompt_details.audio"
"""
The number of audio input tokens presented in the prompt
"""
LLM_TOKEN_COUNT_COMPLETION = "llm.token_count.completion"
"""
Number of tokens in the completion.
"""
LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING = (
"llm.token_count.completion_details.reasoning"
)
"""
Number of tokens used for reasoning steps in the completion.
"""
LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO = (
"llm.token_count.completion_details.audio"
)
"""
The number of audio input tokens generated by the model
"""
LLM_TOKEN_COUNT_TOTAL = "llm.token_count.total"
"""
Total number of tokens, including both prompt and completion.
"""
LLM_TOOLS = "llm.tools"
"""
List of tools that are advertised to the LLM to be able to call
"""
TOOL_NAME = "tool.name"
"""
Name of the tool being used.
"""
TOOL_DESCRIPTION = "tool.description"
"""
Description of the tool's purpose, typically used to select the tool.
"""
TOOL_PARAMETERS = "tool.parameters"
"""
Parameters of the tool represented a dictionary JSON string, e.g.
see https://platform.openai.com/docs/guides/gpt/function-calling
"""
RETRIEVAL_DOCUMENTS = "retrieval.documents"
METADATA = "metadata"
"""
Metadata attributes are used to store user-defined key-value pairs.
For example, LangChain uses metadata to store user-defined attributes for a chain.
"""
TAG_TAGS = "tag.tags"
"""
Custom categorical tags for the span.
"""
OPENINFERENCE_SPAN_KIND = "openinference.span.kind"
SESSION_ID = "session.id"
"""
The id of the session
"""
USER_ID = "user.id"
"""
The id of the user
"""
PROMPT_VENDOR = "prompt.vendor"
"""
The vendor or origin of the prompt, e.g. a prompt library, a specialized service, etc.
"""
PROMPT_ID = "prompt.id"
"""
A vendor-specific id used to locate the prompt.
"""
PROMPT_URL = "prompt.url"
"""
A vendor-specific url used to locate the prompt.
"""
class MessageAttributes:
"""
Attributes for a message sent to or from an LLM
"""
MESSAGE_ROLE = "message.role"
"""
The role of the message, such as "user", "agent", "function".
"""
MESSAGE_CONTENT = "message.content"
"""
The content of the message to or from the llm, must be a string.
"""
MESSAGE_CONTENTS = "message.contents"
"""
The message contents to the llm, it is an array of
`message_content` prefixed attributes.
"""
MESSAGE_NAME = "message.name"
"""
The name of the message, often used to identify the function
that was used to generate the message.
"""
MESSAGE_TOOL_CALLS = "message.tool_calls"
"""
The tool calls generated by the model, such as function calls.
"""
MESSAGE_FUNCTION_CALL_NAME = "message.function_call_name"
"""
The function name that is a part of the message list.
This is populated for role 'function' or 'agent' as a mechanism to identify
the function that was called during the execution of a tool.
"""
MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON = "message.function_call_arguments_json"
"""
The JSON string representing the arguments passed to the function
during a function call.
"""
MESSAGE_TOOL_CALL_ID = "message.tool_call_id"
"""
The id of the tool call.
"""
MESSAGE_REASONING_SUMMARY = "message.reasoning_summary"
"""
The reasoning summary from the model's chain-of-thought process.
"""
class MessageContentAttributes:
"""
Attributes for the contents of user messages sent to an LLM.
"""
MESSAGE_CONTENT_TYPE = "message_content.type"
"""
The type of the content, such as "text" or "image".
"""
MESSAGE_CONTENT_TEXT = "message_content.text"
"""
The text content of the message, if the type is "text".
"""
MESSAGE_CONTENT_IMAGE = "message_content.image"
"""
The image content of the message, if the type is "image".
An image can be made available to the model by passing a link to
the image or by passing the base64 encoded image directly in the
request.
"""
class ImageAttributes:
"""
Attributes for images
"""
IMAGE_URL = "image.url"
"""
An http or base64 image url
"""
class AudioAttributes:
"""
Attributes for audio
"""
AUDIO_URL = "audio.url"
"""
The url to an audio file
"""
AUDIO_MIME_TYPE = "audio.mime_type"
"""
The mime type of the audio file
"""
AUDIO_TRANSCRIPT = "audio.transcript"
"""
The transcript of the audio file
"""
class DocumentAttributes:
"""
Attributes for a document.
"""
DOCUMENT_ID = "document.id"
"""
The id of the document.
"""
DOCUMENT_SCORE = "document.score"
"""
The score of the document
"""
DOCUMENT_CONTENT = "document.content"
"""
The content of the document.
"""
DOCUMENT_METADATA = "document.metadata"
"""
The metadata of the document represented as a dictionary
JSON string, e.g. `"{ 'title': 'foo' }"`
"""
class RerankerAttributes:
"""
Attributes for a reranker
"""
RERANKER_INPUT_DOCUMENTS = "reranker.input_documents"
"""
List of documents as input to the reranker
"""
RERANKER_OUTPUT_DOCUMENTS = "reranker.output_documents"
"""
List of documents as output from the reranker
"""
RERANKER_QUERY = "reranker.query"
"""
Query string for the reranker
"""
RERANKER_MODEL_NAME = "reranker.model_name"
"""
Model name of the reranker
"""
RERANKER_TOP_K = "reranker.top_k"
"""
Top K parameter of the reranker
"""
class EmbeddingAttributes:
"""
Attributes for an embedding
"""
EMBEDDING_TEXT = "embedding.text"
"""
The text represented by the embedding.
"""
EMBEDDING_VECTOR = "embedding.vector"
"""
The embedding vector.
"""
class ToolCallAttributes:
"""
Attributes for a tool call
"""
TOOL_CALL_ID = "tool_call.id"
"""
The id of the tool call.
"""
TOOL_CALL_FUNCTION_NAME = "tool_call.function.name"
"""
The name of function that is being called during a tool call.
"""
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = "tool_call.function.arguments"
"""
The JSON string representing the arguments passed to the function
during a tool call.
"""
class ToolAttributes:
"""
Attributes for a tools
"""
TOOL_JSON_SCHEMA = "tool.json_schema"
"""
The json schema of a tool input, It is RECOMMENDED that this be in the
OpenAI tool calling format: https://platform.openai.com/docs/assistants/tools
"""
class OpenInferenceSpanKindValues(Enum):
TOOL = "TOOL"
CHAIN = "CHAIN"
LLM = "LLM"
RETRIEVER = "RETRIEVER"
EMBEDDING = "EMBEDDING"
AGENT = "AGENT"
RERANKER = "RERANKER"
UNKNOWN = "UNKNOWN"
GUARDRAIL = "GUARDRAIL"
EVALUATOR = "EVALUATOR"
class OpenInferenceMimeTypeValues(Enum):
TEXT = "text/plain"
JSON = "application/json"
class OpenInferenceLLMSystemValues(Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
COHERE = "cohere"
MISTRALAI = "mistralai"
VERTEXAI = "vertexai"
class OpenInferenceLLMProviderValues(Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
COHERE = "cohere"
MISTRALAI = "mistralai"
GOOGLE = "google"
AZURE = "azure"
AWS = "aws"
class ErrorAttributes:
"""
Attributes for error information in spans.
These attributes follow OpenTelemetry semantic conventions for exceptions
and are used to record error information from StandardLoggingPayloadErrorInformation.
"""
ERROR_TYPE = "error.type"
"""
The type/class of the error (e.g., 'ValueError', 'OpenAIError', 'RateLimitError').
Corresponds to StandardLoggingPayloadErrorInformation.error_class
"""
ERROR_MESSAGE = "error.message"
"""
The error message describing what went wrong.
Corresponds to StandardLoggingPayloadErrorInformation.error_message
"""
ERROR_CODE = "error.code"
"""
The error code (e.g., HTTP status code like '500', '429', or provider-specific codes).
Corresponds to StandardLoggingPayloadErrorInformation.error_code
"""
ERROR_STACK_TRACE = "error.stack_trace"
"""
The full stack trace of the error.
Corresponds to StandardLoggingPayloadErrorInformation.traceback
"""
ERROR_LLM_PROVIDER = "error.llm_provider"
"""
The LLM provider where the error occurred (e.g., 'openai', 'anthropic', 'azure').
Corresponds to StandardLoggingPayloadErrorInformation.llm_provider
"""

View File

@@ -0,0 +1,36 @@
"""
Base class for Additional Logging Utils for CustomLoggers
- Health Check for the logging util
- Get Request / Response Payload for the logging util
"""
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
class AdditionalLoggingUtils(ABC):
def __init__(self):
super().__init__()
@abstractmethod
async def async_health_check(self) -> IntegrationHealthCheckStatus:
"""
Check if the service is healthy
"""
pass
@abstractmethod
async def get_request_response_payload(
self,
request_id: str,
start_time_utc: Optional[datetime],
end_time_utc: Optional[datetime],
) -> Optional[dict]:
"""
Get the request and response payload for a given `request_id`
"""
return None

View File

@@ -0,0 +1,3 @@
from .agentops import AgentOps
__all__ = ["AgentOps"]

View File

@@ -0,0 +1,116 @@
"""
AgentOps integration for LiteLLM - Provides OpenTelemetry tracing for LLM calls
"""
import os
from dataclasses import dataclass
from typing import Optional, Dict, Any
from litellm.integrations.opentelemetry import OpenTelemetry, OpenTelemetryConfig
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
@dataclass
class AgentOpsConfig:
endpoint: str = "https://otlp.agentops.cloud/v1/traces"
api_key: Optional[str] = None
service_name: Optional[str] = None
deployment_environment: Optional[str] = None
auth_endpoint: str = "https://api.agentops.ai/v3/auth/token"
@classmethod
def from_env(cls):
return cls(
endpoint="https://otlp.agentops.cloud/v1/traces",
api_key=os.getenv("AGENTOPS_API_KEY"),
service_name=os.getenv("AGENTOPS_SERVICE_NAME", "agentops"),
deployment_environment=os.getenv("AGENTOPS_ENVIRONMENT", "production"),
auth_endpoint="https://api.agentops.ai/v3/auth/token",
)
class AgentOps(OpenTelemetry):
"""
AgentOps integration - built on top of OpenTelemetry
Example usage:
```python
import litellm
litellm.success_callback = ["agentops"]
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, how are you?"}],
)
```
"""
def __init__(
self,
config: Optional[AgentOpsConfig] = None,
):
if config is None:
config = AgentOpsConfig.from_env()
# Prefetch JWT token for authentication
jwt_token = None
project_id = None
if config.api_key:
try:
response = self._fetch_auth_token(config.api_key, config.auth_endpoint)
jwt_token = response.get("token")
project_id = response.get("project_id")
except Exception:
pass
headers = f"Authorization=Bearer {jwt_token}" if jwt_token else None
otel_config = OpenTelemetryConfig(
exporter="otlp_http", endpoint=config.endpoint, headers=headers
)
# Initialize OpenTelemetry with our config
super().__init__(config=otel_config, callback_name="agentops")
# Set AgentOps-specific resource attributes
resource_attrs = {
"service.name": config.service_name or "litellm",
"deployment.environment": config.deployment_environment or "production",
"telemetry.sdk.name": "agentops",
}
if project_id:
resource_attrs["project.id"] = project_id
self.resource_attributes = resource_attrs
def _fetch_auth_token(self, api_key: str, auth_endpoint: str) -> Dict[str, Any]:
"""
Fetch JWT authentication token from AgentOps API
Args:
api_key: AgentOps API key
auth_endpoint: Authentication endpoint
Returns:
Dict containing JWT token and project ID
"""
headers = {
"Content-Type": "application/json",
"Connection": "keep-alive",
}
client = _get_httpx_client()
try:
response = client.post(
url=auth_endpoint,
headers=headers,
json={"api_key": api_key},
timeout=10,
)
if response.status_code != 200:
raise Exception(f"Failed to fetch auth token: {response.text}")
return response.json()
finally:
client.close()

View File

@@ -0,0 +1,253 @@
"""
This hook is used to inject cache control directives into the messages of a chat completion.
Users can define
- `cache_control_injection_points` in the completion params and litellm will inject the cache control directives into the messages at the specified injection points.
"""
import copy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.integrations.prompt_management_base import PromptManagementClient
from litellm.types.integrations.anthropic_cache_control_hook import (
CacheControlInjectionPoint,
CacheControlMessageInjectionPoint,
)
from litellm.types.llms.openai import AllMessageValues, ChatCompletionCachedContent
from litellm.types.prompts.init_prompts import PromptSpec
from litellm.types.utils import StandardCallbackDynamicParams
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AnthropicCacheControlHook(CustomPromptManagement):
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_spec: Optional[PromptSpec] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
ignore_prompt_manager_model: Optional[bool] = False,
ignore_prompt_manager_optional_params: Optional[bool] = False,
) -> Tuple[str, List[AllMessageValues], dict]:
"""
Apply cache control directives based on specified injection points.
Returns:
- model: str - the model to use
- messages: List[AllMessageValues] - messages with applied cache controls
- non_default_params: dict - params with any global cache controls
"""
# Extract cache control injection points
injection_points: List[CacheControlInjectionPoint] = non_default_params.pop(
"cache_control_injection_points", []
)
if not injection_points:
return model, messages, non_default_params
# Create a deep copy of messages to avoid modifying the original list
processed_messages = copy.deepcopy(messages)
# Process message-level cache controls
for point in injection_points:
if point.get("location") == "message":
point = cast(CacheControlMessageInjectionPoint, point)
processed_messages = self._process_message_injection(
point=point, messages=processed_messages
)
return model, processed_messages, non_default_params
@staticmethod
def _process_message_injection(
point: CacheControlMessageInjectionPoint, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
"""Process message-level cache control injection."""
control: ChatCompletionCachedContent = point.get(
"control", None
) or ChatCompletionCachedContent(type="ephemeral")
_targetted_index: Optional[Union[int, str]] = point.get("index", None)
targetted_index: Optional[int] = None
if isinstance(_targetted_index, str):
try:
targetted_index = int(_targetted_index)
except ValueError:
pass
else:
targetted_index = _targetted_index
targetted_role = point.get("role", None)
# Case 1: Target by specific index
if targetted_index is not None:
original_index = targetted_index
# Handle negative indices (convert to positive)
if targetted_index < 0:
targetted_index += len(messages)
if 0 <= targetted_index < len(messages):
messages[
targetted_index
] = AnthropicCacheControlHook._safe_insert_cache_control_in_message(
messages[targetted_index], control
)
else:
verbose_logger.warning(
f"AnthropicCacheControlHook: Provided index {original_index} is out of bounds for message list of length {len(messages)}. "
f"Targeted index was {targetted_index}. Skipping cache control injection for this point."
)
# Case 2: Target by role
elif targetted_role is not None:
for msg in messages:
if msg.get("role") == targetted_role:
msg = (
AnthropicCacheControlHook._safe_insert_cache_control_in_message(
message=msg, control=control
)
)
return messages
@staticmethod
def _safe_insert_cache_control_in_message(
message: AllMessageValues, control: ChatCompletionCachedContent
) -> AllMessageValues:
"""
Safe way to insert cache control in a message
OpenAI Message content can be either:
- string
- list of objects
This method handles inserting cache control in both cases.
Per Anthropic's API specification, when using multiple content blocks,
only the last content block can have cache_control.
"""
message_content = message.get("content", None)
# 1. if string, insert cache control in the message
if isinstance(message_content, str):
message["cache_control"] = control # type: ignore
# 2. list of objects - only apply to last item per Anthropic spec
elif isinstance(message_content, list):
if len(message_content) > 0 and isinstance(message_content[-1], dict):
message_content[-1]["cache_control"] = control # type: ignore
return message
@property
def integration_name(self) -> str:
"""Return the integration name for this hook."""
return "anthropic_cache_control_hook"
def should_run_prompt_management(
self,
prompt_id: Optional[str],
prompt_spec: Optional[PromptSpec],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> bool:
"""Always return False since this is not a true prompt management system."""
return False
def _compile_prompt_helper(
self,
prompt_id: Optional[str],
prompt_spec: Optional[PromptSpec],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
) -> PromptManagementClient:
"""Not used - this hook only modifies messages, doesn't fetch prompts."""
return PromptManagementClient(
prompt_id=prompt_id,
prompt_template=[],
prompt_template_model=None,
prompt_template_optional_params=None,
completed_messages=None,
)
async def async_compile_prompt_helper(
self,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_spec: Optional[PromptSpec] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
) -> PromptManagementClient:
"""Not used - this hook only modifies messages, doesn't fetch prompts."""
return self._compile_prompt_helper(
prompt_id=prompt_id,
prompt_spec=prompt_spec,
prompt_variables=prompt_variables,
dynamic_callback_params=dynamic_callback_params,
prompt_label=prompt_label,
prompt_version=prompt_version,
)
async def async_get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
litellm_logging_obj: LiteLLMLoggingObj,
prompt_spec: Optional[PromptSpec] = None,
tools: Optional[List[Dict]] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
ignore_prompt_manager_model: Optional[bool] = False,
ignore_prompt_manager_optional_params: Optional[bool] = False,
) -> Tuple[str, List[AllMessageValues], dict]:
"""Async version - delegates to sync since no async operations needed."""
return self.get_chat_completion_prompt(
model=model,
messages=messages,
non_default_params=non_default_params,
prompt_id=prompt_id,
prompt_variables=prompt_variables,
dynamic_callback_params=dynamic_callback_params,
prompt_spec=prompt_spec,
prompt_label=prompt_label,
prompt_version=prompt_version,
ignore_prompt_manager_model=ignore_prompt_manager_model,
ignore_prompt_manager_optional_params=ignore_prompt_manager_optional_params,
)
@staticmethod
def should_use_anthropic_cache_control_hook(non_default_params: Dict) -> bool:
if non_default_params.get("cache_control_injection_points", None):
return True
return False
@staticmethod
def get_custom_logger_for_anthropic_cache_control_hook(
non_default_params: Dict,
) -> Optional[CustomLogger]:
from litellm.litellm_core_utils.litellm_logging import (
_init_custom_logger_compatible_class,
)
if AnthropicCacheControlHook.should_use_anthropic_cache_control_hook(
non_default_params
):
return _init_custom_logger_compatible_class(
logging_integration="anthropic_cache_control_hook",
internal_usage_cache=None,
llm_router=None,
)
return None

View File

@@ -0,0 +1,391 @@
"""
Send logs to Argilla for annotation
"""
import asyncio
import json
import os
import random
import types
from typing import Any, Dict, List, Optional
import httpx
from pydantic import BaseModel # type: ignore
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.integrations.argilla import (
SUPPORTED_PAYLOAD_FIELDS,
ArgillaCredentialsObject,
ArgillaItem,
)
from litellm.types.utils import StandardLoggingPayload
def is_serializable(value):
non_serializable_types = (
types.CoroutineType,
types.FunctionType,
types.GeneratorType,
BaseModel,
)
return not isinstance(value, non_serializable_types)
class ArgillaLogger(CustomBatchLogger):
def __init__(
self,
argilla_api_key: Optional[str] = None,
argilla_dataset_name: Optional[str] = None,
argilla_base_url: Optional[str] = None,
**kwargs,
):
if litellm.argilla_transformation_object is None:
raise Exception(
"'litellm.argilla_transformation_object' is required, to log your payload to Argilla."
)
self.validate_argilla_transformation_object(
litellm.argilla_transformation_object
)
self.argilla_transformation_object = litellm.argilla_transformation_object
self.default_credentials = self.get_credentials_from_env(
argilla_api_key=argilla_api_key,
argilla_dataset_name=argilla_dataset_name,
argilla_base_url=argilla_base_url,
)
self.sampling_rate: float = (
float(os.getenv("ARGILLA_SAMPLING_RATE")) # type: ignore
if os.getenv("ARGILLA_SAMPLING_RATE") is not None
and os.getenv("ARGILLA_SAMPLING_RATE").strip().isdigit() # type: ignore
else 1.0
)
self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
_batch_size = (
os.getenv("ARGILLA_BATCH_SIZE", None) or litellm.argilla_batch_size
)
if _batch_size:
self.batch_size = int(_batch_size)
asyncio.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock()
super().__init__(**kwargs, flush_lock=self.flush_lock)
def validate_argilla_transformation_object(
self, argilla_transformation_object: Dict[str, Any]
):
if not isinstance(argilla_transformation_object, dict):
raise Exception(
"'argilla_transformation_object' must be a dictionary, to log your payload to Argilla."
)
for v in argilla_transformation_object.values():
if v not in SUPPORTED_PAYLOAD_FIELDS:
raise Exception(
f"All values in argilla_transformation_object must be a key in SUPPORTED_PAYLOAD_FIELDS, {v} is not a valid key."
)
def get_credentials_from_env(
self,
argilla_api_key: Optional[str],
argilla_dataset_name: Optional[str],
argilla_base_url: Optional[str],
) -> ArgillaCredentialsObject:
_credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY")
if _credentials_api_key is None:
raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.")
_credentials_base_url = (
argilla_base_url
or os.getenv("ARGILLA_BASE_URL")
or "http://localhost:6900/"
)
if _credentials_base_url is None:
raise Exception(
"Invalid Argilla Base URL given. _credentials_base_url=None."
)
_credentials_dataset_name = (
argilla_dataset_name
or os.getenv("ARGILLA_DATASET_NAME")
or "litellm-completion"
)
if _credentials_dataset_name is None:
raise Exception("Invalid Argilla Dataset give. Value=None.")
else:
dataset_response = litellm.module_level_client.get(
url=f"{_credentials_base_url}/api/v1/me/datasets?name={_credentials_dataset_name}",
headers={"X-Argilla-Api-Key": _credentials_api_key},
)
json_response = dataset_response.json()
if (
"items" in json_response
and isinstance(json_response["items"], list)
and len(json_response["items"]) > 0
):
_credentials_dataset_name = json_response["items"][0]["id"]
return ArgillaCredentialsObject(
ARGILLA_API_KEY=_credentials_api_key,
ARGILLA_BASE_URL=_credentials_base_url,
ARGILLA_DATASET_NAME=_credentials_dataset_name,
)
def get_chat_messages(
self, payload: StandardLoggingPayload
) -> List[Dict[str, Any]]:
payload_messages = payload.get("messages", None)
if payload_messages is None:
raise Exception("No chat messages found in payload.")
if (
isinstance(payload_messages, list)
and len(payload_messages) > 0
and isinstance(payload_messages[0], dict)
):
return payload_messages
elif isinstance(payload_messages, dict):
return [payload_messages]
else:
raise Exception(f"Invalid chat messages format: {payload_messages}")
def get_str_response(self, payload: StandardLoggingPayload) -> str:
response = payload["response"]
if response is None:
raise Exception("No response found in payload.")
if isinstance(response, str):
return response
elif isinstance(response, dict):
return (
response.get("choices", [{}])[0].get("message", {}).get("content", "")
)
else:
raise Exception(f"Invalid response format: {response}")
def _prepare_log_data(
self, kwargs, response_obj, start_time, end_time
) -> Optional[ArgillaItem]:
try:
# Ensure everything in the payload is converted to str
payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if payload is None:
raise Exception("Error logging request payload. Payload=none.")
argilla_message = self.get_chat_messages(payload)
argilla_response = self.get_str_response(payload)
argilla_item: ArgillaItem = {"fields": {}}
for k, v in self.argilla_transformation_object.items():
if v == "messages":
argilla_item["fields"][k] = argilla_message
elif v == "response":
argilla_item["fields"][k] = argilla_response
else:
argilla_item["fields"][k] = payload.get(v, None)
return argilla_item
except Exception:
raise
def _send_batch(self):
if not self.log_queue:
return
argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"]
argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"]
url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk"
argilla_api_key = self.default_credentials["ARGILLA_API_KEY"]
headers = {"X-Argilla-Api-Key": argilla_api_key}
try:
response = litellm.module_level_client.post(
url=url,
json=self.log_queue,
headers=headers,
)
if response.status_code >= 300:
verbose_logger.error(
f"Argilla Error: {response.status_code} - {response.text}"
)
else:
verbose_logger.debug(
f"Batch of {len(self.log_queue)} runs successfully created"
)
self.log_queue.clear()
except Exception:
verbose_logger.exception("Argilla Layer Error - Error sending batch.")
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
sampling_rate = (
float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore
if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore
else 1.0
)
random_sample = random.random()
if random_sample > sampling_rate:
verbose_logger.info(
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
sampling_rate, random_sample
)
)
return # Skip logging
verbose_logger.debug(
"Langsmith Sync Layer Logging - kwargs: %s, response_obj: %s",
kwargs,
response_obj,
)
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
if data is None:
return
self.log_queue.append(data)
verbose_logger.debug(
f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
)
if len(self.log_queue) >= self.batch_size:
self._send_batch()
except Exception:
verbose_logger.exception("Langsmith Layer Error - log_success_event error")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
sampling_rate = self.sampling_rate
random_sample = random.random()
if random_sample > sampling_rate:
verbose_logger.info(
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
sampling_rate, random_sample
)
)
return # Skip logging
verbose_logger.debug(
"Langsmith Async Layer Logging - kwargs: %s, response_obj: %s",
kwargs,
response_obj,
)
payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
## ALLOW CUSTOM LOGGERS TO MODIFY / FILTER DATA BEFORE LOGGING
for callback in litellm.callbacks:
if isinstance(callback, CustomLogger):
try:
if data is None:
break
data = await callback.async_dataset_hook(data, payload)
except NotImplementedError:
pass
if data is None:
return
self.log_queue.append(data)
verbose_logger.debug(
"Langsmith logging: queue length %s, batch size %s",
len(self.log_queue),
self.batch_size,
)
if len(self.log_queue) >= self.batch_size:
await self.flush_queue()
except Exception:
verbose_logger.exception(
"Argilla Layer Error - error logging async success event."
)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
sampling_rate = self.sampling_rate
random_sample = random.random()
if random_sample > sampling_rate:
verbose_logger.info(
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
sampling_rate, random_sample
)
)
return # Skip logging
verbose_logger.info("Langsmith Failure Event Logging!")
try:
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
self.log_queue.append(data)
verbose_logger.debug(
"Langsmith logging: queue length %s, batch size %s",
len(self.log_queue),
self.batch_size,
)
if len(self.log_queue) >= self.batch_size:
await self.flush_queue()
except Exception:
verbose_logger.exception(
"Langsmith Layer Error - error logging async failure event."
)
async def async_send_batch(self):
"""
sends runs to /batch endpoint
Sends runs from self.log_queue
Returns: None
Raises: Does not raise an exception, will only verbose_logger.exception()
"""
if not self.log_queue:
return
argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"]
argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"]
url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk"
argilla_api_key = self.default_credentials["ARGILLA_API_KEY"]
headers = {"X-Argilla-Api-Key": argilla_api_key}
try:
response = await self.async_httpx_client.put(
url=url,
data=json.dumps(
{
"items": self.log_queue,
}
),
headers=headers,
timeout=60000,
)
response.raise_for_status()
if response.status_code >= 300:
verbose_logger.error(
f"Argilla Error: {response.status_code} - {response.text}"
)
else:
verbose_logger.debug(
"Batch of %s runs successfully created", len(self.log_queue)
)
except httpx.HTTPStatusError:
verbose_logger.exception("Argilla HTTP Error")
except Exception:
verbose_logger.exception("Argilla Layer Error")

View File

@@ -0,0 +1,210 @@
# Arize Phoenix Prompt Management Integration
This integration enables using prompt versions from Arize Phoenix with LiteLLM's completion function.
## Features
- Fetch prompt versions from Arize Phoenix API
- Workspace-based access control through Arize Phoenix permissions
- Mustache/Handlebars-style variable templating (`{{variable}}`)
- Support for multi-message chat templates
- Automatic model and parameter configuration from prompt metadata
- OpenAI and Anthropic provider parameter support
## Configuration
Configure Arize Phoenix access in your application:
```python
import litellm
# Configure Arize Phoenix access
# api_base should include your workspace, e.g., "https://app.phoenix.arize.com/s/your-workspace/v1"
api_key = "your-arize-phoenix-token"
api_base = "https://app.phoenix.arize.com/s/krrishdholakia/v1"
```
## Usage
### Basic Usage
```python
import litellm
# Use with completion
response = litellm.completion(
model="arize/gpt-4o",
prompt_id="UHJvbXB0VmVyc2lvbjox", # Your prompt version ID
prompt_variables={"question": "What is artificial intelligence?"},
api_key="your-arize-phoenix-token",
api_base="https://app.phoenix.arize.com/s/krrishdholakia/v1",
)
print(response.choices[0].message.content)
```
### With Additional Messages
You can also combine prompt templates with additional messages:
```python
response = litellm.completion(
model="arize/gpt-4o",
prompt_id="UHJvbXB0VmVyc2lvbjox",
prompt_variables={"question": "Explain quantum computing"},
api_key="your-arize-phoenix-token",
api_base="https://app.phoenix.arize.com/s/krrishdholakia/v1",
messages=[
{"role": "user", "content": "Please keep your response under 100 words."}
],
)
```
### Direct Manager Usage
You can also use the prompt manager directly:
```python
from litellm.integrations.arize.arize_phoenix_prompt_manager import ArizePhoenixPromptManager
# Initialize the manager
manager = ArizePhoenixPromptManager(
api_key="your-arize-phoenix-token",
api_base="https://app.phoenix.arize.com/s/krrishdholakia/v1",
prompt_id="UHJvbXB0VmVyc2lvbjox",
)
# Get rendered messages
messages, metadata = manager.get_prompt_template(
prompt_id="UHJvbXB0VmVyc2lvbjox",
prompt_variables={"question": "What is machine learning?"}
)
print("Rendered messages:", messages)
print("Metadata:", metadata)
```
## Prompt Format
Arize Phoenix prompts support the following structure:
```json
{
"data": {
"description": "A chatbot prompt",
"model_provider": "OPENAI",
"model_name": "gpt-4o",
"template": {
"type": "chat",
"messages": [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are a chatbot"
}
]
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "{{question}}"
}
]
}
]
},
"template_type": "CHAT",
"template_format": "MUSTACHE",
"invocation_parameters": {
"type": "openai",
"openai": {
"temperature": 1.0
}
},
"id": "UHJvbXB0VmVyc2lvbjox"
}
}
```
### Variable Substitution
Variables in your prompt templates use Mustache/Handlebars syntax:
- `{{variable_name}}` - Simple variable substitution
Example:
```
Template: "Hello {{name}}, your order {{order_id}} is ready!"
Variables: {"name": "Alice", "order_id": "12345"}
Result: "Hello Alice, your order 12345 is ready!"
```
## API Reference
### ArizePhoenixPromptManager
Main class for managing Arize Phoenix prompts.
**Methods:**
- `get_prompt_template(prompt_id, prompt_variables)` - Get and render a prompt template
- `get_available_prompts()` - List available prompt IDs
- `reload_prompts()` - Reload prompts from Arize Phoenix
### ArizePhoenixClient
Low-level client for Arize Phoenix API.
**Methods:**
- `get_prompt_version(prompt_version_id)` - Fetch a prompt version
- `test_connection()` - Test API connection
## Error Handling
The integration provides detailed error messages:
- **404**: Prompt version not found
- **401**: Authentication failed (check your access token)
- **403**: Access denied (check workspace permissions)
Example:
```python
try:
response = litellm.completion(
model="arize/gpt-4o",
prompt_id="invalid-id",
arize_config=arize_config,
)
except Exception as e:
print(f"Error: {e}")
```
## Getting Your Prompt Version ID and API Base
1. Log in to Arize Phoenix
2. Navigate to your workspace
3. Go to Prompts section
4. Select a prompt version
5. The ID will be in the URL: `/s/{workspace}/v1/prompt_versions/{PROMPT_VERSION_ID}`
Your `api_base` should be: `https://app.phoenix.arize.com/s/{workspace}/v1`
For example:
- Workspace: `krrishdholakia`
- API Base: `https://app.phoenix.arize.com/s/krrishdholakia/v1`
- Prompt Version ID: `UHJvbXB0VmVyc2lvbjox`
You can also fetch it via API:
```bash
curl -L -X GET 'https://app.phoenix.arize.com/s/krrishdholakia/v1/prompt_versions/UHJvbXB0VmVyc2lvbjox' \
-H 'Authorization: Bearer YOUR_TOKEN'
```
## Support
For issues or questions:
- LiteLLM Issues: https://github.com/BerriAI/litellm/issues
- Arize Phoenix Docs: https://docs.arize.com/phoenix

View File

@@ -0,0 +1,52 @@
import os
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from litellm.types.prompts.init_prompts import PromptLiteLLMParams, PromptSpec
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.types.prompts.init_prompts import SupportedPromptIntegrations
from .arize_phoenix_prompt_manager import ArizePhoenixPromptManager
# Global instances
global_arize_config: Optional[dict] = None
def prompt_initializer(
litellm_params: "PromptLiteLLMParams", prompt_spec: "PromptSpec"
) -> "CustomPromptManagement":
"""
Initialize a prompt from Arize Phoenix.
"""
api_key = getattr(litellm_params, "api_key", None) or os.environ.get(
"PHOENIX_API_KEY"
)
api_base = getattr(litellm_params, "api_base", None)
prompt_id = getattr(litellm_params, "prompt_id", None)
if not api_key or not api_base:
raise ValueError(
"api_key and api_base are required for Arize Phoenix prompt integration"
)
try:
arize_prompt_manager = ArizePhoenixPromptManager(
**{
"api_key": api_key,
"api_base": api_base,
"prompt_id": prompt_id,
**litellm_params.model_dump(
exclude={"api_key", "api_base", "prompt_id"}
),
},
)
return arize_prompt_manager
except Exception as e:
raise e
prompt_initializer_registry = {
SupportedPromptIntegrations.ARIZE_PHOENIX.value: prompt_initializer,
}

View File

@@ -0,0 +1,502 @@
import json
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
from typing_extensions import override
from litellm._logging import verbose_logger
from litellm.integrations.opentelemetry_utils.base_otel_llm_obs_attributes import (
BaseLLMObsOTELAttributes,
safe_set_attribute,
)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.types.utils import StandardLoggingPayload
if TYPE_CHECKING:
from opentelemetry.trace import Span
from litellm.integrations._types.open_inference import (
MessageAttributes,
ImageAttributes,
SpanAttributes,
AudioAttributes,
EmbeddingAttributes,
OpenInferenceSpanKindValues,
)
class ArizeOTELAttributes(BaseLLMObsOTELAttributes):
@staticmethod
@override
def set_messages(span: "Span", kwargs: Dict[str, Any]):
messages = kwargs.get("messages")
# for /chat/completions
# https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
if messages:
last_message = messages[-1]
safe_set_attribute(
span,
SpanAttributes.INPUT_VALUE,
last_message.get("content", ""),
)
# LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page.
for idx, msg in enumerate(messages):
prefix = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}"
# Set the role per message.
safe_set_attribute(
span, f"{prefix}.{MessageAttributes.MESSAGE_ROLE}", msg.get("role")
)
# Set the content per message.
safe_set_attribute(
span,
f"{prefix}.{MessageAttributes.MESSAGE_CONTENT}",
msg.get("content", ""),
)
@staticmethod
@override
def set_response_output_messages(span: "Span", response_obj):
"""
Sets output message attributes on the span from the LLM response.
Args:
span: The OpenTelemetry span to set attributes on
response_obj: The response object containing choices with messages
"""
from litellm.integrations._types.open_inference import (
MessageAttributes,
SpanAttributes,
)
for idx, choice in enumerate(response_obj.get("choices", [])):
response_message = choice.get("message", {})
safe_set_attribute(
span,
SpanAttributes.OUTPUT_VALUE,
response_message.get("content", ""),
)
# This shows up under `output_messages` tab on the span page.
prefix = f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{idx}"
safe_set_attribute(
span,
f"{prefix}.{MessageAttributes.MESSAGE_ROLE}",
response_message.get("role"),
)
safe_set_attribute(
span,
f"{prefix}.{MessageAttributes.MESSAGE_CONTENT}",
response_message.get("content", ""),
)
def _set_response_attributes(span: "Span", response_obj):
"""Helper to set response output and token usage attributes on span."""
if not hasattr(response_obj, "get"):
return
_set_choice_outputs(span, response_obj, MessageAttributes, SpanAttributes)
_set_image_outputs(span, response_obj, ImageAttributes, SpanAttributes)
_set_audio_outputs(span, response_obj, AudioAttributes, SpanAttributes)
_set_embedding_outputs(span, response_obj, EmbeddingAttributes, SpanAttributes)
_set_structured_outputs(span, response_obj, MessageAttributes, SpanAttributes)
_set_usage_outputs(span, response_obj, SpanAttributes)
def _set_choice_outputs(span: "Span", response_obj, msg_attrs, span_attrs):
for idx, choice in enumerate(response_obj.get("choices", [])):
response_message = choice.get("message", {})
safe_set_attribute(
span,
span_attrs.OUTPUT_VALUE,
response_message.get("content", ""),
)
prefix = f"{span_attrs.LLM_OUTPUT_MESSAGES}.{idx}"
safe_set_attribute(
span,
f"{prefix}.{msg_attrs.MESSAGE_ROLE}",
response_message.get("role"),
)
safe_set_attribute(
span,
f"{prefix}.{msg_attrs.MESSAGE_CONTENT}",
response_message.get("content", ""),
)
def _set_image_outputs(span: "Span", response_obj, image_attrs, span_attrs):
images = response_obj.get("data", [])
for i, image in enumerate(images):
img_url = image.get("url")
if img_url is None and image.get("b64_json"):
img_url = f"data:image/png;base64,{image.get('b64_json')}"
if not img_url:
continue
if i == 0:
safe_set_attribute(span, span_attrs.OUTPUT_VALUE, img_url)
safe_set_attribute(span, f"{image_attrs.IMAGE_URL}.{i}", img_url)
def _set_audio_outputs(span: "Span", response_obj, audio_attrs, span_attrs):
audio = response_obj.get("audio", [])
for i, audio_item in enumerate(audio):
audio_url = audio_item.get("url")
if audio_url is None and audio_item.get("b64_json"):
audio_url = f"data:audio/wav;base64,{audio_item.get('b64_json')}"
if audio_url:
if i == 0:
safe_set_attribute(span, span_attrs.OUTPUT_VALUE, audio_url)
safe_set_attribute(span, f"{audio_attrs.AUDIO_URL}.{i}", audio_url)
audio_mime = audio_item.get("mime_type")
if audio_mime:
safe_set_attribute(span, f"{audio_attrs.AUDIO_MIME_TYPE}.{i}", audio_mime)
audio_transcript = audio_item.get("transcript")
if audio_transcript:
safe_set_attribute(
span, f"{audio_attrs.AUDIO_TRANSCRIPT}.{i}", audio_transcript
)
def _set_embedding_outputs(span: "Span", response_obj, embedding_attrs, span_attrs):
embeddings = response_obj.get("data", [])
for i, embedding_item in enumerate(embeddings):
embedding_vector = embedding_item.get("embedding")
if embedding_vector:
if i == 0:
safe_set_attribute(
span,
span_attrs.OUTPUT_VALUE,
str(embedding_vector),
)
safe_set_attribute(
span,
f"{embedding_attrs.EMBEDDING_VECTOR}.{i}",
str(embedding_vector),
)
embedding_text = embedding_item.get("text")
if embedding_text:
safe_set_attribute(
span,
f"{embedding_attrs.EMBEDDING_TEXT}.{i}",
str(embedding_text),
)
def _set_structured_outputs(span: "Span", response_obj, msg_attrs, span_attrs):
output_items = response_obj.get("output", [])
for i, item in enumerate(output_items):
prefix = f"{span_attrs.LLM_OUTPUT_MESSAGES}.{i}"
if not hasattr(item, "type"):
continue
item_type = item.type
if item_type == "reasoning" and hasattr(item, "summary"):
for summary in item.summary:
if hasattr(summary, "text"):
safe_set_attribute(
span,
f"{prefix}.{msg_attrs.MESSAGE_REASONING_SUMMARY}",
summary.text,
)
elif item_type == "message" and hasattr(item, "content"):
message_content = ""
content_list = item.content
if content_list and len(content_list) > 0:
first_content = content_list[0]
message_content = getattr(first_content, "text", "")
message_role = getattr(item, "role", "assistant")
safe_set_attribute(span, span_attrs.OUTPUT_VALUE, message_content)
safe_set_attribute(
span, f"{prefix}.{msg_attrs.MESSAGE_CONTENT}", message_content
)
safe_set_attribute(span, f"{prefix}.{msg_attrs.MESSAGE_ROLE}", message_role)
def _set_usage_outputs(span: "Span", response_obj, span_attrs):
usage = response_obj and response_obj.get("usage")
if not usage:
return
safe_set_attribute(
span, span_attrs.LLM_TOKEN_COUNT_TOTAL, usage.get("total_tokens")
)
completion_tokens = usage.get("completion_tokens") or usage.get("output_tokens")
if completion_tokens:
safe_set_attribute(
span, span_attrs.LLM_TOKEN_COUNT_COMPLETION, completion_tokens
)
prompt_tokens = usage.get("prompt_tokens") or usage.get("input_tokens")
if prompt_tokens:
safe_set_attribute(span, span_attrs.LLM_TOKEN_COUNT_PROMPT, prompt_tokens)
reasoning_tokens = usage.get("output_tokens_details", {}).get("reasoning_tokens")
if reasoning_tokens:
safe_set_attribute(
span,
span_attrs.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING,
reasoning_tokens,
)
def _infer_open_inference_span_kind(call_type: Optional[str]) -> str:
"""
Map LiteLLM call types to OpenInference span kinds.
"""
if not call_type:
return OpenInferenceSpanKindValues.UNKNOWN.value
lowered = str(call_type).lower()
if "embed" in lowered:
return OpenInferenceSpanKindValues.EMBEDDING.value
if "rerank" in lowered:
return OpenInferenceSpanKindValues.RERANKER.value
if "search" in lowered:
return OpenInferenceSpanKindValues.RETRIEVER.value
if "moderation" in lowered or "guardrail" in lowered:
return OpenInferenceSpanKindValues.GUARDRAIL.value
if lowered == "call_mcp_tool" or lowered == "mcp" or lowered.endswith("tool"):
return OpenInferenceSpanKindValues.TOOL.value
if "asend_message" in lowered or "a2a" in lowered or "assistant" in lowered:
return OpenInferenceSpanKindValues.AGENT.value
if any(
keyword in lowered
for keyword in (
"completion",
"chat",
"image",
"audio",
"speech",
"transcription",
"generate_content",
"response",
"videos",
"realtime",
"pass_through",
"anthropic_messages",
"ocr",
)
):
return OpenInferenceSpanKindValues.LLM.value
if any(
keyword in lowered
for keyword in ("file", "batch", "container", "fine_tuning_job")
):
return OpenInferenceSpanKindValues.CHAIN.value
return OpenInferenceSpanKindValues.UNKNOWN.value
def _set_tool_attributes(
span: "Span", optional_tools: Optional[list], metadata_tools: Optional[list]
):
"""set tool attributes on span from optional_params or tool call metadata"""
if optional_tools:
for idx, tool in enumerate(optional_tools):
if not isinstance(tool, dict):
continue
function = (
tool.get("function") if isinstance(tool.get("function"), dict) else None
)
if not function:
continue
tool_name = function.get("name")
if tool_name:
safe_set_attribute(
span, f"{SpanAttributes.LLM_TOOLS}.{idx}.name", tool_name
)
tool_description = function.get("description")
if tool_description:
safe_set_attribute(
span,
f"{SpanAttributes.LLM_TOOLS}.{idx}.description",
tool_description,
)
params = function.get("parameters")
if params is not None:
safe_set_attribute(
span,
f"{SpanAttributes.LLM_TOOLS}.{idx}.parameters",
json.dumps(params),
)
if metadata_tools and isinstance(metadata_tools, list):
for idx, tool in enumerate(metadata_tools):
if not isinstance(tool, dict):
continue
tool_name = tool.get("name")
if tool_name:
safe_set_attribute(
span,
f"{SpanAttributes.LLM_INVOCATION_PARAMETERS}.tools.{idx}.name",
tool_name,
)
tool_description = tool.get("description")
if tool_description:
safe_set_attribute(
span,
f"{SpanAttributes.LLM_INVOCATION_PARAMETERS}.tools.{idx}.description",
tool_description,
)
def set_attributes(
span: "Span", kwargs, response_obj, attributes: Type[BaseLLMObsOTELAttributes]
):
"""
Populates span with OpenInference-compliant LLM attributes for Arize and Phoenix tracing.
"""
try:
optional_params = _sanitize_optional_params(kwargs.get("optional_params"))
litellm_params = kwargs.get("litellm_params", {}) or {}
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if standard_logging_payload is None:
raise ValueError("standard_logging_object not found in kwargs")
metadata = (
standard_logging_payload.get("metadata")
if standard_logging_payload
else None
)
_set_metadata_attributes(span, metadata, SpanAttributes)
metadata_tools = _extract_metadata_tools(metadata)
optional_tools = _extract_optional_tools(optional_params)
call_type = standard_logging_payload.get("call_type")
_set_request_attributes(
span=span,
kwargs=kwargs,
standard_logging_payload=standard_logging_payload,
optional_params=optional_params,
litellm_params=litellm_params,
response_obj=response_obj,
span_attrs=SpanAttributes,
)
span_kind = _infer_open_inference_span_kind(call_type=call_type)
_set_tool_attributes(span, optional_tools, metadata_tools)
if (
optional_tools or metadata_tools
) and span_kind != OpenInferenceSpanKindValues.TOOL.value:
span_kind = OpenInferenceSpanKindValues.TOOL.value
safe_set_attribute(span, SpanAttributes.OPENINFERENCE_SPAN_KIND, span_kind)
attributes.set_messages(span, kwargs)
model_params = (
standard_logging_payload.get("model_parameters")
if standard_logging_payload
else None
)
_set_model_params(span, model_params, SpanAttributes)
_set_response_attributes(span=span, response_obj=response_obj)
except Exception as e:
verbose_logger.error(
f"[Arize/Phoenix] Failed to set OpenInference span attributes: {e}"
)
if hasattr(span, "record_exception"):
span.record_exception(e)
def _sanitize_optional_params(optional_params: Optional[dict]) -> dict:
if not isinstance(optional_params, dict):
return {}
optional_params.pop("secret_fields", None)
return optional_params
def _set_metadata_attributes(span: "Span", metadata: Optional[Any], span_attrs) -> None:
if metadata is not None:
safe_set_attribute(span, span_attrs.METADATA, safe_dumps(metadata))
def _extract_metadata_tools(metadata: Optional[Any]) -> Optional[list]:
if not isinstance(metadata, dict):
return None
llm_obj = metadata.get("llm")
if isinstance(llm_obj, dict):
return llm_obj.get("tools")
return None
def _extract_optional_tools(optional_params: dict) -> Optional[list]:
return optional_params.get("tools") if isinstance(optional_params, dict) else None
def _set_request_attributes(
span: "Span",
kwargs,
standard_logging_payload: StandardLoggingPayload,
optional_params: dict,
litellm_params: dict,
response_obj,
span_attrs,
):
if kwargs.get("model"):
safe_set_attribute(span, span_attrs.LLM_MODEL_NAME, kwargs.get("model"))
safe_set_attribute(
span, "llm.request.type", standard_logging_payload.get("call_type")
)
safe_set_attribute(
span,
span_attrs.LLM_PROVIDER,
litellm_params.get("custom_llm_provider", "Unknown"),
)
if optional_params.get("max_tokens"):
safe_set_attribute(
span, "llm.request.max_tokens", optional_params.get("max_tokens")
)
if optional_params.get("temperature"):
safe_set_attribute(
span, "llm.request.temperature", optional_params.get("temperature")
)
if optional_params.get("top_p"):
safe_set_attribute(span, "llm.request.top_p", optional_params.get("top_p"))
safe_set_attribute(
span, "llm.is_streaming", str(optional_params.get("stream", False))
)
if optional_params.get("user"):
safe_set_attribute(span, "llm.user", optional_params.get("user"))
if response_obj and response_obj.get("id"):
safe_set_attribute(span, "llm.response.id", response_obj.get("id"))
if response_obj and response_obj.get("model"):
safe_set_attribute(span, "llm.response.model", response_obj.get("model"))
def _set_model_params(span: "Span", model_params: Optional[dict], span_attrs) -> None:
if not model_params:
return
safe_set_attribute(
span, span_attrs.LLM_INVOCATION_PARAMETERS, safe_dumps(model_params)
)
if model_params.get("user"):
user_id = model_params.get("user")
if user_id is not None:
safe_set_attribute(span, span_attrs.USER_ID, user_id)

View File

@@ -0,0 +1,214 @@
"""
arize AI is OTEL compatible
this file has Arize ai specific helper functions
"""
import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Optional, Union
from litellm.integrations.arize import _utils
from litellm.integrations.arize._utils import ArizeOTELAttributes
from litellm.integrations.opentelemetry import OpenTelemetry
from litellm.types.integrations.arize import ArizeConfig
from litellm.types.services import ServiceLoggerPayload
from litellm.types.utils import StandardCallbackDynamicParams
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.types.integrations.arize import Protocol as _Protocol
Protocol = _Protocol
Span = Union[_Span, Any]
else:
Protocol = Any
Span = Any
class ArizeLogger(OpenTelemetry):
"""
Arize logger that sends traces to an Arize endpoint.
Creates its own dedicated TracerProvider so it can coexist with the
generic ``otel`` callback (or any other OTEL-based integration) without
fighting over the global ``opentelemetry.trace`` TracerProvider singleton.
"""
def _init_tracing(self, tracer_provider):
"""
Override to always create a *private* TracerProvider for Arize.
See ArizePhoenixLogger._init_tracing for full rationale.
"""
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace import SpanKind
if tracer_provider is not None:
self.tracer = tracer_provider.get_tracer("litellm")
self.span_kind = SpanKind
return
provider = TracerProvider(resource=self._get_litellm_resource(self.config))
provider.add_span_processor(self._get_span_processor())
self.tracer = provider.get_tracer("litellm")
self.span_kind = SpanKind
def _init_otel_logger_on_litellm_proxy(self):
"""
Override: Arize should NOT overwrite the proxy's
``open_telemetry_logger``. That attribute is reserved for the
primary ``otel`` callback which handles proxy-level parent spans.
"""
pass
def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]):
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
return
@staticmethod
def set_arize_attributes(span: Span, kwargs, response_obj):
_utils.set_attributes(span, kwargs, response_obj, ArizeOTELAttributes)
return
@staticmethod
def get_arize_config() -> ArizeConfig:
"""
Helper function to get Arize configuration.
Returns:
ArizeConfig: A Pydantic model containing Arize configuration.
Raises:
ValueError: If required environment variables are not set.
"""
space_id = os.environ.get("ARIZE_SPACE_ID")
space_key = os.environ.get("ARIZE_SPACE_KEY")
api_key = os.environ.get("ARIZE_API_KEY")
project_name = os.environ.get("ARIZE_PROJECT_NAME")
grpc_endpoint = os.environ.get("ARIZE_ENDPOINT")
http_endpoint = os.environ.get("ARIZE_HTTP_ENDPOINT")
endpoint = None
protocol: Protocol = "otlp_grpc"
if grpc_endpoint:
protocol = "otlp_grpc"
endpoint = grpc_endpoint
elif http_endpoint:
protocol = "otlp_http"
endpoint = http_endpoint
else:
protocol = "otlp_grpc"
endpoint = "https://otlp.arize.com/v1"
return ArizeConfig(
space_id=space_id,
space_key=space_key,
api_key=api_key,
protocol=protocol,
endpoint=endpoint,
project_name=project_name,
)
async def async_service_success_hook(
self,
payload: ServiceLoggerPayload,
parent_otel_span: Optional[Span] = None,
start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[Union[datetime, float]] = None,
event_metadata: Optional[dict] = None,
):
"""Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs"""
pass
async def async_service_failure_hook(
self,
payload: ServiceLoggerPayload,
error: Optional[str] = "",
parent_otel_span: Optional[Span] = None,
start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[Union[float, datetime]] = None,
event_metadata: Optional[dict] = None,
):
"""Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs"""
pass
# def create_litellm_proxy_request_started_span(
# self,
# start_time: datetime,
# headers: dict,
# ):
# """Arize is used mainly for LLM I/O tracing, sending Proxy Server Request adds bloat to arize logs"""
# pass
async def async_health_check(self):
"""
Performs a health check for Arize integration.
Returns:
dict: Health check result with status and message
"""
try:
config = self.get_arize_config()
if not config.space_id and not config.space_key:
return {
"status": "unhealthy",
"error_message": "ARIZE_SPACE_ID or ARIZE_SPACE_KEY environment variable not set",
}
if not config.api_key:
return {
"status": "unhealthy",
"error_message": "ARIZE_API_KEY environment variable not set",
}
return {
"status": "healthy",
"message": "Arize credentials are configured properly",
}
except Exception as e:
return {
"status": "unhealthy",
"error_message": f"Arize health check failed: {str(e)}",
}
def construct_dynamic_otel_headers(
self, standard_callback_dynamic_params: StandardCallbackDynamicParams
) -> Optional[dict]:
"""
Construct dynamic Arize headers from standard callback dynamic params
This is used for team/key based logging.
Returns:
dict: A dictionary of dynamic Arize headers
"""
dynamic_headers = {}
#########################################################
# `arize-space-id` handling
# the suggested param is `arize_space_key`
#########################################################
if standard_callback_dynamic_params.get("arize_space_id"):
dynamic_headers["arize-space-id"] = standard_callback_dynamic_params.get(
"arize_space_id"
)
if standard_callback_dynamic_params.get("arize_space_key"):
dynamic_headers["arize-space-id"] = standard_callback_dynamic_params.get(
"arize_space_key"
)
#########################################################
# `api_key` handling
#########################################################
if standard_callback_dynamic_params.get("arize_api_key"):
dynamic_headers["api_key"] = standard_callback_dynamic_params.get(
"arize_api_key"
)
return dynamic_headers

View File

@@ -0,0 +1,360 @@
import os
from typing import TYPE_CHECKING, Any, Optional, Union
from litellm._logging import verbose_logger
from litellm.integrations.arize import _utils
from litellm.integrations.arize._utils import ArizeOTELAttributes
from litellm.types.integrations.arize_phoenix import ArizePhoenixConfig
if TYPE_CHECKING:
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace import Span as _Span
from opentelemetry.trace import SpanKind
from litellm.integrations.opentelemetry import OpenTelemetry as _OpenTelemetry
from litellm.integrations.opentelemetry import (
OpenTelemetryConfig as _OpenTelemetryConfig,
)
from litellm.types.integrations.arize import Protocol as _Protocol
Protocol = _Protocol
OpenTelemetryConfig = _OpenTelemetryConfig
Span = Union[_Span, Any]
OpenTelemetry = _OpenTelemetry
else:
Protocol = Any
OpenTelemetryConfig = Any
Span = Any
TracerProvider = Any
SpanKind = Any
# Import OpenTelemetry at runtime
try:
from litellm.integrations.opentelemetry import OpenTelemetry
except ImportError:
OpenTelemetry = None # type: ignore
ARIZE_HOSTED_PHOENIX_ENDPOINT = "https://otlp.arize.com/v1/traces"
class ArizePhoenixLogger(OpenTelemetry): # type: ignore
"""
Arize Phoenix logger that sends traces to a Phoenix endpoint.
Creates its own dedicated TracerProvider so it can coexist with the
generic ``otel`` callback (or any other OTEL-based integration) without
fighting over the global ``opentelemetry.trace`` TracerProvider singleton.
"""
def _init_tracing(self, tracer_provider):
"""
Override to always create a *private* TracerProvider for Arize Phoenix.
The base ``OpenTelemetry._init_tracing`` falls back to the global
TracerProvider when one already exists. That causes whichever
integration initialises second to silently reuse the first one's
exporter, so spans only reach one destination.
By creating our own provider we guarantee Arize Phoenix always gets
its own exporter pipeline, regardless of initialisation order.
"""
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace import SpanKind
if tracer_provider is not None:
# Explicitly supplied (e.g. in tests) — honour it.
self.tracer = tracer_provider.get_tracer("litellm")
self.span_kind = SpanKind
return
# Always create a dedicated provider — never touch the global one.
provider = TracerProvider(resource=self._get_litellm_resource(self.config))
provider.add_span_processor(self._get_span_processor())
self.tracer = provider.get_tracer("litellm")
self.span_kind = SpanKind
verbose_logger.debug(
"ArizePhoenixLogger: Created dedicated TracerProvider "
"(endpoint=%s, exporter=%s)",
self.config.endpoint,
self.config.exporter,
)
def _init_otel_logger_on_litellm_proxy(self):
"""
Override: Arize Phoenix should NOT overwrite the proxy's
``open_telemetry_logger``. That attribute is reserved for the
primary ``otel`` callback which handles proxy-level parent spans.
"""
pass
def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]):
ArizePhoenixLogger.set_arize_phoenix_attributes(span, kwargs, response_obj)
return
@staticmethod
def set_arize_phoenix_attributes(span: Span, kwargs, response_obj):
from litellm.integrations.opentelemetry_utils.base_otel_llm_obs_attributes import (
safe_set_attribute,
)
_utils.set_attributes(span, kwargs, response_obj, ArizeOTELAttributes)
# Dynamic project name: check metadata first, then fall back to env var config
dynamic_project_name = ArizePhoenixLogger._get_dynamic_project_name(kwargs)
if dynamic_project_name:
safe_set_attribute(span, "openinference.project.name", dynamic_project_name)
else:
# Fall back to static config from env var
config = ArizePhoenixLogger.get_arize_phoenix_config()
if config.project_name:
safe_set_attribute(
span, "openinference.project.name", config.project_name
)
return
@staticmethod
def _get_dynamic_project_name(kwargs) -> Optional[str]:
"""
Retrieve dynamic Phoenix project name from request metadata.
Users can set `metadata.phoenix_project_name` in their request to route
traces to different Phoenix projects dynamically.
"""
standard_logging_payload = kwargs.get("standard_logging_object")
if isinstance(standard_logging_payload, dict):
metadata = standard_logging_payload.get("metadata")
if isinstance(metadata, dict):
project_name = metadata.get("phoenix_project_name")
if project_name:
return str(project_name)
# Also check litellm_params.metadata for SDK usage
litellm_params = kwargs.get("litellm_params")
if isinstance(litellm_params, dict):
metadata = litellm_params.get("metadata") or {}
else:
metadata = {}
if isinstance(metadata, dict):
project_name = metadata.get("phoenix_project_name")
if project_name:
return str(project_name)
return None
def _get_phoenix_context(self, kwargs):
"""
Build a trace context for Phoenix's dedicated TracerProvider.
The base ``_get_span_context`` returns parent spans from the global
TracerProvider (the ``otel`` callback). Those spans live on a
*different* TracerProvider, so they won't appear in Phoenix — using
them as parents just creates broken links.
Instead we:
1. Honour an incoming ``traceparent`` HTTP header (distributed tracing).
2. In proxy mode, create our *own* parent span on Phoenix's tracer
so the hierarchy is visible end-to-end inside Phoenix.
3. In SDK (non-proxy) mode, just return (None, None) for a root span.
"""
from opentelemetry import trace
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request", {}) or {}
headers = proxy_server_request.get("headers", {}) or {}
# Propagate distributed trace context if the caller sent a traceparent
traceparent_ctx = (
self.get_traceparent_from_header(headers=headers)
if headers.get("traceparent")
else None
)
is_proxy_mode = bool(proxy_server_request)
if is_proxy_mode:
# Create a parent span on Phoenix's own tracer so both parent
# and child are exported to Phoenix.
start_time_val = kwargs.get("start_time", kwargs.get("api_call_start_time"))
parent_span = self.tracer.start_span(
name="litellm_proxy_request",
start_time=self._to_ns(start_time_val)
if start_time_val is not None
else None,
context=traceparent_ctx,
kind=self.span_kind.SERVER,
)
ctx = trace.set_span_in_context(parent_span)
return ctx, parent_span
# SDK mode — no parent span needed
return traceparent_ctx, None
def _handle_success(self, kwargs, response_obj, start_time, end_time):
"""
Override to always create spans on ArizePhoenixLogger's dedicated TracerProvider.
The base class's ``_get_span_context`` would find the parent span created by
the ``otel`` callback on the *global* TracerProvider. That span is invisible
in Phoenix (different exporter pipeline), so we ignore it and build our own
hierarchy via ``_get_phoenix_context``.
"""
from opentelemetry.trace import Status, StatusCode
verbose_logger.debug(
"ArizePhoenixLogger: Logging kwargs: %s, OTEL config settings=%s",
kwargs,
self.config,
)
ctx, parent_span = self._get_phoenix_context(kwargs)
# Create litellm_request span (child of our parent when in proxy mode)
span = self.tracer.start_span(
name=self._get_span_name(kwargs),
start_time=self._to_ns(start_time),
context=ctx,
)
span.set_status(Status(StatusCode.OK))
self.set_attributes(span, kwargs, response_obj)
# Raw-request sub-span (if enabled) — must be created before
# ending the parent span so the hierarchy is valid.
self._maybe_log_raw_request(kwargs, response_obj, start_time, end_time, span)
span.end(end_time=self._to_ns(end_time))
# Guardrail span
self._create_guardrail_span(kwargs=kwargs, context=ctx)
# Annotate and close our proxy parent span
if parent_span is not None:
parent_span.set_status(Status(StatusCode.OK))
self.set_attributes(parent_span, kwargs, response_obj)
parent_span.end(end_time=self._to_ns(end_time))
# Metrics & cost recording
self._record_metrics(kwargs, response_obj, start_time, end_time)
# Semantic logs
if self.config.enable_events:
self._emit_semantic_logs(kwargs, response_obj, span)
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
"""
Override to always create failure spans on ArizePhoenixLogger's dedicated
TracerProvider. Mirrors ``_handle_success`` but sets ERROR status.
"""
from opentelemetry.trace import Status, StatusCode
verbose_logger.debug(
"ArizePhoenixLogger: Failure - Logging kwargs: %s, OTEL config settings=%s",
kwargs,
self.config,
)
ctx, parent_span = self._get_phoenix_context(kwargs)
# Create litellm_request span (child of our parent when in proxy mode)
span = self.tracer.start_span(
name=self._get_span_name(kwargs),
start_time=self._to_ns(start_time),
context=ctx,
)
span.set_status(Status(StatusCode.ERROR))
self.set_attributes(span, kwargs, response_obj)
self._record_exception_on_span(span=span, kwargs=kwargs)
span.end(end_time=self._to_ns(end_time))
# Guardrail span
self._create_guardrail_span(kwargs=kwargs, context=ctx)
# Annotate and close our proxy parent span
if parent_span is not None:
parent_span.set_status(Status(StatusCode.ERROR))
self.set_attributes(parent_span, kwargs, response_obj)
self._record_exception_on_span(span=parent_span, kwargs=kwargs)
parent_span.end(end_time=self._to_ns(end_time))
@staticmethod
def get_arize_phoenix_config() -> ArizePhoenixConfig:
"""
Retrieves the Arize Phoenix configuration based on environment variables.
Returns:
ArizePhoenixConfig: A Pydantic model containing Arize Phoenix configuration.
"""
api_key = os.environ.get("PHOENIX_API_KEY", None)
collector_endpoint = os.environ.get("PHOENIX_COLLECTOR_HTTP_ENDPOINT", None)
if not collector_endpoint:
grpc_endpoint = os.environ.get("PHOENIX_COLLECTOR_ENDPOINT", None)
http_endpoint = os.environ.get("PHOENIX_COLLECTOR_HTTP_ENDPOINT", None)
collector_endpoint = http_endpoint or grpc_endpoint
endpoint = None
protocol: Protocol = "otlp_http"
if collector_endpoint:
# Parse the endpoint to determine protocol
if collector_endpoint.startswith("grpc://") or (
":4317" in collector_endpoint and "/v1/traces" not in collector_endpoint
):
endpoint = collector_endpoint
protocol = "otlp_grpc"
else:
# Phoenix Cloud endpoints (app.phoenix.arize.com) include the space in the URL
if "app.phoenix.arize.com" in collector_endpoint:
endpoint = collector_endpoint
protocol = "otlp_http"
# For other HTTP endpoints, ensure they have the correct path
elif "/v1/traces" not in collector_endpoint:
if collector_endpoint.endswith("/v1"):
endpoint = collector_endpoint + "/traces"
elif collector_endpoint.endswith("/"):
endpoint = f"{collector_endpoint}v1/traces"
else:
endpoint = f"{collector_endpoint}/v1/traces"
else:
endpoint = collector_endpoint
protocol = "otlp_http"
else:
# If no endpoint specified, self hosted phoenix
endpoint = "http://localhost:6006/v1/traces"
protocol = "otlp_http"
verbose_logger.debug(
f"No PHOENIX_COLLECTOR_ENDPOINT found, using default local Phoenix endpoint: {endpoint}"
)
otlp_auth_headers = None
if api_key is not None:
otlp_auth_headers = f"Authorization=Bearer {api_key}"
elif "app.phoenix.arize.com" in endpoint:
# Phoenix Cloud requires an API key
raise ValueError(
"PHOENIX_API_KEY must be set when using Phoenix Cloud (app.phoenix.arize.com)."
)
project_name = os.environ.get("PHOENIX_PROJECT_NAME", "default")
return ArizePhoenixConfig(
otlp_auth_headers=otlp_auth_headers,
protocol=protocol,
endpoint=endpoint,
project_name=project_name,
)
## cannot suppress additional proxy server spans, removed previous methods.
async def async_health_check(self):
config = self.get_arize_phoenix_config()
if not config.otlp_auth_headers:
return {
"status": "unhealthy",
"error_message": "PHOENIX_API_KEY environment variable not set",
}
return {
"status": "healthy",
"message": "Arize-Phoenix credentials are configured properly",
}

View File

@@ -0,0 +1,108 @@
"""
Arize Phoenix API client for fetching prompt versions from Arize Phoenix.
"""
from typing import Any, Dict, Optional
from litellm.llms.custom_httpx.http_handler import HTTPHandler
class ArizePhoenixClient:
"""
Client for interacting with Arize Phoenix API to fetch prompt versions.
Supports:
- Authentication with Bearer tokens
- Fetching prompt versions
- Direct API base URL configuration
"""
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
"""
Initialize the Arize Phoenix client.
Args:
api_key: Arize Phoenix API token
api_base: Base URL for the Arize Phoenix API (e.g., 'https://app.phoenix.arize.com/s/workspace/v1')
"""
self.api_key = api_key
self.api_base = api_base
if not self.api_key:
raise ValueError("api_key is required")
if not self.api_base:
raise ValueError("api_base is required")
# Set up authentication headers
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
}
# Initialize HTTPHandler
self.http_handler = HTTPHandler(disable_default_headers=True)
def get_prompt_version(self, prompt_version_id: str) -> Optional[Dict[str, Any]]:
"""
Fetch a prompt version from Arize Phoenix.
Args:
prompt_version_id: The ID of the prompt version to fetch
Returns:
Dictionary containing prompt version data, or None if not found
"""
url = f"{self.api_base}/v1/prompt_versions/{prompt_version_id}"
try:
# Use the underlying httpx client directly to avoid query param extraction
response = self.http_handler.get(url, headers=self.headers)
response.raise_for_status()
data = response.json()
return data.get("data")
except Exception as e:
# Check if it's an HTTP error
response = getattr(e, "response", None)
if response is not None and hasattr(response, "status_code"):
if response.status_code == 404:
return None
elif response.status_code == 403:
raise Exception(
f"Access denied to prompt version '{prompt_version_id}'. Check your Arize Phoenix permissions."
)
elif response.status_code == 401:
raise Exception(
"Authentication failed. Check your Arize Phoenix API key and permissions."
)
else:
raise Exception(
f"Failed to fetch prompt version '{prompt_version_id}': {e}"
)
else:
raise Exception(
f"Error fetching prompt version '{prompt_version_id}': {e}"
)
def test_connection(self) -> bool:
"""
Test the connection to the Arize Phoenix API.
Returns:
True if connection is successful, False otherwise
"""
try:
# Try to access the prompt_versions endpoint to test connection
url = f"{self.api_base}/prompt_versions"
response = self.http_handler.client.get(url, headers=self.headers)
response.raise_for_status()
return True
except Exception:
return False
def close(self):
"""Close the HTTP handler to free resources."""
if hasattr(self, "http_handler"):
self.http_handler.close()

View File

@@ -0,0 +1,488 @@
"""
Arize Phoenix prompt manager that integrates with LiteLLM's prompt management system.
Fetches prompt versions from Arize Phoenix and provides workspace-based access control.
"""
from typing import Any, Dict, List, Optional, Tuple, Union
from jinja2 import DictLoader, Environment, select_autoescape
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.integrations.prompt_management_base import (
PromptManagementBase,
PromptManagementClient,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.prompts.init_prompts import PromptSpec
from litellm.types.utils import StandardCallbackDynamicParams
from .arize_phoenix_client import ArizePhoenixClient
class ArizePhoenixPromptTemplate:
"""
Represents a prompt template loaded from Arize Phoenix.
"""
def __init__(
self,
template_id: str,
messages: List[Dict[str, Any]],
metadata: Dict[str, Any],
model: Optional[str] = None,
):
self.template_id = template_id
self.messages = messages
self.metadata = metadata
self.model = model or metadata.get("model_name")
self.model_provider = metadata.get("model_provider")
self.temperature = metadata.get("temperature")
self.max_tokens = metadata.get("max_tokens")
self.invocation_parameters = metadata.get("invocation_parameters", {})
self.description = metadata.get("description", "")
self.template_format = metadata.get("template_format", "MUSTACHE")
def __repr__(self):
return (
f"ArizePhoenixPromptTemplate(id='{self.template_id}', model='{self.model}')"
)
class ArizePhoenixTemplateManager:
"""
Manager for loading and rendering prompt templates from Arize Phoenix.
Supports:
- Fetching prompt versions from Arize Phoenix API
- Workspace-based access control through Arize Phoenix permissions
- Mustache/Handlebars-style templating (using Jinja2)
- Model configuration and invocation parameters
- Multi-message chat templates
"""
def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
prompt_id: Optional[str] = None,
):
self.api_key = api_key
self.api_base = api_base
self.prompt_id = prompt_id
self.prompts: Dict[str, ArizePhoenixPromptTemplate] = {}
self.arize_client = ArizePhoenixClient(
api_key=self.api_key, api_base=self.api_base
)
self.jinja_env = Environment(
loader=DictLoader({}),
autoescape=select_autoescape(["html", "xml"]),
# Use Mustache/Handlebars-style delimiters
variable_start_string="{{",
variable_end_string="}}",
block_start_string="{%",
block_end_string="%}",
comment_start_string="{#",
comment_end_string="#}",
)
# Load prompt from Arize Phoenix if prompt_id is provided
if self.prompt_id:
self._load_prompt_from_arize(self.prompt_id)
def _load_prompt_from_arize(self, prompt_version_id: str) -> None:
"""Load a specific prompt version from Arize Phoenix."""
try:
# Fetch the prompt version from Arize Phoenix
prompt_data = self.arize_client.get_prompt_version(prompt_version_id)
if prompt_data:
template = self._parse_prompt_data(prompt_data, prompt_version_id)
self.prompts[prompt_version_id] = template
else:
raise ValueError(f"Prompt version '{prompt_version_id}' not found")
except Exception as e:
raise Exception(
f"Failed to load prompt version '{prompt_version_id}' from Arize Phoenix: {e}"
)
def _parse_prompt_data(
self, data: Dict[str, Any], prompt_version_id: str
) -> ArizePhoenixPromptTemplate:
"""Parse Arize Phoenix prompt data and extract messages and metadata."""
template_data = data.get("template", {})
messages = template_data.get("messages", [])
# Extract invocation parameters
invocation_params = data.get("invocation_parameters", {})
provider_params = {}
# Extract provider-specific parameters
if "openai" in invocation_params:
provider_params = invocation_params["openai"]
elif "anthropic" in invocation_params:
provider_params = invocation_params["anthropic"]
else:
# Try to find any nested provider params
for key, value in invocation_params.items():
if isinstance(value, dict):
provider_params = value
break
# Build metadata dictionary
metadata = {
"model_name": data.get("model_name"),
"model_provider": data.get("model_provider"),
"description": data.get("description", ""),
"template_type": data.get("template_type"),
"template_format": data.get("template_format", "MUSTACHE"),
"invocation_parameters": invocation_params,
"temperature": provider_params.get("temperature"),
"max_tokens": provider_params.get("max_tokens"),
}
return ArizePhoenixPromptTemplate(
template_id=prompt_version_id,
messages=messages,
metadata=metadata,
)
def render_template(
self, template_id: str, variables: Optional[Dict[str, Any]] = None
) -> List[AllMessageValues]:
"""Render a template with the given variables and return formatted messages."""
if template_id not in self.prompts:
raise ValueError(f"Template '{template_id}' not found")
template = self.prompts[template_id]
rendered_messages: List[AllMessageValues] = []
for message in template.messages:
role = message.get("role", "user")
content_parts = message.get("content", [])
# Render each content part
rendered_content_parts = []
for part in content_parts:
if part.get("type") == "text":
text = part.get("text", "")
# Render the text with Jinja2 (Mustache-style)
jinja_template = self.jinja_env.from_string(text)
rendered_text = jinja_template.render(**(variables or {}))
rendered_content_parts.append(rendered_text)
else:
# Handle other content types if needed
rendered_content_parts.append(part)
# Combine rendered content
final_content = " ".join(rendered_content_parts)
rendered_messages.append(
{"role": role, "content": final_content} # type: ignore
)
return rendered_messages
def get_template(self, template_id: str) -> Optional[ArizePhoenixPromptTemplate]:
"""Get a template by ID."""
return self.prompts.get(template_id)
def list_templates(self) -> List[str]:
"""List all available template IDs."""
return list(self.prompts.keys())
class ArizePhoenixPromptManager(CustomPromptManagement):
"""
Arize Phoenix prompt manager that integrates with LiteLLM's prompt management system.
This class enables using prompt versions from Arize Phoenix with the
litellm completion() function by implementing the PromptManagementBase interface.
Usage:
# Configure Arize Phoenix access
arize_config = {
"workspace": "your-workspace",
"access_token": "your-token",
}
# Use with completion
response = litellm.completion(
model="arize/gpt-4o",
prompt_id="UHJvbXB0VmVyc2lvbjox",
prompt_variables={"question": "What is AI?"},
arize_config=arize_config,
messages=[{"role": "user", "content": "This will be combined with the prompt"}]
)
"""
def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
prompt_id: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
self.api_key = api_key
self.api_base = api_base
self.prompt_id = prompt_id
self._prompt_manager: Optional[ArizePhoenixTemplateManager] = None
@property
def integration_name(self) -> str:
"""Integration name used in model names like 'arize/gpt-4o'."""
return "arize"
@property
def prompt_manager(self) -> ArizePhoenixTemplateManager:
"""Get or create the prompt manager instance."""
if self._prompt_manager is None:
self._prompt_manager = ArizePhoenixTemplateManager(
api_key=self.api_key,
api_base=self.api_base,
prompt_id=self.prompt_id,
)
return self._prompt_manager
def get_prompt_template(
self,
prompt_id: str,
prompt_variables: Optional[Dict[str, Any]] = None,
) -> Tuple[List[AllMessageValues], Dict[str, Any]]:
"""
Get a prompt template and render it with variables.
Args:
prompt_id: The ID of the prompt version
prompt_variables: Variables to substitute in the template
Returns:
Tuple of (rendered_messages, metadata)
"""
template = self.prompt_manager.get_template(prompt_id)
if not template:
raise ValueError(f"Prompt template '{prompt_id}' not found")
# Render the template
rendered_messages = self.prompt_manager.render_template(
prompt_id, prompt_variables or {}
)
# Extract metadata
metadata = {
"model": template.model,
"temperature": template.temperature,
"max_tokens": template.max_tokens,
}
# Add additional invocation parameters
invocation_params = template.invocation_parameters
provider_params = {}
if "openai" in invocation_params:
provider_params = invocation_params["openai"]
elif "anthropic" in invocation_params:
provider_params = invocation_params["anthropic"]
# Add any additional parameters
for key, value in provider_params.items():
if key not in metadata:
metadata[key] = value
return rendered_messages, metadata
def pre_call_hook(
self,
user_id: Optional[str],
messages: List[AllMessageValues],
function_call: Optional[Union[Dict[str, Any], str]] = None,
litellm_params: Optional[Dict[str, Any]] = None,
prompt_id: Optional[str] = None,
prompt_variables: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Tuple[List[AllMessageValues], Optional[Dict[str, Any]]]:
"""
Pre-call hook that processes the prompt template before making the LLM call.
"""
if not prompt_id:
return messages, litellm_params
try:
# Get the rendered messages and metadata
rendered_messages, prompt_metadata = self.get_prompt_template(
prompt_id, prompt_variables
)
# Merge rendered messages with existing messages
if rendered_messages:
# Prepend rendered messages to existing messages
final_messages = rendered_messages + messages
else:
final_messages = messages
# Update litellm_params with prompt metadata
if litellm_params is None:
litellm_params = {}
# Apply model and parameters from prompt metadata
if prompt_metadata.get("model") and not self.ignore_prompt_manager_model:
litellm_params["model"] = prompt_metadata["model"]
if not self.ignore_prompt_manager_optional_params:
for param in [
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
]:
if param in prompt_metadata:
litellm_params[param] = prompt_metadata[param]
return final_messages, litellm_params
except Exception as e:
# Log error but don't fail the call
import litellm
litellm._logging.verbose_proxy_logger.error(
f"Error in Arize Phoenix prompt pre_call_hook: {e}"
)
return messages, litellm_params
def get_available_prompts(self) -> List[str]:
"""Get list of available prompt IDs."""
return self.prompt_manager.list_templates()
def reload_prompts(self) -> None:
"""Reload prompts from Arize Phoenix."""
if self.prompt_id:
self._prompt_manager = None # Reset to force reload
self.prompt_manager # This will trigger reload
def should_run_prompt_management(
self,
prompt_id: Optional[str],
prompt_spec: Optional[PromptSpec],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> bool:
"""
Determine if prompt management should run based on the prompt_id.
For Arize Phoenix, we always return True and handle the prompt loading
in the _compile_prompt_helper method.
"""
return True
def _compile_prompt_helper(
self,
prompt_id: Optional[str],
prompt_spec: Optional[PromptSpec],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
) -> PromptManagementClient:
"""
Compile an Arize Phoenix prompt template into a PromptManagementClient structure.
This method:
1. Loads the prompt version from Arize Phoenix
2. Renders it with the provided variables
3. Returns formatted chat messages
4. Extracts model and optional parameters from metadata
"""
if prompt_id is None:
raise ValueError("prompt_id is required for Arize Phoenix prompt manager")
try:
# Load the prompt from Arize Phoenix if not already loaded
if prompt_id not in self.prompt_manager.prompts:
self.prompt_manager._load_prompt_from_arize(prompt_id)
# Get the rendered messages and metadata
rendered_messages, prompt_metadata = self.get_prompt_template(
prompt_id, prompt_variables
)
# Extract model from metadata (if specified)
template_model = prompt_metadata.get("model")
# Extract optional parameters from metadata
optional_params = {}
for param in [
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
]:
if param in prompt_metadata:
optional_params[param] = prompt_metadata[param]
return PromptManagementClient(
prompt_id=prompt_id,
prompt_template=rendered_messages,
prompt_template_model=template_model,
prompt_template_optional_params=optional_params,
completed_messages=None,
)
except Exception as e:
raise ValueError(f"Error compiling prompt '{prompt_id}': {e}")
async def async_compile_prompt_helper(
self,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_spec: Optional[PromptSpec] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
) -> PromptManagementClient:
"""
Async version of compile prompt helper. Since Arize Phoenix operations are synchronous,
this simply delegates to the sync version.
"""
if prompt_id is None:
raise ValueError("prompt_id is required for Arize Phoenix prompt manager")
return self._compile_prompt_helper(
prompt_id=prompt_id,
prompt_spec=prompt_spec,
prompt_variables=prompt_variables,
dynamic_callback_params=dynamic_callback_params,
prompt_label=prompt_label,
prompt_version=prompt_version,
)
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_spec: Optional[PromptSpec] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
ignore_prompt_manager_model: Optional[bool] = False,
ignore_prompt_manager_optional_params: Optional[bool] = False,
) -> Tuple[str, List[AllMessageValues], dict]:
"""
Get chat completion prompt from Arize Phoenix and return processed model, messages, and parameters.
"""
return PromptManagementBase.get_chat_completion_prompt(
self,
model,
messages,
non_default_params,
prompt_id,
prompt_variables,
dynamic_callback_params,
prompt_spec=prompt_spec,
prompt_label=prompt_label,
prompt_version=prompt_version,
ignore_prompt_manager_model=ignore_prompt_manager_model,
ignore_prompt_manager_optional_params=ignore_prompt_manager_optional_params,
)

View File

@@ -0,0 +1,105 @@
import datetime
import litellm
class AthinaLogger:
def __init__(self):
import os
self.athina_api_key = os.getenv("ATHINA_API_KEY")
self.headers = {
"athina-api-key": self.athina_api_key,
"Content-Type": "application/json",
}
self.athina_logging_url = (
os.getenv("ATHINA_BASE_URL", "https://log.athina.ai")
+ "/api/v1/log/inference"
)
self.additional_keys = [
"environment",
"prompt_slug",
"customer_id",
"customer_user_id",
"session_id",
"external_reference_id",
"context",
"expected_response",
"user_query",
"tags",
"user_feedback",
"model_options",
"custom_attributes",
]
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
import json
import traceback
try:
is_stream = kwargs.get("stream", False)
if is_stream:
if "complete_streaming_response" in kwargs:
# Log the completion response in streaming mode
completion_response = kwargs["complete_streaming_response"]
response_json = (
completion_response.model_dump() if completion_response else {}
)
else:
# Skip logging if the completion response is not available
return
else:
# Log the completion response in non streaming mode
response_json = response_obj.model_dump() if response_obj else {}
data = {
"language_model_id": kwargs.get("model"),
"request": kwargs,
"response": response_json,
"prompt_tokens": response_json.get("usage", {}).get("prompt_tokens"),
"completion_tokens": response_json.get("usage", {}).get(
"completion_tokens"
),
"total_tokens": response_json.get("usage", {}).get("total_tokens"),
}
if (
type(end_time) is datetime.datetime
and type(start_time) is datetime.datetime
):
data["response_time"] = int(
(end_time - start_time).total_seconds() * 1000
)
if "messages" in kwargs:
data["prompt"] = kwargs.get("messages", None)
# Directly add tools or functions if present
optional_params = kwargs.get("optional_params", {})
data.update(
(k, v)
for k, v in optional_params.items()
if k in ["tools", "functions"]
)
# Add additional metadata keys
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
if metadata:
for key in self.additional_keys:
if key in metadata:
data[key] = metadata[key]
response = litellm.module_level_client.post(
self.athina_logging_url,
headers=self.headers,
data=json.dumps(data, default=str),
)
if response.status_code != 200:
print_verbose(
f"Athina Logger Error - {response.text}, {response.status_code}"
)
else:
print_verbose(f"Athina Logger Succeeded - {response.text}")
except Exception as e:
print_verbose(
f"Athina Logger Error - {e}, Stack trace: {traceback.format_exc()}"
)
pass

View File

@@ -0,0 +1,3 @@
from litellm.integrations.azure_sentinel.azure_sentinel import AzureSentinelLogger
__all__ = ["AzureSentinelLogger"]

View File

@@ -0,0 +1,304 @@
"""
Azure Sentinel Integration - sends logs to Azure Log Analytics using Logs Ingestion API
Azure Sentinel uses Log Analytics workspaces for data storage. This integration sends
LiteLLM logs to the Log Analytics workspace using the Azure Monitor Logs Ingestion API.
Reference API: https://learn.microsoft.com/en-us/azure/azure-monitor/logs/logs-ingestion-api-overview
`async_log_success_event` - used by litellm proxy to send logs to Azure Sentinel
`async_log_failure_event` - used by litellm proxy to send failure logs to Azure Sentinel
For batching specific details see CustomBatchLogger class
"""
import asyncio
import os
import traceback
from typing import List, Optional
from litellm._logging import verbose_logger
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.utils import StandardLoggingPayload
class AzureSentinelLogger(CustomBatchLogger):
"""
Logger that sends LiteLLM logs to Azure Sentinel via Azure Monitor Logs Ingestion API
"""
def __init__(
self,
dcr_immutable_id: Optional[str] = None,
stream_name: Optional[str] = None,
endpoint: Optional[str] = None,
tenant_id: Optional[str] = None,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
**kwargs,
):
"""
Initialize Azure Sentinel logger using Logs Ingestion API
Args:
dcr_immutable_id (str, optional): Data Collection Rule (DCR) Immutable ID.
If not provided, will use AZURE_SENTINEL_DCR_IMMUTABLE_ID env var.
stream_name (str, optional): Stream name from DCR (e.g., "Custom-LiteLLM").
If not provided, will use AZURE_SENTINEL_STREAM_NAME env var or default to "Custom-LiteLLM".
endpoint (str, optional): Data Collection Endpoint (DCE) or DCR ingestion endpoint.
If not provided, will use AZURE_SENTINEL_ENDPOINT env var.
tenant_id (str, optional): Azure Tenant ID for OAuth2 authentication.
If not provided, will use AZURE_SENTINEL_TENANT_ID or AZURE_TENANT_ID env var.
client_id (str, optional): Azure Client ID (Application ID) for OAuth2 authentication.
If not provided, will use AZURE_SENTINEL_CLIENT_ID or AZURE_CLIENT_ID env var.
client_secret (str, optional): Azure Client Secret for OAuth2 authentication.
If not provided, will use AZURE_SENTINEL_CLIENT_SECRET or AZURE_CLIENT_SECRET env var.
"""
self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
self.dcr_immutable_id = dcr_immutable_id or os.getenv(
"AZURE_SENTINEL_DCR_IMMUTABLE_ID"
)
self.stream_name = stream_name or os.getenv(
"AZURE_SENTINEL_STREAM_NAME", "Custom-LiteLLM"
)
self.endpoint = endpoint or os.getenv("AZURE_SENTINEL_ENDPOINT")
self.tenant_id = (
tenant_id
or os.getenv("AZURE_SENTINEL_TENANT_ID")
or os.getenv("AZURE_TENANT_ID")
)
self.client_id = (
client_id
or os.getenv("AZURE_SENTINEL_CLIENT_ID")
or os.getenv("AZURE_CLIENT_ID")
)
self.client_secret = (
client_secret
or os.getenv("AZURE_SENTINEL_CLIENT_SECRET")
or os.getenv("AZURE_CLIENT_SECRET")
)
if not self.dcr_immutable_id:
raise ValueError(
"AZURE_SENTINEL_DCR_IMMUTABLE_ID is required. Set it as an environment variable or pass dcr_immutable_id parameter."
)
if not self.endpoint:
raise ValueError(
"AZURE_SENTINEL_ENDPOINT is required. Set it as an environment variable or pass endpoint parameter."
)
if not self.tenant_id:
raise ValueError(
"AZURE_SENTINEL_TENANT_ID or AZURE_TENANT_ID is required. Set it as an environment variable or pass tenant_id parameter."
)
if not self.client_id:
raise ValueError(
"AZURE_SENTINEL_CLIENT_ID or AZURE_CLIENT_ID is required. Set it as an environment variable or pass client_id parameter."
)
if not self.client_secret:
raise ValueError(
"AZURE_SENTINEL_CLIENT_SECRET or AZURE_CLIENT_SECRET is required. Set it as an environment variable or pass client_secret parameter."
)
# Build API endpoint: {Endpoint}/dataCollectionRules/{DCR Immutable ID}/streams/{Stream Name}?api-version=2023-01-01
self.api_endpoint = f"{self.endpoint.rstrip('/')}/dataCollectionRules/{self.dcr_immutable_id}/streams/{self.stream_name}?api-version=2023-01-01"
# OAuth2 scope for Azure Monitor
self.oauth_scope = "https://monitor.azure.com/.default"
self.oauth_token: Optional[str] = None
self.oauth_token_expires_at: Optional[float] = None
self.flush_lock = asyncio.Lock()
super().__init__(**kwargs, flush_lock=self.flush_lock)
asyncio.create_task(self.periodic_flush())
self.log_queue: List[StandardLoggingPayload] = []
async def _get_oauth_token(self) -> str:
"""
Get OAuth2 Bearer token for Azure Monitor Logs Ingestion API
Returns:
Bearer token string
"""
# Check if we have a valid cached token
import time
if (
self.oauth_token
and self.oauth_token_expires_at
and time.time() < self.oauth_token_expires_at - 60
): # Refresh 60 seconds before expiry
return self.oauth_token
# Get new token using client credentials flow
assert self.tenant_id is not None, "tenant_id is required"
assert self.client_id is not None, "client_id is required"
assert self.client_secret is not None, "client_secret is required"
token_url = (
f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/v2.0/token"
)
token_data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"scope": self.oauth_scope,
"grant_type": "client_credentials",
}
response = await self.async_httpx_client.post(
url=token_url,
data=token_data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
if response.status_code != 200:
raise Exception(
f"Failed to get OAuth2 token: {response.status_code} - {response.text}"
)
token_response = response.json()
self.oauth_token = token_response.get("access_token")
expires_in = token_response.get("expires_in", 3600)
if not self.oauth_token:
raise Exception("OAuth2 token response did not contain access_token")
# Cache token expiry time
import time
self.oauth_token_expires_at = time.time() + expires_in
return self.oauth_token
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Async Log success events to Azure Sentinel
- Gets StandardLoggingPayload from kwargs
- Adds to batch queue
- Flushes based on CustomBatchLogger settings
Raises:
Raises a NON Blocking verbose_logger.exception if an error occurs
"""
try:
verbose_logger.debug(
"Azure Sentinel: Logging - Enters logging function for model %s", kwargs
)
standard_logging_payload = kwargs.get("standard_logging_object", None)
if standard_logging_payload is None:
verbose_logger.warning(
"Azure Sentinel: standard_logging_object not found in kwargs"
)
return
self.log_queue.append(standard_logging_payload)
if len(self.log_queue) >= self.batch_size:
await self.async_send_batch()
except Exception as e:
verbose_logger.exception(
f"Azure Sentinel Layer Error - {str(e)}\n{traceback.format_exc()}"
)
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
"""
Async Log failure events to Azure Sentinel
- Gets StandardLoggingPayload from kwargs
- Adds to batch queue
- Flushes based on CustomBatchLogger settings
Raises:
Raises a NON Blocking verbose_logger.exception if an error occurs
"""
try:
verbose_logger.debug(
"Azure Sentinel: Logging - Enters failure logging function for model %s",
kwargs,
)
standard_logging_payload = kwargs.get("standard_logging_object", None)
if standard_logging_payload is None:
verbose_logger.warning(
"Azure Sentinel: standard_logging_object not found in kwargs"
)
return
self.log_queue.append(standard_logging_payload)
if len(self.log_queue) >= self.batch_size:
await self.async_send_batch()
except Exception as e:
verbose_logger.exception(
f"Azure Sentinel Layer Error - {str(e)}\n{traceback.format_exc()}"
)
pass
async def async_send_batch(self):
"""
Sends the batch of logs to Azure Monitor Logs Ingestion API
Raises:
Raises a NON Blocking verbose_logger.exception if an error occurs
"""
try:
if not self.log_queue:
return
verbose_logger.debug(
"Azure Sentinel - about to flush %s events", len(self.log_queue)
)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
# Get OAuth2 token
bearer_token = await self._get_oauth_token()
# Convert log queue to JSON array format expected by Logs Ingestion API
# Each log entry should be a JSON object in the array
body = safe_dumps(self.log_queue)
# Set headers for Logs Ingestion API
headers = {
"Authorization": f"Bearer {bearer_token}",
"Content-Type": "application/json",
}
# Send the request
response = await self.async_httpx_client.post(
url=self.api_endpoint, data=body.encode("utf-8"), headers=headers
)
if response.status_code not in [200, 204]:
verbose_logger.error(
"Azure Sentinel API error: status_code=%s, response=%s",
response.status_code,
response.text,
)
raise Exception(
f"Failed to send logs to Azure Sentinel: {response.status_code} - {response.text}"
)
verbose_logger.debug(
"Azure Sentinel: Response from API status_code: %s",
response.status_code,
)
except Exception as e:
verbose_logger.exception(
f"Azure Sentinel Error sending batch API - {str(e)}\n{traceback.format_exc()}"
)
finally:
self.log_queue.clear()

View File

@@ -0,0 +1,179 @@
{
"id": "chatcmpl-2299b6a2-82a3-465a-b47c-04e685a2227f",
"trace_id": "97311c60-9a61-4f48-a814-70139ee57868",
"call_type": "acompletion",
"cache_hit": null,
"stream": true,
"status": "success",
"custom_llm_provider": "openai",
"saved_cache_cost": 0.0,
"startTime": 1766000068.28466,
"endTime": 1766000070.07935,
"completionStartTime": 1766000070.07935,
"response_time": 1.79468512535095,
"model": "gpt-4o",
"metadata": {
"user_api_key_hash": null,
"user_api_key_alias": null,
"user_api_key_team_id": null,
"user_api_key_org_id": null,
"user_api_key_user_id": null,
"user_api_key_team_alias": null,
"user_api_key_user_email": null,
"spend_logs_metadata": null,
"requester_ip_address": null,
"requester_metadata": null,
"user_api_key_end_user_id": null,
"prompt_management_metadata": null,
"applied_guardrails": [],
"mcp_tool_call_metadata": null,
"vector_store_request_metadata": null,
"guardrail_information": null
},
"cache_key": null,
"response_cost": 0.00022500000000000002,
"total_tokens": 30,
"prompt_tokens": 10,
"completion_tokens": 20,
"request_tags": [],
"end_user": "",
"api_base": "",
"model_group": "",
"model_id": "",
"requester_ip_address": null,
"messages": [
{
"role": "user",
"content": "Hello, world!"
}
],
"response": {
"id": "chatcmpl-2299b6a2-82a3-465a-b47c-04e685a2227f",
"created": 1742855151,
"model": "gpt-4o",
"object": "chat.completion",
"system_fingerprint": null,
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "hi",
"role": "assistant",
"tool_calls": null,
"function_call": null,
"provider_specific_fields": null
}
}
],
"usage": {
"completion_tokens": 20,
"prompt_tokens": 10,
"total_tokens": 30,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
},
"model_parameters": {},
"hidden_params": {
"model_id": null,
"cache_key": null,
"api_base": "https://api.openai.com",
"response_cost": 0.00022500000000000002,
"additional_headers": {},
"litellm_overhead_time_ms": null,
"batch_models": null,
"litellm_model_name": "gpt-4o"
},
"model_map_information": {
"model_map_key": "gpt-4o",
"model_map_value": {
"key": "gpt-4o",
"max_tokens": 16384,
"max_input_tokens": 128000,
"max_output_tokens": 16384,
"input_cost_per_token": 2.5e-06,
"cache_creation_input_token_cost": null,
"cache_read_input_token_cost": 1.25e-06,
"input_cost_per_character": null,
"input_cost_per_token_above_128k_tokens": null,
"input_cost_per_query": null,
"input_cost_per_second": null,
"input_cost_per_audio_token": null,
"input_cost_per_token_batches": 1.25e-06,
"output_cost_per_token_batches": 5e-06,
"output_cost_per_token": 1e-05,
"output_cost_per_audio_token": null,
"output_cost_per_character": null,
"output_cost_per_token_above_128k_tokens": null,
"output_cost_per_character_above_128k_tokens": null,
"output_cost_per_second": null,
"output_cost_per_image": null,
"output_vector_size": null,
"litellm_provider": "openai",
"mode": "chat",
"supports_system_messages": true,
"supports_response_schema": true,
"supports_vision": true,
"supports_function_calling": true,
"supports_tool_choice": true,
"supports_assistant_prefill": false,
"supports_prompt_caching": true,
"supports_audio_input": false,
"supports_audio_output": false,
"supports_pdf_input": false,
"supports_embedding_image_input": false,
"supports_native_streaming": null,
"supports_web_search": true,
"search_context_cost_per_query": {
"search_context_size_low": 0.03,
"search_context_size_medium": 0.035,
"search_context_size_high": 0.05
},
"tpm": null,
"rpm": null,
"supported_openai_params": [
"frequency_penalty",
"logit_bias",
"logprobs",
"top_logprobs",
"max_tokens",
"max_completion_tokens",
"modalities",
"prediction",
"n",
"presence_penalty",
"seed",
"stop",
"stream",
"stream_options",
"temperature",
"top_p",
"tools",
"tool_choice",
"function_call",
"functions",
"max_retries",
"extra_headers",
"parallel_tool_calls",
"audio",
"response_format",
"user"
]
}
},
"error_str": null,
"error_information": {
"error_code": "",
"error_class": "",
"llm_provider": "",
"traceback": "",
"error_message": ""
},
"response_cost_failure_debug_info": null,
"guardrail_information": null,
"standard_built_in_tools_params": {
"web_search_options": null,
"file_search": null
}
}

View File

@@ -0,0 +1,400 @@
import asyncio
import os
import time
from litellm._uuid import uuid
from datetime import datetime, timedelta
from typing import List, Optional
from litellm._logging import verbose_logger
from litellm.constants import _DEFAULT_TTL_FOR_HTTPX_CLIENTS, AZURE_STORAGE_MSFT_VERSION
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.llms.azure.common_utils import get_azure_ad_token_from_entra_id
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.types.utils import StandardLoggingPayload
class AzureBlobStorageLogger(CustomBatchLogger):
def __init__(
self,
**kwargs,
):
try:
verbose_logger.debug(
"AzureBlobStorageLogger: in init azure blob storage logger"
)
# Env Variables used for Azure Storage Authentication
self.tenant_id = os.getenv("AZURE_STORAGE_TENANT_ID")
self.client_id = os.getenv("AZURE_STORAGE_CLIENT_ID")
self.client_secret = os.getenv("AZURE_STORAGE_CLIENT_SECRET")
self.azure_storage_account_key: Optional[str] = os.getenv(
"AZURE_STORAGE_ACCOUNT_KEY"
)
# Required Env Variables for Azure Storage
_azure_storage_account_name = os.getenv("AZURE_STORAGE_ACCOUNT_NAME")
if not _azure_storage_account_name:
raise ValueError(
"Missing required environment variable: AZURE_STORAGE_ACCOUNT_NAME"
)
self.azure_storage_account_name: str = _azure_storage_account_name
_azure_storage_file_system = os.getenv("AZURE_STORAGE_FILE_SYSTEM")
if not _azure_storage_file_system:
raise ValueError(
"Missing required environment variable: AZURE_STORAGE_FILE_SYSTEM"
)
self.azure_storage_file_system: str = _azure_storage_file_system
self._service_client = None
# Time that the azure service client expires, in order to reset the connection pool and keep it fresh
self._service_client_timeout: Optional[float] = None
# Internal variables used for Token based authentication
self.azure_auth_token: Optional[
str
] = None # the Azure AD token to use for Azure Storage API requests
self.token_expiry: Optional[
datetime
] = None # the expiry time of the currentAzure AD token
asyncio.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock()
self.log_queue: List[StandardLoggingPayload] = []
super().__init__(**kwargs, flush_lock=self.flush_lock)
except Exception as e:
verbose_logger.exception(
f"AzureBlobStorageLogger: Got exception on init AzureBlobStorageLogger client {str(e)}"
)
raise e
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Async Log success events to Azure Blob Storage
Raises:
Raises a NON Blocking verbose_logger.exception if an error occurs
"""
try:
self._premium_user_check()
verbose_logger.debug(
"AzureBlobStorageLogger: Logging - Enters logging function for model %s",
kwargs,
)
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if standard_logging_payload is None:
raise ValueError("standard_logging_payload is not set")
self.log_queue.append(standard_logging_payload)
except Exception as e:
verbose_logger.exception(f"AzureBlobStorageLogger Layer Error - {str(e)}")
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
"""
Async Log failure events to Azure Blob Storage
Raises:
Raises a NON Blocking verbose_logger.exception if an error occurs
"""
try:
self._premium_user_check()
verbose_logger.debug(
"AzureBlobStorageLogger: Logging - Enters logging function for model %s",
kwargs,
)
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if standard_logging_payload is None:
raise ValueError("standard_logging_payload is not set")
self.log_queue.append(standard_logging_payload)
except Exception as e:
verbose_logger.exception(f"AzureBlobStorageLogger Layer Error - {str(e)}")
pass
async def async_send_batch(self):
"""
Sends the in memory logs queue to Azure Blob Storage
Raises:
Raises a NON Blocking verbose_logger.exception if an error occurs
"""
try:
if not self.log_queue:
verbose_logger.exception("Datadog: log_queue does not exist")
return
verbose_logger.debug(
"AzureBlobStorageLogger - about to flush %s events",
len(self.log_queue),
)
for payload in self.log_queue:
await self.async_upload_payload_to_azure_blob_storage(payload=payload)
except Exception as e:
verbose_logger.exception(
f"AzureBlobStorageLogger Error sending batch API - {str(e)}"
)
async def async_upload_payload_to_azure_blob_storage(
self, payload: StandardLoggingPayload
):
"""
Uploads the payload to Azure Blob Storage using a 3-step process:
1. Create file resource
2. Append data
3. Flush the data
"""
try:
if self.azure_storage_account_key:
await self.upload_to_azure_data_lake_with_azure_account_key(
payload=payload
)
else:
# Get a valid token instead of always requesting a new one
await self.set_valid_azure_ad_token()
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
json_payload = (
safe_dumps(payload) + "\n"
) # Add newline for each log entry
payload_bytes = json_payload.encode("utf-8")
filename = f"{payload.get('id') or str(uuid.uuid4())}.json"
base_url = f"https://{self.azure_storage_account_name}.dfs.core.windows.net/{self.azure_storage_file_system}/{filename}"
# Execute the 3-step upload process
await self._create_file(async_client, base_url)
await self._append_data(async_client, base_url, json_payload)
await self._flush_data(async_client, base_url, len(payload_bytes))
verbose_logger.debug(
f"Successfully uploaded log to Azure Blob Storage: {filename}"
)
except Exception as e:
verbose_logger.exception(f"Error uploading to Azure Blob Storage: {str(e)}")
raise e
async def _create_file(self, client: AsyncHTTPHandler, base_url: str):
"""Helper method to create the file resource"""
try:
verbose_logger.debug(f"Creating file resource at: {base_url}")
headers = {
"x-ms-version": AZURE_STORAGE_MSFT_VERSION,
"Content-Length": "0",
"Authorization": f"Bearer {self.azure_auth_token}",
}
response = await client.put(f"{base_url}?resource=file", headers=headers)
response.raise_for_status()
verbose_logger.debug("Successfully created file resource")
except Exception as e:
verbose_logger.exception(f"Error creating file resource: {str(e)}")
raise
async def _append_data(
self, client: AsyncHTTPHandler, base_url: str, json_payload: str
):
"""Helper method to append data to the file"""
try:
verbose_logger.debug(f"Appending data to file: {base_url}")
headers = {
"x-ms-version": AZURE_STORAGE_MSFT_VERSION,
"Content-Type": "application/json",
"Authorization": f"Bearer {self.azure_auth_token}",
}
response = await client.patch(
f"{base_url}?action=append&position=0",
headers=headers,
data=json_payload,
)
response.raise_for_status()
verbose_logger.debug("Successfully appended data")
except Exception as e:
verbose_logger.exception(f"Error appending data: {str(e)}")
raise
async def _flush_data(self, client: AsyncHTTPHandler, base_url: str, position: int):
"""Helper method to flush the data"""
try:
verbose_logger.debug(f"Flushing data at position {position}")
headers = {
"x-ms-version": AZURE_STORAGE_MSFT_VERSION,
"Content-Length": "0",
"Authorization": f"Bearer {self.azure_auth_token}",
}
response = await client.patch(
f"{base_url}?action=flush&position={position}", headers=headers
)
response.raise_for_status()
verbose_logger.debug("Successfully flushed data")
except Exception as e:
verbose_logger.exception(f"Error flushing data: {str(e)}")
raise
####### Helper methods to managing Authentication to Azure Storage #######
##########################################################################
async def set_valid_azure_ad_token(self):
"""
Wrapper to set self.azure_auth_token to a valid Azure AD token, refreshing if necessary
Refreshes the token when:
- Token is expired
- Token is not set
"""
# Check if token needs refresh
if self._azure_ad_token_is_expired() or self.azure_auth_token is None:
verbose_logger.debug("Azure AD token needs refresh")
self.azure_auth_token = self.get_azure_ad_token_from_azure_storage(
tenant_id=self.tenant_id,
client_id=self.client_id,
client_secret=self.client_secret,
)
# Token typically expires in 1 hour
self.token_expiry = datetime.now() + timedelta(hours=1)
verbose_logger.debug(f"New token will expire at {self.token_expiry}")
def get_azure_ad_token_from_azure_storage(
self,
tenant_id: Optional[str],
client_id: Optional[str],
client_secret: Optional[str],
) -> str:
"""
Gets Azure AD token to use for Azure Storage API requests
"""
verbose_logger.debug("Getting Azure AD Token from Azure Storage")
verbose_logger.debug(
"tenant_id %s, client_id %s, client_secret %s",
tenant_id,
client_id,
client_secret,
)
if tenant_id is None:
raise ValueError(
"Missing required environment variable: AZURE_STORAGE_TENANT_ID"
)
if client_id is None:
raise ValueError(
"Missing required environment variable: AZURE_STORAGE_CLIENT_ID"
)
if client_secret is None:
raise ValueError(
"Missing required environment variable: AZURE_STORAGE_CLIENT_SECRET"
)
token_provider = get_azure_ad_token_from_entra_id(
tenant_id=tenant_id,
client_id=client_id,
client_secret=client_secret,
scope="https://storage.azure.com/.default",
)
token = token_provider()
verbose_logger.debug("azure auth token %s", token)
return token
def _azure_ad_token_is_expired(self):
"""
Returns True if Azure AD token is expired, False otherwise
"""
if self.azure_auth_token and self.token_expiry:
if datetime.now() + timedelta(minutes=5) >= self.token_expiry:
verbose_logger.debug("Azure AD token is expired. Requesting new token")
return True
return False
def _premium_user_check(self):
"""
Checks if the user is a premium user, raises an error if not
"""
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
if premium_user is not True:
raise ValueError(
f"AzureBlobStorageLogger is only available for premium users. {CommonProxyErrors.not_premium_user}"
)
async def get_service_client(self):
from azure.storage.filedatalake.aio import DataLakeServiceClient
# expire old clients to recover from connection issues
if (
self._service_client_timeout
and self._service_client
and self._service_client_timeout > time.time()
):
await self._service_client.close()
self._service_client = None
if not self._service_client:
self._service_client = DataLakeServiceClient(
account_url=f"https://{self.azure_storage_account_name}.dfs.core.windows.net",
credential=self.azure_storage_account_key,
)
self._service_client_timeout = time.time() + _DEFAULT_TTL_FOR_HTTPX_CLIENTS
return self._service_client
async def upload_to_azure_data_lake_with_azure_account_key(
self, payload: StandardLoggingPayload
):
"""
Uploads the payload to Azure Data Lake using the Azure SDK
This is used when Azure Storage Account Key is set - Azure Storage Account Key does not work directly with Azure Rest API
"""
# Create an async service client
service_client = await self.get_service_client()
# Get file system client
file_system_client = service_client.get_file_system_client(
file_system=self.azure_storage_file_system
)
try:
# Create directory with today's date
from datetime import datetime
today = datetime.now().strftime("%Y-%m-%d")
directory_client = file_system_client.get_directory_client(today)
# check if the directory exists
if not await directory_client.exists():
await directory_client.create_directory()
verbose_logger.debug(f"Created directory: {today}")
# Create a file client
file_name = f"{payload.get('id') or str(uuid.uuid4())}.json"
file_client = directory_client.get_file_client(file_name)
# Create the file
await file_client.create_file()
# Content to append
content = safe_dumps(payload).encode("utf-8")
# Append content to the file
await file_client.append_data(data=content, offset=0, length=len(content))
# Flush the content to finalize the file
await file_client.flush_data(position=len(content), offset=0)
verbose_logger.debug(
f"Successfully uploaded and wrote to {today}/{file_name}"
)
except Exception as e:
verbose_logger.exception(f"Error occurred: {str(e)}")

View File

@@ -0,0 +1,317 @@
# LiteLLM BitBucket Prompt Management
A powerful prompt management system for LiteLLM that fetches `.prompt` files from BitBucket repositories. This enables team-based prompt management with BitBucket's built-in access control and version control capabilities.
## Features
- **🏢 Team-based access control**: Leverage BitBucket's workspace and repository permissions
- **📁 Repository-based prompt storage**: Store prompts in BitBucket repositories
- **🔐 Multiple authentication methods**: Support for access tokens and basic auth
- **🎯 YAML frontmatter**: Define model, parameters, and schemas in file headers
- **🔧 Handlebars templating**: Use `{{variable}}` syntax with Jinja2 backend
- **✅ Input validation**: Automatic validation against defined schemas
- **🔗 LiteLLM integration**: Works seamlessly with `litellm.completion()`
- **💬 Smart message parsing**: Converts prompts to proper chat messages
- **⚙️ Parameter extraction**: Automatically applies model settings from prompts
## Quick Start
### 1. Set up BitBucket Repository
Create a repository in your BitBucket workspace and add `.prompt` files:
```
your-repo/
├── prompts/
│ ├── chat_assistant.prompt
│ ├── code_reviewer.prompt
│ └── data_analyst.prompt
```
### 2. Create a `.prompt` file
Create a file called `prompts/chat_assistant.prompt`:
```yaml
---
model: gpt-4
temperature: 0.7
max_tokens: 150
input:
schema:
user_message: string
system_context?: string
---
{% if system_context %}System: {{system_context}}
{% endif %}User: {{user_message}}
```
### 3. Configure BitBucket Access
#### Option A: Access Token (Recommended)
```python
import litellm
# Configure BitBucket access
bitbucket_config = {
"workspace": "your-workspace",
"repository": "your-repo",
"access_token": "your-access-token",
"branch": "main" # optional, defaults to main
}
# Set global BitBucket configuration
litellm.set_global_bitbucket_config(bitbucket_config)
```
#### Option B: Basic Authentication
```python
import litellm
# Configure BitBucket access with basic auth
bitbucket_config = {
"workspace": "your-workspace",
"repository": "your-repo",
"username": "your-username",
"access_token": "your-app-password", # Use app password for basic auth
"auth_method": "basic",
"branch": "main"
}
litellm.set_global_bitbucket_config(bitbucket_config)
```
### 4. Use with LiteLLM
```python
# Use with completion - the model prefix 'bitbucket/' tells LiteLLM to use BitBucket prompt management
response = litellm.completion(
model="bitbucket/gpt-4", # The actual model comes from the .prompt file
prompt_id="prompts/chat_assistant", # Location of the prompt file
prompt_variables={
"user_message": "What is machine learning?",
"system_context": "You are a helpful AI tutor."
},
# Any additional messages will be appended after the prompt
messages=[{"role": "user", "content": "Please explain it simply."}]
)
print(response.choices[0].message.content)
```
## Proxy Server Configuration
### 1. Create a `.prompt` file
Create `prompts/hello.prompt`:
```yaml
---
model: gpt-4
temperature: 0.7
---
System: You are a helpful assistant.
User: {{user_message}}
```
### 2. Setup config.yaml
```yaml
model_list:
- model_name: my-bitbucket-model
litellm_params:
model: bitbucket/gpt-4
prompt_id: "prompts/hello"
api_key: os.environ/OPENAI_API_KEY
litellm_settings:
global_bitbucket_config:
workspace: "your-workspace"
repository: "your-repo"
access_token: "your-access-token"
branch: "main"
```
### 3. Start the proxy
```bash
litellm --config config.yaml --detailed_debug
```
### 4. Test it!
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "my-bitbucket-model",
"messages": [{"role": "user", "content": "IGNORED"}],
"prompt_variables": {
"user_message": "What is the capital of France?"
}
}'
```
## Prompt File Format
### Basic Structure
```yaml
---
# Model configuration
model: gpt-4
temperature: 0.7
max_tokens: 500
# Input schema (optional)
input:
schema:
user_message: string
system_context?: string
---
System: You are a helpful {{role}} assistant.
User: {{user_message}}
```
### Advanced Features
**Multi-role conversations:**
```yaml
---
model: gpt-4
temperature: 0.3
---
System: You are a helpful coding assistant.
User: {{user_question}}
```
**Dynamic model selection:**
```yaml
---
model: "{{preferred_model}}" # Model can be a variable
temperature: 0.7
---
System: You are a helpful assistant specialized in {{domain}}.
User: {{user_message}}
```
## Team-Based Access Control
BitBucket's built-in permission system provides team-based access control:
1. **Workspace-level permissions**: Control access to entire workspaces
2. **Repository-level permissions**: Control access to specific repositories
3. **Branch-level permissions**: Control access to specific branches
4. **User and group management**: Manage team members and their access levels
### Setting up Team Access
1. **Create workspaces for each team**:
```
team-a-prompts/
team-b-prompts/
team-c-prompts/
```
2. **Configure repository permissions**:
- Grant read access to team members
- Grant write access to prompt maintainers
- Use branch protection rules for production prompts
3. **Use different access tokens**:
- Each team can have their own access token
- Tokens can be scoped to specific repositories
- Use app passwords for additional security
## API Reference
### BitBucket Configuration
```python
bitbucket_config = {
"workspace": str, # Required: BitBucket workspace name
"repository": str, # Required: Repository name
"access_token": str, # Required: BitBucket access token or app password
"branch": str, # Optional: Branch to fetch from (default: "main")
"base_url": str, # Optional: Custom BitBucket API URL
"auth_method": str, # Optional: "token" or "basic" (default: "token")
"username": str, # Optional: Username for basic auth
"base_url" : str # Optional: Incase where the base url is not https://api.bitbucket.org/2.0
}
```
### LiteLLM Integration
```python
response = litellm.completion(
model="bitbucket/<base_model>", # required (e.g., bitbucket/gpt-4)
prompt_id=str, # required - the .prompt filename without extension
prompt_variables=dict, # optional - variables for template rendering
bitbucket_config=dict, # optional - BitBucket configuration (if not set globally)
messages=list, # optional - additional messages
)
```
## Error Handling
The BitBucket integration provides detailed error messages for common issues:
- **Authentication errors**: Invalid access tokens or credentials
- **Permission errors**: Insufficient access to workspace/repository
- **File not found**: Missing .prompt files
- **Network errors**: Connection issues with BitBucket API
## Security Considerations
1. **Access Token Security**: Store access tokens securely using environment variables or secret management systems
2. **Repository Permissions**: Use BitBucket's permission system to control access
3. **Branch Protection**: Protect main branches from unauthorized changes
4. **Audit Logging**: BitBucket provides audit logs for all repository access
## Troubleshooting
### Common Issues
1. **"Access denied" errors**: Check your BitBucket permissions for the workspace and repository
2. **"Authentication failed" errors**: Verify your access token or credentials
3. **"File not found" errors**: Ensure the .prompt file exists in the specified branch
4. **Template rendering errors**: Check your Handlebars syntax in the .prompt file
### Debug Mode
Enable debug logging to troubleshoot issues:
```python
import litellm
litellm.set_verbose = True
# Your BitBucket prompt calls will now show detailed logs
response = litellm.completion(
model="bitbucket/gpt-4",
prompt_id="your_prompt",
prompt_variables={"key": "value"}
)
```
## Migration from File-Based Prompts
If you're currently using file-based prompts with the dotprompt integration, you can easily migrate to BitBucket:
1. **Upload your .prompt files** to a BitBucket repository
2. **Update your configuration** to use BitBucket instead of local files
3. **Set up team access** using BitBucket's permission system
4. **Update your code** to use `bitbucket/` model prefix instead of `dotprompt/`
This provides better collaboration, version control, and team-based access control for your prompts.

View File

@@ -0,0 +1,66 @@
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from .bitbucket_prompt_manager import BitBucketPromptManager
from litellm.types.prompts.init_prompts import PromptLiteLLMParams, PromptSpec
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.types.prompts.init_prompts import SupportedPromptIntegrations
from .bitbucket_prompt_manager import BitBucketPromptManager
# Global instances
global_bitbucket_config: Optional[dict] = None
def set_global_bitbucket_config(config: dict) -> None:
"""
Set the global BitBucket configuration for prompt management.
Args:
config: Dictionary containing BitBucket configuration
- workspace: BitBucket workspace name
- repository: Repository name
- access_token: BitBucket access token
- branch: Branch to fetch prompts from (default: main)
"""
import litellm
litellm.global_bitbucket_config = config # type: ignore
def prompt_initializer(
litellm_params: "PromptLiteLLMParams", prompt_spec: "PromptSpec"
) -> "CustomPromptManagement":
"""
Initialize a prompt from a BitBucket repository.
"""
bitbucket_config = getattr(litellm_params, "bitbucket_config", None)
prompt_id = getattr(litellm_params, "prompt_id", None)
if not bitbucket_config:
raise ValueError(
"bitbucket_config is required for BitBucket prompt integration"
)
try:
bitbucket_prompt_manager = BitBucketPromptManager(
bitbucket_config=bitbucket_config,
prompt_id=prompt_id,
)
return bitbucket_prompt_manager
except Exception as e:
raise e
prompt_initializer_registry = {
SupportedPromptIntegrations.BITBUCKET.value: prompt_initializer,
}
# Export public API
__all__ = [
"BitBucketPromptManager",
"set_global_bitbucket_config",
"global_bitbucket_config",
]

View File

@@ -0,0 +1,241 @@
"""
BitBucket API client for fetching .prompt files from BitBucket repositories.
"""
import base64
from typing import Any, Dict, List, Optional
from litellm.llms.custom_httpx.http_handler import HTTPHandler
class BitBucketClient:
"""
Client for interacting with BitBucket API to fetch .prompt files.
Supports:
- Authentication with access tokens
- Fetching file contents from repositories
- Team-based access control through BitBucket permissions
- Branch-specific file fetching
"""
def __init__(self, config: Dict[str, Any]):
"""
Initialize the BitBucket client.
Args:
config: Dictionary containing:
- workspace: BitBucket workspace name
- repository: Repository name
- access_token: BitBucket access token (or app password)
- branch: Branch to fetch from (default: main)
- base_url: Custom BitBucket API base URL (optional)
- auth_method: Authentication method ('token' or 'basic', default: 'token')
- username: Username for basic auth (optional)
"""
self.workspace = config.get("workspace")
self.repository = config.get("repository")
self.access_token = config.get("access_token")
self.branch = config.get("branch", "main")
self.base_url = config.get("", "https://api.bitbucket.org/2.0")
self.auth_method = config.get("auth_method", "token")
self.username = config.get("username")
if not all([self.workspace, self.repository, self.access_token]):
raise ValueError("workspace, repository, and access_token are required")
# Set up authentication headers
self.headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
if self.auth_method == "basic" and self.username:
# Use basic auth with username and app password
credentials = f"{self.username}:{self.access_token}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
self.headers["Authorization"] = f"Basic {encoded_credentials}"
else:
# Use token-based authentication (default)
self.headers["Authorization"] = f"Bearer {self.access_token}"
# Initialize HTTPHandler
self.http_handler = HTTPHandler()
def get_file_content(self, file_path: str) -> Optional[str]:
"""
Fetch the content of a file from the BitBucket repository.
Args:
file_path: Path to the file in the repository
Returns:
File content as string, or None if file not found
"""
url = f"{self.base_url}/repositories/{self.workspace}/{self.repository}/src/{self.branch}/{file_path}"
try:
response = self.http_handler.get(url, headers=self.headers)
response.raise_for_status()
# BitBucket returns file content as base64 encoded
if response.headers.get("content-type", "").startswith("text/"):
return response.text
else:
# For binary files or when content-type is not text, try to decode as base64
try:
return base64.b64decode(response.content).decode("utf-8")
except Exception:
return response.text
except Exception as e:
# Check if it's an HTTP error
if hasattr(e, "response") and hasattr(e.response, "status_code"):
if e.response.status_code == 404:
return None
elif e.response.status_code == 403:
raise Exception(
f"Access denied to file '{file_path}'. Check your BitBucket permissions for workspace '{self.workspace}' and repository '{self.repository}'."
)
elif e.response.status_code == 401:
raise Exception(
"Authentication failed. Check your BitBucket access token and permissions."
)
else:
raise Exception(f"Failed to fetch file '{file_path}': {e}")
else:
raise Exception(f"Error fetching file '{file_path}': {e}")
def list_files(
self, directory_path: str = "", file_extension: str = ".prompt"
) -> List[str]:
"""
List files in a directory with a specific extension.
Args:
directory_path: Directory path in the repository (empty for root)
file_extension: File extension to filter by (default: .prompt)
Returns:
List of file paths
"""
url = f"{self.base_url}/repositories/{self.workspace}/{self.repository}/src/{self.branch}/{directory_path}"
try:
response = self.http_handler.get(url, headers=self.headers)
response.raise_for_status()
data = response.json()
files = []
for item in data.get("values", []):
if item.get("type") == "commit_file":
file_path = item.get("path", "")
if file_path.endswith(file_extension):
files.append(file_path)
return files
except Exception as e:
# Check if it's an HTTP error
if hasattr(e, "response") and hasattr(e.response, "status_code"):
if e.response.status_code == 404:
return []
elif e.response.status_code == 403:
raise Exception(
f"Access denied to directory '{directory_path}'. Check your BitBucket permissions for workspace '{self.workspace}' and repository '{self.repository}'."
)
elif e.response.status_code == 401:
raise Exception(
"Authentication failed. Check your BitBucket access token and permissions."
)
else:
raise Exception(f"Failed to list files in '{directory_path}': {e}")
else:
raise Exception(f"Error listing files in '{directory_path}': {e}")
def get_repository_info(self) -> Dict[str, Any]:
"""
Get information about the repository.
Returns:
Dictionary containing repository information
"""
url = f"{self.base_url}/repositories/{self.workspace}/{self.repository}"
try:
response = self.http_handler.get(url, headers=self.headers)
response.raise_for_status()
return response.json()
except Exception as e:
raise Exception(f"Failed to get repository info: {e}")
def test_connection(self) -> bool:
"""
Test the connection to the BitBucket repository.
Returns:
True if connection is successful, False otherwise
"""
try:
self.get_repository_info()
return True
except Exception:
return False
def get_branches(self) -> List[Dict[str, Any]]:
"""
Get list of branches in the repository.
Returns:
List of branch information dictionaries
"""
url = f"{self.base_url}/repositories/{self.workspace}/{self.repository}/refs/branches"
try:
response = self.http_handler.get(url, headers=self.headers)
response.raise_for_status()
data = response.json()
return data.get("values", [])
except Exception as e:
raise Exception(f"Failed to get branches: {e}")
def get_file_metadata(self, file_path: str) -> Optional[Dict[str, Any]]:
"""
Get metadata about a file (size, last modified, etc.).
Args:
file_path: Path to the file in the repository
Returns:
Dictionary containing file metadata, or None if file not found
"""
url = f"{self.base_url}/repositories/{self.workspace}/{self.repository}/src/{self.branch}/{file_path}"
try:
# Use GET with Range header to get just the headers (HEAD equivalent)
headers = self.headers.copy()
headers["Range"] = "bytes=0-0" # Request only first byte to get headers
response = self.http_handler.get(url, headers=headers)
response.raise_for_status()
return {
"content_type": response.headers.get("content-type"),
"content_length": response.headers.get("content-length"),
"last_modified": response.headers.get("last-modified"),
}
except Exception as e:
# Check if it's an HTTP error
if hasattr(e, "response") and hasattr(e.response, "status_code"):
if e.response.status_code == 404:
return None
raise Exception(f"Failed to get file metadata for '{file_path}': {e}")
else:
raise Exception(f"Error getting file metadata for '{file_path}': {e}")
def close(self):
"""Close the HTTP handler to free resources."""
if hasattr(self, "http_handler"):
self.http_handler.close()

View File

@@ -0,0 +1,584 @@
"""
BitBucket prompt manager that integrates with LiteLLM's prompt management system.
Fetches .prompt files from BitBucket repositories and provides team-based access control.
"""
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from jinja2 import DictLoader, Environment, select_autoescape
from litellm.integrations.custom_prompt_management import CustomPromptManagement
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
from litellm.integrations.prompt_management_base import (
PromptManagementBase,
PromptManagementClient,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.prompts.init_prompts import PromptSpec
from litellm.types.utils import StandardCallbackDynamicParams
from .bitbucket_client import BitBucketClient
class BitBucketPromptTemplate:
"""
Represents a prompt template loaded from BitBucket.
"""
def __init__(
self,
template_id: str,
content: str,
metadata: Dict[str, Any],
model: Optional[str] = None,
):
self.template_id = template_id
self.content = content
self.metadata = metadata
self.model = model or metadata.get("model")
self.temperature = metadata.get("temperature")
self.max_tokens = metadata.get("max_tokens")
self.input_schema = metadata.get("input", {}).get("schema", {})
self.optional_params = {
k: v for k, v in metadata.items() if k not in ["model", "input", "content"]
}
def __repr__(self):
return f"BitBucketPromptTemplate(id='{self.template_id}', model='{self.model}')"
class BitBucketTemplateManager:
"""
Manager for loading and rendering .prompt files from BitBucket repositories.
Supports:
- Fetching .prompt files from BitBucket repositories
- Team-based access control through BitBucket permissions
- YAML frontmatter for metadata
- Handlebars-style templating (using Jinja2)
- Input/output schema validation
- Model configuration
"""
def __init__(
self,
bitbucket_config: Dict[str, Any],
prompt_id: Optional[str] = None,
):
self.bitbucket_config = bitbucket_config
self.prompt_id = prompt_id
self.prompts: Dict[str, BitBucketPromptTemplate] = {}
self.bitbucket_client = BitBucketClient(bitbucket_config)
self.jinja_env = Environment(
loader=DictLoader({}),
autoescape=select_autoescape(["html", "xml"]),
# Use Handlebars-style delimiters to match Dotprompt spec
variable_start_string="{{",
variable_end_string="}}",
block_start_string="{%",
block_end_string="%}",
comment_start_string="{#",
comment_end_string="#}",
)
# Load prompts from BitBucket if prompt_id is provided
if self.prompt_id:
self._load_prompt_from_bitbucket(self.prompt_id)
def _load_prompt_from_bitbucket(self, prompt_id: str) -> None:
"""Load a specific .prompt file from BitBucket."""
try:
# Fetch the .prompt file from BitBucket
prompt_content = self.bitbucket_client.get_file_content(
f"{prompt_id}.prompt"
)
if prompt_content:
template = self._parse_prompt_file(prompt_content, prompt_id)
self.prompts[prompt_id] = template
except Exception as e:
raise Exception(f"Failed to load prompt '{prompt_id}' from BitBucket: {e}")
def _parse_prompt_file(
self, content: str, prompt_id: str
) -> BitBucketPromptTemplate:
"""Parse a .prompt file content and extract metadata and template."""
# Split frontmatter and content
if content.startswith("---"):
parts = content.split("---", 2)
if len(parts) >= 3:
frontmatter_str = parts[1].strip()
template_content = parts[2].strip()
else:
frontmatter_str = ""
template_content = content
else:
frontmatter_str = ""
template_content = content
# Parse YAML frontmatter
metadata: Dict[str, Any] = {}
if frontmatter_str:
try:
import yaml
metadata = yaml.safe_load(frontmatter_str) or {}
except ImportError:
# Fallback to basic parsing if PyYAML is not available
metadata = self._parse_yaml_basic(frontmatter_str)
except Exception:
metadata = {}
return BitBucketPromptTemplate(
template_id=prompt_id,
content=template_content,
metadata=metadata,
)
def _parse_yaml_basic(self, yaml_str: str) -> Dict[str, Any]:
"""Basic YAML parser for simple cases when PyYAML is not available."""
result: Dict[str, Any] = {}
for line in yaml_str.split("\n"):
line = line.strip()
if ":" in line and not line.startswith("#"):
key, value = line.split(":", 1)
key = key.strip()
value = value.strip()
# Try to parse value as appropriate type
if value.lower() in ["true", "false"]:
result[key] = value.lower() == "true"
elif value.isdigit():
result[key] = int(value)
elif value.replace(".", "").isdigit():
result[key] = float(value)
else:
result[key] = value.strip("\"'")
return result
def render_template(
self, template_id: str, variables: Optional[Dict[str, Any]] = None
) -> str:
"""Render a template with the given variables."""
if template_id not in self.prompts:
raise ValueError(f"Template '{template_id}' not found")
template = self.prompts[template_id]
jinja_template = self.jinja_env.from_string(template.content)
return jinja_template.render(**(variables or {}))
def get_template(self, template_id: str) -> Optional[BitBucketPromptTemplate]:
"""Get a template by ID."""
return self.prompts.get(template_id)
def list_templates(self) -> List[str]:
"""List all available template IDs."""
return list(self.prompts.keys())
class BitBucketPromptManager(CustomPromptManagement):
"""
BitBucket prompt manager that integrates with LiteLLM's prompt management system.
This class enables using .prompt files from BitBucket repositories with the
litellm completion() function by implementing the PromptManagementBase interface.
Usage:
# Configure BitBucket access
bitbucket_config = {
"workspace": "your-workspace",
"repository": "your-repo",
"access_token": "your-token",
"branch": "main" # optional, defaults to main
}
# Use with completion
response = litellm.completion(
model="bitbucket/gpt-4",
prompt_id="my_prompt",
prompt_variables={"variable": "value"},
bitbucket_config=bitbucket_config,
messages=[{"role": "user", "content": "This will be combined with the prompt"}]
)
"""
def __init__(
self,
bitbucket_config: Dict[str, Any],
prompt_id: Optional[str] = None,
):
self.bitbucket_config = bitbucket_config
self.prompt_id = prompt_id
self._prompt_manager: Optional[BitBucketTemplateManager] = None
@property
def integration_name(self) -> str:
"""Integration name used in model names like 'bitbucket/gpt-4'."""
return "bitbucket"
@property
def prompt_manager(self) -> BitBucketTemplateManager:
"""Get or create the prompt manager instance."""
if self._prompt_manager is None:
self._prompt_manager = BitBucketTemplateManager(
bitbucket_config=self.bitbucket_config,
prompt_id=self.prompt_id,
)
return self._prompt_manager
def get_prompt_template(
self,
prompt_id: str,
prompt_variables: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict[str, Any]]:
"""
Get a prompt template and render it with variables.
Args:
prompt_id: The ID of the prompt template
prompt_variables: Variables to substitute in the template
Returns:
Tuple of (rendered_prompt, metadata)
"""
template = self.prompt_manager.get_template(prompt_id)
if not template:
raise ValueError(f"Prompt template '{prompt_id}' not found")
# Render the template
rendered_prompt = self.prompt_manager.render_template(
prompt_id, prompt_variables or {}
)
# Extract metadata
metadata = {
"model": template.model,
"temperature": template.temperature,
"max_tokens": template.max_tokens,
**template.optional_params,
}
return rendered_prompt, metadata
def pre_call_hook(
self,
user_id: Optional[str],
messages: List[AllMessageValues],
function_call: Optional[Union[Dict[str, Any], str]] = None,
litellm_params: Optional[Dict[str, Any]] = None,
prompt_id: Optional[str] = None,
prompt_variables: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Tuple[List[AllMessageValues], Optional[Dict[str, Any]]]:
"""
Pre-call hook that processes the prompt template before making the LLM call.
"""
if not prompt_id:
return messages, litellm_params
try:
# Get the rendered prompt and metadata
rendered_prompt, prompt_metadata = self.get_prompt_template(
prompt_id, prompt_variables
)
# Parse the rendered prompt into messages
parsed_messages = self._parse_prompt_to_messages(rendered_prompt)
# Merge with existing messages
if parsed_messages:
# If we have parsed messages, use them instead of the original messages
final_messages: List[AllMessageValues] = parsed_messages
else:
# If no messages were parsed, prepend the prompt to existing messages
final_messages = [
{"role": "user", "content": rendered_prompt} # type: ignore
] + messages
# Update litellm_params with prompt metadata
if litellm_params is None:
litellm_params = {}
# Apply model and parameters from prompt metadata
if prompt_metadata.get("model"):
litellm_params["model"] = prompt_metadata["model"]
for param in [
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
]:
if param in prompt_metadata:
litellm_params[param] = prompt_metadata[param]
return final_messages, litellm_params
except Exception as e:
# Log error but don't fail the call
import litellm
litellm._logging.verbose_proxy_logger.error(
f"Error in BitBucket prompt pre_call_hook: {e}"
)
return messages, litellm_params
def _parse_prompt_to_messages(self, prompt_content: str) -> List[AllMessageValues]:
"""
Parse prompt content into a list of messages.
Handles both simple prompts and multi-role conversations.
"""
messages = []
lines = prompt_content.strip().split("\n")
current_role = None
current_content = []
for line in lines:
line = line.strip()
if not line:
continue
# Check for role indicators
if line.lower().startswith("system:"):
if current_role and current_content:
messages.append(
{
"role": current_role,
"content": "\n".join(current_content).strip(),
} # type: ignore
)
current_role = "system"
current_content = [line[7:].strip()] # Remove "System:" prefix
elif line.lower().startswith("user:"):
if current_role and current_content:
messages.append(
{
"role": current_role,
"content": "\n".join(current_content).strip(),
} # type: ignore
)
current_role = "user"
current_content = [line[5:].strip()] # Remove "User:" prefix
elif line.lower().startswith("assistant:"):
if current_role and current_content:
messages.append(
{
"role": current_role,
"content": "\n".join(current_content).strip(),
} # type: ignore
)
current_role = "assistant"
current_content = [line[10:].strip()] # Remove "Assistant:" prefix
else:
# Continue building current message
current_content.append(line)
# Add the last message
if current_role and current_content:
messages.append(
{"role": current_role, "content": "\n".join(current_content).strip()}
)
# If no role indicators found, treat as a single user message
if not messages and prompt_content.strip():
messages = [{"role": "user", "content": prompt_content.strip()}] # type: ignore
return messages # type: ignore
def post_call_hook(
self,
user_id: Optional[str],
response: Any,
input_messages: List[AllMessageValues],
function_call: Optional[Union[Dict[str, Any], str]] = None,
litellm_params: Optional[Dict[str, Any]] = None,
prompt_id: Optional[str] = None,
prompt_variables: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""
Post-call hook for any post-processing after the LLM call.
"""
return response
def get_available_prompts(self) -> List[str]:
"""Get list of available prompt IDs."""
return self.prompt_manager.list_templates()
def reload_prompts(self) -> None:
"""Reload prompts from BitBucket."""
if self.prompt_id:
self._prompt_manager = None # Reset to force reload
self.prompt_manager # This will trigger reload
def should_run_prompt_management(
self,
prompt_id: Optional[str],
prompt_spec: Optional[PromptSpec],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> bool:
"""
Determine if prompt management should run based on the prompt_id.
For BitBucket, we always return True and handle the prompt loading
in the _compile_prompt_helper method.
"""
return prompt_id is not None
def _compile_prompt_helper(
self,
prompt_id: Optional[str],
prompt_spec: Optional[PromptSpec],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
) -> PromptManagementClient:
"""
Compile a BitBucket prompt template into a PromptManagementClient structure.
This method:
1. Loads the prompt template from BitBucket
2. Renders it with the provided variables
3. Converts the rendered text into chat messages
4. Extracts model and optional parameters from metadata
"""
if prompt_id is None:
raise ValueError("prompt_id is required for BitBucket prompt manager")
try:
# Load the prompt from BitBucket if not already loaded
if prompt_id not in self.prompt_manager.prompts:
self.prompt_manager._load_prompt_from_bitbucket(prompt_id)
# Get the rendered prompt and metadata
rendered_prompt, prompt_metadata = self.get_prompt_template(
prompt_id, prompt_variables
)
# Convert rendered content to chat messages
messages = self._parse_prompt_to_messages(rendered_prompt)
# Extract model from metadata (if specified)
template_model = prompt_metadata.get("model")
# Extract optional parameters from metadata
optional_params = {}
for param in [
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
]:
if param in prompt_metadata:
optional_params[param] = prompt_metadata[param]
return PromptManagementClient(
prompt_id=prompt_id,
prompt_template=messages,
prompt_template_model=template_model,
prompt_template_optional_params=optional_params,
completed_messages=None,
)
except Exception as e:
raise ValueError(f"Error compiling prompt '{prompt_id}': {e}")
async def async_compile_prompt_helper(
self,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_spec: Optional[PromptSpec] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
) -> PromptManagementClient:
"""
Async version of compile prompt helper. Since BitBucket operations use sync client,
this simply delegates to the sync version.
"""
if prompt_id is None:
raise ValueError("prompt_id is required for BitBucket prompt manager")
return self._compile_prompt_helper(
prompt_id=prompt_id,
prompt_spec=prompt_spec,
prompt_variables=prompt_variables,
dynamic_callback_params=dynamic_callback_params,
prompt_label=prompt_label,
prompt_version=prompt_version,
)
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_spec: Optional[PromptSpec] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
ignore_prompt_manager_model: Optional[bool] = False,
ignore_prompt_manager_optional_params: Optional[bool] = False,
) -> Tuple[str, List[AllMessageValues], dict]:
"""
Get chat completion prompt from BitBucket and return processed model, messages, and parameters.
"""
return PromptManagementBase.get_chat_completion_prompt(
self,
model,
messages,
non_default_params,
prompt_id,
prompt_variables,
dynamic_callback_params,
prompt_spec=prompt_spec,
prompt_label=prompt_label,
prompt_version=prompt_version,
)
async def async_get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
litellm_logging_obj: LiteLLMLoggingObj,
prompt_spec: Optional[PromptSpec] = None,
tools: Optional[List[Dict]] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
ignore_prompt_manager_model: Optional[bool] = False,
ignore_prompt_manager_optional_params: Optional[bool] = False,
) -> Tuple[str, List[AllMessageValues], dict]:
"""
Async version - delegates to PromptManagementBase async implementation.
"""
return await PromptManagementBase.async_get_chat_completion_prompt(
self,
model,
messages,
non_default_params,
prompt_id=prompt_id,
prompt_variables=prompt_variables,
litellm_logging_obj=litellm_logging_obj,
dynamic_callback_params=dynamic_callback_params,
prompt_spec=prompt_spec,
tools=tools,
prompt_label=prompt_label,
prompt_version=prompt_version,
ignore_prompt_manager_model=ignore_prompt_manager_model,
ignore_prompt_manager_optional_params=ignore_prompt_manager_optional_params,
)

View File

@@ -0,0 +1,422 @@
# What is this?
## Log success + failure events to Braintrust
import os
from datetime import datetime
from typing import Dict, Optional
import httpx
import litellm
from litellm import verbose_logger
from litellm.integrations.braintrust_mock_client import (
should_use_braintrust_mock,
create_mock_braintrust_client,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import (
HTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.utils import print_verbose
API_BASE = "https://api.braintrustdata.com/v1"
def get_utc_datetime():
import datetime as dt
if hasattr(dt, "UTC"):
return datetime.now(dt.UTC) # type: ignore
else:
return datetime.utcnow() # type: ignore
class BraintrustLogger(CustomLogger):
def __init__(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> None:
super().__init__()
self.is_mock_mode = should_use_braintrust_mock()
if self.is_mock_mode:
create_mock_braintrust_client()
verbose_logger.info(
"[BRAINTRUST MOCK] Braintrust logger initialized in mock mode"
)
self.validate_environment(api_key=api_key)
self.api_base = api_base or os.getenv("BRAINTRUST_API_BASE") or API_BASE
self.default_project_id = None
self.api_key: str = api_key or os.getenv("BRAINTRUST_API_KEY") # type: ignore
self.headers = {
"Authorization": "Bearer " + self.api_key,
"Content-Type": "application/json",
}
self._project_id_cache: Dict[
str, str
] = {} # Cache mapping project names to IDs
self.global_braintrust_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
self.global_braintrust_sync_http_handler = HTTPHandler()
def validate_environment(self, api_key: Optional[str]):
"""
Expects
BRAINTRUST_API_KEY
in the environment
"""
missing_keys = []
if api_key is None and os.getenv("BRAINTRUST_API_KEY", None) is None:
missing_keys.append("BRAINTRUST_API_KEY")
if len(missing_keys) > 0:
raise Exception("Missing keys={} in environment.".format(missing_keys))
def get_project_id_sync(self, project_name: str) -> str:
"""
Get project ID from name, using cache if available.
If project doesn't exist, creates it.
"""
if project_name in self._project_id_cache:
return self._project_id_cache[project_name]
try:
response = self.global_braintrust_sync_http_handler.post(
f"{self.api_base}/project",
headers=self.headers,
json={"name": project_name},
)
project_dict = response.json()
project_id = project_dict["id"]
self._project_id_cache[project_name] = project_id
return project_id
except httpx.HTTPStatusError as e:
raise Exception(f"Failed to register project: {e.response.text}")
async def get_project_id_async(self, project_name: str) -> str:
"""
Async version of get_project_id_sync
"""
if project_name in self._project_id_cache:
return self._project_id_cache[project_name]
try:
response = await self.global_braintrust_http_handler.post(
f"{self.api_base}/project/register",
headers=self.headers,
json={"name": project_name},
)
project_dict = response.json()
project_id = project_dict["id"]
self._project_id_cache[project_name] = project_id
return project_id
except httpx.HTTPStatusError as e:
raise Exception(f"Failed to register project: {e.response.text}")
async def create_default_project_and_experiment(self):
project = await self.global_braintrust_http_handler.post(
f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"}
)
project_dict = project.json()
self.default_project_id = project_dict["id"]
def create_sync_default_project_and_experiment(self):
project = self.global_braintrust_sync_http_handler.post(
f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"}
)
project_dict = project.json()
self.default_project_id = project_dict["id"]
def log_success_event( # noqa: PLR0915
self, kwargs, response_obj, start_time, end_time
):
verbose_logger.debug("REACHES BRAINTRUST SUCCESS")
try:
litellm_call_id = kwargs.get("litellm_call_id")
standard_logging_object = kwargs.get("standard_logging_object", {})
prompt = {"messages": kwargs.get("messages")}
output = None
choices = []
if response_obj is not None and (
kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse)
):
output = None
elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
output = response_obj["choices"][0]["message"].json()
choices = response_obj["choices"]
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
output = response_obj.choices[0].text
choices = response_obj.choices
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
output = response_obj["data"]
litellm_params = kwargs.get("litellm_params", {}) or {}
dynamic_metadata = litellm_params.get("metadata", {}) or {}
# Get project_id from metadata or create default if needed
project_id = dynamic_metadata.get("project_id")
if project_id is None:
project_name = dynamic_metadata.get("project_name")
project_id = (
self.get_project_id_sync(project_name) if project_name else None
)
if project_id is None:
if self.default_project_id is None:
self.create_sync_default_project_and_experiment()
project_id = self.default_project_id
tags = []
if isinstance(dynamic_metadata, dict):
for key, value in dynamic_metadata.items():
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
if (
litellm.langfuse_default_tags is not None
and isinstance(litellm.langfuse_default_tags, list)
and key in litellm.langfuse_default_tags
):
tags.append(f"{key}:{value}")
if (
isinstance(value, str) and key not in standard_logging_object
): # support logging dynamic metadata to braintrust
standard_logging_object[key] = value
cost = kwargs.get("response_cost", None)
metrics: Optional[dict] = None
usage_obj = getattr(response_obj, "usage", None)
if usage_obj and isinstance(usage_obj, litellm.Usage):
litellm.utils.get_logging_id(start_time, response_obj)
metrics = {
"prompt_tokens": usage_obj.prompt_tokens,
"completion_tokens": usage_obj.completion_tokens,
"total_tokens": usage_obj.total_tokens,
"total_cost": cost,
"time_to_first_token": end_time.timestamp()
- start_time.timestamp(),
"start": start_time.timestamp(),
"end": end_time.timestamp(),
}
# Allow metadata override for span name
span_name = dynamic_metadata.get("span_name", "Chat Completion")
# Span parents is a special case
span_parents = dynamic_metadata.get("span_parents")
# Convert comma-separated string to list if present
if span_parents:
span_parents = [s.strip() for s in span_parents.split(",") if s.strip()]
# Add optional span attributes only if present
span_attributes = {
"span_id": dynamic_metadata.get("span_id"),
"root_span_id": dynamic_metadata.get("root_span_id"),
"span_parents": span_parents,
}
request_data = {
"id": litellm_call_id,
"input": prompt["messages"],
"metadata": standard_logging_object,
"span_attributes": {"name": span_name, "type": "llm"},
}
# Braintrust cannot specify 'tags' for non-root spans
if dynamic_metadata.get("root_span_id") is None:
request_data["tags"] = tags
# Only add those that are not None (or falsy)
for key, value in span_attributes.items():
if value:
request_data[key] = value
if choices is not None:
request_data["output"] = [choice.dict() for choice in choices]
else:
request_data["output"] = output
if metrics is not None:
request_data["metrics"] = metrics
try:
print_verbose(
f"self.global_braintrust_sync_http_handler.post: {self.global_braintrust_sync_http_handler.post}"
)
self.global_braintrust_sync_http_handler.post(
url=f"{self.api_base}/project_logs/{project_id}/insert",
json={"events": [request_data]},
headers=self.headers,
)
if self.is_mock_mode:
print_verbose("[BRAINTRUST MOCK] Sync event successfully mocked")
except httpx.HTTPStatusError as e:
raise Exception(e.response.text)
except Exception as e:
raise e # don't use verbose_logger.exception, if exception is raised
async def async_log_success_event( # noqa: PLR0915
self, kwargs, response_obj, start_time, end_time
):
verbose_logger.debug("REACHES BRAINTRUST SUCCESS")
try:
litellm_call_id = kwargs.get("litellm_call_id")
standard_logging_object = kwargs.get("standard_logging_object", {})
prompt = {"messages": kwargs.get("messages")}
output = None
choices = []
if response_obj is not None and (
kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse)
):
output = None
elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
output = response_obj["choices"][0]["message"].json()
choices = response_obj["choices"]
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
output = response_obj.choices[0].text
choices = response_obj.choices
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
output = response_obj["data"]
litellm_params = kwargs.get("litellm_params", {})
dynamic_metadata = litellm_params.get("metadata", {}) or {}
# Get project_id from metadata or create default if needed
project_id = dynamic_metadata.get("project_id")
if project_id is None:
project_name = dynamic_metadata.get("project_name")
project_id = (
await self.get_project_id_async(project_name)
if project_name
else None
)
if project_id is None:
if self.default_project_id is None:
await self.create_default_project_and_experiment()
project_id = self.default_project_id
tags = []
if isinstance(dynamic_metadata, dict):
for key, value in dynamic_metadata.items():
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
if (
litellm.langfuse_default_tags is not None
and isinstance(litellm.langfuse_default_tags, list)
and key in litellm.langfuse_default_tags
):
tags.append(f"{key}:{value}")
if (
isinstance(value, str) and key not in standard_logging_object
): # support logging dynamic metadata to braintrust
standard_logging_object[key] = value
cost = kwargs.get("response_cost", None)
metrics: Optional[dict] = None
usage_obj = getattr(response_obj, "usage", None)
if usage_obj and isinstance(usage_obj, litellm.Usage):
litellm.utils.get_logging_id(start_time, response_obj)
metrics = {
"prompt_tokens": usage_obj.prompt_tokens,
"completion_tokens": usage_obj.completion_tokens,
"total_tokens": usage_obj.total_tokens,
"total_cost": cost,
"start": start_time.timestamp(),
"end": end_time.timestamp(),
}
api_call_start_time = kwargs.get("api_call_start_time")
completion_start_time = kwargs.get("completion_start_time")
if (
api_call_start_time is not None
and completion_start_time is not None
):
metrics["time_to_first_token"] = (
completion_start_time.timestamp()
- api_call_start_time.timestamp()
)
# Allow metadata override for span name
span_name = dynamic_metadata.get("span_name", "Chat Completion")
# Span parents is a special case
span_parents = dynamic_metadata.get("span_parents")
# Convert comma-separated string to list if present
if span_parents:
span_parents = [s.strip() for s in span_parents.split(",") if s.strip()]
# Add optional span attributes only if present
span_attributes = {
"span_id": dynamic_metadata.get("span_id"),
"root_span_id": dynamic_metadata.get("root_span_id"),
"span_parents": span_parents,
}
request_data = {
"id": litellm_call_id,
"input": prompt["messages"],
"output": output,
"metadata": standard_logging_object,
"span_attributes": {"name": span_name, "type": "llm"},
}
# Braintrust cannot specify 'tags' for non-root spans
if dynamic_metadata.get("root_span_id") is None:
request_data["tags"] = tags
# Only add those that are not None (or falsy)
for key, value in span_attributes.items():
if value:
request_data[key] = value
if choices is not None:
request_data["output"] = [choice.dict() for choice in choices]
else:
request_data["output"] = output
if metrics is not None:
request_data["metrics"] = metrics
try:
await self.global_braintrust_http_handler.post(
url=f"{self.api_base}/project_logs/{project_id}/insert",
json={"events": [request_data]},
headers=self.headers,
)
if self.is_mock_mode:
print_verbose("[BRAINTRUST MOCK] Async event successfully mocked")
except httpx.HTTPStatusError as e:
raise Exception(e.response.text)
except Exception as e:
raise e # don't use verbose_logger.exception, if exception is raised
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
return super().log_failure_event(kwargs, response_obj, start_time, end_time)

View File

@@ -0,0 +1,166 @@
"""
Mock HTTP client for Braintrust integration testing.
This module intercepts Braintrust API calls and returns successful mock responses,
allowing full code execution without making actual network calls.
Usage:
Set BRAINTRUST_MOCK=true in environment variables or config to enable mock mode.
"""
import os
import time
from urllib.parse import urlparse
from litellm._logging import verbose_logger
from litellm.integrations.mock_client_factory import (
MockClientConfig,
MockResponse,
create_mock_client_factory,
)
# Use factory for should_use_mock and MockResponse
# Braintrust uses both HTTPHandler (sync) and AsyncHTTPHandler (async)
# Braintrust needs endpoint-specific responses, so we use custom HTTPHandler.post patching
_config = MockClientConfig(
"BRAINTRUST",
"BRAINTRUST_MOCK",
default_latency_ms=100,
default_status_code=200,
default_json_data={"id": "mock-project-id", "status": "success"},
url_matchers=[
".braintrustdata.com",
"braintrustdata.com",
".braintrust.dev",
"braintrust.dev",
],
patch_async_handler=True, # Patch AsyncHTTPHandler.post for async calls
patch_sync_client=False, # HTTPHandler uses self.client.send(), not self.client.post()
patch_http_handler=False, # We use custom patching for endpoint-specific responses
)
# Get should_use_mock and create_mock_client from factory
# We need to call the factory's create_mock_client to patch AsyncHTTPHandler.post
(
create_mock_braintrust_factory_client,
should_use_braintrust_mock,
) = create_mock_client_factory(_config)
# Store original HTTPHandler.post method (Braintrust-specific for sync calls with custom logic)
_original_http_handler_post = None
_mocks_initialized = False
# Default mock latency in seconds
_MOCK_LATENCY_SECONDS = float(os.getenv("BRAINTRUST_MOCK_LATENCY_MS", "100")) / 1000.0
def _is_braintrust_url(url: str) -> bool:
"""Check if URL is a Braintrust API URL."""
if not isinstance(url, str):
return False
parsed = urlparse(url)
host = (parsed.hostname or "").lower()
if not host:
return False
return (
host == "braintrustdata.com"
or host.endswith(".braintrustdata.com")
or host == "braintrust.dev"
or host.endswith(".braintrust.dev")
)
def _mock_http_handler_post(
self,
url,
data=None,
json=None,
params=None,
headers=None,
timeout=None,
stream=False,
files=None,
content=None,
logging_obj=None,
):
"""Monkey-patched HTTPHandler.post that intercepts Braintrust calls with endpoint-specific responses."""
# Only mock Braintrust API calls
if isinstance(url, str) and _is_braintrust_url(url):
verbose_logger.info(f"[BRAINTRUST MOCK] POST to {url}")
time.sleep(_MOCK_LATENCY_SECONDS)
# Return appropriate mock response based on endpoint
if "/project" in url:
# Project creation/retrieval/register endpoint
project_name = json.get("name", "litellm") if json else "litellm"
mock_data = {"id": f"mock-project-id-{project_name}", "name": project_name}
elif "/project_logs" in url:
# Log insertion endpoint
mock_data = {"status": "success"}
else:
mock_data = _config.default_json_data
return MockResponse(
status_code=_config.default_status_code,
json_data=mock_data,
url=url,
elapsed_seconds=_MOCK_LATENCY_SECONDS,
)
if _original_http_handler_post is not None:
return _original_http_handler_post(
self,
url=url,
data=data,
json=json,
params=params,
headers=headers,
timeout=timeout,
stream=stream,
files=files,
content=content,
logging_obj=logging_obj,
)
raise RuntimeError("Original HTTPHandler.post not available")
def create_mock_braintrust_client():
"""
Monkey-patch HTTPHandler.post to intercept Braintrust sync calls.
Braintrust uses HTTPHandler for sync calls and AsyncHTTPHandler for async calls.
HTTPHandler.post uses self.client.send(), not self.client.post(), so we need
custom patching for sync (similar to Helicone).
AsyncHTTPHandler.post is patched by the factory.
We use custom patching instead of factory's patch_http_handler because we need
endpoint-specific responses (different for /project vs /project_logs).
This function is idempotent - it only initializes mocks once, even if called multiple times.
"""
global _original_http_handler_post, _mocks_initialized
if _mocks_initialized:
return
verbose_logger.debug("[BRAINTRUST MOCK] Initializing Braintrust mock client...")
from litellm.llms.custom_httpx.http_handler import HTTPHandler
if _original_http_handler_post is None:
_original_http_handler_post = HTTPHandler.post
HTTPHandler.post = _mock_http_handler_post # type: ignore
verbose_logger.debug("[BRAINTRUST MOCK] Patched HTTPHandler.post")
# CRITICAL: Call the factory's initialization function to patch AsyncHTTPHandler.post
# This is required for async calls to be mocked
create_mock_braintrust_factory_client()
verbose_logger.debug(
f"[BRAINTRUST MOCK] Mock latency set to {_MOCK_LATENCY_SECONDS*1000:.0f}ms"
)
verbose_logger.debug(
"[BRAINTRUST MOCK] Braintrust mock client initialization complete"
)
_mocks_initialized = True

View File

@@ -0,0 +1,458 @@
[
{
"id": "arize",
"displayName": "Arize",
"logo": "arize.png",
"supports_key_team_logging": true,
"dynamic_params": {
"arize_api_key": {
"type": "password",
"ui_name": "API Key",
"description": "Arize API key for authentication",
"required": true
},
"arize_space_id": {
"type": "password",
"ui_name": "Space ID",
"description": "Arize Space ID to identify your workspace",
"required": true
}
},
"description": "Arize Logging Integration"
},
{
"id": "braintrust",
"displayName": "Braintrust",
"logo": "braintrust.png",
"supports_key_team_logging": false,
"dynamic_params": {
"braintrust_api_key": {
"type": "password",
"ui_name": "API Key",
"description": "Braintrust API key for authentication",
"required": true
},
"braintrust_project_name": {
"type": "text",
"ui_name": "Project Name",
"description": "Name of the Braintrust project to log to",
"required": true
}
},
"description": "Braintrust Logging Integration"
},
{
"id": "generic_api",
"displayName": "Custom Callback API",
"logo": "custom.svg",
"supports_key_team_logging": true,
"dynamic_params": {
"GENERIC_LOGGER_ENDPOINT": {
"type": "text",
"ui_name": "Callback URL",
"description": "Your custom webhook/API endpoint URL to receive logs",
"required": true
},
"GENERIC_LOGGER_HEADERS": {
"type": "text",
"ui_name": "Headers",
"description": "Custom HTTP headers as a comma-separated string (e.g., Authorization: Bearer token, Content-Type: application/json)",
"required": false
}
},
"description": "Custom Callback API Logging Integration"
},
{
"id": "datadog",
"displayName": "Datadog",
"logo": "datadog.png",
"supports_key_team_logging": false,
"dynamic_params": {
"dd_api_key": {
"type": "password",
"ui_name": "API Key",
"description": "Datadog API key for authentication",
"required": true
},
"dd_site": {
"type": "text",
"ui_name": "Site",
"description": "Datadog site URL (e.g., us5.datadoghq.com)",
"required": true
}
},
"description": "Datadog Logging Integration"
},
{
"id": "datadog_metrics",
"displayName": "Datadog Metrics",
"logo": "datadog.png",
"supports_key_team_logging": false,
"dynamic_params": {
"dd_api_key": {
"type": "password",
"ui_name": "API Key",
"description": "Datadog API key for authentication",
"required": true
},
"dd_site": {
"type": "text",
"ui_name": "Site",
"description": "Datadog site URL (e.g., us5.datadoghq.com)",
"required": true
}
},
"description": "Datadog Custom Metrics Integration"
},
{
"id": "datadog_cost_management",
"displayName": "Datadog Cost Management",
"logo": "datadog.png",
"supports_key_team_logging": false,
"dynamic_params": {
"dd_api_key": {
"type": "password",
"ui_name": "API Key",
"description": "Datadog API key for authentication",
"required": true
},
"dd_app_key": {
"type": "password",
"ui_name": "App Key",
"description": "Datadog Application Key for Cloud Cost Management",
"required": true
},
"dd_site": {
"type": "text",
"ui_name": "Site",
"description": "Datadog site URL (e.g., us5.datadoghq.com)",
"required": true
}
},
"description": "Datadog Cloud Cost Management Integration"
},
{
"id": "lago",
"displayName": "Lago",
"logo": "lago.svg",
"supports_key_team_logging": false,
"dynamic_params": {
"lago_api_url": {
"type": "text",
"ui_name": "API URL",
"description": "Lago API base URL",
"required": true
},
"lago_api_key": {
"type": "password",
"ui_name": "API Key",
"description": "Lago API key for authentication",
"required": true
}
},
"description": "Lago Billing Logging Integration"
},
{
"id": "langfuse",
"displayName": "Langfuse",
"logo": "langfuse.png",
"supports_key_team_logging": true,
"dynamic_params": {
"langfuse_public_key": {
"type": "text",
"ui_name": "Public Key",
"description": "Langfuse public key",
"required": true
},
"langfuse_secret_key": {
"type": "password",
"ui_name": "Secret Key",
"description": "Langfuse secret key for authentication",
"required": true
},
"langfuse_host": {
"type": "text",
"ui_name": "Host URL",
"description": "Langfuse host URL (default: https://cloud.langfuse.com)",
"required": false
}
},
"description": "Langfuse v2 Logging Integration"
},
{
"id": "langfuse_otel",
"displayName": "Langfuse OTEL",
"logo": "langfuse.png",
"supports_key_team_logging": true,
"dynamic_params": {
"langfuse_public_key": {
"type": "text",
"ui_name": "Public Key",
"description": "Langfuse public key",
"required": true
},
"langfuse_secret_key": {
"type": "password",
"ui_name": "Secret Key",
"description": "Langfuse secret key for authentication",
"required": true
},
"langfuse_host": {
"type": "text",
"ui_name": "Host URL",
"description": "Langfuse host URL (default: https://cloud.langfuse.com)",
"required": false
}
},
"description": "Langfuse v3 OTEL Logging Integration"
},
{
"id": "langsmith",
"displayName": "LangSmith",
"logo": "langsmith.png",
"supports_key_team_logging": true,
"dynamic_params": {
"langsmith_api_key": {
"type": "password",
"ui_name": "API Key",
"description": "LangSmith API key for authentication",
"required": true
},
"langsmith_project": {
"type": "text",
"ui_name": "Project Name",
"description": "LangSmith project name (default: litellm-completion)",
"required": false
},
"langsmith_base_url": {
"type": "text",
"ui_name": "Base URL",
"description": "LangSmith base URL (default: https://api.smith.langchain.com)",
"required": false
},
"langsmith_sampling_rate": {
"type": "number",
"ui_name": "Sampling Rate",
"description": "Sampling rate for logging (0.0 to 1.0, default: 1.0)",
"required": false
},
"langsmith_tenant_id": {
"type": "text",
"ui_name": "Tenant ID",
"description": "LangSmith tenant ID for organization-scoped API keys (required when using org-scoped keys)",
"required": false
}
},
"description": "Langsmith Logging Integration"
},
{
"id": "openmeter",
"displayName": "OpenMeter",
"logo": "openmeter.png",
"supports_key_team_logging": false,
"dynamic_params": {
"openmeter_api_key": {
"type": "password",
"ui_name": "API Key",
"description": "OpenMeter API key for authentication",
"required": true
},
"openmeter_base_url": {
"type": "text",
"ui_name": "Base URL",
"description": "OpenMeter base URL (default: https://openmeter.cloud)",
"required": false
}
},
"description": "OpenMeter Logging Integration"
},
{
"id": "otel",
"displayName": "Open Telemetry",
"logo": "otel.png",
"supports_key_team_logging": false,
"dynamic_params": {
"otel_endpoint": {
"type": "text",
"ui_name": "Endpoint URL",
"description": "OpenTelemetry collector endpoint URL",
"required": true
},
"otel_headers": {
"type": "text",
"ui_name": "Headers",
"description": "Headers for OTEL exporter (e.g., x-honeycomb-team=YOUR_API_KEY)",
"required": false
}
},
"description": "OpenTelemetry Logging Integration"
},
{
"id": "s3",
"displayName": "S3",
"logo": "aws.svg",
"supports_key_team_logging": false,
"dynamic_params": {
"s3_bucket_name": {
"type": "text",
"ui_name": "Bucket Name",
"description": "AWS S3 bucket name to store logs",
"required": true
},
"s3_region_name": {
"type": "text",
"ui_name": "AWS Region",
"description": "AWS region name (e.g., us-east-1)",
"required": false
},
"s3_aws_access_key_id": {
"type": "password",
"ui_name": "AWS Access Key ID",
"description": "AWS access key ID for authentication",
"required": false
},
"s3_aws_secret_access_key": {
"type": "password",
"ui_name": "AWS Secret Access Key",
"description": "AWS secret access key for authentication",
"required": false
},
"s3_aws_session_token": {
"type": "password",
"ui_name": "AWS Session Token",
"description": "AWS session token for temporary credentials",
"required": false
},
"s3_endpoint_url": {
"type": "text",
"ui_name": "S3 Endpoint URL",
"description": "Custom S3 endpoint URL (for MinIO or custom S3-compatible services)",
"required": false
},
"s3_path": {
"type": "text",
"ui_name": "S3 Path Prefix",
"description": "Path prefix within the bucket for organizing logs",
"required": false
}
},
"description": "S3 Bucket (AWS) Logging Integration"
},
{
"id": "sqs",
"displayName": "SQS",
"logo": "aws.svg",
"supports_key_team_logging": false,
"dynamic_params": {
"sqs_queue_url": {
"type": "text",
"ui_name": "Queue URL",
"description": "AWS SQS Queue URL",
"required": true
},
"sqs_region_name": {
"type": "text",
"ui_name": "AWS Region",
"description": "AWS region name (e.g., us-east-1)",
"required": false
},
"sqs_aws_access_key_id": {
"type": "password",
"ui_name": "AWS Access Key ID",
"description": "AWS access key ID for authentication",
"required": false
},
"sqs_aws_secret_access_key": {
"type": "password",
"ui_name": "AWS Secret Access Key",
"description": "AWS secret access key for authentication",
"required": false
},
"sqs_aws_session_token": {
"type": "password",
"ui_name": "AWS Session Token",
"description": "AWS session token for temporary credentials",
"required": false
},
"sqs_aws_session_name": {
"type": "text",
"ui_name": "AWS Session Name",
"description": "Name for AWS session",
"required": false
},
"sqs_aws_profile_name": {
"type": "text",
"ui_name": "AWS Profile Name",
"description": "AWS profile name from credentials file",
"required": false
},
"sqs_aws_role_name": {
"type": "text",
"ui_name": "AWS Role Name",
"description": "AWS IAM role name to assume",
"required": false
},
"sqs_aws_web_identity_token": {
"type": "password",
"ui_name": "AWS Web Identity Token",
"description": "AWS web identity token for authentication",
"required": false
},
"sqs_aws_sts_endpoint": {
"type": "text",
"ui_name": "AWS STS Endpoint",
"description": "AWS STS endpoint URL",
"required": false
},
"sqs_endpoint_url": {
"type": "text",
"ui_name": "SQS Endpoint URL",
"description": "Custom SQS endpoint URL (for LocalStack or custom endpoints)",
"required": false
},
"sqs_api_version": {
"type": "text",
"ui_name": "API Version",
"description": "SQS API version",
"required": false
},
"sqs_use_ssl": {
"type": "boolean",
"ui_name": "Use SSL",
"description": "Whether to use SSL for SQS connections",
"required": false
},
"sqs_verify": {
"type": "boolean",
"ui_name": "Verify SSL",
"description": "Whether to verify SSL certificates",
"required": false
},
"sqs_strip_base64_files": {
"type": "boolean",
"ui_name": "Strip Base64 Files",
"description": "Remove base64-encoded files from logs to reduce payload size",
"required": false
},
"sqs_aws_use_application_level_encryption": {
"type": "boolean",
"ui_name": "Use Application-Level Encryption",
"description": "Enable application-level encryption for SQS messages",
"required": false
},
"sqs_app_encryption_key_b64": {
"type": "password",
"ui_name": "Encryption Key (Base64)",
"description": "Base64-encoded encryption key for application-level encryption",
"required": false
},
"sqs_app_encryption_aad": {
"type": "text",
"ui_name": "Encryption AAD",
"description": "Additional authenticated data for encryption",
"required": false
}
},
"description": "SQS Queue (AWS) Logging Integration"
}
]

View File

@@ -0,0 +1,422 @@
import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Optional, cast
import litellm
from litellm._logging import verbose_logger
from litellm.constants import CLOUDZERO_EXPORT_INTERVAL_MINUTES
from litellm.integrations.custom_logger import CustomLogger
if TYPE_CHECKING:
from apscheduler.schedulers.asyncio import AsyncIOScheduler
else:
AsyncIOScheduler = Any
class CloudZeroLogger(CustomLogger):
"""
CloudZero Logger for exporting LiteLLM usage data to CloudZero AnyCost API.
Environment Variables:
CLOUDZERO_API_KEY: CloudZero API key for authentication
CLOUDZERO_CONNECTION_ID: CloudZero connection ID for data submission
CLOUDZERO_TIMEZONE: Timezone for date handling (default: UTC)
"""
def __init__(
self,
api_key: Optional[str] = None,
connection_id: Optional[str] = None,
timezone: Optional[str] = None,
**kwargs,
):
"""Initialize CloudZero logger with configuration from parameters or environment variables."""
super().__init__(**kwargs)
# Get configuration from parameters first, fall back to environment variables
self.api_key = api_key or os.getenv("CLOUDZERO_API_KEY")
self.connection_id = connection_id or os.getenv("CLOUDZERO_CONNECTION_ID")
self.timezone = timezone or os.getenv("CLOUDZERO_TIMEZONE", "UTC")
verbose_logger.debug(
f"CloudZero Logger initialized with connection ID: {self.connection_id}, timezone: {self.timezone}"
)
async def initialize_cloudzero_export_job(self):
"""
Handler for initializing CloudZero export job.
Runs when CloudZero logger starts up.
- If redis cache is available, we use the pod lock manager to acquire a lock and export the data.
- Ensures only one pod exports the data at a time.
- If redis cache is not available, we export the data directly.
"""
from litellm.constants import (
CLOUDZERO_EXPORT_USAGE_DATA_JOB_NAME,
)
from litellm.proxy.proxy_server import proxy_logging_obj
pod_lock_manager = proxy_logging_obj.db_spend_update_writer.pod_lock_manager
# if using redis, ensure only one pod exports the data at a time
if pod_lock_manager and pod_lock_manager.redis_cache:
if await pod_lock_manager.acquire_lock(
cronjob_id=CLOUDZERO_EXPORT_USAGE_DATA_JOB_NAME
):
try:
await self._hourly_usage_data_export()
finally:
await pod_lock_manager.release_lock(
cronjob_id=CLOUDZERO_EXPORT_USAGE_DATA_JOB_NAME
)
else:
# if not using redis, export the data directly
await self._hourly_usage_data_export()
async def _hourly_usage_data_export(self):
"""
Exports the hourly usage data to CloudZero.
Start time: 1 hour ago
End time: current time
"""
from datetime import timedelta, timezone
from litellm.constants import CLOUDZERO_MAX_FETCHED_DATA_RECORDS
current_time_utc = datetime.now(timezone.utc)
# Mitigates the possibility of missing spend if an hour is skipped due to a restart in an ephemeral environment
one_hour_ago_utc = current_time_utc - timedelta(
minutes=CLOUDZERO_EXPORT_INTERVAL_MINUTES * 2
)
await self.export_usage_data(
limit=CLOUDZERO_MAX_FETCHED_DATA_RECORDS,
operation="replace_hourly",
start_time_utc=one_hour_ago_utc,
end_time_utc=current_time_utc,
)
async def export_usage_data(
self,
limit: Optional[int] = None,
operation: str = "replace_hourly",
start_time_utc: Optional[datetime] = None,
end_time_utc: Optional[datetime] = None,
):
"""
Exports the usage data to CloudZero.
- Reads data from the DB
- Transforms the data to the CloudZero format
- Sends the data to CloudZero
Args:
limit: Optional limit on number of records to export
operation: CloudZero operation type ("replace_hourly" or "sum")
"""
from litellm.integrations.cloudzero.cz_stream_api import CloudZeroStreamer
from litellm.integrations.cloudzero.database import LiteLLMDatabase
from litellm.integrations.cloudzero.transform import CBFTransformer
try:
verbose_logger.debug("CloudZero Logger: Starting usage data export")
# Validate required configuration
if not self.api_key or not self.connection_id:
raise ValueError(
"CloudZero configuration missing. Please set CLOUDZERO_API_KEY and CLOUDZERO_CONNECTION_ID environment variables."
)
# Initialize database connection and load data
database = LiteLLMDatabase()
verbose_logger.debug("CloudZero Logger: Loading usage data from database")
data = await database.get_usage_data(
limit=limit, start_time_utc=start_time_utc, end_time_utc=end_time_utc
)
if data.is_empty():
verbose_logger.debug("CloudZero Logger: No usage data found to export")
return
verbose_logger.debug(f"CloudZero Logger: Processing {len(data)} records")
# Transform data to CloudZero CBF format
transformer = CBFTransformer()
cbf_data = transformer.transform(data)
if cbf_data.is_empty():
verbose_logger.warning(
"CloudZero Logger: No valid data after transformation"
)
return
# Send data to CloudZero
streamer = CloudZeroStreamer(
api_key=self.api_key,
connection_id=self.connection_id,
user_timezone=self.timezone,
)
verbose_logger.debug(
f"CloudZero Logger: Transmitting {len(cbf_data)} records to CloudZero"
)
streamer.send_batched(cbf_data, operation=operation)
verbose_logger.debug(
f"CloudZero Logger: Successfully exported {len(cbf_data)} records to CloudZero"
)
except Exception as e:
verbose_logger.error(
f"CloudZero Logger: Error exporting usage data: {str(e)}"
)
raise
async def dry_run_export_usage_data(self, limit: Optional[int] = 10000):
"""
Returns the data that would be exported to CloudZero without actually sending it.
Args:
limit: Limit number of records to display (default: 10000)
Returns:
dict: Contains usage_data, cbf_data, and summary statistics
"""
from litellm.integrations.cloudzero.database import LiteLLMDatabase
from litellm.integrations.cloudzero.transform import CBFTransformer
try:
verbose_logger.debug("CloudZero Logger: Starting dry run export")
# Initialize database connection and load data
database = LiteLLMDatabase()
verbose_logger.debug("CloudZero Logger: Loading usage data for dry run")
data = await database.get_usage_data(limit=limit)
if data.is_empty():
verbose_logger.warning("CloudZero Dry Run: No usage data found")
return {
"usage_data": [],
"cbf_data": [],
"summary": {
"total_records": 0,
"total_cost": 0,
"total_tokens": 0,
"unique_accounts": 0,
"unique_services": 0,
},
}
verbose_logger.debug(
f"CloudZero Dry Run: Processing {len(data)} records..."
)
# Convert usage data to dict format for response
usage_data_sample = data.head(50).to_dicts() # Return first 50 rows
# Transform data to CloudZero CBF format
transformer = CBFTransformer()
cbf_data = transformer.transform(data)
if cbf_data.is_empty():
verbose_logger.warning(
"CloudZero Dry Run: No valid data after transformation"
)
return {
"usage_data": usage_data_sample,
"cbf_data": [],
"summary": {
"total_records": len(usage_data_sample),
"total_cost": sum(
row.get("spend", 0) for row in usage_data_sample
),
"total_tokens": sum(
row.get("prompt_tokens", 0)
+ row.get("completion_tokens", 0)
for row in usage_data_sample
),
"unique_accounts": 0,
"unique_services": 0,
},
}
# Convert CBF data to dict format for response
cbf_data_dict = cbf_data.to_dicts()
# Calculate summary statistics
total_cost = sum(record.get("cost/cost", 0) for record in cbf_data_dict)
unique_accounts = len(
set(
record.get("resource/account", "")
for record in cbf_data_dict
if record.get("resource/account")
)
)
unique_services = len(
set(
record.get("resource/service", "")
for record in cbf_data_dict
if record.get("resource/service")
)
)
total_tokens = sum(
record.get("usage/amount", 0) for record in cbf_data_dict
)
verbose_logger.debug(
f"CloudZero Logger: Dry run completed for {len(cbf_data)} records"
)
return {
"usage_data": usage_data_sample,
"cbf_data": cbf_data_dict,
"summary": {
"total_records": len(cbf_data_dict),
"total_cost": total_cost,
"total_tokens": total_tokens,
"unique_accounts": unique_accounts,
"unique_services": unique_services,
},
}
except Exception as e:
verbose_logger.error(f"CloudZero Logger: Error in dry run export: {str(e)}")
verbose_logger.error(f"CloudZero Dry Run Error: {str(e)}")
raise
def _display_cbf_data_on_screen(self, cbf_data):
"""Display CBF transformed data in a formatted table on screen."""
from rich.box import SIMPLE
from rich.console import Console
from rich.table import Table
console = Console()
if cbf_data.is_empty():
console.print("[yellow]No CBF data to display[/yellow]")
return
console.print(
f"\n[bold green]💰 CloudZero CBF Transformed Data ({len(cbf_data)} records)[/bold green]"
)
# Convert to dicts for easier processing
records = cbf_data.to_dicts()
# Create main CBF table
cbf_table = Table(
show_header=True, header_style="bold cyan", box=SIMPLE, padding=(0, 1)
)
cbf_table.add_column("time/usage_start", style="blue", no_wrap=False)
cbf_table.add_column("cost/cost", style="green", justify="right", no_wrap=False)
cbf_table.add_column(
"entity_type", style="magenta", justify="right", no_wrap=False
)
cbf_table.add_column(
"entity_id", style="magenta", justify="right", no_wrap=False
)
cbf_table.add_column("team_id", style="cyan", no_wrap=False)
cbf_table.add_column("team_alias", style="cyan", no_wrap=False)
cbf_table.add_column("user_email", style="cyan", no_wrap=False)
cbf_table.add_column("api_key_alias", style="yellow", no_wrap=False)
cbf_table.add_column(
"usage/amount", style="yellow", justify="right", no_wrap=False
)
cbf_table.add_column("resource/id", style="magenta", no_wrap=False)
cbf_table.add_column("resource/service", style="cyan", no_wrap=False)
cbf_table.add_column("resource/account", style="white", no_wrap=False)
cbf_table.add_column("resource/region", style="dim", no_wrap=False)
for record in records:
# Use proper CBF field names
time_usage_start = str(record.get("time/usage_start", "N/A"))
cost_cost = str(record.get("cost/cost", 0))
usage_amount = str(record.get("usage/amount", 0))
resource_id = str(record.get("resource/id", "N/A"))
resource_service = str(record.get("resource/service", "N/A"))
resource_account = str(record.get("resource/account", "N/A"))
resource_region = str(record.get("resource/region", "N/A"))
entity_type = str(record.get("entity_type", "N/A"))
entity_id = str(record.get("entity_id", "N/A"))
team_id = str(record.get("resource/tag:team_id", "N/A"))
team_alias = str(record.get("resource/tag:team_alias", "N/A"))
user_email = str(record.get("resource/tag:user_email", "N/A"))
api_key_alias = str(record.get("resource/tag:api_key_alias", "N/A"))
cbf_table.add_row(
time_usage_start,
cost_cost,
entity_type,
entity_id,
team_id,
team_alias,
user_email,
api_key_alias,
usage_amount,
resource_id,
resource_service,
resource_account,
resource_region,
)
console.print(cbf_table)
# Show summary statistics
total_cost = sum(record.get("cost/cost", 0) for record in records)
unique_accounts = len(
set(
record.get("resource/account", "")
for record in records
if record.get("resource/account")
)
)
unique_services = len(
set(
record.get("resource/service", "")
for record in records
if record.get("resource/service")
)
)
# Count total tokens from usage metrics
total_tokens = sum(record.get("usage/amount", 0) for record in records)
console.print("\n[bold blue]📊 CBF Summary[/bold blue]")
console.print(f" Records: {len(records):,}")
console.print(f" Total Cost: ${total_cost:.2f}")
console.print(f" Total Tokens: {total_tokens:,}")
console.print(f" Unique Accounts: {unique_accounts}")
console.print(f" Unique Services: {unique_services}")
console.print(
"\n[dim]💡 This is the CloudZero CBF format ready for AnyCost ingestion[/dim]"
)
@staticmethod
async def init_cloudzero_background_job(scheduler: AsyncIOScheduler):
"""
Initialize the CloudZero background job.
Starts the background job that exports the usage data to CloudZero every hour.
"""
from litellm.constants import CLOUDZERO_EXPORT_INTERVAL_MINUTES
from litellm.integrations.custom_logger import CustomLogger
prometheus_loggers: List[
CustomLogger
] = litellm.logging_callback_manager.get_custom_loggers_for_type(
callback_type=CloudZeroLogger
)
# we need to get the initialized prometheus logger instance(s) and call logger.initialize_remaining_budget_metrics() on them
verbose_logger.debug("found %s cloudzero loggers", len(prometheus_loggers))
if len(prometheus_loggers) > 0:
cloudzero_logger = cast(CloudZeroLogger, prometheus_loggers[0])
verbose_logger.debug(
"Initializing remaining budget metrics as a cron job executing every %s minutes"
% CLOUDZERO_EXPORT_INTERVAL_MINUTES
)
scheduler.add_job(
cloudzero_logger.initialize_cloudzero_export_job,
"interval",
minutes=CLOUDZERO_EXPORT_INTERVAL_MINUTES,
)

View File

@@ -0,0 +1,161 @@
# Copyright 2025 CloudZero
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# CHANGELOG: 2025-01-19 - Initial CZRN module for CloudZero Resource Names (erik.peterson)
"""CloudZero Resource Names (CZRN) generation and validation for LiteLLM resources."""
import re
from enum import Enum
from typing import Any, cast
import litellm
class CZEntityType(str, Enum):
TEAM = "team"
class CZRNGenerator:
"""Generate CloudZero Resource Names (CZRNs) for LiteLLM resources."""
CZRN_REGEX = re.compile(
r"^czrn:([a-z0-9-]+):([a-zA-Z0-9-]+):([a-z0-9-]+):([a-z0-9-]+):([a-z0-9-]+):(.+)$"
)
def __init__(self):
"""Initialize CZRN generator."""
pass
def create_from_litellm_data(self, row: dict[str, Any]) -> str:
"""Create a CZRN from LiteLLM daily spend data.
CZRN format: czrn:<service-type>:<provider>:<region>:<owner-account-id>:<resource-type>:<cloud-local-id>
For LiteLLM resources, we map:
- service-type: 'litellm' (the service managing the LLM calls)
- provider: The custom_llm_provider (e.g., 'openai', 'anthropic', 'azure')
- region: 'cross-region' (LiteLLM operates across regions)
- owner-account-id: The team_id or user_id (entity_id)
- resource-type: 'llm-usage' (represents LLM usage/inference)
- cloud-local-id: model
"""
service_type = "litellm"
provider = self._normalize_provider(row.get("custom_llm_provider", "unknown"))
region = "cross-region"
# Use the actual entity_id (team_id or user_id) as the owner account
team_id = row.get("team_id", "unknown")
owner_account_id = self._normalize_component(team_id)
resource_type = "llm-usage"
# Create a unique identifier with just the model (entity info already in owner_account_id)
model = row.get("model", "unknown")
cloud_local_id = model
return self.create_from_components(
service_type=service_type,
provider=provider,
region=region,
owner_account_id=owner_account_id,
resource_type=resource_type,
cloud_local_id=cloud_local_id,
)
def create_from_components(
self,
service_type: str,
provider: str,
region: str,
owner_account_id: str,
resource_type: str,
cloud_local_id: str,
) -> str:
"""Create a CZRN from individual components."""
# Normalize components to ensure they meet CZRN requirements
service_type = self._normalize_component(service_type, allow_uppercase=True)
provider = self._normalize_component(provider)
region = self._normalize_component(region)
owner_account_id = self._normalize_component(owner_account_id)
resource_type = self._normalize_component(resource_type)
# cloud_local_id can contain pipes and other characters, so don't normalize it
czrn = f"czrn:{service_type}:{provider}:{region}:{owner_account_id}:{resource_type}:{cloud_local_id}"
if not self.is_valid(czrn):
raise ValueError(f"Generated CZRN is invalid: {czrn}")
return czrn
def is_valid(self, czrn: str) -> bool:
"""Validate a CZRN string against the standard format."""
return bool(self.CZRN_REGEX.match(czrn))
def extract_components(self, czrn: str) -> tuple[str, str, str, str, str, str]:
"""Extract all components from a CZRN.
Returns: (service_type, provider, region, owner_account_id, resource_type, cloud_local_id)
"""
match = self.CZRN_REGEX.match(czrn)
if not match:
raise ValueError(f"Invalid CZRN format: {czrn}")
return cast(tuple[str, str, str, str, str, str], match.groups())
def _normalize_provider(self, provider: str) -> str:
"""Normalize provider names to standard CZRN format."""
# Map common provider names to CZRN standards
provider_map = {
litellm.LlmProviders.AZURE.value: "azure",
litellm.LlmProviders.AZURE_AI.value: "azure",
litellm.LlmProviders.ANTHROPIC.value: "anthropic",
litellm.LlmProviders.BEDROCK.value: "aws",
litellm.LlmProviders.VERTEX_AI.value: "gcp",
litellm.LlmProviders.GEMINI.value: "google",
litellm.LlmProviders.COHERE.value: "cohere",
litellm.LlmProviders.HUGGINGFACE.value: "huggingface",
litellm.LlmProviders.REPLICATE.value: "replicate",
litellm.LlmProviders.TOGETHER_AI.value: "together-ai",
}
normalized = provider.lower().replace("_", "-")
# use litellm custom llm provider if not in provider_map
if normalized not in provider_map:
return normalized
return provider_map.get(normalized, normalized)
def _normalize_component(
self, component: str, allow_uppercase: bool = False
) -> str:
"""Normalize a CZRN component to meet format requirements."""
if not component:
return "unknown"
# Convert to lowercase unless uppercase is allowed
if not allow_uppercase:
component = component.lower()
# Replace invalid characters with hyphens
component = re.sub(r"[^a-zA-Z0-9-]", "-", component)
# Remove consecutive hyphens
component = re.sub(r"-+", "-", component)
# Remove leading/trailing hyphens
component = component.strip("-")
return component or "unknown"

View File

@@ -0,0 +1,278 @@
# Copyright 2025 CloudZero
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# CHANGELOG: 2025-01-19 - Added pathlib for filesystem operations (erik.peterson)
# CHANGELOG: 2025-01-19 - Migrated from pandas to polars and requests to httpx (erik.peterson)
# CHANGELOG: 2025-01-19 - Initial output module for CSV and CloudZero API (erik.peterson)
"""Output modules for writing CBF data to various destinations."""
import zoneinfo
from datetime import datetime, timezone
from typing import Any, Optional, Union
import httpx
import polars as pl
from rich.console import Console
class CloudZeroStreamer:
"""Stream CBF data to CloudZero AnyCost API with proper batching and timezone handling."""
def __init__(
self, api_key: str, connection_id: str, user_timezone: Optional[str] = None
):
"""Initialize CloudZero streamer with credentials."""
self.api_key = api_key
self.connection_id = connection_id
self.base_url = "https://api.cloudzero.com"
self.console = Console()
# Set timezone - default to UTC
self.user_timezone: Union[zoneinfo.ZoneInfo, timezone]
if user_timezone:
try:
self.user_timezone = zoneinfo.ZoneInfo(user_timezone)
except zoneinfo.ZoneInfoNotFoundError:
self.console.print(
f"[yellow]Warning: Unknown timezone '{user_timezone}', using UTC[/yellow]"
)
self.user_timezone = timezone.utc
else:
self.user_timezone = timezone.utc
def send_batched(
self, data: pl.DataFrame, operation: str = "replace_hourly"
) -> None:
"""Send CBF data in daily batches to CloudZero AnyCost API."""
if data.is_empty():
self.console.print("[yellow]No data to send to CloudZero[/yellow]")
return
# Group data by date and send each day as a batch
daily_batches = self._group_by_date(data)
if not daily_batches:
self.console.print("[yellow]No valid daily batches to send[/yellow]")
return
self.console.print(
f"[blue]Sending {len(daily_batches)} daily batch(es) with operation '{operation}'[/blue]"
)
for batch_date, batch_data in daily_batches.items():
self._send_daily_batch(batch_date, batch_data, operation)
def _group_by_date(self, data: pl.DataFrame) -> dict[str, pl.DataFrame]:
"""Group data by date, converting to UTC and validating dates."""
daily_batches: dict[str, list[dict[str, Any]]] = {}
# Ensure we have the required columns
if "time/usage_start" not in data.columns:
self.console.print(
"[red]Error: Missing 'time/usage_start' column for date grouping[/red]"
)
return {}
timestamp_str: Optional[str] = None
for row in data.iter_rows(named=True):
try:
# Parse the timestamp and convert to UTC
timestamp_str = row.get("time/usage_start")
if not timestamp_str:
continue
# Parse timestamp and handle timezone conversion
dt = self._parse_and_convert_timestamp(timestamp_str)
batch_date = dt.strftime("%Y-%m-%d")
if batch_date not in daily_batches:
daily_batches[batch_date] = []
daily_batches[batch_date].append(row)
except Exception as e:
self.console.print(
f"[yellow]Warning: Could not process timestamp '{timestamp_str}': {e}[/yellow]"
)
continue
# Convert lists back to DataFrames
return {
date_key: pl.DataFrame(records)
for date_key, records in daily_batches.items()
if records
}
def _parse_and_convert_timestamp(self, timestamp_str: str) -> datetime:
"""Parse timestamp string and convert to UTC."""
# Try to parse the timestamp string
try:
# Handle various ISO 8601 formats
if timestamp_str.endswith("Z"):
dt = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00"))
elif "+" in timestamp_str or timestamp_str.endswith(
(
"-00:00",
"-01:00",
"-02:00",
"-03:00",
"-04:00",
"-05:00",
"-06:00",
"-07:00",
"-08:00",
"-09:00",
"-10:00",
"-11:00",
"-12:00",
"+01:00",
"+02:00",
"+03:00",
"+04:00",
"+05:00",
"+06:00",
"+07:00",
"+08:00",
"+09:00",
"+10:00",
"+11:00",
"+12:00",
)
):
dt = datetime.fromisoformat(timestamp_str)
else:
# Assume user timezone if no timezone info
dt = datetime.fromisoformat(timestamp_str)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=self.user_timezone)
# Convert to UTC
return dt.astimezone(timezone.utc)
except ValueError as e:
raise ValueError(f"Could not parse timestamp '{timestamp_str}': {e}")
def _send_daily_batch(
self, batch_date: str, batch_data: pl.DataFrame, operation: str
) -> None:
"""Send a single daily batch to CloudZero API."""
if batch_data.is_empty():
return
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
# Use the correct API endpoint format from documentation
url = f"{self.base_url}/v2/connections/billing/anycost/{self.connection_id}/billing_drops"
# Prepare the batch payload according to AnyCost API format
payload = self._prepare_batch_payload(batch_date, batch_data, operation)
try:
with httpx.Client(timeout=30.0) as client:
self.console.print(
f"[blue]Sending batch for {batch_date} ({len(batch_data)} records)[/blue]"
)
response = client.post(url, headers=headers, json=payload)
response.raise_for_status()
self.console.print(
f"[green]✓ Successfully sent batch for {batch_date} ({len(batch_data)} records)[/green]"
)
except httpx.RequestError as e:
self.console.print(
f"[red]✗ Network error sending batch for {batch_date}: {e}[/red]"
)
raise
except httpx.HTTPStatusError as e:
self.console.print(
f"[red]✗ HTTP error sending batch for {batch_date}: {e.response.status_code} {e.response.text}[/red]"
)
raise
def _prepare_batch_payload(
self, batch_date: str, batch_data: pl.DataFrame, operation: str
) -> dict[str, Any]:
"""Prepare batch payload according to CloudZero AnyCost API format."""
# Convert batch_date to month for the API (YYYY-MM format)
try:
date_obj = datetime.strptime(batch_date, "%Y-%m-%d")
month_str = date_obj.strftime("%Y-%m")
except ValueError:
# Fallback to current month
month_str = datetime.now().strftime("%Y-%m")
# Convert DataFrame rows to API format
data_records = []
for row in batch_data.iter_rows(named=True):
record = self._convert_cbf_to_api_format(row)
if record:
data_records.append(record)
payload = {"month": month_str, "operation": operation, "data": data_records}
return payload
def _convert_cbf_to_api_format(
self, row: dict[str, Any]
) -> Optional[dict[str, Any]]:
"""Convert CBF row to CloudZero API format - keeping CBF field names as CloudZero expects them."""
try:
# CloudZero expects CBF format field names directly, not converted names
api_record = {}
# Copy all CBF fields, converting numeric values to strings as required by CloudZero
for key, value in row.items():
if value is not None:
# CloudZero requires numeric values to be strings, but NOT in scientific notation
if isinstance(value, (int, float)):
# Format floats to avoid scientific notation
if isinstance(value, float):
# Use a reasonable precision that avoids scientific notation
api_record[key] = f"{value:.10f}".rstrip("0").rstrip(".")
else:
api_record[key] = str(value)
else:
api_record[key] = value
# Ensure timestamp is in UTC format
if "time/usage_start" in api_record:
api_record["time/usage_start"] = self._ensure_utc_timestamp(
api_record["time/usage_start"]
)
return api_record
except Exception as e:
self.console.print(
f"[yellow]Warning: Could not convert record to API format: {e}[/yellow]"
)
return None
def _ensure_utc_timestamp(self, timestamp_str: str) -> str:
"""Ensure timestamp is in UTC format for API."""
if not timestamp_str:
return datetime.now(timezone.utc).isoformat()
try:
dt = self._parse_and_convert_timestamp(timestamp_str)
return dt.isoformat().replace("+00:00", "Z")
except Exception:
# Fallback to current time in UTC
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")

View File

@@ -0,0 +1,101 @@
# Copyright 2025 CloudZero
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# CHANGELOG: 2025-01-19 - Refactored to use daily spend tables for proper CBF mapping (erik.peterson)
# CHANGELOG: 2025-01-19 - Migrated from pandas to polars for database operations (erik.peterson)
# CHANGELOG: 2025-01-19 - Initial database module for LiteLLM data extraction (erik.peterson)
"""Database connection and data extraction for LiteLLM."""
from datetime import datetime
from typing import Any, Optional, List
import polars as pl
class LiteLLMDatabase:
"""Handle LiteLLM PostgreSQL database connections and queries."""
def _ensure_prisma_client(self):
from litellm.proxy.proxy_server import prisma_client
"""Ensure prisma client is available."""
if prisma_client is None:
raise Exception(
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
return prisma_client
async def get_usage_data(
self,
limit: Optional[int] = None,
start_time_utc: Optional[datetime] = None,
end_time_utc: Optional[datetime] = None,
) -> pl.DataFrame:
"""Retrieve usage data from LiteLLM daily user spend table."""
client = self._ensure_prisma_client()
# Query to get user spend data with team information. Use parameter binding to
# avoid SQL injection from user-supplied timestamps or limits.
query = """
SELECT
dus.id,
dus.date,
dus.user_id,
dus.api_key,
dus.model,
dus.model_group,
dus.custom_llm_provider,
dus.prompt_tokens,
dus.completion_tokens,
dus.spend,
dus.api_requests,
dus.successful_requests,
dus.failed_requests,
dus.cache_creation_input_tokens,
dus.cache_read_input_tokens,
dus.created_at,
dus.updated_at,
vt.team_id,
vt.key_alias as api_key_alias,
tt.team_alias,
ut.user_email as user_email
FROM "LiteLLM_DailyUserSpend" dus
LEFT JOIN "LiteLLM_VerificationToken" vt ON dus.api_key = vt.token
LEFT JOIN "LiteLLM_TeamTable" tt ON vt.team_id = tt.team_id
LEFT JOIN "LiteLLM_UserTable" ut ON dus.user_id = ut.user_id
WHERE ($1::timestamptz IS NULL OR dus.updated_at >= $1::timestamptz)
AND ($2::timestamptz IS NULL OR dus.updated_at <= $2::timestamptz)
ORDER BY dus.date DESC, dus.created_at DESC
"""
params: List[Any] = [
start_time_utc,
end_time_utc,
]
if limit is not None:
try:
params.append(int(limit))
except (TypeError, ValueError):
raise ValueError("limit must be an integer")
query += " LIMIT $3"
try:
db_response = await client.db.query_raw(query, *params)
# Convert the response to polars DataFrame with full schema inference
# This prevents schema mismatch errors when data types vary across rows
return pl.DataFrame(db_response, infer_schema_length=None)
except Exception as e:
raise Exception(f"Error retrieving usage data: {str(e)}")

View File

@@ -0,0 +1,223 @@
# Copyright 2025 CloudZero
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# CHANGELOG: 2025-01-19 - Updated CBF transformation for daily spend tables and proper CloudZero mapping (erik.peterson)
# CHANGELOG: 2025-01-19 - Migrated from pandas to polars for data transformation (erik.peterson)
# CHANGELOG: 2025-01-19 - Initial CBF transformation module (erik.peterson)
"""Transform LiteLLM data to CloudZero AnyCost CBF format."""
from datetime import datetime
from typing import Any, Optional
import polars as pl
from ...types.integrations.cloudzero import CBFRecord
from .cz_resource_names import CZEntityType, CZRNGenerator
class CBFTransformer:
"""Transform LiteLLM usage data to CloudZero Billing Format (CBF)."""
def __init__(self):
"""Initialize transformer with CZRN generator."""
self.czrn_generator = CZRNGenerator()
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
"""Transform LiteLLM data to CBF format, dropping records with zero successful_requests or invalid CZRNs."""
if data.is_empty():
return pl.DataFrame()
# Filter out records with zero successful_requests first
original_count = len(data)
if "successful_requests" in data.columns:
filtered_data = data.filter(pl.col("successful_requests") > 0)
zero_requests_dropped = original_count - len(filtered_data)
else:
filtered_data = data
zero_requests_dropped = 0
cbf_data = []
czrn_dropped_count = 0
filtered_count = len(filtered_data)
for row in filtered_data.iter_rows(named=True):
try:
cbf_record = self._create_cbf_record(row)
# Only include the record if CZRN generation was successful
cbf_data.append(cbf_record)
except Exception:
# Skip records that fail CZRN generation
czrn_dropped_count += 1
continue
# Print summary of dropped records if any
from rich.console import Console
console = Console()
if zero_requests_dropped > 0:
console.print(
f"[yellow]⚠️ Dropped {zero_requests_dropped:,} of {original_count:,} records with zero successful_requests[/yellow]"
)
if czrn_dropped_count > 0:
console.print(
f"[yellow]⚠️ Dropped {czrn_dropped_count:,} of {filtered_count:,} filtered records due to invalid CZRNs[/yellow]"
)
if len(cbf_data) > 0:
console.print(
f"[green]✓ Successfully transformed {len(cbf_data):,} records[/green]"
)
return pl.DataFrame(cbf_data)
def _create_cbf_record(self, row: dict[str, Any]) -> CBFRecord:
"""Create a single CBF record from LiteLLM daily spend row."""
# Parse date (daily spend tables use date strings like '2025-04-19')
usage_date = self._parse_date(row.get("date"))
# Calculate total tokens
prompt_tokens = int(row.get("prompt_tokens", 0))
completion_tokens = int(row.get("completion_tokens", 0))
total_tokens = prompt_tokens + completion_tokens
# Create CloudZero Resource Name (CZRN) as resource_id
resource_id = self.czrn_generator.create_from_litellm_data(row)
# Build dimensions for CloudZero
model = str(row.get("model", ""))
api_key_hash = str(row.get("api_key", ""))[
:8
] # First 8 chars for identification
# Handle team information with fallbacks
team_id = row.get("team_id")
team_alias = row.get("team_alias")
user_email = row.get("user_email")
# Use team_alias if available, otherwise team_id, otherwise fallback to 'unknown'
entity_id = (
str(team_alias) if team_alias else (str(team_id) if team_id else "unknown")
)
# Get alias fields if they exist
api_key_alias = row.get("api_key_alias")
organization_alias = row.get("organization_alias")
project_alias = row.get("project_alias")
user_alias = row.get("user_alias")
dimensions = {
"entity_type": CZEntityType.TEAM.value,
"entity_id": entity_id,
"team_alias": str(team_alias) if team_alias else "unknown",
"model": model,
"model_group": str(row.get("model_group", "")),
"provider": str(row.get("custom_llm_provider", "")),
"api_key_prefix": api_key_hash,
"api_key_alias": str(row.get("api_key_alias", "")),
"user_email": str(user_email) if user_email else "",
"api_requests": str(row.get("api_requests", 0)),
"successful_requests": str(row.get("successful_requests", 0)),
"failed_requests": str(row.get("failed_requests", 0)),
"cache_creation_tokens": str(row.get("cache_creation_input_tokens", 0)),
"cache_read_tokens": str(row.get("cache_read_input_tokens", 0)),
"organization_alias": str(organization_alias) if organization_alias else "",
"project_alias": str(project_alias) if project_alias else "",
"user_alias": str(user_alias) if user_alias else "",
}
# Extract CZRN components to populate corresponding CBF columns
czrn_components = self.czrn_generator.extract_components(resource_id)
(
service_type,
provider,
region,
owner_account_id,
resource_type,
cloud_local_id,
) = czrn_components
# Build resource/account as concat of api_key_alias and api_key_prefix
resource_account = (
f"{api_key_alias}|{api_key_hash}" if api_key_alias else api_key_hash
)
# CloudZero CBF format with proper column names
cbf_record = {
# Required CBF fields
"time/usage_start": usage_date.isoformat()
if usage_date
else None, # Required: ISO-formatted UTC datetime
"cost/cost": float(row.get("spend", 0.0)), # Required: billed cost
"resource/id": resource_id, # CZRN (CloudZero Resource Name)
# Usage metrics for token consumption
"usage/amount": total_tokens, # Numeric value of tokens consumed
"usage/units": "tokens", # Description of token units
# CBF fields - updated per LIT-1907
"resource/service": str(row.get("model_group", "")), # Send model_group
"resource/account": resource_account, # Send api_key_alias|api_key_prefix
"resource/region": region, # Maps to CZRN region (cross-region)
"resource/usage_family": str(
row.get("custom_llm_provider", "")
), # Send provider
# Action field
"action/operation": str(team_id) if team_id else "", # Send team_id
# Line item details
"lineitem/type": "Usage", # Standard usage line item
}
# Add CZRN components that don't have direct CBF column mappings as resource tags
cbf_record["resource/tag:provider"] = provider # CZRN provider component
cbf_record[
"resource/tag:model"
] = cloud_local_id # CZRN cloud-local-id component (model)
# Add resource tags for all dimensions (using resource/tag:<key> format)
for key, value in dimensions.items():
if (
value and value != "N/A" and value != "unknown"
): # Only add meaningful tags
cbf_record[f"resource/tag:{key}"] = str(value)
# Add token breakdown as resource tags for analysis (excluding total_tokens per LIT-1907)
if prompt_tokens > 0:
cbf_record["resource/tag:prompt_tokens"] = str(prompt_tokens)
if completion_tokens > 0:
cbf_record["resource/tag:completion_tokens"] = str(completion_tokens)
return CBFRecord(cbf_record)
def _parse_date(self, date_str) -> Optional[datetime]:
"""Parse date string from daily spend tables (e.g., '2025-04-19')."""
if date_str is None:
return None
if isinstance(date_str, datetime):
return date_str
if isinstance(date_str, str):
try:
# Parse date string and set to midnight UTC for daily aggregation
return pl.Series([date_str]).str.to_datetime("%Y-%m-%d").item()
except Exception:
try:
# Fallback: try ISO format parsing
return pl.Series([date_str]).str.to_datetime().item()
except Exception:
return None
return None

View File

@@ -0,0 +1,58 @@
"""
Custom Logger that handles batching logic
Use this if you want your logs to be stored in memory and flushed periodically.
"""
import asyncio
import time
from typing import List, Optional
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
class CustomBatchLogger(CustomLogger):
def __init__(
self,
flush_lock: Optional[asyncio.Lock] = None,
batch_size: Optional[int] = None,
flush_interval: Optional[int] = None,
**kwargs,
) -> None:
"""
Args:
flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching
"""
self.log_queue: List = []
self.flush_interval = flush_interval or litellm.DEFAULT_FLUSH_INTERVAL_SECONDS
self.batch_size: int = batch_size or litellm.DEFAULT_BATCH_SIZE
self.last_flush_time = time.time()
self.flush_lock = flush_lock
super().__init__(**kwargs)
async def periodic_flush(self):
while True:
await asyncio.sleep(self.flush_interval)
verbose_logger.debug(
f"CustomLogger periodic flush after {self.flush_interval} seconds"
)
await self.flush_queue()
async def flush_queue(self):
if self.flush_lock is None:
return
async with self.flush_lock:
if self.log_queue:
verbose_logger.debug(
"CustomLogger: Flushing batch of %s events", len(self.log_queue)
)
await self.async_send_batch()
self.log_queue.clear()
self.last_flush_time = time.time()
async def async_send_batch(self, *args, **kwargs):
pass

View File

@@ -0,0 +1,955 @@
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Type,
Union,
get_args,
)
from litellm._logging import verbose_logger
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.guardrails import (
DynamicGuardrailParams,
GuardrailEventHooks,
LitellmParams,
Mode,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel
from litellm.types.utils import (
CallTypes,
GenericGuardrailAPIInputs,
GuardrailStatus,
GuardrailTracingDetail,
LLMResponseTypes,
StandardLoggingGuardrailInformation,
)
try:
from fastapi.exceptions import HTTPException
except ImportError:
HTTPException = None # type: ignore
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
dc = DualCache()
class ModifyResponseException(Exception):
"""
Exception raised when a guardrail wants to modify the response.
This exception carries the synthetic response that should be returned
to the user instead of calling the LLM or instead of the LLM's response.
It should be caught by the proxy and returned with a 200 status code.
This is a base exception that all guardrails can use to replace responses,
allowing violation messages to be returned as successful responses
rather than errors.
"""
def __init__(
self,
message: str,
model: str,
request_data: Dict[str, Any],
guardrail_name: Optional[str] = None,
detection_info: Optional[Dict[str, Any]] = None,
):
"""
Initialize the modify response exception.
Args:
message: The violation message to return to the user
model: The model that was being called
request_data: The original request data
guardrail_name: Name of the guardrail that raised this exception
detection_info: Additional detection metadata (scores, rules, etc.)
"""
self.message = message
self.model = model
self.request_data = request_data
self.guardrail_name = guardrail_name
self.detection_info = detection_info or {}
super().__init__(message)
class CustomGuardrail(CustomLogger):
def __init__(
self,
guardrail_name: Optional[str] = None,
supported_event_hooks: Optional[List[GuardrailEventHooks]] = None,
event_hook: Optional[
Union[GuardrailEventHooks, List[GuardrailEventHooks], Mode]
] = None,
default_on: bool = False,
mask_request_content: bool = False,
mask_response_content: bool = False,
violation_message_template: Optional[str] = None,
end_session_after_n_fails: Optional[int] = None,
on_violation: Optional[str] = None,
realtime_violation_message: Optional[str] = None,
**kwargs,
):
"""
Initialize the CustomGuardrail class
Args:
guardrail_name: The name of the guardrail. This is the name used in your requests.
supported_event_hooks: The event hooks that the guardrail supports
event_hook: The event hook to run the guardrail on
default_on: If True, the guardrail will be run by default on all requests
mask_request_content: If True, the guardrail will mask the request content
mask_response_content: If True, the guardrail will mask the response content
end_session_after_n_fails: For /v1/realtime sessions, end the session after this many violations
on_violation: For /v1/realtime sessions, 'warn' or 'end_session'
realtime_violation_message: Message the bot speaks aloud when a /v1/realtime guardrail fires
"""
self.guardrail_name = guardrail_name
self.supported_event_hooks = supported_event_hooks
self.event_hook: Optional[
Union[GuardrailEventHooks, List[GuardrailEventHooks], Mode]
] = event_hook
self.default_on: bool = default_on
self.mask_request_content: bool = mask_request_content
self.mask_response_content: bool = mask_response_content
self.violation_message_template: Optional[str] = violation_message_template
self.end_session_after_n_fails: Optional[int] = end_session_after_n_fails
self.on_violation: Optional[str] = on_violation
self.realtime_violation_message: Optional[str] = realtime_violation_message
if supported_event_hooks:
## validate event_hook is in supported_event_hooks
self._validate_event_hook(event_hook, supported_event_hooks)
super().__init__(**kwargs)
def render_violation_message(
self, default: str, context: Optional[Dict[str, Any]] = None
) -> str:
"""Return a custom violation message if template is configured."""
if not self.violation_message_template:
return default
format_context: Dict[str, Any] = {"default_message": default}
if context:
format_context.update(context)
try:
return self.violation_message_template.format(**format_context)
except Exception as e:
verbose_logger.warning(
"Failed to format violation message template for guardrail %s: %s",
self.guardrail_name,
e,
)
return default
def raise_passthrough_exception(
self,
violation_message: str,
request_data: Dict[str, Any],
detection_info: Optional[Dict[str, Any]] = None,
) -> None:
"""
Raise a passthrough exception for guardrail violations.
This helper method should be used by guardrails when they detect a violation
in passthrough mode.
The exception will be caught by the proxy endpoints and converted to a 200 response
with the violation message, preventing the LLM call from being made (pre_call/during_call)
or replacing the LLM response (post_call).
Args:
violation_message: The formatted violation message to return to the user
request_data: The original request data dictionary
detection_info: Optional dictionary with detection metadata (scores, rules, etc.)
Raises:
ModifyResponseException: Always raises this exception to short-circuit
the LLM call and return the violation message
Example:
if violation_detected and self.on_flagged_action == "passthrough":
message = self._format_violation_message(detection_info)
self.raise_passthrough_exception(
violation_message=message,
request_data=data,
detection_info=detection_info
)
"""
model = request_data.get("model", "unknown")
raise ModifyResponseException(
message=violation_message,
model=model,
request_data=request_data,
guardrail_name=self.guardrail_name,
detection_info=detection_info,
)
@staticmethod
def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
"""
Returns the config model for the guardrail
This is used to render the config model in the UI.
"""
return None
def _validate_event_hook(
self,
event_hook: Optional[
Union[GuardrailEventHooks, List[GuardrailEventHooks], Mode]
],
supported_event_hooks: List[GuardrailEventHooks],
) -> None:
def _validate_event_hook_list_is_in_supported_event_hooks(
event_hook: Union[List[GuardrailEventHooks], List[str]],
supported_event_hooks: List[GuardrailEventHooks],
) -> None:
for hook in event_hook:
if isinstance(hook, str):
hook = GuardrailEventHooks(hook)
if hook not in supported_event_hooks:
raise ValueError(
f"Event hook {hook} is not in the supported event hooks {supported_event_hooks}"
)
if event_hook is None:
return
if isinstance(event_hook, str):
event_hook = GuardrailEventHooks(event_hook)
if isinstance(event_hook, list):
_validate_event_hook_list_is_in_supported_event_hooks(
event_hook, supported_event_hooks
)
elif isinstance(event_hook, Mode):
tag_values_flat: list = []
for v in event_hook.tags.values():
if isinstance(v, list):
tag_values_flat.extend(v)
else:
tag_values_flat.append(v)
_validate_event_hook_list_is_in_supported_event_hooks(
tag_values_flat, supported_event_hooks
)
if event_hook.default:
default_list = (
event_hook.default
if isinstance(event_hook.default, list)
else [event_hook.default]
)
_validate_event_hook_list_is_in_supported_event_hooks(
default_list, supported_event_hooks
)
elif isinstance(event_hook, GuardrailEventHooks):
if event_hook not in supported_event_hooks:
raise ValueError(
f"Event hook {event_hook} is not in the supported event hooks {supported_event_hooks}"
)
def get_disable_global_guardrail(self, data: dict) -> Optional[bool]:
"""
Returns True if the global guardrail should be disabled
"""
if "disable_global_guardrail" in data:
return data["disable_global_guardrail"]
metadata = data.get("litellm_metadata") or data.get("metadata", {})
if "disable_global_guardrail" in metadata:
return metadata["disable_global_guardrail"]
return False
def _is_valid_response_type(self, result: Any) -> bool:
"""
Check if result is a valid LLMResponseTypes instance.
Safely handles TypedDict types which don't support isinstance checks.
For non-LiteLLM responses (like passthrough httpx.Response), returns True
to allow them through.
"""
if result is None:
return False
try:
# Try isinstance check on valid types that support it
response_types = get_args(LLMResponseTypes)
return isinstance(result, response_types)
except TypeError as e:
# TypedDict types don't support isinstance checks
# In this case, we can't validate the type, so we allow it through
if "TypedDict" in str(e):
return True
raise
def get_guardrail_from_metadata(
self, data: dict
) -> Union[List[str], List[Dict[str, DynamicGuardrailParams]]]:
"""
Returns the guardrail(s) to be run from the metadata or root
"""
if "guardrails" in data:
return data["guardrails"]
metadata = data.get("litellm_metadata") or data.get("metadata", {})
return metadata.get("guardrails") or []
def _guardrail_is_in_requested_guardrails(
self,
requested_guardrails: Union[List[str], List[Dict[str, DynamicGuardrailParams]]],
) -> bool:
for _guardrail in requested_guardrails:
if isinstance(_guardrail, dict):
if self.guardrail_name in _guardrail:
return True
elif isinstance(_guardrail, str):
if self.guardrail_name == _guardrail:
return True
return False
async def async_pre_call_deployment_hook(
self, kwargs: Dict[str, Any], call_type: Optional[CallTypes]
) -> Optional[dict]:
from litellm.proxy._types import UserAPIKeyAuth
# should run guardrail
litellm_guardrails = kwargs.get("guardrails")
if litellm_guardrails is None or not isinstance(litellm_guardrails, list):
return kwargs
if (
self.should_run_guardrail(
data=kwargs, event_type=GuardrailEventHooks.pre_call
)
is not True
):
return kwargs
# CHECK IF GUARDRAIL REJECTS THE REQUEST
if call_type == CallTypes.completion or call_type == CallTypes.acompletion:
result = await self.async_pre_call_hook(
user_api_key_dict=UserAPIKeyAuth(
user_id=kwargs.get("user_api_key_user_id"),
team_id=kwargs.get("user_api_key_team_id"),
end_user_id=kwargs.get("user_api_key_end_user_id"),
api_key=kwargs.get("user_api_key_hash"),
request_route=kwargs.get("user_api_key_request_route"),
),
cache=dc,
data=kwargs,
call_type=call_type.value or "acompletion", # type: ignore
)
if result is not None and isinstance(result, dict):
result_messages = result.get("messages")
if result_messages is not None: # update for any pii / masking logic
kwargs["messages"] = result_messages
return kwargs
async def async_post_call_success_deployment_hook(
self,
request_data: dict,
response: LLMResponseTypes,
call_type: Optional[CallTypes],
) -> Optional[LLMResponseTypes]:
"""
Allow modifying / reviewing the response just after it's received from the deployment.
"""
from litellm.proxy._types import UserAPIKeyAuth
# should run guardrail
litellm_guardrails = request_data.get("guardrails")
if litellm_guardrails is None or not isinstance(litellm_guardrails, list):
return response
if (
self.should_run_guardrail(
data=request_data, event_type=GuardrailEventHooks.post_call
)
is not True
):
return response
# CHECK IF GUARDRAIL REJECTS THE REQUEST
result = await self.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(
user_id=request_data.get("user_api_key_user_id"),
team_id=request_data.get("user_api_key_team_id"),
end_user_id=request_data.get("user_api_key_end_user_id"),
api_key=request_data.get("user_api_key_hash"),
request_route=request_data.get("user_api_key_request_route"),
),
data=request_data,
response=response,
)
if not self._is_valid_response_type(result):
return response
return result
def should_run_guardrail(
self,
data,
event_type: GuardrailEventHooks,
) -> bool:
"""
Returns True if the guardrail should be run on the event_type
"""
requested_guardrails = self.get_guardrail_from_metadata(data)
disable_global_guardrail = self.get_disable_global_guardrail(data)
verbose_logger.debug(
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s self.default_on= %s",
self.guardrail_name,
event_type,
self.event_hook,
requested_guardrails,
self.default_on,
)
if self.default_on is True and disable_global_guardrail is not True:
if self._event_hook_is_event_type(event_type):
if isinstance(self.event_hook, Mode):
try:
from litellm_enterprise.integrations.custom_guardrail import (
EnterpriseCustomGuardrailHelper,
)
except ImportError:
raise ImportError(
"Setting tag-based guardrails is only available in litellm-enterprise. You must be a premium user to use this feature."
)
result = EnterpriseCustomGuardrailHelper._should_run_if_mode_by_tag(
data, self.event_hook, event_type
)
if result is not None:
return result
return True
return False
if (
self.event_hook
and not self._guardrail_is_in_requested_guardrails(requested_guardrails)
and event_type.value != "logging_only"
):
return False
if not self._event_hook_is_event_type(event_type):
return False
if isinstance(self.event_hook, Mode):
try:
from litellm_enterprise.integrations.custom_guardrail import (
EnterpriseCustomGuardrailHelper,
)
except ImportError:
raise ImportError(
"Setting tag-based guardrails is only available in litellm-enterprise. You must be a premium user to use this feature."
)
result = EnterpriseCustomGuardrailHelper._should_run_if_mode_by_tag(
data, self.event_hook, event_type
)
if result is not None:
return result
return True
def _event_hook_is_event_type(self, event_type: GuardrailEventHooks) -> bool:
"""
Returns True if the event_hook is the same as the event_type
eg. if `self.event_hook == "pre_call" and event_type == "pre_call"` -> then True
eg. if `self.event_hook == "pre_call" and event_type == "post_call"` -> then False
"""
if self.event_hook is None:
return True
if isinstance(self.event_hook, list):
return event_type.value in self.event_hook
if isinstance(self.event_hook, Mode):
for tag_value in self.event_hook.tags.values():
if isinstance(tag_value, list):
if event_type.value in tag_value:
return True
elif event_type.value == tag_value:
return True
if self.event_hook.default:
default_list = (
self.event_hook.default
if isinstance(self.event_hook.default, list)
else [self.event_hook.default]
)
return event_type.value in default_list
return False
return self.event_hook == event_type.value
def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict:
"""
Returns `extra_body` to be added to the request body for the Guardrail API call
Use this to pass dynamic params to the guardrail API call - eg. success_threshold, failure_threshold, etc.
```
[{"lakera_guard": {"extra_body": {"foo": "bar"}}}]
```
Will return: for guardrail=`lakera-guard`:
{
"foo": "bar"
}
Args:
request_data: The original `request_data` passed to LiteLLM Proxy
"""
requested_guardrails = self.get_guardrail_from_metadata(request_data)
# Look for the guardrail configuration matching self.guardrail_name
for guardrail in requested_guardrails:
if isinstance(guardrail, dict) and self.guardrail_name in guardrail:
# Get the configuration for this guardrail
guardrail_config: DynamicGuardrailParams = DynamicGuardrailParams(
**guardrail[self.guardrail_name]
)
extra_body = guardrail_config.get("extra_body", {})
if self._validate_premium_user() is not True:
if isinstance(extra_body, dict) and extra_body:
verbose_logger.warning(
"Guardrail %s: ignoring dynamic extra_body keys %s because premium_user is False",
self.guardrail_name,
list(extra_body.keys()),
)
return {}
# Return the extra_body if it exists, otherwise empty dict
return extra_body
return {}
def _validate_premium_user(self) -> bool:
"""
Returns True if the user is a premium user
"""
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
if premium_user is not True:
verbose_logger.warning(
f"Trying to use premium guardrail without premium user {CommonProxyErrors.not_premium_user.value}"
)
return False
return True
def add_standard_logging_guardrail_information_to_request_data(
self,
guardrail_json_response: Union[Exception, str, dict, List[dict]],
request_data: dict,
guardrail_status: GuardrailStatus,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
duration: Optional[float] = None,
masked_entity_count: Optional[Dict[str, int]] = None,
guardrail_provider: Optional[str] = None,
event_type: Optional[GuardrailEventHooks] = None,
tracing_detail: Optional[GuardrailTracingDetail] = None,
) -> None:
"""
Builds `StandardLoggingGuardrailInformation` and adds it to the request metadata so it can be used for logging to DataDog, Langfuse, etc.
Args:
tracing_detail: Optional typed dict with provider-specific tracing fields
(guardrail_id, policy_template, detection_method, confidence_score,
classification, match_details, patterns_checked, alert_recipients).
"""
if isinstance(guardrail_json_response, Exception):
guardrail_json_response = str(guardrail_json_response)
from litellm.types.utils import GuardrailMode
# Use event_type if provided, otherwise fall back to self.event_hook
guardrail_mode: Union[
GuardrailEventHooks, GuardrailMode, List[GuardrailEventHooks]
]
if event_type is not None:
guardrail_mode = event_type
elif isinstance(self.event_hook, Mode):
guardrail_mode = GuardrailMode(**dict(self.event_hook.model_dump())) # type: ignore[typeddict-item]
else:
guardrail_mode = self.event_hook # type: ignore[assignment]
from litellm.litellm_core_utils.core_helpers import (
filter_exceptions_from_params,
)
# Sanitize the response to ensure it's JSON serializable and free of circular refs
# This prevents RecursionErrors in downstream loggers (Langfuse, Datadog, etc.)
clean_guardrail_response = filter_exceptions_from_params(
guardrail_json_response
)
# Strip secret_fields to prevent plaintext Authorization headers from
# being persisted to spend logs, OTEL traces, or other logging backends.
# This matches the pattern used by Langfuse and Arize integrations.
if isinstance(clean_guardrail_response, dict):
clean_guardrail_response.pop("secret_fields", None)
elif isinstance(clean_guardrail_response, list):
for item in clean_guardrail_response:
if isinstance(item, dict):
item.pop("secret_fields", None)
slg = StandardLoggingGuardrailInformation(
guardrail_name=self.guardrail_name,
guardrail_provider=guardrail_provider,
guardrail_mode=guardrail_mode,
guardrail_response=clean_guardrail_response,
guardrail_status=guardrail_status,
start_time=start_time,
end_time=end_time,
duration=duration,
masked_entity_count=masked_entity_count,
**(tracing_detail or {}),
)
def _append_guardrail_info(container: dict) -> None:
key = "standard_logging_guardrail_information"
existing = container.get(key)
if existing is None:
container[key] = [slg]
elif isinstance(existing, list):
existing.append(slg)
else:
# should not happen
container[key] = [existing, slg]
if "metadata" in request_data:
if request_data["metadata"] is None:
request_data["metadata"] = {}
_append_guardrail_info(request_data["metadata"])
elif "litellm_metadata" in request_data:
_append_guardrail_info(request_data["litellm_metadata"])
else:
# Ensure guardrail info is always logged (e.g. proxy may not have set
# metadata yet). Attach to "metadata" so spend log / standard logging see it.
request_data["metadata"] = {}
_append_guardrail_info(request_data["metadata"])
async def apply_guardrail(
self,
inputs: GenericGuardrailAPIInputs,
request_data: dict,
input_type: Literal["request", "response"],
logging_obj: Optional["LiteLLMLoggingObj"] = None,
) -> GenericGuardrailAPIInputs:
"""
Apply your guardrail logic to the given inputs
Args:
inputs: Dictionary containing:
- texts: List of texts to apply the guardrail to
- images: Optional list of images to apply the guardrail to
- tool_calls: Optional list of tool calls to apply the guardrail to
request_data: The request data dictionary - containing user api key metadata (e.g. user_id, team_id, etc.)
input_type: The type of input to apply the guardrail to - "request" or "response"
logging_obj: Optional logging object for tracking the guardrail execution
Any of the custom guardrails can override this method to provide custom guardrail logic
Returns the texts with the guardrail applied and the images with the guardrail applied (if any)
Raises:
Exception:
- If the guardrail raises an exception
"""
return inputs
def _process_response(
self,
response: Optional[Dict],
request_data: dict,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
duration: Optional[float] = None,
event_type: Optional[GuardrailEventHooks] = None,
original_inputs: Optional[Dict] = None,
):
"""
Add StandardLoggingGuardrailInformation to the request data
This gets logged on downsteam Langfuse, DataDog, etc.
"""
# Convert None to empty dict to satisfy type requirements
guardrail_response: Union[Dict[str, Any], str] = (
{} if response is None else response
)
# For apply_guardrail functions in custom_code_guardrail scenario,
# simplify the logged response to "allow", "deny", or "mask"
if original_inputs is not None and isinstance(response, dict):
# Check if inputs were modified by comparing them
if self._inputs_were_modified(original_inputs, response):
guardrail_response = "mask"
else:
guardrail_response = "allow"
verbose_logger.debug(f"Guardrail response: {response}")
self.add_standard_logging_guardrail_information_to_request_data(
guardrail_json_response=guardrail_response,
request_data=request_data,
guardrail_status="success",
duration=duration,
start_time=start_time,
end_time=end_time,
event_type=event_type,
)
return response
@staticmethod
def _is_guardrail_intervention(e: Exception) -> bool:
"""
Returns True if the exception represents an intentional guardrail block
(this was logged previously as an API failure - guardrail_failed_to_respond).
Guardrails signal intentional blocks by raising:
- HTTPException with status 400 (content policy violation)
- ModifyResponseException (passthrough mode violation)
"""
if isinstance(e, ModifyResponseException):
return True
if (
HTTPException is not None
and isinstance(e, HTTPException)
and e.status_code == 400
):
return True
return False
def _process_error(
self,
e: Exception,
request_data: dict,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
duration: Optional[float] = None,
event_type: Optional[GuardrailEventHooks] = None,
):
"""
Add StandardLoggingGuardrailInformation to the request data
This gets logged on downsteam Langfuse, DataDog, etc.
"""
guardrail_status: GuardrailStatus = (
"guardrail_intervened"
if self._is_guardrail_intervention(e)
else "guardrail_failed_to_respond"
)
# For custom_code_guardrail scenario, log as "deny" instead of full exception
# Check if this is from custom_code_guardrail by checking the class name
guardrail_response: Union[Exception, str] = e
if "CustomCodeGuardrail" in self.__class__.__name__:
guardrail_response = "deny"
self.add_standard_logging_guardrail_information_to_request_data(
guardrail_json_response=guardrail_response,
request_data=request_data,
guardrail_status=guardrail_status,
duration=duration,
start_time=start_time,
end_time=end_time,
event_type=event_type,
)
raise e
def _inputs_were_modified(self, original_inputs: Dict, response: Dict) -> bool:
"""
Compare original inputs with response to determine if content was modified.
Returns True if the inputs were modified (mask scenario), False otherwise (allow scenario).
"""
# Get all keys from both dictionaries
all_keys = set(original_inputs.keys()) | set(response.keys())
# Compare each key's value
for key in all_keys:
original_value = original_inputs.get(key)
response_value = response.get(key)
if original_value != response_value:
return True
# No modifications detected
return False
def mask_content_in_string(
self,
content_string: str,
mask_string: str,
start_index: int,
end_index: int,
) -> str:
"""
Mask the content in the string between the start and end indices.
"""
# Do nothing if the start or end are not valid
if not (0 <= start_index < end_index <= len(content_string)):
return content_string
# Mask the content
return content_string[:start_index] + mask_string + content_string[end_index:]
def update_in_memory_litellm_params(self, litellm_params: LitellmParams) -> None:
"""
Update the guardrails litellm params in memory
"""
for key, value in vars(litellm_params).items():
setattr(self, key, value)
def get_guardrails_messages_for_call_type(
self, call_type: CallTypes, data: Optional[dict] = None
) -> Optional[List[AllMessageValues]]:
"""
Returns the messages for the given call type and data
"""
if call_type is None or data is None:
return None
#########################################################
# /chat/completions
# /messages
# Both endpoints store the messages in the "messages" key
#########################################################
if (
call_type == CallTypes.completion.value
or call_type == CallTypes.acompletion.value
or call_type == CallTypes.anthropic_messages.value
):
return data.get("messages")
#########################################################
# /responses
# User/System messages are stored in the "input" key, use litellm transformation to get the messages
#########################################################
if (
call_type == CallTypes.responses.value
or call_type == CallTypes.aresponses.value
):
from typing import cast
from litellm.responses.litellm_completion_transformation.transformation import (
LiteLLMCompletionResponsesConfig,
)
input_data = data.get("input")
if input_data is None:
return None
messages = LiteLLMCompletionResponsesConfig.transform_responses_api_input_to_messages(
input=input_data,
responses_api_request=data,
)
return cast(List[AllMessageValues], messages)
return None
def log_guardrail_information(func):
"""
Decorator to add standard logging guardrail information to any function
Add this decorator to ensure your guardrail response is logged to DataDog, OTEL, s3, GCS etc.
Logs for:
- pre_call
- during_call
- post_call
"""
import functools
import inspect
def _infer_event_type_from_function_name(
func_name: str,
) -> Optional[GuardrailEventHooks]:
"""Infer the actual event type from the function name"""
if func_name == "async_pre_call_hook":
return GuardrailEventHooks.pre_call
elif func_name == "async_moderation_hook":
return GuardrailEventHooks.during_call
elif func_name in (
"async_post_call_success_hook",
"async_post_call_streaming_hook",
):
return GuardrailEventHooks.post_call
return None
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
start_time = datetime.now() # Move start_time inside the wrapper
self: CustomGuardrail = args[0]
request_data: dict = kwargs.get("data") or kwargs.get("request_data") or {}
event_type = _infer_event_type_from_function_name(func.__name__)
# Store original inputs for comparison (for apply_guardrail functions)
original_inputs = None
if func.__name__ == "apply_guardrail" and "inputs" in kwargs:
original_inputs = kwargs.get("inputs")
try:
response = await func(*args, **kwargs)
return self._process_response(
response=response,
request_data=request_data,
start_time=start_time.timestamp(),
end_time=datetime.now().timestamp(),
duration=(datetime.now() - start_time).total_seconds(),
event_type=event_type,
original_inputs=original_inputs,
)
except Exception as e:
return self._process_error(
e=e,
request_data=request_data,
start_time=start_time.timestamp(),
end_time=datetime.now().timestamp(),
duration=(datetime.now() - start_time).total_seconds(),
event_type=event_type,
)
@functools.wraps(func)
def sync_wrapper(*args, **kwargs):
start_time = datetime.now() # Move start_time inside the wrapper
self: CustomGuardrail = args[0]
request_data: dict = kwargs.get("data") or kwargs.get("request_data") or {}
event_type = _infer_event_type_from_function_name(func.__name__)
# Store original inputs for comparison (for apply_guardrail functions)
original_inputs = None
if func.__name__ == "apply_guardrail" and "inputs" in kwargs:
original_inputs = kwargs.get("inputs")
try:
response = func(*args, **kwargs)
return self._process_response(
response=response,
request_data=request_data,
duration=(datetime.now() - start_time).total_seconds(),
event_type=event_type,
original_inputs=original_inputs,
)
except Exception as e:
return self._process_error(
e=e,
request_data=request_data,
duration=(datetime.now() - start_time).total_seconds(),
event_type=event_type,
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
if inspect.iscoroutinefunction(func):
return async_wrapper(*args, **kwargs)
return sync_wrapper(*args, **kwargs)
return wrapper

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,83 @@
from typing import List, Optional, Tuple
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.prompt_management_base import (
PromptManagementBase,
PromptManagementClient,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.prompts.init_prompts import PromptSpec
from litellm.types.utils import StandardCallbackDynamicParams
class CustomPromptManagement(CustomLogger, PromptManagementBase):
def __init__(
self,
ignore_prompt_manager_model: Optional[bool] = False,
ignore_prompt_manager_optional_params: Optional[bool] = False,
**kwargs,
):
self.ignore_prompt_manager_model = ignore_prompt_manager_model
self.ignore_prompt_manager_optional_params = (
ignore_prompt_manager_optional_params
)
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_spec: Optional[PromptSpec] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
ignore_prompt_manager_model: Optional[bool] = False,
ignore_prompt_manager_optional_params: Optional[bool] = False,
) -> Tuple[str, List[AllMessageValues], dict]:
"""
Returns:
- model: str - the model to use (can be pulled from prompt management tool)
- messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool)
- non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool)
"""
return model, messages, non_default_params
@property
def integration_name(self) -> str:
return "custom-prompt-management"
def should_run_prompt_management(
self,
prompt_id: Optional[str],
prompt_spec: Optional[PromptSpec],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> bool:
return True
def _compile_prompt_helper(
self,
prompt_id: Optional[str],
prompt_spec: Optional[PromptSpec],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
) -> PromptManagementClient:
raise NotImplementedError(
"Custom prompt management does not support compile prompt helper"
)
async def async_compile_prompt_helper(
self,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_spec: Optional[PromptSpec] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
) -> PromptManagementClient:
raise NotImplementedError(
"Custom prompt management does not support async compile prompt helper"
)

View File

@@ -0,0 +1,252 @@
"""
Custom Secret Manager Integration
This module provides a base class for implementing custom secret managers in LiteLLM.
Usage:
from litellm.integrations.custom_secret_manager import CustomSecretManager
class MySecretManager(CustomSecretManager):
def __init__(self):
super().__init__(secret_manager_name="my_secret_manager")
async def async_read_secret(
self,
secret_name: str,
optional_params=None,
timeout=None,
):
# Your implementation here
return await self._fetch_secret_from_service(secret_name)
def sync_read_secret(
self,
secret_name: str,
optional_params=None,
timeout=None,
):
# Your implementation here
return self._fetch_secret_from_service_sync(secret_name)
# Set your custom secret manager
import litellm
from litellm.types.secret_managers.main import KeyManagementSystem
litellm.secret_manager_client = MySecretManager()
litellm._key_management_system = KeyManagementSystem.CUSTOM
"""
from abc import abstractmethod
from typing import Any, Dict, Optional, Union
import httpx
from litellm._logging import verbose_logger
from litellm.secret_managers.base_secret_manager import BaseSecretManager
class CustomSecretManager(BaseSecretManager):
"""
Base class for implementing custom secret managers.
This class provides a standard interface for implementing custom secret management
integrations in LiteLLM. Users can extend this class to integrate their own secret
management systems.
Example:
```python
from litellm.integrations.custom_secret_manager import CustomSecretManager
class MyVaultSecretManager(CustomSecretManager):
def __init__(self, vault_url: str, token: str):
super().__init__(secret_manager_name="my_vault")
self.vault_url = vault_url
self.token = token
async def async_read_secret(self, secret_name: str, optional_params=None, timeout=None):
# Implementation for reading secrets from your vault
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.vault_url}/v1/secret/{secret_name}",
headers={"X-Vault-Token": self.token},
timeout=timeout
)
return response.json()["data"]["value"]
def sync_read_secret(self, secret_name: str, optional_params=None, timeout=None):
# Sync implementation
with httpx.Client() as client:
response = client.get(
f"{self.vault_url}/v1/secret/{secret_name}",
headers={"X-Vault-Token": self.token},
timeout=timeout
)
return response.json()["data"]["value"]
```
"""
def __init__(
self,
secret_manager_name: Optional[str] = None,
**kwargs,
):
"""
Initialize the CustomSecretManager.
Args:
secret_manager_name: A descriptive name for your secret manager.
This is used for logging and debugging purposes.
**kwargs: Additional keyword arguments to pass to your secret manager.
"""
super().__init__()
self.secret_manager_name = secret_manager_name or "custom_secret_manager"
verbose_logger.info("Initialized custom secret manager")
@abstractmethod
async def async_read_secret(
self,
secret_name: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> Optional[str]:
"""
Asynchronously read a secret from your custom secret manager.
Args:
secret_name: Name/path of the secret to read
optional_params: Additional parameters specific to your secret manager
timeout: Request timeout
Returns:
The secret value if found, None otherwise
Raises:
Exception: If there's an error reading the secret
"""
pass
@abstractmethod
def sync_read_secret(
self,
secret_name: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> Optional[str]:
"""
Synchronously read a secret from your custom secret manager.
Args:
secret_name: Name/path of the secret to read
optional_params: Additional parameters specific to your secret manager
timeout: Request timeout
Returns:
The secret value if found, None otherwise
Raises:
Exception: If there's an error reading the secret
"""
pass
async def async_write_secret(
self,
secret_name: str,
secret_value: str,
description: Optional[str] = None,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
tags: Optional[Union[dict, list]] = None,
) -> Dict[str, Any]:
"""
Asynchronously write a secret to your custom secret manager.
This is optional to implement. If your secret manager supports writing secrets,
you can override this method.
Args:
secret_name: Name/path of the secret to write
secret_value: Value to store
description: Description of the secret
optional_params: Additional parameters specific to your secret manager
timeout: Request timeout
tags: Optional tags to apply to the secret
Returns:
Response from the secret manager containing write operation details
Raises:
NotImplementedError: If write operations are not supported
"""
raise NotImplementedError(
f"Write operations are not implemented for {self.secret_manager_name}. "
"Override async_write_secret() to add write support."
)
async def async_delete_secret(
self,
secret_name: str,
recovery_window_in_days: Optional[int] = 7,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> dict:
"""
Asynchronously delete a secret from your custom secret manager.
This is optional to implement. If your secret manager supports deleting secrets,
you can override this method.
Args:
secret_name: Name of the secret to delete
recovery_window_in_days: Number of days before permanent deletion (if supported)
optional_params: Additional parameters specific to your secret manager
timeout: Request timeout
Returns:
Response from the secret manager containing deletion details
Raises:
NotImplementedError: If delete operations are not supported
"""
raise NotImplementedError(
f"Delete operations are not implemented for {self.secret_manager_name}. "
"Override async_delete_secret() to add delete support."
)
def validate_environment(self) -> bool:
"""
Validate that all required environment variables and configuration are present.
Override this method to validate your secret manager's configuration.
Returns:
True if the environment is valid
Raises:
ValueError: If required configuration is missing
"""
verbose_logger.debug(
"No environment validation configured for custom secret manager"
)
return True
async def async_health_check(
self, timeout: Optional[Union[float, httpx.Timeout]] = None
) -> bool:
"""
Perform a health check on your secret manager.
This is optional to implement. Override this method to add health check support.
Args:
timeout: Request timeout
Returns:
True if the secret manager is healthy, False otherwise
"""
verbose_logger.debug(
f"Health check not implemented for {self.secret_manager_name}"
)
return True
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(name={self.secret_manager_name})>"

View File

@@ -0,0 +1,30 @@
from fastapi import Request
from fastapi_sso.sso.base import OpenID
from litellm.integrations.custom_logger import CustomLogger
class CustomSSOLoginHandler(CustomLogger):
"""
Custom logger for the UI SSO sign in
Use this to parse the request headers and return a OpenID object
Useful when you have an OAuth proxy in front of LiteLLM
and you want to use the headers from the proxy to sign in the user
"""
async def handle_custom_ui_sso_sign_in(
self,
request: Request,
) -> OpenID:
request_headers_dict = dict(request.headers)
return OpenID(
id=request_headers_dict.get("x-litellm-user-id"),
email=request_headers_dict.get("x-litellm-user-email"),
first_name="Test",
last_name="Test",
display_name="Test",
picture="https://test.com/test.png",
provider="test",
)

View File

@@ -0,0 +1,777 @@
"""
DataDog Integration - sends logs to /api/v2/log
DD Reference API: https://docs.datadoghq.com/api/latest/logs
`async_log_success_event` - used by litellm proxy to send logs to datadog
`log_success_event` - sync version of logging to DataDog, only used on litellm Python SDK, if user opts in to using sync functions
async_log_success_event: will store batch of DD_MAX_BATCH_SIZE in memory and flush to Datadog once it reaches DD_MAX_BATCH_SIZE or every 5 seconds
async_service_failure_hook: Logs failures from Redis, Postgres (Adjacent systems), as 'WARNING' on DataDog
For batching specific details see CustomBatchLogger class
"""
import asyncio
import datetime
import os
import traceback
from datetime import datetime as datetimeObj
from typing import Any, Dict, List, Optional, Union
import httpx
from httpx import Response
import litellm
from litellm._logging import verbose_logger
from litellm._uuid import uuid
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.integrations.datadog.datadog_mock_client import (
should_use_datadog_mock,
create_mock_datadog_client,
)
from litellm.integrations.datadog.datadog_handler import (
get_datadog_hostname,
get_datadog_service,
get_datadog_source,
get_datadog_tags,
get_datadog_base_url_from_env,
)
from litellm.litellm_core_utils.dd_tracing import tracer
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
from litellm.types.integrations.datadog import (
DD_ERRORS,
DD_MAX_BATCH_SIZE,
DataDogStatus,
DatadogInitParams,
DatadogPayload,
DatadogProxyFailureHookJsonMessage,
)
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
from litellm.types.utils import StandardLoggingPayload
from ..additional_logging_utils import AdditionalLoggingUtils
# max number of logs DD API can accept
# specify what ServiceTypes are logged as success events to DD. (We don't want to spam DD traces with large number of service types)
DD_LOGGED_SUCCESS_SERVICE_TYPES = [
ServiceTypes.RESET_BUDGET_JOB,
]
class DataDogLogger(
CustomBatchLogger,
AdditionalLoggingUtils,
):
# Class variables or attributes
def __init__(
self,
**kwargs,
):
"""
Initializes the datadog logger, checks if the correct env variables are set
Required environment variables (Direct API):
`DD_API_KEY` - your datadog api key
`DD_SITE` - your datadog site, example = `"us5.datadoghq.com"`
Optional environment variables (DataDog Agent):
`LITELLM_DD_AGENT_HOST` - hostname or IP of DataDog agent, example = `"localhost"`
`LITELLM_DD_AGENT_PORT` - port of DataDog agent (default: 10518 for logs)
Note: We use LITELLM_DD_AGENT_HOST instead of DD_AGENT_HOST to avoid conflicts
with ddtrace which automatically sets DD_AGENT_HOST for APM tracing.
"""
try:
verbose_logger.debug("Datadog: in init datadog logger")
self.is_mock_mode = should_use_datadog_mock()
if self.is_mock_mode:
create_mock_datadog_client()
verbose_logger.debug(
"[DATADOG MOCK] Datadog logger initialized in mock mode"
)
#########################################################
# Handle datadog_params set as litellm.datadog_params
#########################################################
dict_datadog_params = self._get_datadog_params()
kwargs.update(dict_datadog_params)
self.async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
# Configure DataDog endpoint (Agent or Direct API)
# Use LITELLM_DD_AGENT_HOST to avoid conflicts with ddtrace's DD_AGENT_HOST
dd_agent_host = os.getenv("LITELLM_DD_AGENT_HOST")
if dd_agent_host:
self._configure_dd_agent(dd_agent_host=dd_agent_host)
else:
self._configure_dd_direct_api()
# Optional override for testing
dd_base_url = get_datadog_base_url_from_env()
if dd_base_url:
self.intake_url = f"{dd_base_url}/api/v2/logs"
self.sync_client = _get_httpx_client()
asyncio.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock()
super().__init__(
**kwargs, flush_lock=self.flush_lock, batch_size=DD_MAX_BATCH_SIZE
)
except Exception as e:
verbose_logger.exception(
f"Datadog: Got exception on init Datadog client {str(e)}"
)
raise e
def _get_datadog_params(self) -> Dict:
"""
Get the datadog_params from litellm.datadog_params
These are params specific to initializing the DataDogLogger e.g. turn_off_message_logging
"""
dict_datadog_params: Dict = {}
if litellm.datadog_params is not None:
if isinstance(litellm.datadog_params, DatadogInitParams):
dict_datadog_params = litellm.datadog_params.model_dump()
elif isinstance(litellm.datadog_params, Dict):
# only allow params that are of DatadogInitParams
dict_datadog_params = DatadogInitParams(
**litellm.datadog_params
).model_dump()
return dict_datadog_params
def _configure_dd_agent(self, dd_agent_host: str) -> None:
"""
Configure DataDog Agent for log forwarding
Args:
dd_agent_host: Hostname or IP of DataDog agent
"""
dd_agent_port = os.getenv(
"LITELLM_DD_AGENT_PORT", "10518"
) # default port for logs
self.intake_url = f"http://{dd_agent_host}:{dd_agent_port}/api/v2/logs"
self.DD_API_KEY = os.getenv("DD_API_KEY") # Optional when using agent
verbose_logger.debug(f"Datadog: Using DD Agent at {self.intake_url}")
def _configure_dd_direct_api(self) -> None:
"""
Configure direct DataDog API connection
Raises:
Exception: If required environment variables are not set
"""
if os.getenv("DD_API_KEY", None) is None:
raise Exception("DD_API_KEY is not set, set 'DD_API_KEY=<>")
if os.getenv("DD_SITE", None) is None:
raise Exception("DD_SITE is not set in .env, set 'DD_SITE=<>")
self.DD_API_KEY = os.getenv("DD_API_KEY")
self.intake_url = f"https://http-intake.logs.{os.getenv('DD_SITE')}/api/v2/logs"
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Async Log success events to Datadog
- Creates a Datadog payload
- Adds the Payload to the in memory logs queue
- Payload is flushed every 10 seconds or when batch size is greater than 100
Raises:
Raises a NON Blocking verbose_logger.exception if an error occurs
"""
try:
verbose_logger.debug(
"Datadog: Logging - Enters logging function for model %s", kwargs
)
await self._log_async_event(kwargs, response_obj, start_time, end_time)
except Exception as e:
verbose_logger.exception(
f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}"
)
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
verbose_logger.debug(
"Datadog: Logging - Enters logging function for model %s", kwargs
)
await self._log_async_event(kwargs, response_obj, start_time, end_time)
except Exception as e:
verbose_logger.exception(
f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}"
)
pass
async def async_post_call_failure_hook(
self,
request_data: dict,
original_exception: Exception,
user_api_key_dict: Any,
traceback_str: Optional[str] = None,
) -> Optional[Any]:
"""
Log proxy-level failures (e.g. 401 auth, DB connection errors) to Datadog.
Ensures failures that occur before or outside the LLM completion flow
(e.g. ConnectError during auth when DB is down) are visible in Datadog
alongside Prometheus.
"""
try:
from litellm.litellm_core_utils.litellm_logging import (
StandardLoggingPayloadSetup,
)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
error_information = StandardLoggingPayloadSetup.get_error_information(
original_exception=original_exception,
traceback_str=traceback_str,
)
_code = error_information.get("error_code") or ""
status_code: Optional[int] = None
if _code and str(_code).strip().isdigit():
status_code = int(_code)
# Use project-standard sanitized user context when running in proxy
user_context: Dict[str, Any] = {}
try:
from litellm.proxy.litellm_pre_call_utils import (
LiteLLMProxyRequestSetup,
)
_meta = (
LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
user_api_key_dict=user_api_key_dict
)
)
user_context = dict(_meta) if isinstance(_meta, dict) else _meta
except Exception:
# Fallback if proxy not available (e.g. SDK-only): minimal safe fields
if hasattr(user_api_key_dict, "request_route"):
user_context["request_route"] = getattr(
user_api_key_dict, "request_route", None
)
if hasattr(user_api_key_dict, "team_id"):
user_context["team_id"] = getattr(
user_api_key_dict, "team_id", None
)
if hasattr(user_api_key_dict, "user_id"):
user_context["user_id"] = getattr(
user_api_key_dict, "user_id", None
)
if hasattr(user_api_key_dict, "end_user_id"):
user_context["end_user_id"] = getattr(
user_api_key_dict, "end_user_id", None
)
message_payload: DatadogProxyFailureHookJsonMessage = {
"exception": error_information.get("error_message")
or str(original_exception),
"error_class": error_information.get("error_class")
or original_exception.__class__.__name__,
"status_code": status_code,
"traceback": error_information.get("traceback") or "",
"user_api_key_dict": user_context,
}
dd_payload = DatadogPayload(
ddsource=get_datadog_source(),
ddtags=get_datadog_tags(),
hostname=get_datadog_hostname(),
message=safe_dumps(message_payload),
service=get_datadog_service(),
status=DataDogStatus.ERROR,
)
self._add_trace_context_to_payload(dd_payload=dd_payload)
self.log_queue.append(dd_payload)
if len(self.log_queue) >= self.batch_size:
await self.async_send_batch()
except Exception as e:
verbose_logger.exception(
f"Datadog: async_post_call_failure_hook - {str(e)}\n{traceback.format_exc()}"
)
return None
async def async_send_batch(self):
"""
Sends the in memory logs queue to datadog api
Logs sent to /api/v2/logs
DD Ref: https://docs.datadoghq.com/api/latest/logs/
Raises:
Raises a NON Blocking verbose_logger.exception if an error occurs
"""
try:
if not self.log_queue:
verbose_logger.exception("Datadog: log_queue does not exist")
return
verbose_logger.debug(
"Datadog - about to flush %s events on %s",
len(self.log_queue),
self.intake_url,
)
if self.is_mock_mode:
verbose_logger.debug(
"[DATADOG MOCK] Mock mode enabled - API calls will be intercepted"
)
response = await self.async_send_compressed_data(self.log_queue)
if response.status_code == 413:
verbose_logger.exception(DD_ERRORS.DATADOG_413_ERROR.value)
return
response.raise_for_status()
if response.status_code != 202:
raise Exception(
f"Response from datadog API status_code: {response.status_code}, text: {response.text}"
)
if self.is_mock_mode:
verbose_logger.debug(
f"[DATADOG MOCK] Batch of {len(self.log_queue)} events successfully mocked"
)
else:
verbose_logger.debug(
"Datadog: Response from datadog API status_code: %s, text: %s",
response.status_code,
response.text,
)
except Exception as e:
verbose_logger.exception(
f"Datadog Error sending batch API - {str(e)}\n{traceback.format_exc()}"
)
def log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Sync Log success events to Datadog
- Creates a Datadog payload
- instantly logs it on DD API
"""
try:
if litellm.datadog_use_v1 is True:
dd_payload = self._create_v0_logging_payload(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
)
else:
dd_payload = self.create_datadog_logging_payload(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
)
# Build headers
headers = {}
# Add API key if available (required for direct API, optional for agent)
if self.DD_API_KEY:
headers["DD-API-KEY"] = self.DD_API_KEY
response = self.sync_client.post(
url=self.intake_url,
json=dd_payload, # type: ignore
headers=headers,
)
response.raise_for_status()
if response.status_code != 202:
raise Exception(
f"Response from datadog API status_code: {response.status_code}, text: {response.text}"
)
verbose_logger.debug(
"Datadog: Response from datadog API status_code: %s, text: %s",
response.status_code,
response.text,
)
except Exception as e:
verbose_logger.exception(
f"Datadog Layer Error - {str(e)}\n{traceback.format_exc()}"
)
pass
pass
async def _log_async_event(self, kwargs, response_obj, start_time, end_time):
dd_payload = self.create_datadog_logging_payload(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
)
self.log_queue.append(dd_payload)
verbose_logger.debug(
f"Datadog, event added to queue. Will flush in {self.flush_interval} seconds..."
)
if len(self.log_queue) >= self.batch_size:
await self.async_send_batch()
def _create_datadog_logging_payload_helper(
self,
standard_logging_object: StandardLoggingPayload,
status: DataDogStatus,
) -> DatadogPayload:
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
json_payload = safe_dumps(standard_logging_object)
verbose_logger.debug("Datadog: Logger - Logging payload = %s", json_payload)
dd_payload = DatadogPayload(
ddsource=get_datadog_source(),
ddtags=get_datadog_tags(standard_logging_object=standard_logging_object),
hostname=get_datadog_hostname(),
message=json_payload,
service=get_datadog_service(),
status=status,
)
self._add_trace_context_to_payload(dd_payload=dd_payload)
return dd_payload
def create_datadog_logging_payload(
self,
kwargs: Union[dict, Any],
response_obj: Any,
start_time: datetime.datetime,
end_time: datetime.datetime,
) -> DatadogPayload:
"""
Helper function to create a datadog payload for logging
Args:
kwargs (Union[dict, Any]): request kwargs
response_obj (Any): llm api response
start_time (datetime.datetime): start time of request
end_time (datetime.datetime): end time of request
Returns:
DatadogPayload: defined in types.py
"""
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if standard_logging_object is None:
raise ValueError("standard_logging_object not found in kwargs")
status = DataDogStatus.INFO
if standard_logging_object.get("status") == "failure":
status = DataDogStatus.ERROR
# Build the initial payload
self.truncate_standard_logging_payload_content(standard_logging_object)
dd_payload = self._create_datadog_logging_payload_helper(
standard_logging_object=standard_logging_object,
status=status,
)
return dd_payload
async def async_send_compressed_data(self, data: List) -> Response:
"""
Async helper to send compressed data to datadog self.intake_url
Datadog recommends using gzip to compress data
https://docs.datadoghq.com/api/latest/logs/
"Datadog recommends sending your logs compressed. Add the Content-Encoding: gzip header to the request when sending"
"""
import gzip
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
compressed_data = gzip.compress(safe_dumps(data).encode("utf-8"))
# Build headers
headers = {
"Content-Encoding": "gzip",
"Content-Type": "application/json",
}
# Add API key if available (required for direct API, optional for agent)
if self.DD_API_KEY:
headers["DD-API-KEY"] = self.DD_API_KEY
response = await self.async_client.post(
url=self.intake_url,
data=compressed_data, # type: ignore
headers=headers,
)
return response
async def async_service_failure_hook(
self,
payload: ServiceLoggerPayload,
error: Optional[str] = "",
parent_otel_span: Optional[Any] = None,
start_time: Optional[Union[datetimeObj, float]] = None,
end_time: Optional[Union[float, datetimeObj]] = None,
event_metadata: Optional[dict] = None,
):
"""
Logs failures from Redis, Postgres (Adjacent systems), as 'WARNING' on DataDog
- example - Redis is failing / erroring, will be logged on DataDog
"""
try:
_payload_dict = payload.model_dump()
_payload_dict.update(event_metadata or {})
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
_dd_message_str = safe_dumps(_payload_dict)
_dd_payload = DatadogPayload(
ddsource=get_datadog_source(),
ddtags=get_datadog_tags(),
hostname=get_datadog_hostname(),
message=_dd_message_str,
service=get_datadog_service(),
status=DataDogStatus.WARN,
)
self.log_queue.append(_dd_payload)
except Exception as e:
verbose_logger.exception(
f"Datadog: Logger - Exception in async_service_failure_hook: {e}"
)
pass
async def async_service_success_hook(
self,
payload: ServiceLoggerPayload,
error: Optional[str] = "",
parent_otel_span: Optional[Any] = None,
start_time: Optional[Union[datetimeObj, float]] = None,
end_time: Optional[Union[float, datetimeObj]] = None,
event_metadata: Optional[dict] = None,
):
"""
Logs success from Redis, Postgres (Adjacent systems), as 'INFO' on DataDog
No user has asked for this so far, this might be spammy on datatdog. If need arises we can implement this
"""
try:
# intentionally done. Don't want to log all service types to DD
if payload.service not in DD_LOGGED_SUCCESS_SERVICE_TYPES:
return
_payload_dict = payload.model_dump()
_payload_dict.update(event_metadata or {})
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
_dd_message_str = safe_dumps(_payload_dict)
_dd_payload = DatadogPayload(
ddsource=get_datadog_source(),
ddtags=get_datadog_tags(),
hostname=get_datadog_hostname(),
message=_dd_message_str,
service=get_datadog_service(),
status=DataDogStatus.INFO,
)
self.log_queue.append(_dd_payload)
except Exception as e:
verbose_logger.exception(
f"Datadog: Logger - Exception in async_service_failure_hook: {e}"
)
def _create_v0_logging_payload(
self,
kwargs: Union[dict, Any],
response_obj: Any,
start_time: datetime.datetime,
end_time: datetime.datetime,
) -> DatadogPayload:
"""
Note: This is our V1 Version of DataDog Logging Payload
(Not Recommended) If you want this to get logged set `litellm.datadog_use_v1 = True`
"""
litellm_params = kwargs.get("litellm_params", {})
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
messages = kwargs.get("messages")
optional_params = kwargs.get("optional_params", {})
call_type = kwargs.get("call_type", "litellm.completion")
cache_hit = kwargs.get("cache_hit", False)
usage = response_obj["usage"]
id = response_obj.get("id", str(uuid.uuid4()))
usage = dict(usage)
try:
response_time = (end_time - start_time).total_seconds() * 1000
except Exception:
response_time = None
try:
response_obj = dict(response_obj)
except Exception:
response_obj = response_obj
# Clean Metadata before logging - never log raw metadata
# the raw metadata can contain circular references which leads to infinite recursion
# we clean out all extra litellm metadata params before logging
clean_metadata = {}
if isinstance(metadata, dict):
for key, value in metadata.items():
# clean litellm metadata before logging
if key in [
"endpoint",
"caching_groups",
"previous_models",
]:
continue
else:
clean_metadata[key] = value
# Build the initial payload
payload = {
"id": id,
"call_type": call_type,
"cache_hit": cache_hit,
"start_time": start_time,
"end_time": end_time,
"response_time": response_time,
"model": kwargs.get("model", ""),
"user": kwargs.get("user", ""),
"model_parameters": optional_params,
"spend": kwargs.get("response_cost", 0),
"messages": messages,
"response": response_obj,
"usage": usage,
"metadata": clean_metadata,
}
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
json_payload = safe_dumps(payload)
verbose_logger.debug("Datadog: Logger - Logging payload = %s", json_payload)
dd_payload = DatadogPayload(
ddsource=get_datadog_source(),
ddtags=get_datadog_tags(),
hostname=get_datadog_hostname(),
message=json_payload,
service=get_datadog_service(),
status=DataDogStatus.INFO,
)
return dd_payload
def _add_trace_context_to_payload(
self,
dd_payload: DatadogPayload,
) -> None:
"""Attach Datadog APM trace context if one is active."""
try:
trace_context = self._get_active_trace_context()
if trace_context is None:
return
dd_payload["dd.trace_id"] = trace_context["trace_id"]
span_id = trace_context.get("span_id")
if span_id is not None:
dd_payload["dd.span_id"] = span_id
except Exception:
verbose_logger.exception(
"Datadog: Failed to attach trace context to payload"
)
def _get_active_trace_context(self) -> Optional[Dict[str, str]]:
try:
current_span = None
current_span_fn = getattr(tracer, "current_span", None)
if callable(current_span_fn):
current_span = current_span_fn()
if current_span is None:
current_root_span_fn = getattr(tracer, "current_root_span", None)
if callable(current_root_span_fn):
current_span = current_root_span_fn()
if current_span is None:
return None
trace_id = getattr(current_span, "trace_id", None)
if trace_id is None:
return None
span_id = getattr(current_span, "span_id", None)
trace_context: Dict[str, str] = {"trace_id": str(trace_id)}
if span_id is not None:
trace_context["span_id"] = str(span_id)
return trace_context
except Exception:
verbose_logger.exception(
"Datadog: Failed to retrieve active trace context from tracer"
)
return None
async def async_health_check(self) -> IntegrationHealthCheckStatus:
"""
Check if the service is healthy
"""
from litellm.litellm_core_utils.litellm_logging import (
create_dummy_standard_logging_payload,
)
standard_logging_object = create_dummy_standard_logging_payload()
dd_payload = self._create_datadog_logging_payload_helper(
standard_logging_object=standard_logging_object,
status=DataDogStatus.INFO,
)
log_queue = [dd_payload]
response = await self.async_send_compressed_data(log_queue)
try:
response.raise_for_status()
return IntegrationHealthCheckStatus(
status="healthy",
error_message=None,
)
except httpx.HTTPStatusError as e:
return IntegrationHealthCheckStatus(
status="unhealthy",
error_message=e.response.text,
)
except Exception as e:
return IntegrationHealthCheckStatus(
status="unhealthy",
error_message=str(e),
)
async def get_request_response_payload(
self,
request_id: str,
start_time_utc: Optional[datetimeObj],
end_time_utc: Optional[datetimeObj],
) -> Optional[dict]:
pass

View File

@@ -0,0 +1,216 @@
import asyncio
import os
import time
from datetime import datetime
from typing import Dict, List, Optional, Tuple
from litellm._logging import verbose_logger
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.integrations.datadog_cost_management import (
DatadogFOCUSCostEntry,
)
from litellm.types.utils import StandardLoggingPayload
class DatadogCostManagementLogger(CustomBatchLogger):
def __init__(self, **kwargs):
self.dd_api_key = os.getenv("DD_API_KEY")
self.dd_app_key = os.getenv("DD_APP_KEY")
self.dd_site = os.getenv("DD_SITE", "datadoghq.com")
if not self.dd_api_key or not self.dd_app_key:
verbose_logger.warning(
"Datadog Cost Management: DD_API_KEY and DD_APP_KEY are required. Integration will not work."
)
self.upload_url = f"https://api.{self.dd_site}/api/v2/cost/custom_costs"
self.async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
# Initialize lock and start periodic flush task
self.flush_lock = asyncio.Lock()
asyncio.create_task(self.periodic_flush())
# Check if flush_lock is already in kwargs to avoid double passing (unlikely but safe)
if "flush_lock" not in kwargs:
kwargs["flush_lock"] = self.flush_lock
super().__init__(**kwargs)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if standard_logging_object is None:
return
# Only log if there is a cost associated
if standard_logging_object.get("response_cost", 0) > 0:
self.log_queue.append(standard_logging_object)
if len(self.log_queue) >= self.batch_size:
await self.async_send_batch()
except Exception as e:
verbose_logger.exception(
f"Datadog Cost Management: Error in async_log_success_event: {str(e)}"
)
async def async_send_batch(self):
if not self.log_queue:
return
try:
# Aggregate costs from the batch
aggregated_entries = self._aggregate_costs(self.log_queue)
if not aggregated_entries:
return
# Send to Datadog
await self._upload_to_datadog(aggregated_entries)
# Clear queue only on success (or if we decide to drop on failure)
# CustomBatchLogger clears queue in flush_queue, so we just process here
except Exception as e:
verbose_logger.exception(
f"Datadog Cost Management: Error in async_send_batch: {str(e)}"
)
def _aggregate_costs(
self, logs: List[StandardLoggingPayload]
) -> List[DatadogFOCUSCostEntry]:
"""
Aggregates costs by Provider, Model, and Date.
Returns a list of DatadogFOCUSCostEntry.
"""
aggregator: Dict[
Tuple[str, str, str, Tuple[Tuple[str, str], ...]], DatadogFOCUSCostEntry
] = {}
for log in logs:
try:
# Extract keys for aggregation
provider = log.get("custom_llm_provider") or "unknown"
model = log.get("model") or "unknown"
cost = log.get("response_cost", 0)
if cost == 0:
continue
# Get date strings (FOCUS format requires specific keys, but for aggregation we group by Day)
# UTC date
# We interpret "ChargePeriod" as the day of the request.
ts = log.get("startTime") or time.time()
dt = datetime.fromtimestamp(ts)
date_str = dt.strftime("%Y-%m-%d")
# ChargePeriodStart and End
# If we want daily granularity, end date is usually same day or next day?
# Datadog Custom Costs usually expects periods.
# "ChargePeriodStart": "2023-01-01", "ChargePeriodEnd": "2023-12-31" in example.
# If we send daily, we can say Start=Date, End=Date.
# Grouping Key: Provider + Model + Date + Tags?
# For simplicity, let's aggregate by Provider + Model + Date first.
# If we handle tags, we need to include them in the key.
tags = self._extract_tags(log)
tags_key = tuple(sorted(tags.items())) if tags else ()
key = (provider, model, date_str, tags_key)
if key not in aggregator:
aggregator[key] = {
"ProviderName": provider,
"ChargeDescription": f"LLM Usage for {model}",
"ChargePeriodStart": date_str,
"ChargePeriodEnd": date_str,
"BilledCost": 0.0,
"BillingCurrency": "USD",
"Tags": tags if tags else None,
}
aggregator[key]["BilledCost"] += cost
except Exception as e:
verbose_logger.warning(
f"Error processing log for cost aggregation: {e}"
)
continue
return list(aggregator.values())
def _extract_tags(self, log: StandardLoggingPayload) -> Dict[str, str]:
from litellm.integrations.datadog.datadog_handler import (
get_datadog_env,
get_datadog_hostname,
get_datadog_pod_name,
get_datadog_service,
)
tags = {
"env": get_datadog_env(),
"service": get_datadog_service(),
"host": get_datadog_hostname(),
"pod_name": get_datadog_pod_name(),
}
# Add metadata as tags
metadata = log.get("metadata", {})
if metadata:
# Add user info
# Add user info
if metadata.get("user_api_key_alias"):
tags["user"] = str(metadata["user_api_key_alias"])
# Add Team Tag
team_tag = (
metadata.get("user_api_key_team_alias")
or metadata.get("team_alias") # type: ignore
or metadata.get("user_api_key_team_id")
or metadata.get("team_id") # type: ignore
)
if team_tag:
tags["team"] = str(team_tag)
# model_group is not in StandardLoggingMetadata TypedDict, so we need to access it via dict.get()
model_group = metadata.get("model_group") # type: ignore[misc]
if model_group:
tags["model_group"] = str(model_group)
return tags
async def _upload_to_datadog(self, payload: List[Dict]):
if not self.dd_api_key or not self.dd_app_key:
return
headers = {
"Content-Type": "application/json",
"DD-API-KEY": self.dd_api_key,
"DD-APPLICATION-KEY": self.dd_app_key,
}
# The API endpoint expects a list of objects directly in the body (file content behavior)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
data_json = safe_dumps(payload)
response = await self.async_client.put(
self.upload_url, content=data_json, headers=headers
)
response.raise_for_status()
verbose_logger.debug(
f"Datadog Cost Management: Uploaded {len(payload)} cost entries. Status: {response.status_code}"
)

View File

@@ -0,0 +1,69 @@
"""Shared helpers for Datadog integrations."""
from __future__ import annotations
import os
from typing import List, Optional
from litellm.types.utils import StandardLoggingPayload
def get_datadog_source() -> str:
return os.getenv("DD_SOURCE", "litellm")
def get_datadog_service() -> str:
return os.getenv("DD_SERVICE", "litellm-server")
def get_datadog_hostname() -> str:
return os.getenv("HOSTNAME", "")
def get_datadog_base_url_from_env() -> Optional[str]:
"""
Get base URL override from common DD_BASE_URL env var.
This is useful for testing or custom endpoints.
"""
return os.getenv("DD_BASE_URL")
def get_datadog_env() -> str:
return os.getenv("DD_ENV", "unknown")
def get_datadog_pod_name() -> str:
return os.getenv("POD_NAME", "unknown")
def get_datadog_tags(
standard_logging_object: Optional[StandardLoggingPayload] = None,
) -> str:
"""Build Datadog tags string used by multiple integrations."""
base_tags = {
"env": get_datadog_env(),
"service": get_datadog_service(),
"version": os.getenv("DD_VERSION", "unknown"),
"HOSTNAME": get_datadog_hostname(),
"POD_NAME": get_datadog_pod_name(),
}
tags: List[str] = [f"{k}:{v}" for k, v in base_tags.items()]
if standard_logging_object:
request_tags = standard_logging_object.get("request_tags", []) or []
tags.extend(f"request_tag:{tag}" for tag in request_tags)
# Add Team Tag
metadata = standard_logging_object.get("metadata", {}) or {}
team_tag = (
metadata.get("user_api_key_team_alias")
or metadata.get("team_alias")
or metadata.get("user_api_key_team_id")
or metadata.get("team_id")
)
if team_tag:
tags.append(f"team:{team_tag}")
return ",".join(tags)

View File

@@ -0,0 +1,856 @@
"""
Implements logging integration with Datadog's LLM Observability Service
API Reference: https://docs.datadoghq.com/llm_observability/setup/api/?tab=example#api-standards
"""
import asyncio
import json
import os
from litellm._uuid import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.integrations.datadog.datadog_mock_client import (
should_use_datadog_mock,
create_mock_datadog_client,
)
from litellm.integrations.datadog.datadog_handler import (
get_datadog_service,
get_datadog_tags,
get_datadog_base_url_from_env,
)
from litellm.litellm_core_utils.dd_tracing import tracer
from litellm.litellm_core_utils.prompt_templates.common_utils import (
handle_any_messages_to_chat_completion_str_messages_conversion,
)
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.integrations.datadog_llm_obs import *
from litellm.types.utils import (
CallTypes,
StandardLoggingGuardrailInformation,
StandardLoggingPayload,
StandardLoggingPayloadErrorInformation,
)
class DataDogLLMObsLogger(CustomBatchLogger):
def __init__(self, **kwargs):
try:
verbose_logger.debug("DataDogLLMObs: Initializing logger")
self.is_mock_mode = should_use_datadog_mock()
if self.is_mock_mode:
create_mock_datadog_client()
verbose_logger.debug(
"[DATADOG MOCK] DataDogLLMObs logger initialized in mock mode"
)
# Configure DataDog endpoint (Agent or Direct API)
# Use LITELLM_DD_AGENT_HOST to avoid conflicts with ddtrace's DD_AGENT_HOST
# Check for agent mode FIRST - agent mode doesn't require DD_API_KEY or DD_SITE
dd_agent_host = os.getenv("LITELLM_DD_AGENT_HOST")
self.async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
self.DD_API_KEY = os.getenv("DD_API_KEY")
if dd_agent_host:
self._configure_dd_agent(dd_agent_host=dd_agent_host)
else:
# Only require DD_API_KEY and DD_SITE for direct API mode
if os.getenv("DD_API_KEY", None) is None:
raise Exception("DD_API_KEY is not set, set 'DD_API_KEY=<>'")
if os.getenv("DD_SITE", None) is None:
raise Exception(
"DD_SITE is not set, set 'DD_SITE=<>', example sit = `us5.datadoghq.com`"
)
self._configure_dd_direct_api()
# Optional override for testing
dd_base_url = get_datadog_base_url_from_env()
if dd_base_url:
self.intake_url = f"{dd_base_url}/api/intake/llm-obs/v1/trace/spans"
asyncio.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock()
self.log_queue: List[LLMObsPayload] = []
#########################################################
# Handle datadog_llm_observability_params set as litellm.datadog_llm_observability_params
#########################################################
dict_datadog_llm_obs_params = self._get_datadog_llm_obs_params()
kwargs.update(dict_datadog_llm_obs_params)
CustomBatchLogger.__init__(self, **kwargs, flush_lock=self.flush_lock)
except Exception as e:
verbose_logger.exception(f"DataDogLLMObs: Error initializing - {str(e)}")
raise e
def _configure_dd_agent(self, dd_agent_host: str):
"""
Configure the Datadog logger to send traces to the Agent.
"""
# When using the Agent, LLM Observability Intake does NOT require the API Key
# Reference: https://docs.datadoghq.com/llm_observability/setup/sdk/#agent-setup
# Use specific port for LLM Obs (Trace Agent) to avoid conflict with Logs Agent (10518)
agent_port = os.getenv("LITELLM_DD_LLM_OBS_PORT", "8126")
self.DD_SITE = "localhost" # Not used for URL construction in agent mode
self.intake_url = (
f"http://{dd_agent_host}:{agent_port}/api/intake/llm-obs/v1/trace/spans"
)
verbose_logger.debug(f"DataDogLLMObs: Using DD Agent at {self.intake_url}")
def _configure_dd_direct_api(self):
"""
Configure the Datadog logger to send traces directly to the Datadog API.
"""
if not self.DD_API_KEY:
raise Exception("DD_API_KEY is not set, set 'DD_API_KEY=<>'")
self.DD_SITE = os.getenv("DD_SITE")
if not self.DD_SITE:
raise Exception(
"DD_SITE is not set, set 'DD_SITE=<>', example site = `us5.datadoghq.com`"
)
self.intake_url = (
f"https://api.{self.DD_SITE}/api/intake/llm-obs/v1/trace/spans"
)
def _get_datadog_llm_obs_params(self) -> Dict:
"""
Get the datadog_llm_observability_params from litellm.datadog_llm_observability_params
These are params specific to initializing the DataDogLLMObsLogger e.g. turn_off_message_logging
"""
dict_datadog_llm_obs_params: Dict = {}
if litellm.datadog_llm_observability_params is not None:
if isinstance(
litellm.datadog_llm_observability_params, DatadogLLMObsInitParams
):
dict_datadog_llm_obs_params = (
litellm.datadog_llm_observability_params.model_dump()
)
elif isinstance(litellm.datadog_llm_observability_params, Dict):
# only allow params that are of DatadogLLMObsInitParams
dict_datadog_llm_obs_params = DatadogLLMObsInitParams(
**litellm.datadog_llm_observability_params
).model_dump()
return dict_datadog_llm_obs_params
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
verbose_logger.debug(
f"DataDogLLMObs: Logging success event for model {kwargs.get('model', 'unknown')}"
)
payload = self.create_llm_obs_payload(kwargs, start_time, end_time)
verbose_logger.debug(f"DataDogLLMObs: Payload: {payload}")
self.log_queue.append(payload)
if len(self.log_queue) >= self.batch_size:
await self.async_send_batch()
except Exception as e:
verbose_logger.exception(
f"DataDogLLMObs: Error logging success event - {str(e)}"
)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
verbose_logger.debug(
f"DataDogLLMObs: Logging failure event for model {kwargs.get('model', 'unknown')}"
)
payload = self.create_llm_obs_payload(kwargs, start_time, end_time)
verbose_logger.debug(f"DataDogLLMObs: Payload: {payload}")
self.log_queue.append(payload)
if len(self.log_queue) >= self.batch_size:
await self.async_send_batch()
except Exception as e:
verbose_logger.exception(
f"DataDogLLMObs: Error logging failure event - {str(e)}"
)
async def async_send_batch(self):
try:
if not self.log_queue:
return
verbose_logger.debug(
f"DataDogLLMObs: Flushing {len(self.log_queue)} events"
)
if self.is_mock_mode:
verbose_logger.debug(
"[DATADOG MOCK] Mock mode enabled - API calls will be intercepted"
)
# Prepare the payload
payload = {
"data": DDIntakePayload(
type="span",
attributes=DDSpanAttributes(
ml_app=get_datadog_service(),
tags=[get_datadog_tags()],
spans=self.log_queue,
),
),
}
# serialize datetime objects - for budget reset time in spend metrics
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
try:
verbose_logger.debug("payload %s", safe_dumps(payload))
except Exception as debug_error:
verbose_logger.debug(
"payload serialization failed: %s", str(debug_error)
)
json_payload = safe_dumps(payload)
headers = {"Content-Type": "application/json"}
if self.DD_API_KEY:
headers["DD-API-KEY"] = self.DD_API_KEY
response = await self.async_client.post(
url=self.intake_url,
content=json_payload,
headers=headers,
)
if response.status_code != 202:
raise Exception(
f"DataDogLLMObs: Unexpected response - status_code: {response.status_code}, text: {response.text}"
)
if self.is_mock_mode:
verbose_logger.debug(
f"[DATADOG MOCK] Batch of {len(self.log_queue)} events successfully mocked"
)
else:
verbose_logger.debug(
f"DataDogLLMObs: Successfully sent batch - status_code: {response.status_code}"
)
self.log_queue.clear()
except httpx.HTTPStatusError as e:
verbose_logger.exception(
f"DataDogLLMObs: Error sending batch - {e.response.text}"
)
except Exception as e:
verbose_logger.exception(f"DataDogLLMObs: Error sending batch - {str(e)}")
def create_llm_obs_payload(
self, kwargs: Dict, start_time: datetime, end_time: datetime
) -> LLMObsPayload:
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if standard_logging_payload is None:
raise Exception("DataDogLLMObs: standard_logging_object is not set")
messages = standard_logging_payload["messages"]
messages = self._ensure_string_content(messages=messages)
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
input_meta = InputMeta(
messages=handle_any_messages_to_chat_completion_str_messages_conversion(
messages
)
)
output_meta = OutputMeta(
messages=self._get_response_messages(
standard_logging_payload=standard_logging_payload,
call_type=standard_logging_payload.get("call_type"),
)
)
error_info = self._assemble_error_info(standard_logging_payload)
metadata_parent_id: Optional[str] = None
if isinstance(metadata, dict):
metadata_parent_id = metadata.get("parent_id")
meta = Meta(
kind=self._get_datadog_span_kind(
standard_logging_payload.get("call_type"), metadata_parent_id
),
input=input_meta,
output=output_meta,
metadata=self._get_dd_llm_obs_payload_metadata(standard_logging_payload),
error=error_info,
)
# Calculate metrics (you may need to adjust these based on available data)
metrics = LLMMetrics(
input_tokens=float(standard_logging_payload.get("prompt_tokens", 0)),
output_tokens=float(standard_logging_payload.get("completion_tokens", 0)),
total_tokens=float(standard_logging_payload.get("total_tokens", 0)),
total_cost=float(standard_logging_payload.get("response_cost", 0)),
time_to_first_token=self._get_time_to_first_token_seconds(
standard_logging_payload
),
)
payload: LLMObsPayload = LLMObsPayload(
parent_id=metadata_parent_id if metadata_parent_id else "undefined",
trace_id=standard_logging_payload.get("trace_id", str(uuid.uuid4())),
span_id=metadata.get("span_id", str(uuid.uuid4())),
name=metadata.get("name", "litellm_llm_call"),
meta=meta,
start_ns=int(start_time.timestamp() * 1e9),
duration=int((end_time - start_time).total_seconds() * 1e9),
metrics=metrics,
status="error" if error_info else "ok",
tags=[get_datadog_tags(standard_logging_object=standard_logging_payload)],
)
apm_trace_id = self._get_apm_trace_id()
if apm_trace_id is not None:
payload["apm_id"] = apm_trace_id
return payload
def _get_apm_trace_id(self) -> Optional[str]:
"""Retrieve the current APM trace ID if available."""
try:
current_span_fn = getattr(tracer, "current_span", None)
if callable(current_span_fn):
current_span = current_span_fn()
if current_span is not None:
trace_id = getattr(current_span, "trace_id", None)
if trace_id is not None:
return str(trace_id)
except Exception:
pass
return None
def _assemble_error_info(
self, standard_logging_payload: StandardLoggingPayload
) -> Optional[DDLLMObsError]:
"""
Assemble error information for failure cases according to DD LLM Obs API spec
"""
# Handle error information for failure cases according to DD LLM Obs API spec
error_info: Optional[DDLLMObsError] = None
if standard_logging_payload.get("status") == "failure":
# Try to get structured error information first
error_information: Optional[
StandardLoggingPayloadErrorInformation
] = standard_logging_payload.get("error_information")
if error_information:
error_info = DDLLMObsError(
message=error_information.get("error_message")
or standard_logging_payload.get("error_str")
or "Unknown error",
type=error_information.get("error_class"),
stack=error_information.get("traceback"),
)
return error_info
def _get_time_to_first_token_seconds(
self, standard_logging_payload: StandardLoggingPayload
) -> float:
"""
Get the time to first token in seconds
CompletionStartTime - StartTime = Time to first token
For non streaming calls, CompletionStartTime is time we get the response back
"""
start_time: Optional[float] = standard_logging_payload.get("startTime")
completion_start_time: Optional[float] = standard_logging_payload.get(
"completionStartTime"
)
end_time: Optional[float] = standard_logging_payload.get("endTime")
if completion_start_time is not None and start_time is not None:
return completion_start_time - start_time
elif end_time is not None and start_time is not None:
return end_time - start_time
else:
return 0.0
def _get_response_messages(
self, standard_logging_payload: StandardLoggingPayload, call_type: Optional[str]
) -> List[Any]:
"""
Get the messages from the response object
for now this handles logging /chat/completions responses
"""
response_obj = standard_logging_payload.get("response")
if response_obj is None:
return []
# edge case: handle response_obj is a string representation of a dict
if isinstance(response_obj, str):
try:
import ast
response_obj = ast.literal_eval(response_obj)
except (ValueError, SyntaxError):
try:
# fallback to json parsing
response_obj = json.loads(str(response_obj))
except json.JSONDecodeError:
return []
if call_type in [
CallTypes.completion.value,
CallTypes.acompletion.value,
CallTypes.text_completion.value,
CallTypes.atext_completion.value,
CallTypes.generate_content.value,
CallTypes.agenerate_content.value,
CallTypes.generate_content_stream.value,
CallTypes.agenerate_content_stream.value,
CallTypes.anthropic_messages.value,
]:
try:
# Safely extract message from response_obj, handle failure cases
if isinstance(response_obj, dict) and "choices" in response_obj:
choices = response_obj["choices"]
if choices and len(choices) > 0 and "message" in choices[0]:
return [choices[0]["message"]]
return []
except (KeyError, IndexError, TypeError):
# In case of any error accessing the response structure, return empty list
return []
return []
def _get_datadog_span_kind(
self, call_type: Optional[str], parent_id: Optional[str] = None
) -> Literal["llm", "tool", "task", "embedding", "retrieval"]:
"""
Map liteLLM call_type to appropriate DataDog LLM Observability span kind.
Available DataDog span kinds: "llm", "tool", "task", "embedding", "retrieval"
see: https://docs.datadoghq.com/ja/llm_observability/terms/
"""
# Non llm/workflow/agent kinds cannot be root spans, so fallback to "llm" when parent metadata is missing
if call_type is None or parent_id is None:
return "llm"
# Embedding operations
if call_type in [CallTypes.embedding.value, CallTypes.aembedding.value]:
return "embedding"
# LLM completion operations
if call_type in [
CallTypes.completion.value,
CallTypes.acompletion.value,
CallTypes.text_completion.value,
CallTypes.atext_completion.value,
CallTypes.generate_content.value,
CallTypes.agenerate_content.value,
CallTypes.generate_content_stream.value,
CallTypes.agenerate_content_stream.value,
CallTypes.anthropic_messages.value,
CallTypes.responses.value,
CallTypes.aresponses.value,
]:
return "llm"
# Tool operations
if call_type in [CallTypes.call_mcp_tool.value]:
return "tool"
# Retrieval operations
if call_type in [
CallTypes.get_assistants.value,
CallTypes.aget_assistants.value,
CallTypes.get_thread.value,
CallTypes.aget_thread.value,
CallTypes.get_messages.value,
CallTypes.aget_messages.value,
CallTypes.afile_retrieve.value,
CallTypes.file_retrieve.value,
CallTypes.afile_list.value,
CallTypes.file_list.value,
CallTypes.afile_content.value,
CallTypes.file_content.value,
CallTypes.retrieve_batch.value,
CallTypes.aretrieve_batch.value,
CallTypes.retrieve_fine_tuning_job.value,
CallTypes.aretrieve_fine_tuning_job.value,
CallTypes.alist_input_items.value,
]:
return "retrieval"
# Task operations (batch, fine-tuning, file operations, etc.)
if call_type in [
CallTypes.create_batch.value,
CallTypes.acreate_batch.value,
CallTypes.create_fine_tuning_job.value,
CallTypes.acreate_fine_tuning_job.value,
CallTypes.cancel_fine_tuning_job.value,
CallTypes.acancel_fine_tuning_job.value,
CallTypes.list_fine_tuning_jobs.value,
CallTypes.alist_fine_tuning_jobs.value,
CallTypes.create_assistants.value,
CallTypes.acreate_assistants.value,
CallTypes.delete_assistant.value,
CallTypes.adelete_assistant.value,
CallTypes.create_thread.value,
CallTypes.acreate_thread.value,
CallTypes.add_message.value,
CallTypes.a_add_message.value,
CallTypes.run_thread.value,
CallTypes.arun_thread.value,
CallTypes.run_thread_stream.value,
CallTypes.arun_thread_stream.value,
CallTypes.file_delete.value,
CallTypes.afile_delete.value,
CallTypes.create_file.value,
CallTypes.acreate_file.value,
CallTypes.image_generation.value,
CallTypes.aimage_generation.value,
CallTypes.image_edit.value,
CallTypes.aimage_edit.value,
CallTypes.moderation.value,
CallTypes.amoderation.value,
CallTypes.transcription.value,
CallTypes.atranscription.value,
CallTypes.speech.value,
CallTypes.aspeech.value,
CallTypes.rerank.value,
CallTypes.arerank.value,
]:
return "task"
# Default fallback for unknown or passthrough operations
return "llm"
def _ensure_string_content(
self, messages: Optional[Union[str, List[Any], Dict[Any, Any]]]
) -> List[Any]:
if messages is None:
return []
if isinstance(messages, str):
return [messages]
elif isinstance(messages, list):
return [message for message in messages]
elif isinstance(messages, dict):
return [str(messages.get("content", ""))]
return []
def _get_dd_llm_obs_payload_metadata(
self, standard_logging_payload: StandardLoggingPayload
) -> Dict[str, Any]:
"""
Fields to track in DD LLM Observability metadata from litellm standard logging payload
"""
_metadata: Dict[str, Any] = {
"model_name": standard_logging_payload.get("model", "unknown"),
"model_provider": standard_logging_payload.get(
"custom_llm_provider", "unknown"
),
"id": standard_logging_payload.get("id", "unknown"),
"trace_id": standard_logging_payload.get("trace_id", "unknown"),
"cache_hit": standard_logging_payload.get("cache_hit", "unknown"),
"cache_key": standard_logging_payload.get("cache_key", "unknown"),
"saved_cache_cost": standard_logging_payload.get("saved_cache_cost", 0),
"guardrail_information": standard_logging_payload.get(
"guardrail_information", None
),
"is_streamed_request": self._get_stream_value_from_payload(
standard_logging_payload
),
}
#########################################################
# Add latency metrics to metadata
#########################################################
latency_metrics = self._get_latency_metrics(standard_logging_payload)
_metadata.update({"latency_metrics": dict(latency_metrics)})
#########################################################
# Add spend metrics to metadata
#########################################################
spend_metrics = self._get_spend_metrics(standard_logging_payload)
_metadata.update({"spend_metrics": dict(spend_metrics)})
## extract tool calls and add to metadata
tool_call_metadata = self._extract_tool_call_metadata(standard_logging_payload)
_metadata.update(tool_call_metadata)
_standard_logging_metadata: dict = (
dict(standard_logging_payload.get("metadata", {})) or {}
)
_metadata.update(_standard_logging_metadata)
return _metadata
def _get_latency_metrics(
self, standard_logging_payload: StandardLoggingPayload
) -> DDLLMObsLatencyMetrics:
"""
Get the latency metrics from the standard logging payload
"""
latency_metrics: DDLLMObsLatencyMetrics = DDLLMObsLatencyMetrics()
# Add latency metrics to metadata
# Time to first token (convert from seconds to milliseconds for consistency)
time_to_first_token_seconds = self._get_time_to_first_token_seconds(
standard_logging_payload
)
if time_to_first_token_seconds > 0:
latency_metrics["time_to_first_token_ms"] = (
time_to_first_token_seconds * 1000
)
# LiteLLM overhead time
hidden_params = standard_logging_payload.get("hidden_params", {})
litellm_overhead_ms = hidden_params.get("litellm_overhead_time_ms")
if litellm_overhead_ms is not None:
latency_metrics["litellm_overhead_time_ms"] = litellm_overhead_ms
# Guardrail overhead latency
guardrail_info: Optional[
list[StandardLoggingGuardrailInformation]
] = standard_logging_payload.get("guardrail_information")
if guardrail_info is not None:
total_duration = 0.0
for info in guardrail_info:
_guardrail_duration_seconds: Optional[float] = info.get("duration")
if _guardrail_duration_seconds is not None:
total_duration += float(_guardrail_duration_seconds)
if total_duration > 0:
# Convert from seconds to milliseconds for consistency
latency_metrics["guardrail_overhead_time_ms"] = total_duration * 1000
return latency_metrics
def _get_stream_value_from_payload(
self, standard_logging_payload: StandardLoggingPayload
) -> bool:
"""
Extract the stream value from standard logging payload.
The stream field in StandardLoggingPayload is only set to True for completed streaming responses.
For non-streaming requests, it's None. The original stream parameter is in model_parameters.
Returns:
bool: True if this was a streaming request, False otherwise
"""
# Check top-level stream field first (only True for completed streaming)
stream_value = standard_logging_payload.get("stream")
if stream_value is True:
return True
# Fallback to model_parameters.stream for original request parameters
model_params = standard_logging_payload.get("model_parameters", {})
if isinstance(model_params, dict):
stream_value = model_params.get("stream")
if stream_value is True:
return True
# Default to False for non-streaming requests
return False
def _get_spend_metrics(
self, standard_logging_payload: StandardLoggingPayload
) -> DDLLMObsSpendMetrics:
"""
Get the spend metrics from the standard logging payload
"""
spend_metrics: DDLLMObsSpendMetrics = DDLLMObsSpendMetrics()
# send response cost
spend_metrics["response_cost"] = standard_logging_payload.get(
"response_cost", 0.0
)
# Get budget information from metadata
metadata = standard_logging_payload.get("metadata", {})
# API key max budget
user_api_key_max_budget = metadata.get("user_api_key_max_budget")
if user_api_key_max_budget is not None:
spend_metrics["user_api_key_max_budget"] = float(user_api_key_max_budget)
# API key spend
user_api_key_spend = metadata.get("user_api_key_spend")
if user_api_key_spend is not None:
try:
spend_metrics["user_api_key_spend"] = float(user_api_key_spend)
except (ValueError, TypeError):
verbose_logger.debug(
f"Invalid user_api_key_spend value: {user_api_key_spend}"
)
# API key budget reset datetime
user_api_key_budget_reset_at = metadata.get("user_api_key_budget_reset_at")
if user_api_key_budget_reset_at is not None:
try:
from datetime import datetime, timezone
budget_reset_at = None
if isinstance(user_api_key_budget_reset_at, str):
# Handle ISO format strings that might have 'Z' suffix
iso_string = user_api_key_budget_reset_at.replace("Z", "+00:00")
budget_reset_at = datetime.fromisoformat(iso_string)
elif isinstance(user_api_key_budget_reset_at, datetime):
budget_reset_at = user_api_key_budget_reset_at
if budget_reset_at is not None:
# Preserve timezone info if already present
if budget_reset_at.tzinfo is None:
budget_reset_at = budget_reset_at.replace(tzinfo=timezone.utc)
# Convert to ISO string format for JSON serialization
# This prevents circular reference issues and ensures proper timezone representation
iso_string = budget_reset_at.isoformat()
spend_metrics["user_api_key_budget_reset_at"] = iso_string
# Debug logging to verify the conversion
verbose_logger.debug(
f"Converted budget_reset_at to ISO format: {iso_string}"
)
except Exception as e:
verbose_logger.debug(f"Error processing budget reset datetime: {e}")
verbose_logger.debug(f"Original value: {user_api_key_budget_reset_at}")
return spend_metrics
def _process_input_messages_preserving_tool_calls(
self, messages: List[Any]
) -> List[Dict[str, Any]]:
"""
Process input messages while preserving tool_calls and tool message types.
This bypasses the lossy string conversion when tool calls are present,
allowing complex nested tool_calls objects to be preserved for Datadog.
"""
processed = []
for msg in messages:
if isinstance(msg, dict):
# Preserve messages with tool_calls or tool role as-is
if "tool_calls" in msg or msg.get("role") == "tool":
processed.append(msg)
else:
# For regular messages, still apply string conversion
converted = (
handle_any_messages_to_chat_completion_str_messages_conversion(
[msg]
)
)
processed.extend(converted)
else:
# For non-dict messages, apply string conversion
converted = (
handle_any_messages_to_chat_completion_str_messages_conversion(
[msg]
)
)
processed.extend(converted)
return processed
@staticmethod
def _tool_calls_kv_pair(tool_calls: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Extract tool call information into key-value pairs for Datadog metadata.
Similar to OpenTelemetry's implementation but adapted for Datadog's format.
"""
kv_pairs: Dict[str, Any] = {}
for idx, tool_call in enumerate(tool_calls):
try:
# Extract tool call ID
tool_id = tool_call.get("id")
if tool_id:
kv_pairs[f"tool_calls.{idx}.id"] = tool_id
# Extract tool call type
tool_type = tool_call.get("type")
if tool_type:
kv_pairs[f"tool_calls.{idx}.type"] = tool_type
# Extract function information
function = tool_call.get("function")
if function:
function_name = function.get("name")
if function_name:
kv_pairs[f"tool_calls.{idx}.function.name"] = function_name
function_arguments = function.get("arguments")
if function_arguments:
# Store arguments as JSON string for Datadog
if isinstance(function_arguments, str):
kv_pairs[
f"tool_calls.{idx}.function.arguments"
] = function_arguments
else:
import json
kv_pairs[
f"tool_calls.{idx}.function.arguments"
] = json.dumps(function_arguments)
except (KeyError, TypeError, ValueError) as e:
verbose_logger.debug(
f"DataDogLLMObs: Error processing tool call {idx}: {str(e)}"
)
continue
return kv_pairs
def _extract_tool_call_metadata(
self, standard_logging_payload: StandardLoggingPayload
) -> Dict[str, Any]:
"""
Extract tool call information from both input messages and response for Datadog metadata.
"""
tool_call_metadata: Dict[str, Any] = {}
try:
# Extract tool calls from input messages
messages = standard_logging_payload.get("messages", [])
if messages and isinstance(messages, list):
for message in messages:
if isinstance(message, dict) and "tool_calls" in message:
tool_calls = message.get("tool_calls")
if tool_calls:
input_tool_calls_kv = self._tool_calls_kv_pair(tool_calls)
# Prefix with "input_" to distinguish from response tool calls
for key, value in input_tool_calls_kv.items():
tool_call_metadata[f"input_{key}"] = value
# Extract tool calls from response
response_obj = standard_logging_payload.get("response")
if response_obj and isinstance(response_obj, dict):
choices = response_obj.get("choices", [])
for choice in choices:
if isinstance(choice, dict):
message = choice.get("message")
if message and isinstance(message, dict):
tool_calls = message.get("tool_calls")
if tool_calls:
response_tool_calls_kv = self._tool_calls_kv_pair(
tool_calls
)
# Prefix with "output_" to distinguish from input tool calls
for key, value in response_tool_calls_kv.items():
tool_call_metadata[f"output_{key}"] = value
except Exception as e:
verbose_logger.debug(
f"DataDogLLMObs: Error extracting tool call metadata: {str(e)}"
)
return tool_call_metadata

View File

@@ -0,0 +1,286 @@
import asyncio
import gzip
import os
import time
from datetime import datetime
from typing import List, Optional, Union
from litellm._logging import verbose_logger
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.integrations.datadog.datadog_handler import (
get_datadog_env,
get_datadog_hostname,
get_datadog_pod_name,
get_datadog_service,
)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
from litellm.types.integrations.datadog_metrics import (
DatadogMetricPoint,
DatadogMetricSeries,
DatadogMetricsPayload,
)
from litellm.types.utils import StandardLoggingPayload
class DatadogMetricsLogger(CustomBatchLogger):
def __init__(self, start_periodic_flush: bool = True, **kwargs):
self.dd_api_key = os.getenv("DD_API_KEY")
self.dd_app_key = os.getenv("DD_APP_KEY")
self.dd_site = os.getenv("DD_SITE", "datadoghq.com")
if not self.dd_api_key:
verbose_logger.warning(
"Datadog Metrics: DD_API_KEY is required. Integration will not work."
)
self.upload_url = f"https://api.{self.dd_site}/api/v2/series"
self.async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
# Initialize lock
self.flush_lock = asyncio.Lock()
# Only set flush_lock if not already provided by caller
if "flush_lock" not in kwargs:
kwargs["flush_lock"] = self.flush_lock
# Send metrics more quickly to datadog (every 5 seconds)
if "flush_interval" not in kwargs:
kwargs["flush_interval"] = 5
super().__init__(**kwargs)
# Start periodic flush task only if instructed
if start_periodic_flush:
asyncio.create_task(self.periodic_flush())
def _extract_tags(
self,
log: StandardLoggingPayload,
status_code: Optional[Union[str, int]] = None,
) -> List[str]:
"""
Builds the list of tags for a Datadog metric point
"""
# Base tags
tags = [
f"env:{get_datadog_env()}",
f"service:{get_datadog_service()}",
f"version:{os.getenv('DD_VERSION', 'unknown')}",
f"HOSTNAME:{get_datadog_hostname()}",
f"POD_NAME:{get_datadog_pod_name()}",
]
# Add metric-specific tags
if provider := log.get("custom_llm_provider"):
tags.append(f"provider:{provider}")
if model := log.get("model"):
tags.append(f"model_name:{model}")
if model_group := log.get("model_group"):
tags.append(f"model_group:{model_group}")
if status_code is not None:
tags.append(f"status_code:{status_code}")
# Extract team tag
metadata = log.get("metadata", {}) or {}
team_tag = (
metadata.get("user_api_key_team_alias")
or metadata.get("team_alias") # type: ignore
or metadata.get("user_api_key_team_id")
or metadata.get("team_id") # type: ignore
)
if team_tag:
tags.append(f"team:{team_tag}")
return tags
def _add_metrics_from_log(
self,
log: StandardLoggingPayload,
kwargs: dict,
status_code: Union[str, int] = "200",
):
"""
Extracts latencies and appends Datadog metric series to the queue
"""
tags = self._extract_tags(log, status_code=status_code)
# We record metrics with the end_time as the timestamp for the point
end_time_dt = kwargs.get("end_time") or datetime.now()
timestamp = int(end_time_dt.timestamp())
# 1. Total Request Latency Metric (End to End)
start_time_dt = kwargs.get("start_time")
if start_time_dt and end_time_dt:
total_duration = (end_time_dt - start_time_dt).total_seconds()
series_total_latency: DatadogMetricSeries = {
"metric": "litellm.request.total_latency",
"type": 3, # gauge
"points": [{"timestamp": timestamp, "value": total_duration}],
"tags": tags,
}
self.log_queue.append(series_total_latency)
# 2. LLM API Latency Metric (Provider alone)
api_call_start_time = kwargs.get("api_call_start_time")
if api_call_start_time and end_time_dt:
llm_api_duration = (end_time_dt - api_call_start_time).total_seconds()
series_llm_latency: DatadogMetricSeries = {
"metric": "litellm.llm_api.latency",
"type": 3, # gauge
"points": [{"timestamp": timestamp, "value": llm_api_duration}],
"tags": tags,
}
self.log_queue.append(series_llm_latency)
# 3. Request Count / Status Code
series_count: DatadogMetricSeries = {
"metric": "litellm.llm_api.request_count",
"type": 1, # count
"points": [{"timestamp": timestamp, "value": 1.0}],
"tags": tags,
"interval": self.flush_interval,
}
self.log_queue.append(series_count)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if standard_logging_object is None:
return
self._add_metrics_from_log(
log=standard_logging_object, kwargs=kwargs, status_code="200"
)
if len(self.log_queue) >= self.batch_size:
await self.flush_queue()
except Exception as e:
verbose_logger.exception(
f"Datadog Metrics: Error in async_log_success_event: {str(e)}"
)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if standard_logging_object is None:
return
# Extract status code from error information
status_code = "500" # default
error_information = (
standard_logging_object.get("error_information", {}) or {}
)
error_code = error_information.get("error_code") # type: ignore
if error_code is not None:
status_code = str(error_code)
self._add_metrics_from_log(
log=standard_logging_object, kwargs=kwargs, status_code=status_code
)
if len(self.log_queue) >= self.batch_size:
await self.flush_queue()
except Exception as e:
verbose_logger.exception(
f"Datadog Metrics: Error in async_log_failure_event: {str(e)}"
)
async def async_send_batch(self):
if not self.log_queue:
return
batch = self.log_queue.copy()
payload_data: DatadogMetricsPayload = {"series": batch}
try:
await self._upload_to_datadog(payload_data)
except Exception as e:
verbose_logger.exception(
f"Datadog Metrics: Error in async_send_batch: {str(e)}"
)
raise
async def _upload_to_datadog(self, payload: DatadogMetricsPayload):
if not self.dd_api_key:
return
headers = {
"Content-Type": "application/json",
"DD-API-KEY": self.dd_api_key,
}
if self.dd_app_key:
headers["DD-APPLICATION-KEY"] = self.dd_app_key
json_data = safe_dumps(payload)
compressed_data = gzip.compress(json_data.encode("utf-8"))
headers["Content-Encoding"] = "gzip"
response = await self.async_client.post(
self.upload_url, content=compressed_data, headers=headers # type: ignore
)
response.raise_for_status()
verbose_logger.debug(
f"Datadog Metrics: Uploaded {len(payload['series'])} metric points. Status: {response.status_code}"
)
async def async_health_check(self) -> IntegrationHealthCheckStatus:
"""
Check if the service is healthy
"""
try:
# Send a test metric point to Datadog
test_metric_point: DatadogMetricPoint = {
"timestamp": int(time.time()),
"value": 1.0,
}
test_metric_series: DatadogMetricSeries = {
"metric": "litellm.health_check",
"type": 3, # Gauge
"points": [test_metric_point],
"tags": ["env:health_check"],
}
payload_data: DatadogMetricsPayload = {"series": [test_metric_series]}
await self._upload_to_datadog(payload_data)
return IntegrationHealthCheckStatus(
status="healthy",
error_message=None,
)
except Exception as e:
return IntegrationHealthCheckStatus(
status="unhealthy",
error_message=str(e),
)
async def get_request_response_payload(
self,
request_id: str,
start_time_utc: Optional[datetime],
end_time_utc: Optional[datetime],
) -> Optional[dict]:
pass

View File

@@ -0,0 +1,33 @@
"""
Mock client for Datadog integration testing.
This module intercepts Datadog API calls and returns successful mock responses,
allowing full code execution without making actual network calls.
Usage:
Set DATADOG_MOCK=true in environment variables or config to enable mock mode.
"""
from litellm.integrations.mock_client_factory import (
MockClientConfig,
create_mock_client_factory,
)
# Create mock client using factory
_config = MockClientConfig(
name="DATADOG",
env_var="DATADOG_MOCK",
default_latency_ms=100,
default_status_code=202,
default_json_data={"status": "ok"},
url_matchers=[
".datadoghq.com",
"datadoghq.com",
],
patch_async_handler=True,
patch_sync_client=True,
)
create_mock_datadog_client, should_use_datadog_mock = create_mock_client_factory(
_config
)

View File

@@ -0,0 +1,3 @@
from .deepeval import DeepEvalLogger
__all__ = ["DeepEvalLogger"]

View File

@@ -0,0 +1,120 @@
# duplicate -> https://github.com/confident-ai/deepeval/blob/main/deepeval/confident/api.py
import logging
import httpx
from enum import Enum
from litellm._logging import verbose_logger
DEEPEVAL_BASE_URL = "https://deepeval.confident-ai.com"
DEEPEVAL_BASE_URL_EU = "https://eu.deepeval.confident-ai.com"
API_BASE_URL = "https://api.confident-ai.com"
API_BASE_URL_EU = "https://eu.api.confident-ai.com"
retryable_exceptions = httpx.HTTPError
from litellm.llms.custom_httpx.http_handler import (
HTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
def log_retry_error(details):
exception = details.get("exception")
tries = details.get("tries")
if exception:
logging.error(f"Confident AI Error: {exception}. Retrying: {tries} time(s)...")
else:
logging.error(f"Retrying: {tries} time(s)...")
class HttpMethods(Enum):
GET = "GET"
POST = "POST"
DELETE = "DELETE"
PUT = "PUT"
class Endpoints(Enum):
DATASET_ENDPOINT = "/v1/dataset"
TEST_RUN_ENDPOINT = "/v1/test-run"
TRACING_ENDPOINT = "/v1/tracing"
EVENT_ENDPOINT = "/v1/event"
FEEDBACK_ENDPOINT = "/v1/feedback"
PROMPT_ENDPOINT = "/v1/prompt"
RECOMMEND_ENDPOINT = "/v1/recommend-metrics"
EVALUATE_ENDPOINT = "/evaluate"
GUARD_ENDPOINT = "/guard"
GUARDRAILS_ENDPOINT = "/guardrails"
BASELINE_ATTACKS_ENDPOINT = "/generate-baseline-attacks"
class Api:
def __init__(self, api_key: str, base_url=None):
self.api_key = api_key
self._headers = {
"Content-Type": "application/json",
# "User-Agent": "Python/Requests",
"CONFIDENT_API_KEY": api_key,
}
# using the global non-eu variable for base url
self.base_api_url = base_url or API_BASE_URL
self.sync_http_handler = HTTPHandler()
self.async_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
def _http_request(
self, method: str, url: str, headers=None, json=None, params=None
):
if method != "POST":
raise Exception("Only POST requests are supported")
try:
self.sync_http_handler.post(
url=url,
headers=headers,
json=json,
params=params,
)
except httpx.HTTPStatusError as e:
raise Exception(f"DeepEval logging error: {e.response.text}")
except Exception as e:
raise e
def send_request(
self, method: HttpMethods, endpoint: Endpoints, body=None, params=None
):
url = f"{self.base_api_url}{endpoint.value}"
res = self._http_request(
method=method.value,
url=url,
headers=self._headers,
json=body,
params=params,
)
if res.status_code == 200:
try:
return res.json()
except ValueError:
return res.text
else:
verbose_logger.debug(res.json())
raise Exception(res.json().get("error", res.text))
async def a_send_request(
self, method: HttpMethods, endpoint: Endpoints, body=None, params=None
):
if method != HttpMethods.POST:
raise Exception("Only POST requests are supported")
url = f"{self.base_api_url}{endpoint.value}"
try:
await self.async_http_handler.post(
url=url,
headers=self._headers,
json=body,
params=params,
)
except httpx.HTTPStatusError as e:
raise Exception(f"DeepEval logging error: {e.response.text}")
except Exception as e:
raise e

View File

@@ -0,0 +1,175 @@
import os
from litellm._uuid import uuid
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.deepeval.api import Api, Endpoints, HttpMethods
from litellm.integrations.deepeval.types import (
BaseApiSpan,
SpanApiType,
TraceApi,
TraceSpanApiStatus,
)
from litellm.integrations.deepeval.utils import (
to_zod_compatible_iso,
validate_environment,
)
from litellm._logging import verbose_logger
# This file includes the custom callbacks for LiteLLM Proxy
# Once defined, these can be passed in proxy_config.yaml
class DeepEvalLogger(CustomLogger):
"""Logs litellm traces to DeepEval's platform."""
def __init__(self, *args, **kwargs):
api_key = os.getenv("CONFIDENT_API_KEY")
self.litellm_environment = os.getenv("LITELM_ENVIRONMENT", "development")
validate_environment(self.litellm_environment)
if not api_key:
raise ValueError(
"Please set 'CONFIDENT_API_KEY=<>' in your environment variables."
)
self.api = Api(api_key=api_key)
super().__init__(*args, **kwargs)
def log_success_event(self, kwargs, response_obj, start_time, end_time):
"""Logs a success event to DeepEval's platform."""
self._sync_event_handler(
kwargs, response_obj, start_time, end_time, is_success=True
)
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
"""Logs a failure event to DeepEval's platform."""
self._sync_event_handler(
kwargs, response_obj, start_time, end_time, is_success=False
)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
"""Logs a failure event to DeepEval's platform."""
await self._async_event_handler(
kwargs, response_obj, start_time, end_time, is_success=False
)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""Logs a success event to DeepEval's platform."""
await self._async_event_handler(
kwargs, response_obj, start_time, end_time, is_success=True
)
def _prepare_trace_api(
self, kwargs, response_obj, start_time, end_time, is_success
):
_start_time = to_zod_compatible_iso(start_time)
_end_time = to_zod_compatible_iso(end_time)
_standard_logging_object = kwargs.get("standard_logging_object", {})
base_api_span = self._create_base_api_span(
kwargs,
standard_logging_object=_standard_logging_object,
start_time=_start_time,
end_time=_end_time,
is_success=is_success,
)
trace_api = self._create_trace_api(
base_api_span,
standard_logging_object=_standard_logging_object,
start_time=_start_time,
end_time=_end_time,
litellm_environment=self.litellm_environment,
)
body = {}
try:
body = trace_api.model_dump(by_alias=True, exclude_none=True)
except AttributeError:
# Pydantic version below 2.0
body = trace_api.dict(by_alias=True, exclude_none=True)
return body
def _sync_event_handler(
self, kwargs, response_obj, start_time, end_time, is_success
):
body = self._prepare_trace_api(
kwargs, response_obj, start_time, end_time, is_success
)
try:
response = self.api.send_request(
method=HttpMethods.POST,
endpoint=Endpoints.TRACING_ENDPOINT,
body=body,
)
except Exception as e:
raise e
verbose_logger.debug(
"DeepEvalLogger: sync_log_failure_event: Api response %s", response
)
async def _async_event_handler(
self, kwargs, response_obj, start_time, end_time, is_success
):
body = self._prepare_trace_api(
kwargs, response_obj, start_time, end_time, is_success
)
response = await self.api.a_send_request(
method=HttpMethods.POST,
endpoint=Endpoints.TRACING_ENDPOINT,
body=body,
)
verbose_logger.debug(
"DeepEvalLogger: async_event_handler: Api response %s", response
)
def _create_base_api_span(
self, kwargs, standard_logging_object, start_time, end_time, is_success
):
# extract usage
usage = standard_logging_object.get("response", {}).get("usage", {})
if is_success:
output = (
standard_logging_object.get("response", {})
.get("choices", [{}])[0]
.get("message", {})
.get("content", "NO_OUTPUT")
)
else:
output = str(standard_logging_object.get("error_string", ""))
return BaseApiSpan(
uuid=standard_logging_object.get("id", uuid.uuid4()),
name=(
"litellm_success_callback" if is_success else "litellm_failure_callback"
),
status=(
TraceSpanApiStatus.SUCCESS if is_success else TraceSpanApiStatus.ERRORED
),
type=SpanApiType.LLM,
traceUuid=standard_logging_object.get("trace_id", uuid.uuid4()),
startTime=str(start_time),
endTime=str(end_time),
input=kwargs.get("input", "NO_INPUT"),
output=output,
model=standard_logging_object.get("model", None),
inputTokenCount=usage.get("prompt_tokens", None) if is_success else None,
outputTokenCount=(
usage.get("completion_tokens", None) if is_success else None
),
)
def _create_trace_api(
self,
base_api_span,
standard_logging_object,
start_time,
end_time,
litellm_environment,
):
return TraceApi(
uuid=standard_logging_object.get("trace_id", uuid.uuid4()),
baseSpans=[],
agentSpans=[],
llmSpans=[base_api_span],
retrieverSpans=[],
toolSpans=[],
startTime=str(start_time),
endTime=str(end_time),
environment=litellm_environment,
)

View File

@@ -0,0 +1,63 @@
# Duplicate -> https://github.com/confident-ai/deepeval/blob/main/deepeval/tracing/api.py
from enum import Enum
from typing import Any, ClassVar, Dict, List, Optional, Union, Literal
from pydantic import BaseModel, Field, ConfigDict
class SpanApiType(Enum):
BASE = "base"
AGENT = "agent"
LLM = "llm"
RETRIEVER = "retriever"
TOOL = "tool"
span_api_type_literals = Literal["base", "agent", "llm", "retriever", "tool"]
class TraceSpanApiStatus(Enum):
SUCCESS = "SUCCESS"
ERRORED = "ERRORED"
class BaseApiSpan(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(use_enum_values=True)
uuid: str
name: Optional[str] = None
status: TraceSpanApiStatus
type: SpanApiType
trace_uuid: str = Field(alias="traceUuid")
parent_uuid: Optional[str] = Field(None, alias="parentUuid")
start_time: str = Field(alias="startTime")
end_time: str = Field(alias="endTime")
input: Optional[Union[Dict, list, str]] = None
output: Optional[Union[Dict, list, str]] = None
error: Optional[str] = None
# llm
model: Optional[str] = None
input_token_count: Optional[int] = Field(None, alias="inputTokenCount")
output_token_count: Optional[int] = Field(None, alias="outputTokenCount")
cost_per_input_token: Optional[float] = Field(None, alias="costPerInputToken")
cost_per_output_token: Optional[float] = Field(None, alias="costPerOutputToken")
class TraceApi(BaseModel):
uuid: str
base_spans: List[BaseApiSpan] = Field(alias="baseSpans")
agent_spans: List[BaseApiSpan] = Field(alias="agentSpans")
llm_spans: List[BaseApiSpan] = Field(alias="llmSpans")
retriever_spans: List[BaseApiSpan] = Field(alias="retrieverSpans")
tool_spans: List[BaseApiSpan] = Field(alias="toolSpans")
start_time: str = Field(alias="startTime")
end_time: str = Field(alias="endTime")
metadata: Optional[Dict[str, Any]] = Field(None)
tags: Optional[List[str]] = Field(None)
environment: Optional[str] = Field(None)
class Environment(Enum):
PRODUCTION = "production"
DEVELOPMENT = "development"
STAGING = "staging"

View File

@@ -0,0 +1,18 @@
from datetime import datetime, timezone
from litellm.integrations.deepeval.types import Environment
def to_zod_compatible_iso(dt: datetime) -> str:
return (
dt.astimezone(timezone.utc)
.isoformat(timespec="milliseconds")
.replace("+00:00", "Z")
)
def validate_environment(environment: str):
if environment not in [env.value for env in Environment]:
valid_values = ", ".join(f'"{env.value}"' for env in Environment)
raise ValueError(
f"Invalid environment: {environment}. Please use one of the following instead: {valid_values}"
)

View File

@@ -0,0 +1,316 @@
# LiteLLM Dotprompt Manager
A powerful prompt management system for LiteLLM that supports [Google's Dotprompt specification](https://google.github.io/dotprompt/getting-started/). This allows you to manage your AI prompts in organized `.prompt` files with YAML frontmatter, Handlebars templating, and full integration with LiteLLM's completion API.
## Features
- **📁 File-based prompt management**: Organize prompts in `.prompt` files
- **🎯 YAML frontmatter**: Define model, parameters, and schemas in file headers
- **🔧 Handlebars templating**: Use `{{variable}}` syntax with Jinja2 backend
- **✅ Input validation**: Automatic validation against defined schemas
- **🔗 LiteLLM integration**: Works seamlessly with `litellm.completion()`
- **💬 Smart message parsing**: Converts prompts to proper chat messages
- **⚙️ Parameter extraction**: Automatically applies model settings from prompts
## Quick Start
### 1. Create a `.prompt` file
Create a file called `chat_assistant.prompt`:
```yaml
---
model: gpt-4
temperature: 0.7
max_tokens: 150
input:
schema:
user_message: string
system_context?: string
---
{% if system_context %}System: {{system_context}}
{% endif %}User: {{user_message}}
```
### 2. Use with LiteLLM
```python
import litellm
litellm.set_global_prompt_directory("path/to/your/prompts")
# Use with completion - the model prefix 'dotprompt/' tells LiteLLM to use prompt management
response = litellm.completion(
model="dotprompt/gpt-4", # The actual model comes from the .prompt file
prompt_id="chat_assistant",
prompt_variables={
"user_message": "What is machine learning?",
"system_context": "You are a helpful AI tutor."
},
# Any additional messages will be appended after the prompt
messages=[{"role": "user", "content": "Please explain it simply."}]
)
print(response.choices[0].message.content)
```
## Prompt File Format
### Basic Structure
```yaml
---
# Model configuration
model: gpt-4
temperature: 0.7
max_tokens: 500
# Input schema (optional)
input:
schema:
name: string
age: integer
preferences?: array
---
# Template content using Handlebars syntax
Hello {{name}}!
{% if age >= 18 %}
You're an adult, so here are some mature recommendations:
{% else %}
Here are some age-appropriate suggestions:
{% endif %}
{% for pref in preferences %}
- Based on your interest in {{pref}}, I recommend...
{% endfor %}
```
### Supported Frontmatter Fields
- **`model`**: The LLM model to use (e.g., `gpt-4`, `claude-3-sonnet`)
- **`input.schema`**: Define expected input variables and their types
- **`output.format`**: Expected output format (`json`, `text`, etc.)
- **`output.schema`**: Structure of expected output
### Additional Parameters
- **`temperature`**: Model temperature (0.0 to 1.0)
- **`max_tokens`**: Maximum tokens to generate
- **`top_p`**: Nucleus sampling parameter (0.0 to 1.0)
- **`frequency_penalty`**: Frequency penalty (0.0 to 1.0)
- **`presence_penalty`**: Presence penalty (0.0 to 1.0)
- any other parameters that are not model or schema-related will be treated as optional parameters to the model.
### Input Schema Types
- `string` or `str`: Text values
- `integer` or `int`: Whole numbers
- `float`: Decimal numbers
- `boolean` or `bool`: True/false values
- `array` or `list`: Lists of values
- `object` or `dict`: Key-value objects
Use `?` suffix for optional fields: `name?: string`
## Message Format Conversion
The dotprompt manager intelligently converts your rendered prompts into proper chat messages:
### Simple Text → User Message
```yaml
---
model: gpt-4
---
Tell me about {{topic}}.
```
Becomes: `[{"role": "user", "content": "Tell me about AI."}]`
### Role-Based Format → Multiple Messages
```yaml
---
model: gpt-4
---
System: You are a {{role}}.
User: {{question}}
```
Becomes:
```python
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is AI?"}
]
```
## Example Prompts
### Data Extraction
```yaml
# extract_info.prompt
---
model: gemini/gemini-1.5-pro
input:
schema:
text: string
output:
format: json
schema:
title?: string
summary: string
tags: array
---
Extract the requested information from the given text. Return JSON format.
Text: {{text}}
```
### Code Assistant
```yaml
# code_helper.prompt
---
model: claude-3-5-sonnet-20241022
temperature: 0.2
max_tokens: 2000
input:
schema:
language: string
task: string
code?: string
---
You are an expert {{language}} programmer.
Task: {{task}}
{% if code %}
Current code:
```{{language}}
{{code}}
```
{% endif %}
Please provide a complete, well-documented solution.
```
### Multi-turn Conversation
```yaml
# conversation.prompt
---
model: gpt-4
temperature: 0.8
input:
schema:
personality: string
context: string
---
System: You are a {{personality}}. {{context}}
User: Let's start our conversation.
```
## API Reference
### PromptManager
The core class for managing `.prompt` files.
#### Methods
- **`__init__(prompt_directory: str)`**: Initialize with directory path
- **`render(prompt_id: str, variables: dict) -> str`**: Render prompt with variables
- **`list_prompts() -> List[str]`**: Get all available prompt IDs
- **`get_prompt(prompt_id: str) -> PromptTemplate`**: Get prompt template object
- **`get_prompt_metadata(prompt_id: str) -> dict`**: Get prompt metadata
- **`reload_prompts() -> None`**: Reload all prompts from directory
- **`add_prompt(prompt_id: str, content: str, metadata: dict)`**: Add prompt programmatically
### DotpromptManager
LiteLLM integration class extending `PromptManagementBase`.
#### Methods
- **`__init__(prompt_directory: str)`**: Initialize with directory path
- **`should_run_prompt_management(prompt_id: str, params: dict) -> bool`**: Check if prompt exists
- **`set_prompt_directory(directory: str)`**: Change prompt directory
- **`reload_prompts()`**: Reload prompts from directory
### PromptTemplate
Represents a single prompt with metadata.
#### Properties
- **`content: str`**: The prompt template content
- **`metadata: dict`**: Full metadata from frontmatter
- **`model: str`**: Specified model name
- **`temperature: float`**: Model temperature
- **`max_tokens: int`**: Token limit
- **`input_schema: dict`**: Input validation schema
- **`output_format: str`**: Expected output format
- **`output_schema: dict`**: Output structure schema
## Best Practices
1. **Organize by purpose**: Group related prompts in subdirectories
2. **Use descriptive names**: `extract_user_info.prompt` vs `prompt1.prompt`
3. **Define schemas**: Always specify input schemas for validation
4. **Version control**: Store `.prompt` files in git for change tracking
5. **Test prompts**: Use the test framework to validate prompt behavior
6. **Keep templates focused**: One prompt should do one thing well
7. **Use includes**: Break complex prompts into reusable components
## Troubleshooting
### Common Issues
**Prompt not found**: Ensure the `.prompt` file exists and has correct extension
```python
# Check available prompts
from litellm.integrations.dotprompt import get_dotprompt_manager
manager = get_dotprompt_manager()
print(manager.prompt_manager.list_prompts())
```
**Template errors**: Verify Handlebars syntax and variable names
```python
# Test rendering directly
manager.prompt_manager.render("my_prompt", {"test": "value"})
```
**Model not working**: Check that model name in frontmatter is correct
```python
# Check prompt metadata
metadata = manager.prompt_manager.get_prompt_metadata("my_prompt")
print(metadata)
```
### Validation Errors
Input validation failures show helpful error messages:
```
ValueError: Invalid type for field 'age': expected int, got str
```
Make sure your variables match the defined schema types.
## Contributing
The LiteLLM Dotprompt manager follows the [Dotprompt specification](https://google.github.io/dotprompt/) for maximum compatibility. When contributing:
1. Ensure compatibility with existing `.prompt` files
2. Add tests for new features
3. Update documentation
4. Follow the existing code style
## License
This prompt management system is part of LiteLLM and follows the same license terms.

View File

@@ -0,0 +1,91 @@
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from .prompt_manager import PromptManager, PromptTemplate
from litellm.types.prompts.init_prompts import PromptLiteLLMParams, PromptSpec
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.types.prompts.init_prompts import SupportedPromptIntegrations
from .dotprompt_manager import DotpromptManager
# Global instances
global_prompt_directory: Optional[str] = None
global_prompt_manager: Optional["PromptManager"] = None
def set_global_prompt_directory(directory: str) -> None:
"""
Set the global prompt directory for dotprompt files.
Args:
directory: Path to directory containing .prompt files
"""
import litellm
litellm.global_prompt_directory = directory # type: ignore
def _get_prompt_data_from_dotprompt_content(dotprompt_content: str) -> dict:
"""
Get the prompt data from the dotprompt content.
The UI stores prompts under `dotprompt_content` in the database. This function parses the content and returns the prompt data in the format expected by the prompt manager.
"""
from .prompt_manager import PromptManager
# Parse the dotprompt content to extract frontmatter and content
temp_manager = PromptManager()
metadata, content = temp_manager._parse_frontmatter(dotprompt_content)
# Convert to prompt_data format
return {"content": content.strip(), "metadata": metadata}
def prompt_initializer(
litellm_params: "PromptLiteLLMParams", prompt_spec: "PromptSpec"
) -> "CustomPromptManagement":
"""
Initialize a prompt from a .prompt file.
"""
prompt_directory = getattr(litellm_params, "prompt_directory", None)
prompt_data = getattr(litellm_params, "prompt_data", None)
prompt_id = getattr(litellm_params, "prompt_id", None)
if prompt_directory:
raise ValueError(
"Cannot set prompt_directory when working with prompt_initializer. Needs to be a specific dotprompt file"
)
prompt_file = getattr(litellm_params, "prompt_file", None)
# Handle dotprompt_content from database
dotprompt_content = getattr(litellm_params, "dotprompt_content", None)
if dotprompt_content and not prompt_data and not prompt_file:
prompt_data = _get_prompt_data_from_dotprompt_content(dotprompt_content)
try:
dot_prompt_manager = DotpromptManager(
prompt_directory=prompt_directory,
prompt_data=prompt_data,
prompt_file=prompt_file,
prompt_id=prompt_id,
)
return dot_prompt_manager
except Exception as e:
raise e
prompt_initializer_registry = {
SupportedPromptIntegrations.DOT_PROMPT.value: prompt_initializer,
}
# Export public API
__all__ = [
"PromptManager",
"DotpromptManager",
"PromptTemplate",
"set_global_prompt_directory",
"global_prompt_directory",
"global_prompt_manager",
]

View File

@@ -0,0 +1,378 @@
"""
Dotprompt manager that integrates with LiteLLM's prompt management system.
Builds on top of PromptManagementBase to provide .prompt file support.
"""
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.integrations.prompt_management_base import PromptManagementClient
from litellm.types.llms.openai import AllMessageValues
from litellm.types.prompts.init_prompts import PromptSpec
from litellm.types.utils import StandardCallbackDynamicParams
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
from .prompt_manager import PromptManager, PromptTemplate
class DotpromptManager(CustomPromptManagement):
"""
Dotprompt manager that integrates with LiteLLM's prompt management system.
This class enables using .prompt files with the litellm completion() function
by implementing the PromptManagementBase interface.
Usage:
# Set global prompt directory
litellm.prompt_directory = "path/to/prompts"
# Use with completion
response = litellm.completion(
model="dotprompt/gpt-4",
prompt_id="my_prompt",
prompt_variables={"variable": "value"},
messages=[{"role": "user", "content": "This will be combined with the prompt"}]
)
"""
def __init__(
self,
prompt_directory: Optional[str] = None,
prompt_file: Optional[str] = None,
prompt_data: Optional[Union[dict, str]] = None,
prompt_id: Optional[str] = None,
):
import litellm
self.prompt_directory = prompt_directory or litellm.global_prompt_directory
# Support for JSON-based prompts stored in memory/database
if isinstance(prompt_data, str):
self.prompt_data = json.loads(prompt_data)
else:
self.prompt_data = prompt_data or {}
self._prompt_manager: Optional[PromptManager] = None
self.prompt_file = prompt_file
self.prompt_id = prompt_id
@property
def integration_name(self) -> str:
"""Integration name used in model names like 'dotprompt/gpt-4'."""
return "dotprompt"
@property
def prompt_manager(self) -> PromptManager:
"""Lazy-load the prompt manager."""
if self._prompt_manager is None:
if (
self.prompt_directory is None
and not self.prompt_data
and not self.prompt_file
):
raise ValueError(
"Either prompt_directory or prompt_data must be set before using dotprompt manager. "
"Set litellm.global_prompt_directory, initialize with prompt_directory parameter, or provide prompt_data."
)
self._prompt_manager = PromptManager(
prompt_directory=self.prompt_directory,
prompt_data=self.prompt_data,
prompt_file=self.prompt_file,
prompt_id=self.prompt_id,
)
return self._prompt_manager
def should_run_prompt_management(
self,
prompt_id: Optional[str],
prompt_spec: Optional[PromptSpec],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> bool:
"""
Determine if prompt management should run based on the prompt_id.
Returns True if the prompt_id exists in our prompt manager.
"""
if prompt_id is None:
return False
try:
return prompt_id in self.prompt_manager.list_prompts()
except Exception:
# If there's any error accessing prompts, don't run prompt management
return False
def _compile_prompt_helper(
self,
prompt_id: Optional[str],
prompt_spec: Optional[PromptSpec],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
) -> PromptManagementClient:
"""
Compile a .prompt file into a PromptManagementClient structure.
This method:
1. Loads the prompt template from the .prompt file (with optional version)
2. Renders it with the provided variables
3. Converts the rendered text into chat messages
4. Extracts model and optional parameters from metadata
"""
if prompt_id is None:
raise ValueError("prompt_id is required for dotprompt manager")
try:
# Get the prompt template (versioned or base)
template = self.prompt_manager.get_prompt(
prompt_id=prompt_id, version=prompt_version
)
if template is None:
version_str = f" (version {prompt_version})" if prompt_version else ""
raise ValueError(
f"Prompt '{prompt_id}'{version_str} not found in prompt directory"
)
# Render the template with variables (pass version for proper lookup)
rendered_content = self.prompt_manager.render(
prompt_id=prompt_id,
prompt_variables=prompt_variables,
version=prompt_version,
)
# Convert rendered content to chat messages
messages = self._convert_to_messages(rendered_content)
# Extract model from metadata (if specified)
template_model = template.model
# Extract optional parameters from metadata
optional_params = self._extract_optional_params(template)
return PromptManagementClient(
prompt_id=prompt_id,
prompt_template=messages,
prompt_template_model=template_model,
prompt_template_optional_params=optional_params,
completed_messages=None,
)
except Exception as e:
raise ValueError(f"Error compiling prompt '{prompt_id}': {e}")
async def async_compile_prompt_helper(
self,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_spec: Optional[PromptSpec] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
) -> PromptManagementClient:
"""
Async version of compile prompt helper. Since dotprompt operations are synchronous,
this simply delegates to the sync version.
"""
if prompt_id is None:
raise ValueError("prompt_id is required for dotprompt manager")
return self._compile_prompt_helper(
prompt_id=prompt_id,
prompt_spec=prompt_spec,
prompt_variables=prompt_variables,
dynamic_callback_params=dynamic_callback_params,
prompt_label=prompt_label,
prompt_version=prompt_version,
)
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_spec: Optional[PromptSpec] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
ignore_prompt_manager_model: Optional[bool] = False,
ignore_prompt_manager_optional_params: Optional[bool] = False,
) -> Tuple[str, List[AllMessageValues], dict]:
from litellm.integrations.prompt_management_base import PromptManagementBase
return PromptManagementBase.get_chat_completion_prompt(
self,
model,
messages,
non_default_params,
prompt_id,
prompt_variables,
dynamic_callback_params,
prompt_spec=prompt_spec,
prompt_label=prompt_label,
prompt_version=prompt_version,
)
async def async_get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
litellm_logging_obj: LiteLLMLoggingObj,
prompt_spec: Optional[PromptSpec] = None,
tools: Optional[List[Dict]] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
ignore_prompt_manager_model: Optional[bool] = False,
ignore_prompt_manager_optional_params: Optional[bool] = False,
) -> Tuple[str, List[AllMessageValues], dict]:
"""
Async version - delegates to PromptManagementBase async implementation.
"""
from litellm.integrations.prompt_management_base import PromptManagementBase
return await PromptManagementBase.async_get_chat_completion_prompt(
self,
model,
messages,
non_default_params,
prompt_id=prompt_id,
prompt_variables=prompt_variables,
litellm_logging_obj=litellm_logging_obj,
dynamic_callback_params=dynamic_callback_params,
prompt_spec=prompt_spec,
tools=tools,
prompt_label=prompt_label,
prompt_version=prompt_version,
ignore_prompt_manager_model=ignore_prompt_manager_model,
ignore_prompt_manager_optional_params=ignore_prompt_manager_optional_params,
)
def _convert_to_messages(self, rendered_content: str) -> List[AllMessageValues]:
"""
Convert rendered prompt content to chat messages.
This method supports multiple formats:
1. Simple text -> converted to user message
2. Text with role prefixes (System:, User:, Assistant:) -> parsed into separate messages
3. Already formatted as a single message
"""
# Clean up the content
content = rendered_content.strip()
# Try to parse role-based format (System: ..., User: ..., etc.)
messages = []
current_role = None
current_content = []
lines = content.split("\n")
for line in lines:
line = line.strip()
# Check for role prefixes
if line.startswith("System:"):
if current_role and current_content:
messages.append(
self._create_message(
current_role, "\n".join(current_content).strip()
)
)
current_role = "system"
current_content = [line[7:].strip()] # Remove "System:" prefix
elif line.startswith("User:"):
if current_role and current_content:
messages.append(
self._create_message(
current_role, "\n".join(current_content).strip()
)
)
current_role = "user"
current_content = [line[5:].strip()] # Remove "User:" prefix
elif line.startswith("Assistant:"):
if current_role and current_content:
messages.append(
self._create_message(
current_role, "\n".join(current_content).strip()
)
)
current_role = "assistant"
current_content = [line[10:].strip()] # Remove "Assistant:" prefix
else:
# Continue current message content
if current_role:
current_content.append(line)
else:
# No role prefix found, treat as user message
current_role = "user"
current_content = [line]
# Add the last message
if current_role and current_content:
content_text = "\n".join(current_content).strip()
if content_text: # Only add if there's actual content
messages.append(self._create_message(current_role, content_text))
# If no messages were created, treat the entire content as a user message
if not messages and content:
messages.append(self._create_message("user", content))
return messages
def _create_message(self, role: str, content: str) -> AllMessageValues:
"""Create a message with the specified role and content."""
return {
"role": role, # type: ignore
"content": content,
}
def _extract_optional_params(self, template: PromptTemplate) -> dict:
"""
Extract optional parameters from the prompt template metadata.
Includes parameters like temperature, max_tokens, etc.
"""
optional_params = {}
# Extract common parameters from metadata
if template.optional_params is not None:
optional_params.update(template.optional_params)
return optional_params
def set_prompt_directory(self, prompt_directory: str) -> None:
"""Set the prompt directory and reload prompts."""
self.prompt_directory = prompt_directory
self._prompt_manager = None # Reset to force reload
def reload_prompts(self) -> None:
"""Reload all prompts from the directory."""
if self._prompt_manager:
self._prompt_manager.reload_prompts()
def add_prompt_from_json(self, prompt_id: str, json_data: Dict[str, Any]) -> None:
"""Add a prompt from JSON data."""
content = json_data.get("content", "")
metadata = json_data.get("metadata", {})
self.prompt_manager.add_prompt(prompt_id, content, metadata)
def load_prompts_from_json(self, prompts_data: Dict[str, Dict[str, Any]]) -> None:
"""Load multiple prompts from JSON data."""
self.prompt_manager.load_prompts_from_json_data(prompts_data)
def get_prompts_as_json(self) -> Dict[str, Dict[str, Any]]:
"""Get all prompts in JSON format."""
return self.prompt_manager.get_all_prompts_as_json()
def convert_prompt_file_to_json(self, file_path: str) -> Dict[str, Any]:
"""Convert a .prompt file to JSON format."""
return self.prompt_manager.prompt_file_to_json(file_path)

View File

@@ -0,0 +1,368 @@
"""
Based on Google's GenAI Kit dotprompt implementation: https://google.github.io/dotprompt/reference/frontmatter/
"""
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import yaml
from jinja2 import DictLoader, Environment, select_autoescape
class PromptTemplate:
"""Represents a single prompt template with metadata and content."""
def __init__(
self,
content: str,
metadata: Optional[Dict[str, Any]] = None,
template_id: Optional[str] = None,
):
self.content = content
self.metadata = metadata or {}
self.template_id = template_id
# Extract common metadata fields
restricted_keys = ["model", "input", "output"]
self.model = self.metadata.get("model")
self.input_schema = self.metadata.get("input", {}).get("schema", {})
self.output_format = self.metadata.get("output", {}).get("format")
self.output_schema = self.metadata.get("output", {}).get("schema", {})
self.optional_params = {}
for key in self.metadata.keys():
if key not in restricted_keys:
self.optional_params[key] = self.metadata[key]
def __repr__(self):
return f"PromptTemplate(id='{self.template_id}', model='{self.model}')"
class PromptManager:
"""
Manager for loading and rendering .prompt files following the Dotprompt specification.
Supports:
- YAML frontmatter for metadata
- Handlebars-style templating (using Jinja2)
- Input/output schema validation
- Model configuration
"""
def __init__(
self,
prompt_id: Optional[str] = None,
prompt_directory: Optional[str] = None,
prompt_data: Optional[Dict[str, Dict[str, Any]]] = None,
prompt_file: Optional[str] = None,
):
self.prompt_directory = Path(prompt_directory) if prompt_directory else None
self.prompts: Dict[str, PromptTemplate] = {}
self.prompt_file = prompt_file
self.jinja_env = Environment(
loader=DictLoader({}),
autoescape=select_autoescape(["html", "xml"]),
# Use Handlebars-style delimiters to match Dotprompt spec
variable_start_string="{{",
variable_end_string="}}",
block_start_string="{%",
block_end_string="%}",
comment_start_string="{#",
comment_end_string="#}",
)
# Load prompts from directory if provided
if self.prompt_directory:
self._load_prompts()
if self.prompt_file:
if not prompt_id:
raise ValueError("prompt_id is required when prompt_file is provided")
template = self._load_prompt_file(self.prompt_file, prompt_id)
self.prompts[prompt_id] = template
# Load prompts from JSON data if provided
if prompt_data:
self._load_prompts_from_json(prompt_data, prompt_id)
def _load_prompts(self) -> None:
"""Load all .prompt files from the prompt directory."""
if not self.prompt_directory or not self.prompt_directory.exists():
raise ValueError(
f"Prompt directory does not exist: {self.prompt_directory}"
)
prompt_files = list(self.prompt_directory.glob("*.prompt"))
for prompt_file in prompt_files:
try:
prompt_id = prompt_file.stem # filename without extension
template = self._load_prompt_file(prompt_file, prompt_id)
self.prompts[prompt_id] = template
# Optional: print(f"Loaded prompt: {prompt_id}")
except Exception:
# Optional: print(f"Error loading prompt file {prompt_file}")
pass
def _load_prompts_from_json(
self, prompt_data: Dict[str, Dict[str, Any]], prompt_id: Optional[str] = None
) -> None:
"""Load prompts from JSON data structure.
Expected format:
{
"prompt_id": {
"content": "template content",
"metadata": {"model": "gpt-4", "temperature": 0.7, ...}
}
}
or
{
"content": "template content",
"metadata": {"model": "gpt-4", "temperature": 0.7, ...}
} + prompt_id
"""
if prompt_id:
prompt_data = {prompt_id: prompt_data}
for prompt_id, prompt_info in prompt_data.items():
try:
content = prompt_info.get("content", "")
metadata = prompt_info.get("metadata", {})
template = PromptTemplate(
content=content,
metadata=metadata,
template_id=prompt_id,
)
self.prompts[prompt_id] = template
except Exception:
# Optional: print(f"Error loading prompt from JSON: {prompt_id}")
pass
def _load_prompt_file(
self, file_path: Union[str, Path], prompt_id: str
) -> PromptTemplate:
"""Load and parse a single .prompt file."""
if isinstance(file_path, str):
file_path = Path(file_path)
content = file_path.read_text(encoding="utf-8")
# Split frontmatter and content
frontmatter, template_content = self._parse_frontmatter(content)
return PromptTemplate(
content=template_content.strip(),
metadata=frontmatter,
template_id=prompt_id,
)
def _parse_frontmatter(self, content: str) -> Tuple[Dict[str, Any], str]:
"""Parse YAML frontmatter from prompt content."""
# Match YAML frontmatter between --- delimiters
frontmatter_pattern = r"^---\s*\n(.*?)\n---\s*\n(.*)$"
match = re.match(frontmatter_pattern, content, re.DOTALL)
if match:
frontmatter_yaml = match.group(1)
template_content = match.group(2)
try:
frontmatter = yaml.safe_load(frontmatter_yaml) or {}
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML frontmatter: {e}")
else:
# No frontmatter found, treat entire content as template
frontmatter = {}
template_content = content
return frontmatter, template_content
def render(
self,
prompt_id: str,
prompt_variables: Optional[Dict[str, Any]] = None,
version: Optional[int] = None,
) -> str:
"""
Render a prompt template with the given variables.
Args:
prompt_id: The ID of the prompt template to render
prompt_variables: Variables to substitute in the template
version: Optional version number. If provided, looks for {prompt_id}.v{version}
Returns:
The rendered prompt string
Raises:
KeyError: If prompt_id is not found
ValueError: If template rendering fails
"""
# Get the template (versioned or base)
template = self.get_prompt(prompt_id=prompt_id, version=version)
if template is None:
available_prompts = list(self.prompts.keys())
version_str = f" (version {version})" if version else ""
raise KeyError(
f"Prompt '{prompt_id}'{version_str} not found. Available prompts: {available_prompts}"
)
variables = prompt_variables or {}
# Validate input variables against schema if defined
if template.input_schema:
self._validate_input(variables, template.input_schema)
try:
# Create Jinja2 template and render
jinja_template = self.jinja_env.from_string(template.content)
rendered = jinja_template.render(**variables)
return rendered
except Exception as e:
raise ValueError(f"Error rendering template '{prompt_id}': {e}")
def _validate_input(
self, variables: Dict[str, Any], schema: Dict[str, Any]
) -> None:
"""Basic validation of input variables against schema."""
for field_name, field_type in schema.items():
if field_name in variables:
value = variables[field_name]
expected_type = self._get_python_type(field_type)
if not isinstance(value, expected_type):
raise ValueError(
f"Invalid type for field '{field_name}': "
f"expected {getattr(expected_type, '__name__', str(expected_type))}, got {type(value).__name__}"
)
def _get_python_type(self, schema_type: str) -> Union[type, tuple]:
"""Convert schema type string to Python type."""
type_mapping: Dict[str, Union[type, tuple]] = {
"string": str,
"str": str,
"number": (int, float),
"integer": int,
"int": int,
"float": float,
"boolean": bool,
"bool": bool,
"array": list,
"list": list,
"object": dict,
"dict": dict,
}
return type_mapping.get(schema_type.lower(), str) # type: ignore
def get_prompt(
self, prompt_id: str, version: Optional[int] = None
) -> Optional[PromptTemplate]:
"""
Get a prompt template by ID and optional version.
Args:
prompt_id: The base prompt ID
version: Optional version number. If provided, looks for {prompt_id}.v{version}
Returns:
The prompt template if found, None otherwise
"""
if version is not None:
# Try versioned prompt first: prompt_id.v{version}
versioned_id = f"{prompt_id}.v{version}"
if versioned_id in self.prompts:
return self.prompts[versioned_id]
# Fall back to base prompt_id
return self.prompts.get(prompt_id)
def list_prompts(self) -> List[str]:
"""Get a list of all available prompt IDs."""
return list(self.prompts.keys())
def get_prompt_metadata(self, prompt_id: str) -> Optional[Dict[str, Any]]:
"""Get metadata for a specific prompt."""
template = self.prompts.get(prompt_id)
return template.metadata if template else None
def reload_prompts(self) -> None:
"""Reload all prompts from the directory (if directory was provided)."""
self.prompts.clear()
if self.prompt_directory:
self._load_prompts()
def add_prompt(
self, prompt_id: str, content: str, metadata: Optional[Dict[str, Any]] = None
) -> None:
"""Add a prompt template programmatically."""
template = PromptTemplate(
content=content, metadata=metadata or {}, template_id=prompt_id
)
self.prompts[prompt_id] = template
def prompt_file_to_json(self, file_path: Union[str, Path]) -> Dict[str, Any]:
"""Convert a .prompt file to JSON format.
Args:
file_path: Path to the .prompt file
Returns:
Dictionary with 'content' and 'metadata' keys
"""
file_path = Path(file_path)
content = file_path.read_text(encoding="utf-8")
# Parse frontmatter and content
frontmatter, template_content = self._parse_frontmatter(content)
return {"content": template_content.strip(), "metadata": frontmatter}
def json_to_prompt_file(self, prompt_data: Dict[str, Any]) -> str:
"""Convert JSON prompt data to .prompt file format.
Args:
prompt_data: Dictionary with 'content' and 'metadata' keys
Returns:
String content in .prompt file format
"""
content = prompt_data.get("content", "")
metadata = prompt_data.get("metadata", {})
if not metadata:
# No metadata, return just the content
return content
# Convert metadata to YAML frontmatter
import yaml
frontmatter_yaml = yaml.dump(metadata, default_flow_style=False)
return f"---\n{frontmatter_yaml}---\n{content}"
def get_all_prompts_as_json(self) -> Dict[str, Dict[str, Any]]:
"""Get all loaded prompts in JSON format.
Returns:
Dictionary mapping prompt_id to prompt data
"""
result = {}
for prompt_id, template in self.prompts.items():
result[prompt_id] = {
"content": template.content,
"metadata": template.metadata,
}
return result
def load_prompts_from_json_data(
self, prompt_data: Dict[str, Dict[str, Any]]
) -> None:
"""Load additional prompts from JSON data (merges with existing prompts)."""
self._load_prompts_from_json(prompt_data)

View File

@@ -0,0 +1,89 @@
#### What this does ####
# On success + failure, log events to Supabase
import os
import traceback
from litellm._uuid import uuid
from typing import Any
import litellm
class DyanmoDBLogger:
# Class variables or attributes
def __init__(self):
# Instance variables
import boto3
self.dynamodb: Any = boto3.resource(
"dynamodb", region_name=os.environ["AWS_REGION_NAME"]
)
if litellm.dynamodb_table_name is None:
raise ValueError(
"LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`"
)
self.table_name = litellm.dynamodb_table_name
async def _async_log_event(
self, kwargs, response_obj, start_time, end_time, print_verbose
):
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
try:
print_verbose(
f"DynamoDB Logging - Enters logging function for model {kwargs}"
)
# construct payload to send to DynamoDB
# follows the same params as langfuse.py
litellm_params = kwargs.get("litellm_params", {})
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
messages = kwargs.get("messages")
optional_params = kwargs.get("optional_params", {})
call_type = kwargs.get("call_type", "litellm.completion")
usage = response_obj["usage"]
id = response_obj.get("id", str(uuid.uuid4()))
# Build the initial payload
payload = {
"id": id,
"call_type": call_type,
"startTime": start_time,
"endTime": end_time,
"model": kwargs.get("model", ""),
"user": kwargs.get("user", ""),
"modelParameters": optional_params,
"messages": messages,
"response": response_obj,
"usage": usage,
"metadata": metadata,
}
# Ensure everything in the payload is converted to str
for key, value in payload.items():
try:
payload[key] = str(value)
except Exception:
# non blocking if it can't cast to a str
pass
print_verbose(f"\nDynamoDB Logger - Logging payload = {payload}")
# put data in dyanmo DB
table = self.dynamodb.Table(self.table_name)
# Assuming log_data is a dictionary with log information
response = table.put_item(Item=payload)
print_verbose(f"Response from DynamoDB:{str(response)}")
print_verbose(
f"DynamoDB Layer Logging - final response object: {response_obj}"
)
return response
except Exception:
print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
pass

View File

@@ -0,0 +1,136 @@
"""
Functions for sending Email Alerts
"""
import os
from typing import List, Optional
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.proxy._types import WebhookEvent
# we use this for the email header, please send a test email if you change this. verify it looks good on email
LITELLM_LOGO_URL = "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
LITELLM_SUPPORT_CONTACT = "support@berri.ai"
async def get_all_team_member_emails(team_id: Optional[str] = None) -> list:
verbose_logger.debug(
"Email Alerting: Getting all team members for team_id=%s", team_id
)
if team_id is None:
return []
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise Exception("Not connected to DB!")
team_row = await prisma_client.db.litellm_teamtable.find_unique(
where={
"team_id": team_id,
}
)
if team_row is None:
return []
_team_members = team_row.members_with_roles
verbose_logger.debug(
"Email Alerting: Got team members for team_id=%s Team Members: %s",
team_id,
_team_members,
)
_team_member_user_ids: List[str] = []
for member in _team_members:
if member and isinstance(member, dict):
_user_id = member.get("user_id")
if _user_id and isinstance(_user_id, str):
_team_member_user_ids.append(_user_id)
sql_query = """
SELECT user_email
FROM "LiteLLM_UserTable"
WHERE user_id = ANY($1::TEXT[]);
"""
_result = await prisma_client.db.query_raw(sql_query, _team_member_user_ids)
verbose_logger.debug("Email Alerting: Got all Emails for team, emails=%s", _result)
if _result is None:
return []
emails = []
for user in _result:
if user and isinstance(user, dict) and user.get("user_email", None) is not None:
emails.append(user.get("user_email"))
return emails
async def send_team_budget_alert(webhook_event: WebhookEvent) -> bool:
"""
Send an Email Alert to All Team Members when the Team Budget is crossed
Returns -> True if sent, False if not.
"""
from litellm.proxy.utils import send_email
_team_id = webhook_event.team_id
team_alias = webhook_event.team_alias
verbose_logger.debug(
"Email Alerting: Sending Team Budget Alert for team=%s", team_alias
)
email_logo_url = os.getenv("SMTP_SENDER_LOGO", os.getenv("EMAIL_LOGO_URL", None))
email_support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", None)
# await self._check_if_using_premium_email_feature(
# premium_user, email_logo_url, email_support_contact
# )
if email_logo_url is None:
email_logo_url = LITELLM_LOGO_URL
if email_support_contact is None:
email_support_contact = LITELLM_SUPPORT_CONTACT
recipient_emails = await get_all_team_member_emails(_team_id)
recipient_emails_str: str = ",".join(recipient_emails)
verbose_logger.debug(
"Email Alerting: Sending team budget alert to %s", recipient_emails_str
)
event_name = webhook_event.event_message
max_budget = webhook_event.max_budget
email_html_content = "Alert from LiteLLM Server"
if recipient_emails_str is None:
verbose_proxy_logger.warning(
"Email Alerting: Trying to send email alert to no recipient, got recipient_emails=%s",
recipient_emails_str,
)
email_html_content = f"""
<img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" /> <br/><br/><br/>
Budget Crossed for Team <b> {team_alias} </b> <br/> <br/>
Your Teams LLM API usage has crossed it's <b> budget of ${max_budget} </b>, current spend is <b>${webhook_event.spend}</b><br /> <br />
API requests will be rejected until either (a) you increase your budget or (b) your budget gets reset <br /> <br />
If you have any questions, please send an email to {email_support_contact} <br /> <br />
Best, <br />
The LiteLLM team <br />
"""
email_event = {
"to": recipient_emails_str,
"subject": f"LiteLLM {event_name} for Team {team_alias}",
"html": email_html_content,
}
await send_email(
receiver_email=email_event["to"],
subject=email_event["subject"],
html=email_event["html"],
)
return False

View File

@@ -0,0 +1,10 @@
EMAIL_FOOTER = """
<div class="footer">
<p>© 2025 LiteLLM. All rights reserved.</p>
<div class="social-links">
<a href="https://twitter.com/litellm">Twitter</a> •
<a href="https://github.com/BerriAI/litellm">GitHub</a> •
<a href="https://litellm.ai">Website</a>
</div>
</div>
"""

View File

@@ -0,0 +1,212 @@
"""
Modern Email Templates for LiteLLM Email Service with professional styling
"""
KEY_CREATED_EMAIL_TEMPLATE = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Your API Key is Ready</title>
<style>
body, html {{
margin: 0;
padding: 0;
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
color: #333333;
background-color: #f8fafc;
line-height: 1.5;
}}
.container {{
max-width: 560px;
margin: 20px auto;
background-color: #ffffff;
border-radius: 8px;
overflow: hidden;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}}
.header {{
padding: 24px 0;
text-align: center;
border-bottom: 1px solid #f1f5f9;
}}
.content {{
padding: 32px 40px;
}}
.greeting {{
font-size: 16px;
margin-bottom: 20px;
color: #333333;
}}
.message {{
font-size: 16px;
color: #333333;
margin-bottom: 20px;
}}
.key-container {{
margin: 28px 0;
}}
.key-label {{
font-size: 14px;
font-weight: 500;
margin-bottom: 8px;
color: #4b5563;
}}
.key {{
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
word-break: break-all;
background-color: #f9fafb;
border-radius: 6px;
padding: 16px;
font-size: 14px;
border: 1px solid #e5e7eb;
color: #4338ca;
}}
h2 {{
font-size: 18px;
font-weight: 600;
margin-top: 36px;
margin-bottom: 16px;
color: #333333;
}}
.budget-info {{
background-color: #f0fdf4;
border-radius: 6px;
padding: 14px 16px;
margin: 24px 0;
font-size: 14px;
border: 1px solid #dcfce7;
}}
.code-block {{
background-color: #f8fafc;
color: #334155;
border-radius: 8px;
padding: 20px;
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
font-size: 13px;
overflow-x: auto;
margin: 20px 0;
line-height: 1.6;
border: 1px solid #e2e8f0;
}}
.code-comment {{
color: #64748b;
}}
.code-string {{
color: #0369a1;
}}
.code-keyword {{
color: #7e22ce;
}}
.btn {{
display: inline-block;
padding: 8px 20px;
background-color: #6366f1;
color: #ffffff !important;
text-decoration: none;
border-radius: 6px;
font-weight: 500;
margin-top: 24px;
text-align: center;
font-size: 14px;
transition: background-color 0.2s;
}}
.btn:hover {{
background-color: #4f46e5;
color: #ffffff !important;
}}
.separator {{
height: 1px;
background-color: #f1f5f9;
margin: 40px 0 30px;
}}
.footer {{
padding: 24px 40px 32px;
text-align: center;
color: #64748b;
font-size: 13px;
background-color: #f8fafc;
border-top: 1px solid #f1f5f9;
}}
.social-links {{
margin-top: 12px;
}}
.social-links a {{
display: inline-block;
margin: 0 8px;
color: #64748b;
text-decoration: none;
}}
@media only screen and (max-width: 620px) {{
.container {{
width: 100%;
margin: 0;
border-radius: 0;
}}
.content {{
padding: 24px 20px;
}}
.footer {{
padding: 20px;
}}
}}
</style>
</head>
<body>
<div class="container">
<div class="header">
<img src="{email_logo_url}" alt="LiteLLM Logo" style="height: 32px; width: auto;">
</div>
<div class="content">
<div class="greeting">
<p>Hi {recipient_email},</p>
</div>
<div class="message">
<p>Great news! Your LiteLLM API key is ready to use.</p>
</div>
<div class="budget-info">
<p style="margin: 0;"><strong>Monthly Budget:</strong> {key_budget}</p>
</div>
<div class="key-container">
<div class="key-label">Your API Key</div>
<div class="key">{key_token}</div>
</div>
<h2>Quick Start Guide</h2>
<p>Here's how to use your key with the OpenAI SDK:</p>
<div class="code-block">
<span class="code-keyword">import</span> openai<br>
<br>
client = openai.OpenAI(<br>
&nbsp;&nbsp;api_key=<span class="code-string">"{key_token}"</span>,<br>
&nbsp;&nbsp;base_url=<span class="code-string">"{base_url}"</span><br>
)<br>
<br>
response = client.chat.completions.create(<br>
&nbsp;&nbsp;model=<span class="code-string">"gpt-3.5-turbo"</span>, <span class="code-comment"># model to send to the proxy</span><br>
&nbsp;&nbsp;messages = [<br>
&nbsp;&nbsp;&nbsp;&nbsp;{{<br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="code-string">"role"</span>: <span class="code-string">"user"</span>,<br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="code-string">"content"</span>: <span class="code-string">"this is a test request, write a short poem"</span><br>
&nbsp;&nbsp;&nbsp;&nbsp;}}<br>
&nbsp;&nbsp;]<br>
)
</div>
<a href="https://docs.litellm.ai/docs/proxy/user_keys" class="btn" style="color: #ffffff;">View Documentation</a>
<div class="separator"></div>
<h2>Need Help?</h2>
<p>If you have any questions or need assistance, please contact us at {email_support_contact}.</p>
</div>
{email_footer}
</div>
</body>
</html>
"""

View File

@@ -0,0 +1,224 @@
"""
Modern Email Templates for LiteLLM Email Service with professional styling
"""
KEY_ROTATED_EMAIL_TEMPLATE = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Your API Key Has Been Rotated</title>
<style>
body, html {{
margin: 0;
padding: 0;
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
color: #333333;
background-color: #f8fafc;
line-height: 1.5;
}}
.container {{
max-width: 560px;
margin: 20px auto;
background-color: #ffffff;
border-radius: 8px;
overflow: hidden;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}}
.header {{
padding: 24px 0;
text-align: center;
border-bottom: 1px solid #f1f5f9;
}}
.content {{
padding: 32px 40px;
}}
.greeting {{
font-size: 16px;
margin-bottom: 20px;
color: #333333;
}}
.message {{
font-size: 16px;
color: #333333;
margin-bottom: 20px;
}}
.key-container {{
margin: 28px 0;
}}
.key-label {{
font-size: 14px;
font-weight: 500;
margin-bottom: 8px;
color: #4b5563;
}}
.key {{
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
word-break: break-all;
background-color: #f9fafb;
border-radius: 6px;
padding: 16px;
font-size: 14px;
border: 1px solid #e5e7eb;
color: #4338ca;
}}
h2 {{
font-size: 18px;
font-weight: 600;
margin-top: 36px;
margin-bottom: 16px;
color: #333333;
}}
.budget-info {{
background-color: #f0fdf4;
border-radius: 6px;
padding: 14px 16px;
margin: 24px 0;
font-size: 14px;
border: 1px solid #dcfce7;
}}
.code-block {{
background-color: #f8fafc;
color: #334155;
border-radius: 8px;
padding: 20px;
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
font-size: 13px;
overflow-x: auto;
margin: 20px 0;
line-height: 1.6;
border: 1px solid #e2e8f0;
}}
.code-comment {{
color: #64748b;
}}
.code-string {{
color: #0369a1;
}}
.code-keyword {{
color: #7e22ce;
}}
.btn {{
display: inline-block;
padding: 8px 20px;
background-color: #6366f1;
color: #ffffff !important;
text-decoration: none;
border-radius: 6px;
font-weight: 500;
margin-top: 24px;
text-align: center;
font-size: 14px;
transition: background-color 0.2s;
}}
.btn:hover {{
background-color: #4f46e5;
color: #ffffff !important;
}}
.separator {{
height: 1px;
background-color: #f1f5f9;
margin: 40px 0 30px;
}}
.footer {{
padding: 24px 40px 32px;
text-align: center;
color: #64748b;
font-size: 13px;
background-color: #f8fafc;
border-top: 1px solid #f1f5f9;
}}
.social-links {{
margin-top: 12px;
}}
.social-links a {{
display: inline-block;
margin: 0 8px;
color: #64748b;
text-decoration: none;
}}
@media only screen and (max-width: 620px) {{
.container {{
width: 100%;
margin: 0;
border-radius: 0;
}}
.content {{
padding: 24px 20px;
}}
.footer {{
padding: 20px;
}}
}}
</style>
</head>
<body>
<div class="container">
<div class="header">
<img src="{email_logo_url}" alt="LiteLLM Logo" style="height: 32px; width: auto;">
</div>
<div class="content">
<div class="greeting">
<p>Hi {recipient_email},</p>
</div>
<div class="message">
<p><strong>Your LiteLLM API key has been rotated</strong> as part of our ongoing commitment to security best practices.</p>
<p style="margin-top: 16px;">Your previous API key has been deactivated and will no longer work. Please update your applications with the new key below.</p>
</div>
<div class="key-container">
<div class="key-label">Your New API Key</div>
<div class="key">{key_token}</div>
</div>
<div class="budget-info">
<p style="margin: 0;"><strong>Monthly Budget:</strong> {key_budget}</p>
</div>
<h2>Action Required</h2>
<p>Update your applications and systems with the new API key. Here's an example:</p>
<div class="code-block">
<span class="code-keyword">import</span> openai<br>
<br>
client = openai.OpenAI(<br>
&nbsp;&nbsp;api_key=<span class="code-string">"{key_token}"</span>,<br>
&nbsp;&nbsp;base_url=<span class="code-string">"{base_url}"</span><br>
)<br>
<br>
response = client.chat.completions.create(<br>
&nbsp;&nbsp;model=<span class="code-string">"gpt-3.5-turbo"</span>, <span class="code-comment"># model to send to the proxy</span><br>
&nbsp;&nbsp;messages = [<br>
&nbsp;&nbsp;&nbsp;&nbsp;{{<br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="code-string">"role"</span>: <span class="code-string">"user"</span>,<br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="code-string">"content"</span>: <span class="code-string">"this is a test request, write a short poem"</span><br>
&nbsp;&nbsp;&nbsp;&nbsp;}}<br>
&nbsp;&nbsp;]<br>
)
</div>
<div class="separator"></div>
<h2>Security Best Practices</h2>
<p style="margin-bottom: 12px;">To keep your API key secure:</p>
<ul style="margin: 0; padding-left: 20px; color: #333333;">
<li style="margin-bottom: 8px;">Never share your API key publicly or commit it to version control</li>
<li style="margin-bottom: 8px;">Store it securely using environment variables or secret management systems</li>
<li style="margin-bottom: 8px;">Monitor your API usage regularly for any unusual activity</li>
<li style="margin-bottom: 8px;">Rotate your keys periodically as a security best practice</li>
</ul>
<a href="https://docs.litellm.ai/docs/proxy/user_keys" class="btn" style="color: #ffffff;">View Documentation</a>
<div class="separator"></div>
<h2>Need Help?</h2>
<p>If you have any questions or need assistance updating your systems, please contact us at {email_support_contact}.</p>
</div>
{email_footer}
</div>
</body>
</html>
"""

View File

@@ -0,0 +1,134 @@
"""
Email Templates used by the LiteLLM Email Service in slack_alerting.py
"""
KEY_CREATED_EMAIL_TEMPLATE = """
<img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" />
<p> Hi {recipient_email}, <br/>
I'm happy to provide you with an OpenAI Proxy API Key, loaded with ${key_budget} per month. <br /> <br />
<b>
Key: <pre>{key_token}</pre> <br>
</b>
<h2>Usage Example</h2>
Detailed Documentation on <a href="https://docs.litellm.ai/docs/proxy/user_keys">Usage with OpenAI Python SDK, Langchain, LlamaIndex, Curl</a>
<pre>
import openai
client = openai.OpenAI(
api_key="{key_token}",
base_url={{base_url}}
)
response = client.chat.completions.create(
model="gpt-3.5-turbo", # model to send to the proxy
messages = [
{{
"role": "user",
"content": "this is a test request, write a short poem"
}}
]
)
</pre>
If you have any questions, please send an email to {email_support_contact} <br /> <br />
Best, <br />
The LiteLLM team <br />
"""
USER_INVITED_EMAIL_TEMPLATE = """
<img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" />
<p> Hi {recipient_email}, <br/>
You were invited to use OpenAI Proxy API for team {team_name} <br /> <br />
<a href="{base_url}" style="display: inline-block; padding: 10px 20px; background-color: #87ceeb; color: #fff; text-decoration: none; border-radius: 20px;">Get Started here</a> <br /> <br />
If you have any questions, please send an email to {email_support_contact} <br /> <br />
Best, <br />
The LiteLLM team <br />
"""
SOFT_BUDGET_ALERT_EMAIL_TEMPLATE = """
<img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" />
<p> Hi {recipient_email}, <br/>
Your LiteLLM API key has crossed its <b>soft budget limit of {soft_budget}</b>. <br /> <br />
<b>Current Spend:</b> {spend} <br />
<b>Soft Budget:</b> {soft_budget} <br />
{max_budget_info}
<p style="color: #dc2626; font-weight: 500;">
⚠️ Note: Your API requests will continue to work, but you should monitor your usage closely.
If you reach your maximum budget, requests will be rejected.
</p>
You can view your usage and manage your budget in the <a href="{base_url}">LiteLLM Dashboard</a>. <br /> <br />
If you have any questions, please send an email to {email_support_contact} <br /> <br />
Best, <br />
The LiteLLM team <br />
"""
TEAM_SOFT_BUDGET_ALERT_EMAIL_TEMPLATE = """
<img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" />
<p> Hi {team_alias} team member, <br/>
Your LiteLLM team has crossed its <b>soft budget limit of {soft_budget}</b>. <br /> <br />
<b>Current Spend:</b> {spend} <br />
<b>Soft Budget:</b> {soft_budget} <br />
{max_budget_info}
<p style="color: #dc2626; font-weight: 500;">
⚠️ Note: Your API requests will continue to work, but you should monitor your usage closely.
If you reach your maximum budget, requests will be rejected.
</p>
You can view your usage and manage your budget in the <a href="{base_url}">LiteLLM Dashboard</a>. <br /> <br />
If you have any questions, please send an email to {email_support_contact} <br /> <br />
Best, <br />
The LiteLLM team <br />
"""
MAX_BUDGET_ALERT_EMAIL_TEMPLATE = """
<img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" />
<p> Hi {recipient_email}, <br/>
Your LiteLLM API key has reached <b>{percentage}% of its maximum budget</b>. <br /> <br />
<b>Current Spend:</b> {spend} <br />
<b>Maximum Budget:</b> {max_budget} <br />
<b>Alert Threshold:</b> {alert_threshold} ({percentage}%) <br />
<p style="color: #dc2626; font-weight: 500;">
⚠️ Warning: You are approaching your maximum budget limit.
Once you reach your maximum budget of {max_budget}, all API requests will be rejected.
</p>
You can view your usage and manage your budget in the <a href="{base_url}">LiteLLM Dashboard</a>. <br /> <br />
If you have any questions, please send an email to {email_support_contact} <br /> <br />
Best, <br />
The LiteLLM team <br />
"""

View File

@@ -0,0 +1,175 @@
"""
Modern Email Templates for LiteLLM Email Service with professional styling
"""
USER_INVITATION_EMAIL_TEMPLATE = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Welcome to LiteLLM</title>
<style>
body, html {{
margin: 0;
padding: 0;
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
color: #333333;
background-color: #f8f8f8;
line-height: 1.5;
}}
.container {{
max-width: 560px;
margin: 20px auto;
background-color: #ffffff;
border-radius: 8px;
overflow: hidden;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}}
.logo {{
padding: 24px 0 0 24px;
text-align: left;
}}
.greeting {{
font-size: 16px;
margin-bottom: 20px;
color: #333333;
}}
.content {{
padding: 24px 40px 32px;
}}
h1 {{
font-size: 24px;
font-weight: 600;
margin-top: 24px;
margin-bottom: 16px;
color: #333333;
}}
p {{
font-size: 16px;
color: #333333;
margin-bottom: 16px;
line-height: 1.5;
}}
.intro-text {{
margin-bottom: 24px;
}}
.link {{
color: #6366f1;
text-decoration: none;
font-weight: 500;
}}
.link:hover {{
text-decoration: underline;
}}
.link-with-arrow {{
display: inline-flex;
align-items: center;
color: #6366f1;
text-decoration: none;
font-weight: 500;
margin-bottom: 20px;
}}
.link-with-arrow:hover {{
text-decoration: underline;
}}
.arrow {{
margin-left: 6px;
}}
.divider {{
height: 1px;
background-color: #f1f1f1;
margin: 24px 0;
}}
.btn {{
display: inline-block;
padding: 12px 24px;
background-color: #5c5ce0;
color: #ffffff !important;
text-decoration: none;
border-radius: 6px;
font-weight: 500;
margin-top: 12px;
text-align: center;
font-size: 15px;
transition: background-color 0.2s ease;
}}
.btn:hover {{
background-color: #4b4bb3;
}}
.btn-container {{
text-align: center;
margin: 24px 0;
}}
.footer {{
padding: 24px 40px 32px;
text-align: left;
color: #666;
font-size: 14px;
}}
.quickstart {{
margin-top: 32px;
}}
</style>
</head>
<body>
<div class="container">
<div class="logo">
<img src="{email_logo_url}" alt="LiteLLM Logo" style="height: 32px; width: auto;">
</div>
<div class="content">
<h1>Welcome to LiteLLM</h1>
<div class="greeting">
<p>Hi {recipient_email},</p>
</div>
<div class="intro-text">
<p>LiteLLM allows you to call 100+ LLM providers in the OpenAI API format. Get started by accepting your invitation.</p>
</div>
<div class="btn-container">
<a href="{base_url}" class="btn">Accept Invitation</a>
</div>
<div class="quickstart">
<p>Here's a quickstart guide to get you started:</p>
</div>
<div class="divider"></div>
<a href="https://docs.litellm.ai/docs/proxy/user_keys" class="link-with-arrow">
Make your first LLM request →
<span class="arrow"></span>
</a>
<p>Making LLM requests with OpenAI SDK, Langchain, LlamaIndex, and more.</p>
<div class="divider"></div>
<a href="https://docs.litellm.ai/docs/supported_endpoints" class="link-with-arrow">
Supported Endpoints →
<span class="arrow"></span>
</a>
<p>View all supported LLM endpoints on LiteLLM (/chat/completions, /embeddings, /responses etc.)</p>
<div class="divider"></div>
<a href="https://docs.litellm.ai/docs/pass_through/vertex_ai" class="link-with-arrow">
Passthrough Endpoints →
<span class="arrow"></span>
</a>
<p>We support calling VertexAI, Anthropic, and other providers in their native API format.</p>
<div class="divider"></div>
<p>Thanks for signing up. We're here to help you and your team. If you have any questions, contact us at {email_support_contact}</p>
</div>
{email_footer}
</div>
</body>
</html>
"""

View File

@@ -0,0 +1,113 @@
"""Database access helpers for Focus export."""
from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, Optional
import polars as pl
class FocusLiteLLMDatabase:
"""Retrieves LiteLLM usage data for Focus export workflows."""
def _ensure_prisma_client(self):
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise RuntimeError(
"Database not connected. Connect a database to your proxy - "
"https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
return prisma_client
async def get_usage_data(
self,
*,
limit: Optional[int] = None,
start_time_utc: Optional[datetime] = None,
end_time_utc: Optional[datetime] = None,
) -> pl.DataFrame:
"""Return usage data for the requested window."""
client = self._ensure_prisma_client()
where_clauses: list[str] = []
query_params: list[Any] = []
placeholder_index = 1
if start_time_utc:
where_clauses.append(f"dus.updated_at >= ${placeholder_index}::timestamptz")
query_params.append(start_time_utc)
placeholder_index += 1
if end_time_utc:
where_clauses.append(f"dus.updated_at <= ${placeholder_index}::timestamptz")
query_params.append(end_time_utc)
placeholder_index += 1
where_clause = ""
if where_clauses:
where_clause = "WHERE " + " AND ".join(where_clauses)
limit_clause = ""
if limit is not None:
try:
limit_value = int(limit)
except (TypeError, ValueError) as exc: # pragma: no cover - defensive guard
raise ValueError("limit must be an integer") from exc
if limit_value < 0:
raise ValueError("limit must be non-negative")
limit_clause = f" LIMIT ${placeholder_index}"
query_params.append(limit_value)
query = f"""
SELECT
dus.id,
dus.date,
dus.user_id,
dus.api_key,
dus.model,
dus.model_group,
dus.custom_llm_provider,
dus.prompt_tokens,
dus.completion_tokens,
dus.spend,
dus.api_requests,
dus.successful_requests,
dus.failed_requests,
dus.cache_creation_input_tokens,
dus.cache_read_input_tokens,
dus.created_at,
dus.updated_at,
vt.team_id,
vt.key_alias as api_key_alias,
tt.team_alias,
ut.user_email as user_email
FROM "LiteLLM_DailyUserSpend" dus
LEFT JOIN "LiteLLM_VerificationToken" vt ON dus.api_key = vt.token
LEFT JOIN "LiteLLM_TeamTable" tt ON vt.team_id = tt.team_id
LEFT JOIN "LiteLLM_UserTable" ut ON dus.user_id = ut.user_id
{where_clause}
ORDER BY dus.date DESC, dus.created_at DESC
{limit_clause}
"""
try:
db_response = await client.db.query_raw(query, *query_params)
return pl.DataFrame(db_response, infer_schema_length=None)
except Exception as exc:
raise RuntimeError(f"Error retrieving usage data: {exc}") from exc
async def get_table_info(self) -> Dict[str, Any]:
"""Return metadata about the spend table for diagnostics."""
client = self._ensure_prisma_client()
info_query = """
SELECT column_name, data_type, is_nullable
FROM information_schema.columns
WHERE table_name = 'LiteLLM_DailyUserSpend'
ORDER BY ordinal_position;
"""
try:
columns_response = await client.db.query_raw(info_query)
return {"columns": columns_response, "table_name": "LiteLLM_DailyUserSpend"}
except Exception as exc:
raise RuntimeError(f"Error getting table info: {exc}") from exc

View File

@@ -0,0 +1,12 @@
"""Destination implementations for Focus export."""
from .base import FocusDestination, FocusTimeWindow
from .factory import FocusDestinationFactory
from .s3_destination import FocusS3Destination
__all__ = [
"FocusDestination",
"FocusDestinationFactory",
"FocusTimeWindow",
"FocusS3Destination",
]

View File

@@ -0,0 +1,30 @@
"""Abstract destination interfaces for Focus export."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
from typing import Protocol
@dataclass(frozen=True)
class FocusTimeWindow:
"""Represents the span of data exported in a single batch."""
start_time: datetime
end_time: datetime
frequency: str
class FocusDestination(Protocol):
"""Protocol for anything that can receive Focus export files."""
async def deliver(
self,
*,
content: bytes,
time_window: FocusTimeWindow,
filename: str,
) -> None:
"""Persist the serialized export for the provided time window."""
...

View File

@@ -0,0 +1,59 @@
"""Factory helpers for Focus export destinations."""
from __future__ import annotations
import os
from typing import Any, Dict, Optional
from .base import FocusDestination
from .s3_destination import FocusS3Destination
class FocusDestinationFactory:
"""Builds destination instances based on provider/config settings."""
@staticmethod
def create(
*,
provider: str,
prefix: str,
config: Optional[Dict[str, Any]] = None,
) -> FocusDestination:
"""Return a destination implementation for the requested provider."""
provider_lower = provider.lower()
normalized_config = FocusDestinationFactory._resolve_config(
provider=provider_lower, overrides=config or {}
)
if provider_lower == "s3":
return FocusS3Destination(prefix=prefix, config=normalized_config)
raise NotImplementedError(
f"Provider '{provider}' not supported for Focus export"
)
@staticmethod
def _resolve_config(
*,
provider: str,
overrides: Dict[str, Any],
) -> Dict[str, Any]:
if provider == "s3":
resolved = {
"bucket_name": overrides.get("bucket_name")
or os.getenv("FOCUS_S3_BUCKET_NAME"),
"region_name": overrides.get("region_name")
or os.getenv("FOCUS_S3_REGION_NAME"),
"endpoint_url": overrides.get("endpoint_url")
or os.getenv("FOCUS_S3_ENDPOINT_URL"),
"aws_access_key_id": overrides.get("aws_access_key_id")
or os.getenv("FOCUS_S3_ACCESS_KEY"),
"aws_secret_access_key": overrides.get("aws_secret_access_key")
or os.getenv("FOCUS_S3_SECRET_KEY"),
"aws_session_token": overrides.get("aws_session_token")
or os.getenv("FOCUS_S3_SESSION_TOKEN"),
}
if not resolved.get("bucket_name"):
raise ValueError("FOCUS_S3_BUCKET_NAME must be provided for S3 exports")
return {k: v for k, v in resolved.items() if v is not None}
raise NotImplementedError(
f"Provider '{provider}' not supported for Focus export configuration"
)

View File

@@ -0,0 +1,74 @@
"""S3 destination implementation for Focus export."""
from __future__ import annotations
import asyncio
from datetime import timezone
from typing import Any, Optional
import boto3
from .base import FocusDestination, FocusTimeWindow
class FocusS3Destination(FocusDestination):
"""Handles uploading serialized exports to S3 buckets."""
def __init__(
self,
*,
prefix: str,
config: Optional[dict[str, Any]] = None,
) -> None:
config = config or {}
bucket_name = config.get("bucket_name")
if not bucket_name:
raise ValueError("bucket_name must be provided for S3 destination")
self.bucket_name = bucket_name
self.prefix = prefix.rstrip("/")
self.config = config
async def deliver(
self,
*,
content: bytes,
time_window: FocusTimeWindow,
filename: str,
) -> None:
object_key = self._build_object_key(time_window=time_window, filename=filename)
await asyncio.to_thread(self._upload, content, object_key)
def _build_object_key(self, *, time_window: FocusTimeWindow, filename: str) -> str:
start_utc = time_window.start_time.astimezone(timezone.utc)
date_component = f"date={start_utc.strftime('%Y-%m-%d')}"
parts = [self.prefix, date_component]
if time_window.frequency == "hourly":
parts.append(f"hour={start_utc.strftime('%H')}")
key_prefix = "/".join(filter(None, parts))
return f"{key_prefix}/{filename}" if key_prefix else filename
def _upload(self, content: bytes, object_key: str) -> None:
client_kwargs: dict[str, Any] = {}
region_name = self.config.get("region_name")
if region_name:
client_kwargs["region_name"] = region_name
endpoint_url = self.config.get("endpoint_url")
if endpoint_url:
client_kwargs["endpoint_url"] = endpoint_url
session_kwargs: dict[str, Any] = {}
for key in (
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
):
if self.config.get(key):
session_kwargs[key] = self.config[key]
s3_client = boto3.client("s3", **client_kwargs, **session_kwargs)
s3_client.put_object(
Bucket=self.bucket_name,
Key=object_key,
Body=content,
ContentType="application/octet-stream",
)

View File

@@ -0,0 +1,124 @@
"""Core export engine for Focus integrations (heavy dependencies)."""
from __future__ import annotations
from typing import Any, Dict, Optional
import polars as pl
from litellm._logging import verbose_logger
from .database import FocusLiteLLMDatabase
from .destinations import FocusDestinationFactory, FocusTimeWindow
from .serializers import FocusParquetSerializer, FocusSerializer
from .transformer import FocusTransformer
class FocusExportEngine:
"""Engine that fetches, normalizes, and uploads Focus exports."""
def __init__(
self,
*,
provider: str,
export_format: str,
prefix: str,
destination_config: Optional[dict[str, Any]] = None,
) -> None:
self.provider = provider
self.export_format = export_format
self.prefix = prefix
self._destination = FocusDestinationFactory.create(
provider=self.provider,
prefix=self.prefix,
config=destination_config,
)
self._serializer = self._init_serializer()
self._transformer = FocusTransformer()
self._database = FocusLiteLLMDatabase()
def _init_serializer(self) -> FocusSerializer:
if self.export_format != "parquet":
raise NotImplementedError("Only parquet export supported currently")
return FocusParquetSerializer()
async def dry_run_export_usage_data(self, limit: Optional[int]) -> Dict[str, Any]:
data = await self._database.get_usage_data(limit=limit)
normalized = self._transformer.transform(data)
usage_sample = data.head(min(50, len(data))).to_dicts()
normalized_sample = normalized.head(min(50, len(normalized))).to_dicts()
summary = {
"total_records": len(normalized),
"total_spend": self._sum_column(normalized, "spend"),
"total_tokens": self._sum_column(normalized, "total_tokens"),
"unique_teams": self._count_unique(normalized, "team_id"),
"unique_models": self._count_unique(normalized, "model"),
}
return {
"usage_data": usage_sample,
"normalized_data": normalized_sample,
"summary": summary,
}
async def export_window(
self,
*,
window: FocusTimeWindow,
limit: Optional[int],
) -> None:
data = await self._database.get_usage_data(
limit=limit,
start_time_utc=window.start_time,
end_time_utc=window.end_time,
)
if data.is_empty():
verbose_logger.debug("Focus export: no usage data for window %s", window)
return
normalized = self._transformer.transform(data)
if normalized.is_empty():
verbose_logger.debug(
"Focus export: normalized data empty for window %s", window
)
return
await self._serialize_and_upload(normalized, window)
async def _serialize_and_upload(
self, frame: pl.DataFrame, window: FocusTimeWindow
) -> None:
payload = self._serializer.serialize(frame)
if not payload:
verbose_logger.debug("Focus export: serializer returned empty payload")
return
await self._destination.deliver(
content=payload,
time_window=window,
filename=self._build_filename(),
)
def _build_filename(self) -> str:
if not self._serializer.extension:
raise ValueError("Serializer must declare a file extension")
return f"usage.{self._serializer.extension}"
@staticmethod
def _sum_column(frame: pl.DataFrame, column: str) -> float:
if frame.is_empty() or column not in frame.columns:
return 0.0
value = frame.select(pl.col(column).sum().alias("sum")).row(0)[0]
if value is None:
return 0.0
return float(value)
@staticmethod
def _count_unique(frame: pl.DataFrame, column: str) -> int:
if frame.is_empty() or column not in frame.columns:
return 0
value = frame.select(pl.col(column).n_unique().alias("unique")).row(0)[0]
if value is None:
return 0
return int(value)

View File

@@ -0,0 +1,214 @@
"""Focus export logger orchestrating DB pull/transform/upload."""
from __future__ import annotations
import os
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from .destinations import FocusTimeWindow
if TYPE_CHECKING:
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from .export_engine import FocusExportEngine
else:
AsyncIOScheduler = Any
FOCUS_USAGE_DATA_JOB_NAME = "focus_export_usage_data"
DEFAULT_DRY_RUN_LIMIT = 500
class FocusLogger(CustomLogger):
"""Coordinates Focus export jobs across transformer/serializer/destination layers."""
def __init__(
self,
*,
provider: Optional[str] = None,
export_format: Optional[str] = None,
frequency: Optional[str] = None,
cron_offset_minute: Optional[int] = None,
interval_seconds: Optional[int] = None,
prefix: Optional[str] = None,
destination_config: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.provider = (provider or os.getenv("FOCUS_PROVIDER") or "s3").lower()
self.export_format = (
export_format or os.getenv("FOCUS_FORMAT") or "parquet"
).lower()
self.frequency = (frequency or os.getenv("FOCUS_FREQUENCY") or "hourly").lower()
self.cron_offset_minute = (
cron_offset_minute
if cron_offset_minute is not None
else int(os.getenv("FOCUS_CRON_OFFSET", "5"))
)
raw_interval = (
interval_seconds
if interval_seconds is not None
else os.getenv("FOCUS_INTERVAL_SECONDS")
)
self.interval_seconds = int(raw_interval) if raw_interval is not None else None
env_prefix = os.getenv("FOCUS_PREFIX")
self.prefix: str = (
prefix
if prefix is not None
else (env_prefix if env_prefix else "focus_exports")
)
self._destination_config = destination_config
self._engine: Optional["FocusExportEngine"] = None
def _ensure_engine(self) -> "FocusExportEngine":
"""Instantiate the heavy export engine lazily."""
if self._engine is None:
from .export_engine import FocusExportEngine
self._engine = FocusExportEngine(
provider=self.provider,
export_format=self.export_format,
prefix=self.prefix,
destination_config=self._destination_config,
)
return self._engine
async def export_usage_data(
self,
*,
limit: Optional[int] = None,
start_time_utc: Optional[datetime] = None,
end_time_utc: Optional[datetime] = None,
) -> None:
"""Public hook to trigger export immediately."""
if bool(start_time_utc) ^ bool(end_time_utc):
raise ValueError(
"start_time_utc and end_time_utc must be provided together"
)
if start_time_utc and end_time_utc:
window = FocusTimeWindow(
start_time=start_time_utc,
end_time=end_time_utc,
frequency=self.frequency,
)
else:
window = self._compute_time_window(datetime.now(timezone.utc))
await self._export_window(window=window, limit=limit)
async def dry_run_export_usage_data(
self, limit: Optional[int] = DEFAULT_DRY_RUN_LIMIT
) -> dict[str, Any]:
"""Return transformed data without uploading."""
engine = self._ensure_engine()
return await engine.dry_run_export_usage_data(limit=limit)
async def initialize_focus_export_job(self) -> None:
"""Entry point for scheduler jobs to run export cycle with locking."""
from litellm.proxy.proxy_server import proxy_logging_obj
pod_lock_manager = None
if proxy_logging_obj is not None:
writer = getattr(proxy_logging_obj, "db_spend_update_writer", None)
if writer is not None:
pod_lock_manager = getattr(writer, "pod_lock_manager", None)
if pod_lock_manager and pod_lock_manager.redis_cache:
acquired = await pod_lock_manager.acquire_lock(
cronjob_id=FOCUS_USAGE_DATA_JOB_NAME
)
if not acquired:
verbose_logger.debug("Focus export: unable to acquire pod lock")
return
try:
await self._run_scheduled_export()
finally:
await pod_lock_manager.release_lock(
cronjob_id=FOCUS_USAGE_DATA_JOB_NAME
)
else:
await self._run_scheduled_export()
@staticmethod
async def init_focus_export_background_job(
scheduler: AsyncIOScheduler,
) -> None:
"""Register the export cron/interval job with the provided scheduler."""
focus_loggers: List[
CustomLogger
] = litellm.logging_callback_manager.get_custom_loggers_for_type(
callback_type=FocusLogger
)
if not focus_loggers:
verbose_logger.debug(
"No Focus export logger registered; skipping scheduler"
)
return
focus_logger = cast(FocusLogger, focus_loggers[0])
trigger_kwargs = focus_logger._build_scheduler_trigger()
scheduler.add_job(
focus_logger.initialize_focus_export_job,
**trigger_kwargs,
)
def _build_scheduler_trigger(self) -> Dict[str, Any]:
"""Return scheduler configuration for the selected frequency."""
if self.frequency == "interval":
seconds = self.interval_seconds or 60
return {"trigger": "interval", "seconds": seconds}
if self.frequency == "hourly":
minute = max(0, min(59, self.cron_offset_minute))
return {"trigger": "cron", "minute": minute, "second": 0}
if self.frequency == "daily":
total_minutes = max(0, self.cron_offset_minute)
hour = min(23, total_minutes // 60)
minute = min(59, total_minutes % 60)
return {"trigger": "cron", "hour": hour, "minute": minute, "second": 0}
raise ValueError(f"Unsupported frequency: {self.frequency}")
async def _run_scheduled_export(self) -> None:
"""Execute the scheduled export for the configured window."""
window = self._compute_time_window(datetime.now(timezone.utc))
await self._export_window(window=window, limit=None)
async def _export_window(
self,
*,
window: FocusTimeWindow,
limit: Optional[int],
) -> None:
engine = self._ensure_engine()
await engine.export_window(window=window, limit=limit)
def _compute_time_window(self, now: datetime) -> FocusTimeWindow:
"""Derive the time window to export based on configured frequency."""
now_utc = now.astimezone(timezone.utc)
if self.frequency == "hourly":
end_time = now_utc.replace(minute=0, second=0, microsecond=0)
start_time = end_time - timedelta(hours=1)
elif self.frequency == "daily":
end_time = now_utc.replace(hour=0, minute=0, second=0, microsecond=0)
start_time = end_time - timedelta(days=1)
elif self.frequency == "interval":
interval = timedelta(seconds=self.interval_seconds or 60)
end_time = now_utc
start_time = end_time - interval
else:
raise ValueError(f"Unsupported frequency: {self.frequency}")
return FocusTimeWindow(
start_time=start_time,
end_time=end_time,
frequency=self.frequency,
)
__all__ = ["FocusLogger"]

View File

@@ -0,0 +1,50 @@
"""Schema definitions for Focus export data."""
from __future__ import annotations
import polars as pl
# see: https://focus.finops.org/focus-specification/v1-2/
FOCUS_NORMALIZED_SCHEMA = pl.Schema(
[
("BilledCost", pl.Decimal(18, 6)),
("BillingAccountId", pl.String),
("BillingAccountName", pl.String),
("BillingCurrency", pl.String),
("BillingPeriodStart", pl.Datetime(time_unit="us")),
("BillingPeriodEnd", pl.Datetime(time_unit="us")),
("ChargeCategory", pl.String),
("ChargeClass", pl.String),
("ChargeDescription", pl.String),
("ChargeFrequency", pl.String),
("ChargePeriodStart", pl.Datetime(time_unit="us")),
("ChargePeriodEnd", pl.Datetime(time_unit="us")),
("ConsumedQuantity", pl.Decimal(18, 6)),
("ConsumedUnit", pl.String),
("ContractedCost", pl.Decimal(18, 6)),
("ContractedUnitPrice", pl.Decimal(18, 6)),
("EffectiveCost", pl.Decimal(18, 6)),
("InvoiceIssuerName", pl.String),
("ListCost", pl.Decimal(18, 6)),
("ListUnitPrice", pl.Decimal(18, 6)),
("PricingCategory", pl.String),
("PricingQuantity", pl.Decimal(18, 6)),
("PricingUnit", pl.String),
("ProviderName", pl.String),
("PublisherName", pl.String),
("RegionId", pl.String),
("RegionName", pl.String),
("ResourceId", pl.String),
("ResourceName", pl.String),
("ResourceType", pl.String),
("ServiceCategory", pl.String),
("ServiceSubcategory", pl.String),
("ServiceName", pl.String),
("SubAccountId", pl.String),
("SubAccountName", pl.String),
("SubAccountType", pl.String),
("Tags", pl.Object),
]
)
__all__ = ["FOCUS_NORMALIZED_SCHEMA"]

View File

@@ -0,0 +1,6 @@
"""Serializer package exports for Focus integration."""
from .base import FocusSerializer
from .parquet import FocusParquetSerializer
__all__ = ["FocusSerializer", "FocusParquetSerializer"]

View File

@@ -0,0 +1,18 @@
"""Serializer abstractions for Focus export."""
from __future__ import annotations
from abc import ABC, abstractmethod
import polars as pl
class FocusSerializer(ABC):
"""Base serializer turning Focus frames into bytes."""
extension: str = ""
@abstractmethod
def serialize(self, frame: pl.DataFrame) -> bytes:
"""Convert the normalized Focus frame into the chosen format."""
raise NotImplementedError

View File

@@ -0,0 +1,22 @@
"""Parquet serializer for Focus export."""
from __future__ import annotations
import io
import polars as pl
from .base import FocusSerializer
class FocusParquetSerializer(FocusSerializer):
"""Serialize normalized Focus frames to Parquet bytes."""
extension = "parquet"
def serialize(self, frame: pl.DataFrame) -> bytes:
"""Encode the provided frame as a parquet payload."""
target = frame if not frame.is_empty() else pl.DataFrame(schema=frame.schema)
buffer = io.BytesIO()
target.write_parquet(buffer, compression="snappy")
return buffer.getvalue()

View File

@@ -0,0 +1,90 @@
"""Focus export data transformer."""
from __future__ import annotations
from datetime import timedelta
import polars as pl
from .schema import FOCUS_NORMALIZED_SCHEMA
class FocusTransformer:
"""Transforms LiteLLM DB rows into Focus-compatible schema."""
schema = FOCUS_NORMALIZED_SCHEMA
def transform(self, frame: pl.DataFrame) -> pl.DataFrame:
"""Return a normalized frame expected by downstream serializers."""
if frame.is_empty():
return pl.DataFrame(schema=self.schema)
# derive period start/end from usage date
frame = frame.with_columns(
pl.col("date")
.cast(pl.Utf8)
.str.strptime(pl.Datetime(time_unit="us"), format="%Y-%m-%d", strict=False)
.alias("usage_date"),
)
frame = frame.with_columns(
pl.col("usage_date").alias("ChargePeriodStart"),
(pl.col("usage_date") + timedelta(days=1)).alias("ChargePeriodEnd"),
)
def fmt(col):
return col.dt.strftime("%Y-%m-%dT%H:%M:%SZ")
DEC = pl.Decimal(18, 6)
def dec(col):
return col.cast(DEC)
none_str = pl.lit(None, dtype=pl.Utf8)
none_dec = pl.lit(None, dtype=pl.Decimal(18, 6))
return frame.select(
dec(pl.col("spend").fill_null(0.0)).alias("BilledCost"),
pl.col("api_key").cast(pl.String).alias("BillingAccountId"),
pl.col("api_key_alias").cast(pl.String).alias("BillingAccountName"),
pl.lit("API Key").alias("BillingAccountType"),
pl.lit("USD").alias("BillingCurrency"),
fmt(pl.col("ChargePeriodEnd")).alias("BillingPeriodEnd"),
fmt(pl.col("ChargePeriodStart")).alias("BillingPeriodStart"),
pl.lit("Usage").alias("ChargeCategory"),
none_str.alias("ChargeClass"),
pl.col("model").cast(pl.String).alias("ChargeDescription"),
pl.lit("Usage-Based").alias("ChargeFrequency"),
fmt(pl.col("ChargePeriodEnd")).alias("ChargePeriodEnd"),
fmt(pl.col("ChargePeriodStart")).alias("ChargePeriodStart"),
dec(pl.lit(1.0)).alias("ConsumedQuantity"),
pl.lit("Requests").alias("ConsumedUnit"),
dec(pl.col("spend").fill_null(0.0)).alias("ContractedCost"),
none_str.alias("ContractedUnitPrice"),
dec(pl.col("spend").fill_null(0.0)).alias("EffectiveCost"),
pl.col("custom_llm_provider").cast(pl.String).alias("InvoiceIssuerName"),
none_str.alias("InvoiceId"),
dec(pl.col("spend").fill_null(0.0)).alias("ListCost"),
none_dec.alias("ListUnitPrice"),
none_str.alias("AvailabilityZone"),
pl.lit("USD").alias("PricingCurrency"),
none_str.alias("PricingCategory"),
dec(pl.lit(1.0)).alias("PricingQuantity"),
none_dec.alias("PricingCurrencyContractedUnitPrice"),
dec(pl.col("spend").fill_null(0.0)).alias("PricingCurrencyEffectiveCost"),
none_dec.alias("PricingCurrencyListUnitPrice"),
pl.lit("Requests").alias("PricingUnit"),
pl.col("custom_llm_provider").cast(pl.String).alias("ProviderName"),
pl.col("custom_llm_provider").cast(pl.String).alias("PublisherName"),
none_str.alias("RegionId"),
none_str.alias("RegionName"),
pl.col("model").cast(pl.String).alias("ResourceId"),
pl.col("model").cast(pl.String).alias("ResourceName"),
pl.col("model").cast(pl.String).alias("ResourceType"),
pl.lit("AI and Machine Learning").alias("ServiceCategory"),
pl.lit("Generative AI").alias("ServiceSubcategory"),
pl.col("model_group").cast(pl.String).alias("ServiceName"),
pl.col("team_id").cast(pl.String).alias("SubAccountId"),
pl.col("team_alias").cast(pl.String).alias("SubAccountName"),
none_str.alias("SubAccountType"),
none_str.alias("Tags"),
)

View File

@@ -0,0 +1,157 @@
import os
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
# from here: https://docs.rungalileo.io/galileo/gen-ai-studio-products/galileo-observe/how-to/logging-data-via-restful-apis#structuring-your-records
class LLMResponse(BaseModel):
latency_ms: int
status_code: int
input_text: str
output_text: str
node_type: str
model: str
num_input_tokens: int
num_output_tokens: int
output_logprobs: Optional[Dict[str, Any]] = Field(
default=None,
description="Optional. When available, logprobs are used to compute Uncertainty.",
)
created_at: str = Field(
..., description='timestamp constructed in "%Y-%m-%dT%H:%M:%S" format'
)
tags: Optional[List[str]] = None
user_metadata: Optional[Dict[str, Any]] = None
class GalileoObserve(CustomLogger):
def __init__(self) -> None:
self.in_memory_records: List[dict] = []
self.batch_size = 1
self.base_url = os.getenv("GALILEO_BASE_URL", None)
self.project_id = os.getenv("GALILEO_PROJECT_ID", None)
self.headers: Optional[Dict[str, str]] = None
self.async_httpx_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
pass
def set_galileo_headers(self):
# following https://docs.rungalileo.io/galileo/gen-ai-studio-products/galileo-observe/how-to/logging-data-via-restful-apis#logging-your-records
headers = {
"accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded",
}
galileo_login_response = litellm.module_level_client.post(
url=f"{self.base_url}/login",
headers=headers,
data={
"username": os.getenv("GALILEO_USERNAME"),
"password": os.getenv("GALILEO_PASSWORD"),
},
)
access_token = galileo_login_response.json()["access_token"]
self.headers = {
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
}
def get_output_str_from_response(self, response_obj, kwargs):
output = None
if response_obj is not None and (
kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse)
):
output = None
elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
output = response_obj["choices"][0]["message"].json()
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
output = response_obj.choices[0].text
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
output = response_obj["data"]
return output
async def async_log_success_event(
self, kwargs: Any, response_obj: Any, start_time: Any, end_time: Any
):
verbose_logger.debug("On Async Success")
_latency_ms = int((end_time - start_time).total_seconds() * 1000)
_call_type = kwargs.get("call_type", "litellm")
input_text = litellm.utils.get_formatted_prompt(
data=kwargs, call_type=_call_type
)
_usage = response_obj.get("usage", {}) or {}
num_input_tokens = _usage.get("prompt_tokens", 0)
num_output_tokens = _usage.get("completion_tokens", 0)
output_text = self.get_output_str_from_response(
response_obj=response_obj, kwargs=kwargs
)
if output_text is not None:
request_record = LLMResponse(
latency_ms=_latency_ms,
status_code=200,
input_text=input_text,
output_text=output_text,
node_type=_call_type,
model=kwargs.get("model", "-"),
num_input_tokens=num_input_tokens,
num_output_tokens=num_output_tokens,
created_at=start_time.strftime(
"%Y-%m-%dT%H:%M:%S"
), # timestamp str constructed in "%Y-%m-%dT%H:%M:%S" format
)
# dump to dict
request_dict = request_record.model_dump()
self.in_memory_records.append(request_dict)
if len(self.in_memory_records) >= self.batch_size:
await self.flush_in_memory_records()
async def flush_in_memory_records(self):
verbose_logger.debug("flushing in memory records")
response = await self.async_httpx_handler.post(
url=f"{self.base_url}/projects/{self.project_id}/observe/ingest",
headers=self.headers,
json={"records": self.in_memory_records},
)
if response.status_code == 200:
verbose_logger.debug(
"Galileo Logger:successfully flushed in memory records"
)
self.in_memory_records = []
else:
verbose_logger.debug("Galileo Logger: failed to flush in memory records")
verbose_logger.debug(
"Galileo Logger error=%s, status code=%s",
response.text,
response.status_code,
)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
verbose_logger.debug("On Async Failure")

View File

@@ -0,0 +1,12 @@
# GCS (Google Cloud Storage) Bucket Logging on LiteLLM Gateway
This folder contains the GCS Bucket Logging integration for LiteLLM Gateway.
## Folder Structure
- `gcs_bucket.py`: This is the main file that handles failure/success logging to GCS Bucket
- `gcs_bucket_base.py`: This file contains the GCSBucketBase class which handles Authentication for GCS Buckets
## Further Reading
- [Doc setting up GCS Bucket Logging on LiteLLM Proxy (Gateway)](https://docs.litellm.ai/docs/observability/gcs_bucket_integration)
- [Doc on Key / Team Based logging with GCS](https://docs.litellm.ai/docs/proxy/team_logging)

View File

@@ -0,0 +1,419 @@
import asyncio
import hashlib
import json
import os
import time
from litellm._uuid import uuid
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from urllib.parse import quote
from litellm._logging import verbose_logger
from litellm.constants import LITELLM_ASYNCIO_QUEUE_MAXSIZE
from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
from litellm.proxy._types import CommonProxyErrors
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
from litellm.types.integrations.gcs_bucket import *
from litellm.types.utils import StandardLoggingPayload
if TYPE_CHECKING:
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
else:
VertexBase = Any
class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
def __init__(self, bucket_name: Optional[str] = None) -> None:
from litellm.proxy.proxy_server import premium_user
super().__init__(bucket_name=bucket_name)
self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE))
self.flush_interval = int(
os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS)
)
self.use_batched_logging = (
os.getenv(
"GCS_USE_BATCHED_LOGGING", str(GCS_DEFAULT_USE_BATCHED_LOGGING).lower()
).lower()
== "true"
)
self.flush_lock = asyncio.Lock()
super().__init__(
flush_lock=self.flush_lock,
batch_size=self.batch_size,
flush_interval=self.flush_interval,
)
self.log_queue: asyncio.Queue[GCSLogQueueItem] = asyncio.Queue( # type: ignore[assignment]
maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE
)
asyncio.create_task(self.periodic_flush())
AdditionalLoggingUtils.__init__(self)
if premium_user is not True:
raise ValueError(
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
)
#### ASYNC ####
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
from litellm.proxy.proxy_server import premium_user
if premium_user is not True:
raise ValueError(
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
)
try:
verbose_logger.debug(
"GCS Logger: async_log_success_event logging kwargs: %s, response_obj: %s",
kwargs,
response_obj,
)
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if logging_payload is None:
raise ValueError("standard_logging_object not found in kwargs")
# When queue is at maxsize, flush immediately to make room (no blocking, no data dropped)
if self.log_queue.full():
await self.flush_queue()
await self.log_queue.put(
GCSLogQueueItem(
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
)
)
except Exception as e:
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
verbose_logger.debug(
"GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
kwargs,
response_obj,
)
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if logging_payload is None:
raise ValueError("standard_logging_object not found in kwargs")
# When queue is at maxsize, flush immediately to make room (no blocking, no data dropped)
if self.log_queue.full():
await self.flush_queue()
await self.log_queue.put(
GCSLogQueueItem(
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
)
)
except Exception as e:
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
def _drain_queue_batch(self) -> List[GCSLogQueueItem]:
"""
Drain items from the queue (non-blocking), respecting batch_size limit.
This prevents unbounded queue growth when processing is slower than log accumulation.
Returns:
List of items to process, up to batch_size items
"""
items_to_process: List[GCSLogQueueItem] = []
while len(items_to_process) < self.batch_size:
try:
items_to_process.append(self.log_queue.get_nowait())
except asyncio.QueueEmpty:
break
return items_to_process
def _generate_batch_object_name(self, date_str: str, batch_id: str) -> str:
"""
Generate object name for a batched log file.
Format: {date}/batch-{batch_id}.ndjson
"""
return f"{date_str}/batch-{batch_id}.ndjson"
def _get_config_key(self, kwargs: Dict[str, Any]) -> str:
"""
Extract a synchronous grouping key from kwargs to group items by GCS config.
This allows us to batch items with the same bucket/credentials together.
Returns a string key that uniquely identifies the GCS config combination.
This key may contain sensitive information (bucket names, paths) - use _sanitize_config_key()
for logging purposes.
"""
standard_callback_dynamic_params = (
kwargs.get("standard_callback_dynamic_params", None) or {}
)
bucket_name = (
standard_callback_dynamic_params.get("gcs_bucket_name", None)
or self.BUCKET_NAME
or "default"
)
path_service_account = (
standard_callback_dynamic_params.get("gcs_path_service_account", None)
or self.path_service_account_json
or "default"
)
return f"{bucket_name}|{path_service_account}"
def _sanitize_config_key(self, config_key: str) -> str:
"""
Create a sanitized version of the config key for logging.
Uses a hash to avoid exposing sensitive bucket names or service account paths.
Returns a short hash prefix for safe logging.
"""
hash_obj = hashlib.sha256(config_key.encode("utf-8"))
return f"config-{hash_obj.hexdigest()[:8]}"
def _group_items_by_config(
self, items: List[GCSLogQueueItem]
) -> Dict[str, List[GCSLogQueueItem]]:
"""
Group items by their GCS config (bucket + credentials).
This ensures items with different configs are processed separately.
Returns a dict mapping config_key -> list of items with that config.
"""
grouped: Dict[str, List[GCSLogQueueItem]] = {}
for item in items:
config_key = self._get_config_key(item["kwargs"])
if config_key not in grouped:
grouped[config_key] = []
grouped[config_key].append(item)
return grouped
def _combine_payloads_to_ndjson(self, items: List[GCSLogQueueItem]) -> str:
"""
Combine multiple log payloads into newline-delimited JSON (NDJSON) format.
Each line is a valid JSON object representing one log entry.
"""
lines = []
for item in items:
logging_payload = item["payload"]
json_line = json.dumps(logging_payload, default=str, ensure_ascii=False)
lines.append(json_line)
return "\n".join(lines)
async def _send_grouped_batch(
self, items: List[GCSLogQueueItem], config_key: str
) -> Tuple[int, int]:
"""
Send a batch of items that share the same GCS config.
Returns:
(success_count, error_count)
"""
if not items:
return (0, 0)
first_kwargs = items[0]["kwargs"]
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
first_kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
current_date = self._get_object_date_from_datetime(
datetime.now(timezone.utc)
)
batch_id = f"{int(time.time() * 1000)}-{uuid.uuid4().hex[:8]}"
object_name = self._generate_batch_object_name(current_date, batch_id)
combined_payload = self._combine_payloads_to_ndjson(items)
await self._log_json_data_on_gcs(
headers=headers,
bucket_name=bucket_name,
object_name=object_name,
logging_payload=combined_payload,
)
success_count = len(items)
error_count = 0
return (success_count, error_count)
except Exception as e:
success_count = 0
error_count = len(items)
verbose_logger.exception(
f"GCS Bucket error logging batch payload to GCS bucket: {str(e)}"
)
return (success_count, error_count)
async def _send_individual_logs(self, items: List[GCSLogQueueItem]) -> None:
"""
Send each log individually as separate GCS objects (legacy behavior).
This is used when GCS_USE_BATCHED_LOGGING is disabled.
"""
for item in items:
await self._send_single_log_item(item)
async def _send_single_log_item(self, item: GCSLogQueueItem) -> None:
"""
Send a single log item to GCS as an individual object.
"""
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
item["kwargs"]
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
object_name = self._get_object_name(
kwargs=item["kwargs"],
logging_payload=item["payload"],
response_obj=item["response_obj"],
)
await self._log_json_data_on_gcs(
headers=headers,
bucket_name=bucket_name,
object_name=object_name,
logging_payload=item["payload"],
)
except Exception as e:
verbose_logger.exception(
f"GCS Bucket error logging individual payload to GCS bucket: {str(e)}"
)
async def async_send_batch(self):
"""
Process queued logs - sends logs to GCS Bucket.
If `GCS_USE_BATCHED_LOGGING` is enabled (default), batches multiple log payloads
into single GCS object uploads (NDJSON format), dramatically reducing API calls.
If disabled, sends each log individually as separate GCS objects (legacy behavior).
"""
items_to_process = self._drain_queue_batch()
if not items_to_process:
return
if self.use_batched_logging:
grouped_items = self._group_items_by_config(items_to_process)
for config_key, group_items in grouped_items.items():
await self._send_grouped_batch(group_items, config_key)
else:
await self._send_individual_logs(items_to_process)
def _get_object_name(
self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any
) -> str:
"""
Get the object name to use for the current payload
"""
current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc))
if logging_payload.get("error_str", None) is not None:
object_name = self._generate_failure_object_name(
request_date_str=current_date,
)
else:
object_name = self._generate_success_object_name(
request_date_str=current_date,
response_id=response_obj.get("id", ""),
)
# used for testing
_litellm_params = kwargs.get("litellm_params", None) or {}
_metadata = _litellm_params.get("metadata", None) or {}
if "gcs_log_id" in _metadata:
object_name = _metadata["gcs_log_id"]
return object_name
async def get_request_response_payload(
self,
request_id: str,
start_time_utc: Optional[datetime],
end_time_utc: Optional[datetime],
) -> Optional[dict]:
"""
Get the request and response payload for a given `request_id`
Tries current day, next day, and previous day until it finds the payload
"""
if start_time_utc is None:
raise ValueError(
"start_time_utc is required for getting a payload from GCS Bucket"
)
dates_to_try = [
start_time_utc,
start_time_utc + timedelta(days=1),
start_time_utc - timedelta(days=1),
]
date_str = None
for date in dates_to_try:
try:
date_str = self._get_object_date_from_datetime(datetime_obj=date)
object_name = self._generate_success_object_name(
request_date_str=date_str,
response_id=request_id,
)
encoded_object_name = quote(object_name, safe="")
response = await self.download_gcs_object(encoded_object_name)
if response is not None:
loaded_response = json.loads(response)
return loaded_response
except Exception as e:
verbose_logger.debug(
f"Failed to fetch payload for date {date_str}: {str(e)}"
)
continue
return None
def _generate_success_object_name(
self,
request_date_str: str,
response_id: str,
) -> str:
return f"{request_date_str}/{response_id}"
def _generate_failure_object_name(
self,
request_date_str: str,
) -> str:
return f"{request_date_str}/failure-{uuid.uuid4().hex}"
def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str:
return datetime_obj.strftime("%Y-%m-%d")
async def flush_queue(self):
"""
Override flush_queue to work with asyncio.Queue.
"""
await self.async_send_batch()
self.last_flush_time = time.time()
async def periodic_flush(self):
"""
Override periodic_flush to work with asyncio.Queue.
"""
while True:
await asyncio.sleep(self.flush_interval)
verbose_logger.debug(
f"GCS Bucket periodic flush after {self.flush_interval} seconds"
)
await self.flush_queue()
async def async_health_check(self) -> IntegrationHealthCheckStatus:
raise NotImplementedError("GCS Bucket does not support health check")

View File

@@ -0,0 +1,347 @@
import json
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from litellm.integrations.gcs_bucket.gcs_bucket_mock_client import (
should_use_gcs_mock,
create_mock_gcs_client,
mock_vertex_auth_methods,
)
from litellm._logging import verbose_logger
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.integrations.gcs_bucket import *
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
if TYPE_CHECKING:
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
else:
VertexBase = Any
IAM_AUTH_KEY = "IAM_AUTH"
class GCSBucketBase(CustomBatchLogger):
def __init__(self, bucket_name: Optional[str] = None, **kwargs) -> None:
self.is_mock_mode = should_use_gcs_mock()
if self.is_mock_mode:
mock_vertex_auth_methods()
create_mock_gcs_client()
self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
_path_service_account = os.getenv("GCS_PATH_SERVICE_ACCOUNT")
_bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
self.path_service_account_json: Optional[str] = _path_service_account
self.BUCKET_NAME: Optional[str] = _bucket_name
self.vertex_instances: Dict[str, VertexBase] = {}
super().__init__(**kwargs)
async def construct_request_headers(
self,
service_account_json: Optional[str],
vertex_instance: Optional[VertexBase] = None,
) -> Dict[str, str]:
from litellm import vertex_chat_completion
if vertex_instance is None:
vertex_instance = vertex_chat_completion
_auth_header, vertex_project = await vertex_instance._ensure_access_token_async(
credentials=service_account_json,
project_id=None,
custom_llm_provider="vertex_ai",
)
auth_header, _ = vertex_instance._get_token_and_url(
model="gcs-bucket",
auth_header=_auth_header,
vertex_credentials=service_account_json,
vertex_project=vertex_project,
vertex_location=None,
gemini_api_key=None,
stream=None,
custom_llm_provider="vertex_ai",
api_base=None,
)
verbose_logger.debug("constructed auth_header %s", auth_header)
headers = {
"Authorization": f"Bearer {auth_header}", # auth_header
"Content-Type": "application/json",
}
return headers
def sync_construct_request_headers(self) -> Dict[str, str]:
"""
Construct request headers for GCS API calls
"""
from litellm import vertex_chat_completion
# Get project_id from environment if available, otherwise None
# This helps support use of this library to auth to pull secrets
# from Secret Manager.
project_id = os.getenv("GOOGLE_SECRET_MANAGER_PROJECT_ID")
_auth_header, vertex_project = vertex_chat_completion._ensure_access_token(
credentials=self.path_service_account_json,
project_id=project_id,
custom_llm_provider="vertex_ai",
)
auth_header, _ = vertex_chat_completion._get_token_and_url(
model="gcs-bucket",
auth_header=_auth_header,
vertex_credentials=self.path_service_account_json,
vertex_project=vertex_project,
vertex_location=None,
gemini_api_key=None,
stream=None,
custom_llm_provider="vertex_ai",
api_base=None,
)
verbose_logger.debug("constructed auth_header %s", auth_header)
headers = {
"Authorization": f"Bearer {auth_header}", # auth_header
"Content-Type": "application/json",
}
return headers
def _handle_folders_in_bucket_name(
self,
bucket_name: str,
object_name: str,
) -> Tuple[str, str]:
"""
Handles when the user passes a bucket name with a folder postfix
Example:
- Bucket name: "my-bucket/my-folder/dev"
- Object name: "my-object"
- Returns: bucket_name="my-bucket", object_name="my-folder/dev/my-object"
"""
if "/" in bucket_name:
bucket_name, prefix = bucket_name.split("/", 1)
object_name = f"{prefix}/{object_name}"
return bucket_name, object_name
return bucket_name, object_name
async def get_gcs_logging_config(
self, kwargs: Optional[Dict[str, Any]] = {}
) -> GCSLoggingConfig:
"""
This function is used to get the GCS logging config for the GCS Bucket Logger.
It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config.
If no dynamic parameters are provided, it uses the default values.
"""
if kwargs is None:
kwargs = {}
standard_callback_dynamic_params: Optional[
StandardCallbackDynamicParams
] = kwargs.get("standard_callback_dynamic_params", None)
bucket_name: str
path_service_account: Optional[str]
if standard_callback_dynamic_params is not None:
verbose_logger.debug("Using dynamic GCS logging")
verbose_logger.debug(
"standard_callback_dynamic_params: %s", standard_callback_dynamic_params
)
_bucket_name: Optional[str] = (
standard_callback_dynamic_params.get("gcs_bucket_name", None)
or self.BUCKET_NAME
)
_path_service_account: Optional[str] = (
standard_callback_dynamic_params.get("gcs_path_service_account", None)
or self.path_service_account_json
)
if _bucket_name is None:
raise ValueError(
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
)
bucket_name = _bucket_name
path_service_account = _path_service_account
vertex_instance = await self.get_or_create_vertex_instance(
credentials=path_service_account
)
else:
# If no dynamic parameters, use the default instance
if self.BUCKET_NAME is None:
raise ValueError(
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
)
bucket_name = self.BUCKET_NAME
path_service_account = self.path_service_account_json
vertex_instance = await self.get_or_create_vertex_instance(
credentials=path_service_account
)
return GCSLoggingConfig(
bucket_name=bucket_name,
vertex_instance=vertex_instance,
path_service_account=path_service_account,
)
async def get_or_create_vertex_instance(
self, credentials: Optional[str]
) -> VertexBase:
"""
This function is used to get the Vertex instance for the GCS Bucket Logger.
It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
"""
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
_in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
if _in_memory_key not in self.vertex_instances:
vertex_instance = VertexBase()
await vertex_instance._ensure_access_token_async(
credentials=credentials,
project_id=None,
custom_llm_provider="vertex_ai",
)
self.vertex_instances[_in_memory_key] = vertex_instance
return self.vertex_instances[_in_memory_key]
def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
"""
Returns key to use for caching the Vertex instance in-memory.
When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
- If a credentials string is provided, it is used as the key.
- If no credentials string is provided, "IAM_AUTH" is used as the key.
"""
return credentials or IAM_AUTH_KEY
async def download_gcs_object(self, object_name: str, **kwargs):
"""
Download an object from GCS.
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
"""
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs=kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
bucket_name, object_name = self._handle_folders_in_bucket_name(
bucket_name=bucket_name,
object_name=object_name,
)
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
# Send the GET request to download the object
response = await self.async_httpx_client.get(url=url, headers=headers)
if response.status_code != 200:
verbose_logger.error(
"GCS object download error: %s", str(response.text)
)
return None
verbose_logger.debug(
"GCS object download response status code: %s", response.status_code
)
# Return the content of the downloaded object
return response.content
except Exception as e:
verbose_logger.error("GCS object download error: %s", str(e))
return None
async def delete_gcs_object(self, object_name: str, **kwargs):
"""
Delete an object from GCS.
"""
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs=kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
bucket_name, object_name = self._handle_folders_in_bucket_name(
bucket_name=bucket_name,
object_name=object_name,
)
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
# Send the DELETE request to delete the object
response = await self.async_httpx_client.delete(url=url, headers=headers)
if (response.status_code != 200) or (response.status_code != 204):
verbose_logger.error(
"GCS object delete error: %s, status code: %s",
str(response.text),
response.status_code,
)
return None
verbose_logger.debug(
"GCS object delete response status code: %s, response: %s",
response.status_code,
response.text,
)
# Return the content of the downloaded object
return response.text
except Exception as e:
verbose_logger.error("GCS object download error: %s", str(e))
return None
async def _log_json_data_on_gcs(
self,
headers: Dict[str, str],
bucket_name: str,
object_name: str,
logging_payload: Union[StandardLoggingPayload, str],
):
"""
Helper function to make POST request to GCS Bucket in the specified bucket.
"""
if isinstance(logging_payload, str):
json_logged_payload = logging_payload
else:
json_logged_payload = json.dumps(logging_payload, default=str)
bucket_name, object_name = self._handle_folders_in_bucket_name(
bucket_name=bucket_name,
object_name=object_name,
)
response = await self.async_httpx_client.post(
headers=headers,
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
data=json_logged_payload,
)
if response.status_code != 200:
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
verbose_logger.debug("GCS Bucket response %s", response)
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
verbose_logger.debug("GCS Bucket response.text %s", response.text)
return response.json()

View File

@@ -0,0 +1,254 @@
"""
Mock client for GCS Bucket integration testing.
This module intercepts GCS API calls and Vertex AI auth calls, returning successful
mock responses, allowing full code execution without making actual network calls.
Usage:
Set GCS_MOCK=true in environment variables or config to enable mock mode.
"""
import asyncio
from litellm._logging import verbose_logger
from litellm.integrations.mock_client_factory import (
MockClientConfig,
create_mock_client_factory,
MockResponse,
)
# Use factory for POST handler
_config = MockClientConfig(
name="GCS",
env_var="GCS_MOCK",
default_latency_ms=150,
default_status_code=200,
default_json_data={"kind": "storage#object", "name": "mock-object"},
url_matchers=["storage.googleapis.com"],
patch_async_handler=True,
patch_sync_client=False,
)
_create_mock_gcs_post, should_use_gcs_mock = create_mock_client_factory(_config)
# Store original methods for GET/DELETE (GCS-specific)
_original_async_handler_get = None
_original_async_handler_delete = None
_mocks_initialized = False
# Default mock latency in seconds (simulates network round-trip)
# Typical GCS API calls take 100-300ms for uploads, 50-150ms for GET/DELETE
_MOCK_LATENCY_SECONDS = (
float(__import__("os").getenv("GCS_MOCK_LATENCY_MS", "150")) / 1000.0
)
async def _mock_async_handler_get(
self, url, params=None, headers=None, follow_redirects=None
):
"""Monkey-patched AsyncHTTPHandler.get that intercepts GCS calls."""
# Only mock GCS API calls
if isinstance(url, str) and "storage.googleapis.com" in url:
verbose_logger.info(f"[GCS MOCK] GET to {url}")
await asyncio.sleep(_MOCK_LATENCY_SECONDS)
# Return a minimal but valid StandardLoggingPayload JSON string as bytes
# This matches what GCS returns when downloading with ?alt=media
mock_payload = {
"id": "mock-request-id",
"trace_id": "mock-trace-id",
"call_type": "completion",
"stream": False,
"response_cost": 0.0,
"status": "success",
"status_fields": {"llm_api_status": "success"},
"custom_llm_provider": "mock",
"total_tokens": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"startTime": 0.0,
"endTime": 0.0,
"completionStartTime": 0.0,
"response_time": 0.0,
"model_map_information": {"model": "mock-model"},
"model": "mock-model",
"model_id": None,
"model_group": None,
"api_base": "https://api.mock.com",
"metadata": {},
"cache_hit": None,
"cache_key": None,
"saved_cache_cost": 0.0,
"request_tags": [],
"end_user": None,
"requester_ip_address": None,
"messages": None,
"response": None,
"error_str": None,
"error_information": None,
"model_parameters": {},
"hidden_params": {},
"guardrail_information": None,
"standard_built_in_tools_params": None,
}
return MockResponse(
status_code=200,
json_data=mock_payload,
url=url,
elapsed_seconds=_MOCK_LATENCY_SECONDS,
)
if _original_async_handler_get is not None:
return await _original_async_handler_get(
self,
url=url,
params=params,
headers=headers,
follow_redirects=follow_redirects,
)
raise RuntimeError("Original AsyncHTTPHandler.get not available")
async def _mock_async_handler_delete(
self,
url,
data=None,
json=None,
params=None,
headers=None,
timeout=None,
stream=False,
content=None,
):
"""Monkey-patched AsyncHTTPHandler.delete that intercepts GCS calls."""
# Only mock GCS API calls
if isinstance(url, str) and "storage.googleapis.com" in url:
verbose_logger.info(f"[GCS MOCK] DELETE to {url}")
await asyncio.sleep(_MOCK_LATENCY_SECONDS)
# DELETE returns 204 No Content with empty body (not JSON)
return MockResponse(
status_code=204,
json_data=None, # Empty body for DELETE
url=url,
elapsed_seconds=_MOCK_LATENCY_SECONDS,
)
if _original_async_handler_delete is not None:
return await _original_async_handler_delete(
self,
url=url,
data=data,
json=json,
params=params,
headers=headers,
timeout=timeout,
stream=stream,
content=content,
)
raise RuntimeError("Original AsyncHTTPHandler.delete not available")
def create_mock_gcs_client():
"""
Monkey-patch AsyncHTTPHandler methods to intercept GCS calls.
AsyncHTTPHandler is used by LiteLLM's get_async_httpx_client() which is what
GCSBucketBase uses for making API calls.
This function is idempotent - it only initializes mocks once, even if called multiple times.
"""
global _original_async_handler_get, _original_async_handler_delete, _mocks_initialized
# Use factory for POST handler
_create_mock_gcs_post()
# If already initialized, skip GET/DELETE patching
if _mocks_initialized:
return
verbose_logger.debug("[GCS MOCK] Initializing GCS GET/DELETE handlers...")
# Patch GET and DELETE handlers (GCS-specific)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
if _original_async_handler_get is None:
_original_async_handler_get = AsyncHTTPHandler.get
AsyncHTTPHandler.get = _mock_async_handler_get # type: ignore
verbose_logger.debug("[GCS MOCK] Patched AsyncHTTPHandler.get")
if _original_async_handler_delete is None:
_original_async_handler_delete = AsyncHTTPHandler.delete
AsyncHTTPHandler.delete = _mock_async_handler_delete # type: ignore
verbose_logger.debug("[GCS MOCK] Patched AsyncHTTPHandler.delete")
verbose_logger.debug(
f"[GCS MOCK] Mock latency set to {_MOCK_LATENCY_SECONDS*1000:.0f}ms"
)
verbose_logger.debug("[GCS MOCK] GCS mock client initialization complete")
_mocks_initialized = True
def mock_vertex_auth_methods():
"""
Monkey-patch Vertex AI auth methods to return fake tokens.
This prevents auth failures when GCS_MOCK is enabled.
This function is idempotent - it only patches once, even if called multiple times.
"""
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
# Store original methods if not already stored
if not hasattr(VertexBase, "_original_ensure_access_token_async"):
setattr(
VertexBase,
"_original_ensure_access_token_async",
VertexBase._ensure_access_token_async,
)
setattr(
VertexBase, "_original_ensure_access_token", VertexBase._ensure_access_token
)
setattr(
VertexBase, "_original_get_token_and_url", VertexBase._get_token_and_url
)
async def _mock_ensure_access_token_async(
self, credentials, project_id, custom_llm_provider
):
"""Mock async auth method - returns fake token."""
verbose_logger.debug(
"[GCS MOCK] Vertex AI auth: _ensure_access_token_async called"
)
return ("mock-gcs-token", "mock-project-id")
def _mock_ensure_access_token(
self, credentials, project_id, custom_llm_provider
):
"""Mock sync auth method - returns fake token."""
verbose_logger.debug(
"[GCS MOCK] Vertex AI auth: _ensure_access_token called"
)
return ("mock-gcs-token", "mock-project-id")
def _mock_get_token_and_url(
self,
model,
auth_header,
vertex_credentials,
vertex_project,
vertex_location,
gemini_api_key,
stream,
custom_llm_provider,
api_base,
):
"""Mock get_token_and_url - returns fake token."""
verbose_logger.debug("[GCS MOCK] Vertex AI auth: _get_token_and_url called")
return ("mock-gcs-token", "https://storage.googleapis.com")
# Patch the methods
VertexBase._ensure_access_token_async = _mock_ensure_access_token_async # type: ignore
VertexBase._ensure_access_token = _mock_ensure_access_token # type: ignore
VertexBase._get_token_and_url = _mock_get_token_and_url # type: ignore
verbose_logger.debug("[GCS MOCK] Patched Vertex AI auth methods")
# should_use_gcs_mock is already created by the factory

View File

@@ -0,0 +1,214 @@
"""
BETA
This is the PubSub logger for GCS PubSub, this sends LiteLLM SpendLogs Payloads to GCS PubSub.
Users can use this instead of sending their SpendLogs to their Postgres database.
"""
import asyncio
import json
import os
import traceback
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from litellm.types.utils import StandardLoggingPayload
if TYPE_CHECKING:
from litellm.proxy._types import SpendLogsPayload
else:
SpendLogsPayload = Any
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
class GcsPubSubLogger(CustomBatchLogger):
def __init__(
self,
project_id: Optional[str] = None,
topic_id: Optional[str] = None,
credentials_path: Optional[str] = None,
**kwargs,
):
"""
Initialize Google Cloud Pub/Sub publisher
Args:
project_id (str): Google Cloud project ID
topic_id (str): Pub/Sub topic ID
credentials_path (str, optional): Path to Google Cloud credentials JSON file
"""
from litellm.proxy.utils import _premium_user_check
_premium_user_check()
self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
self.project_id = project_id or os.getenv("GCS_PUBSUB_PROJECT_ID")
self.topic_id = topic_id or os.getenv("GCS_PUBSUB_TOPIC_ID")
self.path_service_account_json = credentials_path or os.getenv(
"GCS_PATH_SERVICE_ACCOUNT"
)
if not self.project_id or not self.topic_id:
raise ValueError("Both project_id and topic_id must be provided")
self.flush_lock = asyncio.Lock()
super().__init__(**kwargs, flush_lock=self.flush_lock)
asyncio.create_task(self.periodic_flush())
self.log_queue: List[Union[SpendLogsPayload, StandardLoggingPayload]] = []
async def construct_request_headers(self) -> Dict[str, str]:
"""Construct authorization headers using Vertex AI auth"""
from litellm import vertex_chat_completion
(
_auth_header,
vertex_project,
) = await vertex_chat_completion._ensure_access_token_async(
credentials=self.path_service_account_json,
project_id=self.project_id,
custom_llm_provider="vertex_ai",
)
auth_header, _ = vertex_chat_completion._get_token_and_url(
model="pub-sub",
auth_header=_auth_header,
vertex_credentials=self.path_service_account_json,
vertex_project=vertex_project,
vertex_location=None,
gemini_api_key=None,
stream=None,
custom_llm_provider="vertex_ai",
api_base=None,
)
headers = {
"Authorization": f"Bearer {auth_header}",
"Content-Type": "application/json",
}
return headers
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Async Log success events to GCS PubSub Topic
- Creates a SpendLogsPayload
- Adds to batch queue
- Flushes based on CustomBatchLogger settings
Raises:
Raises a NON Blocking verbose_logger.exception if an error occurs
"""
from litellm.proxy.spend_tracking.spend_tracking_utils import (
get_logging_payload,
)
from litellm.proxy.utils import _premium_user_check
_premium_user_check()
try:
verbose_logger.debug(
"PubSub: Logging - Enters logging function for model %s", kwargs
)
standard_logging_payload = kwargs.get("standard_logging_object", None)
# Backwards compatibility with old logging payload
if litellm.gcs_pub_sub_use_v1 is True:
spend_logs_payload = get_logging_payload(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
)
self.log_queue.append(spend_logs_payload)
else:
# New logging payload, StandardLoggingPayload
self.log_queue.append(standard_logging_payload)
if len(self.log_queue) >= self.batch_size:
await self.async_send_batch()
except Exception as e:
verbose_logger.exception(
f"PubSub Layer Error - {str(e)}\n{traceback.format_exc()}"
)
pass
async def async_send_batch(self):
"""
Sends the batch of messages to Pub/Sub
"""
try:
if not self.log_queue:
return
verbose_logger.debug(
f"PubSub - about to flush {len(self.log_queue)} events"
)
for message in self.log_queue:
await self.publish_message(message)
except Exception as e:
verbose_logger.exception(
f"PubSub Error sending batch - {str(e)}\n{traceback.format_exc()}"
)
finally:
self.log_queue.clear()
async def publish_message(
self, message: Union[SpendLogsPayload, StandardLoggingPayload]
) -> Optional[Dict[str, Any]]:
"""
Publish message to Google Cloud Pub/Sub using REST API
Args:
message: Message to publish (dict or string)
Returns:
dict: Published message response
"""
try:
headers = await self.construct_request_headers()
# Prepare message data
if isinstance(message, str):
message_data = message
else:
message_data = json.dumps(message, default=str)
# Base64 encode the message
import base64
encoded_message = base64.b64encode(message_data.encode("utf-8")).decode(
"utf-8"
)
# Construct request body
request_body = {"messages": [{"data": encoded_message}]}
url = f"https://pubsub.googleapis.com/v1/projects/{self.project_id}/topics/{self.topic_id}:publish"
response = await self.async_httpx_client.post(
url=url, headers=headers, json=request_body
)
if response.status_code not in [200, 202]:
verbose_logger.error("Pub/Sub publish error: %s", str(response.text))
raise Exception(f"Failed to publish message: {response.text}")
verbose_logger.debug("Pub/Sub response: %s", response.text)
return response.json()
except Exception as e:
verbose_logger.error("Pub/Sub publish error: %s", str(e))
return None

View File

@@ -0,0 +1,428 @@
"""
Callback to log events to a Generic API Endpoint
- Creates a StandardLoggingPayload
- Adds to batch queue
- Flushes based on CustomBatchLogger settings
"""
import asyncio
import json
import os
import re
import traceback
from typing import Dict, List, Literal, Optional, Union
import litellm
from litellm._logging import verbose_logger
from litellm._uuid import uuid
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.utils import StandardLoggingPayload
API_EVENT_TYPES = Literal["llm_api_success", "llm_api_failure"]
LOG_FORMAT_TYPES = Literal["json_array", "ndjson", "single"]
def load_compatible_callbacks() -> Dict:
"""
Load the generic_api_compatible_callbacks.json file
Returns:
Dict: Dictionary of compatible callbacks configuration
"""
try:
json_path = os.path.join(
os.path.dirname(__file__), "generic_api_compatible_callbacks.json"
)
with open(json_path, "r") as f:
return json.load(f)
except Exception as e:
verbose_logger.warning(
f"Error loading generic_api_compatible_callbacks.json: {str(e)}"
)
return {}
def is_callback_compatible(callback_name: str) -> bool:
"""
Check if a callback_name exists in the compatible callbacks list
Args:
callback_name: Name of the callback to check
Returns:
bool: True if callback_name exists in the compatible callbacks, False otherwise
"""
compatible_callbacks = load_compatible_callbacks()
return callback_name in compatible_callbacks
def get_callback_config(callback_name: str) -> Optional[Dict]:
"""
Get the configuration for a specific callback
Args:
callback_name: Name of the callback to get config for
Returns:
Optional[Dict]: Configuration dict for the callback, or None if not found
"""
compatible_callbacks = load_compatible_callbacks()
return compatible_callbacks.get(callback_name)
def substitute_env_variables(value: str) -> str:
"""
Replace {{environment_variables.VAR_NAME}} patterns with actual environment variable values
Args:
value: String that may contain {{environment_variables.VAR_NAME}} patterns
Returns:
str: String with environment variables substituted
"""
pattern = r"\{\{environment_variables\.([A-Z_]+)\}\}"
def replace_env_var(match):
env_var_name = match.group(1)
return os.getenv(env_var_name, "")
return re.sub(pattern, replace_env_var, value)
class GenericAPILogger(CustomBatchLogger):
def __init__(
self,
endpoint: Optional[str] = None,
headers: Optional[dict] = None,
event_types: Optional[List[API_EVENT_TYPES]] = None,
callback_name: Optional[str] = None,
log_format: Optional[LOG_FORMAT_TYPES] = None,
**kwargs,
):
"""
Initialize the GenericAPILogger
Args:
endpoint: Optional[str] = None,
headers: Optional[dict] = None,
event_types: Optional[List[API_EVENT_TYPES]] = None,
callback_name: Optional[str] = None - If provided, loads config from generic_api_compatible_callbacks.json
log_format: Optional[LOG_FORMAT_TYPES] = None - Format for log output: "json_array" (default), "ndjson", or "single"
"""
#########################################################
# Check if callback_name is provided and load config
#########################################################
if callback_name:
if is_callback_compatible(callback_name):
verbose_logger.debug(
f"Loading configuration for callback: {callback_name}"
)
callback_config = get_callback_config(callback_name)
# Use config from JSON if not explicitly provided
if callback_config:
if endpoint is None and "endpoint" in callback_config:
endpoint = substitute_env_variables(callback_config["endpoint"])
if "headers" in callback_config:
headers = headers or {}
for key, value in callback_config["headers"].items():
if key not in headers:
headers[key] = substitute_env_variables(value)
if event_types is None and "event_types" in callback_config:
event_types = callback_config["event_types"]
if log_format is None and "log_format" in callback_config:
log_format = callback_config["log_format"]
else:
verbose_logger.warning(
f"callback_name '{callback_name}' not found in generic_api_compatible_callbacks.json"
)
#########################################################
# Init httpx client
#########################################################
self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
endpoint = endpoint or os.getenv("GENERIC_LOGGER_ENDPOINT")
if endpoint is None:
raise ValueError(
"endpoint not set for GenericAPILogger, GENERIC_LOGGER_ENDPOINT not found in environment variables"
)
self.headers: Dict = self._get_headers(headers)
self.endpoint: str = endpoint
self.event_types: Optional[List[API_EVENT_TYPES]] = event_types
self.callback_name: Optional[str] = callback_name
# Validate and store log_format
if log_format is not None and log_format not in [
"json_array",
"ndjson",
"single",
]:
raise ValueError(
f"Invalid log_format: {log_format}. Must be one of: 'json_array', 'ndjson', 'single'"
)
self.log_format: LOG_FORMAT_TYPES = log_format or "json_array"
verbose_logger.debug(
f"in init GenericAPILogger, callback_name: {self.callback_name}, endpoint {self.endpoint}, headers {self.headers}, event_types: {self.event_types}, log_format: {self.log_format}"
)
#########################################################
# Init variables for batch flushing logs
#########################################################
self.flush_lock = asyncio.Lock()
super().__init__(**kwargs, flush_lock=self.flush_lock)
asyncio.create_task(self.periodic_flush())
self.log_queue: List[Union[Dict, StandardLoggingPayload]] = []
def _get_headers(self, headers: Optional[dict] = None):
"""
Get headers for the Generic API Logger
Returns:
Dict: Headers for the Generic API Logger
Args:
headers: Optional[dict] = None
"""
# Process headers from different sources
headers_dict = {
"Content-Type": "application/json",
}
# 1. First check for headers from env var
env_headers = os.getenv("GENERIC_LOGGER_HEADERS")
if env_headers:
try:
# Parse headers in format "key1=value1,key2=value2" or "key1=value1"
header_items = env_headers.split(",")
for item in header_items:
if "=" in item:
key, value = item.split("=", 1)
headers_dict[key.strip()] = value.strip()
except Exception as e:
verbose_logger.warning(
f"Error parsing headers from environment variables: {str(e)}"
)
# 2. Update with litellm generic headers if available
if litellm.generic_logger_headers:
headers_dict.update(litellm.generic_logger_headers)
# 3. Override with directly provided headers if any
if headers:
headers_dict.update(headers)
return headers_dict
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Async Log success events to Generic API Endpoint
- Creates a StandardLoggingPayload
- Adds to batch queue
- Flushes based on CustomBatchLogger settings
Raises:
Raises a NON Blocking verbose_logger.exception if an error occurs
"""
if self.event_types is not None and "llm_api_success" not in self.event_types:
return
try:
verbose_logger.debug(
"Generic API Logger - Enters logging function for model %s", kwargs
)
standard_logging_payload = kwargs.get("standard_logging_object", None)
# Backwards compatibility with old logging payload
if litellm.generic_api_use_v1 is True:
payload = self._get_v1_logging_payload(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
)
self.log_queue.append(payload)
else:
# New logging payload, StandardLoggingPayload
self.log_queue.append(standard_logging_payload)
if len(self.log_queue) >= self.batch_size:
await self.async_send_batch()
except Exception as e:
verbose_logger.exception(
f"Generic API Logger Error - {str(e)}\n{traceback.format_exc()}"
)
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
"""
Async Log failure events to Generic API Endpoint
- Creates a StandardLoggingPayload
- Adds to batch queue
"""
if self.event_types is not None and "llm_api_failure" not in self.event_types:
return
try:
verbose_logger.debug(
"Generic API Logger - Enters logging function for model %s", kwargs
)
standard_logging_payload = kwargs.get("standard_logging_object", None)
if litellm.generic_api_use_v1 is True:
payload = self._get_v1_logging_payload(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
)
self.log_queue.append(payload)
else:
self.log_queue.append(standard_logging_payload)
if len(self.log_queue) >= self.batch_size:
await self.async_send_batch()
except Exception as e:
verbose_logger.exception(
f"Generic API Logger Error - {str(e)}\n{traceback.format_exc()}"
)
async def async_send_batch(self):
"""
Sends the batch of messages to Generic API Endpoint
Supports three formats:
- json_array: Sends all logs as a JSON array (default)
- ndjson: Sends logs as newline-delimited JSON
- single: Sends each log as individual HTTP request in parallel
"""
try:
if not self.log_queue:
return
verbose_logger.debug(
f"Generic API Logger - about to flush {len(self.log_queue)} events in '{self.log_format}' format"
)
if self.log_format == "single":
# Send each log as individual HTTP request in parallel
tasks = []
for log_entry in self.log_queue:
task = self.async_httpx_client.post(
url=self.endpoint,
headers=self.headers,
data=safe_dumps(log_entry),
)
tasks.append(task)
# Execute all requests in parallel
responses = await asyncio.gather(*tasks, return_exceptions=True)
# Log results
for idx, result in enumerate(responses):
if isinstance(result, Exception):
verbose_logger.exception(
f"Generic API Logger - Error sending log {idx}: {result}"
)
else:
# result is a Response object
verbose_logger.debug(
f"Generic API Logger - sent log {idx}, status: {result.status_code}" # type: ignore
)
else:
# Format the payload based on log_format
if self.log_format == "json_array":
data = safe_dumps(self.log_queue)
elif self.log_format == "ndjson":
data = "\n".join(safe_dumps(log) for log in self.log_queue)
else:
raise ValueError(f"Unknown log_format: {self.log_format}")
# Make POST request
response = await self.async_httpx_client.post(
url=self.endpoint,
headers=self.headers,
data=data,
)
verbose_logger.debug(
f"Generic API Logger - sent batch to {self.endpoint}, "
f"status: {response.status_code}, format: {self.log_format}"
)
except Exception as e:
verbose_logger.exception(
f"Generic API Logger Error sending batch - {str(e)}\n{traceback.format_exc()}"
)
finally:
self.log_queue.clear()
def _get_v1_logging_payload(
self, kwargs, response_obj, start_time, end_time
) -> dict:
"""
Maintained for backwards compatibility with old logging payload
Returns a dict of the payload to send to the Generic API Endpoint
"""
verbose_logger.debug(
f"GenericAPILogger Logging - Enters logging function for model {kwargs}"
)
# construct payload to send custom logger
# follows the same params as langfuse.py
litellm_params = kwargs.get("litellm_params", {})
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
messages = kwargs.get("messages")
cost = kwargs.get("response_cost", 0.0)
optional_params = kwargs.get("optional_params", {})
call_type = kwargs.get("call_type", "litellm.completion")
cache_hit = kwargs.get("cache_hit", False)
usage = response_obj["usage"]
id = response_obj.get("id", str(uuid.uuid4()))
# Build the initial payload
payload = {
"id": id,
"call_type": call_type,
"cache_hit": cache_hit,
"startTime": start_time,
"endTime": end_time,
"model": kwargs.get("model", ""),
"user": kwargs.get("user", ""),
"modelParameters": optional_params,
"messages": messages,
"response": response_obj,
"usage": usage,
"metadata": metadata,
"cost": cost,
}
# Ensure everything in the payload is converted to str
for key, value in payload.items():
try:
payload[key] = str(value)
except Exception:
# non blocking if it can't cast to a str
pass
return payload

View File

@@ -0,0 +1,37 @@
{
"sample_callback": {
"event_types": ["llm_api_success", "llm_api_failure"],
"endpoint": "{{environment_variables.SAMPLE_CALLBACK_URL}}",
"headers": {
"Content-Type": "application/json",
"Authorization": "Bearer {{environment_variables.SAMPLE_CALLBACK_API_KEY}}"
},
"environment_variables": ["SAMPLE_CALLBACK_URL", "SAMPLE_CALLBACK_API_KEY"]
},
"rubrik": {
"event_types": ["llm_api_success"],
"endpoint": "{{environment_variables.RUBRIK_WEBHOOK_URL}}",
"headers": {
"Content-Type": "application/json",
"Authorization": "Bearer {{environment_variables.RUBRIK_API_KEY}}"
},
"environment_variables": ["RUBRIK_API_KEY", "RUBRIK_WEBHOOK_URL"]
},
"sumologic": {
"endpoint": "{{environment_variables.SUMOLOGIC_WEBHOOK_URL}}",
"headers": {
"Content-Type": "application/json"
},
"environment_variables": ["SUMOLOGIC_WEBHOOK_URL"],
"log_format": "ndjson"
},
"qualifire_eval": {
"event_types": ["llm_api_success"],
"endpoint": "{{environment_variables.QUALIFIRE_WEBHOOK_URL}}",
"headers": {
"Content-Type": "application/json",
"X-Qualifire-API-Key": "{{environment_variables.QUALIFIRE_API_KEY}}"
},
"environment_variables": ["QUALIFIRE_API_KEY", "QUALIFIRE_WEBHOOK_URL"]
}
}

View File

@@ -0,0 +1,80 @@
"""Generic prompt management integration for LiteLLM."""
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from .generic_prompt_manager import GenericPromptManager
from litellm.types.prompts.init_prompts import PromptLiteLLMParams, PromptSpec
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.types.prompts.init_prompts import SupportedPromptIntegrations
from .generic_prompt_manager import GenericPromptManager
# Global instances
global_generic_prompt_config: Optional[dict] = None
def set_global_generic_prompt_config(config: dict) -> None:
"""
Set the global generic prompt configuration.
Args:
config: Dictionary containing generic prompt configuration
- api_base: Base URL for the API
- api_key: Optional API key for authentication
- timeout: Request timeout in seconds (default: 30)
"""
import litellm
litellm.global_generic_prompt_config = config # type: ignore
def prompt_initializer(
litellm_params: "PromptLiteLLMParams", prompt_spec: "PromptSpec"
) -> "CustomPromptManagement":
"""
Initialize a prompt from a generic prompt management API.
"""
prompt_id = getattr(litellm_params, "prompt_id", None)
api_base = litellm_params.api_base
api_key = litellm_params.api_key
if not api_base:
raise ValueError("api_base is required in generic_prompt_config")
provider_specific_query_params = litellm_params.provider_specific_query_params
try:
generic_prompt_manager = GenericPromptManager(
api_base=api_base,
api_key=api_key,
prompt_id=prompt_id,
additional_provider_specific_query_params=provider_specific_query_params,
**litellm_params.model_dump(
exclude_none=True,
exclude={
"prompt_id",
"api_key",
"provider_specific_query_params",
"api_base",
},
),
)
return generic_prompt_manager
except Exception as e:
raise e
prompt_initializer_registry = {
SupportedPromptIntegrations.GENERIC_PROMPT_MANAGEMENT.value: prompt_initializer,
}
# Export public API
__all__ = [
"GenericPromptManager",
"set_global_generic_prompt_config",
"global_generic_prompt_config",
"prompt_initializer_registry",
]

View File

@@ -0,0 +1,499 @@
"""
Generic prompt manager that integrates with LiteLLM's prompt management system.
Fetches prompts from any API that implements the /beta/litellm_prompt_management endpoint.
"""
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import httpx
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.integrations.prompt_management_base import (
PromptManagementBase,
PromptManagementClient,
)
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.llms.custom_http import httpxSpecialProvider
from litellm.types.llms.openai import AllMessageValues
from litellm.types.prompts.init_prompts import PromptSpec
from litellm.types.utils import StandardCallbackDynamicParams
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
class GenericPromptManager(CustomPromptManagement):
"""
Generic prompt manager that integrates with LiteLLM's prompt management system.
This class enables using prompts from any API that implements the
/beta/litellm_prompt_management endpoint.
Usage:
# Configure API access
generic_config = {
"api_base": "https://your-api.com",
"api_key": "your-api-key", # optional
"timeout": 30, # optional, defaults to 30
}
# Use with completion
response = litellm.completion(
model="generic_prompt/gpt-4",
prompt_id="my_prompt_id",
prompt_variables={"variable": "value"},
generic_prompt_config=generic_config,
messages=[{"role": "user", "content": "Additional message"}]
)
"""
def __init__(
self,
api_base: str,
api_key: Optional[str] = None,
timeout: int = 30,
prompt_id: Optional[str] = None,
additional_provider_specific_query_params: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""
Initialize the Generic Prompt Manager.
Args:
api_base: Base URL for the API (e.g., "https://your-api.com")
api_key: Optional API key for authentication
timeout: Request timeout in seconds (default: 30)
prompt_id: Optional prompt ID to pre-load
"""
super().__init__(**kwargs)
self.api_base = api_base.rstrip("/")
self.api_key = api_key
self.timeout = timeout
self.prompt_id = prompt_id
self.additional_provider_specific_query_params = (
additional_provider_specific_query_params
)
self._prompt_cache: Dict[str, PromptManagementClient] = {}
@property
def integration_name(self) -> str:
"""Integration name used in model names like 'generic_prompt/gpt-4'."""
return "generic_prompt"
def _get_headers(self) -> Dict[str, str]:
"""Get HTTP headers for API requests."""
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
return headers
def _fetch_prompt_from_api(
self, prompt_id: Optional[str], prompt_spec: Optional[PromptSpec]
) -> Dict[str, Any]:
"""
Fetch a prompt from the API.
Args:
prompt_id: The ID of the prompt to fetch
Returns:
The prompt data from the API
Raises:
Exception: If the API request fails
"""
if prompt_id is None and prompt_spec is None:
raise ValueError("prompt_id or prompt_spec is required")
url = f"{self.api_base}/beta/litellm_prompt_management"
params = {
"prompt_id": prompt_id,
**(self.additional_provider_specific_query_params or {}),
}
http_client = _get_httpx_client()
try:
response = http_client.get(
url,
params=params,
headers=self._get_headers(),
)
response.raise_for_status()
return response.json()
except httpx.HTTPError as e:
raise Exception(f"Failed to fetch prompt '{prompt_id}' from API: {e}")
except json.JSONDecodeError as e:
raise Exception(f"Failed to parse prompt response for '{prompt_id}': {e}")
async def async_fetch_prompt_from_api(
self, prompt_id: Optional[str], prompt_spec: Optional[PromptSpec]
) -> Dict[str, Any]:
"""
Fetch a prompt from the API asynchronously.
"""
if prompt_id is None and prompt_spec is None:
raise ValueError("prompt_id or prompt_spec is required")
url = f"{self.api_base}/beta/litellm_prompt_management"
params = {
"prompt_id": prompt_id,
**(
prompt_spec.litellm_params.provider_specific_query_params
if prompt_spec
and prompt_spec.litellm_params.provider_specific_query_params
else {}
),
}
http_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.PromptManagement,
)
try:
response = await http_client.get(
url,
params=params,
headers=self._get_headers(),
)
response.raise_for_status()
return response.json()
except httpx.HTTPError as e:
raise Exception(f"Failed to fetch prompt '{prompt_id}' from API: {e}")
except json.JSONDecodeError as e:
raise Exception(f"Failed to parse prompt response for '{prompt_id}': {e}")
def _parse_api_response(
self,
prompt_id: Optional[str],
prompt_spec: Optional[PromptSpec],
api_response: Dict[str, Any],
) -> PromptManagementClient:
"""
Parse the API response into a PromptManagementClient structure.
Expected API response format:
{
"prompt_id": "string",
"prompt_template": [
{"role": "system", "content": "..."},
{"role": "user", "content": "..."}
],
"prompt_template_model": "gpt-4", # optional
"prompt_template_optional_params": { # optional
"temperature": 0.7,
"max_tokens": 100
}
}
Args:
prompt_id: The ID of the prompt
api_response: The response from the API
Returns:
PromptManagementClient structure
"""
return PromptManagementClient(
prompt_id=prompt_id,
prompt_template=api_response.get("prompt_template", []),
prompt_template_model=api_response.get("prompt_template_model"),
prompt_template_optional_params=api_response.get(
"prompt_template_optional_params"
),
completed_messages=None,
)
def should_run_prompt_management(
self,
prompt_id: Optional[str],
prompt_spec: Optional[PromptSpec],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> bool:
"""
Determine if prompt management should run based on the prompt_id.
For Generic Prompt Manager, we always return True and handle the prompt loading
in the _compile_prompt_helper method.
"""
if prompt_id is not None or (
prompt_spec is not None
and prompt_spec.litellm_params.provider_specific_query_params is not None
):
return True
return False
def _get_cache_key(
self,
prompt_id: Optional[str],
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
) -> str:
return f"{prompt_id}:{prompt_label}:{prompt_version}"
def _common_caching_logic(
self,
prompt_id: Optional[str],
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
prompt_variables: Optional[dict] = None,
) -> Optional[PromptManagementClient]:
"""
Common caching logic for the prompt manager.
"""
# Check cache first
cache_key = self._get_cache_key(prompt_id, prompt_label, prompt_version)
if cache_key in self._prompt_cache:
cached_prompt = self._prompt_cache[cache_key]
# Return a copy with variables applied if needed
if prompt_variables:
return self._apply_variables(cached_prompt, prompt_variables)
return cached_prompt
return None
def _compile_prompt_helper(
self,
prompt_id: Optional[str],
prompt_spec: Optional[PromptSpec],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
) -> PromptManagementClient:
"""
Compile a prompt template into a PromptManagementClient structure.
This method:
1. Fetches the prompt from the API (with caching)
2. Applies any prompt variables (if the API supports it)
3. Returns the structured prompt data
Args:
prompt_id: The ID of the prompt
prompt_variables: Variables to substitute in the template (optional)
dynamic_callback_params: Dynamic callback parameters
prompt_label: Optional label for the prompt version
prompt_version: Optional specific version number
Returns:
PromptManagementClient structure
"""
cached_prompt = self._common_caching_logic(
prompt_id=prompt_id,
prompt_label=prompt_label,
prompt_version=prompt_version,
prompt_variables=prompt_variables,
)
if cached_prompt:
return cached_prompt
cache_key = self._get_cache_key(prompt_id, prompt_label, prompt_version)
try:
# Fetch from API
api_response = self._fetch_prompt_from_api(prompt_id, prompt_spec)
# Parse the response
prompt_client = self._parse_api_response(
prompt_id, prompt_spec, api_response
)
# Cache the result
self._prompt_cache[cache_key] = prompt_client
# Apply variables if provided
if prompt_variables:
prompt_client = self._apply_variables(prompt_client, prompt_variables)
return prompt_client
except Exception as e:
raise ValueError(f"Error compiling prompt '{prompt_id}': {e}")
async def async_compile_prompt_helper(
self,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_spec: Optional[PromptSpec] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
) -> PromptManagementClient:
# Check cache first
cached_prompt = self._common_caching_logic(
prompt_id=prompt_id,
prompt_label=prompt_label,
prompt_version=prompt_version,
prompt_variables=prompt_variables,
)
if cached_prompt:
return cached_prompt
cache_key = self._get_cache_key(prompt_id, prompt_label, prompt_version)
try:
# Fetch from API
api_response = await self.async_fetch_prompt_from_api(
prompt_id=prompt_id, prompt_spec=prompt_spec
)
# Parse the response
prompt_client = self._parse_api_response(
prompt_id, prompt_spec, api_response
)
# Cache the result
self._prompt_cache[cache_key] = prompt_client
# Apply variables if provided
if prompt_variables:
prompt_client = self._apply_variables(prompt_client, prompt_variables)
return prompt_client
except Exception as e:
raise ValueError(
f"Error compiling prompt '{prompt_id}': {e}, prompt_spec: {prompt_spec}"
)
def _apply_variables(
self,
prompt_client: PromptManagementClient,
variables: Dict[str, Any],
) -> PromptManagementClient:
"""
Apply variables to the prompt template.
This performs simple string substitution using {variable_name} syntax.
Args:
prompt_client: The prompt client structure
variables: Variables to substitute
Returns:
Updated PromptManagementClient with variables applied
"""
# Create a copy of the prompt template with variables applied
updated_messages: List[AllMessageValues] = []
for message in prompt_client["prompt_template"]:
updated_message = dict(message) # type: ignore
if "content" in updated_message and isinstance(
updated_message["content"], str
):
content = updated_message["content"]
for key, value in variables.items():
content = content.replace(f"{{{key}}}", str(value))
content = content.replace(
f"{{{{{key}}}}}", str(value)
) # Also support {{key}}
updated_message["content"] = content
updated_messages.append(updated_message) # type: ignore
return PromptManagementClient(
prompt_id=prompt_client["prompt_id"],
prompt_template=updated_messages,
prompt_template_model=prompt_client["prompt_template_model"],
prompt_template_optional_params=prompt_client[
"prompt_template_optional_params"
],
completed_messages=None,
)
async def async_get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
litellm_logging_obj: "LiteLLMLoggingObj",
prompt_spec: Optional[PromptSpec] = None,
tools: Optional[List[Dict]] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
ignore_prompt_manager_model: Optional[bool] = False,
ignore_prompt_manager_optional_params: Optional[bool] = False,
) -> Tuple[str, List[AllMessageValues], dict]:
"""
Get chat completion prompt and return processed model, messages, and parameters.
"""
return await PromptManagementBase.async_get_chat_completion_prompt(
self,
model,
messages,
non_default_params,
prompt_id=prompt_id,
prompt_variables=prompt_variables,
litellm_logging_obj=litellm_logging_obj,
dynamic_callback_params=dynamic_callback_params,
prompt_spec=prompt_spec,
tools=tools,
prompt_label=prompt_label,
prompt_version=prompt_version,
ignore_prompt_manager_model=(
ignore_prompt_manager_model
or prompt_spec.litellm_params.ignore_prompt_manager_model
if prompt_spec
else False
),
ignore_prompt_manager_optional_params=(
ignore_prompt_manager_optional_params
or prompt_spec.litellm_params.ignore_prompt_manager_optional_params
if prompt_spec
else False
),
)
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_spec: Optional[PromptSpec] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
ignore_prompt_manager_model: Optional[bool] = False,
ignore_prompt_manager_optional_params: Optional[bool] = False,
) -> Tuple[str, List[AllMessageValues], dict]:
"""
Get chat completion prompt and return processed model, messages, and parameters.
"""
return PromptManagementBase.get_chat_completion_prompt(
self,
model,
messages,
non_default_params,
prompt_id=prompt_id,
prompt_variables=prompt_variables,
dynamic_callback_params=dynamic_callback_params,
prompt_spec=prompt_spec,
prompt_label=prompt_label,
prompt_version=prompt_version,
ignore_prompt_manager_model=(
ignore_prompt_manager_model
or prompt_spec.litellm_params.ignore_prompt_manager_model
if prompt_spec
else False
),
ignore_prompt_manager_optional_params=(
ignore_prompt_manager_optional_params
or prompt_spec.litellm_params.ignore_prompt_manager_optional_params
if prompt_spec
else False
),
)
def clear_cache(self) -> None:
"""Clear the prompt cache."""
self._prompt_cache.clear()

View File

@@ -0,0 +1,317 @@
# LiteLLM gitlab Prompt Management
A powerful prompt management system for LiteLLM that fetches `.prompt` files from gitlab repositories. This enables team-based prompt management with gitlab's built-in access control and version control capabilities.
## Features
- **🏢 Team-based access control**: Leverage gitlab's workspace and repository permissions
- **📁 Repository-based prompt storage**: Store prompts in gitlab repositories
- **🔐 Multiple authentication methods**: Support for access tokens and basic auth
- **🎯 YAML frontmatter**: Define model, parameters, and schemas in file headers
- **🔧 Handlebars templating**: Use `{{variable}}` syntax with Jinja2 backend
- **✅ Input validation**: Automatic validation against defined schemas
- **🔗 LiteLLM integration**: Works seamlessly with `litellm.completion()`
- **💬 Smart message parsing**: Converts prompts to proper chat messages
- **⚙️ Parameter extraction**: Automatically applies model settings from prompts
## Quick Start
### 1. Set up gitlab Repository
Create a repository in your gitlab workspace and add `.prompt` files:
```
your-repo/
├── prompts/
│ ├── chat_assistant.prompt
│ ├── code_reviewer.prompt
│ └── data_analyst.prompt
```
### 2. Create a `.prompt` file
Create a file called `prompts/chat_assistant.prompt`:
```yaml
---
model: gpt-4
temperature: 0.7
max_tokens: 150
input:
schema:
user_message: string
system_context?: string
---
{% if system_context %}System: {{system_context}}
{% endif %}User: {{user_message}}
```
### 3. Configure gitlab Access
#### Option A: Access Token (Recommended)
```python
import litellm
# Configure gitlab access
gitlab_config = {
"project": "a/b/<repo_name>",
"access_token": "your-access-token",
"base_url": "gitlab url",
"prompts_path": "src/prompts", # folder to point to, defaults to root
"branch":"main" # optional, defaults to main
}
# Set global gitlab configuration
litellm.set_global_gitlab_config(gitlab_config)
```
#### Option B: Basic Authentication
```python
import litellm
# Configure gitlab access with basic auth
gitlab_config = {
"project": "a/b/<repo_name>",
"base_url": "base url",
"access_token": "your-app-password", # Use app password for basic auth
"branch": "main",
"prompts_path": "src/prompts", # folder to point to, defaults to root
}
litellm.set_global_gitlab_config(gitlab_config)
```
### 4. Use with LiteLLM
```python
# Use with completion - the model prefix 'gitlab/' tells LiteLLM to use gitlab prompt management
response = litellm.completion(
model="gitlab/gpt-4", # The actual model comes from the .prompt file
prompt_id="prompts/chat_assistant", # Location of the prompt file
prompt_variables={
"user_message": "What is machine learning?",
"system_context": "You are a helpful AI tutor."
},
# Any additional messages will be appended after the prompt
messages=[{"role": "user", "content": "Please explain it simply."}]
)
print(response.choices[0].message.content)
```
## Proxy Server Configuration
### 1. Create a `.prompt` file
Create `prompts/hello.prompt`:
```yaml
---
model: gpt-4
temperature: 0.7
---
System: You are a helpful assistant.
User: {{user_message}}
```
### 2. Setup config.yaml
```yaml
model_list:
- model_name: my-gitlab-model
litellm_params:
model: gitlab/gpt-4
prompt_id: "prompts/hello"
api_key: os.environ/OPENAI_API_KEY
litellm_settings:
global_gitlab_config:
workspace: "your-workspace"
repository: "your-repo"
access_token: "your-access-token"
branch: "main"
```
### 3. Start the proxy
```bash
litellm --config config.yaml --detailed_debug
```
### 4. Test it!
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "my-gitlab-model",
"messages": [{"role": "user", "content": "IGNORED"}],
"prompt_variables": {
"user_message": "What is the capital of France?"
}
}'
```
## Prompt File Format
### Basic Structure
```yaml
---
# Model configuration
model: gpt-4
temperature: 0.7
max_tokens: 500
# Input schema (optional)
input:
schema:
user_message: string
system_context?: string
---
System: You are a helpful {{role}} assistant.
User: {{user_message}}
```
### Advanced Features
**Multi-role conversations:**
```yaml
---
model: gpt-4
temperature: 0.3
---
System: You are a helpful coding assistant.
User: {{user_question}}
```
**Dynamic model selection:**
```yaml
---
model: "{{preferred_model}}" # Model can be a variable
temperature: 0.7
---
System: You are a helpful assistant specialized in {{domain}}.
User: {{user_message}}
```
## Team-Based Access Control
gitlab's built-in permission system provides team-based access control:
1. **Workspace-level permissions**: Control access to entire workspaces
2. **Repository-level permissions**: Control access to specific repositories
3. **Branch-level permissions**: Control access to specific branches
4. **User and group management**: Manage team members and their access levels
### Setting up Team Access
1. **Create workspaces for each team**:
```
team-a-prompts/
team-b-prompts/
team-c-prompts/
```
2. **Configure repository permissions**:
- Grant read access to team members
- Grant write access to prompt maintainers
- Use branch protection rules for production prompts
3. **Use different access tokens**:
- Each team can have their own access token
- Tokens can be scoped to specific repositories
- Use app passwords for additional security
## API Reference
### gitlab Configuration
```python
gitlab_config = {
"workspace": str, # Required: gitlab workspace name
"repository": str, # Required: Repository name
"access_token": str, # Required: gitlab access token or app password
"branch": str, # Optional: Branch to fetch from (default: "main")
"base_url": str, # Optional: Custom gitlab API URL
"auth_method": str, # Optional: "token" or "basic" (default: "token")
"username": str, # Optional: Username for basic auth
"base_url" : str # Optional: Incase where the base url is not https://api.gitlab.org/2.0
}
```
### LiteLLM Integration
```python
response = litellm.completion(
model="gitlab/<base_model>", # required (e.g., gitlab/gpt-4)
prompt_id=str, # required - the .prompt filename without extension
prompt_variables=dict, # optional - variables for template rendering
gitlab_config=dict, # optional - gitlab configuration (if not set globally)
messages=list, # optional - additional messages
)
```
## Error Handling
The gitlab integration provides detailed error messages for common issues:
- **Authentication errors**: Invalid access tokens or credentials
- **Permission errors**: Insufficient access to workspace/repository
- **File not found**: Missing .prompt files
- **Network errors**: Connection issues with gitlab API
## Security Considerations
1. **Access Token Security**: Store access tokens securely using environment variables or secret management systems
2. **Repository Permissions**: Use gitlab's permission system to control access
3. **Branch Protection**: Protect main branches from unauthorized changes
4. **Audit Logging**: gitlab provides audit logs for all repository access
## Troubleshooting
### Common Issues
1. **"Access denied" errors**: Check your gitlab permissions for the workspace and repository
2. **"Authentication failed" errors**: Verify your access token or credentials
3. **"File not found" errors**: Ensure the .prompt file exists in the specified branch
4. **Template rendering errors**: Check your Handlebars syntax in the .prompt file
### Debug Mode
Enable debug logging to troubleshoot issues:
```python
import litellm
litellm.set_verbose = True
# Your gitlab prompt calls will now show detailed logs
response = litellm.completion(
model="gitlab/gpt-4",
prompt_id="your_prompt",
prompt_variables={"key": "value"}
)
```
## Migration from File-Based Prompts
If you're currently using file-based prompts with the dotprompt integration, you can easily migrate to gitlab:
1. **Upload your .prompt files** to a gitlab repository
2. **Update your configuration** to use gitlab instead of local files
3. **Set up team access** using gitlab's permission system
4. **Update your code** to use `gitlab/` model prefix instead of `dotprompt/`
This provides better collaboration, version control, and team-based access control for your prompts.

View File

@@ -0,0 +1,94 @@
from typing import TYPE_CHECKING, Optional, Dict, Any
if TYPE_CHECKING:
from .gitlab_prompt_manager import GitLabPromptManager
from litellm.types.prompts.init_prompts import PromptLiteLLMParams, PromptSpec
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.types.prompts.init_prompts import SupportedPromptIntegrations
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.types.prompts.init_prompts import PromptSpec, PromptLiteLLMParams
from .gitlab_prompt_manager import GitLabPromptManager, GitLabPromptCache
# Global instances
global_gitlab_config: Optional[dict] = None
def set_global_gitlab_config(config: dict) -> None:
"""
Set the global gitlab configuration for prompt management.
Args:
config: Dictionary containing gitlab configuration
- workspace: gitlab workspace name
- repository: Repository name
- access_token: gitlab access token
- branch: Branch to fetch prompts from (default: main)
"""
import litellm
litellm.global_gitlab_config = config # type: ignore
def prompt_initializer(
litellm_params: "PromptLiteLLMParams", prompt_spec: "PromptSpec"
) -> "CustomPromptManagement":
"""
Initialize a prompt from a Gitlab repository.
"""
gitlab_config = getattr(litellm_params, "gitlab_config", None)
prompt_id = getattr(litellm_params, "prompt_id", None)
if not gitlab_config:
raise ValueError("gitlab_config is required for gitlab prompt integration")
try:
gitlab_prompt_manager = GitLabPromptManager(
gitlab_config=gitlab_config,
prompt_id=prompt_id,
)
return gitlab_prompt_manager
except Exception as e:
raise e
def _gitlab_prompt_initializer(
litellm_params: PromptLiteLLMParams,
prompt: PromptSpec,
) -> CustomPromptManagement:
"""
Build a GitLab-backed prompt manager for this prompt.
Expected fields on litellm_params:
- prompt_integration="gitlab" (handled by the caller)
- gitlab_config: Dict[str, Any] (project/access_token/branch/prompts_path/etc.)
- git_ref (optional): per-prompt tag/branch/SHA override
"""
# You can store arbitrary integration-specific config on PromptLiteLLMParams.
# If your dataclass doesn't have these attributes, add them or put inside
# `litellm_params.extra` and pull them from there.
gitlab_config: Dict[str, Any] = getattr(litellm_params, "gitlab_config", None) or {}
git_ref: Optional[str] = getattr(litellm_params, "git_ref", None)
if not gitlab_config:
raise ValueError("gitlab_config is required for gitlab prompt integration")
# prompt.prompt_id can map to a file path under prompts_path (e.g. "chat/greet/hi")
return GitLabPromptManager(
gitlab_config=gitlab_config,
prompt_id=prompt.prompt_id,
ref=git_ref,
)
prompt_initializer_registry = {
SupportedPromptIntegrations.GITLAB.value: _gitlab_prompt_initializer,
}
# Export public API
__all__ = [
"GitLabPromptManager",
"GitLabPromptCache",
"set_global_gitlab_config",
"global_gitlab_config",
]

View File

@@ -0,0 +1,309 @@
"""
GitLab API client for fetching files from GitLab repositories.
Now supports selecting a tag via `config["tag"]`; falls back to branch ("main").
"""
import base64
from typing import Any, Dict, List, Optional
from urllib.parse import quote
from litellm.llms.custom_httpx.http_handler import HTTPHandler
class GitLabClient:
"""
Client for interacting with the GitLab API to fetch files.
Supports:
- Authentication with personal/access tokens or OAuth bearer tokens
- Fetching file contents from repositories (raw endpoint with JSON fallback)
- Namespace/project path or numeric project ID addressing
- Ref selection via tag (preferred) or branch (default "main")
- Directory listing via the repository tree API
"""
def __init__(self, config: Dict[str, Any]):
"""
Initialize the GitLab client.
Args:
config: Dictionary containing:
- project: Project path ("group/subgroup/repo") or numeric project ID (str|int) [required]
- access_token: GitLab personal/access token or OAuth token [required] (str)
- auth_method: 'token' (default; sends Private-Token) or 'oauth' (Authorization: Bearer)
- tag: Tag name to fetch from (takes precedence over branch if provided)
- branch: Branch to fetch from (default: "main")
- base_url: Base GitLab API URL (default: "https://gitlab.com/api/v4")
"""
project = config.get("project")
access_token = config.get("access_token")
if project is None or access_token is None:
raise ValueError("project and access_token are required")
self.project: str | int = project
self.access_token: str = str(access_token)
self.auth_method = config.get("auth_method", "token") # 'token' or 'oauth'
self.branch = config.get("branch", None)
if not self.branch:
self.branch = "main"
self.tag = config.get("tag")
self.base_url = config.get("base_url", "https://gitlab.com/api/v4")
if not all([self.project, self.access_token]):
raise ValueError("project and access_token are required")
# Effective ref: prefer tag if provided, else branch ("main")
self.ref = str(self.tag or self.branch)
# Build headers
self.headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
if self.auth_method == "oauth":
self.headers["Authorization"] = f"Bearer {self.access_token}"
else:
# Default GitLab token header
self.headers["Private-Token"] = self.access_token
# Project identifier must be URL-encoded (slashes become %2F)
self._project_enc = quote(str(self.project), safe="")
# HTTP handler
self.http_handler = HTTPHandler()
# ------------------------
# Core helpers
# ------------------------
def _file_raw_url(self, file_path: str, *, ref: Optional[str] = None) -> str:
file_enc = quote(file_path, safe="")
ref_q = quote(ref or self.ref, safe="")
return f"{self.base_url}/projects/{self._project_enc}/repository/files/{file_enc}/raw?ref={ref_q}"
def _file_json_url(self, file_path: str, *, ref: Optional[str] = None) -> str:
file_enc = quote(file_path, safe="")
ref_q = quote(ref or self.ref, safe="")
return f"{self.base_url}/projects/{self._project_enc}/repository/files/{file_enc}?ref={ref_q}"
def _tree_url(
self,
directory_path: str = "",
recursive: bool = False,
*,
ref: Optional[str] = None,
) -> str:
path_q = f"&path={quote(directory_path, safe='')}" if directory_path else ""
rec_q = "&recursive=true" if recursive else ""
ref_q = quote(ref or self.ref, safe="")
return f"{self.base_url}/projects/{self._project_enc}/repository/tree?ref={ref_q}{path_q}{rec_q}"
# ------------------------
# Public API
# ------------------------
def set_ref(self, ref: str) -> None:
"""Override the default ref (tag/branch) for subsequent calls."""
if not ref:
raise ValueError("ref must be a non-empty string")
self.ref = ref
def get_file_content(
self, file_path: str, *, ref: Optional[str] = None
) -> Optional[str]:
"""
Fetch the content of a file from the GitLab repository at the given ref
(tag, branch, or commit SHA). If `ref` is None, uses self.ref.
Strategy:
1) Try the RAW endpoint (returns bytes of the file)
2) Fallback to the JSON endpoint (returns base64-encoded content)
Returns:
File content as UTF-8 string, or None if file not found.
"""
raw_url = self._file_raw_url(file_path, ref=ref)
try:
resp = self.http_handler.get(raw_url, headers=self.headers)
if resp.status_code == 404:
# Fallback to JSON endpoint
return self._get_file_content_via_json(file_path, ref=ref)
resp.raise_for_status()
ctype = (resp.headers.get("content-type") or "").lower()
if (
ctype.startswith("text/")
or "charset=" in ctype
or ctype.startswith("application/json")
):
return resp.text
try:
return resp.content.decode("utf-8")
except Exception:
return resp.content.decode("utf-8", errors="replace")
except Exception as e:
status = getattr(getattr(e, "response", None), "status_code", None)
if status == 404:
return None
if status == 403:
raise Exception(
f"Access denied to file '{file_path}'. Check your GitLab permissions for project '{self.project}'."
)
if status == 401:
raise Exception(
"Authentication failed. Check your GitLab token and auth_method."
)
raise Exception(f"Failed to fetch file '{file_path}': {e}")
def _get_file_content_via_json(
self, file_path: str, *, ref: Optional[str] = None
) -> Optional[str]:
"""
Fallback for get_file_content(): use the JSON file API which returns base64 content.
"""
json_url = self._file_json_url(file_path, ref=ref)
try:
resp = self.http_handler.get(json_url, headers=self.headers)
if resp.status_code == 404:
return None
resp.raise_for_status()
data = resp.json()
content = data.get("content")
encoding = data.get("encoding", "")
if content and encoding == "base64":
try:
return base64.b64decode(content).decode("utf-8")
except Exception:
return base64.b64decode(content).decode("utf-8", errors="replace")
return content
except Exception as e:
status = getattr(getattr(e, "response", None), "status_code", None)
if status == 404:
return None
if status == 403:
raise Exception(
f"Access denied to file '{file_path}'. Check your GitLab permissions for project '{self.project}'."
)
if status == 401:
raise Exception(
"Authentication failed. Check your GitLab token and auth_method."
)
raise Exception(
f"Failed to fetch file '{file_path}' via JSON endpoint: {e}"
)
def list_files(
self,
directory_path: str = "",
file_extension: str = ".prompt",
recursive: bool = False,
*,
ref: Optional[str] = None,
) -> List[str]:
"""
List files in a directory with a specific extension using the repository tree API.
Args:
directory_path: Directory path in the repository (empty for repo root)
file_extension: File extension to filter by (default: .prompt)
recursive: If True, traverses subdirectories
ref: Optional override (tag/branch/SHA). Defaults to self.ref.
Returns:
List of file paths (relative to repo root)
"""
url = self._tree_url(directory_path, recursive=recursive, ref=ref)
try:
resp = self.http_handler.get(url, headers=self.headers)
if resp.status_code == 404:
return []
resp.raise_for_status()
data = resp.json() or []
files: List[str] = []
for item in data:
if item.get("type") == "blob":
file_path = item.get("path", "")
if not file_extension or file_path.endswith(file_extension):
files.append(file_path)
return files
except Exception as e:
status = getattr(getattr(e, "response", None), "status_code", None)
if status == 404:
return []
if status == 403:
raise Exception(
f"Access denied to directory '{directory_path}'. Check your GitLab permissions for project '{self.project}'."
)
if status == 401:
raise Exception(
"Authentication failed. Check your GitLab token and auth_method."
)
raise Exception(f"Failed to list files in '{directory_path}': {e}")
def get_repository_info(self) -> Dict[str, Any]:
"""Get information about the project/repository."""
url = f"{self.base_url}/projects/{self._project_enc}"
try:
resp = self.http_handler.get(url, headers=self.headers)
resp.raise_for_status()
return resp.json()
except Exception as e:
raise Exception(f"Failed to get repository info: {e}")
def test_connection(self) -> bool:
"""Test the connection to the GitLab project."""
try:
self.get_repository_info()
return True
except Exception:
return False
def get_branches(self) -> List[Dict[str, Any]]:
"""Get list of branches in the repository."""
url = f"{self.base_url}/projects/{self._project_enc}/repository/branches"
try:
resp = self.http_handler.get(url, headers=self.headers)
resp.raise_for_status()
data = resp.json()
return data if isinstance(data, list) else []
except Exception as e:
raise Exception(f"Failed to get branches: {e}")
def get_file_metadata(
self, file_path: str, *, ref: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
Get minimal metadata about a file via RAW endpoint headers at a given ref.
Args:
file_path: Path to the file in the repository.
ref: Optional override (tag/branch/SHA). Defaults to self.ref.
"""
url = self._file_raw_url(file_path, ref=ref)
try:
headers = dict(self.headers)
headers["Range"] = "bytes=0-0"
resp = self.http_handler.get(url, headers=headers)
if resp.status_code == 404:
return None
resp.raise_for_status()
return {
"content_type": resp.headers.get("content-type"),
"content_length": resp.headers.get("content-length"),
"last_modified": resp.headers.get("last-modified"),
}
except Exception as e:
status = getattr(getattr(e, "response", None), "status_code", None)
if status == 404:
return None
raise Exception(f"Failed to get file metadata for '{file_path}': {e}")
def close(self):
"""Close the HTTP handler to free resources."""
if hasattr(self, "http_handler"):
self.http_handler.close()

View File

@@ -0,0 +1,760 @@
"""
GitLab prompt manager with configurable prompts folder.
"""
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from jinja2 import DictLoader, Environment, select_autoescape
from litellm.integrations.custom_prompt_management import CustomPromptManagement
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
from litellm.integrations.gitlab.gitlab_client import GitLabClient
from litellm.integrations.prompt_management_base import (
PromptManagementBase,
PromptManagementClient,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.prompts.init_prompts import PromptSpec
from litellm.types.utils import StandardCallbackDynamicParams
GITLAB_PREFIX = "gitlab::"
def encode_prompt_id(raw_id: str) -> str:
"""Convert GitLab path IDs like 'invoice/extract''gitlab::invoice::extract'"""
if raw_id.startswith(GITLAB_PREFIX):
return raw_id # already encoded
return f"{GITLAB_PREFIX}{raw_id.replace('/', '::')}"
def decode_prompt_id(encoded_id: str) -> str:
"""Convert 'gitlab::invoice::extract''invoice/extract'"""
if not encoded_id.startswith(GITLAB_PREFIX):
return encoded_id
return encoded_id[len(GITLAB_PREFIX) :].replace("::", "/")
class GitLabPromptTemplate:
def __init__(
self,
template_id: str,
content: str,
metadata: Dict[str, Any],
model: Optional[str] = None,
):
self.template_id = template_id
self.content = content
self.metadata = metadata
self.model = model or metadata.get("model")
self.temperature = metadata.get("temperature")
self.max_tokens = metadata.get("max_tokens")
self.input_schema = metadata.get("input", {}).get("schema", {})
self.optional_params = {
k: v for k, v in metadata.items() if k not in ["model", "input", "content"]
}
def __repr__(self):
return f"GitLabPromptTemplate(id='{self.template_id}', model='{self.model}')"
class GitLabTemplateManager:
"""
Manager for loading and rendering .prompt files from GitLab repositories.
New: supports `prompts_path` (or `folder`) in gitlab_config to scope where prompts live.
"""
def __init__(
self,
gitlab_config: Dict[str, Any],
prompt_id: Optional[str] = None,
ref: Optional[str] = None,
gitlab_client: Optional[GitLabClient] = None,
):
self.gitlab_config = dict(gitlab_config)
self.prompt_id = prompt_id
self.prompts: Dict[str, GitLabPromptTemplate] = {}
self.gitlab_client = gitlab_client or GitLabClient(self.gitlab_config)
if ref:
self.gitlab_client.set_ref(ref)
# Folder inside repo to look for prompts (e.g., "prompts" or "prompts/chat")
self.prompts_path: str = (
self.gitlab_config.get("prompts_path")
or self.gitlab_config.get("folder")
or ""
).strip("/")
self.jinja_env = Environment(
loader=DictLoader({}),
autoescape=select_autoescape(["html", "xml"]),
variable_start_string="{{",
variable_end_string="}}",
block_start_string="{%",
block_end_string="%}",
comment_start_string="{#",
comment_end_string="#}",
)
if self.prompt_id:
self._load_prompt_from_gitlab(self.prompt_id)
# ---------- path helpers ----------
def _id_to_repo_path(self, prompt_id: str) -> str:
"""Map a prompt_id to a repo path (respects prompts_path and adds .prompt)."""
prompt_id = decode_prompt_id(prompt_id)
if self.prompts_path:
return f"{self.prompts_path}/{prompt_id}.prompt"
return f"{prompt_id}.prompt"
def _repo_path_to_id(self, repo_path: str) -> str:
"""
Map a repo path like 'prompts/chat/greeting.prompt' to an ID relative
to prompts_path without the extension (e.g., 'chat/greeting').
"""
path = repo_path.strip("/")
if self.prompts_path and path.startswith(self.prompts_path.strip("/") + "/"):
path = path[len(self.prompts_path.strip("/")) + 1 :]
if path.endswith(".prompt"):
path = path[: -len(".prompt")]
return encode_prompt_id(path)
# ---------- loading ----------
def _load_prompt_from_gitlab(
self, prompt_id: str, *, ref: Optional[str] = None
) -> None:
"""Load a specific .prompt file from GitLab (scoped under prompts_path if set)."""
try:
# prompt_id = decode_prompt_id(prompt_id)
file_path = self._id_to_repo_path(prompt_id)
prompt_content = self.gitlab_client.get_file_content(file_path, ref=ref)
if prompt_content:
template = self._parse_prompt_file(prompt_content, prompt_id)
self.prompts[prompt_id] = template
except Exception as e:
raise Exception(
f"Failed to load prompt '{encode_prompt_id(prompt_id)}' from GitLab: {e}"
)
def load_all_prompts(self, *, recursive: bool = True) -> List[str]:
"""
Eagerly load all .prompt files from prompts_path. Returns loaded IDs.
"""
files = self.list_templates(recursive=recursive)
loaded: List[str] = []
for pid in files:
if pid not in self.prompts:
self._load_prompt_from_gitlab(pid)
loaded.append(pid)
return loaded
# ---------- parsing & rendering ----------
def _parse_prompt_file(self, content: str, prompt_id: str) -> GitLabPromptTemplate:
if content.startswith("---"):
parts = content.split("---", 2)
if len(parts) >= 3:
frontmatter_str = parts[1].strip()
template_content = parts[2].strip()
else:
frontmatter_str = ""
template_content = content
else:
frontmatter_str = ""
template_content = content
metadata: Dict[str, Any] = {}
if frontmatter_str:
try:
import yaml
metadata = yaml.safe_load(frontmatter_str) or {}
except ImportError:
metadata = self._parse_yaml_basic(frontmatter_str)
except Exception:
metadata = {}
return GitLabPromptTemplate(
template_id=prompt_id,
content=template_content,
metadata=metadata,
)
def _parse_yaml_basic(self, yaml_str: str) -> Dict[str, Any]:
result: Dict[str, Any] = {}
for line in yaml_str.split("\n"):
line = line.strip()
if ":" in line and not line.startswith("#"):
key, value = line.split(":", 1)
key = key.strip()
value = value.strip()
if value.lower() in ["true", "false"]:
result[key] = value.lower() == "true"
elif value.isdigit():
result[key] = int(value)
elif value.replace(".", "").isdigit():
try:
result[key] = float(value)
except Exception:
result[key] = value
else:
result[key] = value.strip("\"'")
return result
def render_template(
self, template_id: str, variables: Optional[Dict[str, Any]] = None
) -> str:
if template_id not in self.prompts:
raise ValueError(f"Template '{template_id}' not found")
template = self.prompts[template_id]
jinja_template = self.jinja_env.from_string(template.content)
return jinja_template.render(**(variables or {}))
def get_template(self, template_id: str) -> Optional[GitLabPromptTemplate]:
return self.prompts.get(template_id)
def list_templates(self, *, recursive: bool = True) -> List[str]:
"""
List available prompt IDs under prompts_path (no extension).
Compatible with both list_files signatures:
- list_files(directory_path=..., file_extension=..., recursive=...)
- list_files(path=..., ref=None, recursive=...)
"""
# First try the "new" signature (directory_path/file_extension)
try:
files = self.gitlab_client.list_files(
directory_path=self.prompts_path,
file_extension=".prompt",
recursive=recursive,
)
base = self.prompts_path.strip("/")
out: List[str] = []
for p in files or []:
path = str(p).strip("/")
if base and not path.startswith(base + "/"):
# if the client returns extra files outside the folder, skip them
continue
if not path.endswith(".prompt"):
continue
out.append(self._repo_path_to_id(path))
return out
except TypeError:
# Fallback to the "classic" signature
raw = self.gitlab_client.list_files(
directory_path=self.prompts_path or "",
ref=None,
recursive=recursive,
)
# Classic returns GitLab tree entries; filter *.prompt blobs
files = []
for f in raw or []:
if (
isinstance(f, dict)
and f.get("type") == "blob"
and str(f.get("path", "")).endswith(".prompt")
and "path" in f
):
files.append(f["path"]) # type: ignore
return [self._repo_path_to_id(p) for p in files]
class GitLabPromptManager(CustomPromptManagement):
"""
GitLab prompt manager with folder support.
Example config:
gitlab_config = {
"project": "group/subgroup/repo",
"access_token": "glpat_***",
"tag": "v1.2.3", # optional; takes precedence
"branch": "main", # default fallback
"prompts_path": "prompts/chat"
}
"""
def __init__(
self,
gitlab_config: Dict[str, Any],
prompt_id: Optional[str] = None,
ref: Optional[str] = None, # tag/branch/SHA override
gitlab_client: Optional[GitLabClient] = None,
):
self.gitlab_config = gitlab_config
self.prompt_id = prompt_id
self._prompt_manager: Optional[GitLabTemplateManager] = None
self._ref_override = ref
self._injected_gitlab_client = gitlab_client
if self.prompt_id:
self._prompt_manager = GitLabTemplateManager(
gitlab_config=self.gitlab_config,
prompt_id=self.prompt_id,
ref=self._ref_override,
)
@property
def integration_name(self) -> str:
return "gitlab"
@property
def prompt_manager(self) -> GitLabTemplateManager:
if self._prompt_manager is None:
self._prompt_manager = GitLabTemplateManager(
gitlab_config=self.gitlab_config,
prompt_id=self.prompt_id,
ref=self._ref_override,
gitlab_client=self._injected_gitlab_client,
)
return self._prompt_manager
def get_prompt_template(
self,
prompt_id: str,
prompt_variables: Optional[Dict[str, Any]] = None,
*,
ref: Optional[str] = None,
) -> Tuple[str, Dict[str, Any]]:
if prompt_id not in self.prompt_manager.prompts:
self.prompt_manager._load_prompt_from_gitlab(prompt_id, ref=ref)
template = self.prompt_manager.get_template(prompt_id)
if not template:
raise ValueError(f"Prompt template '{prompt_id}' not found")
rendered_prompt = self.prompt_manager.render_template(
prompt_id, prompt_variables or {}
)
metadata = {
"model": template.model,
"temperature": template.temperature,
"max_tokens": template.max_tokens,
**template.optional_params,
}
return rendered_prompt, metadata
def pre_call_hook(
self,
user_id: Optional[str],
messages: List[AllMessageValues],
function_call: Optional[Union[Dict[str, Any], str]] = None,
litellm_params: Optional[Dict[str, Any]] = None,
prompt_id: Optional[str] = None,
prompt_variables: Optional[Dict[str, Any]] = None,
prompt_version: Optional[str] = None,
**kwargs,
) -> Tuple[List[AllMessageValues], Optional[Dict[str, Any]]]:
if not prompt_id:
return messages, litellm_params
try:
# Precedence: explicit prompt_version → per-call git_ref kwarg → manager override → config default
git_ref = prompt_version or kwargs.get("git_ref") or self._ref_override
rendered_prompt, prompt_metadata = self.get_prompt_template(
prompt_id, prompt_variables, ref=git_ref
)
parsed_messages = self._parse_prompt_to_messages(rendered_prompt)
if parsed_messages:
final_messages: List[AllMessageValues] = parsed_messages
else:
final_messages = [{"role": "user", "content": rendered_prompt}] + messages # type: ignore
if litellm_params is None:
litellm_params = {}
if prompt_metadata.get("model"):
litellm_params["model"] = prompt_metadata["model"]
for param in [
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
]:
if param in prompt_metadata:
litellm_params[param] = prompt_metadata[param]
return final_messages, litellm_params
except Exception as e:
import litellm
litellm._logging.verbose_proxy_logger.error(
f"Error in GitLab prompt pre_call_hook: {e}"
)
return messages, litellm_params
def _parse_prompt_to_messages(self, prompt_content: str) -> List[AllMessageValues]:
messages: List[AllMessageValues] = []
lines = prompt_content.strip().split("\n")
current_role: Optional[str] = None
current_content: List[str] = []
for raw in lines:
line = raw.strip()
if not line:
continue
low = line.lower()
if low.startswith("system:"):
if current_role and current_content:
messages.append({"role": current_role, "content": "\n".join(current_content).strip()}) # type: ignore
current_role = "system"
current_content = [line[7:].strip()]
elif low.startswith("user:"):
if current_role and current_content:
messages.append({"role": current_role, "content": "\n".join(current_content).strip()}) # type: ignore
current_role = "user"
current_content = [line[5:].strip()]
elif low.startswith("assistant:"):
if current_role and current_content:
messages.append({"role": current_role, "content": "\n".join(current_content).strip()}) # type: ignore
current_role = "assistant"
current_content = [line[10:].strip()]
else:
current_content.append(line)
if current_role and current_content:
messages.append({"role": current_role, "content": "\n".join(current_content).strip()}) # type: ignore
if not messages and prompt_content.strip():
messages = [{"role": "user", "content": prompt_content.strip()}] # type: ignore
return messages
def post_call_hook(
self,
user_id: Optional[str],
response: Any,
input_messages: List[AllMessageValues],
function_call: Optional[Union[Dict[str, Any], str]] = None,
litellm_params: Optional[Dict[str, Any]] = None,
prompt_id: Optional[str] = None,
prompt_variables: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
return response
def get_available_prompts(self) -> List[str]:
"""
Return prompt IDs. Prefer already-loaded templates in memory to avoid
unnecessary network calls (and to make tests deterministic).
"""
ids = set(self.prompt_manager.prompts.keys())
try:
ids.update(self.prompt_manager.list_templates())
except Exception:
# If GitLab list fails (auth, network), still return what we've loaded.
pass
return sorted(ids)
def reload_prompts(self) -> None:
if self.prompt_id:
self._prompt_manager = None
_ = self.prompt_manager # trigger re-init/load
def should_run_prompt_management(
self,
prompt_id: Optional[str],
prompt_spec: Optional[PromptSpec],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> bool:
return prompt_id is not None
def _compile_prompt_helper(
self,
prompt_id: Optional[str],
prompt_spec: Optional[PromptSpec],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
) -> PromptManagementClient:
if prompt_id is None:
raise ValueError("prompt_id is required for GitLab prompt manager")
try:
decoded_id = decode_prompt_id(prompt_id)
if decoded_id not in self.prompt_manager.prompts:
git_ref = (
getattr(dynamic_callback_params, "extra", {}).get("git_ref")
if hasattr(dynamic_callback_params, "extra")
else None
)
self.prompt_manager._load_prompt_from_gitlab(decoded_id, ref=git_ref)
rendered_prompt, prompt_metadata = self.get_prompt_template(
prompt_id, prompt_variables
)
messages = self._parse_prompt_to_messages(rendered_prompt)
template_model = prompt_metadata.get("model")
optional_params: Dict[str, Any] = {}
for param in [
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
]:
if param in prompt_metadata:
optional_params[param] = prompt_metadata[param]
return PromptManagementClient(
prompt_id=prompt_id,
prompt_template=messages,
prompt_template_model=template_model,
prompt_template_optional_params=optional_params,
completed_messages=None,
)
except Exception as e:
raise ValueError(f"Error compiling prompt '{prompt_id}': {e}")
async def async_compile_prompt_helper(
self,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_spec: Optional[PromptSpec] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
) -> PromptManagementClient:
"""
Async version of compile prompt helper. Since GitLab operations use sync client,
this simply delegates to the sync version.
"""
if prompt_id is None:
raise ValueError("prompt_id is required for GitLab prompt manager")
return self._compile_prompt_helper(
prompt_id=prompt_id,
prompt_spec=prompt_spec,
prompt_variables=prompt_variables,
dynamic_callback_params=dynamic_callback_params,
prompt_label=prompt_label,
prompt_version=prompt_version,
)
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_spec: Optional[PromptSpec] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
ignore_prompt_manager_model: Optional[bool] = False,
ignore_prompt_manager_optional_params: Optional[bool] = False,
) -> Tuple[str, List[AllMessageValues], dict]:
return PromptManagementBase.get_chat_completion_prompt(
self,
model,
messages,
non_default_params,
prompt_id,
prompt_variables,
dynamic_callback_params,
prompt_spec=prompt_spec,
prompt_label=prompt_label,
prompt_version=prompt_version,
)
async def async_get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
litellm_logging_obj: LiteLLMLoggingObj,
prompt_spec: Optional[PromptSpec] = None,
tools: Optional[List[Dict]] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
ignore_prompt_manager_model: Optional[bool] = False,
ignore_prompt_manager_optional_params: Optional[bool] = False,
) -> Tuple[str, List[AllMessageValues], dict]:
"""
Async version - delegates to PromptManagementBase async implementation.
"""
return await PromptManagementBase.async_get_chat_completion_prompt(
self,
model,
messages,
non_default_params,
prompt_id=prompt_id,
prompt_variables=prompt_variables,
litellm_logging_obj=litellm_logging_obj,
dynamic_callback_params=dynamic_callback_params,
prompt_spec=prompt_spec,
tools=tools,
prompt_label=prompt_label,
prompt_version=prompt_version,
ignore_prompt_manager_model=ignore_prompt_manager_model,
ignore_prompt_manager_optional_params=ignore_prompt_manager_optional_params,
)
class GitLabPromptCache:
"""
Cache all .prompt files from a GitLab repo into memory.
- Keys are the *repo file paths* (e.g. "prompts/chat/greet/hi.prompt")
mapped to JSON-like dicts containing content + metadata.
- Also exposes a by-ID view (ID == path relative to prompts_path without ".prompt",
e.g. "greet/hi").
Usage:
cfg = {
"project": "group/subgroup/repo",
"access_token": "glpat_***",
"prompts_path": "prompts/chat", # optional, can be empty for repo root
# "branch": "main", # default is "main"
# "tag": "v1.2.3", # takes precedence over branch
# "base_url": "https://gitlab.com/api/v4" # default
}
cache = GitLabPromptCache(cfg)
cache.load_all() # fetch + parse all .prompt files
print(cache.list_files()) # repo file paths
print(cache.list_ids()) # template IDs relative to prompts_path
prompt_json = cache.get_by_file("prompts/chat/greet/hi.prompt")
prompt_json2 = cache.get_by_id("greet/hi")
# If GitLab content changes and you want to refresh:
cache.reload() # re-scan and refresh all
"""
def __init__(
self,
gitlab_config: Dict[str, Any],
*,
ref: Optional[str] = None,
gitlab_client: Optional[GitLabClient] = None,
) -> None:
# Build a PromptManager (which internally builds TemplateManager + Client)
self.prompt_manager = GitLabPromptManager(
gitlab_config=gitlab_config,
prompt_id=None,
ref=ref,
gitlab_client=gitlab_client,
)
self.template_manager: GitLabTemplateManager = (
self.prompt_manager.prompt_manager
)
# In-memory stores
self._by_file: Dict[str, Dict[str, Any]] = {}
self._by_id: Dict[str, Dict[str, Any]] = {}
# -------------------------
# Public API
# -------------------------
def load_all(self, *, recursive: bool = True) -> Dict[str, Dict[str, Any]]:
"""
Scan GitLab for all .prompt files under prompts_path, load and parse each,
and return the mapping of repo file path -> JSON-like dict.
"""
ids = self.template_manager.list_templates(
recursive=recursive
) # IDs relative to prompts_path
for pid in ids:
# Ensure template is loaded into TemplateManager
if pid not in self.template_manager.prompts:
self.template_manager._load_prompt_from_gitlab(pid)
tmpl = self.template_manager.get_template(pid)
if tmpl is None:
# If something raced/failed, try once more
self.template_manager._load_prompt_from_gitlab(pid)
tmpl = self.template_manager.get_template(pid)
if tmpl is None:
continue
file_path = self.template_manager._id_to_repo_path(
pid
) # "prompts/chat/..../file.prompt"
entry = self._template_to_json(pid, tmpl)
self._by_file[file_path] = entry
# prefixed_id = pid if pid.startswith("gitlab::") else f"gitlab::{pid}"
encoded_id = encode_prompt_id(pid)
self._by_id[encoded_id] = entry
# self._by_id[pid] = entry
return self._by_id
def reload(self, *, recursive: bool = True) -> Dict[str, Dict[str, Any]]:
"""Clear the cache and re-load from GitLab."""
self._by_file.clear()
self._by_id.clear()
return self.load_all(recursive=recursive)
def list_files(self) -> List[str]:
"""Return the repo file paths currently cached."""
return list(self._by_file.keys())
def list_ids(self) -> List[str]:
"""Return the template IDs (relative to prompts_path, without extension) currently cached."""
return list(self._by_id.keys())
def get_by_file(self, file_path: str) -> Optional[Dict[str, Any]]:
"""Get a cached prompt JSON by repo file path."""
return self._by_file.get(file_path)
def get_by_id(self, prompt_id: str) -> Optional[Dict[str, Any]]:
"""Get a cached prompt JSON by prompt ID (relative to prompts_path)."""
if prompt_id in self._by_id:
return self._by_id[prompt_id]
# Try normalized forms
decoded = decode_prompt_id(prompt_id)
encoded = encode_prompt_id(decoded)
return self._by_id.get(encoded) or self._by_id.get(decoded)
# -------------------------
# Internals
# -------------------------
def _template_to_json(
self, prompt_id: str, tmpl: GitLabPromptTemplate
) -> Dict[str, Any]:
"""
Normalize a GitLabPromptTemplate into a JSON-like dict that is easy to serialize.
"""
# Safer copy of metadata (avoid accidental mutation)
md = dict(tmpl.metadata or {})
# Pull standard fields (also present in metadata sometimes)
model = tmpl.model
temperature = tmpl.temperature
max_tokens = tmpl.max_tokens
optional_params = dict(tmpl.optional_params or {})
return {
"id": prompt_id, # e.g. "greet/hi"
"path": self.template_manager._id_to_repo_path(
prompt_id
), # e.g. "prompts/chat/greet/hi.prompt"
"content": tmpl.content, # rendered content (without frontmatter)
"metadata": md, # parsed frontmatter
"model": model,
"temperature": temperature,
"max_tokens": max_tokens,
"optional_params": optional_params,
}

View File

@@ -0,0 +1,72 @@
import json
import traceback
from datetime import datetime, timezone
import litellm
class GreenscaleLogger:
def __init__(self):
import os
self.greenscale_api_key = os.getenv("GREENSCALE_API_KEY")
self.headers = {
"api-key": self.greenscale_api_key,
"Content-Type": "application/json",
}
self.greenscale_logging_url = os.getenv("GREENSCALE_ENDPOINT")
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
try:
response_json = response_obj.model_dump() if response_obj else {}
data = {
"modelId": kwargs.get("model"),
"inputTokenCount": response_json.get("usage", {}).get("prompt_tokens"),
"outputTokenCount": response_json.get("usage", {}).get(
"completion_tokens"
),
}
data["timestamp"] = datetime.now(timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%SZ"
)
if type(end_time) is datetime and type(start_time) is datetime:
data["invocationLatency"] = int(
(end_time - start_time).total_seconds() * 1000
)
# Add additional metadata keys to tags
tags = []
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
for key, value in metadata.items():
if key.startswith("greenscale"):
if key == "greenscale_project":
data["project"] = value
elif key == "greenscale_application":
data["application"] = value
else:
tags.append(
{"key": key.replace("greenscale_", ""), "value": str(value)}
)
data["tags"] = tags
if self.greenscale_logging_url is None:
raise Exception("Greenscale Logger Error - No logging URL found")
response = litellm.module_level_client.post(
self.greenscale_logging_url,
headers=self.headers,
data=json.dumps(data, default=str),
)
if response.status_code != 200:
print_verbose(
f"Greenscale Logger Error - {response.text}, {response.status_code}"
)
else:
print_verbose(f"Greenscale Logger Succeeded - {response.text}")
except Exception as e:
print_verbose(
f"Greenscale Logger Error - {e}, Stack trace: {traceback.format_exc()}"
)
pass

View File

@@ -0,0 +1,228 @@
#### What this does ####
# On success, logs events to Helicone
import os
import traceback
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.helicone_mock_client import (
should_use_helicone_mock,
create_mock_helicone_client,
)
class HeliconeLogger:
# Class variables or attributes
helicone_model_list = [
"gpt",
"claude",
"gemini",
"command-r",
"command-r-plus",
"command-light",
"command-medium",
"command-medium-beta",
"command-xlarge-nightly",
"command-nightly",
]
def __init__(self):
# Instance variables
self.is_mock_mode = should_use_helicone_mock()
if self.is_mock_mode:
create_mock_helicone_client()
verbose_logger.info(
"[HELICONE MOCK] Helicone logger initialized in mock mode"
)
self.provider_url = "https://api.openai.com/v1"
self.key = os.getenv("HELICONE_API_KEY")
self.api_base = os.getenv("HELICONE_API_BASE") or "https://api.hconeai.com"
if self.api_base.endswith("/"):
self.api_base = self.api_base[:-1]
def claude_mapping(self, model, messages, response_obj):
from anthropic import AI_PROMPT, HUMAN_PROMPT
prompt = f"{HUMAN_PROMPT}"
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += f"{HUMAN_PROMPT}{message['content']}"
else:
prompt += f"{AI_PROMPT}{message['content']}"
else:
prompt += f"{HUMAN_PROMPT}{message['content']}"
prompt += f"{AI_PROMPT}"
choice = response_obj["choices"][0]
message = choice["message"]
content = []
if "tool_calls" in message and message["tool_calls"]:
for tool_call in message["tool_calls"]:
content.append(
{
"type": "tool_use",
"id": tool_call["id"],
"name": tool_call["function"]["name"],
"input": tool_call["function"]["arguments"],
}
)
elif "content" in message and message["content"]:
content = [{"type": "text", "text": message["content"]}]
claude_response_obj = {
"id": response_obj["id"],
"type": "message",
"role": "assistant",
"model": model,
"content": content,
"stop_reason": choice["finish_reason"],
"stop_sequence": None,
"usage": {
"input_tokens": response_obj["usage"]["prompt_tokens"],
"output_tokens": response_obj["usage"]["completion_tokens"],
},
}
return claude_response_obj
@staticmethod
def add_metadata_from_header(litellm_params: dict, metadata: dict) -> dict:
"""
Adds metadata from proxy request headers to Helicone logging if keys start with "helicone_"
and overwrites litellm_params.metadata if already included.
For example if you want to add custom property to your request, send
`headers: { ..., helicone-property-something: 1234 }` via proxy request.
"""
if litellm_params is None:
return metadata
if litellm_params.get("proxy_server_request") is None:
return metadata
if metadata is None:
metadata = {}
proxy_headers = (
litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
)
for header_key in proxy_headers:
if header_key.startswith("helicone_"):
metadata[header_key] = proxy_headers.get(header_key)
# Remove OpenTelemetry span from metadata as it's not JSON serializable
# The span is used internally for tracing but shouldn't be logged to external services
if "litellm_parent_otel_span" in metadata:
metadata.pop("litellm_parent_otel_span")
return metadata
def log_success(
self, model, messages, response_obj, start_time, end_time, print_verbose, kwargs
):
# Method definition
try:
print_verbose(
f"Helicone Logging - Enters logging function for model {model}"
)
litellm_params = kwargs.get("litellm_params", {})
custom_llm_provider = litellm_params.get("custom_llm_provider", "")
kwargs.get("litellm_call_id", None)
metadata = litellm_params.get("metadata", {}) or {}
metadata = self.add_metadata_from_header(litellm_params, metadata)
# Check if model is a vertex_ai model
is_vertex_ai = custom_llm_provider == "vertex_ai" or model.startswith(
"vertex_ai/"
)
model = (
model
if any(
accepted_model in model
for accepted_model in self.helicone_model_list
)
or is_vertex_ai
else "gpt-3.5-turbo"
)
provider_request = {"model": model, "messages": messages}
if isinstance(response_obj, litellm.EmbeddingResponse) or isinstance(
response_obj, litellm.ModelResponse
):
response_obj = response_obj.json()
if "claude" in model and not is_vertex_ai:
response_obj = self.claude_mapping(
model=model, messages=messages, response_obj=response_obj
)
providerResponse = {
"json": response_obj,
"headers": {"openai-version": "2020-10-01"},
"status": 200,
}
# Code to be executed
provider_url = self.provider_url
url = f"{self.api_base}/oai/v1/log"
if "claude" in model and not is_vertex_ai:
url = f"{self.api_base}/anthropic/v1/log"
provider_url = "https://api.anthropic.com/v1/messages"
elif is_vertex_ai:
url = f"{self.api_base}/custom/v1/log"
provider_url = "https://aiplatform.googleapis.com/v1"
elif "gemini" in model:
url = f"{self.api_base}/custom/v1/log"
provider_url = "https://generativelanguage.googleapis.com/v1beta"
headers = {
"Authorization": f"Bearer {self.key}",
"Content-Type": "application/json",
}
start_time_seconds = int(start_time.timestamp())
start_time_milliseconds = int(
(start_time.timestamp() - start_time_seconds) * 1000
)
end_time_seconds = int(end_time.timestamp())
end_time_milliseconds = int(
(end_time.timestamp() - end_time_seconds) * 1000
)
meta = {"Helicone-Auth": f"Bearer {self.key}"}
meta.update(metadata)
data = {
"providerRequest": {
"url": provider_url,
"json": provider_request,
"meta": meta,
},
"providerResponse": providerResponse,
"timing": {
"startTime": {
"seconds": start_time_seconds,
"milliseconds": start_time_milliseconds,
},
"endTime": {
"seconds": end_time_seconds,
"milliseconds": end_time_milliseconds,
},
}, # {"seconds": .., "milliseconds": ..}
}
response = litellm.module_level_client.post(url, headers=headers, json=data)
if response.status_code == 200:
if self.is_mock_mode:
print_verbose(
"[HELICONE MOCK] Helicone Logging - Successfully mocked!"
)
else:
print_verbose("Helicone Logging - Success!")
else:
print_verbose(
f"Helicone Logging - Error Request was not successful. Status Code: {response.status_code}"
)
print_verbose(f"Helicone Logging - Error {response.text}")
except Exception:
print_verbose(f"Helicone Logging Error - {traceback.format_exc()}")
pass

View File

@@ -0,0 +1,37 @@
"""
Mock HTTP client for Helicone integration testing.
This module intercepts Helicone API calls and returns successful mock responses,
allowing full code execution without making actual network calls.
Usage:
Set HELICONE_MOCK=true in environment variables or config to enable mock mode.
"""
from litellm.integrations.mock_client_factory import (
MockClientConfig,
create_mock_client_factory,
)
# Create mock client using factory
# Helicone uses HTTPHandler which internally uses httpx.Client.send(), not httpx.Client.post()
_config = MockClientConfig(
name="HELICONE",
env_var="HELICONE_MOCK",
default_latency_ms=100,
default_status_code=200,
default_json_data={"status": "success"},
url_matchers=[
".hconeai.com",
"hconeai.com",
".helicone.ai",
"helicone.ai",
],
patch_async_handler=False,
patch_sync_client=False, # HTTPHandler uses self.client.send(), not self.client.post()
patch_http_handler=True, # Patch HTTPHandler.post directly
)
create_mock_helicone_client, should_use_helicone_mock = create_mock_client_factory(
_config
)

View File

@@ -0,0 +1,204 @@
"""
Humanloop integration
https://humanloop.com/
"""
from typing import Any, Dict, List, Optional, Tuple, Union, cast
import httpx
from typing_extensions import TypedDict
import litellm
from litellm.caching import DualCache
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.prompts.init_prompts import PromptSpec
from litellm.types.utils import StandardCallbackDynamicParams
from .custom_logger import CustomLogger
class PromptManagementClient(TypedDict):
prompt_id: str
prompt_template: List[AllMessageValues]
model: Optional[str]
optional_params: Optional[Dict[str, Any]]
class HumanLoopPromptManager(DualCache):
@property
def integration_name(self):
return "humanloop"
def _get_prompt_from_id_cache(
self, humanloop_prompt_id: str
) -> Optional[PromptManagementClient]:
return cast(
Optional[PromptManagementClient], self.get_cache(key=humanloop_prompt_id)
)
def _compile_prompt_helper(
self, prompt_template: List[AllMessageValues], prompt_variables: Dict[str, Any]
) -> List[AllMessageValues]:
"""
Helper function to compile the prompt by substituting variables in the template.
Args:
prompt_template: List[AllMessageValues]
prompt_variables (dict): A dictionary of variables to substitute into the prompt template.
Returns:
list: A list of dictionaries with variables substituted.
"""
compiled_prompts: List[AllMessageValues] = []
for template in prompt_template:
tc = template.get("content")
if tc and isinstance(tc, str):
formatted_template = tc.replace("{{", "{").replace("}}", "}")
compiled_content = formatted_template.format(**prompt_variables)
template["content"] = compiled_content
compiled_prompts.append(template)
return compiled_prompts
def _get_prompt_from_id_api(
self, humanloop_prompt_id: str, humanloop_api_key: str
) -> PromptManagementClient:
client = _get_httpx_client()
base_url = "https://api.humanloop.com/v5/prompts/{}".format(humanloop_prompt_id)
response = client.get(
url=base_url,
headers={
"X-Api-Key": humanloop_api_key,
"Content-Type": "application/json",
},
)
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise Exception(f"Error getting prompt from Humanloop: {e.response.text}")
json_response = response.json()
template_message = json_response["template"]
if isinstance(template_message, dict):
template_messages = [template_message]
elif isinstance(template_message, list):
template_messages = template_message
else:
raise ValueError(f"Invalid template message type: {type(template_message)}")
template_model = json_response["model"]
optional_params = {}
for k, v in json_response.items():
if k in litellm.OPENAI_CHAT_COMPLETION_PARAMS:
optional_params[k] = v
return PromptManagementClient(
prompt_id=humanloop_prompt_id,
prompt_template=cast(List[AllMessageValues], template_messages),
model=template_model,
optional_params=optional_params,
)
def _get_prompt_from_id(
self, humanloop_prompt_id: str, humanloop_api_key: str
) -> PromptManagementClient:
prompt = self._get_prompt_from_id_cache(humanloop_prompt_id)
if prompt is None:
prompt = self._get_prompt_from_id_api(
humanloop_prompt_id, humanloop_api_key
)
self.set_cache(
key=humanloop_prompt_id,
value=prompt,
ttl=litellm.HUMANLOOP_PROMPT_CACHE_TTL_SECONDS,
)
return prompt
def compile_prompt(
self,
prompt_template: List[AllMessageValues],
prompt_variables: Optional[dict],
) -> List[AllMessageValues]:
compiled_prompt: Optional[Union[str, list]] = None
if prompt_variables is None:
prompt_variables = {}
compiled_prompt = self._compile_prompt_helper(
prompt_template=prompt_template,
prompt_variables=prompt_variables,
)
return compiled_prompt
def _get_model_from_prompt(
self, prompt_management_client: PromptManagementClient, model: str
) -> str:
if prompt_management_client["model"] is not None:
return prompt_management_client["model"]
else:
return model.replace("{}/".format(self.integration_name), "")
prompt_manager = HumanLoopPromptManager()
class HumanloopLogger(CustomLogger):
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
prompt_spec: Optional[PromptSpec] = None,
prompt_label: Optional[str] = None,
prompt_version: Optional[int] = None,
ignore_prompt_manager_model: Optional[bool] = False,
ignore_prompt_manager_optional_params: Optional[bool] = False,
) -> Tuple[str, List[AllMessageValues], dict,]:
humanloop_api_key = dynamic_callback_params.get(
"humanloop_api_key"
) or get_secret_str("HUMANLOOP_API_KEY")
if prompt_id is None:
raise ValueError("prompt_id is required for Humanloop integration")
if humanloop_api_key is None:
return super().get_chat_completion_prompt(
model=model,
messages=messages,
non_default_params=non_default_params,
prompt_id=prompt_id,
prompt_variables=prompt_variables,
dynamic_callback_params=dynamic_callback_params,
prompt_spec=prompt_spec,
)
prompt_template = prompt_manager._get_prompt_from_id(
humanloop_prompt_id=prompt_id, humanloop_api_key=humanloop_api_key
)
updated_messages = prompt_manager.compile_prompt(
prompt_template=prompt_template["prompt_template"],
prompt_variables=prompt_variables,
)
prompt_template_optional_params = prompt_template["optional_params"] or {}
updated_non_default_params = {
**non_default_params,
**prompt_template_optional_params,
}
model = prompt_manager._get_model_from_prompt(
prompt_management_client=prompt_template, model=model
)
return model, updated_messages, updated_non_default_params

View File

@@ -0,0 +1,202 @@
# What is this?
## On Success events log cost to Lago - https://github.com/BerriAI/litellm/issues/3639
import json
import os
from litellm._uuid import uuid
from typing import Literal, Optional
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import (
HTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
def get_utc_datetime():
import datetime as dt
from datetime import datetime
if hasattr(dt, "UTC"):
return datetime.now(dt.UTC) # type: ignore
else:
return datetime.utcnow() # type: ignore
class LagoLogger(CustomLogger):
def __init__(self) -> None:
super().__init__()
self.validate_environment()
self.async_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
self.sync_http_handler = HTTPHandler()
def validate_environment(self):
"""
Expects
LAGO_API_BASE,
LAGO_API_KEY,
LAGO_API_EVENT_CODE,
Optional:
LAGO_API_CHARGE_BY
in the environment
"""
missing_keys = []
if os.getenv("LAGO_API_KEY", None) is None:
missing_keys.append("LAGO_API_KEY")
if os.getenv("LAGO_API_BASE", None) is None:
missing_keys.append("LAGO_API_BASE")
if os.getenv("LAGO_API_EVENT_CODE", None) is None:
missing_keys.append("LAGO_API_EVENT_CODE")
if len(missing_keys) > 0:
raise Exception("Missing keys={} in environment.".format(missing_keys))
def _common_logic(self, kwargs: dict, response_obj) -> dict:
response_obj.get("id", kwargs.get("litellm_call_id"))
get_utc_datetime().isoformat()
cost = kwargs.get("response_cost", None)
model = kwargs.get("model")
usage = {}
if (
isinstance(response_obj, litellm.ModelResponse)
or isinstance(response_obj, litellm.EmbeddingResponse)
) and hasattr(response_obj, "usage"):
usage = {
"prompt_tokens": response_obj["usage"].get("prompt_tokens", 0),
"completion_tokens": response_obj["usage"].get("completion_tokens", 0),
"total_tokens": response_obj["usage"].get("total_tokens"),
}
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
user_id = litellm_params["metadata"].get("user_api_key_user_id", None)
team_id = litellm_params["metadata"].get("user_api_key_team_id", None)
litellm_params["metadata"].get("user_api_key_org_id", None)
charge_by: Literal["end_user_id", "team_id", "user_id"] = "end_user_id"
external_customer_id: Optional[str] = None
if os.getenv("LAGO_API_CHARGE_BY", None) is not None and isinstance(
os.environ["LAGO_API_CHARGE_BY"], str
):
if os.environ["LAGO_API_CHARGE_BY"] in [
"end_user_id",
"user_id",
"team_id",
]:
charge_by = os.environ["LAGO_API_CHARGE_BY"] # type: ignore
else:
raise Exception("invalid LAGO_API_CHARGE_BY set")
if charge_by == "end_user_id":
external_customer_id = end_user_id
elif charge_by == "team_id":
external_customer_id = team_id
elif charge_by == "user_id":
external_customer_id = user_id
if external_customer_id is None:
raise Exception(
"External Customer ID is not set. Charge_by={}. User_id={}. End_user_id={}. Team_id={}".format(
charge_by, user_id, end_user_id, team_id
)
)
returned_val = {
"event": {
"transaction_id": str(uuid.uuid4()),
"external_subscription_id": external_customer_id,
"code": os.getenv("LAGO_API_EVENT_CODE"),
"properties": {"model": model, "response_cost": cost, **usage},
}
}
verbose_logger.debug(
"\033[91mLogged Lago Object:\n{}\033[0m\n".format(returned_val)
)
return returned_val
def log_success_event(self, kwargs, response_obj, start_time, end_time):
_url = os.getenv("LAGO_API_BASE")
assert _url is not None and isinstance(
_url, str
), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(_url)
if _url.endswith("/"):
_url += "api/v1/events"
else:
_url += "/api/v1/events"
api_key = os.getenv("LAGO_API_KEY")
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
_headers = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(api_key),
}
try:
response = self.sync_http_handler.post(
url=_url,
data=json.dumps(_data),
headers=_headers,
)
response.raise_for_status()
except Exception as e:
error_response = getattr(e, "response", None)
if error_response is not None and hasattr(error_response, "text"):
verbose_logger.debug(f"\nError Message: {error_response.text}")
raise e
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
verbose_logger.debug("ENTERS LAGO CALLBACK")
_url = os.getenv("LAGO_API_BASE")
assert _url is not None and isinstance(
_url, str
), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(
_url
)
if _url.endswith("/"):
_url += "api/v1/events"
else:
_url += "/api/v1/events"
api_key = os.getenv("LAGO_API_KEY")
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
_headers = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(api_key),
}
except Exception as e:
raise e
response: Optional[httpx.Response] = None
try:
response = await self.async_http_handler.post(
url=_url,
data=json.dumps(_data),
headers=_headers,
)
response.raise_for_status()
verbose_logger.debug(f"Logged Lago Object: {response.text}")
except Exception as e:
if response is not None and hasattr(response, "text"):
verbose_logger.debug(f"\nError Message: {response.text}")
raise e

View File

@@ -0,0 +1,170 @@
"""
This file contains the LangFuseHandler class
Used to get the LangFuseLogger for a given request
Handles Key/Team Based Langfuse Logging
"""
from typing import TYPE_CHECKING, Any, Dict, Optional
from litellm.litellm_core_utils.litellm_logging import StandardCallbackDynamicParams
from .langfuse import LangFuseLogger, LangfuseLoggingConfig
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache
else:
DynamicLoggingCache = Any
class LangFuseHandler:
@staticmethod
def get_langfuse_logger_for_request(
standard_callback_dynamic_params: StandardCallbackDynamicParams,
in_memory_dynamic_logger_cache: DynamicLoggingCache,
globalLangfuseLogger: Optional[LangFuseLogger] = None,
) -> LangFuseLogger:
"""
This function is used to get the LangFuseLogger for a given request
1. If dynamic credentials are passed
- check if a LangFuseLogger is cached for the dynamic credentials
- if cached LangFuseLogger is not found, create a new LangFuseLogger and cache it
2. If dynamic credentials are not passed return the globalLangfuseLogger
"""
temp_langfuse_logger: Optional[LangFuseLogger] = globalLangfuseLogger
if (
LangFuseHandler._dynamic_langfuse_credentials_are_passed(
standard_callback_dynamic_params
)
is False
):
return LangFuseHandler._return_global_langfuse_logger(
globalLangfuseLogger=globalLangfuseLogger,
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
)
# get langfuse logging config to use for this request, based on standard_callback_dynamic_params
_credentials = LangFuseHandler.get_dynamic_langfuse_logging_config(
globalLangfuseLogger=globalLangfuseLogger,
standard_callback_dynamic_params=standard_callback_dynamic_params,
)
credentials_dict = dict(_credentials)
# check if langfuse logger is already cached
temp_langfuse_logger = in_memory_dynamic_logger_cache.get_cache(
credentials=credentials_dict, service_name="langfuse"
)
# if not cached, create a new langfuse logger and cache it
if temp_langfuse_logger is None:
temp_langfuse_logger = (
LangFuseHandler._create_langfuse_logger_from_credentials(
credentials=credentials_dict,
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
)
)
return temp_langfuse_logger
@staticmethod
def _return_global_langfuse_logger(
globalLangfuseLogger: Optional[LangFuseLogger],
in_memory_dynamic_logger_cache: DynamicLoggingCache,
) -> LangFuseLogger:
"""
Returns the Global LangfuseLogger set on litellm
(this is the default langfuse logger - used when no dynamic credentials are passed)
If no Global LangfuseLogger is set, it will check in_memory_dynamic_logger_cache for a cached LangFuseLogger
This function is used to return the globalLangfuseLogger if it exists, otherwise it will check in_memory_dynamic_logger_cache for a cached LangFuseLogger
"""
if globalLangfuseLogger is not None:
return globalLangfuseLogger
credentials_dict: Dict[
str, Any
] = (
{}
) # the global langfuse logger uses Environment Variables, there are no dynamic credentials
globalLangfuseLogger = in_memory_dynamic_logger_cache.get_cache(
credentials=credentials_dict,
service_name="langfuse",
)
if globalLangfuseLogger is None:
globalLangfuseLogger = (
LangFuseHandler._create_langfuse_logger_from_credentials(
credentials=credentials_dict,
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
)
)
return globalLangfuseLogger
@staticmethod
def _create_langfuse_logger_from_credentials(
credentials: Dict,
in_memory_dynamic_logger_cache: DynamicLoggingCache,
) -> LangFuseLogger:
"""
This function is used to
1. create a LangFuseLogger from the credentials
2. cache the LangFuseLogger to prevent re-creating it for the same credentials
"""
langfuse_logger = LangFuseLogger(
langfuse_public_key=credentials.get("langfuse_public_key"),
langfuse_secret=credentials.get("langfuse_secret"),
langfuse_host=credentials.get("langfuse_host"),
)
in_memory_dynamic_logger_cache.set_cache(
credentials=credentials,
service_name="langfuse",
logging_obj=langfuse_logger,
)
return langfuse_logger
@staticmethod
def get_dynamic_langfuse_logging_config(
standard_callback_dynamic_params: StandardCallbackDynamicParams,
globalLangfuseLogger: Optional[LangFuseLogger] = None,
) -> LangfuseLoggingConfig:
"""
This function is used to get the Langfuse logging config to use for a given request.
It checks if the dynamic parameters are provided in the standard_callback_dynamic_params and uses them to get the Langfuse logging config.
If no dynamic parameters are provided, it uses the `globalLangfuseLogger` values
"""
# only use dynamic params if langfuse credentials are passed dynamically
return LangfuseLoggingConfig(
langfuse_secret=standard_callback_dynamic_params.get("langfuse_secret")
or standard_callback_dynamic_params.get("langfuse_secret_key"),
langfuse_public_key=standard_callback_dynamic_params.get(
"langfuse_public_key"
),
langfuse_host=standard_callback_dynamic_params.get("langfuse_host"),
)
@staticmethod
def _dynamic_langfuse_credentials_are_passed(
standard_callback_dynamic_params: StandardCallbackDynamicParams,
) -> bool:
"""
This function is used to check if the dynamic langfuse credentials are passed in standard_callback_dynamic_params
Returns:
bool: True if the dynamic langfuse credentials are passed, False otherwise
"""
if (
standard_callback_dynamic_params.get("langfuse_host") is not None
or standard_callback_dynamic_params.get("langfuse_public_key") is not None
or standard_callback_dynamic_params.get("langfuse_secret") is not None
or standard_callback_dynamic_params.get("langfuse_secret_key") is not None
):
return True
return False

Some files were not shown because too many files have changed in this diff Show More