272 lines
9.6 KiB
Python
272 lines
9.6 KiB
Python
|
|
from typing import Any, AsyncIterator, Iterator, List, Optional, Tuple, Union
|
||
|
|
|
||
|
|
import httpx
|
||
|
|
|
||
|
|
import litellm
|
||
|
|
from litellm._logging import verbose_logger
|
||
|
|
from litellm.constants import XAI_API_BASE
|
||
|
|
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||
|
|
filter_value_from_dict,
|
||
|
|
strip_name_from_messages,
|
||
|
|
)
|
||
|
|
from litellm.secret_managers.main import get_secret_str
|
||
|
|
from litellm.types.llms.openai import AllMessageValues
|
||
|
|
from litellm.types.utils import (
|
||
|
|
Choices,
|
||
|
|
ModelResponse,
|
||
|
|
ModelResponseStream,
|
||
|
|
PromptTokensDetailsWrapper,
|
||
|
|
Usage,
|
||
|
|
)
|
||
|
|
|
||
|
|
from ...openai.chat.gpt_transformation import (
|
||
|
|
OpenAIChatCompletionStreamingHandler,
|
||
|
|
OpenAIGPTConfig,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class XAIChatConfig(OpenAIGPTConfig):
|
||
|
|
@property
|
||
|
|
def custom_llm_provider(self) -> Optional[str]:
|
||
|
|
return "xai"
|
||
|
|
|
||
|
|
def _get_openai_compatible_provider_info(
|
||
|
|
self, api_base: Optional[str], api_key: Optional[str]
|
||
|
|
) -> Tuple[Optional[str], Optional[str]]:
|
||
|
|
api_base = api_base or get_secret_str("XAI_API_BASE") or XAI_API_BASE # type: ignore
|
||
|
|
dynamic_api_key = api_key or get_secret_str("XAI_API_KEY")
|
||
|
|
return api_base, dynamic_api_key
|
||
|
|
|
||
|
|
def get_supported_openai_params(self, model: str) -> list:
|
||
|
|
base_openai_params = [
|
||
|
|
"logit_bias",
|
||
|
|
"logprobs",
|
||
|
|
"max_tokens",
|
||
|
|
"n",
|
||
|
|
"presence_penalty",
|
||
|
|
"response_format",
|
||
|
|
"seed",
|
||
|
|
"stream",
|
||
|
|
"stream_options",
|
||
|
|
"temperature",
|
||
|
|
"tool_choice",
|
||
|
|
"tools",
|
||
|
|
"top_logprobs",
|
||
|
|
"top_p",
|
||
|
|
"user",
|
||
|
|
"web_search_options",
|
||
|
|
]
|
||
|
|
# for some reason, grok-3-mini does not support stop tokens
|
||
|
|
#########################################################
|
||
|
|
# stop tokens check
|
||
|
|
#########################################################
|
||
|
|
if self._supports_stop_reason(model):
|
||
|
|
base_openai_params.append("stop")
|
||
|
|
|
||
|
|
#########################################################
|
||
|
|
# frequency penalty check
|
||
|
|
#########################################################
|
||
|
|
if self._supports_frequency_penalty(model):
|
||
|
|
base_openai_params.append("frequency_penalty")
|
||
|
|
|
||
|
|
#########################################################
|
||
|
|
# reasoning check
|
||
|
|
#########################################################
|
||
|
|
try:
|
||
|
|
if litellm.supports_reasoning(
|
||
|
|
model=model, custom_llm_provider=self.custom_llm_provider
|
||
|
|
):
|
||
|
|
base_openai_params.append("reasoning_effort")
|
||
|
|
except Exception as e:
|
||
|
|
verbose_logger.debug(f"Error checking if model supports reasoning: {e}")
|
||
|
|
|
||
|
|
return base_openai_params
|
||
|
|
|
||
|
|
def _supports_stop_reason(self, model: str) -> bool:
|
||
|
|
if "grok-3-mini" in model:
|
||
|
|
return False
|
||
|
|
elif "grok-4" in model:
|
||
|
|
return False
|
||
|
|
elif "grok-code-fast" in model:
|
||
|
|
return False
|
||
|
|
return True
|
||
|
|
|
||
|
|
def _supports_frequency_penalty(self, model: str) -> bool:
|
||
|
|
"""
|
||
|
|
From manual testing grok-4 does not support `frequency_penalty`
|
||
|
|
|
||
|
|
When sent the model fails from xAI API
|
||
|
|
"""
|
||
|
|
if "grok-4" in model:
|
||
|
|
return False
|
||
|
|
if "grok-code-fast" in model:
|
||
|
|
return False
|
||
|
|
return True
|
||
|
|
|
||
|
|
def map_openai_params(
|
||
|
|
self,
|
||
|
|
non_default_params: dict,
|
||
|
|
optional_params: dict,
|
||
|
|
model: str,
|
||
|
|
drop_params: bool = False,
|
||
|
|
) -> dict:
|
||
|
|
supported_openai_params = self.get_supported_openai_params(model=model)
|
||
|
|
for param, value in non_default_params.items():
|
||
|
|
if param == "max_completion_tokens":
|
||
|
|
optional_params["max_tokens"] = value
|
||
|
|
elif param == "tools" and value is not None:
|
||
|
|
tools = []
|
||
|
|
for tool in value:
|
||
|
|
tool = filter_value_from_dict(tool, "strict")
|
||
|
|
if tool is not None:
|
||
|
|
tools.append(tool)
|
||
|
|
if len(tools) > 0:
|
||
|
|
optional_params["tools"] = tools
|
||
|
|
elif param in supported_openai_params:
|
||
|
|
if value is not None:
|
||
|
|
optional_params[param] = value
|
||
|
|
return optional_params
|
||
|
|
|
||
|
|
def get_model_response_iterator(
|
||
|
|
self,
|
||
|
|
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||
|
|
sync_stream: bool,
|
||
|
|
json_mode: Optional[bool] = False,
|
||
|
|
) -> Any:
|
||
|
|
return XAIChatCompletionStreamingHandler(
|
||
|
|
streaming_response=streaming_response,
|
||
|
|
sync_stream=sync_stream,
|
||
|
|
json_mode=json_mode,
|
||
|
|
)
|
||
|
|
|
||
|
|
def transform_request(
|
||
|
|
self,
|
||
|
|
model: str,
|
||
|
|
messages: List[AllMessageValues],
|
||
|
|
optional_params: dict,
|
||
|
|
litellm_params: dict,
|
||
|
|
headers: dict,
|
||
|
|
) -> dict:
|
||
|
|
"""
|
||
|
|
Handle https://github.com/BerriAI/litellm/issues/9720
|
||
|
|
|
||
|
|
Filter out 'name' from messages
|
||
|
|
"""
|
||
|
|
messages = strip_name_from_messages(messages)
|
||
|
|
return super().transform_request(
|
||
|
|
model, messages, optional_params, litellm_params, headers
|
||
|
|
)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _fix_choice_finish_reason_for_tool_calls(choice: Choices) -> None:
|
||
|
|
"""
|
||
|
|
Helper to fix finish_reason for tool calls when XAI API returns empty string.
|
||
|
|
|
||
|
|
XAI API returns empty string for finish_reason when using tools,
|
||
|
|
so we need to set it to "tool_calls" when tool_calls are present.
|
||
|
|
"""
|
||
|
|
if (
|
||
|
|
choice.finish_reason == ""
|
||
|
|
and choice.message.tool_calls
|
||
|
|
and len(choice.message.tool_calls) > 0
|
||
|
|
):
|
||
|
|
choice.finish_reason = "tool_calls"
|
||
|
|
|
||
|
|
def transform_response(
|
||
|
|
self,
|
||
|
|
model: str,
|
||
|
|
raw_response: httpx.Response,
|
||
|
|
model_response: ModelResponse,
|
||
|
|
logging_obj,
|
||
|
|
request_data: dict,
|
||
|
|
messages: List[AllMessageValues],
|
||
|
|
optional_params: dict,
|
||
|
|
litellm_params: dict,
|
||
|
|
encoding,
|
||
|
|
api_key: Optional[str] = None,
|
||
|
|
json_mode: Optional[bool] = None,
|
||
|
|
) -> ModelResponse:
|
||
|
|
"""
|
||
|
|
Transform the response from the XAI API.
|
||
|
|
|
||
|
|
XAI API returns empty string for finish_reason when using tools,
|
||
|
|
so we need to fix this after the standard OpenAI transformation.
|
||
|
|
|
||
|
|
Also handles X.AI web search usage tracking by extracting num_sources_used.
|
||
|
|
"""
|
||
|
|
|
||
|
|
# First, let the parent class handle the standard transformation
|
||
|
|
response = super().transform_response(
|
||
|
|
model=model,
|
||
|
|
raw_response=raw_response,
|
||
|
|
model_response=model_response,
|
||
|
|
logging_obj=logging_obj,
|
||
|
|
request_data=request_data,
|
||
|
|
messages=messages,
|
||
|
|
optional_params=optional_params,
|
||
|
|
litellm_params=litellm_params,
|
||
|
|
encoding=encoding,
|
||
|
|
api_key=api_key,
|
||
|
|
json_mode=json_mode,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Fix finish_reason for tool calls across all choices
|
||
|
|
if response.choices:
|
||
|
|
for choice in response.choices:
|
||
|
|
if isinstance(choice, Choices):
|
||
|
|
self._fix_choice_finish_reason_for_tool_calls(choice)
|
||
|
|
|
||
|
|
# Handle X.AI web search usage tracking
|
||
|
|
try:
|
||
|
|
raw_response_json = raw_response.json()
|
||
|
|
self._enhance_usage_with_xai_web_search_fields(response, raw_response_json)
|
||
|
|
except Exception as e:
|
||
|
|
verbose_logger.debug(f"Error extracting X.AI web search usage: {e}")
|
||
|
|
return response
|
||
|
|
|
||
|
|
def _enhance_usage_with_xai_web_search_fields(
|
||
|
|
self, model_response: ModelResponse, raw_response_json: dict
|
||
|
|
) -> None:
|
||
|
|
"""
|
||
|
|
Extract num_sources_used from X.AI response and map it to web_search_requests.
|
||
|
|
"""
|
||
|
|
if not hasattr(model_response, "usage") or model_response.usage is None:
|
||
|
|
return
|
||
|
|
|
||
|
|
usage: Usage = model_response.usage
|
||
|
|
num_sources_used = None
|
||
|
|
response_usage = raw_response_json.get("usage", {})
|
||
|
|
if isinstance(response_usage, dict) and "num_sources_used" in response_usage:
|
||
|
|
num_sources_used = response_usage.get("num_sources_used")
|
||
|
|
|
||
|
|
# Map num_sources_used to web_search_requests for cost detection
|
||
|
|
if num_sources_used is not None and num_sources_used > 0:
|
||
|
|
if usage.prompt_tokens_details is None:
|
||
|
|
usage.prompt_tokens_details = PromptTokensDetailsWrapper()
|
||
|
|
|
||
|
|
usage.prompt_tokens_details.web_search_requests = int(num_sources_used)
|
||
|
|
setattr(usage, "num_sources_used", int(num_sources_used))
|
||
|
|
verbose_logger.debug(f"X.AI web search sources used: {num_sources_used}")
|
||
|
|
|
||
|
|
|
||
|
|
class XAIChatCompletionStreamingHandler(OpenAIChatCompletionStreamingHandler):
|
||
|
|
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
||
|
|
"""
|
||
|
|
Handle xAI-specific streaming behavior.
|
||
|
|
|
||
|
|
xAI Grok sends a final chunk with empty choices array but with usage data
|
||
|
|
when stream_options={"include_usage": True} is set.
|
||
|
|
|
||
|
|
Example from xAI API:
|
||
|
|
{"id":"...","object":"chat.completion.chunk","created":...,"model":"grok-4-1-fast-non-reasoning",
|
||
|
|
"choices":[],"usage":{"prompt_tokens":171,"completion_tokens":2,"total_tokens":173,...}}
|
||
|
|
"""
|
||
|
|
# Handle chunks with empty choices but with usage data
|
||
|
|
choices = chunk.get("choices", [])
|
||
|
|
if len(choices) == 0 and "usage" in chunk:
|
||
|
|
# xAI sends usage in a chunk with empty choices array
|
||
|
|
# Add a dummy choice with empty delta to ensure proper processing
|
||
|
|
chunk["choices"] = [{"index": 0, "delta": {}, "finish_reason": None}]
|
||
|
|
|
||
|
|
return super().chunk_parser(chunk)
|