chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
## Supported Secret Managers to read credentials from
Example read OPENAI_API_KEY, AZURE_API_KEY from a secret manager

View File

@@ -0,0 +1,143 @@
"""
This is a file for the AWS Secret Manager Integration
Relevant issue: https://github.com/BerriAI/litellm/issues/1883
Requires:
* `os.environ["AWS_REGION_NAME"],
* `pip install boto3>=1.28.57`
"""
import ast
import base64
import os
import re
from typing import Any, Dict, Optional
import litellm
from litellm.proxy._types import KeyManagementSystem
def validate_environment():
if "AWS_REGION_NAME" not in os.environ:
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
def load_aws_kms(use_aws_kms: Optional[bool]):
if use_aws_kms is None or use_aws_kms is False:
return
try:
import boto3
validate_environment()
# Create a Secrets Manager client
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
litellm.secret_manager_client = kms_client
litellm._key_management_system = KeyManagementSystem.AWS_KMS
except Exception as e:
raise e
class AWSKeyManagementService_V2:
"""
V2 Clean Class for decrypting keys from AWS KeyManagementService
"""
def __init__(self) -> None:
self.validate_environment()
self.kms_client = self.load_aws_kms(use_aws_kms=True)
def validate_environment(
self,
):
if "AWS_REGION_NAME" not in os.environ:
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
## CHECK IF LICENSE IN ENV ## - premium feature
is_litellm_license_in_env: bool = False
if os.getenv("LITELLM_LICENSE", None) is not None:
is_litellm_license_in_env = True
elif os.getenv("LITELLM_SECRET_AWS_KMS_LITELLM_LICENSE", None) is not None:
is_litellm_license_in_env = True
if is_litellm_license_in_env is False:
raise ValueError(
"AWSKeyManagementService V2 is an Enterprise Feature. Please add a valid LITELLM_LICENSE to your envionment."
)
def load_aws_kms(self, use_aws_kms: Optional[bool]):
if use_aws_kms is None or use_aws_kms is False:
return
try:
import boto3
validate_environment()
# Create a Secrets Manager client
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
return kms_client
except Exception as e:
raise e
def decrypt_value(self, secret_name: str) -> Any:
if self.kms_client is None:
raise ValueError("kms_client is None")
encrypted_value = os.getenv(secret_name, None)
if encrypted_value is None:
raise Exception(
"AWS KMS - Encrypted Value of Key={} is None".format(secret_name)
)
if isinstance(encrypted_value, str) and encrypted_value.startswith("aws_kms/"):
encrypted_value = encrypted_value.replace("aws_kms/", "")
# Decode the base64 encoded ciphertext
ciphertext_blob = base64.b64decode(encrypted_value)
# Set up the parameters for the decrypt call
params = {"CiphertextBlob": ciphertext_blob}
# Perform the decryption
response = self.kms_client.decrypt(**params)
# Extract and decode the plaintext
plaintext = response["Plaintext"]
secret = plaintext.decode("utf-8")
if isinstance(secret, str):
secret = secret.strip()
try:
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
except Exception:
pass
return secret
"""
- look for all values in the env with `aws_kms/<hashed_key>`
- decrypt keys
- rewrite env var with decrypted key (). Note: this environment variable will only be available to the current process and any child processes spawned from it. Once the Python script ends, the environment variable will not persist.
"""
def decrypt_env_var() -> Dict[str, Any]:
# setup client class
aws_kms = AWSKeyManagementService_V2()
# iterate through env - for `aws_kms/`
new_values = {}
for k, v in os.environ.items():
if (
k is not None
and isinstance(k, str)
and k.lower().startswith("litellm_secret_aws_kms")
) or (v is not None and isinstance(v, str) and v.startswith("aws_kms/")):
decrypted_value = aws_kms.decrypt_value(secret_name=k)
# reset env var
k = re.sub("litellm_secret_aws_kms_", "", k, flags=re.IGNORECASE)
new_values[k] = decrypted_value
return new_values

View File

