""" Translates from OpenAI's `/v1/chat/completions` to Databricks' `/chat/completions` """ import os from typing import ( TYPE_CHECKING, Any, AsyncIterator, Coroutine, Iterator, List, Literal, Optional, Tuple, Union, cast, overload, ) import httpx from pydantic import BaseModel from litellm.constants import RESPONSE_FORMAT_TOOL_NAME from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import ( _handle_invalid_parallel_tool_calls, _should_convert_tool_call_to_json_mode, ) from litellm.litellm_core_utils.prompt_templates.common_utils import ( strip_name_from_message, ) from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator from litellm.types.llms.anthropic import AllAnthropicToolsValues from litellm.types.llms.databricks import ( AllDatabricksContentValues, DatabricksChoice, DatabricksFunction, DatabricksResponse, DatabricksTool, ) from litellm.types.llms.openai import ( AllMessageValues, ChatCompletionRedactedThinkingBlock, ChatCompletionThinkingBlock, ChatCompletionToolChoiceFunctionParam, ChatCompletionToolChoiceObjectParam, ChatCompletionToolParam, ) from litellm.types.utils import ( ChatCompletionMessageToolCall, Choices, Message, ModelResponse, ModelResponseStream, ProviderField, Usage, ) from ...anthropic.chat.transformation import AnthropicConfig from ...openai_like.chat.transformation import OpenAILikeChatConfig from ..common_utils import DatabricksBase, DatabricksException def _sanitize_empty_content(message_dict: dict[str, Any]) -> None: """ Remove or filter content so empty text blocks are not sent. Databricks Model Serving uses Anthropic Messages API spec and rejects empty text blocks. """ content = message_dict.get("content") if content is None: message_dict.pop("content", None) return if isinstance(content, str): if not content.strip(): message_dict.pop("content") return if isinstance(content, list): if not content: message_dict.pop("content") return filtered = [ block for block in content if not ( isinstance(block, dict) and block.get("type") == "text" and not (block.get("text") or "").strip() ) ] if not filtered: message_dict.pop("content") else: message_dict["content"] = filtered if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj LiteLLMLoggingObj = _LiteLLMLoggingObj else: LiteLLMLoggingObj = Any class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig): """ Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request """ max_tokens: Optional[int] = None temperature: Optional[int] = None top_p: Optional[int] = None top_k: Optional[int] = None stop: Optional[Union[List[str], str]] = None n: Optional[int] = None def __init__( self, max_tokens: Optional[int] = None, temperature: Optional[int] = None, top_p: Optional[int] = None, top_k: Optional[int] = None, stop: Optional[Union[List[str], str]] = None, n: Optional[int] = None, ) -> None: locals_ = locals().copy() for key, value in locals_.items(): if key != "self" and value is not None: setattr(self.__class__, key, value) @classmethod def get_config(cls): return super().get_config() def get_required_params(self) -> List[ProviderField]: """For a given provider, return it's required fields with a description""" return [ ProviderField( field_name="api_key", field_type="string", field_description="Your Databricks API Key.", field_value="dapi...", ), ProviderField( field_name="api_base", field_type="string", field_description="Your Databricks API Base.", field_value="https://adb-..", ), ] def validate_environment( self, headers: dict, model: str, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: # Check for custom user agent in optional_params or environment # This allows partners building on LiteLLM to set their own telemetry # Use pop() to remove these keys so they don't get sent to the API custom_user_agent = ( optional_params.pop("user_agent", None) or optional_params.pop("databricks_user_agent", None) or litellm_params.get("user_agent") or os.getenv("LITELLM_USER_AGENT") or os.getenv("DATABRICKS_USER_AGENT") ) api_base, headers = self.databricks_validate_environment( api_base=api_base, api_key=api_key, endpoint_type="chat_completions", custom_endpoint=False, headers=headers, custom_user_agent=custom_user_agent, ) # Ensure Content-Type header is set headers["Content-Type"] = "application/json" return headers def get_complete_url( self, api_base: Optional[str], api_key: Optional[str], model: str, optional_params: dict, litellm_params: dict, stream: Optional[bool] = None, ) -> str: api_base = self._get_api_base(api_base) complete_url = f"{api_base}/chat/completions" return complete_url def get_supported_openai_params(self, model: Optional[str] = None) -> list: return [ "stream", "stop", "temperature", "top_p", "max_tokens", "max_completion_tokens", "n", "response_format", "tools", "tool_choice", "reasoning_effort", "thinking", ] def convert_anthropic_tool_to_databricks_tool( self, tool: Optional[AllAnthropicToolsValues] ) -> Optional[DatabricksTool]: if tool is None: return None # Build DatabricksFunction explicitly to avoid parameter conflicts function_params: DatabricksFunction = { "name": tool["name"], "parameters": cast(dict, tool.get("input_schema") or {}), } # Only add description if it exists description = tool.get("description") if description is not None: function_params["description"] = cast(Union[dict, str], description) return DatabricksTool( type="function", function=function_params, ) def _map_openai_to_dbrx_tool(self, model: str, tools: List) -> List[DatabricksTool]: # if not claude, send as is if "claude" not in model: return tools # if claude, convert to anthropic tool and then to databricks tool anthropic_tools, _ = self._map_tools( tools=tools ) # unclear how mcp tool calling on databricks works databricks_tools = [ cast(DatabricksTool, self.convert_anthropic_tool_to_databricks_tool(tool)) for tool in anthropic_tools ] return databricks_tools def map_response_format_to_databricks_tool( self, model: str, value: Optional[dict], optional_params: dict, is_thinking_enabled: bool, ) -> Optional[DatabricksTool]: if value is None: return None tool = self.map_response_format_to_anthropic_tool( value, optional_params, is_thinking_enabled ) databricks_tool = self.convert_anthropic_tool_to_databricks_tool(tool) return databricks_tool def remove_cache_control_flag_from_messages_and_tools( self, model: str, # allows overrides to selectively run this messages: List[AllMessageValues], tools: Optional[List["ChatCompletionToolParam"]] = None, ) -> Tuple[List[AllMessageValues], Optional[List["ChatCompletionToolParam"]]]: """ Override the parent class method to preserve cache_control for models on Databricks. Databricks supports Anthropic-style cache control for Claude models. Databricks ignores the cache_control flag with other models. """ # TODO: Think about how to best design the request transformation so that # every request doesn't have to be transformed for to OpenAI and Anthropic request formats. return messages, tools def map_openai_params( self, non_default_params: dict, optional_params: dict, model: str, drop_params: bool, replace_max_completion_tokens_with_max_tokens: bool = True, ) -> dict: is_thinking_enabled = self.is_thinking_enabled(non_default_params) mapped_params = super().map_openai_params( non_default_params, optional_params, model, drop_params ) if "tools" in mapped_params: mapped_params["tools"] = self._map_openai_to_dbrx_tool( model=model, tools=mapped_params["tools"] ) if ( "max_completion_tokens" in non_default_params and replace_max_completion_tokens_with_max_tokens ): mapped_params["max_tokens"] = non_default_params[ "max_completion_tokens" ] # most openai-compatible providers support 'max_tokens' not 'max_completion_tokens' mapped_params.pop("max_completion_tokens", None) if "response_format" in non_default_params and "claude" in model: _tool = self.map_response_format_to_databricks_tool( model, non_default_params["response_format"], mapped_params, is_thinking_enabled, ) if _tool is not None: self._add_tools_to_optional_params( optional_params=optional_params, tools=[_tool] ) optional_params["json_mode"] = True if not is_thinking_enabled: _tool_choice = ChatCompletionToolChoiceObjectParam( type="function", function=ChatCompletionToolChoiceFunctionParam( name=RESPONSE_FORMAT_TOOL_NAME ), ) optional_params["tool_choice"] = _tool_choice optional_params.pop( "response_format", None ) # unsupported for claude models - if json_schema -> convert to tool call if "reasoning_effort" in non_default_params and "claude" in model: optional_params["thinking"] = AnthropicConfig._map_reasoning_effort( reasoning_effort=non_default_params.get("reasoning_effort"), model=model ) optional_params.pop("reasoning_effort", None) ## handle thinking tokens self.update_optional_params_with_thinking_tokens( non_default_params=non_default_params, optional_params=mapped_params ) return mapped_params def _should_fake_stream(self, optional_params: dict) -> bool: """ Databricks doesn't support 'response_format' while streaming """ if optional_params.get("response_format") is not None: return True return False @overload def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: Literal[True] ) -> Coroutine[Any, Any, List[AllMessageValues]]: ... @overload def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: Literal[False] = False, ) -> List[AllMessageValues]: ... def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: bool = False ) -> Union[List[AllMessageValues], Coroutine[Any, Any, List[AllMessageValues]]]: """ Databricks does not support: - 'name' in user message. """ new_messages = [] for idx, message in enumerate(messages): if isinstance(message, BaseModel): _message = message.model_dump(exclude_none=True) else: _message = message _message = strip_name_from_message(_message, allowed_name_roles=["user"]) # Move message-level cache_control into a content block when content is a string. if "cache_control" in _message and isinstance(_message.get("content"), str): _message = self._move_cache_control_into_string_content_block(_message) _sanitize_empty_content(cast(dict[str, Any], _message)) new_messages.append(_message) if is_async: return super()._transform_messages( messages=new_messages, model=model, is_async=cast(Literal[True], True) ) else: return super()._transform_messages( messages=new_messages, model=model, is_async=cast(Literal[False], False) ) def _move_cache_control_into_string_content_block( self, message: AllMessageValues ) -> AllMessageValues: """ Moves message-level cache_control into a content block when content is a string. Transforms: {"role": "user", "content": "text", "cache_control": {...}} Into: {"role": "user", "content": [{"type": "text", "text": "text", "cache_control": {...}}]} This is required for Anthropic's prompt caching API when cache_control is specified at the message level but content is a simple string (not already an array of content blocks). """ content = message.get("content") # Create new message with cache_control moved into content block transformed_message = cast(dict[str, Any], message.copy()) cache_control = transformed_message.pop("cache_control") transformed_message["content"] = [ { "type": "text", "text": content, "cache_control": cache_control, } ] return cast(AllMessageValues, transformed_message) @staticmethod def extract_content_str( content: Optional[AllDatabricksContentValues], ) -> Optional[str]: if content is None: return None if isinstance(content, str): return content elif isinstance(content, list): content_str = "" for item in content: if item.get("type") == "text": text_value = item.get("text", "") content_str += str(text_value) if text_value is not None else "" return content_str else: raise Exception(f"Unsupported content type: {type(content)}") @staticmethod def extract_reasoning_content( content: Optional[AllDatabricksContentValues], ) -> Tuple[ Optional[str], Optional[ List[ Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock] ] ], ]: """ Extract and return the reasoning content and thinking blocks """ if content is None: return None, None thinking_blocks: Optional[ List[ Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock] ] ] = None reasoning_content: Optional[str] = None if isinstance(content, list): for item in content: if item.get("type") == "reasoning": summary_list = item.get("summary", []) if isinstance(summary_list, list): for sum in summary_list: if reasoning_content is None: reasoning_content = "" reasoning_content += sum["text"] thinking_block = ChatCompletionThinkingBlock( type="thinking", thinking=sum.get("text", ""), signature=sum.get("signature", ""), ) if thinking_blocks is None: thinking_blocks = [] thinking_blocks.append(thinking_block) return reasoning_content, thinking_blocks @staticmethod def extract_citations( content: Optional[AllDatabricksContentValues], ) -> Optional[List[Any]]: if content is None: return None citations = [] if isinstance(content, list): for item in content: text = item.get("text", None) if citations_item := item.get("citations"): citations.append( [ {**citation, "supported_text": text} for citation in citations_item ] ) return citations or None def _transform_dbrx_choices( self, choices: List[DatabricksChoice], json_mode: Optional[bool] = None ) -> List[Choices]: transformed_choices = [] for choice in choices: ## HANDLE JSON MODE - anthropic returns single function call] tool_calls = choice["message"].get("tool_calls", None) if tool_calls is not None: _openai_tool_calls = [] for _tc in tool_calls: _openai_tc = ChatCompletionMessageToolCall(**_tc) # type: ignore _openai_tool_calls.append(_openai_tc) fixed_tool_calls = _handle_invalid_parallel_tool_calls( _openai_tool_calls ) if fixed_tool_calls is not None: tool_calls = fixed_tool_calls translated_message: Optional[Message] = None finish_reason: Optional[str] = None if tool_calls and _should_convert_tool_call_to_json_mode( tool_calls=tool_calls, convert_tool_call_to_json_mode=json_mode, ): # to support response_format on claude models json_mode_content_str: Optional[str] = ( str(tool_calls[0]["function"].get("arguments", "")) or None ) if json_mode_content_str is not None: translated_message = Message(content=json_mode_content_str) finish_reason = "stop" if translated_message is None: ## get the content str content_str = DatabricksConfig.extract_content_str( choice["message"]["content"] ) ## get the reasoning content ( reasoning_content, thinking_blocks, ) = DatabricksConfig.extract_reasoning_content( choice["message"].get("content") ) citations = DatabricksConfig.extract_citations( choice["message"].get("content") ) translated_message = Message( role="assistant", content=content_str, reasoning_content=reasoning_content, thinking_blocks=thinking_blocks, tool_calls=choice["message"].get("tool_calls"), provider_specific_fields=( {"citations": citations} if citations is not None else None ), ) if finish_reason is None: finish_reason = choice["finish_reason"] translated_choice = Choices( finish_reason=finish_reason, index=choice["index"], message=translated_message, logprobs=None, enhancements=None, ) transformed_choices.append(translated_choice) return transformed_choices def transform_response( self, model: str, raw_response: httpx.Response, model_response: ModelResponse, logging_obj: LiteLLMLoggingObj, request_data: dict, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, encoding: Any, api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: # Redact sensitive data before logging to prevent credential leakage redacted_request_data = self.redact_sensitive_data(request_data) ## LOGGING - Never log actual API keys logging_obj.post_call( input=messages, api_key="[REDACTED]", original_response=raw_response.text, additional_args={"complete_input_dict": redacted_request_data}, ) ## RESPONSE OBJECT try: completion_response = DatabricksResponse(**raw_response.json()) # type: ignore except Exception as e: response_headers = getattr(raw_response, "headers", None) raise DatabricksException( message="Unable to get json response - {}, Original Response: {}".format( str(e), raw_response.text ), status_code=raw_response.status_code, headers=response_headers, ) model_response.model = completion_response["model"] model_response.id = completion_response["id"] model_response.created = completion_response["created"] setattr(model_response, "usage", Usage(**completion_response["usage"])) model_response.choices = self._transform_dbrx_choices( # type: ignore choices=completion_response["choices"], json_mode=json_mode, ) return model_response def get_model_response_iterator( self, streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], sync_stream: bool, json_mode: Optional[bool] = False, ): return DatabricksChatResponseIterator( streaming_response=streaming_response, sync_stream=sync_stream, json_mode=json_mode, ) class DatabricksChatResponseIterator(BaseModelResponseIterator): def __init__( self, streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], sync_stream: bool, json_mode: Optional[bool] = False, ): super().__init__(streaming_response, sync_stream) self.json_mode = json_mode self._last_function_name = None # Track the last seen function name def chunk_parser(self, chunk: dict) -> ModelResponseStream: try: translated_choices = [] for choice in chunk["choices"]: tool_calls = choice["delta"].get("tool_calls") if tool_calls and self.json_mode: # 1. Check if the function name is set and == RESPONSE_FORMAT_TOOL_NAME # 2. If no function name, just args -> check last function name (saved via state variable) # 3. Convert args to json # 4. Convert json to message # 5. Set content to message.content # 6. Set tool_calls to None from litellm.constants import RESPONSE_FORMAT_TOOL_NAME from litellm.llms.base_llm.base_utils import ( _convert_tool_response_to_message, ) # Check if this chunk has a function name function_name = tool_calls[0].get("function", {}).get("name") if function_name is not None: self._last_function_name = function_name # If we have a saved function name that matches RESPONSE_FORMAT_TOOL_NAME # or this chunk has the matching function name if ( self._last_function_name == RESPONSE_FORMAT_TOOL_NAME or function_name == RESPONSE_FORMAT_TOOL_NAME ): # Convert tool calls to message format message = _convert_tool_response_to_message(tool_calls) if message is not None: if message.content == "{}": # empty json message.content = "" choice["delta"]["content"] = message.content choice["delta"]["tool_calls"] = None elif tool_calls: for _tc in tool_calls: if _tc.get("function", {}).get("arguments") == "{}": _tc["function"]["arguments"] = "" # avoid invalid json if isinstance(choice["delta"].get("content"), list) and ( content := choice["delta"]["content"] ): if citations := content[0].get("citations"): # TODO: Databricks delta does not include supported text or chunk type. # Add either here once Databricks supports it to enable citation linkage. choice["delta"].setdefault("provider_specific_fields", {})[ "citation" ] = citations[ 0 ] # Databricks Content item always has citation as a list of list # extract the content str content_str = DatabricksConfig.extract_content_str( choice["delta"].get("content") ) # extract the reasoning content ( reasoning_content, thinking_blocks, ) = DatabricksConfig.extract_reasoning_content( choice["delta"].get("content") ) choice["delta"]["content"] = content_str choice["delta"]["reasoning_content"] = reasoning_content choice["delta"]["thinking_blocks"] = thinking_blocks translated_choices.append(choice) return ModelResponseStream( id=chunk["id"], object="chat.completion.chunk", created=chunk["created"], model=chunk["model"], choices=translated_choices, ) except KeyError as e: raise DatabricksException( message=f"KeyError: {e}, Got unexpected response from Databricks: {chunk}", status_code=400, ) except Exception as e: raise e