chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,307 @@
|
||||
"""
|
||||
This file contains the handler for AWS Bedrock Nova Sonic realtime API.
|
||||
|
||||
This uses aws_sdk_bedrock_runtime for bidirectional streaming with Nova Sonic.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from .transformation import BedrockRealtimeConfig
|
||||
|
||||
|
||||
class BedrockRealtime(BaseAWSLLM):
|
||||
"""Handler for Bedrock Nova Sonic realtime speech-to-speech API."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def async_realtime(
|
||||
self,
|
||||
model: str,
|
||||
websocket: Any,
|
||||
logging_obj: LiteLLMLogging,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_web_identity_token: Optional[str] = None,
|
||||
aws_sts_endpoint: Optional[str] = None,
|
||||
aws_bedrock_runtime_endpoint: Optional[str] = None,
|
||||
aws_external_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Establish bidirectional streaming connection with Bedrock Nova Sonic.
|
||||
|
||||
Args:
|
||||
model: Model ID (e.g., 'amazon.nova-sonic-v1:0')
|
||||
websocket: Client WebSocket connection
|
||||
logging_obj: LiteLLM logging object
|
||||
aws_region_name: AWS region
|
||||
Various AWS authentication parameters
|
||||
"""
|
||||
try:
|
||||
from aws_sdk_bedrock_runtime.client import (
|
||||
BedrockRuntimeClient,
|
||||
InvokeModelWithBidirectionalStreamOperationInput,
|
||||
)
|
||||
from aws_sdk_bedrock_runtime.config import Config
|
||||
from smithy_aws_core.identity.environment import (
|
||||
EnvironmentCredentialsResolver,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Missing aws_sdk_bedrock_runtime. Install with: pip install aws-sdk-bedrock-runtime"
|
||||
)
|
||||
|
||||
# Get AWS region
|
||||
if aws_region_name is None:
|
||||
optional_params = {
|
||||
"aws_region_name": aws_region_name,
|
||||
}
|
||||
aws_region_name = self._get_aws_region_name(optional_params, model)
|
||||
|
||||
# Get endpoint URL
|
||||
if api_base is not None:
|
||||
endpoint_uri = api_base
|
||||
elif aws_bedrock_runtime_endpoint is not None:
|
||||
endpoint_uri = aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
endpoint_uri = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Bedrock Realtime: Connecting to {endpoint_uri} with model {model}"
|
||||
)
|
||||
|
||||
# Initialize Bedrock client with aws_sdk_bedrock_runtime
|
||||
config = Config(
|
||||
endpoint_uri=endpoint_uri,
|
||||
region=aws_region_name,
|
||||
aws_credentials_identity_resolver=EnvironmentCredentialsResolver(),
|
||||
)
|
||||
bedrock_client = BedrockRuntimeClient(config=config)
|
||||
|
||||
transformation_config = BedrockRealtimeConfig()
|
||||
|
||||
try:
|
||||
# Initialize the bidirectional stream
|
||||
bedrock_stream = (
|
||||
await bedrock_client.invoke_model_with_bidirectional_stream(
|
||||
InvokeModelWithBidirectionalStreamOperationInput(model_id=model)
|
||||
)
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Bedrock Realtime: Bidirectional stream established"
|
||||
)
|
||||
|
||||
# Track state for transformation
|
||||
session_state = {
|
||||
"current_output_item_id": None,
|
||||
"current_response_id": None,
|
||||
"current_conversation_id": None,
|
||||
"current_delta_chunks": None,
|
||||
"current_item_chunks": None,
|
||||
"current_delta_type": None,
|
||||
"session_configuration_request": None,
|
||||
}
|
||||
|
||||
# Create tasks for bidirectional forwarding
|
||||
client_to_bedrock_task = asyncio.create_task(
|
||||
self._forward_client_to_bedrock(
|
||||
websocket,
|
||||
bedrock_stream,
|
||||
transformation_config,
|
||||
model,
|
||||
session_state,
|
||||
)
|
||||
)
|
||||
|
||||
bedrock_to_client_task = asyncio.create_task(
|
||||
self._forward_bedrock_to_client(
|
||||
bedrock_stream,
|
||||
websocket,
|
||||
transformation_config,
|
||||
model,
|
||||
logging_obj,
|
||||
session_state,
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for both tasks to complete
|
||||
await asyncio.gather(
|
||||
client_to_bedrock_task,
|
||||
bedrock_to_client_task,
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error in BedrockRealtime.async_realtime: {e}"
|
||||
)
|
||||
try:
|
||||
await websocket.close(code=1011, reason=f"Internal error: {str(e)}")
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
async def _forward_client_to_bedrock(
|
||||
self,
|
||||
client_ws: Any,
|
||||
bedrock_stream: Any,
|
||||
transformation_config: BedrockRealtimeConfig,
|
||||
model: str,
|
||||
session_state: dict,
|
||||
):
|
||||
"""Forward messages from client WebSocket to Bedrock stream."""
|
||||
try:
|
||||
from aws_sdk_bedrock_runtime.models import (
|
||||
BidirectionalInputPayloadPart,
|
||||
InvokeModelWithBidirectionalStreamInputChunk,
|
||||
)
|
||||
|
||||
while True:
|
||||
# Receive message from client
|
||||
message = await client_ws.receive_text()
|
||||
verbose_proxy_logger.debug(
|
||||
f"Bedrock Realtime: Received from client: {message[:200]}"
|
||||
)
|
||||
|
||||
# Transform OpenAI format to Bedrock format
|
||||
transformed_messages = transformation_config.transform_realtime_request(
|
||||
message=message,
|
||||
model=model,
|
||||
session_configuration_request=session_state.get(
|
||||
"session_configuration_request"
|
||||
),
|
||||
)
|
||||
|
||||
# Send transformed messages to Bedrock
|
||||
for bedrock_message in transformed_messages:
|
||||
event = InvokeModelWithBidirectionalStreamInputChunk(
|
||||
value=BidirectionalInputPayloadPart(
|
||||
bytes_=bedrock_message.encode("utf-8")
|
||||
)
|
||||
)
|
||||
await bedrock_stream.input_stream.send(event)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Bedrock Realtime: Sent to Bedrock: {bedrock_message[:200]}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Client to Bedrock forwarding ended: {e}", exc_info=True
|
||||
)
|
||||
# Close the Bedrock stream input
|
||||
try:
|
||||
await bedrock_stream.input_stream.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _forward_bedrock_to_client(
|
||||
self,
|
||||
bedrock_stream: Any,
|
||||
client_ws: Any,
|
||||
transformation_config: BedrockRealtimeConfig,
|
||||
model: str,
|
||||
logging_obj: LiteLLMLogging,
|
||||
session_state: dict,
|
||||
):
|
||||
"""Forward messages from Bedrock stream to client WebSocket."""
|
||||
try:
|
||||
while True:
|
||||
# Receive from Bedrock
|
||||
output = await bedrock_stream.await_output()
|
||||
result = await output[1].receive()
|
||||
|
||||
if result.value and result.value.bytes_:
|
||||
bedrock_response = result.value.bytes_.decode("utf-8")
|
||||
verbose_proxy_logger.debug(
|
||||
f"Bedrock Realtime: Received from Bedrock: {bedrock_response[:200]}"
|
||||
)
|
||||
|
||||
# Transform Bedrock format to OpenAI format
|
||||
from litellm.types.realtime import RealtimeResponseTransformInput
|
||||
|
||||
realtime_response_transform_input: RealtimeResponseTransformInput = {
|
||||
"current_output_item_id": session_state.get(
|
||||
"current_output_item_id"
|
||||
),
|
||||
"current_response_id": session_state.get("current_response_id"),
|
||||
"current_conversation_id": session_state.get(
|
||||
"current_conversation_id"
|
||||
),
|
||||
"current_delta_chunks": session_state.get(
|
||||
"current_delta_chunks"
|
||||
),
|
||||
"current_item_chunks": session_state.get("current_item_chunks"),
|
||||
"current_delta_type": session_state.get("current_delta_type"),
|
||||
"session_configuration_request": session_state.get(
|
||||
"session_configuration_request"
|
||||
),
|
||||
}
|
||||
|
||||
transformed_response = transformation_config.transform_realtime_response(
|
||||
message=bedrock_response,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
realtime_response_transform_input=realtime_response_transform_input,
|
||||
)
|
||||
|
||||
# Update session state
|
||||
session_state.update(
|
||||
{
|
||||
"current_output_item_id": transformed_response.get(
|
||||
"current_output_item_id"
|
||||
),
|
||||
"current_response_id": transformed_response.get(
|
||||
"current_response_id"
|
||||
),
|
||||
"current_conversation_id": transformed_response.get(
|
||||
"current_conversation_id"
|
||||
),
|
||||
"current_delta_chunks": transformed_response.get(
|
||||
"current_delta_chunks"
|
||||
),
|
||||
"current_item_chunks": transformed_response.get(
|
||||
"current_item_chunks"
|
||||
),
|
||||
"current_delta_type": transformed_response.get(
|
||||
"current_delta_type"
|
||||
),
|
||||
"session_configuration_request": transformed_response.get(
|
||||
"session_configuration_request"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Send transformed messages to client
|
||||
openai_messages = transformed_response.get("response", [])
|
||||
for openai_message in openai_messages:
|
||||
message_json = json.dumps(openai_message)
|
||||
await client_ws.send_text(message_json)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Bedrock Realtime: Sent to client: {message_json[:200]}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Bedrock to client forwarding ended: {e}", exc_info=True
|
||||
)
|
||||
# Close the client WebSocket
|
||||
try:
|
||||
await client_ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user