@@ -0,0 +1,539 @@
"""
This is a file for the AWS Secret Manager Integration
Handles Async Operations for:
- Read Secret
- Write Secret (CreateSecret)
- Update Secret (PutSecretValue) - for in-place rotation when alias is preserved
- Delete Secret
Relevant issue: https://github.com/BerriAI/litellm/issues/1883
Requires:
* `os.environ["AWS_REGION_NAME"],
* `pip install boto3>=1.28.57`
"""
import json
import os
from typing import Any, Dict, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.proxy._types import KeyManagementSystem
from litellm.types.llms.custom_http import httpxSpecialProvider
from .base_secret_manager import BaseSecretManager
class AWSSecretsManagerV2(BaseAWSLLM, BaseSecretManager):
def __init__(
self,
aws_region_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_external_id: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
aws_sts_endpoint: Optional[str] = None,
**kwargs,
):
BaseSecretManager.__init__(self, **kwargs)
BaseAWSLLM.__init__(self, **kwargs)
# Store AWS authentication settings
self.aws_region_name = aws_region_name
self.aws_role_name = aws_role_name
self.aws_session_name = aws_session_name
self.aws_external_id = aws_external_id
self.aws_profile_name = aws_profile_name
self.aws_web_identity_token = aws_web_identity_token
self.aws_sts_endpoint = aws_sts_endpoint
@classmethod
def validate_environment(cls):
# AWS_REGION_NAME is only strictly required if not using a profile or role
# When using IAM roles, the region can come from multiple sources
if (
"AWS_REGION_NAME" not in os.environ
and "AWS_REGION" not in os.environ
and "AWS_DEFAULT_REGION" not in os.environ
):
verbose_logger.warning(
"No AWS region found in environment. Ensure aws_region_name is set in key_management_settings "
"or AWS_REGION_NAME/AWS_REGION/AWS_DEFAULT_REGION is set in environment."
)
@classmethod
def load_aws_secret_manager(
cls,
use_aws_secret_manager: Optional[bool],
key_management_settings: Optional[Any] = None,
):
"""
Initialize AWSSecretsManagerV2 with settings from key_management_settings
"""
if use_aws_secret_manager is None or use_aws_secret_manager is False:
return
try:
cls.validate_environment()
# Extract AWS settings from key_management_settings if provided
aws_kwargs = {}
if key_management_settings is not None:
aws_kwargs = {
"aws_region_name": getattr(
key_management_settings, "aws_region_name", None
),
"aws_role_name": getattr(
key_management_settings, "aws_role_name", None
),
"aws_session_name": getattr(
key_management_settings, "aws_session_name", None
),
"aws_external_id": getattr(
key_management_settings, "aws_external_id", None
),
"aws_profile_name": getattr(
key_management_settings, "aws_profile_name", None
),
"aws_web_identity_token": getattr(
key_management_settings, "aws_web_identity_token", None
),
"aws_sts_endpoint": getattr(
key_management_settings, "aws_sts_endpoint", None
),
}
# Remove None values
aws_kwargs = {k: v for k, v in aws_kwargs.items() if v is not None}
litellm.secret_manager_client = cls(**aws_kwargs)
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
except Exception as e:
raise e
async def async_read_secret(
self,
secret_name: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
primary_secret_name: Optional[str] = None,
) -> Optional[str]:
"""
Async function to read a secret from AWS Secrets Manager
Returns:
str: Secret value
Raises:
ValueError: If the secret is not found or an HTTP error occurs
"""
if primary_secret_name:
return await self.async_read_secret_from_primary_secret(
secret_name=secret_name, primary_secret_name=primary_secret_name
)
endpoint_url, headers, body = self._prepare_request(
action="GetSecretValue",
secret_name=secret_name,
optional_params=optional_params,
)
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"timeout": timeout},
)
try:
response = await async_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()["SecretString"]
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
except Exception as e:
verbose_logger.exception(
"Error reading secret='%s' from AWS Secrets Manager: %s",
secret_name,
str(e),
)
return None
def sync_read_secret(
self,
secret_name: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
primary_secret_name: Optional[str] = None,
) -> Optional[str]:
"""
Sync function to read a secret from AWS Secrets Manager
Done for backwards compatibility with existing codebase, since get_secret is a sync function
"""
# self._prepare_request uses these env vars, we cannot read them from AWS Secrets Manager. If we do we'd get stuck in an infinite loop
if secret_name in [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
"AWS_REGION",
"AWS_BEDROCK_RUNTIME_ENDPOINT",
]:
return os.getenv(secret_name)
if primary_secret_name:
return self.sync_read_secret_from_primary_secret(
secret_name=secret_name, primary_secret_name=primary_secret_name
)
endpoint_url, headers, body = self._prepare_request(
action="GetSecretValue",
secret_name=secret_name,
optional_params=optional_params,
)
sync_client = _get_httpx_client(
params={"timeout": timeout},
)
try:
response = sync_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
return response.json()["SecretString"]
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
except httpx.HTTPStatusError as e:
verbose_logger.exception(
"Error reading secret='%s' from AWS Secrets Manager: %s, %s",
secret_name,
str(e.response.text),
str(e.response.status_code),
)
except Exception as e:
verbose_logger.exception(
"Error reading secret='%s' from AWS Secrets Manager: %s",
secret_name,
str(e),
)
return None
def _parse_primary_secret(self, primary_secret_json_str: Optional[str]) -> dict:
"""
Parse the primary secret JSON string into a dictionary
Args:
primary_secret_json_str: JSON string containing key-value pairs
Returns:
Dictionary of key-value pairs from the primary secret
"""
return json.loads(primary_secret_json_str or "{}")
def sync_read_secret_from_primary_secret(
self, secret_name: str, primary_secret_name: str
) -> Optional[str]:
"""
Read a secret from the primary secret
"""
primary_secret_json_str = self.sync_read_secret(secret_name=primary_secret_name)
primary_secret_kv_pairs = self._parse_primary_secret(primary_secret_json_str)
return primary_secret_kv_pairs.get(secret_name)
async def async_read_secret_from_primary_secret(
self, secret_name: str, primary_secret_name: str
) -> Optional[str]:
"""
Read a secret from the primary secret
"""
primary_secret_json_str = await self.async_read_secret(
secret_name=primary_secret_name
)
primary_secret_kv_pairs = self._parse_primary_secret(primary_secret_json_str)
return primary_secret_kv_pairs.get(secret_name)
async def async_write_secret(
self,
secret_name: str,
secret_value: str,
description: Optional[str] = None,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
tags: Optional[Union[dict, list]] = None,
) -> dict:
"""
Async function to write a secret to AWS Secrets Manager
Args:
secret_name: Name of the secret
secret_value: Value to store (can be a JSON string)
description: Optional description for the secret
optional_params: Additional AWS parameters
timeout: Request timeout
tags: Optional dict or list of tags to apply, e.g.
{"Environment": "Prod", "Owner": "AI-Platform"} or
[{"Key": "Environment", "Value": "Prod"}]
"""
from litellm._uuid import uuid
data: Dict[str, Any] = {
"Name": secret_name,
"SecretString": secret_value,
"ClientRequestToken": str(uuid.uuid4()),
}
if description:
data["Description"] = description
# ✅ Normalize tags to AWS format
if tags:
if isinstance(tags, dict):
tags_list = [{"Key": k, "Value": str(v)} for k, v in tags.items()]
elif isinstance(tags, list):
tags_list = tags
else:
raise ValueError("Tags must be a dict or list of {Key, Value} pairs")
data["Tags"] = tags_list # type: ignore[assignment]
endpoint_url, headers, body = self._prepare_request(
action="CreateSecret",
secret_name=secret_name,
secret_value=secret_value,
optional_params=optional_params,
request_data=data,
)
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"timeout": timeout},
)
try:
response = await async_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as err:
raise ValueError(f"HTTP error occurred: {err.response.text}")
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
async def async_put_secret_value(
self,
secret_name: str,
secret_value: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> dict:
"""
Async function to update an existing secret's value in AWS Secrets Manager.
Uses PutSecretValue to update in place. Use this when rotating a secret
that keeps the same name (current_secret_name == new_secret_name).
Args:
secret_name: Name of the existing secret to update
secret_value: New value to store
optional_params: Additional AWS parameters
timeout: Request timeout
Returns:
dict: Response from AWS Secrets Manager containing update details
"""
from litellm._uuid import uuid
data: Dict[str, Any] = {
"SecretId": secret_name,
"SecretString": secret_value,
"ClientRequestToken": str(uuid.uuid4()),
}
endpoint_url, headers, body = self._prepare_request(
action="PutSecretValue",
secret_name=secret_name,
secret_value=secret_value,
optional_params=optional_params,
request_data=data,
)
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"timeout": timeout},
)
try:
response = await async_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as err:
raise ValueError(f"HTTP error occurred: {err.response.text}")
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
async def async_rotate_secret(
self,
current_secret_name: str,
new_secret_name: str,
new_secret_value: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> dict:
"""
Rotate a secret. When current_secret_name == new_secret_name (in-place
update), uses PutSecretValue instead of create+delete to avoid
ResourceExistsException.
"""
if current_secret_name == new_secret_name:
# Same alias: update in place via PutSecretValue
verbose_logger.info(
"Secret rotated in-place (PutSecretValue): secret_name=%s",
current_secret_name,
)
return await self.async_put_secret_value(
secret_name=current_secret_name,
secret_value=new_secret_value,
optional_params=optional_params,
timeout=timeout,
)
# Different names: create new, delete old (base class logic)
return await super().async_rotate_secret(
current_secret_name=current_secret_name,
new_secret_name=new_secret_name,
new_secret_value=new_secret_value,
optional_params=optional_params,
timeout=timeout,
)
async def async_delete_secret(
self,
secret_name: str,
recovery_window_in_days: Optional[int] = 7,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> dict:
"""
Async function to delete a secret from AWS Secrets Manager
Args:
secret_name: Name of the secret to delete
recovery_window_in_days: Number of days before permanent deletion (default: 7)
optional_params: Additional AWS parameters
timeout: Request timeout
Returns:
dict: Response from AWS Secrets Manager containing deletion details
"""
# Prepare the request data
data = {
"SecretId": secret_name,
"RecoveryWindowInDays": recovery_window_in_days,
}
endpoint_url, headers, body = self._prepare_request(
action="DeleteSecret",
secret_name=secret_name,
optional_params=optional_params,
request_data=data,
)
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"timeout": timeout},
)
try:
response = await async_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as err:
raise ValueError(f"HTTP error occurred: {err.response.text}")
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
def _prepare_request(
self,
action: str, # "GetSecretValue" or "PutSecretValue"
secret_name: str,
secret_value: Optional[str] = None,
optional_params: Optional[dict] = None,
request_data: Optional[dict] = None,
) -> tuple[str, Any, bytes]:
"""Prepare the AWS Secrets Manager request"""
try:
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
optional_params = optional_params or {}
# Build optional_params from instance settings if not provided
# This allows the IAM role settings to be used for Secret Manager calls
if not optional_params.get("aws_role_name") and self.aws_role_name:
optional_params["aws_role_name"] = self.aws_role_name
if not optional_params.get("aws_session_name") and self.aws_session_name:
optional_params["aws_session_name"] = self.aws_session_name
if not optional_params.get("aws_region_name") and self.aws_region_name:
optional_params["aws_region_name"] = self.aws_region_name
if not optional_params.get("aws_external_id") and self.aws_external_id:
optional_params["aws_external_id"] = self.aws_external_id
if not optional_params.get("aws_profile_name") and self.aws_profile_name:
optional_params["aws_profile_name"] = self.aws_profile_name
if (
not optional_params.get("aws_web_identity_token")
and self.aws_web_identity_token
):
optional_params["aws_web_identity_token"] = self.aws_web_identity_token
if not optional_params.get("aws_sts_endpoint") and self.aws_sts_endpoint:
optional_params["aws_sts_endpoint"] = self.aws_sts_endpoint
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
optional_params
)
# Get endpoint
_, endpoint_url = self.get_runtime_endpoint(
api_base=None,
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
aws_region_name=boto3_credentials_info.aws_region_name,
)
endpoint_url = endpoint_url.replace("bedrock-runtime", "secretsmanager")
# Use provided request_data if available, otherwise build default data
if request_data:
data = request_data
else:
data = {"SecretId": secret_name}
if secret_value and action == "PutSecretValue":
data["SecretString"] = secret_value
body = json.dumps(data).encode("utf-8")
headers = {
"Content-Type": "application/x-amz-json-1.1",
"X-Amz-Target": f"secretsmanager.{action}",
}
# Sign request
request = AWSRequest(
method="POST", url=endpoint_url, data=body, headers=headers
)
SigV4Auth(
boto3_credentials_info.credentials,
"secretsmanager",
boto3_credentials_info.aws_region_name,
).add_auth(request)
prepped = request.prepare()
return endpoint_url, prepped.headers, body

View File

