Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/llms/sagemaker/embedding/transformation.py

161 lines
5.0 KiB
Python
Raw Normal View History

"""
Translate from OpenAI's `/v1/embeddings` to Sagemaker's `/invoke`
In the Huggingface TGI format.
"""
from typing import TYPE_CHECKING, Any, List, Optional, Union
if TYPE_CHECKING:
from litellm.types.llms.openai import AllEmbeddingInputValues
from httpx._models import Headers, Response
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.utils import Usage, EmbeddingResponse
from litellm.llms.voyage.embedding.transformation import VoyageEmbeddingConfig
from ..common_utils import SagemakerError
class SagemakerEmbeddingConfig(BaseEmbeddingConfig):
"""
SageMaker embedding configuration factory for supporting embedding parameters
"""
def __init__(self) -> None:
pass
@classmethod
def get_model_config(cls, model: str) -> "BaseEmbeddingConfig":
"""
Factory method to get the appropriate embedding config based on model type
Args:
model: The model name
Returns:
Appropriate embedding config instance
"""
if "voyage" in model.lower():
return VoyageEmbeddingConfig()
else:
return cls()
def get_supported_openai_params(self, model: str) -> List[str]:
# Check if this is an embedding model
if "voyage" in model.lower():
return VoyageEmbeddingConfig().get_supported_openai_params(model)
else:
return []
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return optional_params
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return SagemakerError(
message=error_message, status_code=status_code, headers=headers
)
def transform_embedding_request(
self,
model: str,
input: "AllEmbeddingInputValues",
optional_params: dict,
headers: dict,
) -> dict:
"""
Transform embedding request for Hugging Face models on SageMaker
"""
# HF models expect "inputs" field (plural)
return {"inputs": input, **optional_params}
def transform_embedding_response(
self,
model: str,
raw_response: Response,
model_response: "EmbeddingResponse",
logging_obj: Any,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> "EmbeddingResponse":
"""
Transform embedding response for Hugging Face models on SageMaker
"""
try:
response_data = raw_response.json()
except Exception as e:
raise SagemakerError(
message=f"Failed to parse response: {str(e)}",
status_code=raw_response.status_code,
)
# Handle both raw array format (TEI) and wrapped format (standard HF)
if isinstance(response_data, list):
# TEI and some HF models return raw embedding arrays directly
embeddings = response_data
elif isinstance(response_data, dict) and "embedding" in response_data:
# Standard HF format with "embedding" key
embeddings = response_data["embedding"]
else:
raise SagemakerError(
status_code=500,
message=f"Unexpected response format. Expected list or dict with 'embedding' key, got: {type(response_data).__name__}",
)
if not isinstance(embeddings, list):
raise SagemakerError(
status_code=422,
message=f"HF response not in expected format - {embeddings}",
)
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response.object = "list"
model_response.data = output_data
model_response.model = model
# Calculate usage from request data
input_texts = request_data.get("inputs", [])
input_tokens = 0
for text in input_texts:
input_tokens += len(text.split()) # Simple word count fallback
model_response.usage = Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
)
return model_response
def validate_environment(
self,
headers: dict,
model: str,
messages: List[Any],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate environment for SageMaker embeddings
"""
return {"Content-Type": "application/json"}