Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/llms/sap/embed/transformation.py
2026-03-26 16:04:46 +08:00

178 lines
5.2 KiB
Python

"""
Translates from OpenAI's `/v1/embeddings` to IBM's `/text/embeddings` route.
"""
from typing import Optional, List, Dict, Literal, Union
from pydantic import BaseModel, Field
from functools import cached_property
import httpx
from litellm.llms.base_llm.embedding.transformation import (
BaseEmbeddingConfig,
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import AllEmbeddingInputValues
from litellm.types.utils import EmbeddingResponse
from ..chat.handler import GenAIHubOrchestrationError
from ..credentials import get_token_creator
class Usage(BaseModel):
prompt_tokens: int
total_tokens: int
class EmbeddingItem(BaseModel):
object: Literal["embedding"]
embedding: List[float] = Field(
..., description="Vector of floats (length varies by model)."
)
index: int
class FinalResult(BaseModel):
object: Literal["list"]
data: List[EmbeddingItem]
model: str
usage: Usage
class EmbeddingsResponse(BaseModel):
request_id: str
final_result: FinalResult
class EmbeddingModel(BaseModel):
name: str
version: str = "latest"
params: dict = Field(default_factory=dict, validation_alias="parameters")
class EmbeddingsModules(BaseModel):
embeddings: EmbeddingModel
class EmbeddingInput(BaseModel):
text: Union[str, List[str]]
type: Literal["text", "document", "query"] = "text"
class EmbeddingRequest(BaseModel):
config: EmbeddingsModules
input: EmbeddingInput
def validate_dict(data: dict, model) -> dict:
return model(**data).model_dump()
class GenAIHubEmbeddingConfig(BaseEmbeddingConfig):
def __init__(self):
super().__init__()
self._access_token_data = {}
self.token_creator, self.base_url, self.resource_group = get_token_creator()
@property
def headers(self) -> Dict:
access_token = self.token_creator()
# headers for completions and embeddings requests
headers = {
"Authorization": access_token,
"AI-Resource-Group": self.resource_group,
"Content-Type": "application/json",
"AI-Client-Type": "LiteLLM",
}
return headers
@cached_property
def deployment_url(self) -> str:
with httpx.Client(timeout=30) as client:
valid_deployments = []
deployments = client.get(
self.base_url + "/lm/deployments", headers=self.headers
).json()
for deployment in deployments.get("resources", []):
if deployment["scenarioId"] == "orchestration":
config_details = client.get(
self.base_url
+ f'/lm/configurations/{deployment["configurationId"]}',
headers=self.headers,
).json()
if config_details["executableId"] == "orchestration":
valid_deployments.append(
(deployment["deploymentUrl"], deployment["createdAt"])
)
return sorted(valid_deployments, key=lambda x: x[1], reverse=True)[0][0]
def get_error_class(self, error_message, status_code, headers):
return GenAIHubOrchestrationError(status_code, error_message)
def get_supported_openai_params(self, model: str) -> list:
if "text-embedding-3" in model:
return ["encoding_format", "dimensions"]
else:
return [
"encoding_format",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return optional_params
def validate_environment(self, headers: dict, *args, **kwargs) -> dict:
return self.headers
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
url = self.deployment_url.rstrip("/") + "/v2/embeddings"
return url
def transform_embedding_request(
self,
model: str,
input: AllEmbeddingInputValues,
optional_params: dict,
headers: dict,
) -> dict:
model_dict = {}
model_dict["name"] = model
model_dict["version"] = optional_params.get("version", "latest")
model_dict["params"] = optional_params.get("parameters", {})
input_dict = {"text": input}
body = {
"config": {
"modules": {
"embeddings": {"model": validate_dict(model_dict, EmbeddingModel)}
}
},
"input": validate_dict(input_dict, EmbeddingInput),
}
return body
def transform_embedding_response(
self,
model: str,
raw_response: httpx.Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str],
request_data: dict,
optional_params: dict,
litellm_params: dict,
) -> EmbeddingResponse:
return EmbeddingResponse.model_validate(raw_response.json()["final_result"])