@@ -0,0 +1,180 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Union
import httpx
from litellm import verbose_logger
class BaseSecretManager(ABC):
"""
Abstract base class for secret management implementations.
"""
@abstractmethod
async def async_read_secret(
self,
secret_name: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> Optional[str]:
"""
Asynchronously read a secret from the secret manager.
Args:
secret_name (str): Name/path of the secret to read
optional_params (Optional[dict]): Additional parameters specific to the secret manager
timeout (Optional[Union[float, httpx.Timeout]]): Request timeout
Returns:
Optional[str]: The secret value if found, None otherwise
"""
pass
@abstractmethod
def sync_read_secret(
self,
secret_name: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> Optional[str]:
"""
Synchronously read a secret from the secret manager.
Args:
secret_name (str): Name/path of the secret to read
optional_params (Optional[dict]): Additional parameters specific to the secret manager
timeout (Optional[Union[float, httpx.Timeout]]): Request timeout
Returns:
Optional[str]: The secret value if found, None otherwise
"""
pass
@abstractmethod
async def async_write_secret(
self,
secret_name: str,
secret_value: str,
description: Optional[str] = None,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
tags: Optional[Union[dict, list]] = None,
) -> Dict[str, Any]:
"""
Asynchronously write a secret to the secret manager.
Args:
secret_name (str): Name/path of the secret to write
secret_value (str): Value to store
description (Optional[str]): Description of the secret. Some secret managers allow storing a description with the secret.
optional_params (Optional[dict]): Additional parameters specific to the secret manager
timeout (Optional[Union[float, httpx.Timeout]]): Request timeout
tags: Optional dict or list of tags to apply, e.g.
{"Environment": "Prod", "Owner": "AI-Platform"} or
[{"Key": "Environment", "Value": "Prod"}]
Returns:
Dict[str, Any]: Response from the secret manager containing write operation details
"""
pass
@abstractmethod
async def async_delete_secret(
self,
secret_name: str,
recovery_window_in_days: Optional[int] = 7,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> dict:
"""
Async function to delete a secret from the secret manager
Args:
secret_name: Name of the secret to delete
recovery_window_in_days: Number of days before permanent deletion (default: 7)
optional_params: Additional parameters specific to the secret manager
timeout: Request timeout
Returns:
dict: Response from the secret manager containing deletion details
"""
pass
async def async_rotate_secret(
self,
current_secret_name: str,
new_secret_name: str,
new_secret_value: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> dict:
"""
Async function to rotate a secret by creating a new one and deleting the old one.
This allows for both value and name changes during rotation.
Args:
current_secret_name: Current name of the secret
new_secret_name: New name for the secret
new_secret_value: New value for the secret
optional_params: Additional AWS parameters
timeout: Request timeout
Returns:
dict: Response containing the new secret details
Raises:
ValueError: If the secret doesn't exist or if there's an HTTP error
"""
try:
# First verify the old secret exists
old_secret = await self.async_read_secret(
secret_name=current_secret_name,
optional_params=optional_params,
timeout=timeout,
)
if old_secret is None:
raise ValueError(f"Current secret {current_secret_name} not found")
# Create new secret with new name and value
create_response = await self.async_write_secret(
secret_name=new_secret_name,
secret_value=new_secret_value,
description=f"Rotated from {current_secret_name}",
optional_params=optional_params,
timeout=timeout,
)
# Verify new secret was created successfully
new_secret = await self.async_read_secret(
secret_name=new_secret_name,
optional_params=optional_params,
timeout=timeout,
)
if new_secret is None:
raise ValueError(f"Failed to verify new secret {new_secret_name}")
# If everything is successful, delete the old secret
await self.async_delete_secret(
secret_name=current_secret_name,
recovery_window_in_days=7, # Keep for recovery if needed
optional_params=optional_params,
timeout=timeout,
)
return create_response
except httpx.HTTPStatusError as err:
verbose_logger.exception(
"Error rotating secret in AWS Secrets Manager: %s",
str(err.response.text),
)
raise ValueError(f"HTTP error occurred: {err.response.text}")
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
except Exception as e:
verbose_logger.exception(
"Error rotating secret in AWS Secrets Manager: %s", str(e)
)
raise

View File

@@ -0,0 +1,93 @@
"""
Loader for custom secret managers.
Handles dynamic loading of user-defined secret manager classes from Python files.
"""
import importlib.util
import os
from typing import Optional
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_secret_manager import CustomSecretManager
from litellm.types.secret_managers.main import KeyManagementSystem
def load_custom_secret_manager(config_file_path: Optional[str] = None) -> None:
"""
Load and initialize a custom secret manager from a python file.
Similar to how custom guardrails are loaded - loads the class from
the custom_secret_manager field in key_management_settings.
Args:
config_file_path: Path to the config.yaml file
Raises:
ValueError: If required configuration is missing
ImportError: If the custom secret manager module cannot be loaded
"""
if not config_file_path:
raise ValueError(
"CustomSecretManagerException - config_file_path is required to load custom secret manager"
)
# Get the custom_secret_manager class path from settings
if litellm._key_management_settings is None:
raise ValueError(
"CustomSecretManagerException - key_management_settings is required with custom_secret_manager field"
)
custom_secret_manager_path = getattr(
litellm._key_management_settings, "custom_secret_manager", None
)
if not custom_secret_manager_path:
raise ValueError(
"CustomSecretManagerException - custom_secret_manager field is required in key_management_settings"
)
# Split into file_name and class_name (e.g., "my_secret_manager.InMemorySecretManager")
_file_name, _class_name = custom_secret_manager_path.split(".")
verbose_proxy_logger.debug(
"Initializing custom secret manager: %s, file_name: %s, class_name: %s",
custom_secret_manager_path,
_file_name,
_class_name,
)
# Load the module from the same directory as config.yaml
directory = os.path.dirname(config_file_path)
module_file_path = os.path.join(directory, _file_name) + ".py"
spec = importlib.util.spec_from_file_location(_class_name, module_file_path) # type: ignore
if not spec:
raise ImportError(
f"Could not find a module specification for {module_file_path}"
)
module = importlib.util.module_from_spec(spec) # type: ignore
spec.loader.exec_module(module) # type: ignore
_secret_manager_class = getattr(module, _class_name)
# Validate that it's a CustomSecretManager subclass
if not issubclass(_secret_manager_class, CustomSecretManager):
raise TypeError(
f"CustomSecretManagerException - {_class_name} must be a subclass of CustomSecretManager"
)
# Instantiate the custom secret manager
_secret_manager_instance = _secret_manager_class()
# Set it as the secret manager client
litellm.secret_manager_client = _secret_manager_instance
# Set the key management system to CUSTOM so get_secret knows to use it
litellm._key_management_system = KeyManagementSystem.CUSTOM
verbose_proxy_logger.info(
"Successfully initialized custom secret manager: %s",
custom_secret_manager_path,
)

View File

@@ -0,0 +1,350 @@
import base64
import os
from typing import Any, Dict, Optional, Union
from urllib.parse import quote
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.caching import InMemoryCache
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import KeyManagementSystem
from .base_secret_manager import BaseSecretManager
from .main import str_to_bool
class CyberArkSecretManager(BaseSecretManager):
def __init__(self):
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
# CyberArk Conjur-specific config
self.conjur_addr = os.getenv("CYBERARK_API_BASE", "http://127.0.0.1:8080")
self.conjur_account = os.getenv("CYBERARK_ACCOUNT", "default")
self.conjur_username = os.getenv("CYBERARK_USERNAME", "admin")
self.conjur_api_key = os.getenv("CYBERARK_API_KEY", "")
# Optional config for certificate-based auth
self.tls_cert_path = os.getenv("CYBERARK_CLIENT_CERT", "")
self.tls_key_path = os.getenv("CYBERARK_CLIENT_KEY", "")
# SSL verification - can be disabled for self-signed certificates
# Set CYBERARK_SSL_VERIFY=false to disable SSL verification
ssl_verify_env = str_to_bool(os.getenv("CYBERARK_SSL_VERIFY"))
self.ssl_verify: bool = ssl_verify_env if ssl_verify_env is not None else True
# Validate environment
if not self.conjur_api_key and not (self.tls_cert_path and self.tls_key_path):
raise ValueError(
"Missing CyberArk credentials. Please set CYBERARK_API_KEY or both CYBERARK_CLIENT_CERT and CYBERARK_CLIENT_KEY in your environment."
)
litellm.secret_manager_client = self
litellm._key_management_system = KeyManagementSystem.CYBERARK
# Tokens expire after ~8 minutes, so we cache for 5 minutes to be safe
_refresh_interval = int(os.environ.get("CYBERARK_REFRESH_INTERVAL", "300"))
self.cache = InMemoryCache(default_ttl=_refresh_interval)
if premium_user is not True:
raise ValueError(
f"CyberArk secret manager is only available for premium users. {CommonProxyErrors.not_premium_user.value}"
)
if not self.ssl_verify:
verbose_logger.warning(
"CyberArk SSL verification is disabled. This is insecure and should only be used for testing with self-signed certificates."
)
def _authenticate(self) -> str:
"""
Authenticate with CyberArk Conjur and get a session token.
The token is a JSON object that must be base64-encoded for use in subsequent requests.
Returns:
str: Base64-encoded session token
"""
# Check if we have a cached token
cached_token = self.cache.get_cache("cyberark_auth_token")
if cached_token is not None:
return cached_token
verbose_logger.debug("Authenticating with CyberArk Conjur...")
auth_url = f"{self.conjur_addr}/authn/{self.conjur_account}/{self.conjur_username}/authenticate"
try:
if self.tls_cert_path and self.tls_key_path:
# Certificate-based authentication - need custom client for cert
http_client = httpx.Client(
cert=(self.tls_cert_path, self.tls_key_path),
verify=self.ssl_verify,
)
resp = http_client.post(auth_url, content=self.conjur_api_key)
else:
# API key authentication
http_handler = _get_httpx_client(params={"ssl_verify": self.ssl_verify})
resp = http_handler.client.post(auth_url, content=self.conjur_api_key)
resp.raise_for_status()
# The response is a JSON token that needs to be base64-encoded
token_json = resp.text
token_b64 = base64.b64encode(token_json.encode()).decode()
verbose_logger.debug("Successfully authenticated with CyberArk Conjur.")
# Cache the token for the refresh interval
self.cache.set_cache(key="cyberark_auth_token", value=token_b64)
return token_b64
except Exception as e:
raise RuntimeError(f"Could not authenticate to CyberArk Conjur: {e}")
def _get_request_headers(self) -> dict:
"""
Get headers for CyberArk API requests including authentication.
Returns:
dict: Headers with authentication token
"""
token = self._authenticate()
return {"Authorization": f'Token token="{token}"'}
def _ensure_variable_exists(self, secret_name: str) -> None:
"""
Ensure a variable exists in CyberArk Conjur by creating a policy entry if needed.
Args:
secret_name: Name of the variable to ensure exists
"""
# In production, we'd check if the variable exists first
# For now, we'll attempt to create it and ignore if it already exists
policy_url = f"{self.conjur_addr}/policies/{self.conjur_account}/policy/root"
policy_yaml = f"- !variable {secret_name}\n"
try:
client = _get_httpx_client(params={"ssl_verify": self.ssl_verify})
resp = client.client.post(
policy_url,
headers={
**self._get_request_headers(),
"Content-Type": "application/x-yaml",
},
content=policy_yaml,
)
resp.raise_for_status()
verbose_logger.debug(f"Created policy entry for variable: {secret_name}")
except httpx.HTTPStatusError as e:
# Variable might already exist, which is fine
if e.response.status_code in [409, 422]:
verbose_logger.debug(
f"Variable {secret_name} already exists or policy conflict (expected)"
)
else:
verbose_logger.warning(
f"Could not ensure variable exists: {e.response.status_code} - {e.response.text}"
)
except Exception as e:
verbose_logger.warning(f"Error ensuring variable exists: {e}")
def get_url(self, secret_name: str) -> str:
"""
Build the URL for accessing a secret in CyberArk Conjur.
Args:
secret_name: Name of the secret (will be URL-encoded)
Returns:
str: Full URL for the secret
"""
# URL-encode the secret name to handle slashes and special characters
encoded_name = quote(secret_name, safe="")
return (
f"{self.conjur_addr}/secrets/{self.conjur_account}/variable/{encoded_name}"
)
async def async_read_secret(
self,
secret_name: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> Optional[str]:
"""
Reads a secret from CyberArk Conjur using an async HTTPX client.
Args:
secret_name: Name/path of the secret to read
optional_params: Additional parameters (not used for Conjur)
timeout: Request timeout
Returns:
Optional[str]: The secret value if found, None otherwise
"""
# Check cache first
if self.cache.get_cache(secret_name) is not None:
return self.cache.get_cache(secret_name)
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"ssl_verify": self.ssl_verify},
)
try:
url = self.get_url(secret_name)
response = await async_client.get(url, headers=self._get_request_headers())
response.raise_for_status()
# CyberArk Conjur returns the raw secret value as text
secret_value = response.text
self.cache.set_cache(secret_name, secret_value)
return secret_value
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
verbose_logger.debug(
f"Secret {secret_name} not found in CyberArk Conjur"
)
else:
verbose_logger.exception(
f"Error reading secret from CyberArk Conjur: {e}"
)
return None
except Exception as e:
verbose_logger.exception(f"Error reading secret from CyberArk Conjur: {e}")
return None
def sync_read_secret(
self,
secret_name: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> Optional[str]:
"""
Reads a secret from CyberArk Conjur using a sync HTTPX client.
Args:
secret_name: Name/path of the secret to read
optional_params: Additional parameters (not used for Conjur)
timeout: Request timeout
Returns:
Optional[str]: The secret value if found, None otherwise
"""
# Check cache first
if self.cache.get_cache(secret_name) is not None:
return self.cache.get_cache(secret_name)
sync_client = _get_httpx_client(params={"ssl_verify": self.ssl_verify})
try:
url = self.get_url(secret_name)
response = sync_client.client.get(url, headers=self._get_request_headers())
response.raise_for_status()
# CyberArk Conjur returns the raw secret value as text
secret_value = response.text
self.cache.set_cache(secret_name, secret_value)
return secret_value
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
verbose_logger.debug(
f"Secret {secret_name} not found in CyberArk Conjur"
)
else:
verbose_logger.exception(
f"Error reading secret from CyberArk Conjur: {e}"
)
return None
except Exception as e:
verbose_logger.exception(f"Error reading secret from CyberArk Conjur: {e}")
return None
async def async_write_secret(
self,
secret_name: str,
secret_value: str,
description: Optional[str] = None,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
tags: Optional[Union[dict, list]] = None,
) -> Dict[str, Any]:
"""
Writes a secret to CyberArk Conjur using an async HTTPX client.
Args:
secret_name: Name/path of the secret to write
secret_value: Value to store
description: Optional description (not used by Conjur)
optional_params: Additional parameters
timeout: Request timeout
tags: Optional tags (not used by Conjur)
Returns:
dict: Response containing status and details of the operation
"""
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"ssl_verify": self.ssl_verify},
)
try:
# Ensure the variable exists in the policy first
self._ensure_variable_exists(secret_name)
# Now set the secret value
url = self.get_url(secret_name)
response = await async_client.post(
url=url, headers=self._get_request_headers(), content=secret_value
)
response.raise_for_status()
# Update cache
self.cache.set_cache(secret_name, secret_value)
return {
"status": "success",
"message": f"Secret {secret_name} written successfully",
}
except Exception as e:
verbose_logger.exception(f"Error writing secret to CyberArk Conjur: {e}")
return {"status": "error", "message": str(e)}
async def async_delete_secret(
self,
secret_name: str,
recovery_window_in_days: Optional[int] = 7,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> dict:
"""
CyberArk Conjur does not support direct secret deletion via API.
Secrets can only be removed through policy updates.
Args:
secret_name: Name of the secret
recovery_window_in_days: Not used
optional_params: Additional parameters
timeout: Request timeout
Returns:
dict: Response indicating operation not supported
"""
verbose_logger.warning(
"CyberArk Conjur does not support direct secret deletion. "
"Secrets must be removed through policy updates."
)
# Clear from cache
self.cache.delete_cache(secret_name)
return {
"status": "not_supported",
"message": "CyberArk Conjur does not support direct secret deletion. Use policy updates to remove variables.",
}

View File

@@ -0,0 +1,121 @@
import os
from typing import Any, Callable, Optional, Union
from litellm._logging import verbose_logger
from litellm.types.secret_managers.get_azure_ad_token_provider import (
AzureCredentialType,
)
def infer_credential_type_from_environment() -> AzureCredentialType:
if (
os.environ.get("AZURE_CLIENT_ID")
and os.environ.get("AZURE_CLIENT_SECRET")
and os.environ.get("AZURE_TENANT_ID")
):
return AzureCredentialType.ClientSecretCredential
elif os.environ.get("AZURE_CLIENT_ID"):
return AzureCredentialType.ManagedIdentityCredential
elif (
os.environ.get("AZURE_CLIENT_ID")
and os.environ.get("AZURE_TENANT_ID")
and os.environ.get("AZURE_CERTIFICATE_PATH")
and os.environ.get("AZURE_CERTIFICATE_PASSWORD")
):
return AzureCredentialType.CertificateCredential
elif os.environ.get("AZURE_CERTIFICATE_PASSWORD"):
return AzureCredentialType.CertificateCredential
elif os.environ.get("AZURE_CERTIFICATE_PATH"):
return AzureCredentialType.CertificateCredential
else:
return AzureCredentialType.DefaultAzureCredential
def get_azure_ad_token_provider(
azure_scope: Optional[str] = None,
azure_credential: Optional[AzureCredentialType] = None,
) -> Callable[[], str]:
"""
Get Azure AD token provider based on Service Principal with Secret workflow.
Based on: https://github.com/openai/openai-python/blob/main/examples/azure_ad.py
See Also:
https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#service-principal-with-secret;
https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.clientsecretcredential?view=azure-python.
Args:
azure_scope (str, optional): The Azure scope to request token for.
Defaults to environment variable AZURE_SCOPE or
"https://cognitiveservices.azure.com/.default".
Returns:
Callable that returns a temporary authentication token.
"""
import azure.identity as identity
from azure.identity import (
CertificateCredential,
ClientSecretCredential,
DefaultAzureCredential,
ManagedIdentityCredential,
get_bearer_token_provider,
)
if azure_scope is None:
azure_scope = (
os.environ.get("AZURE_SCOPE")
or "https://cognitiveservices.azure.com/.default"
)
cred: str = (
azure_credential.value
if azure_credential
else None
or os.environ.get("AZURE_CREDENTIAL")
or infer_credential_type_from_environment()
)
verbose_logger.info(
f"For Azure AD Token Provider, choosing credential type: {cred}"
)
credential: Optional[
Union[
ClientSecretCredential,
ManagedIdentityCredential,
CertificateCredential,
DefaultAzureCredential,
Any,
]
] = None
if cred == AzureCredentialType.ClientSecretCredential:
credential = ClientSecretCredential(
client_id=os.environ["AZURE_CLIENT_ID"],
client_secret=os.environ["AZURE_CLIENT_SECRET"],
tenant_id=os.environ["AZURE_TENANT_ID"],
)
elif cred == AzureCredentialType.ManagedIdentityCredential:
credential = ManagedIdentityCredential(client_id=os.environ["AZURE_CLIENT_ID"])
elif cred == AzureCredentialType.CertificateCredential:
if os.getenv("AZURE_CERTIFICATE_PASSWORD"):
credential = CertificateCredential(
client_id=os.environ["AZURE_CLIENT_ID"],
tenant_id=os.environ["AZURE_TENANT_ID"],
certificate_path=os.environ["AZURE_CERTIFICATE_PATH"],
password=os.environ["AZURE_CERTIFICATE_PASSWORD"],
)
else:
credential = CertificateCredential(
client_id=os.environ["AZURE_CLIENT_ID"],
tenant_id=os.environ["AZURE_TENANT_ID"],
certificate_path=os.environ["AZURE_CERTIFICATE_PATH"],
)
elif cred == AzureCredentialType.DefaultAzureCredential:
# DefaultAzureCredential doesn't require explicit environment variables
# It automatically discovers credentials from the environment (managed identity, CLI, etc.)
credential = DefaultAzureCredential()
else:
cred_cls = getattr(identity, cred)
credential = cred_cls()
if credential is None:
raise ValueError("No credential provided")
return get_bearer_token_provider(credential, azure_scope)

View File

@@ -0,0 +1,43 @@
"""
This is a file for the Google KMS integration
Relevant issue: https://github.com/BerriAI/litellm/issues/1235
Requires:
* `os.environ["GOOGLE_APPLICATION_CREDENTIALS"], os.environ["GOOGLE_KMS_RESOURCE_NAME"]`
* `pip install google-cloud-kms`
"""
import os
from typing import Optional
import litellm
from litellm.proxy._types import KeyManagementSystem
def validate_environment():
if "GOOGLE_APPLICATION_CREDENTIALS" not in os.environ:
raise ValueError(
"Missing required environment variable - GOOGLE_APPLICATION_CREDENTIALS"
)
if "GOOGLE_KMS_RESOURCE_NAME" not in os.environ:
raise ValueError(
"Missing required environment variable - GOOGLE_KMS_RESOURCE_NAME"
)
def load_google_kms(use_google_kms: Optional[bool]):
if use_google_kms is None or use_google_kms is False:
return
try:
from google.cloud import kms_v1 # type: ignore
validate_environment()
# Create the KMS client
client = kms_v1.KeyManagementServiceClient()
litellm.secret_manager_client = client
litellm._key_management_system = KeyManagementSystem.GOOGLE_KMS
litellm._google_kms_resource_name = os.getenv("GOOGLE_KMS_RESOURCE_NAME")
except Exception as e:
raise e

View File

@@ -0,0 +1,117 @@
import base64
import os
from typing import Optional
import litellm
from litellm._logging import verbose_logger
from litellm.caching.caching import InMemoryCache
from litellm.constants import SECRET_MANAGER_REFRESH_INTERVAL
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
from litellm.proxy._types import CommonProxyErrors, KeyManagementSystem
class GoogleSecretManager(GCSBucketBase):
def __init__(
self,
refresh_interval: Optional[int] = SECRET_MANAGER_REFRESH_INTERVAL,
always_read_secret_manager: Optional[bool] = False,
) -> None:
"""
Args:
refresh_interval (int, optional): The refresh interval in seconds. Defaults to 86400. (24 hours)
always_read_secret_manager (bool, optional): Whether to always read from the secret manager. Defaults to False. Since we do want to cache values
"""
from litellm.proxy.proxy_server import premium_user
if premium_user is not True:
raise ValueError(
f"Google Secret Manager requires an Enterprise License {CommonProxyErrors.not_premium_user.value}"
)
super().__init__()
self.PROJECT_ID = os.environ.get("GOOGLE_SECRET_MANAGER_PROJECT_ID", None)
if self.PROJECT_ID is None:
raise ValueError(
"Google Secret Manager requires a project ID, please set 'GOOGLE_SECRET_MANAGER_PROJECT_ID' in your .env"
)
self.sync_httpx_client = _get_httpx_client()
litellm.secret_manager_client = self
litellm._key_management_system = KeyManagementSystem.GOOGLE_SECRET_MANAGER
_refresh_interval = os.environ.get(
"GOOGLE_SECRET_MANAGER_REFRESH_INTERVAL", refresh_interval
)
_refresh_interval = (
int(_refresh_interval) if _refresh_interval else refresh_interval
)
self.cache = InMemoryCache(
default_ttl=_refresh_interval
) # store in memory for 1 day
_always_read_secret_manager = os.environ.get(
"GOOGLE_SECRET_MANAGER_ALWAYS_READ_SECRET_MANAGER",
)
if (
_always_read_secret_manager
and _always_read_secret_manager.lower() == "true"
):
self.always_read_secret_manager = True
else:
# by default this should be False, we want to use in memory caching for this. It's a bad idea to fetch from secret manager for all requests
self.always_read_secret_manager = always_read_secret_manager or False
def get_secret_from_google_secret_manager(self, secret_name: str) -> Optional[str]:
"""
Retrieve a secret from Google Secret Manager or cache.
Args:
secret_name (str): The name of the secret.
Returns:
str: The secret value if successful, None otherwise.
"""
if self.always_read_secret_manager is not True:
cached_secret = self.cache.get_cache(secret_name)
if cached_secret is not None:
return cached_secret
if secret_name in self.cache.cache_dict:
return cached_secret
_secret_name = (
f"projects/{self.PROJECT_ID}/secrets/{secret_name}/versions/latest"
)
headers = self.sync_construct_request_headers()
url = f"https://secretmanager.googleapis.com/v1/{_secret_name}:access"
# Send the GET request to retrieve the secret
response = self.sync_httpx_client.get(url=url, headers=headers)
if response.status_code != 200:
verbose_logger.error(
"Google Secret Manager retrieval error: %s", str(response.text)
)
self.cache.set_cache(
secret_name, None
) # Cache that the secret was not found
raise ValueError(
f"secret {secret_name} not found in Google Secret Manager. Error: {response.text}"
)
verbose_logger.debug(
"Google Secret Manager retrieval response status code: %s",
response.status_code,
)
# Parse the JSON response and return the secret value
secret_data = response.json()
_base64_encoded_value = secret_data.get("payload", {}).get("data")
# decode the base64 encoded value
if _base64_encoded_value is not None:
_decoded_value = base64.b64decode(_base64_encoded_value).decode("utf-8")
self.cache.set_cache(
secret_name, _decoded_value
) # Cache the retrieved secret
return _decoded_value
self.cache.set_cache(secret_name, None) # Cache that the secret was not found
raise ValueError(f"secret {secret_name} not found in Google Secret Manager")

View File

@@ -0,0 +1,677 @@
import os
from typing import Any, Dict, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.caching import InMemoryCache
from litellm.constants import SECRET_MANAGER_REFRESH_INTERVAL
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import KeyManagementSystem
from .base_secret_manager import BaseSecretManager
class HashicorpSecretManager(BaseSecretManager):
def __init__(self):
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
# Vault-specific config
self.vault_addr = os.getenv("HCP_VAULT_ADDR", "http://127.0.0.1:8200")
self.vault_token = os.getenv("HCP_VAULT_TOKEN", "")
# Vault namespace (for X-Vault-Namespace header)
self.vault_namespace = os.getenv("HCP_VAULT_NAMESPACE", None)
# KV engine mount name (default: "secret")
# If your KV engine is mounted somewhere other than "secret", set HCP_VAULT_MOUNT_NAME
self.vault_mount_name = os.getenv("HCP_VAULT_MOUNT_NAME", "secret")
# Optional path prefix for secrets (e.g., "myapp" -> secret/data/myapp/{secret_name})
self.vault_path_prefix = os.getenv("HCP_VAULT_PATH_PREFIX", None)
# Optional config for TLS cert auth
self.tls_cert_path = os.getenv("HCP_VAULT_CLIENT_CERT", "")
self.tls_key_path = os.getenv("HCP_VAULT_CLIENT_KEY", "")
self.vault_cert_role = os.getenv("HCP_VAULT_CERT_ROLE", None)
# Optional config for AppRole auth
self.approle_role_id = os.getenv("HCP_VAULT_APPROLE_ROLE_ID", "")
self.approle_secret_id = os.getenv("HCP_VAULT_APPROLE_SECRET_ID", "")
self.approle_mount_path = os.getenv("HCP_VAULT_APPROLE_MOUNT_PATH", "approle")
self._verify_required_credentials_exist()
if premium_user is not True:
raise ValueError(
f"Hashicorp secret manager is only available for premium users. {CommonProxyErrors.not_premium_user.value}"
)
litellm.secret_manager_client = self
litellm._key_management_system = KeyManagementSystem.HASHICORP_VAULT
_refresh_interval = os.environ.get(
"HCP_VAULT_REFRESH_INTERVAL", SECRET_MANAGER_REFRESH_INTERVAL
)
_refresh_interval = (
int(_refresh_interval)
if _refresh_interval
else SECRET_MANAGER_REFRESH_INTERVAL
)
self.cache = InMemoryCache(
default_ttl=_refresh_interval
) # store in memory for 1 day
def _verify_required_credentials_exist(self) -> None:
"""
Validate that at least one authentication method is configured.
Raises:
ValueError: If no valid authentication credentials are provided
"""
has_token = bool(self.vault_token)
has_approle = bool(self.approle_role_id and self.approle_secret_id)
has_tls_cert = bool(self.tls_cert_path and self.tls_key_path)
if not has_token and not has_approle and not has_tls_cert:
raise ValueError(
"Missing Vault authentication credentials. Please set either:\n"
" - HCP_VAULT_TOKEN for token-based auth, or\n"
" - HCP_VAULT_APPROLE_ROLE_ID and HCP_VAULT_APPROLE_SECRET_ID for AppRole auth, or\n"
" - HCP_VAULT_CLIENT_CERT and HCP_VAULT_CLIENT_KEY for TLS certificate auth"
)
def _auth_via_approle(self) -> str:
"""
Authenticate to Vault using AppRole auth method.
Ref: https://developer.hashicorp.com/vault/api-docs/auth/approle
Request:
```
curl \
--request POST \
--header "X-Vault-Namespace: mynamespace/" \
--data '{"role_id": "...", "secret_id": "..."}' \
http://127.0.0.1:8200/v1/auth/approle/login
```
Response:
```
{
"auth": {
"client_token": "hvs.CAESI...",
"accessor": "hmac-sha256...",
"policies": ["default", "dev-policy"],
"token_policies": ["default", "dev-policy"],
"lease_duration": 2764800,
"renewable": true
}
}
```
"""
verbose_logger.debug("Using AppRole auth for Hashicorp Vault")
# Check cache first
cached_token = self.cache.get_cache(key="hcp_vault_approle_token")
if cached_token:
verbose_logger.debug("Using cached Vault token from AppRole auth")
return cached_token
# Vault endpoint for AppRole login
login_url = f"{self.vault_addr}/v1/auth/{self.approle_mount_path}/login"
headers = {}
if hasattr(self, "vault_namespace") and self.vault_namespace:
headers["X-Vault-Namespace"] = self.vault_namespace
try:
client = _get_httpx_client()
resp = client.post(
url=login_url,
headers=headers,
json={
"role_id": self.approle_role_id,
"secret_id": self.approle_secret_id,
},
)
resp.raise_for_status()
auth_data = resp.json()["auth"]
token = auth_data["client_token"]
_lease_duration = auth_data["lease_duration"]
verbose_logger.debug(
f"Successfully obtained Vault token via AppRole auth. Lease duration: {_lease_duration}s"
)
# Cache the token with its lease duration
self.cache.set_cache(
key="hcp_vault_approle_token", value=token, ttl=_lease_duration
)
return token
except Exception as e:
raise RuntimeError(f"Could not authenticate to Vault via AppRole: {e}")
def _auth_via_tls_cert(self) -> str:
"""
Ref: https://developer.hashicorp.com/vault/api-docs/auth/cert
Request:
```
curl \
--request POST \
--cacert vault-ca.pem \
--cert cert.pem \
--key key.pem \
--header "X-Vault-Namespace: mynamespace/" \
--data '{"name": "my-cert-role"}' \
https://127.0.0.1:8200/v1/auth/cert/login
```
Response:
```
{
"auth": {
"client_token": "cf95f87d-f95b-47ff-b1f5-ba7bff850425",
"policies": ["web", "stage"],
"lease_duration": 3600,
"renewable": true
}
}
```
"""
verbose_logger.debug("Using TLS cert auth for Hashicorp Vault")
# Vault endpoint for cert-based login, e.g. '/v1/auth/cert/login'
login_url = f"{self.vault_addr}/v1/auth/cert/login"
# Include your Vault namespace in the header if you're using namespaces.
# E.g. self.vault_namespace = 'mynamespace/'
# If you only have root namespace, you can omit this header entirely.
headers = {}
if hasattr(self, "vault_namespace") and self.vault_namespace:
headers["X-Vault-Namespace"] = self.vault_namespace
try:
# We use the client cert and key for mutual TLS
client = httpx.Client(cert=(self.tls_cert_path, self.tls_key_path))
resp = client.post(
login_url,
headers=headers,
json=self._get_tls_cert_auth_body(),
)
resp.raise_for_status()
token = resp.json()["auth"]["client_token"]
_lease_duration = resp.json()["auth"]["lease_duration"]
verbose_logger.debug("Successfully obtained Vault token via TLS cert auth.")
self.cache.set_cache(
key="hcp_vault_token", value=token, ttl=_lease_duration
)
return token
except Exception as e:
raise RuntimeError(f"Could not authenticate to Vault via TLS cert: {e}")
def _get_tls_cert_auth_body(self) -> dict:
return {"name": self.vault_cert_role}
def get_url(
self,
secret_name: str,
namespace: Optional[str] = None,
mount_name: Optional[str] = None,
path_prefix: Optional[str] = None,
) -> str:
"""
Constructs the Vault URL for KV v2 secrets.
Format: {VAULT_ADDR}/v1/{NAMESPACE}/{MOUNT_NAME}/data/{PATH_PREFIX}/{SECRET_NAME}
Examples:
- Default: http://127.0.0.1:8200/v1/secret/data/mykey
- With namespace: http://127.0.0.1:8200/v1/mynamespace/secret/data/mykey
- With custom mount: http://127.0.0.1:8200/v1/kv/data/mykey
- With path prefix: http://127.0.0.1:8200/v1/secret/data/myapp/mykey
"""
resolved_namespace = self._sanitize_path_component(
namespace if namespace is not None else self.vault_namespace
)
resolved_mount = self._sanitize_path_component(
mount_name if mount_name is not None else self.vault_mount_name
)
if resolved_mount is None:
resolved_mount = "secret"
resolved_path_prefix = self._sanitize_path_component(
path_prefix if path_prefix is not None else self.vault_path_prefix
)
_url = f"{self.vault_addr}/v1/"
if resolved_namespace:
_url += f"{resolved_namespace}/"
_url += f"{resolved_mount}/data/"
if resolved_path_prefix:
_url += f"{resolved_path_prefix}/"
_url += secret_name
return _url
def _sanitize_plain_value(self, value: Optional[Union[str, int]]) -> Optional[str]:
if value is None:
return None
value_str = str(value).strip()
if value_str == "":
return None
return value_str
def _sanitize_path_component(
self, value: Optional[Union[str, int]]
) -> Optional[str]:
sanitized_value = self._sanitize_plain_value(value)
if sanitized_value is None:
return None
sanitized_value = sanitized_value.strip("/")
return sanitized_value or None
def _extract_secret_manager_settings(
self, optional_params: Optional[dict]
) -> Dict[str, Any]:
if not isinstance(optional_params, dict):
return {}
candidate = optional_params.get("secret_manager_settings")
source = candidate if isinstance(candidate, dict) else optional_params
allowed_keys = {"namespace", "mount", "path_prefix", "data"}
return {k: source[k] for k in allowed_keys if k in source}
def _build_secret_target(
self, secret_name: str, optional_params: Optional[dict]
) -> Dict[str, Any]:
settings = self._extract_secret_manager_settings(optional_params)
namespace = settings.get("namespace", self.vault_namespace)
mount = settings.get("mount", self.vault_mount_name)
path_prefix = settings.get("path_prefix", self.vault_path_prefix)
data_key_override = settings.get("data")
data_key = self._sanitize_plain_value(data_key_override) or "key"
url = self.get_url(
secret_name=secret_name,
namespace=namespace,
mount_name=mount,
path_prefix=path_prefix,
)
return {
"url": url,
"data_key": data_key,
"secret_name": secret_name,
}
def _get_request_headers(self) -> dict:
"""
Get the headers for Vault API requests.
Authentication priority:
1. AppRole (if role_id and secret_id are configured)
2. TLS Certificate (if cert paths are configured)
3. Direct token (if HCP_VAULT_TOKEN is set)
"""
# Priority 1: AppRole auth
if self.approle_role_id and self.approle_secret_id:
return {"X-Vault-Token": self._auth_via_approle()}
# Priority 2: TLS cert auth
if self.tls_cert_path and self.tls_key_path:
return {"X-Vault-Token": self._auth_via_tls_cert()}
# Priority 3: Direct token
return {"X-Vault-Token": self.vault_token}
async def async_read_secret(
self,
secret_name: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> Optional[str]:
"""
Reads a secret from Vault KV v2 using an async HTTPX client.
secret_name is just the path inside the KV mount (e.g., 'myapp/config').
Returns the entire data dict from data.data, or None on failure.
"""
if self.cache.get_cache(secret_name) is not None:
return self.cache.get_cache(secret_name)
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
)
try:
# For KV v2: /v1/<mount>/data/<path>
# Example: http://127.0.0.1:8200/v1/secret/data/myapp/config
_url = self.get_url(secret_name)
url = _url
response = await async_client.get(url, headers=self._get_request_headers())
response.raise_for_status()
# For KV v2, the secret is in response.json()["data"]["data"]
json_resp = response.json()
_value = self._get_secret_value_from_json_response(json_resp)
self.cache.set_cache(secret_name, _value)
return _value
except Exception as e:
verbose_logger.exception(f"Error reading secret from Hashicorp Vault: {e}")
return None
def sync_read_secret(
self,
secret_name: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> Optional[str]:
"""
Reads a secret from Vault KV v2 using a sync HTTPX client.
secret_name is just the path inside the KV mount (e.g., 'myapp/config').
Returns the entire data dict from data.data, or None on failure.
"""
if self.cache.get_cache(secret_name) is not None:
return self.cache.get_cache(secret_name)
sync_client = _get_httpx_client()
try:
# For KV v2: /v1/<mount>/data/<path>
url = self.get_url(secret_name)
response = sync_client.get(url, headers=self._get_request_headers())
response.raise_for_status()
# For KV v2, the secret is in response.json()["data"]["data"]
json_resp = response.json()
_value = self._get_secret_value_from_json_response(json_resp)
self.cache.set_cache(secret_name, _value)
return _value
except Exception as e:
verbose_logger.exception(f"Error reading secret from Hashicorp Vault: {e}")
return None
async def async_write_secret(
self,
secret_name: str,
secret_value: str,
description: Optional[str] = None,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
tags: Optional[Union[dict, list]] = None,
) -> Dict[str, Any]:
"""
Writes a secret to Vault KV v2 using an async HTTPX client.
Args:
secret_name: Path inside the KV mount (e.g., 'myapp/config')
secret_value: Value to store
description: Optional description for the secret
optional_params: Additional parameters to include in the secret data
timeout: Request timeout
Returns:
dict: Response containing status and details of the operation
"""
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"timeout": timeout},
)
try:
target = self._build_secret_target(secret_name, optional_params)
# Prepare the secret data
data = {"data": {target["data_key"]: secret_value}}
if description:
data["data"]["description"] = description
response = await async_client.post(
url=target["url"],
headers=self._get_request_headers(),
json=data,
)
response.raise_for_status()
return response.json()
except Exception as e:
verbose_logger.exception(f"Error writing secret to Hashicorp Vault: {e}")
return {"status": "error", "message": str(e)}
async def async_rotate_secret(
self,
current_secret_name: str,
new_secret_name: str,
new_secret_value: str,
optional_params: Dict | None = None,
timeout: float | httpx.Timeout | None = None,
) -> Dict:
"""
Rotates a secret by creating a new one and deleting the old one.
Uses _build_secret_target to handle optional_params for namespace, mount, path_prefix customization.
Args:
current_secret_name: Current name of the secret
new_secret_name: New name for the secret
new_secret_value: New value for the secret
optional_params: Additional parameters (namespace, mount, path_prefix, data)
timeout: Request timeout
Returns:
dict: Response containing status and details of the operation.
On success, returns the response from async_write_secret.
On error, returns {"status": "error", "message": "error message"}
"""
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"timeout": timeout},
)
try:
# First verify the old secret exists using _build_secret_target
current_target = self._build_secret_target(
current_secret_name, optional_params
)
try:
response = await async_client.get(
url=current_target["url"],
headers=self._get_request_headers(),
)
response.raise_for_status()
# Secret exists, we can proceed
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
verbose_logger.exception(
f"Current secret {current_secret_name} not found"
)
return {
"status": "error",
"message": f"Current secret {current_secret_name} not found",
}
verbose_logger.exception(
f"Error checking current secret existence: {e.response.text if hasattr(e, 'response') else str(e)}"
)
return {
"status": "error",
"message": f"HTTP error occurred while checking current secret: {e.response.text if hasattr(e, 'response') else str(e)}",
}
except Exception as e:
verbose_logger.exception(
f"Error checking current secret existence: {e}"
)
return {
"status": "error",
"message": f"Error checking current secret: {e}",
}
# Create new secret with new name and value
# Use _build_secret_target to handle optional_params
create_response = await self.async_write_secret(
secret_name=new_secret_name,
secret_value=new_secret_value,
description=f"Rotated from {current_secret_name}",
optional_params=optional_params,
timeout=timeout,
)
# Check if async_write_secret returned an error
if (
isinstance(create_response, dict)
and create_response.get("status") == "error"
):
return create_response
# Verify new secret was created successfully using _build_secret_target
new_target = self._build_secret_target(new_secret_name, optional_params)
try:
response = await async_client.get(
url=new_target["url"],
headers=self._get_request_headers(),
)
response.raise_for_status()
json_resp = response.json()
# Use data_key from target to get the correct value
data_key = new_target["data_key"]
new_secret_value_from_vault = (
json_resp.get("data", {}).get("data", {}).get(data_key, None)
)
if new_secret_value_from_vault != new_secret_value:
verbose_logger.exception(
f"New secret value mismatch. Expected: {new_secret_value}, Got: {new_secret_value_from_vault}"
)
return {
"status": "error",
"message": f"New secret value mismatch. Expected: {new_secret_value}, Got: {new_secret_value_from_vault}",
}
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
verbose_logger.exception(
f"Failed to verify new secret {new_secret_name}"
)
return {
"status": "error",
"message": f"Failed to verify new secret {new_secret_name}",
}
verbose_logger.exception(
f"Error verifying new secret: {e.response.text if hasattr(e, 'response') else str(e)}"
)
return {
"status": "error",
"message": f"HTTP error occurred while verifying new secret: {e.response.text if hasattr(e, 'response') else str(e)}",
}
except Exception as e:
verbose_logger.exception(f"Error verifying new secret: {e}")
return {
"status": "error",
"message": f"Error verifying new secret: {e}",
}
# If everything is successful, delete the old secret
# Only delete if the names are different (same name means we're just updating the value)
if current_secret_name != new_secret_name:
delete_response = await self.async_delete_secret(
secret_name=current_secret_name,
recovery_window_in_days=7, # Keep for recovery if needed
optional_params=optional_params,
timeout=timeout,
)
# Check if async_delete_secret returned an error
if (
isinstance(delete_response, dict)
and delete_response.get("status") == "error"
):
# Log the error but don't fail the rotation since new secret was created successfully
verbose_logger.warning(
f"Failed to delete old secret {current_secret_name} after rotation: {delete_response.get('message')}"
)
else:
# Clear cache for the old secret only if deletion was successful
self.cache.delete_cache(current_secret_name)
# Clear cache for the new secret (or updated secret if names are the same)
self.cache.delete_cache(new_secret_name)
return create_response
except httpx.TimeoutException:
verbose_logger.exception("Timeout error occurred during secret rotation")
return {"status": "error", "message": "Timeout error occurred"}
except Exception as e:
verbose_logger.exception(f"Error rotating secret in Hashicorp Vault: {e}")
return {"status": "error", "message": str(e)}
async def async_delete_secret(
self,
secret_name: str,
recovery_window_in_days: Optional[int] = 7,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> dict:
"""
Async function to delete a secret from Hashicorp Vault.
In KV v2, this marks the latest version of the secret as deleted.
Args:
secret_name: Name of the secret to delete
recovery_window_in_days: Not used for Vault (Vault handles this internally)
optional_params: Additional parameters specific to the secret manager
timeout: Request timeout
Returns:
dict: Response containing status and details of the operation
"""
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"timeout": timeout},
)
try:
target = self._build_secret_target(secret_name, optional_params)
response = await async_client.delete(
url=target["url"], headers=self._get_request_headers()
)
response.raise_for_status()
# Clear the cache for this secret
self.cache.delete_cache(secret_name)
if target["secret_name"] != secret_name:
self.cache.delete_cache(target["secret_name"])
return {
"status": "success",
"message": f"Secret {target['secret_name']} deleted successfully",
}
except Exception as e:
verbose_logger.exception(f"Error deleting secret from Hashicorp Vault: {e}")
return {"status": "error", "message": str(e)}
def _get_secret_value_from_json_response(
self, json_resp: Optional[dict]
) -> Optional[str]:
"""
Get the secret value from the JSON response
Json response from hashicorp vault is of the form:
{
"request_id":"036ba77c-018b-31dd-047b-323bcd0cd332",
"lease_id":"",
"renewable":false,
"lease_duration":0,
"data":
{"data":
{"key":"Vault Is The Way"},
"metadata":{"created_time":"2025-01-01T22:13:50.93942388Z","custom_metadata":null,"deletion_time":"","destroyed":false,"version":1}
},
"wrap_info":null,
"warnings":null,
"auth":null,
"mount_type":"kv"
}
Note: LiteLLM assumes that all secrets are stored as under the key "key"
"""
if json_resp is None:
return None
return json_resp.get("data", {}).get("data", {}).get("key", None)

