chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Support for Snowflake REST API
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ChatCompletionMessageToolCall, Function, ModelResponse
|
||||
|
||||
from ...openai_like.chat.transformation import OpenAIGPTConfig
|
||||
|
||||
from ..utils import SnowflakeBaseConfig
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class SnowflakeConfig(SnowflakeBaseConfig, OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api
|
||||
|
||||
Snowflake Cortex LLM REST API supports function calling with specific models (e.g., Claude 3.5 Sonnet).
|
||||
This config handles transformation between OpenAI format and Snowflake's tool_spec format.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def _transform_tool_calls_from_snowflake_to_openai(
|
||||
self, content_list: List[Dict[str, Any]]
|
||||
) -> Tuple[str, Optional[List[ChatCompletionMessageToolCall]]]:
|
||||
"""
|
||||
Transform Snowflake tool calls to OpenAI format.
|
||||
|
||||
Args:
|
||||
content_list: Snowflake's content_list array containing text and tool_use items
|
||||
|
||||
Returns:
|
||||
Tuple of (text_content, tool_calls)
|
||||
|
||||
Snowflake format in content_list:
|
||||
{
|
||||
"type": "tool_use",
|
||||
"tool_use": {
|
||||
"tool_use_id": "tooluse_...",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "Paris"}
|
||||
}
|
||||
}
|
||||
|
||||
OpenAI format (returned tool_calls):
|
||||
ChatCompletionMessageToolCall(
|
||||
id="tooluse_...",
|
||||
type="function",
|
||||
function=Function(name="get_weather", arguments='{"location": "Paris"}')
|
||||
)
|
||||
"""
|
||||
text_content = ""
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
|
||||
for idx, content_item in enumerate(content_list):
|
||||
if content_item.get("type") == "text":
|
||||
text_content += content_item.get("text", "")
|
||||
|
||||
## TOOL CALLING
|
||||
elif content_item.get("type") == "tool_use":
|
||||
tool_use_data = content_item.get("tool_use", {})
|
||||
tool_call = ChatCompletionMessageToolCall(
|
||||
id=tool_use_data.get("tool_use_id", ""),
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tool_use_data.get("name", ""),
|
||||
arguments=json.dumps(tool_use_data.get("input", {})),
|
||||
),
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return text_content, tool_calls if tool_calls else None
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
response_json = raw_response.json()
|
||||
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=response_json,
|
||||
additional_args={"complete_input_dict": request_data},
|
||||
)
|
||||
|
||||
## RESPONSE TRANSFORMATION
|
||||
# Snowflake returns content_list (not content) with tool_use objects
|
||||
# We need to transform this to OpenAI's format with content + tool_calls
|
||||
if "choices" in response_json and len(response_json["choices"]) > 0:
|
||||
choice = response_json["choices"][0]
|
||||
if "message" in choice and "content_list" in choice["message"]:
|
||||
content_list = choice["message"]["content_list"]
|
||||
(
|
||||
text_content,
|
||||
tool_calls,
|
||||
) = self._transform_tool_calls_from_snowflake_to_openai(content_list)
|
||||
|
||||
# Update the choice message with OpenAI format
|
||||
choice["message"]["content"] = text_content
|
||||
if tool_calls:
|
||||
choice["message"]["tool_calls"] = tool_calls
|
||||
|
||||
# Remove Snowflake-specific content_list
|
||||
del choice["message"]["content_list"]
|
||||
|
||||
returned_response = ModelResponse(**response_json)
|
||||
|
||||
returned_response.model = "snowflake/" + (returned_response.model or "")
|
||||
|
||||
if model is not None:
|
||||
returned_response._hidden_params["model"] = model
|
||||
return returned_response
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
If api_base is not provided, use the default DeepSeek /chat/completions endpoint.
|
||||
"""
|
||||
|
||||
api_base = self._get_api_base(api_base, optional_params)
|
||||
|
||||
return f"{api_base}/cortex/inference:complete"
|
||||
|
||||
def _transform_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Transform OpenAI tool format to Snowflake tool format.
|
||||
|
||||
Args:
|
||||
tools: List of tools in OpenAI format
|
||||
|
||||
Returns:
|
||||
List of tools in Snowflake format
|
||||
|
||||
OpenAI format:
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "...",
|
||||
"parameters": {...}
|
||||
}
|
||||
}
|
||||
|
||||
Snowflake format:
|
||||
{
|
||||
"tool_spec": {
|
||||
"type": "generic",
|
||||
"name": "get_weather",
|
||||
"description": "...",
|
||||
"input_schema": {...}
|
||||
}
|
||||
}
|
||||
"""
|
||||
snowflake_tools: List[Dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
function = tool.get("function", {})
|
||||
snowflake_tool: Dict[str, Any] = {
|
||||
"tool_spec": {
|
||||
"type": "generic",
|
||||
"name": function.get("name"),
|
||||
"input_schema": function.get(
|
||||
"parameters",
|
||||
{"type": "object", "properties": {}},
|
||||
),
|
||||
}
|
||||
}
|
||||
# Add description if present
|
||||
if "description" in function:
|
||||
snowflake_tool["tool_spec"]["description"] = function["description"]
|
||||
|
||||
snowflake_tools.append(snowflake_tool)
|
||||
|
||||
return snowflake_tools
|
||||
|
||||
def _transform_tool_choice(
|
||||
self, tool_choice: Union[str, Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform OpenAI tool_choice format to Snowflake format.
|
||||
|
||||
Snowflake requires tool_choice to be an object, not a string.
|
||||
Ref: https://docs.snowflake.com/en/developer-guide/snowflake-rest-api/reference/cortex-inference#post--api-v2-cortex-inference-complete-req-body-schema
|
||||
|
||||
Args:
|
||||
tool_choice: Tool choice in OpenAI format (str or dict)
|
||||
|
||||
Returns:
|
||||
Tool choice in Snowflake format (always an object, never a string)
|
||||
|
||||
OpenAI format (string):
|
||||
"auto", "required", "none"
|
||||
|
||||
OpenAI format (dict):
|
||||
{"type": "function", "function": {"name": "get_weather"}}
|
||||
|
||||
Snowflake format:
|
||||
{"type": "auto"} / {"type": "any"} / {"type": "none"}
|
||||
{"type": "tool", "name": ["get_weather"]}
|
||||
|
||||
Snowflake's API (like Anthropic) requires tool_choice as an object
|
||||
with a "type" field, not as a bare string.
|
||||
"""
|
||||
if isinstance(tool_choice, str):
|
||||
# Snowflake requires object format, not string.
|
||||
# Map OpenAI string values to Snowflake object format.
|
||||
# "required" maps to "any" (Snowflake/Anthropic convention).
|
||||
_type_map = {
|
||||
"auto": "auto",
|
||||
"required": "any",
|
||||
"none": "none",
|
||||
}
|
||||
mapped_type = _type_map.get(tool_choice, tool_choice)
|
||||
return {"type": mapped_type}
|
||||
|
||||
if isinstance(tool_choice, dict):
|
||||
if tool_choice.get("type") == "function":
|
||||
function_name = tool_choice.get("function", {}).get("name")
|
||||
if function_name:
|
||||
return {
|
||||
"type": "tool",
|
||||
"name": [function_name], # Snowflake expects array
|
||||
}
|
||||
|
||||
return tool_choice
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
stream: bool = optional_params.pop("stream", None) or False
|
||||
extra_body = optional_params.pop("extra_body", {})
|
||||
|
||||
## TOOL CALLING
|
||||
# Transform tools from OpenAI format to Snowflake's tool_spec format
|
||||
tools = optional_params.pop("tools", None)
|
||||
if tools:
|
||||
optional_params["tools"] = self._transform_tools(tools)
|
||||
|
||||
# Transform tool_choice from OpenAI format to Snowflake's tool name array format
|
||||
tool_choice = optional_params.pop("tool_choice", None)
|
||||
if tool_choice:
|
||||
optional_params["tool_choice"] = self._transform_tool_choice(tool_choice)
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
**optional_params,
|
||||
**extra_body,
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class SnowflakeBase:
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
JWT: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Return headers to use for Snowflake completion request
|
||||
|
||||
Snowflake REST API Ref: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#api-reference
|
||||
Expected headers:
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": "Bearer " + <JWT>,
|
||||
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT"
|
||||
}
|
||||
"""
|
||||
|
||||
if JWT is None:
|
||||
raise ValueError("Missing Snowflake JWT key")
|
||||
|
||||
headers.update(
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": "Bearer " + JWT,
|
||||
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT",
|
||||
}
|
||||
)
|
||||
return headers
|
||||
@@ -0,0 +1,69 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||
from litellm.types.llms.openai import AllEmbeddingInputValues
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from ..utils import SnowflakeException, SnowflakeBaseConfig
|
||||
|
||||
|
||||
class SnowflakeEmbeddingConfig(SnowflakeBaseConfig, BaseEmbeddingConfig):
|
||||
"""
|
||||
source: https://docs.snowflake.com/developer-guide/snowflake-rest-api/reference/cortex-embed
|
||||
"""
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
api_base = self._get_api_base(api_base, optional_params)
|
||||
|
||||
return f"{api_base}/cortex/inference:embed"
|
||||
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: AllEmbeddingInputValues,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
return {"text": input, "model": model, **optional_params}
|
||||
|
||||
def transform_embedding_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str],
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> EmbeddingResponse:
|
||||
response_json = raw_response.json()
|
||||
# convert embeddings to 1d array
|
||||
for item in response_json["data"]:
|
||||
item["embedding"] = item["embedding"][0]
|
||||
returned_response = EmbeddingResponse(**response_json)
|
||||
|
||||
returned_response.model = "snowflake/" + (returned_response.model or "")
|
||||
|
||||
if model is not None:
|
||||
returned_response._hidden_params["model"] = model
|
||||
return returned_response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return SnowflakeException(
|
||||
message=error_message, status_code=status_code, headers=headers
|
||||
)
|
||||
@@ -0,0 +1,118 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
||||
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class SnowflakeException(BaseLLMException):
|
||||
"""Snowflake AI Endpoints exception handling class"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SnowflakeBaseConfig:
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"response_format",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
If any supported_openai_params are in non_default_params, add them to optional_params, so they are used in API call
|
||||
|
||||
Args:
|
||||
non_default_params (dict): Non-default parameters to filter.
|
||||
optional_params (dict): Optional parameters to update.
|
||||
model (str): Model name for parameter support check.
|
||||
|
||||
Returns:
|
||||
dict: Updated optional_params with supported non-default parameters.
|
||||
"""
|
||||
supported_openai_params = self.get_supported_openai_params(model)
|
||||
for param, value in non_default_params.items():
|
||||
if param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def _get_api_base(self, api_base, optional_params):
|
||||
if not api_base:
|
||||
if "account_id" in optional_params:
|
||||
account_id = optional_params.pop("account_id")
|
||||
else:
|
||||
account_id = get_secret_str("SNOWFLAKE_ACCOUNT_ID")
|
||||
if account_id is None:
|
||||
raise ValueError("Missing snowflake account_id")
|
||||
api_base = f"https://{account_id}.snowflakecomputing.com/api/v2"
|
||||
|
||||
api_base = api_base.rstrip("/")
|
||||
if not api_base.endswith("/api/v2"):
|
||||
api_base += "/api/v2"
|
||||
return api_base
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Return headers to use for Snowflake completion request
|
||||
|
||||
Snowflake REST API Ref: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#api-reference
|
||||
Expected headers:
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": "Bearer " + <JWT>,
|
||||
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT"
|
||||
}
|
||||
"""
|
||||
|
||||
auth_type = "KEYPAIR_JWT"
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError("Missing Snowflake JWT key")
|
||||
else:
|
||||
pat_key_prefix = "pat/"
|
||||
if api_key.startswith(pat_key_prefix):
|
||||
api_key = api_key[len(pat_key_prefix) :]
|
||||
auth_type = "PROGRAMMATIC_ACCESS_TOKEN"
|
||||
|
||||
headers.update(
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": "Bearer " + api_key,
|
||||
"X-Snowflake-Authorization-Token-Type": auth_type,
|
||||
}
|
||||
)
|
||||
return headers
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
dynamic_api_key = api_key or get_secret_str("SNOWFLAKE_JWT")
|
||||
return api_base, dynamic_api_key
|
||||
Reference in New Issue
Block a user