chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,30 @@
|
||||
from typing import Optional
|
||||
|
||||
from .converse_handler import BedrockConverseLLM
|
||||
from .invoke_handler import (
|
||||
AmazonAnthropicClaudeStreamDecoder,
|
||||
AmazonDeepSeekR1StreamDecoder,
|
||||
AWSEventStreamDecoder,
|
||||
BedrockLLM,
|
||||
)
|
||||
|
||||
|
||||
def get_bedrock_event_stream_decoder(
|
||||
invoke_provider: Optional[str], model: str, sync_stream: bool, json_mode: bool
|
||||
):
|
||||
if invoke_provider and invoke_provider == "anthropic":
|
||||
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
|
||||
model=model,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
return decoder
|
||||
elif invoke_provider and invoke_provider == "deepseek_r1":
|
||||
decoder = AmazonDeepSeekR1StreamDecoder(
|
||||
model=model,
|
||||
sync_stream=sync_stream,
|
||||
)
|
||||
return decoder
|
||||
else:
|
||||
decoder = AWSEventStreamDecoder(model=model)
|
||||
return decoder
|
||||
@@ -0,0 +1,3 @@
|
||||
from .transformation import AmazonAgentCoreConfig
|
||||
|
||||
__all__ = ["AmazonAgentCoreConfig"]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,512 @@
|
||||
import json
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.anthropic_beta_headers_manager import (
|
||||
update_headers_with_filtered_beta,
|
||||
)
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObject
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM, Credentials
|
||||
from ..common_utils import BedrockError, _get_all_bedrock_regions
|
||||
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
|
||||
|
||||
|
||||
def make_sync_call(
|
||||
client: Optional[HTTPHandler],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
json_mode: Optional[bool] = False,
|
||||
fake_stream: bool = False,
|
||||
stream_chunk_size: int = 1024,
|
||||
):
|
||||
if client is None:
|
||||
client = _get_httpx_client() # Create a new client if none provided
|
||||
|
||||
response = client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
stream=not fake_stream,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BedrockError(
|
||||
status_code=response.status_code, message=str(response.read())
|
||||
)
|
||||
|
||||
if fake_stream:
|
||||
model_response: (
|
||||
ModelResponse
|
||||
) = litellm.AmazonConverseConfig()._transform_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=litellm.ModelResponse(),
|
||||
stream=True,
|
||||
logging_obj=logging_obj,
|
||||
optional_params={},
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
encoding=litellm.encoding,
|
||||
) # type: ignore
|
||||
completion_stream: Any = MockResponseIterator(
|
||||
model_response=model_response, json_mode=json_mode
|
||||
)
|
||||
else:
|
||||
decoder = AWSEventStreamDecoder(model=model, json_mode=json_mode)
|
||||
completion_stream = decoder.iter_bytes(
|
||||
response.iter_bytes(chunk_size=stream_chunk_size)
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream
|
||||
|
||||
|
||||
class BedrockConverseLLM(BaseAWSLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
credentials: Credentials,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
fake_stream: bool = False,
|
||||
json_mode: Optional[bool] = False,
|
||||
api_key: Optional[str] = None,
|
||||
stream_chunk_size: int = 1024,
|
||||
) -> CustomStreamWrapper:
|
||||
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
data = json.dumps(request_data)
|
||||
|
||||
prepped = self.get_request_headers(
|
||||
credentials=credentials,
|
||||
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
|
||||
extra_headers=headers,
|
||||
endpoint_url=api_base,
|
||||
data=data,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": dict(prepped.headers),
|
||||
},
|
||||
)
|
||||
|
||||
completion_stream = await make_call(
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=dict(prepped.headers),
|
||||
data=data,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
fake_stream=fake_stream,
|
||||
json_mode=json_mode,
|
||||
stream_chunk_size=stream_chunk_size,
|
||||
)
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
credentials: Credentials,
|
||||
logger_fn=None,
|
||||
headers: dict = {},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
data = json.dumps(request_data)
|
||||
|
||||
prepped = self.get_request_headers(
|
||||
credentials=credentials,
|
||||
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
|
||||
extra_headers=headers,
|
||||
endpoint_url=api_base,
|
||||
data=data,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
|
||||
headers = dict(prepped.headers)
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = get_async_httpx_client(
|
||||
params=_params, llm_provider=litellm.LlmProviders.BEDROCK
|
||||
)
|
||||
else:
|
||||
client = client # type: ignore
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
logging_obj=logging_obj,
|
||||
) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return litellm.AmazonConverseConfig()._transform_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream if isinstance(stream, bool) else False,
|
||||
logging_obj=logging_obj,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def completion( # noqa: PLR0915
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: Optional[str],
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
encoding,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
optional_params: dict,
|
||||
acompletion: bool,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
litellm_params: dict,
|
||||
logger_fn=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
## SETUP ##
|
||||
stream = optional_params.pop("stream", None)
|
||||
stream_chunk_size = optional_params.pop("stream_chunk_size", 1024)
|
||||
unencoded_model_id = optional_params.pop("model_id", None)
|
||||
fake_stream = optional_params.pop("fake_stream", False)
|
||||
json_mode = optional_params.get("json_mode", False)
|
||||
if unencoded_model_id is not None:
|
||||
modelId = self.encode_model_id(model_id=unencoded_model_id)
|
||||
else:
|
||||
# Strip nova spec prefixes before encoding model ID for API URL
|
||||
_model_for_id = model
|
||||
_stripped = _model_for_id
|
||||
for rp in ["bedrock/converse/", "bedrock/", "converse/"]:
|
||||
if _stripped.startswith(rp):
|
||||
_stripped = _stripped[len(rp) :]
|
||||
break
|
||||
# Strip embedded region prefix (e.g. "bedrock/us-east-1/model" -> "model")
|
||||
# and capture it so it can be used as aws_region_name below.
|
||||
_region_from_model: Optional[str] = None
|
||||
_potential_region = _stripped.split("/", 1)[0]
|
||||
if _potential_region in _get_all_bedrock_regions() and "/" in _stripped:
|
||||
_region_from_model = _potential_region
|
||||
_stripped = _stripped.split("/", 1)[1]
|
||||
_model_for_id = _stripped
|
||||
for _nova_prefix in ["nova-2/", "nova/"]:
|
||||
if _stripped.startswith(_nova_prefix):
|
||||
_model_for_id = _model_for_id.replace(_nova_prefix, "", 1)
|
||||
break
|
||||
modelId = self.encode_model_id(model_id=_model_for_id)
|
||||
# Inject region extracted from model path so _get_aws_region_name picks it up
|
||||
if (
|
||||
_region_from_model is not None
|
||||
and "aws_region_name" not in optional_params
|
||||
):
|
||||
optional_params["aws_region_name"] = _region_from_model
|
||||
|
||||
fake_stream = litellm.AmazonConverseConfig().should_fake_stream(
|
||||
fake_stream=fake_stream,
|
||||
model=model,
|
||||
stream=stream,
|
||||
custom_llm_provider="bedrock",
|
||||
)
|
||||
|
||||
### SET REGION NAME ###
|
||||
aws_region_name = self._get_aws_region_name(
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
model_id=unencoded_model_id,
|
||||
)
|
||||
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||
aws_external_id = optional_params.pop("aws_external_id", None)
|
||||
optional_params.pop("aws_region_name", None)
|
||||
|
||||
litellm_params[
|
||||
"aws_region_name"
|
||||
] = aws_region_name # [DO NOT DELETE] important for async calls
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_sts_endpoint=aws_sts_endpoint,
|
||||
aws_external_id=aws_external_id,
|
||||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
if (stream is not None and stream is True) and not fake_stream:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream"
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
|
||||
|
||||
## COMPLETION CALL
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
|
||||
# Filter beta headers in HTTP headers before making the request
|
||||
headers = update_headers_with_filtered_beta(
|
||||
headers=headers, provider="bedrock_converse"
|
||||
)
|
||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||
if acompletion:
|
||||
if isinstance(client, HTTPHandler):
|
||||
client = None
|
||||
if stream is True:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=proxy_endpoint_url,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=True,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
json_mode=json_mode,
|
||||
fake_stream=fake_stream,
|
||||
credentials=credentials,
|
||||
api_key=api_key,
|
||||
stream_chunk_size=stream_chunk_size,
|
||||
) # type: ignore
|
||||
### ASYNC COMPLETION
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=proxy_endpoint_url,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream, # type: ignore
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
credentials=credentials,
|
||||
api_key=api_key,
|
||||
) # type: ignore
|
||||
|
||||
## TRANSFORMATION ##
|
||||
|
||||
_data = litellm.AmazonConverseConfig()._transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=extra_headers,
|
||||
)
|
||||
data = json.dumps(_data)
|
||||
|
||||
prepped = self.get_request_headers(
|
||||
credentials=credentials,
|
||||
aws_region_name=aws_region_name,
|
||||
extra_headers=extra_headers,
|
||||
endpoint_url=proxy_endpoint_url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": proxy_endpoint_url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = _get_httpx_client(_params) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
|
||||
if stream is not None and stream is True:
|
||||
completion_stream = make_sync_call(
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, HTTPHandler)
|
||||
else None
|
||||
),
|
||||
api_base=proxy_endpoint_url,
|
||||
headers=prepped.headers, # type: ignore
|
||||
data=data,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
json_mode=json_mode,
|
||||
fake_stream=fake_stream,
|
||||
stream_chunk_size=stream_chunk_size,
|
||||
)
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return streaming_response
|
||||
|
||||
### COMPLETION
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
url=proxy_endpoint_url,
|
||||
headers=prepped.headers,
|
||||
data=data,
|
||||
logging_obj=logging_obj,
|
||||
) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return litellm.AmazonConverseConfig()._transform_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream if isinstance(stream, bool) else False,
|
||||
logging_obj=logging_obj,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Uses base_llm_http_handler to call the 'converse like' endpoint.
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/8085
|
||||
"""
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Uses `converse_transformation.py` to transform the messages to the format required by Bedrock Converse.
|
||||
"""
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,547 @@
|
||||
"""
|
||||
Transformation for Bedrock Invoke Agent
|
||||
|
||||
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_agent-runtime_InvokeAgent.html
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
from litellm.llms.bedrock.common_utils import BedrockError
|
||||
from litellm.types.llms.bedrock_invoke_agents import (
|
||||
InvokeAgentChunkPayload,
|
||||
InvokeAgentEvent,
|
||||
InvokeAgentEventHeaders,
|
||||
InvokeAgentEventList,
|
||||
InvokeAgentMetadata,
|
||||
InvokeAgentModelInvocationInput,
|
||||
InvokeAgentModelInvocationOutput,
|
||||
InvokeAgentOrchestrationTrace,
|
||||
InvokeAgentPreProcessingTrace,
|
||||
InvokeAgentTrace,
|
||||
InvokeAgentTracePayload,
|
||||
InvokeAgentUsage,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices, Message, ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AmazonInvokeAgentConfig(BaseConfig, BaseAWSLLM):
|
||||
def __init__(self, **kwargs):
|
||||
BaseConfig.__init__(self, **kwargs)
|
||||
BaseAWSLLM.__init__(self, **kwargs)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
This is a base invoke agent model mapping. For Invoke Agent - define a bedrock provider specific config that extends this class.
|
||||
|
||||
Bedrock Invoke Agents has 0 OpenAI compatible params
|
||||
|
||||
As of May 29th, 2025 - they don't support streaming.
|
||||
"""
|
||||
return []
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
This is a base invoke agent model mapping. For Invoke Agent - define a bedrock provider specific config that extends this class.
|
||||
"""
|
||||
return optional_params
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete url for the request
|
||||
"""
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
aws_bedrock_runtime_endpoint = optional_params.get(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
endpoint_url, _ = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=self._get_aws_region_name(
|
||||
optional_params=optional_params, model=model
|
||||
),
|
||||
endpoint_type="agent",
|
||||
)
|
||||
|
||||
agent_id, agent_alias_id = self._get_agent_id_and_alias_id(model)
|
||||
session_id = self._get_session_id(optional_params)
|
||||
|
||||
endpoint_url = f"{endpoint_url}/agents/{agent_id}/agentAliases/{agent_alias_id}/sessions/{session_id}/text"
|
||||
|
||||
return endpoint_url
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
headers: dict,
|
||||
optional_params: dict,
|
||||
request_data: dict,
|
||||
api_base: str,
|
||||
api_key: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
fake_stream: Optional[bool] = None,
|
||||
) -> Tuple[dict, Optional[bytes]]:
|
||||
return self._sign_request(
|
||||
service_name="bedrock",
|
||||
headers=headers,
|
||||
optional_params=optional_params,
|
||||
request_data=request_data,
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
stream=stream,
|
||||
fake_stream=fake_stream,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
def _get_agent_id_and_alias_id(self, model: str) -> tuple[str, str]:
|
||||
"""
|
||||
model = "agent/L1RT58GYRW/MFPSBCXYTW"
|
||||
agent_id = "L1RT58GYRW"
|
||||
agent_alias_id = "MFPSBCXYTW"
|
||||
"""
|
||||
# Split the model string by '/' and extract components
|
||||
parts = model.split("/")
|
||||
if len(parts) != 3 or parts[0] != "agent":
|
||||
raise ValueError(
|
||||
"Invalid model format. Expected format: 'model=agent/AGENT_ID/ALIAS_ID'"
|
||||
)
|
||||
|
||||
return parts[1], parts[2] # Return (agent_id, agent_alias_id)
|
||||
|
||||
def _get_session_id(self, optional_params: dict) -> str:
|
||||
""" """
|
||||
return optional_params.get("sessionID", None) or str(uuid.uuid4())
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
# use the last message content as the query
|
||||
query: str = convert_content_list_to_str(messages[-1])
|
||||
return {
|
||||
"inputText": query,
|
||||
"enableTrace": True,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
def _parse_aws_event_stream(self, raw_content: bytes) -> InvokeAgentEventList:
|
||||
"""
|
||||
Parse AWS event stream format using boto3/botocore's built-in parser.
|
||||
This is the same approach used in the existing AWSEventStreamDecoder.
|
||||
"""
|
||||
try:
|
||||
from botocore.eventstream import EventStreamBuffer
|
||||
from botocore.parsers import EventStreamJSONParser
|
||||
except ImportError:
|
||||
raise ImportError("boto3/botocore is required for AWS event stream parsing")
|
||||
|
||||
events: InvokeAgentEventList = []
|
||||
parser = EventStreamJSONParser()
|
||||
event_stream_buffer = EventStreamBuffer()
|
||||
|
||||
# Add the entire response to the buffer
|
||||
event_stream_buffer.add_data(raw_content)
|
||||
|
||||
# Process all events in the buffer
|
||||
for event in event_stream_buffer:
|
||||
try:
|
||||
headers = self._extract_headers_from_event(event)
|
||||
|
||||
event_type = headers.get("event_type", "")
|
||||
|
||||
if event_type == "chunk":
|
||||
# Handle chunk events specially - they contain decoded content, not JSON
|
||||
message = self._parse_message_from_event(event, parser)
|
||||
parsed_event: InvokeAgentEvent = InvokeAgentEvent()
|
||||
if message:
|
||||
# For chunk events, create a payload with the decoded content
|
||||
parsed_event = {
|
||||
"headers": headers,
|
||||
"payload": {
|
||||
"bytes": base64.b64encode(
|
||||
message.encode("utf-8")
|
||||
).decode("utf-8")
|
||||
}, # Re-encode for consistency
|
||||
}
|
||||
events.append(parsed_event)
|
||||
|
||||
elif event_type == "trace":
|
||||
# Handle trace events normally - they contain JSON
|
||||
message = self._parse_message_from_event(event, parser)
|
||||
|
||||
if message:
|
||||
try:
|
||||
event_data = json.loads(message)
|
||||
parsed_event = {
|
||||
"headers": headers,
|
||||
"payload": event_data,
|
||||
}
|
||||
events.append(parsed_event)
|
||||
except json.JSONDecodeError as e:
|
||||
verbose_logger.warning(
|
||||
f"Failed to parse trace event JSON: {e}"
|
||||
)
|
||||
else:
|
||||
verbose_logger.debug(f"Unknown event type: {event_type}")
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error processing event: {e}")
|
||||
continue
|
||||
|
||||
return events
|
||||
|
||||
def _parse_message_from_event(self, event, parser) -> Optional[str]:
|
||||
"""Extract message content from an AWS event, adapted from AWSEventStreamDecoder."""
|
||||
try:
|
||||
response_dict = event.to_response_dict()
|
||||
verbose_logger.debug(f"Response dict: {response_dict}")
|
||||
|
||||
# Use the same response shape parsing as the existing decoder
|
||||
parsed_response = parser.parse(
|
||||
response_dict, self._get_response_stream_shape()
|
||||
)
|
||||
verbose_logger.debug(f"Parsed response: {parsed_response}")
|
||||
|
||||
if response_dict["status_code"] != 200:
|
||||
decoded_body = response_dict["body"].decode()
|
||||
if isinstance(decoded_body, dict):
|
||||
error_message = decoded_body.get("message")
|
||||
elif isinstance(decoded_body, str):
|
||||
error_message = decoded_body
|
||||
else:
|
||||
error_message = ""
|
||||
exception_status = response_dict["headers"].get(":exception-type")
|
||||
error_message = exception_status + " " + error_message
|
||||
raise BedrockError(
|
||||
status_code=response_dict["status_code"],
|
||||
message=(
|
||||
json.dumps(error_message)
|
||||
if isinstance(error_message, dict)
|
||||
else error_message
|
||||
),
|
||||
)
|
||||
|
||||
if "chunk" in parsed_response:
|
||||
chunk = parsed_response.get("chunk")
|
||||
if not chunk:
|
||||
return None
|
||||
return chunk.get("bytes").decode()
|
||||
else:
|
||||
chunk = response_dict.get("body")
|
||||
if not chunk:
|
||||
return None
|
||||
return chunk.decode()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error parsing message from event: {e}")
|
||||
return None
|
||||
|
||||
def _extract_headers_from_event(self, event) -> InvokeAgentEventHeaders:
|
||||
"""Extract headers from an AWS event for categorization."""
|
||||
try:
|
||||
response_dict = event.to_response_dict()
|
||||
headers = response_dict.get("headers", {})
|
||||
|
||||
# Extract the event-type and content-type headers that we care about
|
||||
return InvokeAgentEventHeaders(
|
||||
event_type=headers.get(":event-type", ""),
|
||||
content_type=headers.get(":content-type", ""),
|
||||
message_type=headers.get(":message-type", ""),
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error extracting headers: {e}")
|
||||
return InvokeAgentEventHeaders(
|
||||
event_type="", content_type="", message_type=""
|
||||
)
|
||||
|
||||
def _get_response_stream_shape(self):
|
||||
"""Get the response stream shape for parsing, reusing existing logic."""
|
||||
try:
|
||||
# Try to reuse the cached shape from the existing decoder
|
||||
from litellm.llms.bedrock.chat.invoke_handler import (
|
||||
get_response_stream_shape,
|
||||
)
|
||||
|
||||
return get_response_stream_shape()
|
||||
except ImportError:
|
||||
# Fallback: create our own shape
|
||||
try:
|
||||
from botocore.loaders import Loader
|
||||
from botocore.model import ServiceModel
|
||||
|
||||
loader = Loader()
|
||||
bedrock_service_dict = loader.load_service_model(
|
||||
"bedrock-runtime", "service-2"
|
||||
)
|
||||
bedrock_service_model = ServiceModel(bedrock_service_dict)
|
||||
return bedrock_service_model.shape_for("ResponseStream")
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Could not load response stream shape: {e}")
|
||||
return None
|
||||
|
||||
def _extract_response_content(self, events: InvokeAgentEventList) -> str:
|
||||
"""Extract the final response content from parsed events."""
|
||||
response_parts = []
|
||||
|
||||
for event in events:
|
||||
headers = event.get("headers", {})
|
||||
payload = event.get("payload")
|
||||
|
||||
event_type = headers.get(
|
||||
"event_type"
|
||||
) # Note: using event_type not event-type
|
||||
|
||||
if event_type == "chunk" and payload:
|
||||
# Extract base64 encoded content from chunk events
|
||||
chunk_payload: InvokeAgentChunkPayload = payload # type: ignore
|
||||
encoded_bytes = chunk_payload.get("bytes", "")
|
||||
if encoded_bytes:
|
||||
try:
|
||||
decoded_content = base64.b64decode(encoded_bytes).decode(
|
||||
"utf-8"
|
||||
)
|
||||
response_parts.append(decoded_content)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to decode chunk content: {e}")
|
||||
|
||||
return "".join(response_parts)
|
||||
|
||||
def _extract_usage_info(self, events: InvokeAgentEventList) -> InvokeAgentUsage:
|
||||
"""Extract token usage information from trace events."""
|
||||
usage_info = InvokeAgentUsage(
|
||||
inputTokens=0,
|
||||
outputTokens=0,
|
||||
model=None,
|
||||
)
|
||||
|
||||
response_model: Optional[str] = None
|
||||
|
||||
for event in events:
|
||||
if not self._is_trace_event(event):
|
||||
continue
|
||||
|
||||
trace_data = self._get_trace_data(event)
|
||||
if not trace_data:
|
||||
continue
|
||||
|
||||
verbose_logger.debug(f"Trace event: {trace_data}")
|
||||
|
||||
# Extract usage from pre-processing trace
|
||||
self._extract_and_update_preprocessing_usage(
|
||||
trace_data=trace_data,
|
||||
usage_info=usage_info,
|
||||
)
|
||||
|
||||
# Extract model from orchestration trace
|
||||
if response_model is None:
|
||||
response_model = self._extract_orchestration_model(trace_data)
|
||||
|
||||
usage_info["model"] = response_model
|
||||
return usage_info
|
||||
|
||||
def _is_trace_event(self, event: InvokeAgentEvent) -> bool:
|
||||
"""Check if the event is a trace event."""
|
||||
headers = event.get("headers", {})
|
||||
event_type = headers.get("event_type")
|
||||
payload = event.get("payload")
|
||||
return event_type == "trace" and payload is not None
|
||||
|
||||
def _get_trace_data(self, event: InvokeAgentEvent) -> Optional[InvokeAgentTrace]:
|
||||
"""Extract trace data from a trace event."""
|
||||
payload = event.get("payload")
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
trace_payload: InvokeAgentTracePayload = payload # type: ignore
|
||||
return trace_payload.get("trace", {})
|
||||
|
||||
def _extract_and_update_preprocessing_usage(
|
||||
self, trace_data: InvokeAgentTrace, usage_info: InvokeAgentUsage
|
||||
) -> None:
|
||||
"""Extract usage information from preprocessing trace."""
|
||||
pre_processing: Optional[InvokeAgentPreProcessingTrace] = trace_data.get(
|
||||
"preProcessingTrace"
|
||||
)
|
||||
if not pre_processing:
|
||||
return
|
||||
|
||||
model_output: Optional[InvokeAgentModelInvocationOutput] = (
|
||||
pre_processing.get("modelInvocationOutput")
|
||||
or InvokeAgentModelInvocationOutput()
|
||||
)
|
||||
if not model_output:
|
||||
return
|
||||
|
||||
metadata: Optional[InvokeAgentMetadata] = (
|
||||
model_output.get("metadata") or InvokeAgentMetadata()
|
||||
)
|
||||
if not metadata:
|
||||
return
|
||||
|
||||
usage: Optional[Union[InvokeAgentUsage, Dict]] = metadata.get("usage", {})
|
||||
if not usage:
|
||||
return
|
||||
|
||||
usage_info["inputTokens"] += usage.get("inputTokens", 0)
|
||||
usage_info["outputTokens"] += usage.get("outputTokens", 0)
|
||||
|
||||
def _extract_orchestration_model(
|
||||
self, trace_data: InvokeAgentTrace
|
||||
) -> Optional[str]:
|
||||
"""Extract model information from orchestration trace."""
|
||||
orchestration_trace: Optional[InvokeAgentOrchestrationTrace] = trace_data.get(
|
||||
"orchestrationTrace"
|
||||
)
|
||||
if not orchestration_trace:
|
||||
return None
|
||||
|
||||
model_invocation: Optional[InvokeAgentModelInvocationInput] = (
|
||||
orchestration_trace.get("modelInvocationInput")
|
||||
or InvokeAgentModelInvocationInput()
|
||||
)
|
||||
if not model_invocation:
|
||||
return None
|
||||
|
||||
return model_invocation.get("foundationModel")
|
||||
|
||||
def _build_model_response(
|
||||
self,
|
||||
content: str,
|
||||
model: str,
|
||||
usage_info: InvokeAgentUsage,
|
||||
model_response: ModelResponse,
|
||||
) -> ModelResponse:
|
||||
"""Build the final ModelResponse object."""
|
||||
|
||||
# Create the message content
|
||||
message = Message(content=content, role="assistant")
|
||||
|
||||
# Create choices
|
||||
choice = Choices(finish_reason="stop", index=0, message=message)
|
||||
|
||||
# Update model response
|
||||
model_response.choices = [choice]
|
||||
model_response.model = usage_info.get("model", model)
|
||||
|
||||
# Add usage information if available
|
||||
if usage_info:
|
||||
from litellm.types.utils import Usage
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=usage_info.get("inputTokens", 0),
|
||||
completion_tokens=usage_info.get("outputTokens", 0),
|
||||
total_tokens=usage_info.get("inputTokens", 0)
|
||||
+ usage_info.get("outputTokens", 0),
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
# Get the raw binary content
|
||||
raw_content = raw_response.content
|
||||
verbose_logger.debug(
|
||||
f"Processing {len(raw_content)} bytes of AWS event stream data"
|
||||
)
|
||||
|
||||
# Parse the AWS event stream format
|
||||
events = self._parse_aws_event_stream(raw_content)
|
||||
verbose_logger.debug(f"Parsed {len(events)} events from stream")
|
||||
|
||||
# Extract response content from chunk events
|
||||
content = self._extract_response_content(events)
|
||||
|
||||
# Extract usage information from trace events
|
||||
usage_info = self._extract_usage_info(events)
|
||||
|
||||
# Build and return the model response
|
||||
return self._build_model_response(
|
||||
content=content,
|
||||
model=model,
|
||||
usage_info=usage_info,
|
||||
model_response=model_response,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"Error processing Bedrock Invoke Agent response: {str(e)}"
|
||||
)
|
||||
raise BedrockError(
|
||||
message=f"Error processing response: {str(e)}",
|
||||
status_code=raw_response.status_code,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return headers
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return BedrockError(status_code=status_code, message=error_message)
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
return True
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,99 @@
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
|
||||
|
||||
class AmazonAI21Config(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
||||
|
||||
Supported Params for the Amazon / AI21 models:
|
||||
|
||||
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
|
||||
|
||||
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
|
||||
|
||||
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
|
||||
|
||||
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
|
||||
|
||||
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
|
||||
|
||||
- `presencePenalty` (object): Placeholder for presence penalty object.
|
||||
|
||||
- `countPenalty` (object): Placeholder for count penalty object.
|
||||
"""
|
||||
|
||||
maxTokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
stopSequences: Optional[list] = None
|
||||
frequencePenalty: Optional[dict] = None
|
||||
presencePenalty: Optional[dict] = None
|
||||
countPenalty: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[float] = None,
|
||||
stopSequences: Optional[list] = None,
|
||||
frequencePenalty: Optional[dict] = None,
|
||||
presencePenalty: Optional[dict] = None,
|
||||
countPenalty: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return [
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["maxTokens"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "top_p":
|
||||
optional_params["topP"] = v
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
return optional_params
|
||||
@@ -0,0 +1,75 @@
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.llms.cohere.chat.transformation import CohereChatConfig
|
||||
|
||||
|
||||
class AmazonCohereConfig(AmazonInvokeConfig, CohereChatConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
|
||||
|
||||
Supported Params for the Amazon / Cohere models:
|
||||
|
||||
- `max_tokens` (integer) max tokens,
|
||||
- `temperature` (float) model temperature,
|
||||
- `return_likelihood` (string) n/a
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
return_likelihood: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
return_likelihood: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
supported_params = CohereChatConfig.get_supported_openai_params(
|
||||
self, model=model
|
||||
)
|
||||
return supported_params
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return CohereChatConfig.map_openai_params(
|
||||
self,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
@@ -0,0 +1,135 @@
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
from httpx import Response
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
_parse_content_for_reasoning,
|
||||
)
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.bedrock import AmazonDeepSeekR1StreamingResponse
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionUsageBlock,
|
||||
Choices,
|
||||
Delta,
|
||||
Message,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
from .amazon_llama_transformation import AmazonLlamaConfig
|
||||
|
||||
|
||||
class AmazonDeepSeekR1Config(AmazonLlamaConfig):
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Extract the reasoning content, and return it as a separate field in the response.
|
||||
"""
|
||||
response = super().transform_response(
|
||||
model,
|
||||
raw_response,
|
||||
model_response,
|
||||
logging_obj,
|
||||
request_data,
|
||||
messages,
|
||||
optional_params,
|
||||
litellm_params,
|
||||
encoding,
|
||||
api_key,
|
||||
json_mode,
|
||||
)
|
||||
prompt = cast(Optional[str], request_data.get("prompt"))
|
||||
message_content = cast(
|
||||
Optional[str], cast(Choices, response.choices[0]).message.get("content")
|
||||
)
|
||||
if prompt and prompt.strip().endswith("<think>") and message_content:
|
||||
message_content_with_reasoning_token = "<think>" + message_content
|
||||
reasoning, content = _parse_content_for_reasoning(
|
||||
message_content_with_reasoning_token
|
||||
)
|
||||
provider_specific_fields = (
|
||||
cast(Choices, response.choices[0]).message.provider_specific_fields
|
||||
or {}
|
||||
)
|
||||
if reasoning:
|
||||
provider_specific_fields["reasoning_content"] = reasoning
|
||||
|
||||
message = Message(
|
||||
**{
|
||||
**cast(Choices, response.choices[0]).message.model_dump(),
|
||||
"content": content,
|
||||
"provider_specific_fields": provider_specific_fields,
|
||||
}
|
||||
)
|
||||
cast(Choices, response.choices[0]).message = message
|
||||
return response
|
||||
|
||||
|
||||
class AmazonDeepseekR1ResponseIterator(BaseModelResponseIterator):
|
||||
def __init__(self, streaming_response: Any, sync_stream: bool) -> None:
|
||||
super().__init__(streaming_response=streaming_response, sync_stream=sync_stream)
|
||||
self.has_finished_thinking = False
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
||||
"""
|
||||
Deepseek r1 starts by thinking, then it generates the response.
|
||||
"""
|
||||
try:
|
||||
typed_chunk = AmazonDeepSeekR1StreamingResponse(**chunk) # type: ignore
|
||||
generated_content = typed_chunk["generation"]
|
||||
if generated_content == "</think>" and not self.has_finished_thinking:
|
||||
verbose_logger.debug(
|
||||
"Deepseek r1: </think> received, setting has_finished_thinking to True"
|
||||
)
|
||||
generated_content = ""
|
||||
self.has_finished_thinking = True
|
||||
|
||||
prompt_token_count = typed_chunk.get("prompt_token_count") or 0
|
||||
generation_token_count = typed_chunk.get("generation_token_count") or 0
|
||||
usage = ChatCompletionUsageBlock(
|
||||
prompt_tokens=prompt_token_count,
|
||||
completion_tokens=generation_token_count,
|
||||
total_tokens=prompt_token_count + generation_token_count,
|
||||
)
|
||||
|
||||
return ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=typed_chunk["stop_reason"],
|
||||
delta=Delta(
|
||||
content=(
|
||||
generated_content
|
||||
if self.has_finished_thinking
|
||||
else None
|
||||
),
|
||||
reasoning_content=(
|
||||
generated_content
|
||||
if not self.has_finished_thinking
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,80 @@
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
|
||||
|
||||
class AmazonLlamaConfig(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
|
||||
|
||||
Supported Params for the Amazon / Meta Llama models:
|
||||
|
||||
- `max_gen_len` (integer) max tokens,
|
||||
- `temperature` (float) temperature for model,
|
||||
- `top_p` (float) top p for model
|
||||
"""
|
||||
|
||||
max_gen_len: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokenCount: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return [
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["max_gen_len"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "top_p":
|
||||
optional_params["top_p"] = v
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
return optional_params
|
||||
@@ -0,0 +1,119 @@
|
||||
import types
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.llms.bedrock.common_utils import BedrockError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
|
||||
class AmazonMistralConfig(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
|
||||
Supported Params for the Amazon / Mistral models:
|
||||
|
||||
- `max_tokens` (integer) max tokens,
|
||||
- `temperature` (float) temperature for model,
|
||||
- `top_p` (float) top p for model
|
||||
- `stop` [string] A list of stop sequences that if generated by the model, stops the model from generating further output.
|
||||
- `top_k` (float) top k for model
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[float] = None
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[int] = None,
|
||||
top_k: Optional[float] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return ["max_tokens", "temperature", "top_p", "stop", "stream"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["max_tokens"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "top_p":
|
||||
optional_params["top_p"] = v
|
||||
if k == "stop":
|
||||
optional_params["stop"] = v
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
return optional_params
|
||||
|
||||
@staticmethod
|
||||
def get_outputText(
|
||||
completion_response: dict, model_response: "ModelResponse"
|
||||
) -> str:
|
||||
"""This function extracts the output text from a bedrock mistral completion.
|
||||
As a side effect, it updates the finish reason for a model response.
|
||||
|
||||
Args:
|
||||
completion_response: JSON from the completion.
|
||||
model_response: ModelResponse
|
||||
|
||||
Returns:
|
||||
A string with the response of the LLM
|
||||
|
||||
"""
|
||||
if "choices" in completion_response:
|
||||
outputText = completion_response["choices"][0]["message"]["content"]
|
||||
model_response.choices[0].finish_reason = completion_response["choices"][0][
|
||||
"finish_reason"
|
||||
]
|
||||
elif "outputs" in completion_response:
|
||||
outputText = completion_response["outputs"][0]["text"]
|
||||
model_response.choices[0].finish_reason = completion_response["outputs"][0][
|
||||
"stop_reason"
|
||||
]
|
||||
else:
|
||||
raise BedrockError(
|
||||
message="Unexpected mistral completion response", status_code=400
|
||||
)
|
||||
|
||||
return outputText
|
||||
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
Transformation for Bedrock Moonshot AI (Kimi K2) models.
|
||||
|
||||
Supports the Kimi K2 Thinking model available on Amazon Bedrock.
|
||||
Model format: bedrock/moonshot.kimi-k2-thinking-v1:0
|
||||
|
||||
Reference: https://aws.amazon.com/about-aws/whats-new/2025/12/amazon-bedrock-fully-managed-open-weight-models/
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
import re
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.llms.bedrock.common_utils import BedrockError
|
||||
from litellm.llms.moonshot.chat.transformation import MoonshotChatConfig
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AmazonMoonshotConfig(AmazonInvokeConfig, MoonshotChatConfig):
|
||||
"""
|
||||
Configuration for Bedrock Moonshot AI (Kimi K2) models.
|
||||
|
||||
Reference:
|
||||
https://aws.amazon.com/about-aws/whats-new/2025/12/amazon-bedrock-fully-managed-open-weight-models/
|
||||
https://platform.moonshot.ai/docs/api/chat
|
||||
|
||||
Supported Params for the Amazon / Moonshot models:
|
||||
- `max_tokens` (integer) max tokens
|
||||
- `temperature` (float) temperature for model (0-1 for Moonshot)
|
||||
- `top_p` (float) top p for model
|
||||
- `stream` (bool) whether to stream responses
|
||||
- `tools` (list) tool definitions (supported on kimi-k2-thinking)
|
||||
- `tool_choice` (str|dict) tool choice specification (supported on kimi-k2-thinking)
|
||||
|
||||
NOT Supported on Bedrock:
|
||||
- `stop` sequences (Bedrock doesn't support stopSequences field for this model)
|
||||
|
||||
Note: The kimi-k2-thinking model DOES support tool calls, unlike kimi-thinking-preview.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
AmazonInvokeConfig.__init__(self, **kwargs)
|
||||
MoonshotChatConfig.__init__(self, **kwargs)
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return "bedrock"
|
||||
|
||||
def _get_model_id(self, model: str) -> str:
|
||||
"""
|
||||
Extract the actual model ID from the LiteLLM model name.
|
||||
|
||||
Removes routing prefixes like:
|
||||
- bedrock/invoke/moonshot.kimi-k2-thinking -> moonshot.kimi-k2-thinking
|
||||
- invoke/moonshot.kimi-k2-thinking -> moonshot.kimi-k2-thinking
|
||||
- moonshot.kimi-k2-thinking -> moonshot.kimi-k2-thinking
|
||||
"""
|
||||
# Remove bedrock/ prefix if present
|
||||
if model.startswith("bedrock/"):
|
||||
model = model[8:]
|
||||
|
||||
# Remove invoke/ prefix if present
|
||||
if model.startswith("invoke/"):
|
||||
model = model[7:]
|
||||
|
||||
# Remove any provider prefix (e.g., moonshot/)
|
||||
if "/" in model and not model.startswith("arn:"):
|
||||
parts = model.split("/", 1)
|
||||
if len(parts) == 2:
|
||||
model = parts[1]
|
||||
|
||||
return model
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
Get the supported OpenAI params for Moonshot AI models on Bedrock.
|
||||
|
||||
Bedrock-specific limitations:
|
||||
- stopSequences field is not supported on Bedrock (unlike native Moonshot API)
|
||||
- functions parameter is not supported (use tools instead)
|
||||
- tool_choice doesn't support "required" value
|
||||
|
||||
Note: kimi-k2-thinking DOES support tool calls (unlike kimi-thinking-preview)
|
||||
The parent MoonshotChatConfig class handles the kimi-thinking-preview exclusion.
|
||||
"""
|
||||
excluded_params: List[str] = [
|
||||
"functions",
|
||||
"stop",
|
||||
] # Bedrock doesn't support stopSequences
|
||||
|
||||
base_openai_params = super(
|
||||
MoonshotChatConfig, self
|
||||
).get_supported_openai_params(model=model)
|
||||
final_params: List[str] = []
|
||||
for param in base_openai_params:
|
||||
if param not in excluded_params:
|
||||
final_params.append(param)
|
||||
|
||||
return final_params
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI parameters to Moonshot AI parameters for Bedrock.
|
||||
|
||||
Handles Moonshot AI specific limitations:
|
||||
- tool_choice doesn't support "required" value
|
||||
- Temperature <0.3 limitation for n>1
|
||||
- Temperature range is [0, 1] (not [0, 2] like OpenAI)
|
||||
"""
|
||||
return MoonshotChatConfig.map_openai_params(
|
||||
self,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the request for Bedrock Moonshot AI models.
|
||||
|
||||
Uses the Moonshot transformation logic which handles:
|
||||
- Converting content lists to strings (Moonshot doesn't support list format)
|
||||
- Adding tool_choice="required" message if needed
|
||||
- Temperature and parameter validation
|
||||
|
||||
"""
|
||||
# Filter out AWS credentials using the existing method from BaseAWSLLM
|
||||
self._get_boto_credentials_from_optional_params(optional_params, model)
|
||||
|
||||
# Strip routing prefixes to get the actual model ID
|
||||
clean_model_id = self._get_model_id(model)
|
||||
|
||||
# Use Moonshot's transform_request which handles message transformation
|
||||
# and tool_choice="required" workaround
|
||||
return MoonshotChatConfig.transform_request(
|
||||
self,
|
||||
model=clean_model_id,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def _extract_reasoning_from_content(
|
||||
self, content: str
|
||||
) -> tuple[Optional[str], str]:
|
||||
"""
|
||||
Extract reasoning content from <reasoning> tags in the response.
|
||||
|
||||
Moonshot AI's Kimi K2 Thinking model returns reasoning in <reasoning> tags.
|
||||
This method extracts that content and returns it separately.
|
||||
|
||||
Args:
|
||||
content: The full content string from the API response
|
||||
|
||||
Returns:
|
||||
tuple: (reasoning_content, main_content)
|
||||
"""
|
||||
if not content:
|
||||
return None, content
|
||||
|
||||
# Match <reasoning>...</reasoning> tags
|
||||
reasoning_match = re.match(
|
||||
r"<reasoning>(.*?)</reasoning>\s*(.*)", content, re.DOTALL
|
||||
)
|
||||
|
||||
if reasoning_match:
|
||||
reasoning_content = reasoning_match.group(1).strip()
|
||||
main_content = reasoning_match.group(2).strip()
|
||||
return reasoning_content, main_content
|
||||
|
||||
return None, content
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: "ModelResponse",
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> "ModelResponse":
|
||||
"""
|
||||
Transform the response from Bedrock Moonshot AI models.
|
||||
|
||||
Moonshot AI uses OpenAI-compatible response format, but returns reasoning
|
||||
content in <reasoning> tags. This method:
|
||||
1. Calls parent class transformation
|
||||
2. Extracts reasoning content from <reasoning> tags
|
||||
3. Sets reasoning_content on the message object
|
||||
"""
|
||||
# First, get the standard transformation
|
||||
model_response = MoonshotChatConfig.transform_response(
|
||||
self,
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
# Extract reasoning content from <reasoning> tags
|
||||
if model_response.choices and len(model_response.choices) > 0:
|
||||
for choice in model_response.choices:
|
||||
# Only process Choices (not StreamingChoices) which have message attribute
|
||||
if (
|
||||
isinstance(choice, Choices)
|
||||
and choice.message
|
||||
and choice.message.content
|
||||
):
|
||||
(
|
||||
reasoning_content,
|
||||
main_content,
|
||||
) = self._extract_reasoning_from_content(choice.message.content)
|
||||
|
||||
if reasoning_content:
|
||||
# Set the reasoning_content field
|
||||
choice.message.reasoning_content = reasoning_content
|
||||
# Update the main content without reasoning tags
|
||||
choice.message.content = main_content
|
||||
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BedrockError:
|
||||
"""Return the appropriate error class for Bedrock."""
|
||||
return BedrockError(status_code=status_code, message=error_message)
|
||||
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
Handles transforming requests for `bedrock/invoke/{nova} models`
|
||||
|
||||
Inherits from `AmazonConverseConfig`
|
||||
|
||||
Nova + Invoke API Tutorial: https://docs.aws.amazon.com/nova/latest/userguide/using-invoke-api.html
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.types.llms.bedrock import BedrockInvokeNovaRequest
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ..converse_transformation import AmazonConverseConfig
|
||||
from .base_invoke_transformation import AmazonInvokeConfig
|
||||
|
||||
|
||||
class AmazonInvokeNovaConfig(AmazonInvokeConfig, AmazonConverseConfig):
|
||||
"""
|
||||
Config for sending `nova` requests to `/bedrock/invoke/`
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
return AmazonConverseConfig.get_supported_openai_params(self, model)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return AmazonConverseConfig.map_openai_params(
|
||||
self, non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
_transformed_nova_request = AmazonConverseConfig.transform_request(
|
||||
self,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
_bedrock_invoke_nova_request = BedrockInvokeNovaRequest(
|
||||
**_transformed_nova_request
|
||||
)
|
||||
self._remove_empty_system_messages(_bedrock_invoke_nova_request)
|
||||
bedrock_invoke_nova_request = self._filter_allowed_fields(
|
||||
_bedrock_invoke_nova_request
|
||||
)
|
||||
return bedrock_invoke_nova_request
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: Logging,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
return AmazonConverseConfig.transform_response(
|
||||
self,
|
||||
model,
|
||||
raw_response,
|
||||
model_response,
|
||||
logging_obj,
|
||||
request_data,
|
||||
messages,
|
||||
optional_params,
|
||||
litellm_params,
|
||||
encoding,
|
||||
api_key,
|
||||
json_mode,
|
||||
)
|
||||
|
||||
def _filter_allowed_fields(
|
||||
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
|
||||
) -> dict:
|
||||
"""
|
||||
Filter out fields that are not allowed in the `BedrockInvokeNovaRequest` dataclass.
|
||||
"""
|
||||
allowed_fields = set(BedrockInvokeNovaRequest.__annotations__.keys())
|
||||
return {
|
||||
k: v for k, v in bedrock_invoke_nova_request.items() if k in allowed_fields
|
||||
}
|
||||
|
||||
def _remove_empty_system_messages(
|
||||
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
|
||||
) -> None:
|
||||
"""
|
||||
In-place remove empty `system` messages from the request.
|
||||
|
||||
/bedrock/invoke/ does not allow empty `system` messages.
|
||||
"""
|
||||
_system_message = bedrock_invoke_nova_request.get("system", None)
|
||||
if isinstance(_system_message, list) and len(_system_message) == 0:
|
||||
bedrock_invoke_nova_request.pop("system", None)
|
||||
return
|
||||
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
Transformation for Bedrock imported models that use OpenAI Chat Completions format.
|
||||
|
||||
Use this for models imported into Bedrock that accept the OpenAI API format.
|
||||
Model format: bedrock/openai/<model-id>
|
||||
|
||||
Example: bedrock/openai/arn:aws:bedrock:us-east-1:123456789012:imported-model/abc123
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
from litellm.llms.bedrock.common_utils import BedrockError
|
||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from litellm.passthrough.utils import CommonUtils
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AmazonBedrockOpenAIConfig(OpenAIGPTConfig, BaseAWSLLM):
|
||||
"""
|
||||
Configuration for Bedrock imported models that use OpenAI Chat Completions format.
|
||||
|
||||
This class handles the transformation of requests and responses for Bedrock
|
||||
imported models that accept the OpenAI API format directly.
|
||||
|
||||
Inherits from OpenAIGPTConfig to leverage standard OpenAI parameter handling
|
||||
and response transformation, while adding Bedrock-specific URL generation
|
||||
and AWS request signing.
|
||||
|
||||
Usage:
|
||||
model = "bedrock/openai/arn:aws:bedrock:us-east-1:123456789012:imported-model/abc123"
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
OpenAIGPTConfig.__init__(self, **kwargs)
|
||||
BaseAWSLLM.__init__(self, **kwargs)
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return "bedrock"
|
||||
|
||||
def _get_openai_model_id(self, model: str) -> str:
|
||||
"""
|
||||
Extract the actual model ID from the LiteLLM model name.
|
||||
|
||||
Input format: bedrock/openai/<model-id>
|
||||
Returns: <model-id>
|
||||
"""
|
||||
# Remove bedrock/ prefix if present
|
||||
if model.startswith("bedrock/"):
|
||||
model = model[8:]
|
||||
|
||||
# Remove openai/ prefix
|
||||
if model.startswith("openai/"):
|
||||
model = model[7:]
|
||||
|
||||
return model
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for the Bedrock invoke endpoint.
|
||||
|
||||
Uses the standard Bedrock invoke endpoint format.
|
||||
"""
|
||||
model_id = self._get_openai_model_id(model)
|
||||
|
||||
# Get AWS region
|
||||
aws_region_name = self._get_aws_region_name(
|
||||
optional_params=optional_params, model=model
|
||||
)
|
||||
|
||||
# Get runtime endpoint
|
||||
aws_bedrock_runtime_endpoint = optional_params.get(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
|
||||
# Encode model ID for ARNs (e.g., :imported-model/ -> :imported-model%2F)
|
||||
model_id = CommonUtils.encode_bedrock_runtime_modelid_arn(model_id)
|
||||
|
||||
# Build the invoke URL
|
||||
if stream:
|
||||
endpoint_url = (
|
||||
f"{endpoint_url}/model/{model_id}/invoke-with-response-stream"
|
||||
)
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{model_id}/invoke"
|
||||
|
||||
return endpoint_url
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
headers: dict,
|
||||
optional_params: dict,
|
||||
request_data: dict,
|
||||
api_base: str,
|
||||
api_key: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
fake_stream: Optional[bool] = None,
|
||||
) -> Tuple[dict, Optional[bytes]]:
|
||||
"""
|
||||
Sign the request using AWS Signature Version 4.
|
||||
"""
|
||||
return self._sign_request(
|
||||
service_name="bedrock",
|
||||
headers=headers,
|
||||
optional_params=optional_params,
|
||||
request_data=request_data,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
stream=stream,
|
||||
fake_stream=fake_stream,
|
||||
)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the request to OpenAI Chat Completions format for Bedrock imported models.
|
||||
|
||||
Removes AWS-specific params and stream param (handled separately in URL),
|
||||
then delegates to parent class for standard OpenAI request transformation.
|
||||
"""
|
||||
# Remove stream from optional_params as it's handled separately in URL
|
||||
optional_params.pop("stream", None)
|
||||
|
||||
# Remove AWS-specific params that shouldn't be in the request body
|
||||
inference_params = {
|
||||
k: v
|
||||
for k, v in optional_params.items()
|
||||
if k not in self.aws_authentication_params
|
||||
}
|
||||
|
||||
# Use parent class transform_request for OpenAI format
|
||||
return super().transform_request(
|
||||
model=self._get_openai_model_id(model),
|
||||
messages=messages,
|
||||
optional_params=inference_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate the environment and return headers.
|
||||
|
||||
For Bedrock, we don't need Bearer token auth since we use AWS SigV4.
|
||||
"""
|
||||
return headers
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BedrockError:
|
||||
"""Return the appropriate error class for Bedrock."""
|
||||
return BedrockError(status_code=status_code, message=error_message)
|
||||
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Handles transforming requests for `bedrock/invoke/{qwen2} models`
|
||||
|
||||
Inherits from `AmazonQwen3Config` since Qwen2 and Qwen3 architectures are mostly similar.
|
||||
The main difference is in the response format: Qwen2 uses "text" field while Qwen3 uses "generation" field.
|
||||
|
||||
Qwen2 + Invoke API Tutorial: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.amazon_qwen3_transformation import (
|
||||
AmazonQwen3Config,
|
||||
)
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
|
||||
|
||||
class AmazonQwen2Config(AmazonQwen3Config):
|
||||
"""
|
||||
Config for sending `qwen2` requests to `/bedrock/invoke/`
|
||||
|
||||
Inherits from AmazonQwen3Config since Qwen2 and Qwen3 architectures are mostly similar.
|
||||
The main difference is in the response format: Qwen2 uses "text" field while Qwen3 uses "generation" field.
|
||||
|
||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html
|
||||
"""
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform Qwen2 Bedrock response to OpenAI format
|
||||
|
||||
Qwen2 uses "text" field, but we also support "generation" field for compatibility.
|
||||
"""
|
||||
try:
|
||||
if hasattr(raw_response, "json"):
|
||||
response_data = raw_response.json()
|
||||
else:
|
||||
response_data = raw_response
|
||||
|
||||
# Extract the generated text - Qwen2 uses "text" field, but also support "generation" for compatibility
|
||||
generated_text = response_data.get("generation", "") or response_data.get(
|
||||
"text", ""
|
||||
)
|
||||
|
||||
# Clean up the response (remove assistant start token if present)
|
||||
if generated_text.startswith("<|im_start|>assistant\n"):
|
||||
generated_text = generated_text[len("<|im_start|>assistant\n") :]
|
||||
if generated_text.endswith("<|im_end|>"):
|
||||
generated_text = generated_text[: -len("<|im_end|>")]
|
||||
|
||||
# Set the content in the existing model_response structure
|
||||
if hasattr(model_response, "choices") and len(model_response.choices) > 0:
|
||||
choice = model_response.choices[0]
|
||||
choice.message.content = generated_text
|
||||
choice.finish_reason = "stop"
|
||||
|
||||
# Set usage information if available in response
|
||||
if "usage" in response_data:
|
||||
usage_data = response_data["usage"]
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
Usage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
),
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
except Exception as e:
|
||||
if logging_obj:
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=raw_response,
|
||||
additional_args={"error": str(e)},
|
||||
)
|
||||
raise e
|
||||
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
Handles transforming requests for `bedrock/invoke/{qwen3} models`
|
||||
|
||||
Inherits from `AmazonInvokeConfig`
|
||||
|
||||
Qwen3 + Invoke API Tutorial: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
|
||||
|
||||
class AmazonQwen3Config(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Config for sending `qwen3` requests to `/bedrock/invoke/`
|
||||
|
||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"stop",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["max_tokens"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "top_p":
|
||||
optional_params["top_p"] = v
|
||||
if k == "top_k":
|
||||
optional_params["top_k"] = v
|
||||
if k == "stop":
|
||||
optional_params["stop"] = v
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
return optional_params
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform OpenAI format to Qwen3 Bedrock invoke format
|
||||
"""
|
||||
# Convert messages to prompt format
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
|
||||
# Build the request body
|
||||
request_body = {
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if "max_tokens" in optional_params:
|
||||
request_body["max_gen_len"] = optional_params["max_tokens"]
|
||||
if "temperature" in optional_params:
|
||||
request_body["temperature"] = optional_params["temperature"]
|
||||
if "top_p" in optional_params:
|
||||
request_body["top_p"] = optional_params["top_p"]
|
||||
if "top_k" in optional_params:
|
||||
request_body["top_k"] = optional_params["top_k"]
|
||||
if "stop" in optional_params:
|
||||
request_body["stop"] = optional_params["stop"]
|
||||
|
||||
return request_body
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: List[AllMessageValues]) -> str:
|
||||
"""
|
||||
Convert OpenAI messages format to Qwen3 prompt format
|
||||
Supports tool calls, multimodal content, and various message types
|
||||
"""
|
||||
prompt_parts = []
|
||||
|
||||
for message in messages:
|
||||
role = message.get("role", "")
|
||||
content = message.get("content", "")
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
|
||||
if role == "system":
|
||||
prompt_parts.append(f"<|im_start|>system\n{content}<|im_end|>")
|
||||
elif role == "user":
|
||||
# Handle multimodal content
|
||||
if isinstance(content, list):
|
||||
text_content = []
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
text_content.append(item.get("text", ""))
|
||||
elif item.get("type") == "image_url":
|
||||
# For Qwen3, we can include image placeholders
|
||||
text_content.append(
|
||||
"<|vision_start|><|image_pad|><|vision_end|>"
|
||||
)
|
||||
content = "".join(text_content)
|
||||
prompt_parts.append(f"<|im_start|>user\n{content}<|im_end|>")
|
||||
elif role == "assistant":
|
||||
if tool_calls and isinstance(tool_calls, list):
|
||||
# Handle tool calls
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call.get("function", {}).get("name", "")
|
||||
function_args = tool_call.get("function", {}).get(
|
||||
"arguments", ""
|
||||
)
|
||||
prompt_parts.append(
|
||||
f'<|im_start|>assistant\n<tool_call>\n{{"name": "{function_name}", "arguments": "{function_args}"}}\n</tool_call><|im_end|>'
|
||||
)
|
||||
else:
|
||||
prompt_parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")
|
||||
elif role == "tool":
|
||||
# Handle tool responses
|
||||
prompt_parts.append(f"<|im_start|>tool\n{content}<|im_end|>")
|
||||
|
||||
# Add assistant start token for response generation
|
||||
prompt_parts.append("<|im_start|>assistant\n")
|
||||
|
||||
return "\n".join(prompt_parts)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform Qwen3 Bedrock response to OpenAI format
|
||||
"""
|
||||
try:
|
||||
if hasattr(raw_response, "json"):
|
||||
response_data = raw_response.json()
|
||||
else:
|
||||
response_data = raw_response
|
||||
|
||||
# Extract the generated text - Qwen3 uses "generation" field
|
||||
generated_text = response_data.get("generation", "")
|
||||
|
||||
# Clean up the response (remove assistant start token if present)
|
||||
if generated_text.startswith("<|im_start|>assistant\n"):
|
||||
generated_text = generated_text[len("<|im_start|>assistant\n") :]
|
||||
if generated_text.endswith("<|im_end|>"):
|
||||
generated_text = generated_text[: -len("<|im_end|>")]
|
||||
|
||||
# Set the content in the existing model_response structure
|
||||
if hasattr(model_response, "choices") and len(model_response.choices) > 0:
|
||||
choice = model_response.choices[0]
|
||||
choice.message.content = generated_text
|
||||
choice.finish_reason = "stop"
|
||||
|
||||
# Set usage information if available in response
|
||||
if "usage" in response_data:
|
||||
usage_data = response_data["usage"]
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
Usage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
),
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
except Exception as e:
|
||||
if logging_obj:
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=raw_response,
|
||||
additional_args={"error": str(e)},
|
||||
)
|
||||
raise e
|
||||
@@ -0,0 +1,116 @@
|
||||
import re
|
||||
import types
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
|
||||
|
||||
class AmazonTitanConfig(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
|
||||
|
||||
Supported Params for the Amazon Titan models:
|
||||
|
||||
- `maxTokenCount` (integer) max tokens,
|
||||
- `stopSequences` (string[]) list of stop sequence strings
|
||||
- `temperature` (float) temperature for model,
|
||||
- `topP` (int) top p for model
|
||||
"""
|
||||
|
||||
maxTokenCount: Optional[int] = None
|
||||
stopSequences: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokenCount: Optional[int] = None,
|
||||
stopSequences: Optional[list] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def _map_and_modify_arg(
|
||||
self,
|
||||
supported_params: dict,
|
||||
provider: str,
|
||||
model: str,
|
||||
stop: Union[List[str], str],
|
||||
):
|
||||
"""
|
||||
filter params to fit the required provider format, drop those that don't fit if user sets `litellm.drop_params = True`.
|
||||
"""
|
||||
filtered_stop = None
|
||||
if "stop" in supported_params and litellm.drop_params:
|
||||
if provider == "bedrock" and "amazon" in model:
|
||||
filtered_stop = []
|
||||
if isinstance(stop, list):
|
||||
for s in stop:
|
||||
if re.match(r"^(\|+|User:)$", s):
|
||||
filtered_stop.append(s)
|
||||
if filtered_stop is not None:
|
||||
supported_params["stop"] = filtered_stop
|
||||
|
||||
return supported_params
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens" or k == "max_completion_tokens":
|
||||
optional_params["maxTokenCount"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "stop":
|
||||
filtered_stop = self._map_and_modify_arg(
|
||||
{"stop": v}, provider="bedrock", model=model, stop=v
|
||||
)
|
||||
optional_params["stopSequences"] = filtered_stop["stop"]
|
||||
if k == "top_p":
|
||||
optional_params["topP"] = v
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
return optional_params
|
||||
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
Transforms OpenAI-style requests into TwelveLabs Pegasus 1.2 requests for Bedrock.
|
||||
|
||||
Reference:
|
||||
https://docs.twelvelabs.io/docs/models/pegasus
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.base_llm.base_utils import type_to_response_format_param
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.llms.bedrock.common_utils import BedrockError
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
from litellm.utils import get_base64_str
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AmazonTwelveLabsPegasusConfig(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Handles transforming OpenAI-style requests into Bedrock InvokeModel requests for
|
||||
`twelvelabs.pegasus-1-2-v1:0`.
|
||||
|
||||
Pegasus 1.2 requires an `inputPrompt` and a `mediaSource` that either references
|
||||
an S3 object or a base64-encoded clip. Optional OpenAI params (temperature,
|
||||
response_format, max_tokens) are translated to the TwelveLabs schema.
|
||||
"""
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"temperature",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param in {"max_tokens", "max_completion_tokens"}:
|
||||
optional_params["maxOutputTokens"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "response_format":
|
||||
optional_params["responseFormat"] = self._normalize_response_format(
|
||||
value
|
||||
)
|
||||
return optional_params
|
||||
|
||||
def _normalize_response_format(self, value: Any) -> Any:
|
||||
"""Normalize response_format to TwelveLabs format.
|
||||
|
||||
TwelveLabs expects:
|
||||
{
|
||||
"jsonSchema": {...}
|
||||
}
|
||||
|
||||
But OpenAI format is:
|
||||
{
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "...",
|
||||
"schema": {...}
|
||||
}
|
||||
}
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
# If it has json_schema field, extract and transform it
|
||||
if "json_schema" in value:
|
||||
json_schema = value["json_schema"]
|
||||
# Extract the schema if nested
|
||||
if isinstance(json_schema, dict) and "schema" in json_schema:
|
||||
return {"jsonSchema": json_schema["schema"]}
|
||||
# Otherwise use json_schema directly
|
||||
return {"jsonSchema": json_schema}
|
||||
# If it already has jsonSchema, return as is
|
||||
if "jsonSchema" in value:
|
||||
return value
|
||||
# Otherwise return the dict as is
|
||||
return value
|
||||
return type_to_response_format_param(response_format=value) or value
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
input_prompt = self._convert_messages_to_prompt(messages=messages)
|
||||
request_data: Dict[str, Any] = {"inputPrompt": input_prompt}
|
||||
|
||||
media_source = self._build_media_source(optional_params)
|
||||
if media_source is not None:
|
||||
request_data["mediaSource"] = media_source
|
||||
|
||||
# Handle temperature and maxOutputTokens
|
||||
for key in ("temperature", "maxOutputTokens"):
|
||||
if key in optional_params:
|
||||
request_data[key] = optional_params.get(key)
|
||||
|
||||
# Handle responseFormat - transform to TwelveLabs format
|
||||
if "responseFormat" in optional_params:
|
||||
response_format = optional_params["responseFormat"]
|
||||
transformed_format = self._normalize_response_format(response_format)
|
||||
if transformed_format:
|
||||
request_data["responseFormat"] = transformed_format
|
||||
|
||||
return request_data
|
||||
|
||||
def _build_media_source(self, optional_params: dict) -> Optional[dict]:
|
||||
direct_source = optional_params.get("mediaSource") or optional_params.get(
|
||||
"media_source"
|
||||
)
|
||||
if isinstance(direct_source, dict):
|
||||
return direct_source
|
||||
|
||||
base64_input = optional_params.get("video_base64") or optional_params.get(
|
||||
"base64_string"
|
||||
)
|
||||
if base64_input:
|
||||
return {"base64String": get_base64_str(base64_input)}
|
||||
|
||||
s3_uri = (
|
||||
optional_params.get("video_s3_uri")
|
||||
or optional_params.get("s3_uri")
|
||||
or optional_params.get("media_source_s3_uri")
|
||||
)
|
||||
if s3_uri:
|
||||
s3_location = {"uri": s3_uri}
|
||||
bucket_owner = (
|
||||
optional_params.get("video_s3_bucket_owner")
|
||||
or optional_params.get("s3_bucket_owner")
|
||||
or optional_params.get("media_source_bucket_owner")
|
||||
)
|
||||
if bucket_owner:
|
||||
s3_location["bucketOwner"] = bucket_owner
|
||||
return {"s3Location": s3_location}
|
||||
return None
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: List[AllMessageValues]) -> str:
|
||||
prompt_parts: List[str] = []
|
||||
for message in messages:
|
||||
role = message.get("role", "user")
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, list):
|
||||
text_fragments = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
text_fragments.append(item.get("text", ""))
|
||||
elif item_type == "image_url":
|
||||
text_fragments.append("<image>")
|
||||
elif item_type == "video_url":
|
||||
text_fragments.append("<video>")
|
||||
elif item_type == "audio_url":
|
||||
text_fragments.append("<audio>")
|
||||
elif isinstance(item, str):
|
||||
text_fragments.append(item)
|
||||
content = " ".join(text_fragments)
|
||||
prompt_parts.append(f"{role}: {content}")
|
||||
return "\n".join(part for part in prompt_parts if part).strip()
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform TwelveLabs Pegasus response to LiteLLM format.
|
||||
|
||||
TwelveLabs response format:
|
||||
{
|
||||
"message": "...",
|
||||
"finishReason": "stop" | "length"
|
||||
}
|
||||
|
||||
LiteLLM format:
|
||||
ModelResponse with choices[0].message.content and finish_reason
|
||||
"""
|
||||
try:
|
||||
completion_response = raw_response.json()
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message=f"Error parsing response: {raw_response.text}, error: {str(e)}",
|
||||
status_code=raw_response.status_code,
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"twelvelabs pegasus response: %s",
|
||||
json.dumps(completion_response, indent=4, default=str),
|
||||
)
|
||||
|
||||
# Extract message content
|
||||
message_content = completion_response.get("message", "")
|
||||
|
||||
# Extract finish reason and map to LiteLLM format
|
||||
finish_reason_raw = completion_response.get("finishReason", "stop")
|
||||
finish_reason = map_finish_reason(finish_reason_raw)
|
||||
|
||||
# Set the response content
|
||||
try:
|
||||
if (
|
||||
message_content
|
||||
and hasattr(model_response.choices[0], "message")
|
||||
and getattr(model_response.choices[0].message, "tool_calls", None)
|
||||
is None
|
||||
):
|
||||
model_response.choices[0].message.content = message_content # type: ignore
|
||||
model_response.choices[0].finish_reason = finish_reason
|
||||
else:
|
||||
raise Exception("Unable to set message content")
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message=f"Error setting response content: {str(e)}. Response: {completion_response}",
|
||||
status_code=raw_response.status_code,
|
||||
)
|
||||
|
||||
# Calculate usage from headers
|
||||
bedrock_input_tokens = raw_response.headers.get(
|
||||
"x-amzn-bedrock-input-token-count", None
|
||||
)
|
||||
bedrock_output_tokens = raw_response.headers.get(
|
||||
"x-amzn-bedrock-output-token-count", None
|
||||
)
|
||||
|
||||
prompt_tokens = int(
|
||||
bedrock_input_tokens or litellm.token_counter(messages=messages)
|
||||
)
|
||||
|
||||
completion_tokens = int(
|
||||
bedrock_output_tokens
|
||||
or litellm.token_counter(
|
||||
text=model_response.choices[0].message.content, # type: ignore
|
||||
count_response_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
@@ -0,0 +1,98 @@
|
||||
import types
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
|
||||
from .base_invoke_transformation import AmazonInvokeConfig
|
||||
|
||||
|
||||
class AmazonAnthropicConfig(AmazonInvokeConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
|
||||
Supported Params for the Amazon / Anthropic models:
|
||||
|
||||
- `max_tokens_to_sample` (integer) max tokens,
|
||||
- `temperature` (float) model temperature,
|
||||
- `top_k` (integer) top k,
|
||||
- `top_p` (integer) top p,
|
||||
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
|
||||
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||
"""
|
||||
|
||||
max_tokens_to_sample: Optional[int] = litellm.max_tokens
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
anthropic_version: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens_to_sample: Optional[int] = None,
|
||||
stop_sequences: Optional[list] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
anthropic_version: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_legacy_anthropic_model_names():
|
||||
return [
|
||||
"anthropic.claude-v2",
|
||||
"anthropic.claude-instant-v1",
|
||||
"anthropic.claude-v2:1",
|
||||
]
|
||||
|
||||
def get_supported_openai_params(self, model: str):
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"temperature",
|
||||
"stop",
|
||||
"top_p",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["max_tokens_to_sample"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "stream" and value is True:
|
||||
optional_params["stream"] = value
|
||||
return optional_params
|
||||
@@ -0,0 +1,206 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.llms.bedrock.common_utils import (
|
||||
get_anthropic_beta_from_headers,
|
||||
remove_custom_field_from_tools,
|
||||
)
|
||||
from litellm.types.llms.anthropic import ANTHROPIC_TOOL_SEARCH_BETA_HEADER
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AmazonAnthropicClaudeConfig(AmazonInvokeConfig, AnthropicConfig):
|
||||
"""
|
||||
Reference:
|
||||
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
|
||||
https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages-request-response.html
|
||||
|
||||
Supported Params for the Amazon / Anthropic Claude models (Claude 3, Claude 4, etc.):
|
||||
Supports anthropic_beta parameter for beta features like:
|
||||
- computer-use-2025-01-24 (Claude 3.7 Sonnet)
|
||||
- computer-use-2024-10-22 (Claude 3.5 Sonnet v2)
|
||||
- token-efficient-tools-2025-02-19 (Claude 3.7 Sonnet)
|
||||
- interleaved-thinking-2025-05-14 (Claude 4 models)
|
||||
- output-128k-2025-02-19 (Claude 3.7 Sonnet)
|
||||
- dev-full-thinking-2025-05-14 (Claude 4 models)
|
||||
- context-1m-2025-08-07 (Claude Sonnet 4)
|
||||
"""
|
||||
|
||||
anthropic_version: str = "bedrock-2023-05-31"
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return "bedrock"
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return AnthropicConfig.get_supported_openai_params(self, model)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
# Force tool-based structured outputs for Bedrock Invoke
|
||||
# (similar to VertexAI fix in #19201)
|
||||
# Bedrock Invoke doesn't support output_format parameter
|
||||
original_model = model
|
||||
if "response_format" in non_default_params:
|
||||
# Use a model name that forces tool-based approach
|
||||
model = "claude-3-sonnet-20240229"
|
||||
|
||||
optional_params = AnthropicConfig.map_openai_params(
|
||||
self,
|
||||
non_default_params,
|
||||
optional_params,
|
||||
model,
|
||||
drop_params,
|
||||
)
|
||||
|
||||
# Restore original model name
|
||||
model = original_model
|
||||
|
||||
return optional_params
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
# Filter out AWS authentication parameters before passing to Anthropic transformation
|
||||
# AWS params should only be used for signing requests, not included in request body
|
||||
filtered_params = {
|
||||
k: v
|
||||
for k, v in optional_params.items()
|
||||
if k not in self.aws_authentication_params
|
||||
}
|
||||
filtered_params = self._normalize_bedrock_tool_search_tools(filtered_params)
|
||||
|
||||
_anthropic_request = AnthropicConfig.transform_request(
|
||||
self,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=filtered_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
_anthropic_request.pop("model", None)
|
||||
_anthropic_request.pop("stream", None)
|
||||
# Bedrock Invoke doesn't support output_format parameter
|
||||
_anthropic_request.pop("output_format", None)
|
||||
# Bedrock Invoke doesn't support output_config parameter
|
||||
# Fixes: https://github.com/BerriAI/litellm/issues/22797
|
||||
_anthropic_request.pop("output_config", None)
|
||||
if "anthropic_version" not in _anthropic_request:
|
||||
_anthropic_request["anthropic_version"] = self.anthropic_version
|
||||
|
||||
# Remove `custom` field from tools (Bedrock doesn't support it)
|
||||
# Claude Code sends `custom: {defer_loading: true}` on tool definitions,
|
||||
# which causes Bedrock to reject the request with "Extra inputs are not permitted"
|
||||
# Ref: https://github.com/BerriAI/litellm/issues/22847
|
||||
remove_custom_field_from_tools(_anthropic_request)
|
||||
|
||||
tools = optional_params.get("tools")
|
||||
tool_search_used = self.is_tool_search_used(tools)
|
||||
programmatic_tool_calling_used = self.is_programmatic_tool_calling_used(tools)
|
||||
input_examples_used = self.is_input_examples_used(tools)
|
||||
|
||||
beta_set = set(get_anthropic_beta_from_headers(headers))
|
||||
auto_betas = self.get_anthropic_beta_list(
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
computer_tool_used=self.is_computer_tool_used(tools),
|
||||
prompt_caching_set=False,
|
||||
file_id_used=self.is_file_id_used(messages),
|
||||
mcp_server_used=self.is_mcp_server_used(optional_params.get("mcp_servers")),
|
||||
)
|
||||
beta_set.update(auto_betas)
|
||||
|
||||
if tool_search_used and not (
|
||||
programmatic_tool_calling_used or input_examples_used
|
||||
):
|
||||
beta_set.discard(ANTHROPIC_TOOL_SEARCH_BETA_HEADER)
|
||||
if "opus-4" in model.lower() or "opus_4" in model.lower():
|
||||
beta_set.add("tool-search-tool-2025-10-19")
|
||||
|
||||
# Filter out beta headers that Bedrock Invoke doesn't support
|
||||
# Uses centralized configuration from anthropic_beta_headers_config.json
|
||||
beta_list = list(beta_set)
|
||||
_anthropic_request["anthropic_beta"] = beta_list
|
||||
|
||||
return _anthropic_request
|
||||
|
||||
def _normalize_bedrock_tool_search_tools(self, optional_params: dict) -> dict:
|
||||
"""
|
||||
Convert tool search entries to the format supported by the Bedrock Invoke API.
|
||||
"""
|
||||
tools = optional_params.get("tools")
|
||||
if not tools or not isinstance(tools, list):
|
||||
return optional_params
|
||||
|
||||
normalized_tools = []
|
||||
for tool in tools:
|
||||
tool_type = tool.get("type")
|
||||
if tool_type == "tool_search_tool_bm25_20251119":
|
||||
# Bedrock Invoke does not support the BM25 variant, so skip it.
|
||||
continue
|
||||
if tool_type == "tool_search_tool_regex_20251119":
|
||||
normalized_tool = tool.copy()
|
||||
normalized_tool["type"] = "tool_search_tool_regex"
|
||||
normalized_tool["name"] = normalized_tool.get(
|
||||
"name", "tool_search_tool_regex"
|
||||
)
|
||||
normalized_tools.append(normalized_tool)
|
||||
continue
|
||||
normalized_tools.append(tool)
|
||||
|
||||
optional_params["tools"] = normalized_tools
|
||||
return optional_params
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
return AnthropicConfig.transform_response(
|
||||
self,
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
@@ -0,0 +1,613 @@
|
||||
import copy
|
||||
import json
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
cohere_message_pt,
|
||||
custom_prompt,
|
||||
deepseek_r1_pt,
|
||||
prompt_factory,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.bedrock.chat.invoke_handler import make_call, make_sync_call
|
||||
from litellm.llms.bedrock.common_utils import BedrockError
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
|
||||
|
||||
class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||
def __init__(self, **kwargs):
|
||||
BaseConfig.__init__(self, **kwargs)
|
||||
BaseAWSLLM.__init__(self, **kwargs)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
This is a base invoke model mapping. For Invoke - define a bedrock provider specific config that extends this class.
|
||||
"""
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
This is a base invoke model mapping. For Invoke - define a bedrock provider specific config that extends this class.
|
||||
"""
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
return optional_params
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete url for the request
|
||||
"""
|
||||
provider = self.get_bedrock_invoke_provider(model)
|
||||
modelId = self.get_bedrock_model_id(
|
||||
model=model,
|
||||
provider=provider,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
aws_bedrock_runtime_endpoint = optional_params.get(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=self._get_aws_region_name(
|
||||
optional_params=optional_params, model=model
|
||||
),
|
||||
)
|
||||
|
||||
if (stream is not None and stream is True) and provider != "ai21":
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream"
|
||||
proxy_endpoint_url = (
|
||||
f"{proxy_endpoint_url}/model/{modelId}/invoke-with-response-stream"
|
||||
)
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
|
||||
|
||||
return endpoint_url
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
headers: dict,
|
||||
optional_params: dict,
|
||||
request_data: dict,
|
||||
api_base: str,
|
||||
api_key: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
fake_stream: Optional[bool] = None,
|
||||
) -> Tuple[dict, Optional[bytes]]:
|
||||
return self._sign_request(
|
||||
service_name="bedrock",
|
||||
headers=headers,
|
||||
optional_params=optional_params,
|
||||
request_data=request_data,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
stream=stream,
|
||||
fake_stream=fake_stream,
|
||||
)
|
||||
|
||||
def _apply_config_to_params(self, config: dict, inference_params: dict) -> None:
|
||||
"""Apply config values to inference_params if not already set."""
|
||||
for k, v in config.items():
|
||||
if k not in inference_params:
|
||||
inference_params[k] = v
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
## SETUP ##
|
||||
stream = optional_params.pop("stream", None)
|
||||
custom_prompt_dict: dict = litellm_params.pop("custom_prompt_dict", None) or {}
|
||||
hf_model_name = litellm_params.get("hf_model_name", None)
|
||||
|
||||
provider = self.get_bedrock_invoke_provider(model)
|
||||
|
||||
prompt, chat_history = self.convert_messages_to_prompt(
|
||||
model=hf_model_name or model,
|
||||
messages=messages,
|
||||
provider=provider,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
)
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params = {
|
||||
k: v
|
||||
for k, v in inference_params.items()
|
||||
if k not in self.aws_authentication_params
|
||||
}
|
||||
request_data: dict = {}
|
||||
if provider == "cohere":
|
||||
if model.startswith("cohere.command-r"):
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonCohereChatConfig().get_config()
|
||||
self._apply_config_to_params(config, inference_params)
|
||||
_data = {"message": prompt, **inference_params}
|
||||
if chat_history is not None:
|
||||
_data["chat_history"] = chat_history
|
||||
request_data = _data
|
||||
else:
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonCohereConfig.get_config()
|
||||
self._apply_config_to_params(config, inference_params)
|
||||
if stream is True:
|
||||
inference_params[
|
||||
"stream"
|
||||
] = True # cohere requires stream = True in inference params
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
elif provider == "anthropic":
|
||||
transformed_request = (
|
||||
litellm.AmazonAnthropicClaudeConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
)
|
||||
|
||||
return transformed_request
|
||||
elif provider == "nova":
|
||||
return litellm.AmazonInvokeNovaConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
elif provider == "ai21":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAI21Config.get_config()
|
||||
self._apply_config_to_params(config, inference_params)
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
elif provider == "mistral":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonMistralConfig.get_config()
|
||||
self._apply_config_to_params(config, inference_params)
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
elif provider == "amazon": # amazon titan
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonTitanConfig.get_config()
|
||||
self._apply_config_to_params(config, inference_params)
|
||||
request_data = {
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig": inference_params,
|
||||
}
|
||||
elif provider == "meta" or provider == "llama" or provider == "deepseek_r1":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonLlamaConfig.get_config()
|
||||
self._apply_config_to_params(config, inference_params)
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
elif provider == "twelvelabs":
|
||||
return litellm.AmazonTwelveLabsPegasusConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
elif provider == "openai":
|
||||
# OpenAI imported models use OpenAI Chat Completions format
|
||||
return litellm.AmazonBedrockOpenAIConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
raise BedrockError(
|
||||
status_code=404,
|
||||
message="Bedrock Invoke HTTPX: Unknown provider={}, model={}. Try calling via converse route - `bedrock/converse/<model>`.".format(
|
||||
provider, model
|
||||
),
|
||||
)
|
||||
|
||||
return request_data
|
||||
|
||||
def transform_response( # noqa: PLR0915
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
completion_response = raw_response.json()
|
||||
except Exception:
|
||||
raise BedrockError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"bedrock invoke response % s",
|
||||
json.dumps(completion_response, indent=4, default=str),
|
||||
)
|
||||
provider = self.get_bedrock_invoke_provider(model)
|
||||
outputText: Optional[str] = None
|
||||
try:
|
||||
if provider == "cohere":
|
||||
if "text" in completion_response:
|
||||
outputText = completion_response["text"] # type: ignore
|
||||
elif "generations" in completion_response:
|
||||
outputText = completion_response["generations"][0]["text"]
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
completion_response["generations"][0]["finish_reason"]
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
return litellm.AmazonAnthropicClaudeConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
elif provider == "nova":
|
||||
return litellm.AmazonInvokeNovaConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
elif provider == "twelvelabs":
|
||||
return litellm.AmazonTwelveLabsPegasusConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
elif provider == "ai21":
|
||||
outputText = (
|
||||
completion_response.get("completions")[0].get("data").get("text")
|
||||
)
|
||||
elif provider == "meta" or provider == "llama" or provider == "deepseek_r1":
|
||||
outputText = completion_response["generation"]
|
||||
elif provider == "mistral":
|
||||
outputText = litellm.AmazonMistralConfig.get_outputText(
|
||||
completion_response, model_response
|
||||
)
|
||||
else: # amazon titan
|
||||
outputText = completion_response.get("results")[0].get("outputText")
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message="Error processing={}, Received error={}".format(
|
||||
raw_response.text, str(e)
|
||||
),
|
||||
status_code=422,
|
||||
)
|
||||
|
||||
try:
|
||||
if (
|
||||
outputText is not None
|
||||
and len(outputText) > 0
|
||||
and hasattr(model_response.choices[0], "message")
|
||||
and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore
|
||||
is None
|
||||
):
|
||||
model_response.choices[0].message.content = outputText # type: ignore
|
||||
elif (
|
||||
hasattr(model_response.choices[0], "message")
|
||||
and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore
|
||||
is not None
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise Exception()
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message="Error parsing received text={}.\nError-{}".format(
|
||||
outputText, str(e)
|
||||
),
|
||||
status_code=raw_response.status_code,
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||
bedrock_input_tokens = raw_response.headers.get(
|
||||
"x-amzn-bedrock-input-token-count", None
|
||||
)
|
||||
bedrock_output_tokens = raw_response.headers.get(
|
||||
"x-amzn-bedrock-output-token-count", None
|
||||
)
|
||||
|
||||
prompt_tokens = int(
|
||||
bedrock_input_tokens or litellm.token_counter(messages=messages)
|
||||
)
|
||||
|
||||
completion_tokens = int(
|
||||
bedrock_output_tokens
|
||||
or litellm.token_counter(
|
||||
text=model_response.choices[0].message.content, # type: ignore
|
||||
count_response_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return headers
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return BedrockError(status_code=status_code, message=error_message)
|
||||
|
||||
@track_llm_api_timing()
|
||||
async def get_async_custom_stream_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
messages: list,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
signed_json_body: Optional[bytes] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_call,
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
fake_stream=True if "ai21" in api_base else False,
|
||||
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
||||
json_mode=json_mode,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
@track_llm_api_timing()
|
||||
def get_sync_custom_stream_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
messages: list,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
signed_json_body: Optional[bytes] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
client = _get_httpx_client(params={})
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_sync_call,
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
signed_json_body=signed_json_body,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
fake_stream=True if "ai21" in api_base else False,
|
||||
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
||||
json_mode=json_mode,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
@property
|
||||
def has_custom_stream_wrapper(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_stream_param_in_request_body(self) -> bool:
|
||||
"""
|
||||
Bedrock invoke does not allow passing `stream` in the request body.
|
||||
"""
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_invoke_provider(
|
||||
model: str,
|
||||
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the bedrock provider from the model
|
||||
|
||||
handles 4 scenarios:
|
||||
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
|
||||
"""
|
||||
if model.startswith("invoke/"):
|
||||
model = model.replace("invoke/", "", 1)
|
||||
|
||||
# Special case: Check for "nova" in model name first (before "amazon")
|
||||
# This handles amazon.nova-* models which would otherwise match "amazon" (Titan)
|
||||
if "nova" in model.lower():
|
||||
if "nova" in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, "nova")
|
||||
|
||||
_split_model = model.split(".")[0]
|
||||
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
||||
|
||||
# If not a known provider, check for pattern with two slashes
|
||||
provider = AmazonInvokeConfig._get_provider_from_model_path(model)
|
||||
if provider is not None:
|
||||
return provider
|
||||
|
||||
for provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
if provider in model:
|
||||
return provider
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_from_model_path(
|
||||
model_path: str,
|
||||
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the provider from a model path with format: provider/model-name
|
||||
|
||||
Args:
|
||||
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
|
||||
|
||||
Returns:
|
||||
Optional[str]: The provider name, or None if no valid provider found
|
||||
"""
|
||||
parts = model_path.split("/")
|
||||
if len(parts) >= 1:
|
||||
provider = parts[0]
|
||||
if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
|
||||
return None
|
||||
|
||||
def convert_messages_to_prompt(
|
||||
self, model, messages, provider, custom_prompt_dict
|
||||
) -> Tuple[str, Optional[list]]:
|
||||
# handle anthropic prompts and amazon titan prompts
|
||||
prompt = ""
|
||||
chat_history: Optional[list] = None
|
||||
## CUSTOM PROMPT
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details.get(
|
||||
"initial_prompt_value", ""
|
||||
),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
messages=messages,
|
||||
)
|
||||
return prompt, None
|
||||
## ELSE
|
||||
if provider == "anthropic" or provider == "amazon":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "mistral":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "meta" or provider == "llama":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "cohere":
|
||||
prompt, chat_history = cohere_message_pt(messages=messages)
|
||||
elif provider == "deepseek_r1":
|
||||
prompt = deepseek_r1_pt(messages=messages)
|
||||
else:
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
return prompt, chat_history # type: ignore
|
||||
Reference in New Issue
Block a user