View File

@@ -0,0 +1,290 @@
import ast
import os
import traceback
from typing import Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.caching.caching import DualCache
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.secret_managers.get_azure_ad_token_provider import (
get_azure_ad_token_provider,
)
from litellm.secret_managers.secret_manager_handler import get_secret_from_manager
oidc_cache = DualCache()
def _get_oidc_http_handler(timeout: Optional[httpx.Timeout] = None) -> HTTPHandler:
"""
Factory function to create HTTPHandler for OIDC requests.
This function can be mocked in tests.
Args:
timeout: Optional timeout for HTTP requests. Defaults to 600.0 seconds with 5.0 connect timeout.
Returns:
HTTPHandler instance configured for OIDC requests.
"""
if timeout is None:
timeout = httpx.Timeout(timeout=600.0, connect=5.0)
return HTTPHandler(timeout=timeout)
######### Secret Manager ############################
# checks if user has passed in a secret manager client
# if passed in then checks the secret there
def str_to_bool(value: Optional[str]) -> Optional[bool]:
"""
Converts a string to a boolean if it's a recognized boolean string.
Returns None if the string is not a recognized boolean value.
:param value: The string to be checked.
:return: True or False if the string is a recognized boolean, otherwise None.
"""
if value is None:
return None
true_values = {"true"}
false_values = {"false"}
value_lower = value.strip().lower()
if value_lower in true_values:
return True
elif value_lower in false_values:
return False
else:
return None
def get_secret_str(
secret_name: str,
default_value: Optional[Union[str, bool]] = None,
) -> Optional[str]:
"""
Guarantees response from 'get_secret' is either string or none. Used for fixing linting errors.
"""
value = get_secret(secret_name=secret_name, default_value=default_value)
if value is not None and not isinstance(value, str):
return None
return value
def get_secret_bool(
secret_name: str,
default_value: Optional[bool] = None,
) -> Optional[bool]:
"""
Guarantees response from 'get_secret' is either boolean or none. Used for fixing linting errors.
Args:
secret_name: The name of the secret to get.
default_value: The default value to return if the secret is not found.
Returns:
The secret value as a boolean or None if the secret is not found.
"""
_secret_value = get_secret(secret_name, default_value)
if _secret_value is None:
return None
elif isinstance(_secret_value, bool):
return _secret_value
else:
return str_to_bool(_secret_value)
def get_secret( # noqa: PLR0915
secret_name: str,
default_value: Optional[Union[str, bool]] = None,
):
key_management_system = litellm._key_management_system
key_management_settings = litellm._key_management_settings
secret = None
if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "")
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
if secret_name.startswith("oidc/"):
secret_name_split = secret_name.replace("oidc/", "")
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
oidc_aud = "/".join(secret_name_split.split("/")[1:])
# TODO: Add caching for HTTP requests
if oidc_provider == "google":
oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None:
return oidc_token
oidc_client = _get_oidc_http_handler()
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
response = oidc_client.get(
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
params={"audience": oidc_aud},
headers={"Metadata-Flavor": "Google"},
)
if response.status_code == 200:
oidc_token = response.text
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
return oidc_token
else:
raise ValueError("Google OIDC provider failed")
elif oidc_provider == "circleci":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
return env_secret
elif oidc_provider == "circleci_v2":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
return env_secret
elif oidc_provider == "github":
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
if (
actions_id_token_request_url is None
or actions_id_token_request_token is None
):
raise ValueError(
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
)
oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None:
return oidc_token
oidc_client = _get_oidc_http_handler()
response = oidc_client.get(
actions_id_token_request_url,
params={"audience": oidc_aud},
headers={
"Authorization": f"Bearer {actions_id_token_request_token}",
"Accept": "application/json; api-version=2.0",
},
)
if response.status_code == 200:
oidc_token = response.json().get("value", None)
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
return oidc_token
else:
raise ValueError("Github OIDC provider failed")
elif oidc_provider == "azure":
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
if azure_federated_token_file is None:
verbose_logger.warning(
"AZURE_FEDERATED_TOKEN_FILE not found in environment will use Azure AD token provider"
)
azure_token_provider = get_azure_ad_token_provider(azure_scope=oidc_aud)
try:
oidc_token = azure_token_provider()
if oidc_token is None:
raise ValueError("Azure OIDC provider returned None token")
return oidc_token
except Exception as e:
error_msg = f"Azure OIDC provider failed: {str(e)}"
verbose_logger.error(error_msg)
raise ValueError(error_msg)
with open(azure_federated_token_file, "r") as f:
oidc_token = f.read()
return oidc_token
elif oidc_provider == "file":
# Load token from a file
with open(oidc_aud, "r") as f:
oidc_token = f.read()
return oidc_token
elif oidc_provider == "env":
# Load token directly from an environment variable
oidc_token = os.getenv(oidc_aud)
if oidc_token is None:
raise ValueError(f"Environment variable {oidc_aud} not found")
return oidc_token
elif oidc_provider == "env_path":
# Load token from a file path specified in an environment variable
token_file_path = os.getenv(oidc_aud)
if token_file_path is None:
raise ValueError(f"Environment variable {oidc_aud} not found")
with open(token_file_path, "r") as f:
oidc_token = f.read()
return oidc_token
else:
raise ValueError("Unsupported OIDC provider")
try:
if (
_should_read_secret_from_secret_manager()
and litellm.secret_manager_client is not None
):
try:
client = litellm.secret_manager_client
key_manager = "local"
if key_management_system is not None:
key_manager = key_management_system.value
if key_management_settings is not None:
if (
key_management_settings.hosted_keys is not None
and secret_name not in key_management_settings.hosted_keys
): # allow user to specify which keys to check in hosted key manager
key_manager = "local"
# Delegate to the secret manager handler
secret = get_secret_from_manager(
client=client,
key_manager=key_manager,
secret_name=secret_name,
key_management_settings=key_management_settings,
)
except Exception as e: # check if it's in os.environ
verbose_logger.error(
f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}"
)
secret = os.getenv(secret_name)
try:
if isinstance(secret, str):
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except Exception:
return secret
else:
secret = os.environ.get(secret_name)
secret_value_as_bool = str_to_bool(secret) if secret is not None else None
if secret_value_as_bool is not None and isinstance(
secret_value_as_bool, bool
):
return secret_value_as_bool
else:
return secret
except Exception as e:
if default_value is not None:
return default_value
else:
raise e
def _should_read_secret_from_secret_manager() -> bool:
"""
Returns True if the secret manager should be used to read the secret, False otherwise
- If the secret manager client is not set, return False
- If the `_key_management_settings` access mode is "read_only" or "read_and_write", return True
- Otherwise, return False
"""
if litellm.secret_manager_client is not None:
if litellm._key_management_settings is not None:
if (
litellm._key_management_settings.access_mode == "read_only"
or litellm._key_management_settings.access_mode == "read_and_write"
):
return True
return False

