Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/llms/bedrock/batches/transformation.py
2026-03-26 20:06:14 +08:00

550 lines
19 KiB
Python

import os
import time
from typing import Any, Dict, List, Literal, Optional, Union, cast
from httpx import Headers, Response
from litellm.llms.base_llm.batches.transformation import BaseBatchesConfig
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.bedrock import (
BedrockCreateBatchRequest,
BedrockCreateBatchResponse,
BedrockInputDataConfig,
BedrockOutputDataConfig,
BedrockS3InputDataConfig,
BedrockS3OutputDataConfig,
)
from litellm.types.llms.openai import (
AllMessageValues,
CreateBatchRequest,
)
from litellm.types.utils import LiteLLMBatch, LlmProviders
from ..base_aws_llm import BaseAWSLLM
from ..common_utils import CommonBatchFilesUtils
class BedrockBatchesConfig(BaseAWSLLM, BaseBatchesConfig):
"""
Config for Bedrock Batches - handles batch job creation and management for Bedrock
"""
def __init__(self):
super().__init__()
self.common_utils = CommonBatchFilesUtils()
@property
def custom_llm_provider(self) -> LlmProviders:
return LlmProviders.BEDROCK
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate and prepare environment for Bedrock batch requests.
AWS credentials are handled by BaseAWSLLM.
"""
# Add any Bedrock-specific headers if needed
return headers
def get_complete_batch_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: Dict,
litellm_params: Dict,
data: CreateBatchRequest,
) -> str:
"""
Get the complete URL for Bedrock batch creation.
Bedrock batch jobs are created via the model invocation job API.
"""
aws_region_name = self._get_aws_region_name(optional_params, model)
# Bedrock model invocation job endpoint
# Format: https://bedrock.{region}.amazonaws.com/model-invocation-job
bedrock_endpoint = (
f"https://bedrock.{aws_region_name}.amazonaws.com/model-invocation-job"
)
return bedrock_endpoint
def transform_create_batch_request(
self,
model: str,
create_batch_data: CreateBatchRequest,
optional_params: dict,
litellm_params: dict,
) -> Dict[str, Any]:
"""
Transform the batch creation request to Bedrock format.
Bedrock batch inference requires:
- modelId: The Bedrock model ID
- jobName: Unique name for the batch job
- inputDataConfig: Configuration for input data (S3 location)
- outputDataConfig: Configuration for output data (S3 location)
- roleArn: IAM role ARN for the batch job
"""
# Get required parameters
input_file_id = create_batch_data.get("input_file_id")
if not input_file_id:
raise ValueError("input_file_id is required for Bedrock batch creation")
# Extract S3 information from file ID using common utility
input_bucket, input_key = self.common_utils.parse_s3_uri(input_file_id)
# Get output S3 configuration
output_bucket = litellm_params.get("s3_output_bucket_name") or os.getenv(
"AWS_S3_OUTPUT_BUCKET_NAME"
)
if not output_bucket:
# Use same bucket as input if no output bucket specified
output_bucket = input_bucket
# Get IAM role ARN
role_arn = (
litellm_params.get("aws_batch_role_arn")
or optional_params.get("aws_batch_role_arn")
or os.getenv("AWS_BATCH_ROLE_ARN")
)
if not role_arn:
raise ValueError(
"AWS IAM role ARN is required for Bedrock batch jobs. "
"Set 'aws_batch_role_arn' in litellm_params or AWS_BATCH_ROLE_ARN env var"
)
if not model:
raise ValueError(
"Could not determine Bedrock model ID. Please pass `model` in your request body."
)
# Generate job name with the correct model ID using common utility
job_name = self.common_utils.generate_unique_job_name(model, prefix="litellm")
output_key = f"litellm-batch-outputs/{job_name}/"
# Build input data config
input_data_config: BedrockInputDataConfig = {
"s3InputDataConfig": BedrockS3InputDataConfig(
s3Uri=f"s3://{input_bucket}/{input_key}"
)
}
# Build output data config
s3_output_config: BedrockS3OutputDataConfig = BedrockS3OutputDataConfig(
s3Uri=f"s3://{output_bucket}/{output_key}"
)
# Add optional KMS encryption key ID if provided
s3_encryption_key_id = litellm_params.get(
"s3_encryption_key_id"
) or get_secret_str("AWS_S3_ENCRYPTION_KEY_ID")
if s3_encryption_key_id:
s3_output_config["s3EncryptionKeyId"] = s3_encryption_key_id
output_data_config: BedrockOutputDataConfig = {
"s3OutputDataConfig": s3_output_config
}
# Create Bedrock batch request with proper typing
bedrock_request: BedrockCreateBatchRequest = {
"modelId": model,
"jobName": job_name,
"inputDataConfig": input_data_config,
"outputDataConfig": output_data_config,
"roleArn": role_arn,
}
# Add optional parameters if provided
completion_window = create_batch_data.get("completion_window")
if completion_window:
# Map OpenAI completion window to Bedrock timeout
# OpenAI uses "24h", Bedrock expects timeout in hours
if completion_window == "24h":
bedrock_request["timeoutDurationInHours"] = 24
# For Bedrock, we need to return a pre-signed request with AWS auth headers
# Use common utility for AWS signing
endpoint_url = f"https://bedrock.{self._get_aws_region_name(optional_params, model)}.amazonaws.com/model-invocation-job"
signed_headers, signed_data = self.common_utils.sign_aws_request(
service_name="bedrock",
data=bedrock_request,
endpoint_url=endpoint_url,
optional_params=optional_params,
method="POST",
)
# Return a pre-signed request format that the HTTP handler can use
return {
"method": "POST",
"url": endpoint_url,
"headers": signed_headers,
"data": signed_data.decode("utf-8"),
}
def transform_create_batch_response(
self,
model: Optional[str],
raw_response: Response,
logging_obj: Any,
litellm_params: dict,
) -> LiteLLMBatch:
"""
Transform Bedrock batch creation response to LiteLLM format.
"""
try:
response_data: BedrockCreateBatchResponse = raw_response.json()
except Exception as e:
raise ValueError(f"Failed to parse Bedrock batch response: {e}")
# Extract information from typed Bedrock response
job_arn = response_data.get("jobArn", "")
status_str: str = str(response_data.get("status", "Submitted"))
# Map Bedrock status to OpenAI-compatible status
status_mapping: Dict[str, str] = {
"Submitted": "validating",
"Validating": "validating",
"Scheduled": "in_progress",
"InProgress": "in_progress",
"PartiallyCompleted": "completed",
"Completed": "completed",
"Failed": "failed",
"Stopping": "cancelling",
"Stopped": "cancelled",
"Expired": "expired",
}
openai_status = cast(
Literal[
"validating",
"failed",
"in_progress",
"finalizing",
"completed",
"expired",
"cancelling",
"cancelled",
],
status_mapping.get(status_str, "validating"),
)
# Get original request data from litellm_params if available
original_request = litellm_params.get("original_batch_request", {})
# Create LiteLLM batch object
return LiteLLMBatch(
id=job_arn, # Use ARN as the batch ID
object="batch",
endpoint=original_request.get("endpoint", "/v1/chat/completions"),
errors=None,
input_file_id=original_request.get("input_file_id", ""),
completion_window=original_request.get("completion_window", "24h"),
status=openai_status,
output_file_id=None, # Will be populated when job completes
error_file_id=None,
created_at=int(time.time()),
in_progress_at=int(time.time()) if status_str == "InProgress" else None,
expires_at=None,
finalizing_at=None,
completed_at=None,
failed_at=None,
expired_at=None,
cancelling_at=None,
cancelled_at=None,
request_counts=None,
metadata=original_request.get("metadata", {}),
)
def transform_retrieve_batch_request(
self,
batch_id: str,
optional_params: dict,
litellm_params: dict,
) -> Dict[str, Any]:
"""
Transform batch retrieval request for Bedrock.
Args:
batch_id: Bedrock job ARN
optional_params: Optional parameters
litellm_params: LiteLLM parameters
Returns:
Transformed request data for Bedrock GetModelInvocationJob API
"""
# For Bedrock, batch_id should be the full job ARN
# The GetModelInvocationJob API expects the full ARN as the identifier
if not batch_id.startswith("arn:aws:bedrock:"):
raise ValueError(f"Invalid batch_id format. Expected ARN, got: {batch_id}")
# Extract the job identifier from the ARN - use the full ARN path part
# ARN format: arn:aws:bedrock:region:account:model-invocation-job/job-name
arn_parts = batch_id.split(":")
if len(arn_parts) < 6:
raise ValueError(f"Invalid ARN format: {batch_id}")
region = arn_parts[3]
# arn_parts[5] contains "model-invocation-job/{jobId}"
# Build the endpoint URL for GetModelInvocationJob
# AWS API format: GET /model-invocation-job/{jobIdentifier}
# Use the FULL ARN as jobIdentifier and URL-encode it (includes ':' and '/')
import urllib.parse as _ul
encoded_arn = _ul.quote(batch_id, safe="")
endpoint_url = (
f"https://bedrock.{region}.amazonaws.com/model-invocation-job/{encoded_arn}"
)
# Use common utility for AWS signing
signed_headers, _ = self.common_utils.sign_aws_request(
service_name="bedrock",
data={}, # GET request has no body
endpoint_url=endpoint_url,
optional_params=optional_params,
method="GET",
)
# Return pre-signed request format
return {
"method": "GET",
"url": endpoint_url,
"headers": signed_headers,
"data": None,
}
def _parse_timestamps_and_status(self, response_data, status_str: str):
"""Helper to parse timestamps based on status."""
import datetime
def parse_timestamp(ts_str: Optional[str]) -> Optional[int]:
if not ts_str:
return None
try:
dt = datetime.datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
return int(dt.timestamp())
except Exception:
return None
created_at = parse_timestamp(
str(response_data.get("submitTime"))
if response_data.get("submitTime") is not None
else None
)
in_progress_states = {"InProgress", "Validating", "Scheduled"}
in_progress_at = (
parse_timestamp(
str(response_data.get("lastModifiedTime"))
if response_data.get("lastModifiedTime") is not None
else None
)
if status_str in in_progress_states
else None
)
completed_at = (
parse_timestamp(
str(response_data.get("endTime"))
if response_data.get("endTime") is not None
else None
)
if status_str in {"Completed", "PartiallyCompleted"}
else None
)
failed_at = (
parse_timestamp(
str(response_data.get("endTime"))
if response_data.get("endTime") is not None
else None
)
if status_str == "Failed"
else None
)
cancelled_at = (
parse_timestamp(
str(response_data.get("endTime"))
if response_data.get("endTime") is not None
else None
)
if status_str == "Stopped"
else None
)
expires_at = parse_timestamp(
str(response_data.get("jobExpirationTime"))
if response_data.get("jobExpirationTime") is not None
else None
)
return (
created_at,
in_progress_at,
completed_at,
failed_at,
cancelled_at,
expires_at,
)
def _extract_file_configs(self, response_data):
"""Helper to extract input and output file configurations."""
# Extract input file ID
input_file_id = ""
input_data_config = response_data.get("inputDataConfig", {})
if isinstance(input_data_config, dict):
s3_input_config = input_data_config.get("s3InputDataConfig", {})
if isinstance(s3_input_config, dict):
input_file_id = s3_input_config.get("s3Uri", "")
# Extract output file ID
output_file_id = None
output_data_config = response_data.get("outputDataConfig", {})
if isinstance(output_data_config, dict):
s3_output_config = output_data_config.get("s3OutputDataConfig", {})
if isinstance(s3_output_config, dict):
output_file_id = s3_output_config.get("s3Uri", "")
return input_file_id, output_file_id
def _extract_errors_and_metadata(self, response_data, raw_response):
"""Helper to extract errors and enriched metadata."""
# Extract errors
message = response_data.get("message")
errors = None
if message:
from openai.types.batch import Errors
from openai.types.batch_error import BatchError
errors = Errors(
data=[BatchError(message=message, code=str(raw_response.status_code))],
object="list",
)
# Enrich metadata with useful Bedrock fields
enriched_metadata_raw: Dict[str, Any] = {
"jobName": response_data.get("jobName"),
"clientRequestToken": response_data.get("clientRequestToken"),
"modelId": response_data.get("modelId"),
"roleArn": response_data.get("roleArn"),
"timeoutDurationInHours": response_data.get("timeoutDurationInHours"),
"vpcConfig": response_data.get("vpcConfig"),
}
import json as _json
enriched_metadata: Dict[str, str] = {}
for _k, _v in enriched_metadata_raw.items():
if _v is None:
continue
if isinstance(_v, (dict, list)):
try:
enriched_metadata[_k] = _json.dumps(_v)
except Exception:
enriched_metadata[_k] = str(_v)
else:
enriched_metadata[_k] = str(_v)
return errors, enriched_metadata
def transform_retrieve_batch_response(
self,
model: Optional[str],
raw_response: Response,
logging_obj: Any,
litellm_params: dict,
) -> LiteLLMBatch:
"""
Transform Bedrock batch retrieval response to LiteLLM format.
"""
from litellm.types.llms.bedrock import BedrockGetBatchResponse
try:
response_data: BedrockGetBatchResponse = raw_response.json()
except Exception as e:
raise ValueError(f"Failed to parse Bedrock batch response: {e}")
job_arn = response_data.get("jobArn", "")
status_str: str = str(response_data.get("status", "Submitted"))
# Map Bedrock status to OpenAI-compatible status
status_mapping: Dict[str, str] = {
"Submitted": "validating",
"Validating": "validating",
"Scheduled": "in_progress",
"InProgress": "in_progress",
"PartiallyCompleted": "completed",
"Completed": "completed",
"Failed": "failed",
"Stopping": "cancelling",
"Stopped": "cancelled",
"Expired": "expired",
}
openai_status = cast(
Literal[
"validating",
"failed",
"in_progress",
"finalizing",
"completed",
"expired",
"cancelling",
"cancelled",
],
status_mapping.get(status_str, "validating"),
)
# Parse timestamps
(
created_at,
in_progress_at,
completed_at,
failed_at,
cancelled_at,
expires_at,
) = self._parse_timestamps_and_status(response_data, status_str)
# Extract file configurations
input_file_id, output_file_id = self._extract_file_configs(response_data)
# Extract errors and metadata
errors, enriched_metadata = self._extract_errors_and_metadata(
response_data, raw_response
)
return LiteLLMBatch(
id=job_arn,
object="batch",
endpoint="/v1/chat/completions",
errors=errors,
input_file_id=input_file_id,
completion_window="24h",
status=openai_status,
output_file_id=output_file_id,
error_file_id=None,
created_at=created_at or int(time.time()),
in_progress_at=in_progress_at,
expires_at=expires_at,
finalizing_at=None,
completed_at=completed_at,
failed_at=failed_at,
expired_at=None,
cancelling_at=None,
cancelled_at=cancelled_at,
request_counts=None,
metadata=enriched_metadata,
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
) -> BaseLLMException:
"""
Get Bedrock-specific error class using common utility.
"""
return self.common_utils.get_error_class(error_message, status_code, headers)