chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,549 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List, Literal, Optional, Union, cast
|
||||
|
||||
from httpx import Headers, Response
|
||||
|
||||
from litellm.llms.base_llm.batches.transformation import BaseBatchesConfig
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.bedrock import (
|
||||
BedrockCreateBatchRequest,
|
||||
BedrockCreateBatchResponse,
|
||||
BedrockInputDataConfig,
|
||||
BedrockOutputDataConfig,
|
||||
BedrockS3InputDataConfig,
|
||||
BedrockS3OutputDataConfig,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
CreateBatchRequest,
|
||||
)
|
||||
from litellm.types.utils import LiteLLMBatch, LlmProviders
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from ..common_utils import CommonBatchFilesUtils
|
||||
|
||||
|
||||
class BedrockBatchesConfig(BaseAWSLLM, BaseBatchesConfig):
|
||||
"""
|
||||
Config for Bedrock Batches - handles batch job creation and management for Bedrock
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.common_utils = CommonBatchFilesUtils()
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.BEDROCK
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate and prepare environment for Bedrock batch requests.
|
||||
AWS credentials are handled by BaseAWSLLM.
|
||||
"""
|
||||
# Add any Bedrock-specific headers if needed
|
||||
return headers
|
||||
|
||||
def get_complete_batch_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
data: CreateBatchRequest,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Bedrock batch creation.
|
||||
Bedrock batch jobs are created via the model invocation job API.
|
||||
"""
|
||||
aws_region_name = self._get_aws_region_name(optional_params, model)
|
||||
|
||||
# Bedrock model invocation job endpoint
|
||||
# Format: https://bedrock.{region}.amazonaws.com/model-invocation-job
|
||||
bedrock_endpoint = (
|
||||
f"https://bedrock.{aws_region_name}.amazonaws.com/model-invocation-job"
|
||||
)
|
||||
|
||||
return bedrock_endpoint
|
||||
|
||||
def transform_create_batch_request(
|
||||
self,
|
||||
model: str,
|
||||
create_batch_data: CreateBatchRequest,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform the batch creation request to Bedrock format.
|
||||
|
||||
Bedrock batch inference requires:
|
||||
- modelId: The Bedrock model ID
|
||||
- jobName: Unique name for the batch job
|
||||
- inputDataConfig: Configuration for input data (S3 location)
|
||||
- outputDataConfig: Configuration for output data (S3 location)
|
||||
- roleArn: IAM role ARN for the batch job
|
||||
"""
|
||||
# Get required parameters
|
||||
input_file_id = create_batch_data.get("input_file_id")
|
||||
if not input_file_id:
|
||||
raise ValueError("input_file_id is required for Bedrock batch creation")
|
||||
|
||||
# Extract S3 information from file ID using common utility
|
||||
input_bucket, input_key = self.common_utils.parse_s3_uri(input_file_id)
|
||||
|
||||
# Get output S3 configuration
|
||||
output_bucket = litellm_params.get("s3_output_bucket_name") or os.getenv(
|
||||
"AWS_S3_OUTPUT_BUCKET_NAME"
|
||||
)
|
||||
if not output_bucket:
|
||||
# Use same bucket as input if no output bucket specified
|
||||
output_bucket = input_bucket
|
||||
|
||||
# Get IAM role ARN
|
||||
role_arn = (
|
||||
litellm_params.get("aws_batch_role_arn")
|
||||
or optional_params.get("aws_batch_role_arn")
|
||||
or os.getenv("AWS_BATCH_ROLE_ARN")
|
||||
)
|
||||
if not role_arn:
|
||||
raise ValueError(
|
||||
"AWS IAM role ARN is required for Bedrock batch jobs. "
|
||||
"Set 'aws_batch_role_arn' in litellm_params or AWS_BATCH_ROLE_ARN env var"
|
||||
)
|
||||
|
||||
if not model:
|
||||
raise ValueError(
|
||||
"Could not determine Bedrock model ID. Please pass `model` in your request body."
|
||||
)
|
||||
|
||||
# Generate job name with the correct model ID using common utility
|
||||
job_name = self.common_utils.generate_unique_job_name(model, prefix="litellm")
|
||||
output_key = f"litellm-batch-outputs/{job_name}/"
|
||||
|
||||
# Build input data config
|
||||
input_data_config: BedrockInputDataConfig = {
|
||||
"s3InputDataConfig": BedrockS3InputDataConfig(
|
||||
s3Uri=f"s3://{input_bucket}/{input_key}"
|
||||
)
|
||||
}
|
||||
|
||||
# Build output data config
|
||||
s3_output_config: BedrockS3OutputDataConfig = BedrockS3OutputDataConfig(
|
||||
s3Uri=f"s3://{output_bucket}/{output_key}"
|
||||
)
|
||||
|
||||
# Add optional KMS encryption key ID if provided
|
||||
s3_encryption_key_id = litellm_params.get(
|
||||
"s3_encryption_key_id"
|
||||
) or get_secret_str("AWS_S3_ENCRYPTION_KEY_ID")
|
||||
if s3_encryption_key_id:
|
||||
s3_output_config["s3EncryptionKeyId"] = s3_encryption_key_id
|
||||
|
||||
output_data_config: BedrockOutputDataConfig = {
|
||||
"s3OutputDataConfig": s3_output_config
|
||||
}
|
||||
|
||||
# Create Bedrock batch request with proper typing
|
||||
bedrock_request: BedrockCreateBatchRequest = {
|
||||
"modelId": model,
|
||||
"jobName": job_name,
|
||||
"inputDataConfig": input_data_config,
|
||||
"outputDataConfig": output_data_config,
|
||||
"roleArn": role_arn,
|
||||
}
|
||||
|
||||
# Add optional parameters if provided
|
||||
completion_window = create_batch_data.get("completion_window")
|
||||
if completion_window:
|
||||
# Map OpenAI completion window to Bedrock timeout
|
||||
# OpenAI uses "24h", Bedrock expects timeout in hours
|
||||
if completion_window == "24h":
|
||||
bedrock_request["timeoutDurationInHours"] = 24
|
||||
|
||||
# For Bedrock, we need to return a pre-signed request with AWS auth headers
|
||||
# Use common utility for AWS signing
|
||||
endpoint_url = f"https://bedrock.{self._get_aws_region_name(optional_params, model)}.amazonaws.com/model-invocation-job"
|
||||
signed_headers, signed_data = self.common_utils.sign_aws_request(
|
||||
service_name="bedrock",
|
||||
data=bedrock_request,
|
||||
endpoint_url=endpoint_url,
|
||||
optional_params=optional_params,
|
||||
method="POST",
|
||||
)
|
||||
|
||||
# Return a pre-signed request format that the HTTP handler can use
|
||||
return {
|
||||
"method": "POST",
|
||||
"url": endpoint_url,
|
||||
"headers": signed_headers,
|
||||
"data": signed_data.decode("utf-8"),
|
||||
}
|
||||
|
||||
def transform_create_batch_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: Response,
|
||||
logging_obj: Any,
|
||||
litellm_params: dict,
|
||||
) -> LiteLLMBatch:
|
||||
"""
|
||||
Transform Bedrock batch creation response to LiteLLM format.
|
||||
"""
|
||||
try:
|
||||
response_data: BedrockCreateBatchResponse = raw_response.json()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse Bedrock batch response: {e}")
|
||||
|
||||
# Extract information from typed Bedrock response
|
||||
job_arn = response_data.get("jobArn", "")
|
||||
status_str: str = str(response_data.get("status", "Submitted"))
|
||||
|
||||
# Map Bedrock status to OpenAI-compatible status
|
||||
status_mapping: Dict[str, str] = {
|
||||
"Submitted": "validating",
|
||||
"Validating": "validating",
|
||||
"Scheduled": "in_progress",
|
||||
"InProgress": "in_progress",
|
||||
"PartiallyCompleted": "completed",
|
||||
"Completed": "completed",
|
||||
"Failed": "failed",
|
||||
"Stopping": "cancelling",
|
||||
"Stopped": "cancelled",
|
||||
"Expired": "expired",
|
||||
}
|
||||
|
||||
openai_status = cast(
|
||||
Literal[
|
||||
"validating",
|
||||
"failed",
|
||||
"in_progress",
|
||||
"finalizing",
|
||||
"completed",
|
||||
"expired",
|
||||
"cancelling",
|
||||
"cancelled",
|
||||
],
|
||||
status_mapping.get(status_str, "validating"),
|
||||
)
|
||||
|
||||
# Get original request data from litellm_params if available
|
||||
original_request = litellm_params.get("original_batch_request", {})
|
||||
|
||||
# Create LiteLLM batch object
|
||||
return LiteLLMBatch(
|
||||
id=job_arn, # Use ARN as the batch ID
|
||||
object="batch",
|
||||
endpoint=original_request.get("endpoint", "/v1/chat/completions"),
|
||||
errors=None,
|
||||
input_file_id=original_request.get("input_file_id", ""),
|
||||
completion_window=original_request.get("completion_window", "24h"),
|
||||
status=openai_status,
|
||||
output_file_id=None, # Will be populated when job completes
|
||||
error_file_id=None,
|
||||
created_at=int(time.time()),
|
||||
in_progress_at=int(time.time()) if status_str == "InProgress" else None,
|
||||
expires_at=None,
|
||||
finalizing_at=None,
|
||||
completed_at=None,
|
||||
failed_at=None,
|
||||
expired_at=None,
|
||||
cancelling_at=None,
|
||||
cancelled_at=None,
|
||||
request_counts=None,
|
||||
metadata=original_request.get("metadata", {}),
|
||||
)
|
||||
|
||||
def transform_retrieve_batch_request(
|
||||
self,
|
||||
batch_id: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform batch retrieval request for Bedrock.
|
||||
|
||||
Args:
|
||||
batch_id: Bedrock job ARN
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters
|
||||
|
||||
Returns:
|
||||
Transformed request data for Bedrock GetModelInvocationJob API
|
||||
"""
|
||||
# For Bedrock, batch_id should be the full job ARN
|
||||
# The GetModelInvocationJob API expects the full ARN as the identifier
|
||||
if not batch_id.startswith("arn:aws:bedrock:"):
|
||||
raise ValueError(f"Invalid batch_id format. Expected ARN, got: {batch_id}")
|
||||
|
||||
# Extract the job identifier from the ARN - use the full ARN path part
|
||||
# ARN format: arn:aws:bedrock:region:account:model-invocation-job/job-name
|
||||
arn_parts = batch_id.split(":")
|
||||
if len(arn_parts) < 6:
|
||||
raise ValueError(f"Invalid ARN format: {batch_id}")
|
||||
|
||||
region = arn_parts[3]
|
||||
# arn_parts[5] contains "model-invocation-job/{jobId}"
|
||||
|
||||
# Build the endpoint URL for GetModelInvocationJob
|
||||
# AWS API format: GET /model-invocation-job/{jobIdentifier}
|
||||
# Use the FULL ARN as jobIdentifier and URL-encode it (includes ':' and '/')
|
||||
import urllib.parse as _ul
|
||||
|
||||
encoded_arn = _ul.quote(batch_id, safe="")
|
||||
endpoint_url = (
|
||||
f"https://bedrock.{region}.amazonaws.com/model-invocation-job/{encoded_arn}"
|
||||
)
|
||||
|
||||
# Use common utility for AWS signing
|
||||
signed_headers, _ = self.common_utils.sign_aws_request(
|
||||
service_name="bedrock",
|
||||
data={}, # GET request has no body
|
||||
endpoint_url=endpoint_url,
|
||||
optional_params=optional_params,
|
||||
method="GET",
|
||||
)
|
||||
|
||||
# Return pre-signed request format
|
||||
return {
|
||||
"method": "GET",
|
||||
"url": endpoint_url,
|
||||
"headers": signed_headers,
|
||||
"data": None,
|
||||
}
|
||||
|
||||
def _parse_timestamps_and_status(self, response_data, status_str: str):
|
||||
"""Helper to parse timestamps based on status."""
|
||||
import datetime
|
||||
|
||||
def parse_timestamp(ts_str: Optional[str]) -> Optional[int]:
|
||||
if not ts_str:
|
||||
return None
|
||||
try:
|
||||
dt = datetime.datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
|
||||
return int(dt.timestamp())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
created_at = parse_timestamp(
|
||||
str(response_data.get("submitTime"))
|
||||
if response_data.get("submitTime") is not None
|
||||
else None
|
||||
)
|
||||
in_progress_states = {"InProgress", "Validating", "Scheduled"}
|
||||
in_progress_at = (
|
||||
parse_timestamp(
|
||||
str(response_data.get("lastModifiedTime"))
|
||||
if response_data.get("lastModifiedTime") is not None
|
||||
else None
|
||||
)
|
||||
if status_str in in_progress_states
|
||||
else None
|
||||
)
|
||||
completed_at = (
|
||||
parse_timestamp(
|
||||
str(response_data.get("endTime"))
|
||||
if response_data.get("endTime") is not None
|
||||
else None
|
||||
)
|
||||
if status_str in {"Completed", "PartiallyCompleted"}
|
||||
else None
|
||||
)
|
||||
failed_at = (
|
||||
parse_timestamp(
|
||||
str(response_data.get("endTime"))
|
||||
if response_data.get("endTime") is not None
|
||||
else None
|
||||
)
|
||||
if status_str == "Failed"
|
||||
else None
|
||||
)
|
||||
cancelled_at = (
|
||||
parse_timestamp(
|
||||
str(response_data.get("endTime"))
|
||||
if response_data.get("endTime") is not None
|
||||
else None
|
||||
)
|
||||
if status_str == "Stopped"
|
||||
else None
|
||||
)
|
||||
expires_at = parse_timestamp(
|
||||
str(response_data.get("jobExpirationTime"))
|
||||
if response_data.get("jobExpirationTime") is not None
|
||||
else None
|
||||
)
|
||||
|
||||
return (
|
||||
created_at,
|
||||
in_progress_at,
|
||||
completed_at,
|
||||
failed_at,
|
||||
cancelled_at,
|
||||
expires_at,
|
||||
)
|
||||
|
||||
def _extract_file_configs(self, response_data):
|
||||
"""Helper to extract input and output file configurations."""
|
||||
# Extract input file ID
|
||||
input_file_id = ""
|
||||
input_data_config = response_data.get("inputDataConfig", {})
|
||||
if isinstance(input_data_config, dict):
|
||||
s3_input_config = input_data_config.get("s3InputDataConfig", {})
|
||||
if isinstance(s3_input_config, dict):
|
||||
input_file_id = s3_input_config.get("s3Uri", "")
|
||||
|
||||
# Extract output file ID
|
||||
output_file_id = None
|
||||
output_data_config = response_data.get("outputDataConfig", {})
|
||||
if isinstance(output_data_config, dict):
|
||||
s3_output_config = output_data_config.get("s3OutputDataConfig", {})
|
||||
if isinstance(s3_output_config, dict):
|
||||
output_file_id = s3_output_config.get("s3Uri", "")
|
||||
|
||||
return input_file_id, output_file_id
|
||||
|
||||
def _extract_errors_and_metadata(self, response_data, raw_response):
|
||||
"""Helper to extract errors and enriched metadata."""
|
||||
# Extract errors
|
||||
message = response_data.get("message")
|
||||
errors = None
|
||||
if message:
|
||||
from openai.types.batch import Errors
|
||||
from openai.types.batch_error import BatchError
|
||||
|
||||
errors = Errors(
|
||||
data=[BatchError(message=message, code=str(raw_response.status_code))],
|
||||
object="list",
|
||||
)
|
||||
|
||||
# Enrich metadata with useful Bedrock fields
|
||||
enriched_metadata_raw: Dict[str, Any] = {
|
||||
"jobName": response_data.get("jobName"),
|
||||
"clientRequestToken": response_data.get("clientRequestToken"),
|
||||
"modelId": response_data.get("modelId"),
|
||||
"roleArn": response_data.get("roleArn"),
|
||||
"timeoutDurationInHours": response_data.get("timeoutDurationInHours"),
|
||||
"vpcConfig": response_data.get("vpcConfig"),
|
||||
}
|
||||
import json as _json
|
||||
|
||||
enriched_metadata: Dict[str, str] = {}
|
||||
for _k, _v in enriched_metadata_raw.items():
|
||||
if _v is None:
|
||||
continue
|
||||
if isinstance(_v, (dict, list)):
|
||||
try:
|
||||
enriched_metadata[_k] = _json.dumps(_v)
|
||||
except Exception:
|
||||
enriched_metadata[_k] = str(_v)
|
||||
else:
|
||||
enriched_metadata[_k] = str(_v)
|
||||
|
||||
return errors, enriched_metadata
|
||||
|
||||
def transform_retrieve_batch_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: Response,
|
||||
logging_obj: Any,
|
||||
litellm_params: dict,
|
||||
) -> LiteLLMBatch:
|
||||
"""
|
||||
Transform Bedrock batch retrieval response to LiteLLM format.
|
||||
"""
|
||||
from litellm.types.llms.bedrock import BedrockGetBatchResponse
|
||||
|
||||
try:
|
||||
response_data: BedrockGetBatchResponse = raw_response.json()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse Bedrock batch response: {e}")
|
||||
|
||||
job_arn = response_data.get("jobArn", "")
|
||||
status_str: str = str(response_data.get("status", "Submitted"))
|
||||
|
||||
# Map Bedrock status to OpenAI-compatible status
|
||||
status_mapping: Dict[str, str] = {
|
||||
"Submitted": "validating",
|
||||
"Validating": "validating",
|
||||
"Scheduled": "in_progress",
|
||||
"InProgress": "in_progress",
|
||||
"PartiallyCompleted": "completed",
|
||||
"Completed": "completed",
|
||||
"Failed": "failed",
|
||||
"Stopping": "cancelling",
|
||||
"Stopped": "cancelled",
|
||||
"Expired": "expired",
|
||||
}
|
||||
openai_status = cast(
|
||||
Literal[
|
||||
"validating",
|
||||
"failed",
|
||||
"in_progress",
|
||||
"finalizing",
|
||||
"completed",
|
||||
"expired",
|
||||
"cancelling",
|
||||
"cancelled",
|
||||
],
|
||||
status_mapping.get(status_str, "validating"),
|
||||
)
|
||||
|
||||
# Parse timestamps
|
||||
(
|
||||
created_at,
|
||||
in_progress_at,
|
||||
completed_at,
|
||||
failed_at,
|
||||
cancelled_at,
|
||||
expires_at,
|
||||
) = self._parse_timestamps_and_status(response_data, status_str)
|
||||
|
||||
# Extract file configurations
|
||||
input_file_id, output_file_id = self._extract_file_configs(response_data)
|
||||
|
||||
# Extract errors and metadata
|
||||
errors, enriched_metadata = self._extract_errors_and_metadata(
|
||||
response_data, raw_response
|
||||
)
|
||||
|
||||
return LiteLLMBatch(
|
||||
id=job_arn,
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
errors=errors,
|
||||
input_file_id=input_file_id,
|
||||
completion_window="24h",
|
||||
status=openai_status,
|
||||
output_file_id=output_file_id,
|
||||
error_file_id=None,
|
||||
created_at=created_at or int(time.time()),
|
||||
in_progress_at=in_progress_at,
|
||||
expires_at=expires_at,
|
||||
finalizing_at=None,
|
||||
completed_at=completed_at,
|
||||
failed_at=failed_at,
|
||||
expired_at=None,
|
||||
cancelling_at=None,
|
||||
cancelled_at=cancelled_at,
|
||||
request_counts=None,
|
||||
metadata=enriched_metadata,
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
|
||||
) -> BaseLLMException:
|
||||
"""
|
||||
Get Bedrock-specific error class using common utility.
|
||||
"""
|
||||
return self.common_utils.get_error_class(error_message, status_code, headers)
|
||||
Reference in New Issue
Block a user