import base64 import datetime import hashlib import json from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Protocol, Tuple, Union, ) from urllib.parse import urlparse import httpx import litellm from litellm.litellm_core_utils.logging_utils import track_llm_api_timing from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, _get_httpx_client, get_async_httpx_client, version, ) from litellm.llms.oci.common_utils import OCIError from litellm.types.llms.oci import ( CohereChatRequest, CohereMessage, CohereChatResult, CohereParameterDefinition, CohereStreamChunk, CohereTool, CohereToolCall, OCIChatRequestPayload, OCICompletionPayload, OCICompletionResponse, OCIContentPartUnion, OCIImageContentPart, OCIImageUrl, OCIMessage, OCIRoles, OCIServingMode, OCIStreamChunk, OCITextContentPart, OCIToolCall, OCIToolDefinition, OCIVendors, ) from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import ( Delta, LlmProviders, ModelResponse, ModelResponseStream, StreamingChoices, ) from litellm.utils import ( ChatCompletionMessageToolCall, CustomStreamWrapper, Usage, ) if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj LiteLLMLoggingObj = _LiteLLMLoggingObj else: LiteLLMLoggingObj = Any class OCISignerProtocol(Protocol): """ Protocol for OCI request signers (e.g., oci.signer.Signer). This protocol defines the interface expected for OCI SDK signer objects. Compatible with the OCI Python SDK's Signer class. See: https://docs.oracle.com/en-us/iaas/tools/python/latest/api/signing.html """ def do_request_sign( self, request: Any, *, enforce_content_headers: bool = False ) -> None: """ Sign an HTTP request by adding authentication headers. Args: request: Request object with method, url, headers, body, and path_url attributes enforce_content_headers: Whether to enforce content-type and content-length headers """ ... @dataclass class OCIRequestWrapper: """ Wrapper for HTTP requests compatible with OCI signer interface. This class wraps request data in a format compatible with OCI SDK signers, which expect objects with method, url, headers, body, and path_url attributes. """ method: str url: str headers: dict body: bytes @property def path_url(self) -> str: """Returns the path + query string for OCI signing.""" parsed_url = urlparse(self.url) return parsed_url.path + ("?" + parsed_url.query if parsed_url.query else "") def sha256_base64(data: bytes) -> str: digest = hashlib.sha256(data).digest() return base64.b64encode(digest).decode() def build_signature_string(method, path, headers, signed_headers): lines = [] for header in signed_headers: if header == "(request-target)": value = f"{method.lower()} {path}" else: value = headers[header] lines.append(f"{header}: {value}") return "\n".join(lines) def load_private_key_from_str(key_str: str): try: from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa except ImportError as e: raise ImportError( "cryptography package is required for OCI authentication. " "Please install it with: pip install cryptography" ) from e key = serialization.load_pem_private_key( key_str.encode("utf-8"), password=None, ) if not isinstance(key, rsa.RSAPrivateKey): raise TypeError( "The provided private key is not an RSA key, which is required for OCI signing." ) return key def load_private_key_from_file(file_path: str): """Loads a private key from a file path""" try: with open(file_path, "r", encoding="utf-8") as f: key_str = f.read().strip() except FileNotFoundError: raise FileNotFoundError(f"Private key file not found: {file_path}") except OSError as e: raise OSError(f"Failed to read private key file '{file_path}': {e}") from e if not key_str: raise ValueError(f"Private key file is empty: {file_path}") return load_private_key_from_str(key_str) def get_vendor_from_model(model: str) -> OCIVendors: """ Extracts the vendor from the model name. Args: model (str): The model name. Returns: str: The vendor name. """ vendor = model.split(".")[0].lower() if vendor == "cohere": return OCIVendors.COHERE else: return OCIVendors.GENERIC # 5 minute timeout (models may need to load) STREAMING_TIMEOUT = 60 * 5 class OCIChatConfig(BaseConfig): """ Configuration class for OCI's API interface. """ def __init__( self, ) -> None: locals_ = locals().copy() for key, value in locals_.items(): if key != "self" and value is not None: setattr(self.__class__, key, value) # mark the class as using a custom stream wrapper because the default only iterates on lines setattr(self.__class__, "has_custom_stream_wrapper", True) self.openai_to_oci_generic_param_map = { "stream": "isStream", "max_tokens": "maxTokens", "max_completion_tokens": "maxTokens", "temperature": "temperature", "tools": "tools", "frequency_penalty": "frequencyPenalty", "logprobs": "logProbs", "logit_bias": "logitBias", "n": "numGenerations", "presence_penalty": "presencePenalty", "seed": "seed", "stop": "stop", "tool_choice": "toolChoice", "top_p": "topP", "max_retries": False, "top_logprobs": False, "modalities": False, "prediction": False, "stream_options": False, "function_call": False, "functions": False, "extra_headers": False, "parallel_tool_calls": False, "audio": False, "web_search_options": False, "response_format": "responseFormat", } # Cohere and Gemini use the same parameter mapping as GENERIC self.openai_to_oci_cohere_param_map = ( self.openai_to_oci_generic_param_map.copy() ) def get_supported_openai_params(self, model: str) -> List[str]: supported_params = [] vendor = get_vendor_from_model(model) if vendor == OCIVendors.COHERE: open_ai_to_oci_param_map = self.openai_to_oci_cohere_param_map open_ai_to_oci_param_map.pop("tool_choice") open_ai_to_oci_param_map.pop("max_retries") else: open_ai_to_oci_param_map = self.openai_to_oci_generic_param_map for key, value in open_ai_to_oci_param_map.items(): if value: supported_params.append(key) return supported_params def map_openai_params( self, non_default_params: dict, optional_params: dict, model: str, drop_params: bool, ) -> dict: adapted_params = {} vendor = get_vendor_from_model(model) if vendor == OCIVendors.COHERE: open_ai_to_oci_param_map = self.openai_to_oci_cohere_param_map else: open_ai_to_oci_param_map = self.openai_to_oci_generic_param_map all_params = {**non_default_params, **optional_params} for key, value in all_params.items(): alias = open_ai_to_oci_param_map.get(key) if alias is False: # Workaround for mypy issue if drop_params or litellm.drop_params: continue raise Exception(f"param `{key}` is not supported on OCI") if alias is None: adapted_params[key] = value continue adapted_params[alias] = value if alias == "responseFormat": adapted_params["response_format"] = value return adapted_params def _sign_with_oci_signer( self, headers: dict, optional_params: dict, request_data: dict, api_base: str, ) -> Tuple[dict, bytes]: """ Sign request using OCI SDK Signer object. Args: headers: Request headers to be signed optional_params: Optional parameters including oci_signer request_data: The request body dict to be sent in HTTP request api_base: The complete URL for the HTTP request Returns: Tuple of (signed_headers, encoded_body) Raises: OCIError: If signing fails ValueError: If HTTP method is unsupported """ oci_signer = optional_params.get("oci_signer") body = json.dumps(request_data).encode("utf-8") method = str(optional_params.get("method", "POST")).upper() if method not in ["POST", "GET", "PUT", "DELETE", "PATCH"]: raise ValueError(f"Unsupported HTTP method: {method}") prepared_headers = headers.copy() prepared_headers.setdefault("content-type", "application/json") prepared_headers.setdefault("content-length", str(len(body))) request_wrapper = OCIRequestWrapper( method=method, url=api_base, headers=prepared_headers, body=body ) if oci_signer is None: raise ValueError( "oci_signer cannot be None when calling _sign_with_oci_signer" ) try: oci_signer.do_request_sign(request_wrapper, enforce_content_headers=True) except Exception as e: raise OCIError( status_code=500, message=( f"Failed to sign request with provided oci_signer: {str(e)}. " "The signer must implement the OCI SDK Signer interface with a " "do_request_sign(request, enforce_content_headers=True) method. " "See: https://docs.oracle.com/en-us/iaas/tools/python/latest/api/signing.html" ), ) from e headers.update(request_wrapper.headers) return headers, body def _sign_with_manual_credentials( self, headers: dict, optional_params: dict, request_data: dict, api_base: str, ) -> Tuple[dict, None]: """ Sign request using manual OCI credentials. Args: headers: Request headers to be signed optional_params: Optional parameters including OCI credentials request_data: The request body dict to be sent in HTTP request api_base: The complete URL for the HTTP request Returns: Tuple of (signed_headers, None) Raises: Exception: If required credentials are missing ImportError: If cryptography package is not installed """ oci_region = optional_params.get("oci_region", "us-ashburn-1") api_base = ( api_base or litellm.api_base or f"https://inference.generativeai.{oci_region}.oci.oraclecloud.com" ) oci_user = optional_params.get("oci_user") oci_fingerprint = optional_params.get("oci_fingerprint") oci_tenancy = optional_params.get("oci_tenancy") oci_key = optional_params.get("oci_key") oci_key_file = optional_params.get("oci_key_file") if ( not oci_user or not oci_fingerprint or not oci_tenancy or not (oci_key or oci_key_file) ): raise Exception( "Missing required parameters: oci_user, oci_fingerprint, oci_tenancy, " "and at least one of oci_key or oci_key_file." ) method = str(optional_params.get("method", "POST")).upper() body = json.dumps(request_data).encode("utf-8") parsed = urlparse(api_base) path = parsed.path or "/" host = parsed.netloc date = datetime.datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S GMT") content_type = headers.get("content-type", "application/json") content_length = str(len(body)) x_content_sha256 = sha256_base64(body) headers_to_sign = { "date": date, "host": host, "content-type": content_type, "content-length": content_length, "x-content-sha256": x_content_sha256, } signed_headers = [ "date", "(request-target)", "host", "content-length", "content-type", "x-content-sha256", ] signing_string = build_signature_string( method, path, headers_to_sign, signed_headers ) try: from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import padding except ImportError as e: raise ImportError( "cryptography package is required for OCI authentication. " "Please install it with: pip install cryptography" ) from e # Handle oci_key - it should be a string (PEM content) oci_key_content = None if oci_key: if isinstance(oci_key, str): oci_key_content = oci_key # Fix common issues with PEM content # Replace escaped newlines with actual newlines oci_key_content = oci_key_content.replace("\\n", "\n") # Ensure proper line endings if "\r\n" in oci_key_content: oci_key_content = oci_key_content.replace("\r\n", "\n") else: raise OCIError( status_code=400, message=f"oci_key must be a string containing the PEM private key content. " f"Got type: {type(oci_key).__name__}", ) private_key = ( load_private_key_from_str(oci_key_content) if oci_key_content else load_private_key_from_file(oci_key_file) if oci_key_file else None ) if private_key is None: raise OCIError( status_code=400, message="Private key is required for OCI authentication. Please provide either oci_key or oci_key_file.", ) signature = private_key.sign( signing_string.encode("utf-8"), padding.PKCS1v15(), hashes.SHA256(), ) signature_b64 = base64.b64encode(signature).decode() key_id = f"{oci_tenancy}/{oci_user}/{oci_fingerprint}" authorization = ( 'Signature version="1",' f'keyId="{key_id}",' 'algorithm="rsa-sha256",' f'headers="{" ".join(signed_headers)}",' f'signature="{signature_b64}"' ) headers.update( { "authorization": authorization, "date": date, "host": host, "content-type": content_type, "content-length": content_length, "x-content-sha256": x_content_sha256, } ) return headers, None def sign_request( self, headers: dict, optional_params: dict, request_data: dict, api_base: str, api_key: Optional[str] = None, model: Optional[str] = None, stream: Optional[bool] = None, fake_stream: Optional[bool] = None, ) -> Tuple[dict, Optional[bytes]]: """ Sign the OCI request by adding authentication headers. Supports two signing modes: 1. OCI SDK Signer: Use an oci_signer object to sign the request 2. Manual Signing: Use OCI credentials to manually sign the request Args: headers: Request headers to be signed optional_params: Optional parameters including auth credentials or oci_signer request_data: The request body dict to be sent in HTTP request api_base: The complete URL for the HTTP request api_key: Optional API key (not used for OCI) model: Optional model name stream: Optional streaming flag fake_stream: Optional fake streaming flag Returns: Tuple of (signed_headers, encoded_body): - If oci_signer is provided: Returns (headers, body) where body is the encoded JSON - If manual credentials are provided: Returns (headers, None) as body is not returned for the manual signing path Raises: OCIError: If signing fails with oci_signer Exception: If required credentials are missing ImportError: If cryptography package is not installed (manual signing only) Example: >>> from oci.signer import Signer >>> signer = Signer( ... tenancy="ocid1.tenancy.oc1..", ... user="ocid1.user.oc1..", ... fingerprint="xx:xx:xx", ... private_key_file_location="~/.oci/key.pem" ... ) >>> headers, body = config.sign_request( ... headers={}, ... optional_params={"oci_signer": signer}, ... request_data={"message": "Hello"}, ... api_base="https://inference.generativeai.us-ashburn-1.oci.oraclecloud.com/..." ... ) """ oci_signer = optional_params.get("oci_signer") # If a signer is provided, use it for request signing if oci_signer is not None: return self._sign_with_oci_signer( headers, optional_params, request_data, api_base ) # Standard manual credential signing return self._sign_with_manual_credentials( headers, optional_params, request_data, api_base ) def validate_environment( self, headers: dict, model: str, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: """ Validate the OCI environment and credentials. Supports two authentication modes: 1. OCI SDK Signer: Pass an oci_signer object (e.g., oci.signer.Signer) 2. Manual Credentials: Pass oci_user, oci_fingerprint, oci_tenancy, and oci_key/oci_key_file Args: headers: Request headers to populate model: Model name messages: List of chat messages optional_params: Optional parameters including authentication credentials litellm_params: LiteLLM parameters api_key: Optional API key (not used for OCI) api_base: Optional API base URL Returns: Updated headers dict Raises: Exception: If required parameters are missing or invalid """ oci_signer = optional_params.get("oci_signer") oci_region = optional_params.get("oci_region", "us-ashburn-1") # Determine api_base api_base = ( api_base or litellm.api_base or f"https://inference.generativeai.{oci_region}.oci.oraclecloud.com" ) if not api_base: raise Exception( "Either `api_base` must be provided or `litellm.api_base` must be set. " "Alternatively, you can set the `oci_region` optional parameter to use the default OCI region." ) # Validate credentials only if signer is not provided if oci_signer is None: oci_user = optional_params.get("oci_user") oci_fingerprint = optional_params.get("oci_fingerprint") oci_tenancy = optional_params.get("oci_tenancy") oci_key = optional_params.get("oci_key") oci_key_file = optional_params.get("oci_key_file") oci_compartment_id = optional_params.get("oci_compartment_id") if ( not oci_user or not oci_fingerprint or not oci_tenancy or not (oci_key or oci_key_file) or not oci_compartment_id ): raise Exception( "Missing required parameters: oci_user, oci_fingerprint, oci_tenancy, oci_compartment_id " "and at least one of oci_key or oci_key_file. " "Alternatively, provide an oci_signer object from the OCI SDK." ) # Common header setup headers.update( { "content-type": "application/json", "user-agent": f"litellm/{version}", } ) if not messages: raise Exception( "kwarg `messages` must be an array of messages that follow the openai chat standard" ) return headers def get_complete_url( self, api_base: Optional[str], api_key: Optional[str], model: str, optional_params: dict, litellm_params: dict, stream: Optional[bool] = None, ) -> str: oci_region = optional_params.get("oci_region", "us-ashburn-1") return f"https://inference.generativeai.{oci_region}.oci.oraclecloud.com/20231130/actions/chat" def _get_optional_params(self, vendor: OCIVendors, optional_params: dict) -> Dict: selected_params = {} if vendor == OCIVendors.COHERE: open_ai_to_oci_param_map = self.openai_to_oci_cohere_param_map # remove tool_choice from the map open_ai_to_oci_param_map.pop("tool_choice") # Add default values for Cohere API selected_params = { "maxTokens": 600, "temperature": 1, "topK": 0, "topP": 0.75, "frequencyPenalty": 0, } else: open_ai_to_oci_param_map = self.openai_to_oci_generic_param_map # Map OpenAI params to OCI params for openai_key, oci_key in open_ai_to_oci_param_map.items(): if oci_key and openai_key in optional_params: selected_params[oci_key] = optional_params[openai_key] # type: ignore[index] # Also check for already-mapped OCI params (for backward compatibility) for oci_value in open_ai_to_oci_param_map.values(): if ( oci_value and oci_value in optional_params and oci_value not in selected_params ): selected_params[oci_value] = optional_params[oci_value] # type: ignore[index] if "tools" in selected_params: if vendor == OCIVendors.COHERE: selected_params["tools"] = self.adapt_tool_definitions_to_cohere_standard( # type: ignore[assignment] selected_params["tools"] # type: ignore[arg-type] ) else: selected_params["tools"] = adapt_tool_definition_to_oci_standard( # type: ignore[assignment] selected_params["tools"], vendor # type: ignore[arg-type] ) # Transform response_format type to OCI uppercase format if "responseFormat" in selected_params: rf = selected_params["responseFormat"] if isinstance(rf, dict) and "type" in rf: rf_payload = dict(rf) selected_params["responseFormat"] = rf_payload response_type = rf_payload["type"] schema_payload: Optional[Any] = None if "json_schema" in rf_payload: raw_schema_payload = rf_payload.pop("json_schema") if isinstance(raw_schema_payload, dict): schema_payload = dict(raw_schema_payload) else: schema_payload = raw_schema_payload if schema_payload is not None: rf_payload["jsonSchema"] = schema_payload if vendor == OCIVendors.COHERE: # Cohere expects lower-case type values rf_payload["type"] = response_type else: format_type = response_type.upper() if format_type == "JSON": format_type = "JSON_OBJECT" rf_payload["type"] = format_type return selected_params def adapt_messages_to_cohere_standard( self, messages: List[AllMessageValues] ) -> List[CohereMessage]: """Build chat history for Cohere models.""" chat_history = [] for msg in messages[:-1]: # All messages except the last one role = msg.get("role") content = msg.get("content") if isinstance(content, list): # Extract text from content array text_content = "" for content_item in content: if ( isinstance(content_item, dict) and content_item.get("type") == "text" ): text_content += content_item.get("text", "") content = text_content # Ensure content is a string if not isinstance(content, str): content = str(content) if content is not None else "" # Handle tool calls tool_calls: Optional[List[CohereToolCall]] = None if role == "assistant" and "tool_calls" in msg and msg.get("tool_calls"): # type: ignore[union-attr,typeddict-item] tool_calls = [] for tool_call in msg["tool_calls"]: # type: ignore[union-attr,typeddict-item] # Parse arguments if they're a JSON string raw_arguments: Any = tool_call.get("function", {}).get( "arguments", {} ) if isinstance(raw_arguments, str): try: arguments: Dict[str, Any] = json.loads(raw_arguments) except json.JSONDecodeError: arguments = {} else: arguments = raw_arguments tool_calls.append( CohereToolCall( name=str(tool_call.get("function", {}).get("name", "")), parameters=arguments, ) ) if role == "user": chat_history.append(CohereMessage(role="USER", message=content)) elif role == "assistant": chat_history.append( CohereMessage(role="CHATBOT", message=content, toolCalls=tool_calls) ) elif role == "tool": # Tool messages need special handling chat_history.append( CohereMessage( role="TOOL", message=content, toolCalls=None, # Tool messages don't have tool calls ) ) return chat_history def adapt_tool_definitions_to_cohere_standard( self, tools: List[Dict[str, Any]] ) -> List[CohereTool]: """Adapt tool definitions to Cohere format.""" cohere_tools = [] for tool in tools: function_def = tool.get("function", {}) parameters = function_def.get("parameters", {}).get("properties", {}) required = function_def.get("parameters", {}).get("required", []) parameter_definitions = {} for param_name, param_schema in parameters.items(): parameter_definitions[param_name] = CohereParameterDefinition( description=param_schema.get("description", ""), type=param_schema.get("type", "string"), isRequired=param_name in required, ) cohere_tools.append( CohereTool( name=function_def.get("name", ""), description=function_def.get("description", ""), parameterDefinitions=parameter_definitions, ) ) return cohere_tools def _extract_text_content(self, content: Any) -> str: """Extract text content from message content.""" if isinstance(content, str): return content elif isinstance(content, list): text_content = "" for content_item in content: if ( isinstance(content_item, dict) and content_item.get("type") == "text" ): text_content += content_item.get("text", "") return text_content return str(content) def transform_request( self, model: str, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, headers: dict, ) -> dict: oci_compartment_id = optional_params.get("oci_compartment_id", None) if not oci_compartment_id: raise Exception("kwarg `oci_compartment_id` is required for OCI requests") vendor = get_vendor_from_model(model) oci_serving_mode = optional_params.get("oci_serving_mode", "ON_DEMAND") if oci_serving_mode not in ["ON_DEMAND", "DEDICATED"]: raise Exception( "kwarg `oci_serving_mode` must be either 'ON_DEMAND' or 'DEDICATED'" ) if oci_serving_mode == "DEDICATED": oci_endpoint_id = optional_params.get("oci_endpoint_id", model) servingMode = OCIServingMode( servingType="DEDICATED", endpointId=oci_endpoint_id, ) else: servingMode = OCIServingMode( servingType="ON_DEMAND", modelId=model, ) # Build request based on vendor type if vendor == OCIVendors.COHERE: # For Cohere, we need to use the specific Cohere format # Extract the last user message as the main message user_messages = [msg for msg in messages if msg.get("role") == "user"] if not user_messages: raise Exception("No user message found for Cohere model") # Extract system messages into preambleOverride system_messages = [msg for msg in messages if msg.get("role") == "system"] preamble_override = None if system_messages: preamble = "\n".join( self._extract_text_content(msg["content"]) for msg in system_messages ) if preamble: preamble_override = preamble # Create Cohere-specific chat request optional_cohere_params = self._get_optional_params( OCIVendors.COHERE, optional_params ) chat_request = CohereChatRequest( apiFormat="COHERE", message=self._extract_text_content(user_messages[-1]["content"]), chatHistory=self.adapt_messages_to_cohere_standard(messages), preambleOverride=preamble_override, **optional_cohere_params, ) data = OCICompletionPayload( compartmentId=oci_compartment_id, servingMode=servingMode, chatRequest=chat_request, ) else: # Use generic format for other vendors data = OCICompletionPayload( compartmentId=oci_compartment_id, servingMode=servingMode, chatRequest=OCIChatRequestPayload( apiFormat=vendor.value, messages=adapt_messages_to_generic_oci_standard(messages), **self._get_optional_params(vendor, optional_params), ), ) return data.model_dump(exclude_none=True) def _handle_cohere_response( self, json_response: dict, model: str, model_response: ModelResponse ) -> ModelResponse: """Handle Cohere-specific response format.""" cohere_response = CohereChatResult(**json_response) # Cohere response format (uses camelCase) model_id = model # Set basic response info model_response.model = model_id model_response.created = int(datetime.datetime.now().timestamp()) # Extract the response text response_text = cohere_response.chatResponse.text oci_finish_reason = cohere_response.chatResponse.finishReason # Map finish reason if oci_finish_reason == "COMPLETE": finish_reason = "stop" elif oci_finish_reason == "MAX_TOKENS": finish_reason = "length" else: finish_reason = "stop" # Handle tool calls tool_calls: Optional[List[Dict[str, Any]]] = None if cohere_response.chatResponse.toolCalls: tool_calls = [] for tool_call in cohere_response.chatResponse.toolCalls: tool_calls.append( { "id": f"call_{len(tool_calls)}", # Generate a simple ID "type": "function", "function": { "name": tool_call.name, "arguments": json.dumps(tool_call.parameters), }, } ) # Create choice from litellm.types.utils import Choices choice = Choices( index=0, message={ "role": "assistant", "content": response_text, "tool_calls": tool_calls, }, finish_reason=finish_reason, ) model_response.choices = [choice] # Extract usage info usage_info = cohere_response.chatResponse.usage from litellm.types.utils import Usage model_response.usage = Usage( # type: ignore[attr-defined] prompt_tokens=usage_info.promptTokens, # type: ignore[union-attr] completion_tokens=usage_info.completionTokens, # type: ignore[union-attr] total_tokens=usage_info.totalTokens, # type: ignore[union-attr] ) return model_response def _handle_generic_response( self, json: dict, model: str, model_response: ModelResponse, raw_response: httpx.Response, ) -> ModelResponse: """Handle generic OCI response format.""" try: completion_response = OCICompletionResponse(**json) except TypeError as e: raise OCIError( message=f"Response cannot be casted to OCICompletionResponse: {str(e)}", status_code=raw_response.status_code, ) iso_str = completion_response.chatResponse.timeCreated dt = datetime.datetime.fromisoformat(iso_str.replace("Z", "+00:00")) model_response.created = int(dt.timestamp()) model_response.model = completion_response.modelId message = model_response.choices[0].message # type: ignore response_message = completion_response.chatResponse.choices[0].message if response_message.content and response_message.content[0].type == "TEXT": message.content = response_message.content[0].text if response_message.toolCalls: message.tool_calls = adapt_tools_to_openai_standard( response_message.toolCalls ) usage = Usage( prompt_tokens=completion_response.chatResponse.usage.promptTokens, completion_tokens=completion_response.chatResponse.usage.completionTokens, total_tokens=completion_response.chatResponse.usage.totalTokens, ) model_response.usage = usage # type: ignore return model_response def transform_response( self, model: str, raw_response: httpx.Response, model_response: ModelResponse, logging_obj: LiteLLMLoggingObj, request_data: dict, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, encoding: Any, api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: json = raw_response.json() # noqa: F811 error = json.get("error") if error is not None: raise OCIError( message=str(json["error"]), status_code=raw_response.status_code, ) if not isinstance(json, dict): raise OCIError( message="Invalid response format from OCI", status_code=raw_response.status_code, ) vendor = get_vendor_from_model(model) # Handle response based on vendor type if vendor == OCIVendors.COHERE: model_response = self._handle_cohere_response(json, model, model_response) else: model_response = self._handle_generic_response( json, model, model_response, raw_response ) model_response._hidden_params["additional_headers"] = raw_response.headers return model_response @track_llm_api_timing() def get_sync_custom_stream_wrapper( self, model: str, custom_llm_provider: str, logging_obj: LiteLLMLoggingObj, api_base: str, headers: dict, data: dict, messages: list, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, json_mode: Optional[bool] = None, signed_json_body: Optional[bytes] = None, ) -> "OCIStreamWrapper": if "stream" in data: del data["stream"] if client is None or isinstance(client, AsyncHTTPHandler): client = _get_httpx_client(params={}) try: response = client.post( api_base, headers=headers, data=json.dumps(data), stream=True, logging_obj=logging_obj, timeout=STREAMING_TIMEOUT, ) except httpx.HTTPStatusError as e: raise OCIError(status_code=e.response.status_code, message=e.response.text) if response.status_code != 200: raise OCIError(status_code=response.status_code, message=response.text) completion_stream = response.iter_text() streaming_response = OCIStreamWrapper( completion_stream=completion_stream, model=model, custom_llm_provider=custom_llm_provider, logging_obj=logging_obj, ) return streaming_response @track_llm_api_timing() async def get_async_custom_stream_wrapper( self, model: str, custom_llm_provider: str, logging_obj: LiteLLMLoggingObj, api_base: str, headers: dict, data: dict, messages: list, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, json_mode: Optional[bool] = None, signed_json_body: Optional[bytes] = None, ) -> "OCIStreamWrapper": if "stream" in data: del data["stream"] if client is None or isinstance(client, HTTPHandler): client = get_async_httpx_client(llm_provider=LlmProviders.BYTEZ, params={}) try: response = await client.post( api_base, headers=headers, data=json.dumps(data), stream=True, logging_obj=logging_obj, timeout=STREAMING_TIMEOUT, ) except httpx.HTTPStatusError as e: raise OCIError(status_code=e.response.status_code, message=e.response.text) if response.status_code != 200: raise OCIError(status_code=response.status_code, message=response.text) completion_stream = response.aiter_text() async def split_chunks(completion_stream: AsyncIterator[str]): async for item in completion_stream: for chunk in item.split("\n\n"): if not chunk: continue yield chunk.strip() streaming_response = OCIStreamWrapper( completion_stream=split_chunks(completion_stream), model=model, custom_llm_provider=custom_llm_provider, logging_obj=logging_obj, ) return streaming_response def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] ) -> BaseLLMException: return OCIError(status_code=status_code, message=error_message) open_ai_to_generic_oci_role_map: Dict[str, OCIRoles] = { "system": "SYSTEM", "user": "USER", "assistant": "ASSISTANT", "tool": "TOOL", } def adapt_messages_to_generic_oci_standard_content_message( role: str, content: Union[str, list] ) -> OCIMessage: new_content: List[OCIContentPartUnion] = [] if isinstance(content, str): return OCIMessage( role=open_ai_to_generic_oci_role_map[role], content=[OCITextContentPart(text=content)], toolCalls=None, toolCallId=None, ) # content is a list of content items: # [ # {"type": "text", "text": "Hello"}, # {"type": "image_url", "image_url": "https://example.com/image.png"} # ] for content_item in content: if not isinstance(content_item, dict): raise Exception("Each content item must be a dictionary") type = content_item.get("type") if not isinstance(type, str): raise Exception("Prop `type` is not a string") if type not in ["text", "image_url"]: raise Exception(f"Prop `{type}` is not supported") if type == "text": text = content_item.get("text") if not isinstance(text, str): raise Exception("Prop `text` is not a string") new_content.append(OCITextContentPart(text=text)) elif type == "image_url": image_url = content_item.get("image_url") # Handle both OpenAI format (object with url) and string format if isinstance(image_url, dict): image_url = image_url.get("url") if not isinstance(image_url, str): raise Exception( "Prop `image_url` must be a string or an object with a `url` property" ) new_content.append(OCIImageContentPart(imageUrl=OCIImageUrl(url=image_url))) return OCIMessage( role=open_ai_to_generic_oci_role_map[role], content=new_content, toolCalls=None, toolCallId=None, ) def adapt_messages_to_generic_oci_standard_tool_call( role: str, tool_calls: list ) -> OCIMessage: tool_calls_formated = [] for tool_call in tool_calls: if not isinstance(tool_call, dict): raise Exception("Each tool call must be a dictionary") if tool_call.get("type") != "function": raise Exception("OCI only supports function tools") tool_call_id = tool_call.get("id") if not isinstance(tool_call_id, str): raise Exception("Prop `id` is not a string") tool_function = tool_call.get("function") if not isinstance(tool_function, dict): raise Exception("Prop `function` is not a dictionary") function_name = tool_function.get("name") if not isinstance(function_name, str): raise Exception("Prop `name` is not a string") arguments = tool_call["function"].get("arguments", "{}") if not isinstance(arguments, str): raise Exception("Prop `arguments` is not a string") # tool_calls_formated.append(OCIToolCall( # id=tool_call_id, # type="FUNCTION", # function=OCIFunction( # name=function_name, # arguments=arguments # ) # )) tool_calls_formated.append( OCIToolCall( id=tool_call_id, type="FUNCTION", name=function_name, arguments=arguments, ) ) return OCIMessage( role=open_ai_to_generic_oci_role_map[role], content=None, toolCalls=tool_calls_formated, toolCallId=None, ) def adapt_messages_to_generic_oci_standard_tool_response( role: str, tool_call_id: str, content: str ) -> OCIMessage: return OCIMessage( role=open_ai_to_generic_oci_role_map[role], content=[OCITextContentPart(text=content)], toolCalls=None, toolCallId=tool_call_id, ) def adapt_messages_to_generic_oci_standard( messages: List[AllMessageValues], ) -> List[OCIMessage]: new_messages = [] for message in messages: role = message["role"] content = message.get("content") tool_calls = message.get("tool_calls") tool_call_id = message.get("tool_call_id") if role == "assistant" and tool_calls is not None: if not isinstance(tool_calls, list): raise Exception("Prop `tool_calls` must be a list of tool calls") new_messages.append( adapt_messages_to_generic_oci_standard_tool_call(role, tool_calls) ) elif role in ["system", "user", "assistant"] and content is not None: if not isinstance(content, (str, list)): raise Exception( "Prop `content` must be a string or a list of content items" ) new_messages.append( adapt_messages_to_generic_oci_standard_content_message(role, content) ) elif role == "tool": if not isinstance(tool_call_id, str): raise Exception("Prop `tool_call_id` is required and must be a string") if not isinstance(content, str): raise Exception("Prop `content` is not a string") new_messages.append( adapt_messages_to_generic_oci_standard_tool_response( role, tool_call_id, content ) ) return new_messages def adapt_tool_definition_to_oci_standard(tools: List[Dict], vendor: OCIVendors): new_tools = [] for tool in tools: if tool["type"] != "function": raise Exception("OCI only supports function tools") tool_function = tool.get("function") if not isinstance(tool_function, dict): raise Exception("Prop `function` is not a dictionary") new_tool = OCIToolDefinition( type="FUNCTION", name=tool_function.get("name"), description=tool_function.get("description", ""), parameters=tool_function.get("parameters", {}), ) new_tools.append(new_tool) return new_tools def adapt_tools_to_openai_standard( tools: List[OCIToolCall], ) -> List[ChatCompletionMessageToolCall]: new_tools = [] for tool in tools: new_tool = ChatCompletionMessageToolCall( id=tool.id, type="function", function={ "name": tool.name, "arguments": tool.arguments, }, ) new_tools.append(new_tool) return new_tools class OCIStreamWrapper(CustomStreamWrapper): """ Custom stream wrapper for OCI responses. This class is used to handle streaming responses from OCI's API. """ def __init__( self, **kwargs: Any, ): super().__init__(**kwargs) def chunk_creator(self, chunk: Any): if not isinstance(chunk, str): raise ValueError(f"Chunk is not a string: {chunk}") if not chunk.startswith("data:"): raise ValueError(f"Chunk does not start with 'data:': {chunk}") dict_chunk = json.loads(chunk[5:]) # Remove 'data: ' prefix and parse JSON # Check if this is a Cohere stream chunk if "apiFormat" in dict_chunk and dict_chunk.get("apiFormat") == "COHERE": return self._handle_cohere_stream_chunk(dict_chunk) else: return self._handle_generic_stream_chunk(dict_chunk) def _handle_cohere_stream_chunk(self, dict_chunk: dict): """Handle Cohere-specific streaming chunks.""" try: typed_chunk = CohereStreamChunk(**dict_chunk) except TypeError as e: raise ValueError(f"Chunk cannot be casted to CohereStreamChunk: {str(e)}") if typed_chunk.index is None: typed_chunk.index = 0 # Extract text content text = typed_chunk.text or "" # Map finish reason to standard format finish_reason = typed_chunk.finishReason if finish_reason == "COMPLETE": finish_reason = "stop" elif finish_reason == "MAX_TOKENS": finish_reason = "length" elif finish_reason is None: finish_reason = None else: finish_reason = "stop" # For Cohere, we don't have tool calls in the streaming format tool_calls = None return ModelResponseStream( choices=[ StreamingChoices( index=typed_chunk.index if typed_chunk.index else 0, delta=Delta( content=text, tool_calls=tool_calls, provider_specific_fields=None, thinking_blocks=None, reasoning_content=None, ), finish_reason=finish_reason, ) ] ) def _handle_generic_stream_chunk(self, dict_chunk: dict): """Handle generic OCI streaming chunks.""" # Fix missing required fields in tool calls before Pydantic validation # OCI streams tool calls progressively, so early chunks may be missing required fields if dict_chunk.get("message") and dict_chunk["message"].get("toolCalls"): for tool_call in dict_chunk["message"]["toolCalls"]: if "arguments" not in tool_call: tool_call["arguments"] = "" if "id" not in tool_call: tool_call["id"] = "" if "name" not in tool_call: tool_call["name"] = "" try: typed_chunk = OCIStreamChunk(**dict_chunk) except TypeError as e: raise ValueError(f"Chunk cannot be casted to OCIStreamChunk: {str(e)}") if typed_chunk.index is None: typed_chunk.index = 0 text = "" if typed_chunk.message and typed_chunk.message.content: for item in typed_chunk.message.content: if isinstance(item, OCITextContentPart): text += item.text elif isinstance(item, OCIImageContentPart): raise ValueError( "OCI does not support image content in streaming responses" ) else: raise ValueError( f"Unsupported content type in OCI response: {item.type}" ) tool_calls = None if typed_chunk.message and typed_chunk.message.toolCalls: tool_calls = adapt_tools_to_openai_standard(typed_chunk.message.toolCalls) return ModelResponseStream( choices=[ StreamingChoices( index=typed_chunk.index if typed_chunk.index else 0, delta=Delta( content=text, tool_calls=( [tool.model_dump() for tool in tool_calls] if tool_calls else None ), provider_specific_fields=None, # OCI does not have provider specific fields in the response thinking_blocks=None, # OCI does not have thinking blocks in the response reasoning_content=None, # OCI does not have reasoning content in the response ), finish_reason=typed_chunk.finishReason, ) ] )