chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Bedrock Token Counter implementation using the CountTokens API.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base_llm.base_utils import BaseTokenCounter
|
||||
from litellm.llms.bedrock.common_utils import BedrockError, get_bedrock_base_model
|
||||
from litellm.llms.bedrock.count_tokens.handler import BedrockCountTokensHandler
|
||||
from litellm.types.utils import LlmProviders, TokenCountResponse
|
||||
|
||||
|
||||
class BedrockTokenCounter(BaseTokenCounter):
|
||||
"""Token counter implementation for AWS Bedrock provider using the CountTokens API."""
|
||||
|
||||
def should_use_token_counting_api(
|
||||
self,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if we should use the Bedrock CountTokens API for token counting.
|
||||
"""
|
||||
return custom_llm_provider == LlmProviders.BEDROCK.value
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
model_to_use: str,
|
||||
messages: Optional[List[Dict[str, Any]]],
|
||||
contents: Optional[List[Dict[str, Any]]],
|
||||
deployment: Optional[Dict[str, Any]] = None,
|
||||
request_model: str = "",
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
system: Optional[Any] = None,
|
||||
) -> Optional[TokenCountResponse]:
|
||||
"""
|
||||
Count tokens using AWS Bedrock's CountTokens API.
|
||||
|
||||
This method calls the existing BedrockCountTokensHandler to make an API call
|
||||
to Bedrock's token counting endpoint, bypassing the local tiktoken-based counting.
|
||||
|
||||
Args:
|
||||
model_to_use: The model identifier
|
||||
messages: The messages to count tokens for
|
||||
contents: Alternative content format (not used for Bedrock)
|
||||
deployment: Deployment configuration containing litellm_params
|
||||
request_model: The original request model name
|
||||
|
||||
Returns:
|
||||
TokenCountResponse with token count, or None if counting fails
|
||||
"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
deployment = deployment or {}
|
||||
litellm_params = deployment.get("litellm_params", {})
|
||||
|
||||
# Build request data in the format expected by BedrockCountTokensHandler
|
||||
request_data: Dict[str, Any] = {
|
||||
"model": model_to_use,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
if tools:
|
||||
request_data["tools"] = tools
|
||||
|
||||
if system:
|
||||
request_data["system"] = system
|
||||
|
||||
# Get the resolved model (strip prefixes like bedrock/, converse/, etc.)
|
||||
resolved_model = get_bedrock_base_model(model_to_use)
|
||||
|
||||
try:
|
||||
handler = BedrockCountTokensHandler()
|
||||
result = await handler.handle_count_tokens_request(
|
||||
request_data=request_data,
|
||||
litellm_params=litellm_params,
|
||||
resolved_model=resolved_model,
|
||||
)
|
||||
|
||||
# Transform response to TokenCountResponse
|
||||
if result is not None:
|
||||
return TokenCountResponse(
|
||||
total_tokens=result.get("input_tokens", 0),
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="bedrock_api",
|
||||
original_response=result,
|
||||
)
|
||||
except BedrockError as e:
|
||||
verbose_logger.warning(
|
||||
f"Bedrock CountTokens API error: status={e.status_code}, message={e.message}"
|
||||
)
|
||||
return TokenCountResponse(
|
||||
total_tokens=0,
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="bedrock_api",
|
||||
error=True,
|
||||
error_message=e.message,
|
||||
status_code=e.status_code,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Error calling Bedrock CountTokens API: {e}")
|
||||
return TokenCountResponse(
|
||||
total_tokens=0,
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="bedrock_api",
|
||||
error=True,
|
||||
error_message=str(e),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
AWS Bedrock CountTokens API handler.
|
||||
|
||||
Simplified handler leveraging existing LiteLLM Bedrock infrastructure.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.bedrock.common_utils import BedrockError
|
||||
from litellm.llms.bedrock.count_tokens.transformation import BedrockCountTokensConfig
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
|
||||
|
||||
class BedrockCountTokensHandler(BedrockCountTokensConfig):
|
||||
"""
|
||||
Simplified handler for AWS Bedrock CountTokens API requests.
|
||||
|
||||
Uses existing LiteLLM infrastructure for authentication and request handling.
|
||||
"""
|
||||
|
||||
async def handle_count_tokens_request(
|
||||
self,
|
||||
request_data: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
resolved_model: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle a CountTokens request using existing LiteLLM patterns.
|
||||
|
||||
Args:
|
||||
request_data: The incoming request payload
|
||||
litellm_params: LiteLLM configuration parameters
|
||||
resolved_model: The actual model ID resolved from router
|
||||
|
||||
Returns:
|
||||
Dictionary containing token count response
|
||||
"""
|
||||
try:
|
||||
# Validate the request
|
||||
self.validate_count_tokens_request(request_data)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Processing CountTokens request for resolved model: {resolved_model}"
|
||||
)
|
||||
|
||||
# Get AWS region using existing LiteLLM function
|
||||
aws_region_name = self._get_aws_region_name(
|
||||
optional_params=litellm_params,
|
||||
model=resolved_model,
|
||||
model_id=None,
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Retrieved AWS region: {aws_region_name}")
|
||||
|
||||
# Transform request to Bedrock format (supports both Converse and InvokeModel)
|
||||
bedrock_request = self.transform_anthropic_to_bedrock_count_tokens(
|
||||
request_data=request_data
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Transformed request: {bedrock_request}")
|
||||
|
||||
# Get endpoint URL using simplified function
|
||||
endpoint_url = self.get_bedrock_count_tokens_endpoint(
|
||||
resolved_model, aws_region_name
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Making request to: {endpoint_url}")
|
||||
|
||||
# Use existing _sign_request method from BaseAWSLLM
|
||||
# Extract api_key for bearer token auth if provided
|
||||
api_key = litellm_params.get("api_key", None)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
signed_headers, signed_body = self._sign_request(
|
||||
service_name="bedrock",
|
||||
headers=headers,
|
||||
optional_params=litellm_params,
|
||||
request_data=bedrock_request,
|
||||
api_base=endpoint_url,
|
||||
model=resolved_model,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.BEDROCK
|
||||
)
|
||||
|
||||
response = await async_client.post(
|
||||
endpoint_url,
|
||||
headers=signed_headers,
|
||||
data=signed_body,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Response status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
verbose_logger.error(f"AWS Bedrock error: {error_text}")
|
||||
raise BedrockError(
|
||||
status_code=response.status_code,
|
||||
message=error_text,
|
||||
)
|
||||
|
||||
bedrock_response = response.json()
|
||||
|
||||
verbose_logger.debug(f"Bedrock response: {bedrock_response}")
|
||||
|
||||
# Transform response back to expected format
|
||||
final_response = self.transform_bedrock_response_to_anthropic(
|
||||
bedrock_response
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Final response: {final_response}")
|
||||
|
||||
return final_response
|
||||
|
||||
except BedrockError:
|
||||
# Re-raise Bedrock exceptions as-is
|
||||
raise
|
||||
except httpx.HTTPStatusError as e:
|
||||
# HTTP errors - preserve the actual status code
|
||||
verbose_logger.error(f"HTTP error in CountTokens handler: {str(e)}")
|
||||
raise BedrockError(
|
||||
status_code=e.response.status_code,
|
||||
message=e.response.text,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error in CountTokens handler: {str(e)}")
|
||||
raise BedrockError(
|
||||
status_code=500,
|
||||
message=f"CountTokens processing error: {str(e)}",
|
||||
)
|
||||
@@ -0,0 +1,262 @@
|
||||
"""
|
||||
AWS Bedrock CountTokens API transformation logic.
|
||||
|
||||
This module handles the transformation of requests from Anthropic Messages API format
|
||||
to AWS Bedrock's CountTokens API format and vice versa.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
from litellm.llms.bedrock.common_utils import get_bedrock_base_model
|
||||
|
||||
|
||||
class BedrockCountTokensConfig(BaseAWSLLM):
|
||||
"""
|
||||
Configuration and transformation logic for AWS Bedrock CountTokens API.
|
||||
|
||||
AWS Bedrock CountTokens API Specification:
|
||||
- Endpoint: POST /model/{modelId}/count-tokens
|
||||
- Input formats: 'invokeModel' or 'converse'
|
||||
- Response: {"inputTokens": <number>}
|
||||
"""
|
||||
|
||||
def _detect_input_type(self, request_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Detect whether to use 'converse' or 'invokeModel' input format.
|
||||
|
||||
Args:
|
||||
request_data: The original request data
|
||||
|
||||
Returns:
|
||||
'converse' or 'invokeModel'
|
||||
"""
|
||||
# If the request has messages in the expected Anthropic format, use converse
|
||||
if "messages" in request_data and isinstance(request_data["messages"], list):
|
||||
return "converse"
|
||||
|
||||
# For raw text or other formats, use invokeModel
|
||||
# This handles cases where the input is prompt-based or already in raw Bedrock format
|
||||
return "invokeModel"
|
||||
|
||||
def transform_anthropic_to_bedrock_count_tokens(
|
||||
self,
|
||||
request_data: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform request to Bedrock CountTokens format.
|
||||
Supports both Converse and InvokeModel input types.
|
||||
|
||||
Input (Anthropic format):
|
||||
{
|
||||
"model": "claude-3-5-sonnet",
|
||||
"messages": [{"role": "user", "content": "Hello!"}]
|
||||
}
|
||||
|
||||
Output (Bedrock CountTokens format for Converse):
|
||||
{
|
||||
"input": {
|
||||
"converse": {
|
||||
"messages": [...],
|
||||
"system": [...] (if present)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Output (Bedrock CountTokens format for InvokeModel):
|
||||
{
|
||||
"input": {
|
||||
"invokeModel": {
|
||||
"body": "{...raw model input...}"
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
input_type = self._detect_input_type(request_data)
|
||||
|
||||
if input_type == "converse":
|
||||
return self._transform_to_converse_format(request_data)
|
||||
else:
|
||||
return self._transform_to_invoke_model_format(request_data)
|
||||
|
||||
def _transform_to_converse_format(
|
||||
self, request_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Transform to Converse input format, including system and tools."""
|
||||
messages = request_data.get("messages", [])
|
||||
system = request_data.get("system")
|
||||
tools = request_data.get("tools")
|
||||
|
||||
# Transform messages
|
||||
user_messages = []
|
||||
for message in messages:
|
||||
transformed_message: Dict[str, Any] = {
|
||||
"role": message.get("role"),
|
||||
"content": [],
|
||||
}
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, str):
|
||||
transformed_message["content"].append({"text": content})
|
||||
elif isinstance(content, list):
|
||||
transformed_message["content"] = content
|
||||
user_messages.append(transformed_message)
|
||||
|
||||
converse_input: Dict[str, Any] = {"messages": user_messages}
|
||||
|
||||
# Transform system prompt (string or list of blocks → Bedrock format)
|
||||
system_blocks = self._transform_system(system)
|
||||
if system_blocks:
|
||||
converse_input["system"] = system_blocks
|
||||
|
||||
# Transform tools (Anthropic format → Bedrock toolConfig)
|
||||
tool_config = self._transform_tools(tools)
|
||||
if tool_config:
|
||||
converse_input["toolConfig"] = tool_config
|
||||
|
||||
return {"input": {"converse": converse_input}}
|
||||
|
||||
def _transform_system(self, system: Optional[Any]) -> List[Dict[str, Any]]:
|
||||
"""Transform Anthropic system prompt to Bedrock system blocks."""
|
||||
if system is None:
|
||||
return []
|
||||
if isinstance(system, str):
|
||||
return [{"text": system}]
|
||||
if isinstance(system, list):
|
||||
# Already in blocks format (e.g. [{"type": "text", "text": "..."}])
|
||||
return [
|
||||
{"text": block.get("text", "")}
|
||||
for block in system
|
||||
if isinstance(block, dict)
|
||||
]
|
||||
return []
|
||||
|
||||
def _transform_tools(
|
||||
self, tools: Optional[List[Dict[str, Any]]]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Transform Anthropic tools to Bedrock toolConfig format."""
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
bedrock_tools = []
|
||||
for tool in tools:
|
||||
name = tool.get("name", "")
|
||||
# Bedrock tool names must match [a-zA-Z][a-zA-Z0-9_]* and max 64 chars
|
||||
name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
||||
if name and not name[0].isalpha():
|
||||
name = "t_" + name
|
||||
name = name[:64]
|
||||
|
||||
description = tool.get("description") or name
|
||||
input_schema = tool.get(
|
||||
"input_schema", {"type": "object", "properties": {}}
|
||||
)
|
||||
|
||||
bedrock_tools.append(
|
||||
{
|
||||
"toolSpec": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"inputSchema": {"json": input_schema},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return {"tools": bedrock_tools}
|
||||
|
||||
def _transform_to_invoke_model_format(
|
||||
self, request_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Transform to InvokeModel input format."""
|
||||
import json
|
||||
|
||||
# For InvokeModel, we need to provide the raw body that would be sent to the model
|
||||
# Remove the 'model' field from the body as it's not part of the model input
|
||||
body_data = {k: v for k, v in request_data.items() if k != "model"}
|
||||
|
||||
return {"input": {"invokeModel": {"body": json.dumps(body_data)}}}
|
||||
|
||||
def get_bedrock_count_tokens_endpoint(
|
||||
self, model: str, aws_region_name: str
|
||||
) -> str:
|
||||
"""
|
||||
Construct the AWS Bedrock CountTokens API endpoint using existing LiteLLM functions.
|
||||
|
||||
Args:
|
||||
model: The resolved model ID from router lookup
|
||||
aws_region_name: AWS region (e.g., "eu-west-1")
|
||||
|
||||
Returns:
|
||||
Complete endpoint URL for CountTokens API
|
||||
"""
|
||||
# Use existing LiteLLM function to get the base model ID (removes region prefix)
|
||||
model_id = get_bedrock_base_model(model)
|
||||
|
||||
# Remove bedrock/ prefix if present
|
||||
if model_id.startswith("bedrock/"):
|
||||
model_id = model_id[8:] # Remove "bedrock/" prefix
|
||||
|
||||
base_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
endpoint = f"{base_url}/model/{model_id}/count-tokens"
|
||||
|
||||
return endpoint
|
||||
|
||||
def transform_bedrock_response_to_anthropic(
|
||||
self, bedrock_response: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform Bedrock CountTokens response to Anthropic format.
|
||||
|
||||
Input (Bedrock response):
|
||||
{
|
||||
"inputTokens": 123
|
||||
}
|
||||
|
||||
Output (Anthropic format):
|
||||
{
|
||||
"input_tokens": 123
|
||||
}
|
||||
"""
|
||||
input_tokens = bedrock_response.get("inputTokens", 0)
|
||||
|
||||
return {"input_tokens": input_tokens}
|
||||
|
||||
def validate_count_tokens_request(self, request_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the incoming count tokens request.
|
||||
Supports both Converse and InvokeModel input formats.
|
||||
|
||||
Args:
|
||||
request_data: The request payload
|
||||
|
||||
Raises:
|
||||
ValueError: If the request is invalid
|
||||
"""
|
||||
if not request_data.get("model"):
|
||||
raise ValueError("model parameter is required")
|
||||
|
||||
input_type = self._detect_input_type(request_data)
|
||||
|
||||
if input_type == "converse":
|
||||
# Validate Converse format (messages-based)
|
||||
messages = request_data.get("messages", [])
|
||||
if not messages:
|
||||
raise ValueError("messages parameter is required for Converse input")
|
||||
|
||||
if not isinstance(messages, list):
|
||||
raise ValueError("messages must be a list")
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
if not isinstance(message, dict):
|
||||
raise ValueError(f"Message {i} must be a dictionary")
|
||||
|
||||
if "role" not in message:
|
||||
raise ValueError(f"Message {i} must have a 'role' field")
|
||||
|
||||
if "content" not in message:
|
||||
raise ValueError(f"Message {i} must have a 'content' field")
|
||||
else:
|
||||
# For InvokeModel format, we need at least some content to count tokens
|
||||
# The content structure varies by model, so we do minimal validation
|
||||
if len(request_data) <= 1: # Only has 'model' field
|
||||
raise ValueError("Request must contain content to count tokens")
|
||||
Reference in New Issue
Block a user