Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/llms/bedrock/base_aws_llm.py
2026-03-26 20:06:14 +08:00

1403 lines
54 KiB
Python

import hashlib
import json
import os
import urllib.parse
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
cast,
get_args,
)
import httpx
from pydantic import BaseModel
from litellm._logging import verbose_logger
from litellm.caching.caching import DualCache
from litellm.constants import (
BEDROCK_EMBEDDING_PROVIDERS_LITERAL,
BEDROCK_INVOKE_PROVIDERS_LITERAL,
BEDROCK_MAX_POLICY_SIZE,
)
from litellm.litellm_core_utils.dd_tracing import tracer
from litellm.secret_managers.main import get_secret, get_secret_str
if TYPE_CHECKING:
from botocore.awsrequest import AWSPreparedRequest
from botocore.credentials import Credentials
else:
Credentials = Any
AWSPreparedRequest = Any
class Boto3CredentialsInfo(BaseModel):
credentials: Credentials
aws_region_name: str
aws_bedrock_runtime_endpoint: Optional[str]
class AwsAuthError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class BaseAWSLLM:
def __init__(self) -> None:
self.iam_cache = DualCache()
super().__init__()
self.aws_authentication_params = [
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
"aws_region_name",
"aws_session_name",
"aws_profile_name",
"aws_role_name",
"aws_web_identity_token",
"aws_sts_endpoint",
"aws_bedrock_runtime_endpoint",
"aws_external_id",
]
def _get_ssl_verify(self, ssl_verify: Optional[Union[bool, str]] = None):
"""
Get SSL verification setting for boto3 clients.
This ensures that custom CA certificates are properly used for all AWS API calls,
including STS and Bedrock services.
Returns:
Union[bool, str]: SSL verification setting - False to disable, True to enable,
or a string path to a CA bundle file
"""
from litellm.llms.custom_httpx.http_handler import get_ssl_verify
return get_ssl_verify(ssl_verify=ssl_verify)
def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str:
"""
Generate a unique cache key based on the credential arguments.
"""
# Convert credential arguments to a JSON string and hash it to create a unique key
credential_str = json.dumps(credential_args, sort_keys=True)
return hashlib.sha256(credential_str.encode()).hexdigest()
@tracer.wrap()
def get_credentials(
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
aws_sts_endpoint: Optional[str] = None,
aws_external_id: Optional[str] = None,
ssl_verify: Optional[Union[bool, str]] = None,
):
"""
Return a boto3.Credentials object
"""
## CHECK IS 'os.environ/' passed in
params_to_check: List[Optional[str]] = [
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
aws_sts_endpoint,
aws_external_id,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
_v = get_secret(param)
if _v is not None and isinstance(_v, str):
params_to_check[i] = _v
elif param is None: # check if uppercase value in env
key = self.aws_authentication_params[i]
if key.upper() in os.environ:
params_to_check[i] = os.getenv(key.upper())
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
aws_sts_endpoint,
aws_external_id,
) = params_to_check
verbose_logger.debug(
"in get credentials\n"
"aws_access_key_id=%s\n"
"aws_secret_access_key=%s\n"
"aws_session_token=%s\n"
"aws_region_name=%s\n"
"aws_session_name=%s\n"
"aws_profile_name=%s\n"
"aws_role_name=%s\n"
"aws_web_identity_token=%s\n"
"aws_sts_endpoint=%s\n"
"aws_external_id=%s",
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
aws_sts_endpoint,
aws_external_id,
)
# create cache key for non-expiring auth flows
args = {
k: v
for k, v in locals().items()
if k.startswith("aws_") or k == "ssl_verify"
}
cache_key = self.get_cache_key(args)
_cached_credentials = self.iam_cache.get_cache(cache_key)
if _cached_credentials:
return _cached_credentials
#########################################################
# Handle diff boto3 auth flows
# for each helper
# Return:
# Credentials - boto3.Credentials
# cache ttl - Optional[int]. If None, the credentials are not cached. Some auth flows have no expiry time.
#########################################################
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
credentials, _cache_ttl = self._auth_with_web_identity_token(
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
aws_region_name=aws_region_name,
aws_sts_endpoint=aws_sts_endpoint,
aws_external_id=aws_external_id,
)
elif aws_role_name is not None:
# Check if we're already running as the target role and can skip assumption
# This handles IRSA (EKS), ECS task roles, and EC2 instance profiles
if self._is_already_running_as_role(aws_role_name, ssl_verify=ssl_verify):
verbose_logger.debug(
"Already running as target role %s, using ambient credentials",
aws_role_name,
)
credentials, _cache_ttl = self._auth_with_env_vars()
else:
verbose_logger.debug(
"Using role assumption: calling _auth_with_aws_role"
)
# If aws_session_name is not provided, generate a default one
if aws_session_name is None:
aws_session_name = (
f"litellm-session-{int(datetime.now().timestamp())}"
)
credentials, _cache_ttl = self._auth_with_aws_role(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
aws_region_name=aws_region_name,
aws_sts_endpoint=aws_sts_endpoint,
aws_external_id=aws_external_id,
ssl_verify=ssl_verify,
)
elif aws_profile_name is not None: ### CHECK SESSION ###
credentials, _cache_ttl = self._auth_with_aws_profile(aws_profile_name)
elif (
aws_access_key_id is not None
and aws_secret_access_key is not None
and aws_session_token is not None
):
credentials, _cache_ttl = self._auth_with_aws_session_token(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
)
elif (
aws_access_key_id is not None
and aws_secret_access_key is not None
and aws_region_name is not None
):
credentials, _cache_ttl = self._auth_with_access_key_and_secret_key(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
)
else:
credentials, _cache_ttl = self._auth_with_env_vars()
self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl)
return credentials
def _get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]:
try:
# First check if the string contains the expected prefix
if not isinstance(model, str) or "arn:aws:bedrock" not in model:
return None
# Split the ARN and check if we have enough parts
parts = model.split(":")
if len(parts) < 4:
return None
# Get the region from the correct position
region = parts[3]
if not region: # Check if region is empty
return None
return region
except Exception:
# Catch any unexpected errors and return None
return None
@staticmethod
def _get_provider_from_model_path(
model_path: str,
) -> Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL]:
"""
Helper function to get the provider from a model path with format: provider/model-name
Args:
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
Returns:
Optional[str]: The provider name, or None if no valid provider found
"""
parts = model_path.split("/")
if len(parts) >= 1:
provider = parts[0]
if provider in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
return None
@staticmethod
def get_bedrock_invoke_provider(
model: str,
) -> Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL]:
"""
Helper function to get the bedrock provider from the model
handles 3 scenarions:
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
"""
if model.startswith("invoke/"):
model = model.replace("invoke/", "", 1)
# Special case: Check for "nova" in model name first (before "amazon")
# This handles amazon.nova-* models which would otherwise match "amazon" (Titan)
if "nova" in model.lower():
if "nova" in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(BEDROCK_INVOKE_PROVIDERS_LITERAL, "nova")
_split_model = model.split(".")[0]
if _split_model in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
# If not a known provider, check for pattern with two slashes
provider = BaseAWSLLM._get_provider_from_model_path(model)
if provider is not None:
return provider
for provider in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL):
if provider in model:
return provider
return None
@staticmethod
def get_bedrock_model_id(
optional_params: dict,
provider: Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL],
model: str,
) -> str:
model_id = optional_params.pop("model_id", None)
if model_id is not None:
model_id = BaseAWSLLM.encode_model_id(model_id=model_id)
else:
model_id = model
model_id = model_id.replace("invoke/", "", 1)
if provider == "llama" and "llama/" in model_id:
model_id = BaseAWSLLM._get_model_id_from_model_with_spec(
model_id, spec="llama"
)
elif provider == "deepseek_r1" and "deepseek_r1/" in model_id:
model_id = BaseAWSLLM._get_model_id_from_model_with_spec(
model_id, spec="deepseek_r1"
)
elif provider == "openai" and "openai/" in model_id:
model_id = BaseAWSLLM._get_model_id_from_model_with_spec(
model_id, spec="openai"
)
elif provider == "qwen2" and "qwen2/" in model_id:
model_id = BaseAWSLLM._get_model_id_from_model_with_spec(
model_id, spec="qwen2"
)
elif provider == "qwen3" and "qwen3/" in model_id:
model_id = BaseAWSLLM._get_model_id_from_model_with_spec(
model_id, spec="qwen3"
)
elif provider == "stability" and "stability/" in model_id:
model_id = BaseAWSLLM._get_model_id_from_model_with_spec(
model_id, spec="stability"
)
elif provider == "moonshot" and "moonshot/" in model_id:
model_id = BaseAWSLLM._get_model_id_from_model_with_spec(
model_id, spec="moonshot"
)
elif "nova-2/" in model_id:
model_id = BaseAWSLLM._get_model_id_from_model_with_spec(
model_id, spec="nova-2"
)
elif "nova/" in model_id:
model_id = BaseAWSLLM._get_model_id_from_model_with_spec(
model_id, spec="nova"
)
return model_id
@staticmethod
def _get_model_id_from_model_with_spec(
model: str,
spec: str,
) -> str:
"""
Remove `llama` from modelID since `llama` is simply a spec to follow for custom bedrock models
"""
model_id = model.replace(spec + "/", "")
return BaseAWSLLM.encode_model_id(model_id=model_id)
@staticmethod
def encode_model_id(model_id: str) -> str:
"""
Double encode the model ID to ensure it matches the expected double-encoded format.
Args:
model_id (str): The model ID to encode.
Returns:
str: The double-encoded model ID.
"""
return urllib.parse.quote(model_id, safe="")
@staticmethod
def get_bedrock_embedding_provider(
model: str,
) -> Optional[BEDROCK_EMBEDDING_PROVIDERS_LITERAL]:
"""
Helper function to get the bedrock embedding provider from the model
Handles scenarios like:
1. model=cohere.embed-english-v3:0 -> Returns `cohere`
2. model=amazon.titan-embed-text-v1 -> Returns `amazon`
3. model=amazon.nova-2-multimodal-embeddings-v1:0 -> Returns `nova`
4. model=us.twelvelabs.marengo-embed-2-7-v1:0 -> Returns `twelvelabs`
5. model=twelvelabs.marengo-embed-2-7-v1:0 -> Returns `twelvelabs`
"""
# Special case: Check for "nova" in model name first (before "amazon")
# This handles amazon.nova-* models
if "nova" in model.lower():
if "nova" in get_args(BEDROCK_EMBEDDING_PROVIDERS_LITERAL):
return cast(BEDROCK_EMBEDDING_PROVIDERS_LITERAL, "nova")
# Handle regional models like us.twelvelabs.marengo-embed-2-7-v1:0
if "." in model:
parts = model.split(".")
# Check if the second part (after potential region) is a known provider
if len(parts) >= 2:
potential_provider = parts[
1
] # e.g., "twelvelabs" from "us.twelvelabs.marengo-embed-2-7-v1:0"
if potential_provider in get_args(BEDROCK_EMBEDDING_PROVIDERS_LITERAL):
return cast(BEDROCK_EMBEDDING_PROVIDERS_LITERAL, potential_provider)
# Check if the first part is a known provider (standard format)
potential_provider = parts[
0
] # e.g., "cohere" from "cohere.embed-english-v3:0"
if potential_provider in get_args(BEDROCK_EMBEDDING_PROVIDERS_LITERAL):
return cast(BEDROCK_EMBEDDING_PROVIDERS_LITERAL, potential_provider)
# Fallback: check if any provider name appears in the model string
for provider in get_args(BEDROCK_EMBEDDING_PROVIDERS_LITERAL):
if provider in model:
return cast(BEDROCK_EMBEDDING_PROVIDERS_LITERAL, provider)
return None
def _get_aws_region_name(
self,
optional_params: dict,
model: Optional[str] = None,
model_id: Optional[str] = None,
) -> str:
"""
Get the AWS region name from the environment variables.
Parameters:
optional_params (dict): Optional parameters for the model call
model (str): The model name
model_id (str): The model ID. This is the ARN of the model, if passed in as a separate param.
Returns:
str: The AWS region name
"""
aws_region_name = optional_params.get("aws_region_name", None)
### SET REGION NAME ###
if aws_region_name is None:
# check model arn #
if model_id is not None:
aws_region_name = self._get_aws_region_from_model_arn(model_id)
else:
aws_region_name = self._get_aws_region_from_model_arn(model)
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if (
aws_region_name is None
and litellm_aws_region_name is not None
and isinstance(litellm_aws_region_name, str)
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if (
aws_region_name is None
and standard_aws_region_name is not None
and isinstance(standard_aws_region_name, str)
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
try:
import boto3
with tracer.trace("boto3.Session()"):
session = boto3.Session()
configured_region = session.region_name
if configured_region:
aws_region_name = configured_region
else:
aws_region_name = "us-west-2"
except Exception:
aws_region_name = "us-west-2"
return aws_region_name
def get_aws_region_name_for_non_llm_api_calls(
self,
aws_region_name: Optional[str] = None,
):
"""
Get the AWS region name for non-llm api calls.
LLM API calls check the model arn and end up using that as the region name.
For non-llm api calls eg. Guardrails, Vector Stores we just need to check the dynamic param or env vars.
"""
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
return aws_region_name
@staticmethod
def _parse_arn_account_and_role_name(
arn: str,
) -> Optional[Tuple[str, str, str]]:
"""
Parse an ARN and return (partition, account_id, role_name).
Handles:
- arn:aws:iam::123456789012:role/MyRole
- arn:aws:iam::123456789012:role/path/to/MyRole
- arn:aws:sts::123456789012:assumed-role/MyRole/session-name
Returns None if the ARN cannot be parsed.
"""
# ARN format: arn:PARTITION:SERVICE:REGION:ACCOUNT:RESOURCE
parts = arn.split(":")
if len(parts) < 6 or parts[0] != "arn":
return None
partition = parts[1] # e.g. "aws", "aws-cn", "aws-us-gov"
account_id = parts[4]
resource = ":".join(parts[5:]) # rejoin in case resource contains colons
if resource.startswith("role/"):
# arn:aws:iam::ACCOUNT:role/[path/]ROLE_NAME
role_name = resource.split("/")[-1]
elif resource.startswith("assumed-role/"):
# arn:aws:sts::ACCOUNT:assumed-role/ROLE_NAME/SESSION
role_parts = resource.split("/")
if len(role_parts) >= 2:
role_name = role_parts[1]
else:
return None
else:
return None
return partition, account_id, role_name
def _is_already_running_as_role(
self,
aws_role_name: str,
ssl_verify: Optional[Union[bool, str]] = None,
) -> bool:
"""
Check if the current environment is already running as the target IAM role.
This handles multiple AWS environments:
- IRSA (EKS): AWS_ROLE_ARN + AWS_WEB_IDENTITY_TOKEN_FILE are set
- ECS task roles: Uses sts:GetCallerIdentity to check current role ARN
- EC2 instance profiles: Uses sts:GetCallerIdentity to check current role ARN
Compares partition, account ID, and role name to avoid cross-account
false matches.
Returns True if the current identity matches the target role, meaning
we can skip sts:AssumeRole and use ambient credentials directly.
"""
target_parsed = self._parse_arn_account_and_role_name(aws_role_name)
if target_parsed is None:
return False
target_partition, target_account, target_role = target_parsed
# Fast path: IRSA environment check (no API call needed)
current_role_arn = os.getenv("AWS_ROLE_ARN")
web_identity_token_file = os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
if current_role_arn and web_identity_token_file:
return current_role_arn == aws_role_name
# For ECS/EC2: call sts:GetCallerIdentity to check if already running as the role
try:
import boto3
with tracer.trace("boto3.client(sts).get_caller_identity"):
sts_client = boto3.client(
"sts", verify=self._get_ssl_verify(ssl_verify)
)
identity = sts_client.get_caller_identity()
caller_arn = identity.get("Arn", "")
caller_parsed = self._parse_arn_account_and_role_name(caller_arn)
if caller_parsed is not None:
caller_partition, caller_account, caller_role = caller_parsed
if (
caller_partition == target_partition
and caller_account == target_account
and caller_role == target_role
):
verbose_logger.debug(
"Current identity already matches target role: %s",
aws_role_name,
)
return True
except Exception as e:
verbose_logger.debug(
"Could not determine current role identity: %s", str(e)
)
return False
@tracer.wrap()
def _auth_with_web_identity_token(
self,
aws_web_identity_token: str,
aws_role_name: str,
aws_session_name: str,
aws_region_name: Optional[str],
aws_sts_endpoint: Optional[str],
aws_external_id: Optional[str] = None,
ssl_verify: Optional[Union[bool, str]] = None,
) -> Tuple[Credentials, Optional[int]]:
"""
Authenticate with AWS Web Identity Token
"""
import boto3
verbose_logger.debug(
f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}"
)
if aws_sts_endpoint is None:
sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com"
else:
sts_endpoint = aws_sts_endpoint
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise AwsAuthError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
with tracer.trace("boto3.client(sts)"):
sts_client = boto3.client(
"sts",
region_name=aws_region_name,
endpoint_url=sts_endpoint,
verify=self._get_ssl_verify(ssl_verify),
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
assume_role_params = {
"RoleArn": aws_role_name,
"RoleSessionName": aws_session_name,
"WebIdentityToken": oidc_token,
"DurationSeconds": 3600,
"Policy": '{"Version":"2012-10-17","Statement":[{"Sid":"BedrockLiteLLM","Effect":"Allow","Action":["bedrock:InvokeModel","bedrock:InvokeModelWithResponseStream"],"Resource":"*","Condition":{"Bool":{"aws:SecureTransport":"true"},"StringLike":{"aws:UserAgent":"litellm/*"}}}]}',
}
# Add ExternalId parameter if provided
if aws_external_id is not None:
assume_role_params["ExternalId"] = aws_external_id
sts_response = sts_client.assume_role_with_web_identity(**assume_role_params)
iam_creds_dict = {
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
"aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
"aws_session_token": sts_response["Credentials"]["SessionToken"],
"region_name": aws_region_name,
}
if sts_response["PackedPolicySize"] > BEDROCK_MAX_POLICY_SIZE:
verbose_logger.warning(
f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}"
)
with tracer.trace("boto3.Session(**iam_creds_dict)"):
session = boto3.Session(**iam_creds_dict)
iam_creds = session.get_credentials()
return iam_creds, self._get_default_ttl_for_boto3_credentials()
def _handle_irsa_cross_account(
self,
irsa_role_arn: str,
aws_role_name: str,
aws_session_name: str,
region: str,
web_identity_token_file: str,
aws_external_id: Optional[str] = None,
aws_sts_endpoint: Optional[str] = None,
ssl_verify: Optional[Union[bool, str]] = None,
) -> dict:
"""Handle cross-account role assumption for IRSA."""
import boto3
verbose_logger.debug("Cross-account role assumption detected")
# Read the web identity token
with open(web_identity_token_file, "r") as f:
web_identity_token = f.read().strip()
irsa_sts_kwargs: dict = {
"region_name": region,
"verify": self._get_ssl_verify(ssl_verify),
}
if aws_sts_endpoint is not None:
irsa_sts_kwargs["endpoint_url"] = aws_sts_endpoint
# Create an STS client without credentials
with tracer.trace("boto3.client(sts) for manual IRSA"):
sts_client = boto3.client("sts", **irsa_sts_kwargs)
# Manually assume the IRSA role with the session name
verbose_logger.debug(
f"Manually assuming IRSA role {irsa_role_arn} with session {aws_session_name}"
)
irsa_response = sts_client.assume_role_with_web_identity(
RoleArn=irsa_role_arn,
RoleSessionName=aws_session_name,
WebIdentityToken=web_identity_token,
)
# Extract the credentials from the IRSA assumption
irsa_creds = irsa_response["Credentials"]
# Create a new STS client with the IRSA credentials
with tracer.trace("boto3.client(sts) with manual IRSA credentials"):
sts_client_with_creds = boto3.client(
"sts",
aws_access_key_id=irsa_creds["AccessKeyId"],
aws_secret_access_key=irsa_creds["SecretAccessKey"],
aws_session_token=irsa_creds["SessionToken"],
**irsa_sts_kwargs,
)
# Get current caller identity for debugging
try:
caller_identity = sts_client_with_creds.get_caller_identity()
verbose_logger.debug(
f"Current identity after manual IRSA assumption: {caller_identity.get('Arn', 'unknown')}"
)
except Exception as e:
verbose_logger.debug(f"Failed to get caller identity: {e}")
# Now assume the target role
verbose_logger.debug(
f"Attempting to assume target role: {aws_role_name} with session: {aws_session_name}"
)
assume_role_params = {
"RoleArn": aws_role_name,
"RoleSessionName": aws_session_name,
}
# Add ExternalId parameter if provided
if aws_external_id is not None:
assume_role_params["ExternalId"] = aws_external_id
return sts_client_with_creds.assume_role(**assume_role_params)
def _handle_irsa_same_account(
self,
aws_role_name: str,
aws_session_name: str,
region: str,
aws_external_id: Optional[str] = None,
aws_sts_endpoint: Optional[str] = None,
ssl_verify: Optional[Union[bool, str]] = None,
) -> dict:
"""Handle same-account role assumption for IRSA."""
import boto3
irsa_sts_kwargs: dict = {
"region_name": region,
"verify": self._get_ssl_verify(ssl_verify),
}
if aws_sts_endpoint is not None:
irsa_sts_kwargs["endpoint_url"] = aws_sts_endpoint
verbose_logger.debug("Same account role assumption, using automatic IRSA")
with tracer.trace("boto3.client(sts) with automatic IRSA"):
sts_client = boto3.client("sts", **irsa_sts_kwargs)
# Get current caller identity for debugging
try:
caller_identity = sts_client.get_caller_identity()
verbose_logger.debug(
f"Current IRSA identity: {caller_identity.get('Arn', 'unknown')}"
)
except Exception as e:
verbose_logger.debug(f"Failed to get caller identity: {e}")
# Assume the role
verbose_logger.debug(
f"Attempting to assume role: {aws_role_name} with session: {aws_session_name}"
)
assume_role_params = {
"RoleArn": aws_role_name,
"RoleSessionName": aws_session_name,
}
# Add ExternalId parameter if provided
if aws_external_id is not None:
assume_role_params["ExternalId"] = aws_external_id
return sts_client.assume_role(**assume_role_params)
def _extract_credentials_and_ttl(
self, sts_response: dict
) -> Tuple[Credentials, Optional[int]]:
"""Extract credentials and TTL from STS response."""
from botocore.credentials import Credentials
sts_credentials = sts_response["Credentials"]
credentials = Credentials(
access_key=sts_credentials["AccessKeyId"],
secret_key=sts_credentials["SecretAccessKey"],
token=sts_credentials["SessionToken"],
)
expiration_time = sts_credentials["Expiration"]
ttl = int(
(expiration_time - datetime.now(expiration_time.tzinfo)).total_seconds()
)
return credentials, ttl
@tracer.wrap()
def _auth_with_aws_role(
self,
aws_access_key_id: Optional[str],
aws_secret_access_key: Optional[str],
aws_session_token: Optional[str],
aws_role_name: str,
aws_session_name: str,
aws_region_name: Optional[str] = None,
aws_sts_endpoint: Optional[str] = None,
aws_external_id: Optional[str] = None,
ssl_verify: Optional[Union[bool, str]] = None,
) -> Tuple[Credentials, Optional[int]]:
"""
Authenticate with AWS Role
"""
import boto3
from botocore.credentials import Credentials
# Check if we're in an EKS/IRSA environment
web_identity_token_file = os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
irsa_role_arn = os.getenv("AWS_ROLE_ARN")
region = (
aws_region_name
or os.getenv("AWS_REGION")
or os.getenv("AWS_DEFAULT_REGION")
)
# If we have IRSA environment variables and no explicit credentials,
# we need to use the web identity token flow
if (
web_identity_token_file
and irsa_role_arn
and aws_access_key_id is None
and aws_secret_access_key is None
):
# For cross-account role assumption with specific session names,
# we need to manually assume the IRSA role first with the correct session name
verbose_logger.debug(
f"IRSA detected: using web identity token from {web_identity_token_file}"
)
try:
# Use passed-in region when set, else env, else default (align with AssumeRole path)
region = region or "us-east-1"
# Check if we need to do cross-account role assumption
if aws_role_name != irsa_role_arn:
sts_response = self._handle_irsa_cross_account(
irsa_role_arn,
aws_role_name,
aws_session_name,
region,
web_identity_token_file,
aws_external_id,
aws_sts_endpoint=aws_sts_endpoint,
ssl_verify=ssl_verify,
)
else:
sts_response = self._handle_irsa_same_account(
aws_role_name,
aws_session_name,
region,
aws_external_id,
aws_sts_endpoint=aws_sts_endpoint,
ssl_verify=ssl_verify,
)
return self._extract_credentials_and_ttl(sts_response)
except Exception as e:
verbose_logger.debug(f"Failed to assume role via IRSA: {e}")
if "AccessDenied" in str(
e
) and "is not authorized to perform: sts:AssumeRole" in str(e):
# Provide a more helpful error message for trust policy issues
verbose_logger.error(
f"Access denied when trying to assume role {aws_role_name}. "
f"Please ensure the trust policy of {aws_role_name} allows "
f"the current role to assume it. Current identity: check logs with verbose mode."
)
# Re-raise the exception instead of falling through
raise
# In EKS/IRSA environments, use ambient credentials (no explicit keys needed)
# This allows the web identity token to work automatically
sts_client_kwargs: dict = {"verify": self._get_ssl_verify(ssl_verify)}
if region is not None:
sts_client_kwargs["region_name"] = region
if aws_sts_endpoint is not None:
sts_client_kwargs["endpoint_url"] = aws_sts_endpoint
if aws_access_key_id is None and aws_secret_access_key is None:
with tracer.trace("boto3.client(sts)"):
sts_client = boto3.client("sts", **sts_client_kwargs)
else:
with tracer.trace("boto3.client(sts)"):
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
**sts_client_kwargs,
)
assume_role_params = {
"RoleArn": aws_role_name,
"RoleSessionName": aws_session_name,
}
# Add ExternalId parameter if provided
if aws_external_id is not None:
assume_role_params["ExternalId"] = aws_external_id
try:
sts_response = sts_client.assume_role(**assume_role_params)
except Exception as e:
error_str = str(e)
if "AccessDenied" in error_str:
# Only fall back to ambient credentials if we can positively
# confirm the caller is already the target role (same account,
# partition, and role name). This avoids silently using the
# wrong identity when there is a genuine trust-policy or
# permission misconfiguration.
if self._is_already_running_as_role(
aws_role_name, ssl_verify=ssl_verify
):
verbose_logger.warning(
"AssumeRole failed for %s (%s). "
"Caller is already running as this role; "
"falling back to ambient credentials.",
aws_role_name,
error_str,
)
return self._auth_with_env_vars()
# Genuine permission error — re-raise
verbose_logger.error(
"AssumeRole AccessDenied for %s and caller is NOT "
"the same role. Re-raising. Error: %s",
aws_role_name,
error_str,
)
raise
# Extract the credentials from the response and convert to Session Credentials
sts_credentials = sts_response["Credentials"]
credentials = Credentials(
access_key=sts_credentials["AccessKeyId"],
secret_key=sts_credentials["SecretAccessKey"],
token=sts_credentials["SessionToken"],
)
sts_expiry = sts_credentials["Expiration"]
# Convert to timezone-aware datetime for comparison
current_time = datetime.now(sts_expiry.tzinfo)
sts_ttl = (sts_expiry - current_time).total_seconds() - 60
return credentials, sts_ttl
@tracer.wrap()
def _auth_with_aws_profile(
self, aws_profile_name: str
) -> Tuple[Credentials, Optional[int]]:
"""
Authenticate with AWS profile
"""
import boto3
# uses auth values from AWS profile usually stored in ~/.aws/credentials
with tracer.trace("boto3.Session(profile_name=aws_profile_name)"):
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials(), None
@tracer.wrap()
def _auth_with_aws_session_token(
self,
aws_access_key_id: str,
aws_secret_access_key: str,
aws_session_token: str,
) -> Tuple[Credentials, Optional[int]]:
"""
Authenticate with AWS Session Token
"""
### CHECK FOR AWS SESSION TOKEN ###
from botocore.credentials import Credentials
credentials = Credentials(
access_key=aws_access_key_id,
secret_key=aws_secret_access_key,
token=aws_session_token,
)
return credentials, None
@tracer.wrap()
def _auth_with_access_key_and_secret_key(
self,
aws_access_key_id: str,
aws_secret_access_key: str,
aws_region_name: Optional[str],
) -> Tuple[Credentials, Optional[int]]:
"""
Authenticate with AWS Access Key and Secret Key
"""
import boto3
# Check if credentials are already in cache. These credentials have no expiry time.
with tracer.trace(
"boto3.Session(aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, region_name=aws_region_name)"
):
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
credentials = session.get_credentials()
return credentials, self._get_default_ttl_for_boto3_credentials()
@tracer.wrap()
def _auth_with_env_vars(self) -> Tuple[Credentials, Optional[int]]:
"""
Authenticate with AWS Environment Variables
"""
import boto3
with tracer.trace("boto3.Session()"):
session = boto3.Session()
credentials = session.get_credentials()
return credentials, None
@tracer.wrap()
def _get_default_ttl_for_boto3_credentials(self) -> int:
"""
Get the default TTL for boto3 credentials
Returns `3600-60` which is 59 minutes
"""
return 3600 - 60
def get_runtime_endpoint(
self,
api_base: Optional[str],
aws_bedrock_runtime_endpoint: Optional[str],
aws_region_name: str,
endpoint_type: Optional[Literal["runtime", "agent", "agentcore"]] = "runtime",
) -> Tuple[str, str]:
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if api_base is not None:
endpoint_url = api_base
elif aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
endpoint_url = aws_bedrock_runtime_endpoint
elif env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = self._select_default_endpoint_url(
endpoint_type=endpoint_type,
aws_region_name=aws_region_name,
)
# Determine proxy_endpoint_url
if aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
proxy_endpoint_url = aws_bedrock_runtime_endpoint
elif env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
proxy_endpoint_url = env_aws_bedrock_runtime_endpoint
else:
proxy_endpoint_url = endpoint_url
return endpoint_url, proxy_endpoint_url
def _select_default_endpoint_url(
self,
endpoint_type: Optional[Literal["runtime", "agent", "agentcore"]],
aws_region_name: str,
) -> str:
"""
Select the default endpoint url based on the endpoint type
Default endpoint url is https://bedrock-runtime.{aws_region_name}.amazonaws.com
"""
if endpoint_type == "agent":
return f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com"
elif endpoint_type == "agentcore":
return f"https://bedrock-agentcore.{aws_region_name}.amazonaws.com"
else:
return f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
def _get_boto_credentials_from_optional_params(
self, optional_params: dict, model: Optional[str] = None
) -> Boto3CredentialsInfo:
"""
Get boto3 credentials from optional params
Args:
optional_params (dict): Optional parameters for the model call
Returns:
Credentials: Boto3 credentials object
"""
try:
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = self._get_aws_region_name(optional_params, model)
optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
aws_external_id = optional_params.pop("aws_external_id", None)
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
aws_web_identity_token=aws_web_identity_token,
aws_sts_endpoint=aws_sts_endpoint,
aws_external_id=aws_external_id,
)
return Boto3CredentialsInfo(
credentials=credentials,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
)
@tracer.wrap()
def get_request_headers(
self,
credentials: Credentials,
aws_region_name: str,
extra_headers: Optional[dict],
endpoint_url: str,
data: Union[str, bytes],
headers: dict,
api_key: Optional[str] = None,
) -> AWSPreparedRequest:
if api_key is not None:
aws_bearer_token: Optional[str] = api_key
else:
aws_bearer_token = get_secret_str("AWS_BEARER_TOKEN_BEDROCK")
if aws_bearer_token:
try:
from botocore.awsrequest import AWSRequest
except ImportError:
raise ImportError(
"Missing boto3 to call bedrock. Run 'pip install boto3'."
)
headers["Authorization"] = f"Bearer {aws_bearer_token}"
request = AWSRequest(
method="POST", url=endpoint_url, data=data, headers=headers
)
else:
try:
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
except ImportError:
raise ImportError(
"Missing boto3 to call bedrock. Run 'pip install boto3'."
)
# Filter headers for AWS signature calculation
# AWS SigV4 only includes specific headers in signature calculation
aws_signature_headers = self._filter_headers_for_aws_signature(headers)
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
request = AWSRequest(
method="POST",
url=endpoint_url,
data=data,
headers=aws_signature_headers,
)
sigv4.add_auth(request)
# Add back all original headers (including forwarded ones) after signature calculation
for header_name, header_value in headers.items():
request.headers[header_name] = header_value
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
prepped = request.prepare()
return prepped
def _filter_headers_for_aws_signature(self, headers: dict) -> dict:
"""
Filter headers to only include those that AWS SigV4 includes in signature calculation.
This Fixes forwarded client headers from breaking the signature calculation.
"""
aws_signature_headers = {}
aws_headers = {
"host",
"content-type",
"date",
"x-amz-date",
"x-amz-security-token",
"x-amz-content-sha256",
"x-amz-algorithm",
"x-amz-credential",
"x-amz-signedheaders",
"x-amz-signature",
}
for header_name, header_value in headers.items():
header_lower = header_name.lower()
if (
header_lower in aws_headers
or header_lower.startswith("x-amz-")
or header_lower.startswith("x-amzn-")
):
aws_signature_headers[header_name] = header_value
return aws_signature_headers
def _sign_request(
self,
service_name: Literal["bedrock", "sagemaker", "bedrock-agentcore", "s3vectors"],
headers: dict,
optional_params: dict,
request_data: dict,
api_base: str,
model: Optional[str] = None,
stream: Optional[bool] = None,
fake_stream: Optional[bool] = None,
api_key: Optional[str] = None,
) -> Tuple[dict, Optional[bytes]]:
"""
Sign a request for Bedrock or Sagemaker
Returns:
Tuple[dict, Optional[str]]: A tuple containing the headers and the json str body of the request
"""
if api_key is not None:
aws_bearer_token: Optional[str] = api_key
else:
aws_bearer_token = get_secret_str("AWS_BEARER_TOKEN_BEDROCK")
# If aws bearer token is set, use it directly in the header
if aws_bearer_token:
headers = headers or {}
headers["Content-Type"] = "application/json"
headers["Authorization"] = f"Bearer {aws_bearer_token}"
return headers, json.dumps(request_data).encode()
# If no bearer token is set, proceed with the existing SigV4 authentication
try:
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.get("aws_secret_access_key", None)
aws_access_key_id = optional_params.get("aws_access_key_id", None)
aws_session_token = optional_params.get("aws_session_token", None)
aws_role_name = optional_params.get("aws_role_name", None)
aws_session_name = optional_params.get("aws_session_name", None)
aws_profile_name = optional_params.get("aws_profile_name", None)
aws_web_identity_token = optional_params.get("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.get("aws_sts_endpoint", None)
aws_external_id = optional_params.get("aws_external_id", None)
aws_region_name = self._get_aws_region_name(
optional_params=optional_params, model=model
)
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
aws_web_identity_token=aws_web_identity_token,
aws_sts_endpoint=aws_sts_endpoint,
aws_external_id=aws_external_id,
)
sigv4 = SigV4Auth(credentials, service_name, aws_region_name)
if headers is not None:
headers = {"Content-Type": "application/json", **headers}
else:
headers = {"Content-Type": "application/json"}
aws_signature_headers = self._filter_headers_for_aws_signature(headers)
request = AWSRequest(
method="POST",
url=api_base,
data=json.dumps(request_data),
headers=aws_signature_headers,
)
sigv4.add_auth(request)
request_headers_dict = dict(request.headers)
# Add back original headers after signing. Only headers in SignedHeaders
# are integrity-protected; forwarded headers (x-forwarded-*) must remain unsigned.
for header_name, header_value in headers.items():
request_headers_dict[header_name] = header_value
if (
headers is not None and "Authorization" in headers
): # prevent sigv4 from overwriting the auth header
request_headers_dict["Authorization"] = headers["Authorization"]
return request_headers_dict, request.body