chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
## Supported Secret Managers to read credentials from
|
||||
|
||||
Example read OPENAI_API_KEY, AZURE_API_KEY from a secret manager
|
||||
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
This is a file for the AWS Secret Manager Integration
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/1883
|
||||
|
||||
Requires:
|
||||
* `os.environ["AWS_REGION_NAME"],
|
||||
* `pip install boto3>=1.28.57`
|
||||
"""
|
||||
|
||||
import ast
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import litellm
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
|
||||
|
||||
def validate_environment():
|
||||
if "AWS_REGION_NAME" not in os.environ:
|
||||
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
||||
|
||||
|
||||
def load_aws_kms(use_aws_kms: Optional[bool]):
|
||||
if use_aws_kms is None or use_aws_kms is False:
|
||||
return
|
||||
try:
|
||||
import boto3
|
||||
|
||||
validate_environment()
|
||||
|
||||
# Create a Secrets Manager client
|
||||
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
|
||||
|
||||
litellm.secret_manager_client = kms_client
|
||||
litellm._key_management_system = KeyManagementSystem.AWS_KMS
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
class AWSKeyManagementService_V2:
|
||||
"""
|
||||
V2 Clean Class for decrypting keys from AWS KeyManagementService
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.validate_environment()
|
||||
self.kms_client = self.load_aws_kms(use_aws_kms=True)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
):
|
||||
if "AWS_REGION_NAME" not in os.environ:
|
||||
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
||||
|
||||
## CHECK IF LICENSE IN ENV ## - premium feature
|
||||
is_litellm_license_in_env: bool = False
|
||||
|
||||
if os.getenv("LITELLM_LICENSE", None) is not None:
|
||||
is_litellm_license_in_env = True
|
||||
elif os.getenv("LITELLM_SECRET_AWS_KMS_LITELLM_LICENSE", None) is not None:
|
||||
is_litellm_license_in_env = True
|
||||
if is_litellm_license_in_env is False:
|
||||
raise ValueError(
|
||||
"AWSKeyManagementService V2 is an Enterprise Feature. Please add a valid LITELLM_LICENSE to your envionment."
|
||||
)
|
||||
|
||||
def load_aws_kms(self, use_aws_kms: Optional[bool]):
|
||||
if use_aws_kms is None or use_aws_kms is False:
|
||||
return
|
||||
try:
|
||||
import boto3
|
||||
|
||||
validate_environment()
|
||||
|
||||
# Create a Secrets Manager client
|
||||
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
|
||||
|
||||
return kms_client
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def decrypt_value(self, secret_name: str) -> Any:
|
||||
if self.kms_client is None:
|
||||
raise ValueError("kms_client is None")
|
||||
encrypted_value = os.getenv(secret_name, None)
|
||||
if encrypted_value is None:
|
||||
raise Exception(
|
||||
"AWS KMS - Encrypted Value of Key={} is None".format(secret_name)
|
||||
)
|
||||
if isinstance(encrypted_value, str) and encrypted_value.startswith("aws_kms/"):
|
||||
encrypted_value = encrypted_value.replace("aws_kms/", "")
|
||||
|
||||
# Decode the base64 encoded ciphertext
|
||||
ciphertext_blob = base64.b64decode(encrypted_value)
|
||||
|
||||
# Set up the parameters for the decrypt call
|
||||
params = {"CiphertextBlob": ciphertext_blob}
|
||||
# Perform the decryption
|
||||
response = self.kms_client.decrypt(**params)
|
||||
|
||||
# Extract and decode the plaintext
|
||||
plaintext = response["Plaintext"]
|
||||
secret = plaintext.decode("utf-8")
|
||||
if isinstance(secret, str):
|
||||
secret = secret.strip()
|
||||
try:
|
||||
secret_value_as_bool = ast.literal_eval(secret)
|
||||
if isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return secret
|
||||
|
||||
|
||||
"""
|
||||
- look for all values in the env with `aws_kms/<hashed_key>`
|
||||
- decrypt keys
|
||||
- rewrite env var with decrypted key (). Note: this environment variable will only be available to the current process and any child processes spawned from it. Once the Python script ends, the environment variable will not persist.
|
||||
"""
|
||||
|
||||
|
||||
def decrypt_env_var() -> Dict[str, Any]:
|
||||
# setup client class
|
||||
aws_kms = AWSKeyManagementService_V2()
|
||||
# iterate through env - for `aws_kms/`
|
||||
new_values = {}
|
||||
for k, v in os.environ.items():
|
||||
if (
|
||||
k is not None
|
||||
and isinstance(k, str)
|
||||
and k.lower().startswith("litellm_secret_aws_kms")
|
||||
) or (v is not None and isinstance(v, str) and v.startswith("aws_kms/")):
|
||||
decrypted_value = aws_kms.decrypt_value(secret_name=k)
|
||||
# reset env var
|
||||
k = re.sub("litellm_secret_aws_kms_", "", k, flags=re.IGNORECASE)
|
||||
new_values[k] = decrypted_value
|
||||
|
||||
return new_values
|
||||
@@ -0,0 +1,539 @@
|
||||
"""
|
||||
This is a file for the AWS Secret Manager Integration
|
||||
|
||||
Handles Async Operations for:
|
||||
- Read Secret
|
||||
- Write Secret (CreateSecret)
|
||||
- Update Secret (PutSecretValue) - for in-place rotation when alias is preserved
|
||||
- Delete Secret
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/1883
|
||||
|
||||
Requires:
|
||||
* `os.environ["AWS_REGION_NAME"],
|
||||
* `pip install boto3>=1.28.57`
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
|
||||
from .base_secret_manager import BaseSecretManager
|
||||
|
||||
|
||||
class AWSSecretsManagerV2(BaseAWSLLM, BaseSecretManager):
|
||||
def __init__(
|
||||
self,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
aws_session_name: Optional[str] = None,
|
||||
aws_external_id: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_web_identity_token: Optional[str] = None,
|
||||
aws_sts_endpoint: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
BaseSecretManager.__init__(self, **kwargs)
|
||||
BaseAWSLLM.__init__(self, **kwargs)
|
||||
|
||||
# Store AWS authentication settings
|
||||
self.aws_region_name = aws_region_name
|
||||
self.aws_role_name = aws_role_name
|
||||
self.aws_session_name = aws_session_name
|
||||
self.aws_external_id = aws_external_id
|
||||
self.aws_profile_name = aws_profile_name
|
||||
self.aws_web_identity_token = aws_web_identity_token
|
||||
self.aws_sts_endpoint = aws_sts_endpoint
|
||||
|
||||
@classmethod
|
||||
def validate_environment(cls):
|
||||
# AWS_REGION_NAME is only strictly required if not using a profile or role
|
||||
# When using IAM roles, the region can come from multiple sources
|
||||
if (
|
||||
"AWS_REGION_NAME" not in os.environ
|
||||
and "AWS_REGION" not in os.environ
|
||||
and "AWS_DEFAULT_REGION" not in os.environ
|
||||
):
|
||||
verbose_logger.warning(
|
||||
"No AWS region found in environment. Ensure aws_region_name is set in key_management_settings "
|
||||
"or AWS_REGION_NAME/AWS_REGION/AWS_DEFAULT_REGION is set in environment."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_aws_secret_manager(
|
||||
cls,
|
||||
use_aws_secret_manager: Optional[bool],
|
||||
key_management_settings: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
Initialize AWSSecretsManagerV2 with settings from key_management_settings
|
||||
"""
|
||||
if use_aws_secret_manager is None or use_aws_secret_manager is False:
|
||||
return
|
||||
try:
|
||||
cls.validate_environment()
|
||||
|
||||
# Extract AWS settings from key_management_settings if provided
|
||||
aws_kwargs = {}
|
||||
if key_management_settings is not None:
|
||||
aws_kwargs = {
|
||||
"aws_region_name": getattr(
|
||||
key_management_settings, "aws_region_name", None
|
||||
),
|
||||
"aws_role_name": getattr(
|
||||
key_management_settings, "aws_role_name", None
|
||||
),
|
||||
"aws_session_name": getattr(
|
||||
key_management_settings, "aws_session_name", None
|
||||
),
|
||||
"aws_external_id": getattr(
|
||||
key_management_settings, "aws_external_id", None
|
||||
),
|
||||
"aws_profile_name": getattr(
|
||||
key_management_settings, "aws_profile_name", None
|
||||
),
|
||||
"aws_web_identity_token": getattr(
|
||||
key_management_settings, "aws_web_identity_token", None
|
||||
),
|
||||
"aws_sts_endpoint": getattr(
|
||||
key_management_settings, "aws_sts_endpoint", None
|
||||
),
|
||||
}
|
||||
# Remove None values
|
||||
aws_kwargs = {k: v for k, v in aws_kwargs.items() if v is not None}
|
||||
|
||||
litellm.secret_manager_client = cls(**aws_kwargs)
|
||||
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def async_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
primary_secret_name: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Async function to read a secret from AWS Secrets Manager
|
||||
|
||||
Returns:
|
||||
str: Secret value
|
||||
Raises:
|
||||
ValueError: If the secret is not found or an HTTP error occurs
|
||||
"""
|
||||
if primary_secret_name:
|
||||
return await self.async_read_secret_from_primary_secret(
|
||||
secret_name=secret_name, primary_secret_name=primary_secret_name
|
||||
)
|
||||
|
||||
endpoint_url, headers, body = self._prepare_request(
|
||||
action="GetSecretValue",
|
||||
secret_name=secret_name,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_client.post(
|
||||
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["SecretString"]
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError("Timeout error occurred")
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"Error reading secret='%s' from AWS Secrets Manager: %s",
|
||||
secret_name,
|
||||
str(e),
|
||||
)
|
||||
return None
|
||||
|
||||
def sync_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
primary_secret_name: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Sync function to read a secret from AWS Secrets Manager
|
||||
|
||||
Done for backwards compatibility with existing codebase, since get_secret is a sync function
|
||||
"""
|
||||
# self._prepare_request uses these env vars, we cannot read them from AWS Secrets Manager. If we do we'd get stuck in an infinite loop
|
||||
if secret_name in [
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
"AWS_REGION_NAME",
|
||||
"AWS_REGION",
|
||||
"AWS_BEDROCK_RUNTIME_ENDPOINT",
|
||||
]:
|
||||
return os.getenv(secret_name)
|
||||
|
||||
if primary_secret_name:
|
||||
return self.sync_read_secret_from_primary_secret(
|
||||
secret_name=secret_name, primary_secret_name=primary_secret_name
|
||||
)
|
||||
|
||||
endpoint_url, headers, body = self._prepare_request(
|
||||
action="GetSecretValue",
|
||||
secret_name=secret_name,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
sync_client = _get_httpx_client(
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = sync_client.post(
|
||||
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||
)
|
||||
return response.json()["SecretString"]
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError("Timeout error occurred")
|
||||
except httpx.HTTPStatusError as e:
|
||||
verbose_logger.exception(
|
||||
"Error reading secret='%s' from AWS Secrets Manager: %s, %s",
|
||||
secret_name,
|
||||
str(e.response.text),
|
||||
str(e.response.status_code),
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"Error reading secret='%s' from AWS Secrets Manager: %s",
|
||||
secret_name,
|
||||
str(e),
|
||||
)
|
||||
return None
|
||||
|
||||
def _parse_primary_secret(self, primary_secret_json_str: Optional[str]) -> dict:
|
||||
"""
|
||||
Parse the primary secret JSON string into a dictionary
|
||||
|
||||
Args:
|
||||
primary_secret_json_str: JSON string containing key-value pairs
|
||||
|
||||
Returns:
|
||||
Dictionary of key-value pairs from the primary secret
|
||||
"""
|
||||
return json.loads(primary_secret_json_str or "{}")
|
||||
|
||||
def sync_read_secret_from_primary_secret(
|
||||
self, secret_name: str, primary_secret_name: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Read a secret from the primary secret
|
||||
"""
|
||||
primary_secret_json_str = self.sync_read_secret(secret_name=primary_secret_name)
|
||||
primary_secret_kv_pairs = self._parse_primary_secret(primary_secret_json_str)
|
||||
return primary_secret_kv_pairs.get(secret_name)
|
||||
|
||||
async def async_read_secret_from_primary_secret(
|
||||
self, secret_name: str, primary_secret_name: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Read a secret from the primary secret
|
||||
"""
|
||||
primary_secret_json_str = await self.async_read_secret(
|
||||
secret_name=primary_secret_name
|
||||
)
|
||||
primary_secret_kv_pairs = self._parse_primary_secret(primary_secret_json_str)
|
||||
return primary_secret_kv_pairs.get(secret_name)
|
||||
|
||||
async def async_write_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
secret_value: str,
|
||||
description: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
tags: Optional[Union[dict, list]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Async function to write a secret to AWS Secrets Manager
|
||||
|
||||
Args:
|
||||
secret_name: Name of the secret
|
||||
secret_value: Value to store (can be a JSON string)
|
||||
description: Optional description for the secret
|
||||
optional_params: Additional AWS parameters
|
||||
timeout: Request timeout
|
||||
tags: Optional dict or list of tags to apply, e.g.
|
||||
{"Environment": "Prod", "Owner": "AI-Platform"} or
|
||||
[{"Key": "Environment", "Value": "Prod"}]
|
||||
"""
|
||||
from litellm._uuid import uuid
|
||||
|
||||
data: Dict[str, Any] = {
|
||||
"Name": secret_name,
|
||||
"SecretString": secret_value,
|
||||
"ClientRequestToken": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
if description:
|
||||
data["Description"] = description
|
||||
|
||||
# ✅ Normalize tags to AWS format
|
||||
if tags:
|
||||
if isinstance(tags, dict):
|
||||
tags_list = [{"Key": k, "Value": str(v)} for k, v in tags.items()]
|
||||
elif isinstance(tags, list):
|
||||
tags_list = tags
|
||||
else:
|
||||
raise ValueError("Tags must be a dict or list of {Key, Value} pairs")
|
||||
data["Tags"] = tags_list # type: ignore[assignment]
|
||||
|
||||
endpoint_url, headers, body = self._prepare_request(
|
||||
action="CreateSecret",
|
||||
secret_name=secret_name,
|
||||
secret_value=secret_value,
|
||||
optional_params=optional_params,
|
||||
request_data=data,
|
||||
)
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_client.post(
|
||||
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as err:
|
||||
raise ValueError(f"HTTP error occurred: {err.response.text}")
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError("Timeout error occurred")
|
||||
|
||||
async def async_put_secret_value(
|
||||
self,
|
||||
secret_name: str,
|
||||
secret_value: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Async function to update an existing secret's value in AWS Secrets Manager.
|
||||
|
||||
Uses PutSecretValue to update in place. Use this when rotating a secret
|
||||
that keeps the same name (current_secret_name == new_secret_name).
|
||||
|
||||
Args:
|
||||
secret_name: Name of the existing secret to update
|
||||
secret_value: New value to store
|
||||
optional_params: Additional AWS parameters
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
dict: Response from AWS Secrets Manager containing update details
|
||||
"""
|
||||
from litellm._uuid import uuid
|
||||
|
||||
data: Dict[str, Any] = {
|
||||
"SecretId": secret_name,
|
||||
"SecretString": secret_value,
|
||||
"ClientRequestToken": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
endpoint_url, headers, body = self._prepare_request(
|
||||
action="PutSecretValue",
|
||||
secret_name=secret_name,
|
||||
secret_value=secret_value,
|
||||
optional_params=optional_params,
|
||||
request_data=data,
|
||||
)
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_client.post(
|
||||
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as err:
|
||||
raise ValueError(f"HTTP error occurred: {err.response.text}")
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError("Timeout error occurred")
|
||||
|
||||
async def async_rotate_secret(
|
||||
self,
|
||||
current_secret_name: str,
|
||||
new_secret_name: str,
|
||||
new_secret_value: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Rotate a secret. When current_secret_name == new_secret_name (in-place
|
||||
update), uses PutSecretValue instead of create+delete to avoid
|
||||
ResourceExistsException.
|
||||
"""
|
||||
if current_secret_name == new_secret_name:
|
||||
# Same alias: update in place via PutSecretValue
|
||||
verbose_logger.info(
|
||||
"Secret rotated in-place (PutSecretValue): secret_name=%s",
|
||||
current_secret_name,
|
||||
)
|
||||
return await self.async_put_secret_value(
|
||||
secret_name=current_secret_name,
|
||||
secret_value=new_secret_value,
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
)
|
||||
# Different names: create new, delete old (base class logic)
|
||||
return await super().async_rotate_secret(
|
||||
current_secret_name=current_secret_name,
|
||||
new_secret_name=new_secret_name,
|
||||
new_secret_value=new_secret_value,
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def async_delete_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
recovery_window_in_days: Optional[int] = 7,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Async function to delete a secret from AWS Secrets Manager
|
||||
|
||||
Args:
|
||||
secret_name: Name of the secret to delete
|
||||
recovery_window_in_days: Number of days before permanent deletion (default: 7)
|
||||
optional_params: Additional AWS parameters
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
dict: Response from AWS Secrets Manager containing deletion details
|
||||
"""
|
||||
# Prepare the request data
|
||||
data = {
|
||||
"SecretId": secret_name,
|
||||
"RecoveryWindowInDays": recovery_window_in_days,
|
||||
}
|
||||
|
||||
endpoint_url, headers, body = self._prepare_request(
|
||||
action="DeleteSecret",
|
||||
secret_name=secret_name,
|
||||
optional_params=optional_params,
|
||||
request_data=data,
|
||||
)
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_client.post(
|
||||
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as err:
|
||||
raise ValueError(f"HTTP error occurred: {err.response.text}")
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError("Timeout error occurred")
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
action: str, # "GetSecretValue" or "PutSecretValue"
|
||||
secret_name: str,
|
||||
secret_value: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
request_data: Optional[dict] = None,
|
||||
) -> tuple[str, Any, bytes]:
|
||||
"""Prepare the AWS Secrets Manager request"""
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
optional_params = optional_params or {}
|
||||
|
||||
# Build optional_params from instance settings if not provided
|
||||
# This allows the IAM role settings to be used for Secret Manager calls
|
||||
if not optional_params.get("aws_role_name") and self.aws_role_name:
|
||||
optional_params["aws_role_name"] = self.aws_role_name
|
||||
if not optional_params.get("aws_session_name") and self.aws_session_name:
|
||||
optional_params["aws_session_name"] = self.aws_session_name
|
||||
if not optional_params.get("aws_region_name") and self.aws_region_name:
|
||||
optional_params["aws_region_name"] = self.aws_region_name
|
||||
if not optional_params.get("aws_external_id") and self.aws_external_id:
|
||||
optional_params["aws_external_id"] = self.aws_external_id
|
||||
if not optional_params.get("aws_profile_name") and self.aws_profile_name:
|
||||
optional_params["aws_profile_name"] = self.aws_profile_name
|
||||
if (
|
||||
not optional_params.get("aws_web_identity_token")
|
||||
and self.aws_web_identity_token
|
||||
):
|
||||
optional_params["aws_web_identity_token"] = self.aws_web_identity_token
|
||||
if not optional_params.get("aws_sts_endpoint") and self.aws_sts_endpoint:
|
||||
optional_params["aws_sts_endpoint"] = self.aws_sts_endpoint
|
||||
|
||||
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
||||
optional_params
|
||||
)
|
||||
|
||||
# Get endpoint
|
||||
_, endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=None,
|
||||
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=boto3_credentials_info.aws_region_name,
|
||||
)
|
||||
endpoint_url = endpoint_url.replace("bedrock-runtime", "secretsmanager")
|
||||
|
||||
# Use provided request_data if available, otherwise build default data
|
||||
if request_data:
|
||||
data = request_data
|
||||
else:
|
||||
data = {"SecretId": secret_name}
|
||||
if secret_value and action == "PutSecretValue":
|
||||
data["SecretString"] = secret_value
|
||||
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
headers = {
|
||||
"Content-Type": "application/x-amz-json-1.1",
|
||||
"X-Amz-Target": f"secretsmanager.{action}",
|
||||
}
|
||||
|
||||
# Sign request
|
||||
request = AWSRequest(
|
||||
method="POST", url=endpoint_url, data=body, headers=headers
|
||||
)
|
||||
SigV4Auth(
|
||||
boto3_credentials_info.credentials,
|
||||
"secretsmanager",
|
||||
boto3_credentials_info.aws_region_name,
|
||||
).add_auth(request)
|
||||
prepped = request.prepare()
|
||||
|
||||
return endpoint_url, prepped.headers, body
|
||||
@@ -0,0 +1,180 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm import verbose_logger
|
||||
|
||||
|
||||
class BaseSecretManager(ABC):
|
||||
"""
|
||||
Abstract base class for secret management implementations.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def async_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Asynchronously read a secret from the secret manager.
|
||||
|
||||
Args:
|
||||
secret_name (str): Name/path of the secret to read
|
||||
optional_params (Optional[dict]): Additional parameters specific to the secret manager
|
||||
timeout (Optional[Union[float, httpx.Timeout]]): Request timeout
|
||||
|
||||
Returns:
|
||||
Optional[str]: The secret value if found, None otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Synchronously read a secret from the secret manager.
|
||||
|
||||
Args:
|
||||
secret_name (str): Name/path of the secret to read
|
||||
optional_params (Optional[dict]): Additional parameters specific to the secret manager
|
||||
timeout (Optional[Union[float, httpx.Timeout]]): Request timeout
|
||||
|
||||
Returns:
|
||||
Optional[str]: The secret value if found, None otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def async_write_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
secret_value: str,
|
||||
description: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
tags: Optional[Union[dict, list]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Asynchronously write a secret to the secret manager.
|
||||
|
||||
Args:
|
||||
secret_name (str): Name/path of the secret to write
|
||||
secret_value (str): Value to store
|
||||
description (Optional[str]): Description of the secret. Some secret managers allow storing a description with the secret.
|
||||
optional_params (Optional[dict]): Additional parameters specific to the secret manager
|
||||
timeout (Optional[Union[float, httpx.Timeout]]): Request timeout
|
||||
tags: Optional dict or list of tags to apply, e.g.
|
||||
{"Environment": "Prod", "Owner": "AI-Platform"} or
|
||||
[{"Key": "Environment", "Value": "Prod"}]
|
||||
Returns:
|
||||
Dict[str, Any]: Response from the secret manager containing write operation details
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def async_delete_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
recovery_window_in_days: Optional[int] = 7,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Async function to delete a secret from the secret manager
|
||||
|
||||
Args:
|
||||
secret_name: Name of the secret to delete
|
||||
recovery_window_in_days: Number of days before permanent deletion (default: 7)
|
||||
optional_params: Additional parameters specific to the secret manager
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
dict: Response from the secret manager containing deletion details
|
||||
"""
|
||||
pass
|
||||
|
||||
async def async_rotate_secret(
|
||||
self,
|
||||
current_secret_name: str,
|
||||
new_secret_name: str,
|
||||
new_secret_value: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Async function to rotate a secret by creating a new one and deleting the old one.
|
||||
This allows for both value and name changes during rotation.
|
||||
|
||||
Args:
|
||||
current_secret_name: Current name of the secret
|
||||
new_secret_name: New name for the secret
|
||||
new_secret_value: New value for the secret
|
||||
optional_params: Additional AWS parameters
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
dict: Response containing the new secret details
|
||||
|
||||
Raises:
|
||||
ValueError: If the secret doesn't exist or if there's an HTTP error
|
||||
"""
|
||||
try:
|
||||
# First verify the old secret exists
|
||||
old_secret = await self.async_read_secret(
|
||||
secret_name=current_secret_name,
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if old_secret is None:
|
||||
raise ValueError(f"Current secret {current_secret_name} not found")
|
||||
|
||||
# Create new secret with new name and value
|
||||
create_response = await self.async_write_secret(
|
||||
secret_name=new_secret_name,
|
||||
secret_value=new_secret_value,
|
||||
description=f"Rotated from {current_secret_name}",
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Verify new secret was created successfully
|
||||
new_secret = await self.async_read_secret(
|
||||
secret_name=new_secret_name,
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if new_secret is None:
|
||||
raise ValueError(f"Failed to verify new secret {new_secret_name}")
|
||||
|
||||
# If everything is successful, delete the old secret
|
||||
await self.async_delete_secret(
|
||||
secret_name=current_secret_name,
|
||||
recovery_window_in_days=7, # Keep for recovery if needed
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
return create_response
|
||||
|
||||
except httpx.HTTPStatusError as err:
|
||||
verbose_logger.exception(
|
||||
"Error rotating secret in AWS Secrets Manager: %s",
|
||||
str(err.response.text),
|
||||
)
|
||||
raise ValueError(f"HTTP error occurred: {err.response.text}")
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError("Timeout error occurred")
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"Error rotating secret in AWS Secrets Manager: %s", str(e)
|
||||
)
|
||||
raise
|
||||
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
Loader for custom secret managers.
|
||||
|
||||
Handles dynamic loading of user-defined secret manager classes from Python files.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_secret_manager import CustomSecretManager
|
||||
from litellm.types.secret_managers.main import KeyManagementSystem
|
||||
|
||||
|
||||
def load_custom_secret_manager(config_file_path: Optional[str] = None) -> None:
|
||||
"""
|
||||
Load and initialize a custom secret manager from a python file.
|
||||
|
||||
Similar to how custom guardrails are loaded - loads the class from
|
||||
the custom_secret_manager field in key_management_settings.
|
||||
|
||||
Args:
|
||||
config_file_path: Path to the config.yaml file
|
||||
|
||||
Raises:
|
||||
ValueError: If required configuration is missing
|
||||
ImportError: If the custom secret manager module cannot be loaded
|
||||
"""
|
||||
|
||||
if not config_file_path:
|
||||
raise ValueError(
|
||||
"CustomSecretManagerException - config_file_path is required to load custom secret manager"
|
||||
)
|
||||
|
||||
# Get the custom_secret_manager class path from settings
|
||||
if litellm._key_management_settings is None:
|
||||
raise ValueError(
|
||||
"CustomSecretManagerException - key_management_settings is required with custom_secret_manager field"
|
||||
)
|
||||
|
||||
custom_secret_manager_path = getattr(
|
||||
litellm._key_management_settings, "custom_secret_manager", None
|
||||
)
|
||||
|
||||
if not custom_secret_manager_path:
|
||||
raise ValueError(
|
||||
"CustomSecretManagerException - custom_secret_manager field is required in key_management_settings"
|
||||
)
|
||||
|
||||
# Split into file_name and class_name (e.g., "my_secret_manager.InMemorySecretManager")
|
||||
_file_name, _class_name = custom_secret_manager_path.split(".")
|
||||
verbose_proxy_logger.debug(
|
||||
"Initializing custom secret manager: %s, file_name: %s, class_name: %s",
|
||||
custom_secret_manager_path,
|
||||
_file_name,
|
||||
_class_name,
|
||||
)
|
||||
|
||||
# Load the module from the same directory as config.yaml
|
||||
directory = os.path.dirname(config_file_path)
|
||||
module_file_path = os.path.join(directory, _file_name) + ".py"
|
||||
|
||||
spec = importlib.util.spec_from_file_location(_class_name, module_file_path) # type: ignore
|
||||
if not spec:
|
||||
raise ImportError(
|
||||
f"Could not find a module specification for {module_file_path}"
|
||||
)
|
||||
|
||||
module = importlib.util.module_from_spec(spec) # type: ignore
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
_secret_manager_class = getattr(module, _class_name)
|
||||
|
||||
# Validate that it's a CustomSecretManager subclass
|
||||
if not issubclass(_secret_manager_class, CustomSecretManager):
|
||||
raise TypeError(
|
||||
f"CustomSecretManagerException - {_class_name} must be a subclass of CustomSecretManager"
|
||||
)
|
||||
|
||||
# Instantiate the custom secret manager
|
||||
_secret_manager_instance = _secret_manager_class()
|
||||
|
||||
# Set it as the secret manager client
|
||||
litellm.secret_manager_client = _secret_manager_instance
|
||||
|
||||
# Set the key management system to CUSTOM so get_secret knows to use it
|
||||
litellm._key_management_system = KeyManagementSystem.CUSTOM
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
"Successfully initialized custom secret manager: %s",
|
||||
custom_secret_manager_path,
|
||||
)
|
||||
@@ -0,0 +1,350 @@
|
||||
import base64
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching import InMemoryCache
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
|
||||
from .base_secret_manager import BaseSecretManager
|
||||
from .main import str_to_bool
|
||||
|
||||
|
||||
class CyberArkSecretManager(BaseSecretManager):
|
||||
def __init__(self):
|
||||
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
|
||||
|
||||
# CyberArk Conjur-specific config
|
||||
self.conjur_addr = os.getenv("CYBERARK_API_BASE", "http://127.0.0.1:8080")
|
||||
self.conjur_account = os.getenv("CYBERARK_ACCOUNT", "default")
|
||||
self.conjur_username = os.getenv("CYBERARK_USERNAME", "admin")
|
||||
self.conjur_api_key = os.getenv("CYBERARK_API_KEY", "")
|
||||
|
||||
# Optional config for certificate-based auth
|
||||
self.tls_cert_path = os.getenv("CYBERARK_CLIENT_CERT", "")
|
||||
self.tls_key_path = os.getenv("CYBERARK_CLIENT_KEY", "")
|
||||
|
||||
# SSL verification - can be disabled for self-signed certificates
|
||||
# Set CYBERARK_SSL_VERIFY=false to disable SSL verification
|
||||
ssl_verify_env = str_to_bool(os.getenv("CYBERARK_SSL_VERIFY"))
|
||||
self.ssl_verify: bool = ssl_verify_env if ssl_verify_env is not None else True
|
||||
|
||||
# Validate environment
|
||||
if not self.conjur_api_key and not (self.tls_cert_path and self.tls_key_path):
|
||||
raise ValueError(
|
||||
"Missing CyberArk credentials. Please set CYBERARK_API_KEY or both CYBERARK_CLIENT_CERT and CYBERARK_CLIENT_KEY in your environment."
|
||||
)
|
||||
|
||||
litellm.secret_manager_client = self
|
||||
litellm._key_management_system = KeyManagementSystem.CYBERARK
|
||||
|
||||
# Tokens expire after ~8 minutes, so we cache for 5 minutes to be safe
|
||||
_refresh_interval = int(os.environ.get("CYBERARK_REFRESH_INTERVAL", "300"))
|
||||
self.cache = InMemoryCache(default_ttl=_refresh_interval)
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"CyberArk secret manager is only available for premium users. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
|
||||
if not self.ssl_verify:
|
||||
verbose_logger.warning(
|
||||
"CyberArk SSL verification is disabled. This is insecure and should only be used for testing with self-signed certificates."
|
||||
)
|
||||
|
||||
def _authenticate(self) -> str:
|
||||
"""
|
||||
Authenticate with CyberArk Conjur and get a session token.
|
||||
|
||||
The token is a JSON object that must be base64-encoded for use in subsequent requests.
|
||||
|
||||
Returns:
|
||||
str: Base64-encoded session token
|
||||
"""
|
||||
# Check if we have a cached token
|
||||
cached_token = self.cache.get_cache("cyberark_auth_token")
|
||||
if cached_token is not None:
|
||||
return cached_token
|
||||
|
||||
verbose_logger.debug("Authenticating with CyberArk Conjur...")
|
||||
auth_url = f"{self.conjur_addr}/authn/{self.conjur_account}/{self.conjur_username}/authenticate"
|
||||
|
||||
try:
|
||||
if self.tls_cert_path and self.tls_key_path:
|
||||
# Certificate-based authentication - need custom client for cert
|
||||
http_client = httpx.Client(
|
||||
cert=(self.tls_cert_path, self.tls_key_path),
|
||||
verify=self.ssl_verify,
|
||||
)
|
||||
resp = http_client.post(auth_url, content=self.conjur_api_key)
|
||||
else:
|
||||
# API key authentication
|
||||
http_handler = _get_httpx_client(params={"ssl_verify": self.ssl_verify})
|
||||
resp = http_handler.client.post(auth_url, content=self.conjur_api_key)
|
||||
|
||||
resp.raise_for_status()
|
||||
|
||||
# The response is a JSON token that needs to be base64-encoded
|
||||
token_json = resp.text
|
||||
token_b64 = base64.b64encode(token_json.encode()).decode()
|
||||
|
||||
verbose_logger.debug("Successfully authenticated with CyberArk Conjur.")
|
||||
|
||||
# Cache the token for the refresh interval
|
||||
self.cache.set_cache(key="cyberark_auth_token", value=token_b64)
|
||||
|
||||
return token_b64
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Could not authenticate to CyberArk Conjur: {e}")
|
||||
|
||||
def _get_request_headers(self) -> dict:
|
||||
"""
|
||||
Get headers for CyberArk API requests including authentication.
|
||||
|
||||
Returns:
|
||||
dict: Headers with authentication token
|
||||
"""
|
||||
token = self._authenticate()
|
||||
return {"Authorization": f'Token token="{token}"'}
|
||||
|
||||
def _ensure_variable_exists(self, secret_name: str) -> None:
|
||||
"""
|
||||
Ensure a variable exists in CyberArk Conjur by creating a policy entry if needed.
|
||||
|
||||
Args:
|
||||
secret_name: Name of the variable to ensure exists
|
||||
"""
|
||||
# In production, we'd check if the variable exists first
|
||||
# For now, we'll attempt to create it and ignore if it already exists
|
||||
policy_url = f"{self.conjur_addr}/policies/{self.conjur_account}/policy/root"
|
||||
policy_yaml = f"- !variable {secret_name}\n"
|
||||
|
||||
try:
|
||||
client = _get_httpx_client(params={"ssl_verify": self.ssl_verify})
|
||||
resp = client.client.post(
|
||||
policy_url,
|
||||
headers={
|
||||
**self._get_request_headers(),
|
||||
"Content-Type": "application/x-yaml",
|
||||
},
|
||||
content=policy_yaml,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
verbose_logger.debug(f"Created policy entry for variable: {secret_name}")
|
||||
except httpx.HTTPStatusError as e:
|
||||
# Variable might already exist, which is fine
|
||||
if e.response.status_code in [409, 422]:
|
||||
verbose_logger.debug(
|
||||
f"Variable {secret_name} already exists or policy conflict (expected)"
|
||||
)
|
||||
else:
|
||||
verbose_logger.warning(
|
||||
f"Could not ensure variable exists: {e.response.status_code} - {e.response.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Error ensuring variable exists: {e}")
|
||||
|
||||
def get_url(self, secret_name: str) -> str:
|
||||
"""
|
||||
Build the URL for accessing a secret in CyberArk Conjur.
|
||||
|
||||
Args:
|
||||
secret_name: Name of the secret (will be URL-encoded)
|
||||
|
||||
Returns:
|
||||
str: Full URL for the secret
|
||||
"""
|
||||
# URL-encode the secret name to handle slashes and special characters
|
||||
encoded_name = quote(secret_name, safe="")
|
||||
return (
|
||||
f"{self.conjur_addr}/secrets/{self.conjur_account}/variable/{encoded_name}"
|
||||
)
|
||||
|
||||
async def async_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Reads a secret from CyberArk Conjur using an async HTTPX client.
|
||||
|
||||
Args:
|
||||
secret_name: Name/path of the secret to read
|
||||
optional_params: Additional parameters (not used for Conjur)
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
Optional[str]: The secret value if found, None otherwise
|
||||
"""
|
||||
# Check cache first
|
||||
if self.cache.get_cache(secret_name) is not None:
|
||||
return self.cache.get_cache(secret_name)
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
params={"ssl_verify": self.ssl_verify},
|
||||
)
|
||||
|
||||
try:
|
||||
url = self.get_url(secret_name)
|
||||
response = await async_client.get(url, headers=self._get_request_headers())
|
||||
response.raise_for_status()
|
||||
|
||||
# CyberArk Conjur returns the raw secret value as text
|
||||
secret_value = response.text
|
||||
self.cache.set_cache(secret_name, secret_value)
|
||||
return secret_value
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 404:
|
||||
verbose_logger.debug(
|
||||
f"Secret {secret_name} not found in CyberArk Conjur"
|
||||
)
|
||||
else:
|
||||
verbose_logger.exception(
|
||||
f"Error reading secret from CyberArk Conjur: {e}"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error reading secret from CyberArk Conjur: {e}")
|
||||
return None
|
||||
|
||||
def sync_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Reads a secret from CyberArk Conjur using a sync HTTPX client.
|
||||
|
||||
Args:
|
||||
secret_name: Name/path of the secret to read
|
||||
optional_params: Additional parameters (not used for Conjur)
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
Optional[str]: The secret value if found, None otherwise
|
||||
"""
|
||||
# Check cache first
|
||||
if self.cache.get_cache(secret_name) is not None:
|
||||
return self.cache.get_cache(secret_name)
|
||||
|
||||
sync_client = _get_httpx_client(params={"ssl_verify": self.ssl_verify})
|
||||
|
||||
try:
|
||||
url = self.get_url(secret_name)
|
||||
response = sync_client.client.get(url, headers=self._get_request_headers())
|
||||
response.raise_for_status()
|
||||
|
||||
# CyberArk Conjur returns the raw secret value as text
|
||||
secret_value = response.text
|
||||
self.cache.set_cache(secret_name, secret_value)
|
||||
return secret_value
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 404:
|
||||
verbose_logger.debug(
|
||||
f"Secret {secret_name} not found in CyberArk Conjur"
|
||||
)
|
||||
else:
|
||||
verbose_logger.exception(
|
||||
f"Error reading secret from CyberArk Conjur: {e}"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error reading secret from CyberArk Conjur: {e}")
|
||||
return None
|
||||
|
||||
async def async_write_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
secret_value: str,
|
||||
description: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
tags: Optional[Union[dict, list]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Writes a secret to CyberArk Conjur using an async HTTPX client.
|
||||
|
||||
Args:
|
||||
secret_name: Name/path of the secret to write
|
||||
secret_value: Value to store
|
||||
description: Optional description (not used by Conjur)
|
||||
optional_params: Additional parameters
|
||||
timeout: Request timeout
|
||||
tags: Optional tags (not used by Conjur)
|
||||
|
||||
Returns:
|
||||
dict: Response containing status and details of the operation
|
||||
"""
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
params={"ssl_verify": self.ssl_verify},
|
||||
)
|
||||
|
||||
try:
|
||||
# Ensure the variable exists in the policy first
|
||||
self._ensure_variable_exists(secret_name)
|
||||
|
||||
# Now set the secret value
|
||||
url = self.get_url(secret_name)
|
||||
response = await async_client.post(
|
||||
url=url, headers=self._get_request_headers(), content=secret_value
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Update cache
|
||||
self.cache.set_cache(secret_name, secret_value)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Secret {secret_name} written successfully",
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error writing secret to CyberArk Conjur: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def async_delete_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
recovery_window_in_days: Optional[int] = 7,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
CyberArk Conjur does not support direct secret deletion via API.
|
||||
Secrets can only be removed through policy updates.
|
||||
|
||||
Args:
|
||||
secret_name: Name of the secret
|
||||
recovery_window_in_days: Not used
|
||||
optional_params: Additional parameters
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
dict: Response indicating operation not supported
|
||||
"""
|
||||
verbose_logger.warning(
|
||||
"CyberArk Conjur does not support direct secret deletion. "
|
||||
"Secrets must be removed through policy updates."
|
||||
)
|
||||
|
||||
# Clear from cache
|
||||
self.cache.delete_cache(secret_name)
|
||||
|
||||
return {
|
||||
"status": "not_supported",
|
||||
"message": "CyberArk Conjur does not support direct secret deletion. Use policy updates to remove variables.",
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
import os
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.secret_managers.get_azure_ad_token_provider import (
|
||||
AzureCredentialType,
|
||||
)
|
||||
|
||||
|
||||
def infer_credential_type_from_environment() -> AzureCredentialType:
|
||||
if (
|
||||
os.environ.get("AZURE_CLIENT_ID")
|
||||
and os.environ.get("AZURE_CLIENT_SECRET")
|
||||
and os.environ.get("AZURE_TENANT_ID")
|
||||
):
|
||||
return AzureCredentialType.ClientSecretCredential
|
||||
elif os.environ.get("AZURE_CLIENT_ID"):
|
||||
return AzureCredentialType.ManagedIdentityCredential
|
||||
elif (
|
||||
os.environ.get("AZURE_CLIENT_ID")
|
||||
and os.environ.get("AZURE_TENANT_ID")
|
||||
and os.environ.get("AZURE_CERTIFICATE_PATH")
|
||||
and os.environ.get("AZURE_CERTIFICATE_PASSWORD")
|
||||
):
|
||||
return AzureCredentialType.CertificateCredential
|
||||
elif os.environ.get("AZURE_CERTIFICATE_PASSWORD"):
|
||||
return AzureCredentialType.CertificateCredential
|
||||
elif os.environ.get("AZURE_CERTIFICATE_PATH"):
|
||||
return AzureCredentialType.CertificateCredential
|
||||
else:
|
||||
return AzureCredentialType.DefaultAzureCredential
|
||||
|
||||
|
||||
def get_azure_ad_token_provider(
|
||||
azure_scope: Optional[str] = None,
|
||||
azure_credential: Optional[AzureCredentialType] = None,
|
||||
) -> Callable[[], str]:
|
||||
"""
|
||||
Get Azure AD token provider based on Service Principal with Secret workflow.
|
||||
|
||||
Based on: https://github.com/openai/openai-python/blob/main/examples/azure_ad.py
|
||||
See Also:
|
||||
https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#service-principal-with-secret;
|
||||
https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.clientsecretcredential?view=azure-python.
|
||||
|
||||
Args:
|
||||
azure_scope (str, optional): The Azure scope to request token for.
|
||||
Defaults to environment variable AZURE_SCOPE or
|
||||
"https://cognitiveservices.azure.com/.default".
|
||||
|
||||
Returns:
|
||||
Callable that returns a temporary authentication token.
|
||||
"""
|
||||
import azure.identity as identity
|
||||
from azure.identity import (
|
||||
CertificateCredential,
|
||||
ClientSecretCredential,
|
||||
DefaultAzureCredential,
|
||||
ManagedIdentityCredential,
|
||||
get_bearer_token_provider,
|
||||
)
|
||||
|
||||
if azure_scope is None:
|
||||
azure_scope = (
|
||||
os.environ.get("AZURE_SCOPE")
|
||||
or "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
|
||||
cred: str = (
|
||||
azure_credential.value
|
||||
if azure_credential
|
||||
else None
|
||||
or os.environ.get("AZURE_CREDENTIAL")
|
||||
or infer_credential_type_from_environment()
|
||||
)
|
||||
verbose_logger.info(
|
||||
f"For Azure AD Token Provider, choosing credential type: {cred}"
|
||||
)
|
||||
credential: Optional[
|
||||
Union[
|
||||
ClientSecretCredential,
|
||||
ManagedIdentityCredential,
|
||||
CertificateCredential,
|
||||
DefaultAzureCredential,
|
||||
Any,
|
||||
]
|
||||
] = None
|
||||
if cred == AzureCredentialType.ClientSecretCredential:
|
||||
credential = ClientSecretCredential(
|
||||
client_id=os.environ["AZURE_CLIENT_ID"],
|
||||
client_secret=os.environ["AZURE_CLIENT_SECRET"],
|
||||
tenant_id=os.environ["AZURE_TENANT_ID"],
|
||||
)
|
||||
elif cred == AzureCredentialType.ManagedIdentityCredential:
|
||||
credential = ManagedIdentityCredential(client_id=os.environ["AZURE_CLIENT_ID"])
|
||||
elif cred == AzureCredentialType.CertificateCredential:
|
||||
if os.getenv("AZURE_CERTIFICATE_PASSWORD"):
|
||||
credential = CertificateCredential(
|
||||
client_id=os.environ["AZURE_CLIENT_ID"],
|
||||
tenant_id=os.environ["AZURE_TENANT_ID"],
|
||||
certificate_path=os.environ["AZURE_CERTIFICATE_PATH"],
|
||||
password=os.environ["AZURE_CERTIFICATE_PASSWORD"],
|
||||
)
|
||||
else:
|
||||
credential = CertificateCredential(
|
||||
client_id=os.environ["AZURE_CLIENT_ID"],
|
||||
tenant_id=os.environ["AZURE_TENANT_ID"],
|
||||
certificate_path=os.environ["AZURE_CERTIFICATE_PATH"],
|
||||
)
|
||||
elif cred == AzureCredentialType.DefaultAzureCredential:
|
||||
# DefaultAzureCredential doesn't require explicit environment variables
|
||||
# It automatically discovers credentials from the environment (managed identity, CLI, etc.)
|
||||
credential = DefaultAzureCredential()
|
||||
else:
|
||||
cred_cls = getattr(identity, cred)
|
||||
credential = cred_cls()
|
||||
|
||||
if credential is None:
|
||||
raise ValueError("No credential provided")
|
||||
|
||||
return get_bearer_token_provider(credential, azure_scope)
|
||||
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
This is a file for the Google KMS integration
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/1235
|
||||
|
||||
Requires:
|
||||
* `os.environ["GOOGLE_APPLICATION_CREDENTIALS"], os.environ["GOOGLE_KMS_RESOURCE_NAME"]`
|
||||
* `pip install google-cloud-kms`
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
|
||||
|
||||
def validate_environment():
|
||||
if "GOOGLE_APPLICATION_CREDENTIALS" not in os.environ:
|
||||
raise ValueError(
|
||||
"Missing required environment variable - GOOGLE_APPLICATION_CREDENTIALS"
|
||||
)
|
||||
if "GOOGLE_KMS_RESOURCE_NAME" not in os.environ:
|
||||
raise ValueError(
|
||||
"Missing required environment variable - GOOGLE_KMS_RESOURCE_NAME"
|
||||
)
|
||||
|
||||
|
||||
def load_google_kms(use_google_kms: Optional[bool]):
|
||||
if use_google_kms is None or use_google_kms is False:
|
||||
return
|
||||
try:
|
||||
from google.cloud import kms_v1 # type: ignore
|
||||
|
||||
validate_environment()
|
||||
|
||||
# Create the KMS client
|
||||
client = kms_v1.KeyManagementServiceClient()
|
||||
litellm.secret_manager_client = client
|
||||
litellm._key_management_system = KeyManagementSystem.GOOGLE_KMS
|
||||
litellm._google_kms_resource_name = os.getenv("GOOGLE_KMS_RESOURCE_NAME")
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,117 @@
|
||||
import base64
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching.caching import InMemoryCache
|
||||
from litellm.constants import SECRET_MANAGER_REFRESH_INTERVAL
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
|
||||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||
from litellm.proxy._types import CommonProxyErrors, KeyManagementSystem
|
||||
|
||||
|
||||
class GoogleSecretManager(GCSBucketBase):
|
||||
def __init__(
|
||||
self,
|
||||
refresh_interval: Optional[int] = SECRET_MANAGER_REFRESH_INTERVAL,
|
||||
always_read_secret_manager: Optional[bool] = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
refresh_interval (int, optional): The refresh interval in seconds. Defaults to 86400. (24 hours)
|
||||
always_read_secret_manager (bool, optional): Whether to always read from the secret manager. Defaults to False. Since we do want to cache values
|
||||
"""
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"Google Secret Manager requires an Enterprise License {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
super().__init__()
|
||||
self.PROJECT_ID = os.environ.get("GOOGLE_SECRET_MANAGER_PROJECT_ID", None)
|
||||
if self.PROJECT_ID is None:
|
||||
raise ValueError(
|
||||
"Google Secret Manager requires a project ID, please set 'GOOGLE_SECRET_MANAGER_PROJECT_ID' in your .env"
|
||||
)
|
||||
self.sync_httpx_client = _get_httpx_client()
|
||||
litellm.secret_manager_client = self
|
||||
litellm._key_management_system = KeyManagementSystem.GOOGLE_SECRET_MANAGER
|
||||
_refresh_interval = os.environ.get(
|
||||
"GOOGLE_SECRET_MANAGER_REFRESH_INTERVAL", refresh_interval
|
||||
)
|
||||
_refresh_interval = (
|
||||
int(_refresh_interval) if _refresh_interval else refresh_interval
|
||||
)
|
||||
self.cache = InMemoryCache(
|
||||
default_ttl=_refresh_interval
|
||||
) # store in memory for 1 day
|
||||
|
||||
_always_read_secret_manager = os.environ.get(
|
||||
"GOOGLE_SECRET_MANAGER_ALWAYS_READ_SECRET_MANAGER",
|
||||
)
|
||||
if (
|
||||
_always_read_secret_manager
|
||||
and _always_read_secret_manager.lower() == "true"
|
||||
):
|
||||
self.always_read_secret_manager = True
|
||||
else:
|
||||
# by default this should be False, we want to use in memory caching for this. It's a bad idea to fetch from secret manager for all requests
|
||||
self.always_read_secret_manager = always_read_secret_manager or False
|
||||
|
||||
def get_secret_from_google_secret_manager(self, secret_name: str) -> Optional[str]:
|
||||
"""
|
||||
Retrieve a secret from Google Secret Manager or cache.
|
||||
|
||||
Args:
|
||||
secret_name (str): The name of the secret.
|
||||
|
||||
Returns:
|
||||
str: The secret value if successful, None otherwise.
|
||||
"""
|
||||
if self.always_read_secret_manager is not True:
|
||||
cached_secret = self.cache.get_cache(secret_name)
|
||||
if cached_secret is not None:
|
||||
return cached_secret
|
||||
if secret_name in self.cache.cache_dict:
|
||||
return cached_secret
|
||||
|
||||
_secret_name = (
|
||||
f"projects/{self.PROJECT_ID}/secrets/{secret_name}/versions/latest"
|
||||
)
|
||||
headers = self.sync_construct_request_headers()
|
||||
url = f"https://secretmanager.googleapis.com/v1/{_secret_name}:access"
|
||||
|
||||
# Send the GET request to retrieve the secret
|
||||
response = self.sync_httpx_client.get(url=url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
verbose_logger.error(
|
||||
"Google Secret Manager retrieval error: %s", str(response.text)
|
||||
)
|
||||
self.cache.set_cache(
|
||||
secret_name, None
|
||||
) # Cache that the secret was not found
|
||||
raise ValueError(
|
||||
f"secret {secret_name} not found in Google Secret Manager. Error: {response.text}"
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"Google Secret Manager retrieval response status code: %s",
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
# Parse the JSON response and return the secret value
|
||||
secret_data = response.json()
|
||||
_base64_encoded_value = secret_data.get("payload", {}).get("data")
|
||||
|
||||
# decode the base64 encoded value
|
||||
if _base64_encoded_value is not None:
|
||||
_decoded_value = base64.b64decode(_base64_encoded_value).decode("utf-8")
|
||||
self.cache.set_cache(
|
||||
secret_name, _decoded_value
|
||||
) # Cache the retrieved secret
|
||||
return _decoded_value
|
||||
|
||||
self.cache.set_cache(secret_name, None) # Cache that the secret was not found
|
||||
raise ValueError(f"secret {secret_name} not found in Google Secret Manager")
|
||||
@@ -0,0 +1,677 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching import InMemoryCache
|
||||
from litellm.constants import SECRET_MANAGER_REFRESH_INTERVAL
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
|
||||
from .base_secret_manager import BaseSecretManager
|
||||
|
||||
|
||||
class HashicorpSecretManager(BaseSecretManager):
|
||||
def __init__(self):
|
||||
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
|
||||
|
||||
# Vault-specific config
|
||||
self.vault_addr = os.getenv("HCP_VAULT_ADDR", "http://127.0.0.1:8200")
|
||||
self.vault_token = os.getenv("HCP_VAULT_TOKEN", "")
|
||||
# Vault namespace (for X-Vault-Namespace header)
|
||||
self.vault_namespace = os.getenv("HCP_VAULT_NAMESPACE", None)
|
||||
# KV engine mount name (default: "secret")
|
||||
# If your KV engine is mounted somewhere other than "secret", set HCP_VAULT_MOUNT_NAME
|
||||
self.vault_mount_name = os.getenv("HCP_VAULT_MOUNT_NAME", "secret")
|
||||
# Optional path prefix for secrets (e.g., "myapp" -> secret/data/myapp/{secret_name})
|
||||
self.vault_path_prefix = os.getenv("HCP_VAULT_PATH_PREFIX", None)
|
||||
|
||||
# Optional config for TLS cert auth
|
||||
self.tls_cert_path = os.getenv("HCP_VAULT_CLIENT_CERT", "")
|
||||
self.tls_key_path = os.getenv("HCP_VAULT_CLIENT_KEY", "")
|
||||
self.vault_cert_role = os.getenv("HCP_VAULT_CERT_ROLE", None)
|
||||
|
||||
# Optional config for AppRole auth
|
||||
self.approle_role_id = os.getenv("HCP_VAULT_APPROLE_ROLE_ID", "")
|
||||
self.approle_secret_id = os.getenv("HCP_VAULT_APPROLE_SECRET_ID", "")
|
||||
self.approle_mount_path = os.getenv("HCP_VAULT_APPROLE_MOUNT_PATH", "approle")
|
||||
|
||||
self._verify_required_credentials_exist()
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"Hashicorp secret manager is only available for premium users. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
|
||||
litellm.secret_manager_client = self
|
||||
litellm._key_management_system = KeyManagementSystem.HASHICORP_VAULT
|
||||
_refresh_interval = os.environ.get(
|
||||
"HCP_VAULT_REFRESH_INTERVAL", SECRET_MANAGER_REFRESH_INTERVAL
|
||||
)
|
||||
_refresh_interval = (
|
||||
int(_refresh_interval)
|
||||
if _refresh_interval
|
||||
else SECRET_MANAGER_REFRESH_INTERVAL
|
||||
)
|
||||
self.cache = InMemoryCache(
|
||||
default_ttl=_refresh_interval
|
||||
) # store in memory for 1 day
|
||||
|
||||
def _verify_required_credentials_exist(self) -> None:
|
||||
"""
|
||||
Validate that at least one authentication method is configured.
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid authentication credentials are provided
|
||||
"""
|
||||
has_token = bool(self.vault_token)
|
||||
has_approle = bool(self.approle_role_id and self.approle_secret_id)
|
||||
has_tls_cert = bool(self.tls_cert_path and self.tls_key_path)
|
||||
|
||||
if not has_token and not has_approle and not has_tls_cert:
|
||||
raise ValueError(
|
||||
"Missing Vault authentication credentials. Please set either:\n"
|
||||
" - HCP_VAULT_TOKEN for token-based auth, or\n"
|
||||
" - HCP_VAULT_APPROLE_ROLE_ID and HCP_VAULT_APPROLE_SECRET_ID for AppRole auth, or\n"
|
||||
" - HCP_VAULT_CLIENT_CERT and HCP_VAULT_CLIENT_KEY for TLS certificate auth"
|
||||
)
|
||||
|
||||
def _auth_via_approle(self) -> str:
|
||||
"""
|
||||
Authenticate to Vault using AppRole auth method.
|
||||
|
||||
Ref: https://developer.hashicorp.com/vault/api-docs/auth/approle
|
||||
|
||||
Request:
|
||||
```
|
||||
curl \
|
||||
--request POST \
|
||||
--header "X-Vault-Namespace: mynamespace/" \
|
||||
--data '{"role_id": "...", "secret_id": "..."}' \
|
||||
http://127.0.0.1:8200/v1/auth/approle/login
|
||||
```
|
||||
|
||||
Response:
|
||||
```
|
||||
{
|
||||
"auth": {
|
||||
"client_token": "hvs.CAESI...",
|
||||
"accessor": "hmac-sha256...",
|
||||
"policies": ["default", "dev-policy"],
|
||||
"token_policies": ["default", "dev-policy"],
|
||||
"lease_duration": 2764800,
|
||||
"renewable": true
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
verbose_logger.debug("Using AppRole auth for Hashicorp Vault")
|
||||
|
||||
# Check cache first
|
||||
cached_token = self.cache.get_cache(key="hcp_vault_approle_token")
|
||||
if cached_token:
|
||||
verbose_logger.debug("Using cached Vault token from AppRole auth")
|
||||
return cached_token
|
||||
|
||||
# Vault endpoint for AppRole login
|
||||
login_url = f"{self.vault_addr}/v1/auth/{self.approle_mount_path}/login"
|
||||
|
||||
headers = {}
|
||||
if hasattr(self, "vault_namespace") and self.vault_namespace:
|
||||
headers["X-Vault-Namespace"] = self.vault_namespace
|
||||
|
||||
try:
|
||||
client = _get_httpx_client()
|
||||
resp = client.post(
|
||||
url=login_url,
|
||||
headers=headers,
|
||||
json={
|
||||
"role_id": self.approle_role_id,
|
||||
"secret_id": self.approle_secret_id,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
auth_data = resp.json()["auth"]
|
||||
token = auth_data["client_token"]
|
||||
_lease_duration = auth_data["lease_duration"]
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Successfully obtained Vault token via AppRole auth. Lease duration: {_lease_duration}s"
|
||||
)
|
||||
|
||||
# Cache the token with its lease duration
|
||||
self.cache.set_cache(
|
||||
key="hcp_vault_approle_token", value=token, ttl=_lease_duration
|
||||
)
|
||||
return token
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Could not authenticate to Vault via AppRole: {e}")
|
||||
|
||||
def _auth_via_tls_cert(self) -> str:
|
||||
"""
|
||||
Ref: https://developer.hashicorp.com/vault/api-docs/auth/cert
|
||||
|
||||
Request:
|
||||
```
|
||||
curl \
|
||||
--request POST \
|
||||
--cacert vault-ca.pem \
|
||||
--cert cert.pem \
|
||||
--key key.pem \
|
||||
--header "X-Vault-Namespace: mynamespace/" \
|
||||
--data '{"name": "my-cert-role"}' \
|
||||
https://127.0.0.1:8200/v1/auth/cert/login
|
||||
```
|
||||
|
||||
Response:
|
||||
```
|
||||
{
|
||||
"auth": {
|
||||
"client_token": "cf95f87d-f95b-47ff-b1f5-ba7bff850425",
|
||||
"policies": ["web", "stage"],
|
||||
"lease_duration": 3600,
|
||||
"renewable": true
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
verbose_logger.debug("Using TLS cert auth for Hashicorp Vault")
|
||||
# Vault endpoint for cert-based login, e.g. '/v1/auth/cert/login'
|
||||
login_url = f"{self.vault_addr}/v1/auth/cert/login"
|
||||
|
||||
# Include your Vault namespace in the header if you're using namespaces.
|
||||
# E.g. self.vault_namespace = 'mynamespace/'
|
||||
# If you only have root namespace, you can omit this header entirely.
|
||||
headers = {}
|
||||
if hasattr(self, "vault_namespace") and self.vault_namespace:
|
||||
headers["X-Vault-Namespace"] = self.vault_namespace
|
||||
try:
|
||||
# We use the client cert and key for mutual TLS
|
||||
client = httpx.Client(cert=(self.tls_cert_path, self.tls_key_path))
|
||||
resp = client.post(
|
||||
login_url,
|
||||
headers=headers,
|
||||
json=self._get_tls_cert_auth_body(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
token = resp.json()["auth"]["client_token"]
|
||||
_lease_duration = resp.json()["auth"]["lease_duration"]
|
||||
verbose_logger.debug("Successfully obtained Vault token via TLS cert auth.")
|
||||
self.cache.set_cache(
|
||||
key="hcp_vault_token", value=token, ttl=_lease_duration
|
||||
)
|
||||
return token
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Could not authenticate to Vault via TLS cert: {e}")
|
||||
|
||||
def _get_tls_cert_auth_body(self) -> dict:
|
||||
return {"name": self.vault_cert_role}
|
||||
|
||||
def get_url(
|
||||
self,
|
||||
secret_name: str,
|
||||
namespace: Optional[str] = None,
|
||||
mount_name: Optional[str] = None,
|
||||
path_prefix: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Constructs the Vault URL for KV v2 secrets.
|
||||
|
||||
Format: {VAULT_ADDR}/v1/{NAMESPACE}/{MOUNT_NAME}/data/{PATH_PREFIX}/{SECRET_NAME}
|
||||
|
||||
Examples:
|
||||
- Default: http://127.0.0.1:8200/v1/secret/data/mykey
|
||||
- With namespace: http://127.0.0.1:8200/v1/mynamespace/secret/data/mykey
|
||||
- With custom mount: http://127.0.0.1:8200/v1/kv/data/mykey
|
||||
- With path prefix: http://127.0.0.1:8200/v1/secret/data/myapp/mykey
|
||||
"""
|
||||
resolved_namespace = self._sanitize_path_component(
|
||||
namespace if namespace is not None else self.vault_namespace
|
||||
)
|
||||
resolved_mount = self._sanitize_path_component(
|
||||
mount_name if mount_name is not None else self.vault_mount_name
|
||||
)
|
||||
if resolved_mount is None:
|
||||
resolved_mount = "secret"
|
||||
resolved_path_prefix = self._sanitize_path_component(
|
||||
path_prefix if path_prefix is not None else self.vault_path_prefix
|
||||
)
|
||||
|
||||
_url = f"{self.vault_addr}/v1/"
|
||||
if resolved_namespace:
|
||||
_url += f"{resolved_namespace}/"
|
||||
_url += f"{resolved_mount}/data/"
|
||||
if resolved_path_prefix:
|
||||
_url += f"{resolved_path_prefix}/"
|
||||
_url += secret_name
|
||||
return _url
|
||||
|
||||
def _sanitize_plain_value(self, value: Optional[Union[str, int]]) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
value_str = str(value).strip()
|
||||
if value_str == "":
|
||||
return None
|
||||
return value_str
|
||||
|
||||
def _sanitize_path_component(
|
||||
self, value: Optional[Union[str, int]]
|
||||
) -> Optional[str]:
|
||||
sanitized_value = self._sanitize_plain_value(value)
|
||||
if sanitized_value is None:
|
||||
return None
|
||||
sanitized_value = sanitized_value.strip("/")
|
||||
return sanitized_value or None
|
||||
|
||||
def _extract_secret_manager_settings(
|
||||
self, optional_params: Optional[dict]
|
||||
) -> Dict[str, Any]:
|
||||
if not isinstance(optional_params, dict):
|
||||
return {}
|
||||
|
||||
candidate = optional_params.get("secret_manager_settings")
|
||||
source = candidate if isinstance(candidate, dict) else optional_params
|
||||
allowed_keys = {"namespace", "mount", "path_prefix", "data"}
|
||||
return {k: source[k] for k in allowed_keys if k in source}
|
||||
|
||||
def _build_secret_target(
|
||||
self, secret_name: str, optional_params: Optional[dict]
|
||||
) -> Dict[str, Any]:
|
||||
settings = self._extract_secret_manager_settings(optional_params)
|
||||
|
||||
namespace = settings.get("namespace", self.vault_namespace)
|
||||
mount = settings.get("mount", self.vault_mount_name)
|
||||
path_prefix = settings.get("path_prefix", self.vault_path_prefix)
|
||||
data_key_override = settings.get("data")
|
||||
|
||||
data_key = self._sanitize_plain_value(data_key_override) or "key"
|
||||
|
||||
url = self.get_url(
|
||||
secret_name=secret_name,
|
||||
namespace=namespace,
|
||||
mount_name=mount,
|
||||
path_prefix=path_prefix,
|
||||
)
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"data_key": data_key,
|
||||
"secret_name": secret_name,
|
||||
}
|
||||
|
||||
def _get_request_headers(self) -> dict:
|
||||
"""
|
||||
Get the headers for Vault API requests.
|
||||
|
||||
Authentication priority:
|
||||
1. AppRole (if role_id and secret_id are configured)
|
||||
2. TLS Certificate (if cert paths are configured)
|
||||
3. Direct token (if HCP_VAULT_TOKEN is set)
|
||||
"""
|
||||
# Priority 1: AppRole auth
|
||||
if self.approle_role_id and self.approle_secret_id:
|
||||
return {"X-Vault-Token": self._auth_via_approle()}
|
||||
|
||||
# Priority 2: TLS cert auth
|
||||
if self.tls_cert_path and self.tls_key_path:
|
||||
return {"X-Vault-Token": self._auth_via_tls_cert()}
|
||||
|
||||
# Priority 3: Direct token
|
||||
return {"X-Vault-Token": self.vault_token}
|
||||
|
||||
async def async_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Reads a secret from Vault KV v2 using an async HTTPX client.
|
||||
secret_name is just the path inside the KV mount (e.g., 'myapp/config').
|
||||
Returns the entire data dict from data.data, or None on failure.
|
||||
"""
|
||||
if self.cache.get_cache(secret_name) is not None:
|
||||
return self.cache.get_cache(secret_name)
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
)
|
||||
try:
|
||||
# For KV v2: /v1/<mount>/data/<path>
|
||||
# Example: http://127.0.0.1:8200/v1/secret/data/myapp/config
|
||||
_url = self.get_url(secret_name)
|
||||
url = _url
|
||||
|
||||
response = await async_client.get(url, headers=self._get_request_headers())
|
||||
response.raise_for_status()
|
||||
|
||||
# For KV v2, the secret is in response.json()["data"]["data"]
|
||||
json_resp = response.json()
|
||||
_value = self._get_secret_value_from_json_response(json_resp)
|
||||
self.cache.set_cache(secret_name, _value)
|
||||
return _value
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error reading secret from Hashicorp Vault: {e}")
|
||||
return None
|
||||
|
||||
def sync_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Reads a secret from Vault KV v2 using a sync HTTPX client.
|
||||
secret_name is just the path inside the KV mount (e.g., 'myapp/config').
|
||||
Returns the entire data dict from data.data, or None on failure.
|
||||
"""
|
||||
if self.cache.get_cache(secret_name) is not None:
|
||||
return self.cache.get_cache(secret_name)
|
||||
sync_client = _get_httpx_client()
|
||||
try:
|
||||
# For KV v2: /v1/<mount>/data/<path>
|
||||
url = self.get_url(secret_name)
|
||||
|
||||
response = sync_client.get(url, headers=self._get_request_headers())
|
||||
response.raise_for_status()
|
||||
|
||||
# For KV v2, the secret is in response.json()["data"]["data"]
|
||||
json_resp = response.json()
|
||||
_value = self._get_secret_value_from_json_response(json_resp)
|
||||
self.cache.set_cache(secret_name, _value)
|
||||
return _value
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error reading secret from Hashicorp Vault: {e}")
|
||||
return None
|
||||
|
||||
async def async_write_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
secret_value: str,
|
||||
description: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
tags: Optional[Union[dict, list]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Writes a secret to Vault KV v2 using an async HTTPX client.
|
||||
|
||||
Args:
|
||||
secret_name: Path inside the KV mount (e.g., 'myapp/config')
|
||||
secret_value: Value to store
|
||||
description: Optional description for the secret
|
||||
optional_params: Additional parameters to include in the secret data
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
dict: Response containing status and details of the operation
|
||||
"""
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
target = self._build_secret_target(secret_name, optional_params)
|
||||
|
||||
# Prepare the secret data
|
||||
data = {"data": {target["data_key"]: secret_value}}
|
||||
|
||||
if description:
|
||||
data["data"]["description"] = description
|
||||
|
||||
response = await async_client.post(
|
||||
url=target["url"],
|
||||
headers=self._get_request_headers(),
|
||||
json=data,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error writing secret to Hashicorp Vault: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def async_rotate_secret(
|
||||
self,
|
||||
current_secret_name: str,
|
||||
new_secret_name: str,
|
||||
new_secret_value: str,
|
||||
optional_params: Dict | None = None,
|
||||
timeout: float | httpx.Timeout | None = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Rotates a secret by creating a new one and deleting the old one.
|
||||
Uses _build_secret_target to handle optional_params for namespace, mount, path_prefix customization.
|
||||
|
||||
Args:
|
||||
current_secret_name: Current name of the secret
|
||||
new_secret_name: New name for the secret
|
||||
new_secret_value: New value for the secret
|
||||
optional_params: Additional parameters (namespace, mount, path_prefix, data)
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
dict: Response containing status and details of the operation.
|
||||
On success, returns the response from async_write_secret.
|
||||
On error, returns {"status": "error", "message": "error message"}
|
||||
"""
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
# First verify the old secret exists using _build_secret_target
|
||||
current_target = self._build_secret_target(
|
||||
current_secret_name, optional_params
|
||||
)
|
||||
try:
|
||||
response = await async_client.get(
|
||||
url=current_target["url"],
|
||||
headers=self._get_request_headers(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
# Secret exists, we can proceed
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 404:
|
||||
verbose_logger.exception(
|
||||
f"Current secret {current_secret_name} not found"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Current secret {current_secret_name} not found",
|
||||
}
|
||||
verbose_logger.exception(
|
||||
f"Error checking current secret existence: {e.response.text if hasattr(e, 'response') else str(e)}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"HTTP error occurred while checking current secret: {e.response.text if hasattr(e, 'response') else str(e)}",
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Error checking current secret existence: {e}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Error checking current secret: {e}",
|
||||
}
|
||||
|
||||
# Create new secret with new name and value
|
||||
# Use _build_secret_target to handle optional_params
|
||||
create_response = await self.async_write_secret(
|
||||
secret_name=new_secret_name,
|
||||
secret_value=new_secret_value,
|
||||
description=f"Rotated from {current_secret_name}",
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Check if async_write_secret returned an error
|
||||
if (
|
||||
isinstance(create_response, dict)
|
||||
and create_response.get("status") == "error"
|
||||
):
|
||||
return create_response
|
||||
|
||||
# Verify new secret was created successfully using _build_secret_target
|
||||
new_target = self._build_secret_target(new_secret_name, optional_params)
|
||||
try:
|
||||
response = await async_client.get(
|
||||
url=new_target["url"],
|
||||
headers=self._get_request_headers(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
json_resp = response.json()
|
||||
# Use data_key from target to get the correct value
|
||||
data_key = new_target["data_key"]
|
||||
new_secret_value_from_vault = (
|
||||
json_resp.get("data", {}).get("data", {}).get(data_key, None)
|
||||
)
|
||||
if new_secret_value_from_vault != new_secret_value:
|
||||
verbose_logger.exception(
|
||||
f"New secret value mismatch. Expected: {new_secret_value}, Got: {new_secret_value_from_vault}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"New secret value mismatch. Expected: {new_secret_value}, Got: {new_secret_value_from_vault}",
|
||||
}
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 404:
|
||||
verbose_logger.exception(
|
||||
f"Failed to verify new secret {new_secret_name}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to verify new secret {new_secret_name}",
|
||||
}
|
||||
verbose_logger.exception(
|
||||
f"Error verifying new secret: {e.response.text if hasattr(e, 'response') else str(e)}"
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"HTTP error occurred while verifying new secret: {e.response.text if hasattr(e, 'response') else str(e)}",
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error verifying new secret: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Error verifying new secret: {e}",
|
||||
}
|
||||
|
||||
# If everything is successful, delete the old secret
|
||||
# Only delete if the names are different (same name means we're just updating the value)
|
||||
if current_secret_name != new_secret_name:
|
||||
delete_response = await self.async_delete_secret(
|
||||
secret_name=current_secret_name,
|
||||
recovery_window_in_days=7, # Keep for recovery if needed
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
)
|
||||
# Check if async_delete_secret returned an error
|
||||
if (
|
||||
isinstance(delete_response, dict)
|
||||
and delete_response.get("status") == "error"
|
||||
):
|
||||
# Log the error but don't fail the rotation since new secret was created successfully
|
||||
verbose_logger.warning(
|
||||
f"Failed to delete old secret {current_secret_name} after rotation: {delete_response.get('message')}"
|
||||
)
|
||||
else:
|
||||
# Clear cache for the old secret only if deletion was successful
|
||||
self.cache.delete_cache(current_secret_name)
|
||||
|
||||
# Clear cache for the new secret (or updated secret if names are the same)
|
||||
self.cache.delete_cache(new_secret_name)
|
||||
|
||||
return create_response
|
||||
|
||||
except httpx.TimeoutException:
|
||||
verbose_logger.exception("Timeout error occurred during secret rotation")
|
||||
return {"status": "error", "message": "Timeout error occurred"}
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error rotating secret in Hashicorp Vault: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def async_delete_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
recovery_window_in_days: Optional[int] = 7,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Async function to delete a secret from Hashicorp Vault.
|
||||
In KV v2, this marks the latest version of the secret as deleted.
|
||||
|
||||
Args:
|
||||
secret_name: Name of the secret to delete
|
||||
recovery_window_in_days: Not used for Vault (Vault handles this internally)
|
||||
optional_params: Additional parameters specific to the secret manager
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
dict: Response containing status and details of the operation
|
||||
"""
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
target = self._build_secret_target(secret_name, optional_params)
|
||||
response = await async_client.delete(
|
||||
url=target["url"], headers=self._get_request_headers()
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Clear the cache for this secret
|
||||
self.cache.delete_cache(secret_name)
|
||||
if target["secret_name"] != secret_name:
|
||||
self.cache.delete_cache(target["secret_name"])
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Secret {target['secret_name']} deleted successfully",
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error deleting secret from Hashicorp Vault: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
def _get_secret_value_from_json_response(
|
||||
self, json_resp: Optional[dict]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the secret value from the JSON response
|
||||
|
||||
Json response from hashicorp vault is of the form:
|
||||
|
||||
{
|
||||
"request_id":"036ba77c-018b-31dd-047b-323bcd0cd332",
|
||||
"lease_id":"",
|
||||
"renewable":false,
|
||||
"lease_duration":0,
|
||||
"data":
|
||||
{"data":
|
||||
{"key":"Vault Is The Way"},
|
||||
"metadata":{"created_time":"2025-01-01T22:13:50.93942388Z","custom_metadata":null,"deletion_time":"","destroyed":false,"version":1}
|
||||
},
|
||||
"wrap_info":null,
|
||||
"warnings":null,
|
||||
"auth":null,
|
||||
"mount_type":"kv"
|
||||
}
|
||||
|
||||
Note: LiteLLM assumes that all secrets are stored as under the key "key"
|
||||
"""
|
||||
if json_resp is None:
|
||||
return None
|
||||
return json_resp.get("data", {}).get("data", {}).get("key", None)
|
||||
@@ -0,0 +1,290 @@
|
||||
import ast
|
||||
import os
|
||||
import traceback
|
||||
from typing import Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from litellm.secret_managers.get_azure_ad_token_provider import (
|
||||
get_azure_ad_token_provider,
|
||||
)
|
||||
from litellm.secret_managers.secret_manager_handler import get_secret_from_manager
|
||||
|
||||
oidc_cache = DualCache()
|
||||
|
||||
|
||||
def _get_oidc_http_handler(timeout: Optional[httpx.Timeout] = None) -> HTTPHandler:
|
||||
"""
|
||||
Factory function to create HTTPHandler for OIDC requests.
|
||||
This function can be mocked in tests.
|
||||
|
||||
Args:
|
||||
timeout: Optional timeout for HTTP requests. Defaults to 600.0 seconds with 5.0 connect timeout.
|
||||
|
||||
Returns:
|
||||
HTTPHandler instance configured for OIDC requests.
|
||||
"""
|
||||
if timeout is None:
|
||||
timeout = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
return HTTPHandler(timeout=timeout)
|
||||
|
||||
|
||||
######### Secret Manager ############################
|
||||
# checks if user has passed in a secret manager client
|
||||
# if passed in then checks the secret there
|
||||
def str_to_bool(value: Optional[str]) -> Optional[bool]:
|
||||
"""
|
||||
Converts a string to a boolean if it's a recognized boolean string.
|
||||
Returns None if the string is not a recognized boolean value.
|
||||
|
||||
:param value: The string to be checked.
|
||||
:return: True or False if the string is a recognized boolean, otherwise None.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
true_values = {"true"}
|
||||
false_values = {"false"}
|
||||
|
||||
value_lower = value.strip().lower()
|
||||
|
||||
if value_lower in true_values:
|
||||
return True
|
||||
elif value_lower in false_values:
|
||||
return False
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_secret_str(
|
||||
secret_name: str,
|
||||
default_value: Optional[Union[str, bool]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Guarantees response from 'get_secret' is either string or none. Used for fixing linting errors.
|
||||
"""
|
||||
value = get_secret(secret_name=secret_name, default_value=default_value)
|
||||
if value is not None and not isinstance(value, str):
|
||||
return None
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def get_secret_bool(
|
||||
secret_name: str,
|
||||
default_value: Optional[bool] = None,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
Guarantees response from 'get_secret' is either boolean or none. Used for fixing linting errors.
|
||||
|
||||
Args:
|
||||
secret_name: The name of the secret to get.
|
||||
default_value: The default value to return if the secret is not found.
|
||||
|
||||
Returns:
|
||||
The secret value as a boolean or None if the secret is not found.
|
||||
"""
|
||||
_secret_value = get_secret(secret_name, default_value)
|
||||
if _secret_value is None:
|
||||
return None
|
||||
elif isinstance(_secret_value, bool):
|
||||
return _secret_value
|
||||
else:
|
||||
return str_to_bool(_secret_value)
|
||||
|
||||
|
||||
def get_secret( # noqa: PLR0915
|
||||
secret_name: str,
|
||||
default_value: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
key_management_system = litellm._key_management_system
|
||||
key_management_settings = litellm._key_management_settings
|
||||
secret = None
|
||||
|
||||
if secret_name.startswith("os.environ/"):
|
||||
secret_name = secret_name.replace("os.environ/", "")
|
||||
|
||||
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
|
||||
if secret_name.startswith("oidc/"):
|
||||
secret_name_split = secret_name.replace("oidc/", "")
|
||||
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
|
||||
oidc_aud = "/".join(secret_name_split.split("/")[1:])
|
||||
# TODO: Add caching for HTTP requests
|
||||
if oidc_provider == "google":
|
||||
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||
if oidc_token is not None:
|
||||
return oidc_token
|
||||
|
||||
oidc_client = _get_oidc_http_handler()
|
||||
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
|
||||
response = oidc_client.get(
|
||||
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
|
||||
params={"audience": oidc_aud},
|
||||
headers={"Metadata-Flavor": "Google"},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
oidc_token = response.text
|
||||
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Google OIDC provider failed")
|
||||
elif oidc_provider == "circleci":
|
||||
# https://circleci.com/docs/openid-connect-tokens/
|
||||
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
|
||||
if env_secret is None:
|
||||
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
|
||||
return env_secret
|
||||
elif oidc_provider == "circleci_v2":
|
||||
# https://circleci.com/docs/openid-connect-tokens/
|
||||
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
|
||||
if env_secret is None:
|
||||
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
|
||||
return env_secret
|
||||
elif oidc_provider == "github":
|
||||
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
|
||||
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
|
||||
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
|
||||
if (
|
||||
actions_id_token_request_url is None
|
||||
or actions_id_token_request_token is None
|
||||
):
|
||||
raise ValueError(
|
||||
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
|
||||
)
|
||||
|
||||
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||
if oidc_token is not None:
|
||||
return oidc_token
|
||||
|
||||
oidc_client = _get_oidc_http_handler()
|
||||
response = oidc_client.get(
|
||||
actions_id_token_request_url,
|
||||
params={"audience": oidc_aud},
|
||||
headers={
|
||||
"Authorization": f"Bearer {actions_id_token_request_token}",
|
||||
"Accept": "application/json; api-version=2.0",
|
||||
},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
oidc_token = response.json().get("value", None)
|
||||
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Github OIDC provider failed")
|
||||
elif oidc_provider == "azure":
|
||||
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
|
||||
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
|
||||
if azure_federated_token_file is None:
|
||||
verbose_logger.warning(
|
||||
"AZURE_FEDERATED_TOKEN_FILE not found in environment will use Azure AD token provider"
|
||||
)
|
||||
azure_token_provider = get_azure_ad_token_provider(azure_scope=oidc_aud)
|
||||
try:
|
||||
oidc_token = azure_token_provider()
|
||||
if oidc_token is None:
|
||||
raise ValueError("Azure OIDC provider returned None token")
|
||||
return oidc_token
|
||||
except Exception as e:
|
||||
error_msg = f"Azure OIDC provider failed: {str(e)}"
|
||||
verbose_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
with open(azure_federated_token_file, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
elif oidc_provider == "file":
|
||||
# Load token from a file
|
||||
with open(oidc_aud, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
elif oidc_provider == "env":
|
||||
# Load token directly from an environment variable
|
||||
oidc_token = os.getenv(oidc_aud)
|
||||
if oidc_token is None:
|
||||
raise ValueError(f"Environment variable {oidc_aud} not found")
|
||||
return oidc_token
|
||||
elif oidc_provider == "env_path":
|
||||
# Load token from a file path specified in an environment variable
|
||||
token_file_path = os.getenv(oidc_aud)
|
||||
if token_file_path is None:
|
||||
raise ValueError(f"Environment variable {oidc_aud} not found")
|
||||
with open(token_file_path, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Unsupported OIDC provider")
|
||||
|
||||
try:
|
||||
if (
|
||||
_should_read_secret_from_secret_manager()
|
||||
and litellm.secret_manager_client is not None
|
||||
):
|
||||
try:
|
||||
client = litellm.secret_manager_client
|
||||
key_manager = "local"
|
||||
if key_management_system is not None:
|
||||
key_manager = key_management_system.value
|
||||
|
||||
if key_management_settings is not None:
|
||||
if (
|
||||
key_management_settings.hosted_keys is not None
|
||||
and secret_name not in key_management_settings.hosted_keys
|
||||
): # allow user to specify which keys to check in hosted key manager
|
||||
key_manager = "local"
|
||||
|
||||
# Delegate to the secret manager handler
|
||||
secret = get_secret_from_manager(
|
||||
client=client,
|
||||
key_manager=key_manager,
|
||||
secret_name=secret_name,
|
||||
key_management_settings=key_management_settings,
|
||||
)
|
||||
except Exception as e: # check if it's in os.environ
|
||||
verbose_logger.error(
|
||||
f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}"
|
||||
)
|
||||
secret = os.getenv(secret_name)
|
||||
try:
|
||||
if isinstance(secret, str):
|
||||
secret_value_as_bool = ast.literal_eval(secret)
|
||||
if isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except Exception:
|
||||
return secret
|
||||
else:
|
||||
secret = os.environ.get(secret_name)
|
||||
secret_value_as_bool = str_to_bool(secret) if secret is not None else None
|
||||
if secret_value_as_bool is not None and isinstance(
|
||||
secret_value_as_bool, bool
|
||||
):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except Exception as e:
|
||||
if default_value is not None:
|
||||
return default_value
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
def _should_read_secret_from_secret_manager() -> bool:
|
||||
"""
|
||||
Returns True if the secret manager should be used to read the secret, False otherwise
|
||||
|
||||
- If the secret manager client is not set, return False
|
||||
- If the `_key_management_settings` access mode is "read_only" or "read_and_write", return True
|
||||
- Otherwise, return False
|
||||
"""
|
||||
if litellm.secret_manager_client is not None:
|
||||
if litellm._key_management_settings is not None:
|
||||
if (
|
||||
litellm._key_management_settings.access_mode == "read_only"
|
||||
or litellm._key_management_settings.access_mode == "read_and_write"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Secret Manager Handler
|
||||
|
||||
Handles retrieving secrets from different secret management systems.
|
||||
"""
|
||||
import base64
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose
|
||||
from litellm.types.secret_managers.main import KeyManagementSystem
|
||||
|
||||
|
||||
def _is_base64(s):
|
||||
"""Check if a string is valid base64."""
|
||||
import binascii
|
||||
|
||||
try:
|
||||
return base64.b64encode(base64.b64decode(s)).decode() == s
|
||||
except binascii.Error:
|
||||
return False
|
||||
|
||||
|
||||
def get_secret_from_manager( # noqa: PLR0915
|
||||
client: Any,
|
||||
key_manager: str,
|
||||
secret_name: str,
|
||||
key_management_settings: Optional[Any] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get a secret from the configured secret manager.
|
||||
|
||||
Args:
|
||||
client: The secret manager client instance
|
||||
key_manager: The type of key manager (e.g., "azure_key_vault", "google_kms", etc.)
|
||||
secret_name: The name/path of the secret to retrieve
|
||||
key_management_settings: Optional settings for the key management system
|
||||
|
||||
Returns:
|
||||
The secret value as a string, or None if not found
|
||||
|
||||
Raises:
|
||||
ValueError: If the secret cannot be retrieved or required parameters are missing
|
||||
Exception: For other errors during secret retrieval
|
||||
"""
|
||||
secret = None
|
||||
|
||||
if (
|
||||
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
|
||||
or type(client).__module__ + "." + type(client).__name__
|
||||
== "azure.keyvault.secrets._client.SecretClient"
|
||||
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
|
||||
secret = client.get_secret(secret_name).value
|
||||
|
||||
elif (
|
||||
key_manager == KeyManagementSystem.GOOGLE_KMS.value
|
||||
or client.__class__.__name__ == "KeyManagementServiceClient"
|
||||
):
|
||||
encrypted_secret: Any = os.getenv(secret_name)
|
||||
if encrypted_secret is None:
|
||||
raise ValueError(
|
||||
"Google KMS requires the encrypted secret to be in the environment!"
|
||||
)
|
||||
b64_flag = _is_base64(encrypted_secret)
|
||||
if b64_flag is True: # if passed in as encoded b64 string
|
||||
encrypted_secret = base64.b64decode(encrypted_secret)
|
||||
ciphertext = encrypted_secret
|
||||
else:
|
||||
raise ValueError(
|
||||
"Google KMS requires the encrypted secret to be encoded in base64"
|
||||
) # fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce
|
||||
response = client.decrypt(
|
||||
request={
|
||||
"name": litellm._google_kms_resource_name,
|
||||
"ciphertext": ciphertext,
|
||||
}
|
||||
)
|
||||
secret = response.plaintext.decode(
|
||||
"utf-8"
|
||||
) # assumes the original value was encoded with utf-8
|
||||
|
||||
elif key_manager == KeyManagementSystem.AWS_KMS.value:
|
||||
"""
|
||||
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
|
||||
"""
|
||||
encrypted_value = os.getenv(secret_name, None)
|
||||
if encrypted_value is None:
|
||||
raise Exception(
|
||||
"AWS KMS - Encrypted Value of Key={} is None".format(secret_name)
|
||||
)
|
||||
# Decode the base64 encoded ciphertext
|
||||
ciphertext_blob = base64.b64decode(encrypted_value)
|
||||
|
||||
# Set up the parameters for the decrypt call
|
||||
params = {"CiphertextBlob": ciphertext_blob}
|
||||
# Perform the decryption
|
||||
response = client.decrypt(**params)
|
||||
|
||||
# Extract and decode the plaintext
|
||||
plaintext = response["Plaintext"]
|
||||
secret = plaintext.decode("utf-8")
|
||||
if isinstance(secret, str):
|
||||
secret = secret.strip()
|
||||
|
||||
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
|
||||
from litellm.secret_managers.aws_secret_manager_v2 import (
|
||||
AWSSecretsManagerV2,
|
||||
)
|
||||
|
||||
if isinstance(client, AWSSecretsManagerV2):
|
||||
primary_secret_name = None
|
||||
if key_management_settings is not None:
|
||||
primary_secret_name = key_management_settings.primary_secret_name
|
||||
|
||||
secret = client.sync_read_secret(
|
||||
secret_name=secret_name,
|
||||
primary_secret_name=primary_secret_name,
|
||||
)
|
||||
print_verbose(f"get_secret_value_response: {secret}")
|
||||
|
||||
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
|
||||
try:
|
||||
secret = client.get_secret_from_google_secret_manager(secret_name)
|
||||
print_verbose(f"secret from google secret manager: {secret}")
|
||||
if secret is None:
|
||||
raise ValueError(
|
||||
f"No secret found in Google Secret Manager for {secret_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"An error occurred - {str(e)}")
|
||||
raise e
|
||||
|
||||
elif key_manager == KeyManagementSystem.HASHICORP_VAULT.value:
|
||||
try:
|
||||
secret = client.sync_read_secret(secret_name=secret_name)
|
||||
if secret is None:
|
||||
raise ValueError(
|
||||
f"No secret found in Hashicorp Secret Manager for {secret_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"An error occurred - {str(e)}")
|
||||
raise e
|
||||
|
||||
elif key_manager == KeyManagementSystem.CYBERARK.value:
|
||||
try:
|
||||
secret = client.sync_read_secret(secret_name=secret_name)
|
||||
if secret is None:
|
||||
raise ValueError(
|
||||
f"No secret found in CyberArk Secret Manager for {secret_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"An error occurred - {str(e)}")
|
||||
raise e
|
||||
|
||||
elif key_manager == KeyManagementSystem.CUSTOM.value:
|
||||
# Check if client is a CustomSecretManager instance
|
||||
from litellm.integrations.custom_secret_manager import CustomSecretManager
|
||||
|
||||
if isinstance(client, CustomSecretManager):
|
||||
secret = client.sync_read_secret(
|
||||
secret_name=secret_name,
|
||||
optional_params=key_management_settings.model_dump()
|
||||
if key_management_settings
|
||||
else None,
|
||||
)
|
||||
if secret is None:
|
||||
raise ValueError(
|
||||
f"No secret found in Custom Secret Manager for {secret_name}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Custom secret manager client must be an instance of CustomSecretManager, got {type(client).__name__}"
|
||||
)
|
||||
|
||||
elif key_manager == "local":
|
||||
secret = os.getenv(secret_name)
|
||||
|
||||
else: # assume the default is infisicial client
|
||||
secret = client.get_secret(secret_name).secret_value
|
||||
|
||||
return secret
|
||||
Reference in New Issue
Block a user