chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1 @@
|
||||
|
||||
260
llm-gateway-competitors/litellm-wheel-src/litellm/llms/sap/chat/handler.py
Executable file
260
llm-gateway-competitors/litellm-wheel-src/litellm/llms/sap/chat/handler.py
Executable file
@@ -0,0 +1,260 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import AsyncIterator, Iterator, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import OpenAIChatCompletionChunk
|
||||
|
||||
from ...custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Errors
|
||||
# -------------------------------
|
||||
class GenAIHubOrchestrationError(BaseLLMException):
|
||||
def __init__(self, status_code: int, message: str):
|
||||
super().__init__(status_code=status_code, message=message)
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Stream parsing helpers
|
||||
# -------------------------------
|
||||
|
||||
|
||||
def _now_ts() -> int:
|
||||
return int(time.time())
|
||||
|
||||
|
||||
def _is_terminal_chunk(chunk: OpenAIChatCompletionChunk) -> bool:
|
||||
"""OpenAI-shaped chunk is terminal if any choice has a non-None finish_reason."""
|
||||
try:
|
||||
for ch in chunk.choices or []:
|
||||
if ch.finish_reason is not None:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
class _StreamParser:
|
||||
"""Normalize orchestration streaming events into OpenAI-like chunks."""
|
||||
|
||||
@staticmethod
|
||||
def _from_orchestration_result(evt: dict) -> Optional[OpenAIChatCompletionChunk]:
|
||||
"""
|
||||
Accepts orchestration_result shape and maps it to an OpenAI-like *chunk*.
|
||||
"""
|
||||
orc = evt.get("orchestration_result") or {}
|
||||
if not orc:
|
||||
return None
|
||||
|
||||
return OpenAIChatCompletionChunk.model_validate(
|
||||
{
|
||||
"id": orc.get("id") or evt.get("request_id") or "stream-chunk",
|
||||
"object": orc.get("object") or "chat.completion.chunk",
|
||||
"created": orc.get("created") or evt.get("created") or _now_ts(),
|
||||
"model": orc.get("model") or "unknown",
|
||||
"choices": [
|
||||
{
|
||||
"index": c.get("index", 0),
|
||||
"delta": c.get("delta") or {},
|
||||
"finish_reason": c.get("finish_reason"),
|
||||
}
|
||||
for c in (orc.get("choices") or [])
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def to_openai_chunk(event_obj: dict) -> Optional[OpenAIChatCompletionChunk]:
|
||||
"""
|
||||
Accepts:
|
||||
- {"final_result": <openai-style CHUNK>} (IMPORTANT: this is just another chunk, NOT terminal)
|
||||
- {"orchestration_result": {...}} (map to chunk)
|
||||
- already-openai-shaped chunks
|
||||
- other events (ignored)
|
||||
Raises:
|
||||
- ValueError for in-stream error objects
|
||||
"""
|
||||
# In-stream error per spec (surface as exception)
|
||||
if "code" in event_obj or "error" in event_obj:
|
||||
raise ValueError(json.dumps(event_obj))
|
||||
|
||||
# FINAL RESULT IS *NOT* TERMINAL: treat it as the next chunk
|
||||
if "final_result" in event_obj:
|
||||
fr = event_obj["final_result"] or {}
|
||||
# ensure it looks like an OpenAI chunk
|
||||
if "object" not in fr:
|
||||
fr["object"] = "chat.completion.chunk"
|
||||
return OpenAIChatCompletionChunk.model_validate(fr)
|
||||
|
||||
# Orchestration incremental delta
|
||||
if "orchestration_result" in event_obj:
|
||||
return _StreamParser._from_orchestration_result(event_obj)
|
||||
|
||||
# Already an OpenAI-like chunk
|
||||
if "choices" in event_obj and "object" in event_obj:
|
||||
return OpenAIChatCompletionChunk.model_validate(event_obj)
|
||||
|
||||
# Unknown / heartbeat / metrics
|
||||
return None
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Iterators
|
||||
# -------------------------------
|
||||
class SAPStreamIterator:
|
||||
"""
|
||||
Sync iterator over an httpx streaming response that yields OpenAIChatCompletionChunk.
|
||||
Accepts both SSE `data: ...` and raw JSON lines. Closes on terminal chunk or [DONE].
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response: Iterator,
|
||||
event_prefix: str = "data: ",
|
||||
final_msg: str = "[DONE]",
|
||||
):
|
||||
self._resp = response
|
||||
self._iter = response
|
||||
self._prefix = event_prefix
|
||||
self._final = final_msg
|
||||
self._done = False
|
||||
|
||||
def __iter__(self) -> Iterator[OpenAIChatCompletionChunk]:
|
||||
return self
|
||||
|
||||
def __next__(self) -> OpenAIChatCompletionChunk:
|
||||
if self._done:
|
||||
raise StopIteration
|
||||
|
||||
for raw in self._iter:
|
||||
line = (raw or "").strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
payload = (
|
||||
line[len(self._prefix) :] if line.startswith(self._prefix) else line
|
||||
)
|
||||
if payload == self._final:
|
||||
self._safe_close()
|
||||
raise StopIteration
|
||||
|
||||
try:
|
||||
obj = json.loads(payload)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk = _StreamParser.to_openai_chunk(obj)
|
||||
except ValueError as e:
|
||||
self._safe_close()
|
||||
raise e
|
||||
|
||||
if chunk is None:
|
||||
continue
|
||||
|
||||
# Close on terminal
|
||||
if _is_terminal_chunk(chunk):
|
||||
self._safe_close()
|
||||
|
||||
return chunk
|
||||
|
||||
self._safe_close()
|
||||
raise StopIteration
|
||||
|
||||
def _safe_close(self) -> None:
|
||||
if self._done:
|
||||
return
|
||||
else:
|
||||
self._done = True
|
||||
|
||||
|
||||
class AsyncSAPStreamIterator:
|
||||
sync_stream = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response: AsyncIterator,
|
||||
event_prefix: str = "data: ",
|
||||
final_msg: str = "[DONE]",
|
||||
):
|
||||
self._resp = response
|
||||
self._prefix = event_prefix
|
||||
self._final = final_msg
|
||||
self._line_iter = None
|
||||
self._done = False
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self._done:
|
||||
raise StopAsyncIteration
|
||||
|
||||
if self._line_iter is None:
|
||||
self._line_iter = self._resp
|
||||
|
||||
while True:
|
||||
try:
|
||||
raw = await self._line_iter.__anext__()
|
||||
except (StopAsyncIteration, httpx.ReadError, OSError):
|
||||
await self._aclose()
|
||||
raise StopAsyncIteration
|
||||
|
||||
line = (raw or "").strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# now = lambda: int(time.time() * 1000)
|
||||
payload = (
|
||||
line[len(self._prefix) :] if line.startswith(self._prefix) else line
|
||||
)
|
||||
if payload == self._final:
|
||||
await self._aclose()
|
||||
raise StopAsyncIteration
|
||||
try:
|
||||
obj = json.loads(payload)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk = _StreamParser.to_openai_chunk(obj)
|
||||
except ValueError as e:
|
||||
await self._aclose()
|
||||
raise GenAIHubOrchestrationError(502, str(e))
|
||||
|
||||
if chunk is None:
|
||||
continue
|
||||
|
||||
# If terminal, close BEFORE returning. Next __anext__() will stop immediately.
|
||||
if any(c.finish_reason is not None for c in (chunk.choices or [])):
|
||||
await self._aclose()
|
||||
|
||||
return chunk
|
||||
|
||||
async def _aclose(self):
|
||||
if self._done:
|
||||
return
|
||||
else:
|
||||
self._done = True
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# LLM handler
|
||||
# -------------------------------
|
||||
class GenAIHubOrchestration(BaseLLMHTTPHandler):
|
||||
def _add_stream_param_to_request_body(
|
||||
self, data: dict, provider_config: BaseConfig, fake_stream: bool
|
||||
):
|
||||
if data.get("config", {}).get("stream", None) is not None:
|
||||
data["config"]["stream"]["enabled"] = True
|
||||
else:
|
||||
data["config"]["stream"] = {"enabled": True}
|
||||
return data
|
||||
@@ -0,0 +1,130 @@
|
||||
from typing import Union, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
def validate_different_content(v: Union[str, dict, list]) -> str:
|
||||
if v in ((), {}, []):
|
||||
return ""
|
||||
elif isinstance(v, dict) and "text" in v:
|
||||
return v["text"]
|
||||
elif isinstance(v, list):
|
||||
new_v = []
|
||||
for item in v:
|
||||
if isinstance(item, dict) and "text" in item:
|
||||
if item["text"]:
|
||||
new_v.append(item["text"])
|
||||
elif isinstance(item, str):
|
||||
new_v.append(item)
|
||||
return "\n".join(new_v)
|
||||
elif isinstance(v, str):
|
||||
return v
|
||||
raise ValueError("Content must be a string")
|
||||
return v
|
||||
|
||||
|
||||
class TextContent(BaseModel):
|
||||
type_: Literal["text"] = Field(default="text", alias="type")
|
||||
text: str
|
||||
|
||||
|
||||
class ImageURLContent(BaseModel):
|
||||
url: str
|
||||
detail: str = "auto"
|
||||
|
||||
|
||||
class ImageContent(BaseModel):
|
||||
type_: Literal["image_url"] = Field(default="image_url", alias="type")
|
||||
image_url: ImageURLContent
|
||||
|
||||
|
||||
class FunctionObj(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class FunctionTool(BaseModel):
|
||||
description: str = ""
|
||||
name: str
|
||||
parameters: dict = {"type": "object", "properties": {}}
|
||||
strict: bool = False
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def ensure_object_type(cls, v: dict) -> dict:
|
||||
"""Ensure parameters has type='object' as required by SAP Orchestration Service."""
|
||||
if not v:
|
||||
return {"type": "object", "properties": {}}
|
||||
if "type" not in v:
|
||||
v = {"type": "object", **v}
|
||||
if "properties" not in v:
|
||||
v["properties"] = {}
|
||||
return v
|
||||
|
||||
|
||||
class ChatCompletionTool(BaseModel):
|
||||
type_: Literal["function"] = Field(default="function", alias="type")
|
||||
function: FunctionTool
|
||||
|
||||
|
||||
class MessageToolCall(BaseModel):
|
||||
id: str
|
||||
type_: Literal["function"] = Field(default="function", alias="type")
|
||||
function: FunctionObj
|
||||
|
||||
|
||||
class SAPMessage(BaseModel):
|
||||
"""
|
||||
Model for SystemChatMessage and DeveloperChatMessage
|
||||
"""
|
||||
|
||||
role: Literal["system", "developer"] = "system"
|
||||
content: str
|
||||
|
||||
_content_validator = field_validator("content", mode="before")(
|
||||
validate_different_content
|
||||
)
|
||||
|
||||
|
||||
class SAPUserMessage(BaseModel):
|
||||
role: Literal["user"] = "user"
|
||||
content: Union[
|
||||
str, TextContent, ImageContent, list[Union[TextContent, ImageContent]]
|
||||
]
|
||||
|
||||
|
||||
class SAPAssistantMessage(BaseModel):
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: str = ""
|
||||
refusal: str = ""
|
||||
tool_calls: list[MessageToolCall] = []
|
||||
|
||||
_content_validator = field_validator("content", mode="before")(
|
||||
validate_different_content
|
||||
)
|
||||
|
||||
|
||||
class SAPToolChatMessage(BaseModel):
|
||||
role: Literal["tool"] = "tool"
|
||||
tool_call_id: str
|
||||
content: str
|
||||
|
||||
_content_validator = field_validator("content", mode="before")(
|
||||
validate_different_content
|
||||
)
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
type_: Literal["text", "json_object"] = Field(default="text", alias="type")
|
||||
|
||||
|
||||
class JSONResponseSchema(BaseModel):
|
||||
description: str = ""
|
||||
name: str
|
||||
schema_: dict = Field(default_factory=dict, alias="schema")
|
||||
strict: bool = False
|
||||
|
||||
|
||||
class ResponseFormatJSONSchema(BaseModel):
|
||||
type_: Literal["json_schema"] = Field(default="json_schema", alias="type")
|
||||
json_schema: JSONResponseSchema
|
||||
@@ -0,0 +1,351 @@
|
||||
"""
|
||||
Translate from OpenAI's `/v1/chat/completions` to SAP Generative AI Hub's Orchestration Service`v2/completion`
|
||||
"""
|
||||
from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
Dict,
|
||||
Tuple,
|
||||
Any,
|
||||
TYPE_CHECKING,
|
||||
Iterator,
|
||||
AsyncIterator,
|
||||
)
|
||||
from functools import cached_property
|
||||
import litellm
|
||||
import httpx
|
||||
|
||||
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
from ..credentials import get_token_creator
|
||||
from .models import (
|
||||
SAPMessage,
|
||||
SAPAssistantMessage,
|
||||
SAPToolChatMessage,
|
||||
ChatCompletionTool,
|
||||
ResponseFormatJSONSchema,
|
||||
ResponseFormat,
|
||||
SAPUserMessage,
|
||||
)
|
||||
from .handler import (
|
||||
GenAIHubOrchestrationError,
|
||||
AsyncSAPStreamIterator,
|
||||
SAPStreamIterator,
|
||||
)
|
||||
|
||||
|
||||
def validate_dict(data: dict, model) -> dict:
|
||||
return model(**data).model_dump(by_alias=True)
|
||||
|
||||
|
||||
class GenAIHubOrchestrationConfig(OpenAIGPTConfig):
|
||||
frequency_penalty: Optional[int] = None
|
||||
function_call: Optional[Union[str, dict]] = None
|
||||
functions: Optional[list] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
response_format: Optional[dict] = None
|
||||
tools: Optional[list] = None
|
||||
tool_choice: Optional[Union[str, dict]] = None #
|
||||
model_version: str = "latest"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
function_call: Optional[Union[str, dict]] = None,
|
||||
functions: Optional[list] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[Union[str, dict]] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
self.token_creator = None
|
||||
self._base_url = None
|
||||
self._resource_group = None
|
||||
|
||||
def run_env_setup(self, service_key: Optional[str] = None) -> None:
|
||||
try:
|
||||
self.token_creator, self._base_url, self._resource_group = get_token_creator(service_key) # type: ignore
|
||||
except ValueError as err:
|
||||
raise GenAIHubOrchestrationError(status_code=400, message=err.args[0])
|
||||
|
||||
@property
|
||||
def headers(self) -> Dict[str, str]:
|
||||
if self.token_creator is None:
|
||||
self.run_env_setup()
|
||||
access_token = self.token_creator() # type: ignore
|
||||
return {
|
||||
"Authorization": access_token,
|
||||
"AI-Resource-Group": self.resource_group,
|
||||
"Content-Type": "application/json",
|
||||
"AI-Client-Type": "LiteLLM",
|
||||
}
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
if self._base_url is None:
|
||||
self.run_env_setup()
|
||||
return self._base_url # type: ignore
|
||||
|
||||
@property
|
||||
def resource_group(self) -> str:
|
||||
if self._resource_group is None:
|
||||
self.run_env_setup()
|
||||
return self._resource_group # type: ignore
|
||||
|
||||
@cached_property
|
||||
def deployment_url(self) -> str:
|
||||
# Keep a short, tight client lifecycle here to avoid fd leaks
|
||||
client = litellm.module_level_client
|
||||
# with httpx.Client(timeout=30) as client:
|
||||
deployments = client.get(
|
||||
f"{self.base_url}/lm/deployments", headers=self.headers
|
||||
).json()
|
||||
valid: List[Tuple[str, str]] = []
|
||||
for dep in deployments.get("resources", []):
|
||||
if dep.get("scenarioId") == "orchestration":
|
||||
cfg = client.get(
|
||||
f'{self.base_url}/lm/configurations/{dep["configurationId"]}',
|
||||
headers=self.headers,
|
||||
).json()
|
||||
if cfg.get("executableId") == "orchestration":
|
||||
valid.append((dep["deploymentUrl"], dep["createdAt"]))
|
||||
# newest first
|
||||
return sorted(valid, key=lambda x: x[1], reverse=True)[0][0]
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model):
|
||||
params = [
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"prediction",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"seed",
|
||||
"stop",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"function_call",
|
||||
"functions",
|
||||
"extra_headers",
|
||||
"parallel_tool_calls",
|
||||
"response_format",
|
||||
"timeout",
|
||||
]
|
||||
# Remove response_format for providers that don't support it on SAP GenAI Hub
|
||||
if (
|
||||
model.startswith("amazon")
|
||||
or model.startswith("cohere")
|
||||
or model.startswith("alephalpha")
|
||||
or model == "gpt-4"
|
||||
):
|
||||
params.remove("response_format")
|
||||
if model.startswith("gemini") or model.startswith("amazon"):
|
||||
params.remove("tool_choice")
|
||||
return params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key:
|
||||
self.run_env_setup(api_key)
|
||||
return self.headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
):
|
||||
api_base_ = f"{self.deployment_url}/v2/completion"
|
||||
return api_base_
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]], # type: ignore
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
# Filter out parameters that are not valid model params for SAP Orchestration API
|
||||
# - tools, model_version, deployment_url: handled separately
|
||||
excluded_params = {"tools", "model_version", "deployment_url"}
|
||||
|
||||
# Filter strict for GPT models only - SAP AI Core doesn't accept it as a model param
|
||||
# LangChain agents pass strict=true at top level, which fails for GPT models
|
||||
# Anthropic models accept strict, so preserve it for them
|
||||
if model.startswith("gpt"):
|
||||
excluded_params.add("strict")
|
||||
|
||||
model_params = {
|
||||
k: v for k, v in optional_params.items() if k not in excluded_params
|
||||
}
|
||||
|
||||
model_version = optional_params.pop("model_version", "latest")
|
||||
template = []
|
||||
for message in messages:
|
||||
if message["role"] == "user":
|
||||
template.append(validate_dict(message, SAPUserMessage))
|
||||
elif message["role"] == "assistant":
|
||||
template.append(validate_dict(message, SAPAssistantMessage))
|
||||
elif message["role"] == "tool":
|
||||
template.append(validate_dict(message, SAPToolChatMessage))
|
||||
else:
|
||||
template.append(validate_dict(message, SAPMessage))
|
||||
|
||||
tools_ = optional_params.pop("tools", [])
|
||||
tools_ = [validate_dict(tool, ChatCompletionTool) for tool in tools_]
|
||||
if tools_ != []:
|
||||
tools = {"tools": tools_}
|
||||
else:
|
||||
tools = {}
|
||||
|
||||
response_format = model_params.pop("response_format", {})
|
||||
resp_type = response_format.get("type", None)
|
||||
if resp_type:
|
||||
if resp_type == "json_schema":
|
||||
response_format = validate_dict(
|
||||
response_format, ResponseFormatJSONSchema
|
||||
)
|
||||
else:
|
||||
response_format = validate_dict(response_format, ResponseFormat)
|
||||
response_format = {"response_format": response_format}
|
||||
model_params.pop("stream", False)
|
||||
stream_config = {}
|
||||
if "stream_options" in model_params:
|
||||
# stream_config["enabled"] = True
|
||||
stream_options = model_params.pop("stream_options", {})
|
||||
stream_config["chunk_size"] = stream_options.get("chunk_size", 100)
|
||||
if "delimiters" in stream_options:
|
||||
stream_config["delimiters"] = stream_options.get("delimiters")
|
||||
# else:
|
||||
# stream_config["enabled"] = False
|
||||
config = {
|
||||
"config": {
|
||||
"modules": {
|
||||
"prompt_templating": {
|
||||
"prompt": {"template": template, **tools, **response_format},
|
||||
"model": {
|
||||
"name": model,
|
||||
"params": model_params,
|
||||
"version": model_version,
|
||||
},
|
||||
},
|
||||
},
|
||||
"stream": stream_config,
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=raw_response.text,
|
||||
additional_args={"complete_input_dict": request_data},
|
||||
)
|
||||
response = ModelResponse.model_validate(raw_response.json()["final_result"])
|
||||
|
||||
# Strip markdown code blocks if JSON response_format was used with Anthropic models
|
||||
# SAP GenAI Hub with Anthropic models sometimes wraps JSON in ```json ... ```
|
||||
# based on prompt phrasing. GPT/Gemini models don't exhibit this behavior,
|
||||
# so we gate the stripping to avoid accidentally modifying valid responses.
|
||||
response_format = optional_params.get("response_format", {})
|
||||
if response_format.get("type") in ("json_object", "json_schema"):
|
||||
if model.startswith("anthropic"):
|
||||
response = self._strip_markdown_json(response)
|
||||
|
||||
return response
|
||||
|
||||
def _strip_markdown_json(self, response: ModelResponse) -> ModelResponse:
|
||||
"""Strip markdown code block wrapper from JSON content if present.
|
||||
|
||||
SAP GenAI Hub with Anthropic models sometimes returns JSON wrapped in
|
||||
markdown code blocks (```json ... ```) depending on prompt phrasing.
|
||||
This method strips that wrapper to ensure consistent JSON output.
|
||||
"""
|
||||
import re
|
||||
|
||||
for choice in response.choices or []:
|
||||
if choice.message and choice.message.content:
|
||||
content = choice.message.content.strip()
|
||||
# Match ```json ... ``` or ``` ... ```
|
||||
match = re.match(r"^```(?:json)?\s*\n?(.*?)\n?```$", content, re.DOTALL)
|
||||
if match:
|
||||
choice.message.content = match.group(1).strip()
|
||||
|
||||
return response
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], "ModelResponse"],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
if sync_stream:
|
||||
return SAPStreamIterator(response=streaming_response) # type: ignore
|
||||
else:
|
||||
return AsyncSAPStreamIterator(response=streaming_response) # type: ignore
|
||||
@@ -0,0 +1,332 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, Callable, Dict, Final, List, Optional, Sequence, Tuple
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from threading import Lock
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from litellm import sap_service_key
|
||||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||
|
||||
AUTH_ENDPOINT_SUFFIX = "/oauth/token"
|
||||
|
||||
CONFIG_FILE_ENV_VAR = "AICORE_CONFIG"
|
||||
HOME_PATH_ENV_VAR = "AICORE_HOME"
|
||||
PROFILE_ENV_VAR = "AICORE_PROFILE"
|
||||
|
||||
VCAP_SERVICES_ENV_VAR = "VCAP_SERVICES"
|
||||
VCAP_AICORE_SERVICE_NAME = "aicore"
|
||||
SERVICE_KEY_ENV_VAR = "AICORE_SERVICE_KEY"
|
||||
|
||||
DEFAULT_HOME_PATH = os.path.join(os.path.expanduser("~"), ".aicore")
|
||||
|
||||
|
||||
def _get_home() -> str:
|
||||
return os.getenv(HOME_PATH_ENV_VAR, DEFAULT_HOME_PATH)
|
||||
|
||||
|
||||
def _get_nested(d: Dict[str, Any], path: Sequence[str]) -> Any:
|
||||
cur: Any = d
|
||||
for k in path:
|
||||
if not isinstance(cur, dict) or k not in cur:
|
||||
raise KeyError(".".join(path))
|
||||
cur = cur[k]
|
||||
return cur
|
||||
|
||||
|
||||
def _load_json_env(var_name: str) -> Optional[Dict[str, Any]]:
|
||||
raw = os.environ.get(var_name)
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
return json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
def _load_vcap() -> Dict[str, Any]:
|
||||
return _load_json_env(VCAP_SERVICES_ENV_VAR) or {}
|
||||
|
||||
|
||||
def _get_vcap_service(label: str) -> Optional[Dict[str, Any]]:
|
||||
for services in _load_vcap().values():
|
||||
for svc in services:
|
||||
if svc.get("label") == label:
|
||||
return svc
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CredentialsValue:
|
||||
name: str
|
||||
vcap_key: Optional[Tuple[str, ...]] = None
|
||||
default: Optional[str] = None
|
||||
transform_fn: Optional[Callable[[str], str]] = None
|
||||
|
||||
|
||||
CREDENTIAL_VALUES: Final[List[CredentialsValue]] = [
|
||||
CredentialsValue("client_id", ("clientid",)),
|
||||
CredentialsValue("client_secret", ("clientsecret",)),
|
||||
CredentialsValue(
|
||||
"auth_url",
|
||||
("url",),
|
||||
transform_fn=lambda url: url.rstrip("/")
|
||||
+ ("" if url.endswith(AUTH_ENDPOINT_SUFFIX) else AUTH_ENDPOINT_SUFFIX),
|
||||
),
|
||||
CredentialsValue(
|
||||
"base_url",
|
||||
("serviceurls", "AI_API_URL"),
|
||||
transform_fn=lambda url: url.rstrip("/")
|
||||
+ ("" if url.endswith("/v2") else "/v2"),
|
||||
),
|
||||
CredentialsValue("resource_group", default="default"),
|
||||
CredentialsValue(
|
||||
"cert_url",
|
||||
("certurl",),
|
||||
transform_fn=lambda url: url.rstrip("/")
|
||||
+ ("" if url.endswith(AUTH_ENDPOINT_SUFFIX) else AUTH_ENDPOINT_SUFFIX),
|
||||
),
|
||||
# file paths (kept for config compatibility)
|
||||
CredentialsValue("cert_file_path"),
|
||||
CredentialsValue("key_file_path"),
|
||||
# inline PEMs from VCAP
|
||||
CredentialsValue(
|
||||
"cert_str", ("certificate",), transform_fn=lambda s: s.replace("\\n", "\n")
|
||||
),
|
||||
CredentialsValue(
|
||||
"key_str", ("key",), transform_fn=lambda s: s.replace("\\n", "\n")
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def init_conf(profile: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Loads config JSON from:
|
||||
1) $AICORE_CONFIG if set, otherwise
|
||||
2) $AICORE_HOME/config.json (or config_<profile>.json when profile is given/not default)
|
||||
Returns {} when nothing is found.
|
||||
"""
|
||||
home = Path(_get_home())
|
||||
profile = profile or os.environ.get(PROFILE_ENV_VAR)
|
||||
cfg_env = os.getenv(CONFIG_FILE_ENV_VAR)
|
||||
cfg_path = (
|
||||
Path(cfg_env)
|
||||
if cfg_env
|
||||
else (
|
||||
home
|
||||
/ (
|
||||
"config.json"
|
||||
if profile in (None, "", "default")
|
||||
else f"config_{profile}.json"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if cfg_path and cfg_path.exists():
|
||||
try:
|
||||
with cfg_path.open(encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
raise KeyError(f"{cfg_path} is not valid JSON. Please fix or remove it!")
|
||||
|
||||
# If an explicit non-default profile was requested but not found, raise.
|
||||
if cfg_env or (profile not in (None, "", "default")):
|
||||
raise FileNotFoundError(
|
||||
f"Unable to locate profile config file at '{cfg_path}' in AICORE_HOME '{home}'"
|
||||
)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _env_name(name: str) -> str:
|
||||
return f"AICORE_{name.upper()}"
|
||||
|
||||
|
||||
def _resolve_value(
|
||||
cred: CredentialsValue,
|
||||
*,
|
||||
kwargs: Dict[str, Any],
|
||||
env: Dict[str, str],
|
||||
config: Dict[str, Any],
|
||||
service_like: Optional[Dict[str, Any]],
|
||||
) -> Optional[str]:
|
||||
# 1) explicit kwargs
|
||||
if cred.name in kwargs and kwargs[cred.name] is not None:
|
||||
return kwargs[cred.name]
|
||||
|
||||
# 2) environment variables (primary name)
|
||||
env_key = _env_name(cred.name)
|
||||
if env_key in env and env[env_key] is not None:
|
||||
return env[env_key]
|
||||
|
||||
# 3) config file (accept both prefixed and plain keys)
|
||||
for key in (env_key, cred.name):
|
||||
if key in config and config[key] is not None:
|
||||
return config[key]
|
||||
|
||||
# 4) service-like source (AICORE_SERVICE_KEY first, else VCAP)
|
||||
if service_like and cred.vcap_key:
|
||||
try:
|
||||
val = _get_nested(service_like, ("credentials",) + cred.vcap_key)
|
||||
if val is not None:
|
||||
return val
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# 5) default
|
||||
return cred.default
|
||||
|
||||
|
||||
def fetch_credentials(
|
||||
service_key: Optional[str] = None, profile: Optional[str] = None, **kwargs
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Resolution order per key:
|
||||
kwargs
|
||||
> env (AICORE_<NAME>)
|
||||
> config (AICORE_<NAME> or plain <name>)
|
||||
> service-like source from JSON in $AICORE_SERVICE_KEY (same structure as a VCAP service object)
|
||||
falling back to service entry in $VCAP_SERVICES with label 'aicore'
|
||||
> default
|
||||
"""
|
||||
config = init_conf(profile)
|
||||
env = os.environ # snapshot for testability
|
||||
service_like = None
|
||||
|
||||
if not config:
|
||||
# Prefer AICORE_SERVICE_KEY if present; otherwise fall back to the VCAP service.
|
||||
service_like = (
|
||||
service_key
|
||||
or sap_service_key
|
||||
or _load_json_env(SERVICE_KEY_ENV_VAR)
|
||||
or _get_vcap_service(VCAP_AICORE_SERVICE_NAME)
|
||||
)
|
||||
|
||||
out: Dict[str, str] = {}
|
||||
for cred in CREDENTIAL_VALUES:
|
||||
value = _resolve_value(cred, kwargs=kwargs, env=env, config=config, service_like=service_like) # type: ignore
|
||||
if value is None:
|
||||
continue
|
||||
if cred.transform_fn:
|
||||
value = cred.transform_fn(value)
|
||||
out[cred.name] = value
|
||||
if "cert_url" in out.keys():
|
||||
out["auth_url"] = out.pop("cert_url")
|
||||
return out
|
||||
|
||||
|
||||
def get_token_creator(
|
||||
service_key: Optional[str] = None,
|
||||
profile: Optional[str] = None,
|
||||
*,
|
||||
timeout: float = 30.0,
|
||||
expiry_buffer_minutes: int = 60,
|
||||
**overrides,
|
||||
) -> Tuple[Callable[[], str], str, str]:
|
||||
"""
|
||||
Creates a callable that fetches and caches an OAuth2 bearer token
|
||||
using credentials from `fetch_credentials()`.
|
||||
|
||||
The callable:
|
||||
- Automatically loads credentials via fetch_credentials(profile, **overrides)
|
||||
- Fetches a new token only if expired or near expiry
|
||||
- Caches token thread-safely with a configurable refresh buffer
|
||||
|
||||
Args:
|
||||
profile: Optional AICore profile name
|
||||
timeout: HTTP request timeout in seconds (default 30s)
|
||||
expiry_buffer_minutes: Refresh the token this many minutes before expiry
|
||||
overrides: Any explicit credential overrides (client_id, client_secret, etc.)
|
||||
|
||||
Returns:
|
||||
Callable[[], str]: function returning a valid "Bearer <token>" string.
|
||||
"""
|
||||
|
||||
# Resolve credentials using your helper
|
||||
credentials: Dict[str, str] = fetch_credentials(
|
||||
service_key=service_key, profile=profile, **overrides
|
||||
)
|
||||
|
||||
auth_url = credentials.get("auth_url")
|
||||
client_id = credentials.get("client_id")
|
||||
client_secret = credentials.get("client_secret")
|
||||
cert_str = credentials.get("cert_str")
|
||||
key_str = credentials.get("key_str")
|
||||
cert_file_path = credentials.get("cert_file_path")
|
||||
key_file_path = credentials.get("key_file_path")
|
||||
|
||||
# Sanity check
|
||||
if not auth_url or not client_id:
|
||||
raise ValueError(
|
||||
"fetch_credentials did not return valid 'auth_url' or 'client_id'"
|
||||
)
|
||||
|
||||
modes = [
|
||||
client_secret is not None,
|
||||
(cert_str is not None and key_str is not None),
|
||||
(cert_file_path is not None and key_file_path is not None),
|
||||
]
|
||||
if sum(bool(m) for m in modes) != 1:
|
||||
raise ValueError(
|
||||
"Invalid credentials: provide exactly one of client_secret, "
|
||||
"(cert_str & key_str), or (cert_file_path & key_file_path)."
|
||||
)
|
||||
|
||||
lock = Lock()
|
||||
token: Optional[str] = None
|
||||
token_expiry: Optional[datetime] = None
|
||||
|
||||
def _request_token(cert_pair=None) -> tuple[str, datetime]:
|
||||
data = {"grant_type": "client_credentials", "client_id": client_id}
|
||||
if client_secret:
|
||||
data["client_secret"] = client_secret
|
||||
|
||||
client = _get_httpx_client()
|
||||
# with httpx.Client(cert=cert_pair, timeout=timeout) as client:
|
||||
resp = client.post(auth_url, data=data)
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
access_token = payload["access_token"]
|
||||
expires_in = int(payload.get("expires_in", 3600))
|
||||
expiry_date = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
|
||||
return f"Bearer {access_token}", expiry_date
|
||||
except Exception as e:
|
||||
msg = getattr(resp, "text", str(e))
|
||||
raise RuntimeError(f"Token request failed: {msg}") from e
|
||||
|
||||
def _fetch_token() -> tuple[str, datetime]:
|
||||
# Case 1: secret-based auth
|
||||
if client_secret:
|
||||
return _request_token()
|
||||
# Case 2: cert/key strings
|
||||
if cert_str and key_str:
|
||||
cert_str_fixed = cert_str.replace("\\n", "\n")
|
||||
key_str_fixed = key_str.replace("\\n", "\n")
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
cert_path = os.path.join(tmp, "cert.pem")
|
||||
key_path = os.path.join(tmp, "key.pem")
|
||||
with open(cert_path, "w") as f:
|
||||
f.write(cert_str_fixed)
|
||||
with open(key_path, "w") as f:
|
||||
f.write(key_str_fixed)
|
||||
return _request_token(cert_pair=(cert_path, key_path))
|
||||
# Case 3: file-based cert/key
|
||||
return _request_token(cert_pair=(cert_file_path, key_file_path))
|
||||
|
||||
def get_token() -> str:
|
||||
nonlocal token, token_expiry
|
||||
with lock:
|
||||
now = datetime.now(timezone.utc)
|
||||
if (
|
||||
token is None
|
||||
or token_expiry is None
|
||||
or token_expiry - now < timedelta(minutes=expiry_buffer_minutes)
|
||||
):
|
||||
token, token_expiry = _fetch_token()
|
||||
return token
|
||||
|
||||
return get_token, credentials["base_url"], credentials["resource_group"]
|
||||
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
Translates from OpenAI's `/v1/embeddings` to IBM's `/text/embeddings` route.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Literal, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from functools import cached_property
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.embedding.transformation import (
|
||||
BaseEmbeddingConfig,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import AllEmbeddingInputValues
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from ..chat.handler import GenAIHubOrchestrationError
|
||||
from ..credentials import get_token_creator
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class EmbeddingItem(BaseModel):
|
||||
object: Literal["embedding"]
|
||||
embedding: List[float] = Field(
|
||||
..., description="Vector of floats (length varies by model)."
|
||||
)
|
||||
index: int
|
||||
|
||||
|
||||
class FinalResult(BaseModel):
|
||||
object: Literal["list"]
|
||||
data: List[EmbeddingItem]
|
||||
model: str
|
||||
usage: Usage
|
||||
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
request_id: str
|
||||
final_result: FinalResult
|
||||
|
||||
|
||||
class EmbeddingModel(BaseModel):
|
||||
name: str
|
||||
version: str = "latest"
|
||||
params: dict = Field(default_factory=dict, validation_alias="parameters")
|
||||
|
||||
|
||||
class EmbeddingsModules(BaseModel):
|
||||
embeddings: EmbeddingModel
|
||||
|
||||
|
||||
class EmbeddingInput(BaseModel):
|
||||
text: Union[str, List[str]]
|
||||
type: Literal["text", "document", "query"] = "text"
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
config: EmbeddingsModules
|
||||
input: EmbeddingInput
|
||||
|
||||
|
||||
def validate_dict(data: dict, model) -> dict:
|
||||
return model(**data).model_dump()
|
||||
|
||||
|
||||
class GenAIHubEmbeddingConfig(BaseEmbeddingConfig):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._access_token_data = {}
|
||||
self.token_creator, self.base_url, self.resource_group = get_token_creator()
|
||||
|
||||
@property
|
||||
def headers(self) -> Dict:
|
||||
access_token = self.token_creator()
|
||||
# headers for completions and embeddings requests
|
||||
headers = {
|
||||
"Authorization": access_token,
|
||||
"AI-Resource-Group": self.resource_group,
|
||||
"Content-Type": "application/json",
|
||||
"AI-Client-Type": "LiteLLM",
|
||||
}
|
||||
return headers
|
||||
|
||||
@cached_property
|
||||
def deployment_url(self) -> str:
|
||||
with httpx.Client(timeout=30) as client:
|
||||
valid_deployments = []
|
||||
deployments = client.get(
|
||||
self.base_url + "/lm/deployments", headers=self.headers
|
||||
).json()
|
||||
for deployment in deployments.get("resources", []):
|
||||
if deployment["scenarioId"] == "orchestration":
|
||||
config_details = client.get(
|
||||
self.base_url
|
||||
+ f'/lm/configurations/{deployment["configurationId"]}',
|
||||
headers=self.headers,
|
||||
).json()
|
||||
if config_details["executableId"] == "orchestration":
|
||||
valid_deployments.append(
|
||||
(deployment["deploymentUrl"], deployment["createdAt"])
|
||||
)
|
||||
return sorted(valid_deployments, key=lambda x: x[1], reverse=True)[0][0]
|
||||
|
||||
def get_error_class(self, error_message, status_code, headers):
|
||||
return GenAIHubOrchestrationError(status_code, error_message)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
if "text-embedding-3" in model:
|
||||
return ["encoding_format", "dimensions"]
|
||||
else:
|
||||
return [
|
||||
"encoding_format",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return optional_params
|
||||
|
||||
def validate_environment(self, headers: dict, *args, **kwargs) -> dict:
|
||||
return self.headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
url = self.deployment_url.rstrip("/") + "/v2/embeddings"
|
||||
return url
|
||||
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: AllEmbeddingInputValues,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
model_dict = {}
|
||||
model_dict["name"] = model
|
||||
model_dict["version"] = optional_params.get("version", "latest")
|
||||
model_dict["params"] = optional_params.get("parameters", {})
|
||||
input_dict = {"text": input}
|
||||
body = {
|
||||
"config": {
|
||||
"modules": {
|
||||
"embeddings": {"model": validate_dict(model_dict, EmbeddingModel)}
|
||||
}
|
||||
},
|
||||
"input": validate_dict(input_dict, EmbeddingInput),
|
||||
}
|
||||
return body
|
||||
|
||||
def transform_embedding_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str],
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> EmbeddingResponse:
|
||||
return EmbeddingResponse.model_validate(raw_response.json()["final_result"])
|
||||
Reference in New Issue
Block a user