View File

@@ -0,0 +1,182 @@
"""
Secret Manager Handler
Handles retrieving secrets from different secret management systems.
"""
import base64
import os
from typing import Any, Optional
import litellm
from litellm._logging import print_verbose
from litellm.types.secret_managers.main import KeyManagementSystem
def _is_base64(s):
"""Check if a string is valid base64."""
import binascii
try:
return base64.b64encode(base64.b64decode(s)).decode() == s
except binascii.Error:
return False
def get_secret_from_manager( # noqa: PLR0915
client: Any,
key_manager: str,
secret_name: str,
key_management_settings: Optional[Any] = None,
) -> Optional[str]:
"""
Get a secret from the configured secret manager.
Args:
client: The secret manager client instance
key_manager: The type of key manager (e.g., "azure_key_vault", "google_kms", etc.)
secret_name: The name/path of the secret to retrieve
key_management_settings: Optional settings for the key management system
Returns:
The secret value as a string, or None if not found
Raises:
ValueError: If the secret cannot be retrieved or required parameters are missing
Exception: For other errors during secret retrieval
"""
secret = None
if (
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
or type(client).__module__ + "." + type(client).__name__
== "azure.keyvault.secrets._client.SecretClient"
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
secret = client.get_secret(secret_name).value
elif (
key_manager == KeyManagementSystem.GOOGLE_KMS.value
or client.__class__.__name__ == "KeyManagementServiceClient"
):
encrypted_secret: Any = os.getenv(secret_name)
if encrypted_secret is None:
raise ValueError(
"Google KMS requires the encrypted secret to be in the environment!"
)
b64_flag = _is_base64(encrypted_secret)
if b64_flag is True: # if passed in as encoded b64 string
encrypted_secret = base64.b64decode(encrypted_secret)
ciphertext = encrypted_secret
else:
raise ValueError(
"Google KMS requires the encrypted secret to be encoded in base64"
) # fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce
response = client.decrypt(
request={
"name": litellm._google_kms_resource_name,
"ciphertext": ciphertext,
}
)
secret = response.plaintext.decode(
"utf-8"
) # assumes the original value was encoded with utf-8
elif key_manager == KeyManagementSystem.AWS_KMS.value:
"""
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
"""
encrypted_value = os.getenv(secret_name, None)
if encrypted_value is None:
raise Exception(
"AWS KMS - Encrypted Value of Key={} is None".format(secret_name)
)
# Decode the base64 encoded ciphertext
ciphertext_blob = base64.b64decode(encrypted_value)
# Set up the parameters for the decrypt call
params = {"CiphertextBlob": ciphertext_blob}
# Perform the decryption
response = client.decrypt(**params)
# Extract and decode the plaintext
plaintext = response["Plaintext"]
secret = plaintext.decode("utf-8")
if isinstance(secret, str):
secret = secret.strip()
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
from litellm.secret_managers.aws_secret_manager_v2 import (
AWSSecretsManagerV2,
)
if isinstance(client, AWSSecretsManagerV2):
primary_secret_name = None
if key_management_settings is not None:
primary_secret_name = key_management_settings.primary_secret_name
secret = client.sync_read_secret(
secret_name=secret_name,
primary_secret_name=primary_secret_name,
)
print_verbose(f"get_secret_value_response: {secret}")
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
try:
secret = client.get_secret_from_google_secret_manager(secret_name)
print_verbose(f"secret from google secret manager: {secret}")
if secret is None:
raise ValueError(
f"No secret found in Google Secret Manager for {secret_name}"
)
except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
raise e
elif key_manager == KeyManagementSystem.HASHICORP_VAULT.value:
try:
secret = client.sync_read_secret(secret_name=secret_name)
if secret is None:
raise ValueError(
f"No secret found in Hashicorp Secret Manager for {secret_name}"
)
except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
raise e
elif key_manager == KeyManagementSystem.CYBERARK.value:
try:
secret = client.sync_read_secret(secret_name=secret_name)
if secret is None:
raise ValueError(
f"No secret found in CyberArk Secret Manager for {secret_name}"
)
except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
raise e
elif key_manager == KeyManagementSystem.CUSTOM.value:
# Check if client is a CustomSecretManager instance
from litellm.integrations.custom_secret_manager import CustomSecretManager
if isinstance(client, CustomSecretManager):
secret = client.sync_read_secret(
secret_name=secret_name,
optional_params=key_management_settings.model_dump()
if key_management_settings
else None,
)
if secret is None:
raise ValueError(
f"No secret found in Custom Secret Manager for {secret_name}"
)
else:
raise ValueError(
f"Custom secret manager client must be an instance of CustomSecretManager, got {type(client).__name__}"
)
elif key_manager == "local":
secret = os.getenv(secret_name)
else: # assume the default is infisicial client
secret = client.get_secret(secret_name).secret_value
return secret