chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1 @@
`/chat/completion` calls routed via `openai.py`.

View File

@@ -0,0 +1,11 @@
from litellm.llms.azure_ai.agents.handler import azure_ai_agents_handler
from litellm.llms.azure_ai.agents.transformation import (
AzureAIAgentsConfig,
AzureAIAgentsError,
)
__all__ = [
"AzureAIAgentsConfig",
"AzureAIAgentsError",
"azure_ai_agents_handler",
]

View File

@@ -0,0 +1,659 @@
"""
Handler for Azure Foundry Agent Service API.
This handler executes the multi-step agent flow:
1. Create thread (or use existing)
2. Add messages to thread
3. Create and poll a run
4. Retrieve the assistant's response messages
Model format: azure_ai/agents/<agent_id>
API Base format: https://<AIFoundryResourceName>.services.ai.azure.com/api/projects/<ProjectName>
Authentication: Uses Azure AD Bearer tokens (not API keys)
Get token via: az account get-access-token --resource 'https://ai.azure.com'
Supports both polling-based and native streaming (SSE) modes.
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
"""
import asyncio
import json
import time
import uuid
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
List,
Optional,
Tuple,
)
import httpx
from litellm._logging import verbose_logger
from litellm.llms.azure_ai.agents.transformation import (
AzureAIAgentsConfig,
AzureAIAgentsError,
)
from litellm.types.utils import ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
HTTPHandler = Any
AsyncHTTPHandler = Any
class AzureAIAgentsHandler:
"""
Handler for Azure AI Agent Service.
Executes the complete agent flow which requires multiple API calls.
"""
def __init__(self):
self.config = AzureAIAgentsConfig()
# -------------------------------------------------------------------------
# URL Builders
# -------------------------------------------------------------------------
# Azure Foundry Agents API uses /assistants, /threads, etc. directly
# See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
# -------------------------------------------------------------------------
def _build_thread_url(self, api_base: str, api_version: str) -> str:
return f"{api_base}/threads?api-version={api_version}"
def _build_messages_url(
self, api_base: str, thread_id: str, api_version: str
) -> str:
return f"{api_base}/threads/{thread_id}/messages?api-version={api_version}"
def _build_runs_url(self, api_base: str, thread_id: str, api_version: str) -> str:
return f"{api_base}/threads/{thread_id}/runs?api-version={api_version}"
def _build_run_status_url(
self, api_base: str, thread_id: str, run_id: str, api_version: str
) -> str:
return f"{api_base}/threads/{thread_id}/runs/{run_id}?api-version={api_version}"
def _build_list_messages_url(
self, api_base: str, thread_id: str, api_version: str
) -> str:
return f"{api_base}/threads/{thread_id}/messages?api-version={api_version}"
def _build_create_thread_and_run_url(self, api_base: str, api_version: str) -> str:
"""URL for the create-thread-and-run endpoint (supports streaming)."""
return f"{api_base}/threads/runs?api-version={api_version}"
# -------------------------------------------------------------------------
# Response Helpers
# -------------------------------------------------------------------------
def _extract_content_from_messages(self, messages_data: dict) -> str:
"""Extract assistant content from the messages response."""
for msg in messages_data.get("data", []):
if msg.get("role") == "assistant":
for content_item in msg.get("content", []):
if content_item.get("type") == "text":
return content_item.get("text", {}).get("value", "")
return ""
def _build_model_response(
self,
model: str,
content: str,
model_response: ModelResponse,
thread_id: str,
messages: List[Dict[str, Any]],
) -> ModelResponse:
"""Build the ModelResponse from agent output."""
from litellm.types.utils import Choices, Message, Usage
model_response.choices = [
Choices(
finish_reason="stop",
index=0,
message=Message(content=content, role="assistant"),
)
]
model_response.model = model
# Store thread_id for conversation continuity
if (
not hasattr(model_response, "_hidden_params")
or model_response._hidden_params is None
):
model_response._hidden_params = {}
model_response._hidden_params["thread_id"] = thread_id
# Estimate token usage
try:
from litellm.utils import token_counter
prompt_tokens = token_counter(model="gpt-3.5-turbo", messages=messages)
completion_tokens = token_counter(
model="gpt-3.5-turbo", text=content, count_response_tokens=True
)
setattr(
model_response,
"usage",
Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
except Exception as e:
verbose_logger.warning(f"Failed to calculate token usage: {str(e)}")
return model_response
def _prepare_completion_params(
self,
model: str,
api_base: str,
api_key: str,
optional_params: dict,
headers: Optional[dict],
) -> tuple:
"""Prepare common parameters for completion.
Azure Foundry Agents API uses Bearer token authentication:
- Authorization: Bearer <token> (Azure AD token from 'az account get-access-token --resource https://ai.azure.com')
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
"""
if headers is None:
headers = {}
headers["Content-Type"] = "application/json"
# Azure Foundry Agents uses Bearer token authentication
# The api_key here is expected to be an Azure AD token
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
api_version = optional_params.get(
"api_version", self.config.DEFAULT_API_VERSION
)
agent_id = self.config._get_agent_id(model, optional_params)
thread_id = optional_params.get("thread_id")
api_base = api_base.rstrip("/")
verbose_logger.debug(
f"Azure AI Agents completion - api_base: {api_base}, agent_id: {agent_id}"
)
return headers, api_version, agent_id, thread_id, api_base
def _check_response(
self, response: httpx.Response, expected_codes: List[int], error_msg: str
):
"""Check response status and raise error if not expected."""
if response.status_code not in expected_codes:
raise AzureAIAgentsError(
status_code=response.status_code,
message=f"{error_msg}: {response.text}",
)
# -------------------------------------------------------------------------
# Sync Completion
# -------------------------------------------------------------------------
def completion(
self,
model: str,
messages: List[Dict[str, Any]],
api_base: str,
api_key: str,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
litellm_params: dict,
timeout: float,
client: Optional[HTTPHandler] = None,
headers: Optional[dict] = None,
) -> ModelResponse:
"""Execute synchronous completion using Azure Agent Service."""
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
if client is None:
client = _get_httpx_client(
params={"ssl_verify": litellm_params.get("ssl_verify", None)}
)
(
headers,
api_version,
agent_id,
thread_id,
api_base,
) = self._prepare_completion_params(
model, api_base, api_key, optional_params, headers
)
def make_request(
method: str, url: str, json_data: Optional[dict] = None
) -> httpx.Response:
if method == "GET":
return client.get(url=url, headers=headers)
return client.post(
url=url,
headers=headers,
data=json.dumps(json_data) if json_data else None,
)
# Execute the agent flow
thread_id, content = self._execute_agent_flow_sync(
make_request=make_request,
api_base=api_base,
api_version=api_version,
agent_id=agent_id,
thread_id=thread_id,
messages=messages,
optional_params=optional_params,
)
return self._build_model_response(
model, content, model_response, thread_id, messages
)
def _execute_agent_flow_sync(
self,
make_request: Callable,
api_base: str,
api_version: str,
agent_id: str,
thread_id: Optional[str],
messages: List[Dict[str, Any]],
optional_params: dict,
) -> Tuple[str, str]:
"""Execute the agent flow synchronously. Returns (thread_id, content)."""
# Step 1: Create thread if not provided
if not thread_id:
verbose_logger.debug(
f"Creating thread at: {self._build_thread_url(api_base, api_version)}"
)
response = make_request(
"POST", self._build_thread_url(api_base, api_version), {}
)
self._check_response(response, [200, 201], "Failed to create thread")
thread_id = response.json()["id"]
verbose_logger.debug(f"Created thread: {thread_id}")
# At this point thread_id is guaranteed to be a string
assert thread_id is not None
# Step 2: Add messages to thread
for msg in messages:
if msg.get("role") in ["user", "system"]:
url = self._build_messages_url(api_base, thread_id, api_version)
response = make_request(
"POST", url, {"role": "user", "content": msg.get("content", "")}
)
self._check_response(response, [200, 201], "Failed to add message")
# Step 3: Create run
run_payload = {"assistant_id": agent_id}
if "instructions" in optional_params:
run_payload["instructions"] = optional_params["instructions"]
response = make_request(
"POST", self._build_runs_url(api_base, thread_id, api_version), run_payload
)
self._check_response(response, [200, 201], "Failed to create run")
run_id = response.json()["id"]
verbose_logger.debug(f"Created run: {run_id}")
# Step 4: Poll for completion
status_url = self._build_run_status_url(
api_base, thread_id, run_id, api_version
)
for _ in range(self.config.MAX_POLL_ATTEMPTS):
response = make_request("GET", status_url)
self._check_response(response, [200], "Failed to get run status")
status = response.json().get("status")
verbose_logger.debug(f"Run status: {status}")
if status == "completed":
break
elif status in ["failed", "cancelled", "expired"]:
error_msg = (
response.json()
.get("last_error", {})
.get("message", "Unknown error")
)
raise AzureAIAgentsError(
status_code=500, message=f"Run {status}: {error_msg}"
)
time.sleep(self.config.POLL_INTERVAL_SECONDS)
else:
raise AzureAIAgentsError(
status_code=408, message="Run timed out waiting for completion"
)
# Step 5: Get messages
response = make_request(
"GET", self._build_list_messages_url(api_base, thread_id, api_version)
)
self._check_response(response, [200], "Failed to get messages")
content = self._extract_content_from_messages(response.json())
return thread_id, content
# -------------------------------------------------------------------------
# Async Completion
# -------------------------------------------------------------------------
async def acompletion(
self,
model: str,
messages: List[Dict[str, Any]],
api_base: str,
api_key: str,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
litellm_params: dict,
timeout: float,
client: Optional[AsyncHTTPHandler] = None,
headers: Optional[dict] = None,
) -> ModelResponse:
"""Execute asynchronous completion using Azure Agent Service."""
import litellm
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
if client is None:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.AZURE_AI,
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
)
(
headers,
api_version,
agent_id,
thread_id,
api_base,
) = self._prepare_completion_params(
model, api_base, api_key, optional_params, headers
)
async def make_request(
method: str, url: str, json_data: Optional[dict] = None
) -> httpx.Response:
if method == "GET":
return await client.get(url=url, headers=headers)
return await client.post(
url=url,
headers=headers,
data=json.dumps(json_data) if json_data else None,
)
# Execute the agent flow
thread_id, content = await self._execute_agent_flow_async(
make_request=make_request,
api_base=api_base,
api_version=api_version,
agent_id=agent_id,
thread_id=thread_id,
messages=messages,
optional_params=optional_params,
)
return self._build_model_response(
model, content, model_response, thread_id, messages
)
async def _execute_agent_flow_async(
self,
make_request: Callable,
api_base: str,
api_version: str,
agent_id: str,
thread_id: Optional[str],
messages: List[Dict[str, Any]],
optional_params: dict,
) -> Tuple[str, str]:
"""Execute the agent flow asynchronously. Returns (thread_id, content)."""
# Step 1: Create thread if not provided
if not thread_id:
verbose_logger.debug(
f"Creating thread at: {self._build_thread_url(api_base, api_version)}"
)
response = await make_request(
"POST", self._build_thread_url(api_base, api_version), {}
)
self._check_response(response, [200, 201], "Failed to create thread")
thread_id = response.json()["id"]
verbose_logger.debug(f"Created thread: {thread_id}")
# At this point thread_id is guaranteed to be a string
assert thread_id is not None
# Step 2: Add messages to thread
for msg in messages:
if msg.get("role") in ["user", "system"]:
url = self._build_messages_url(api_base, thread_id, api_version)
response = await make_request(
"POST", url, {"role": "user", "content": msg.get("content", "")}
)
self._check_response(response, [200, 201], "Failed to add message")
# Step 3: Create run
run_payload = {"assistant_id": agent_id}
if "instructions" in optional_params:
run_payload["instructions"] = optional_params["instructions"]
response = await make_request(
"POST", self._build_runs_url(api_base, thread_id, api_version), run_payload
)
self._check_response(response, [200, 201], "Failed to create run")
run_id = response.json()["id"]
verbose_logger.debug(f"Created run: {run_id}")
# Step 4: Poll for completion
status_url = self._build_run_status_url(
api_base, thread_id, run_id, api_version
)
for _ in range(self.config.MAX_POLL_ATTEMPTS):
response = await make_request("GET", status_url)
self._check_response(response, [200], "Failed to get run status")
status = response.json().get("status")
verbose_logger.debug(f"Run status: {status}")
if status == "completed":
break
elif status in ["failed", "cancelled", "expired"]:
error_msg = (
response.json()
.get("last_error", {})
.get("message", "Unknown error")
)
raise AzureAIAgentsError(
status_code=500, message=f"Run {status}: {error_msg}"
)
await asyncio.sleep(self.config.POLL_INTERVAL_SECONDS)
else:
raise AzureAIAgentsError(
status_code=408, message="Run timed out waiting for completion"
)
# Step 5: Get messages
response = await make_request(
"GET", self._build_list_messages_url(api_base, thread_id, api_version)
)
self._check_response(response, [200], "Failed to get messages")
content = self._extract_content_from_messages(response.json())
return thread_id, content
# -------------------------------------------------------------------------
# Streaming Completion (Native SSE)
# -------------------------------------------------------------------------
async def acompletion_stream(
self,
model: str,
messages: List[Dict[str, Any]],
api_base: str,
api_key: str,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
litellm_params: dict,
timeout: float,
headers: Optional[dict] = None,
) -> AsyncIterator:
"""Execute async streaming completion using Azure Agent Service with native SSE."""
import litellm
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
(
headers,
api_version,
agent_id,
thread_id,
api_base,
) = self._prepare_completion_params(
model, api_base, api_key, optional_params, headers
)
# Build payload for create-thread-and-run with streaming
thread_messages = []
for msg in messages:
if msg.get("role") in ["user", "system"]:
thread_messages.append(
{"role": "user", "content": msg.get("content", "")}
)
payload: Dict[str, Any] = {
"assistant_id": agent_id,
"stream": True,
}
# Add thread with messages if we don't have an existing thread
if not thread_id:
payload["thread"] = {"messages": thread_messages}
if "instructions" in optional_params:
payload["instructions"] = optional_params["instructions"]
url = self._build_create_thread_and_run_url(api_base, api_version)
verbose_logger.debug(f"Azure AI Agents streaming - URL: {url}")
# Use LiteLLM's async HTTP client for streaming
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.AZURE_AI,
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
)
response = await client.post(
url=url,
headers=headers,
data=json.dumps(payload),
stream=True,
)
if response.status_code not in [200, 201]:
error_text = await response.aread()
raise AzureAIAgentsError(
status_code=response.status_code,
message=f"Streaming request failed: {error_text.decode()}",
)
async for chunk in self._process_sse_stream(response, model):
yield chunk
async def _process_sse_stream(
self,
response: httpx.Response,
model: str,
) -> AsyncIterator:
"""Process SSE stream and yield OpenAI-compatible streaming chunks."""
from litellm.types.utils import Delta, ModelResponseStream, StreamingChoices
response_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
created = int(time.time())
thread_id = None
current_event = None
async for line in response.aiter_lines():
line = line.strip()
if line.startswith("event:"):
current_event = line[6:].strip()
continue
if line.startswith("data:"):
data_str = line[5:].strip()
if data_str == "[DONE]":
# Send final chunk with finish_reason
final_chunk = ModelResponseStream(
id=response_id,
created=created,
model=model,
object="chat.completion.chunk",
choices=[
StreamingChoices(
finish_reason="stop",
index=0,
delta=Delta(content=None),
)
],
)
if thread_id:
final_chunk._hidden_params = {"thread_id": thread_id}
yield final_chunk
return
try:
data = json.loads(data_str)
except json.JSONDecodeError:
continue
# Extract thread_id from thread.created event
if current_event == "thread.created" and "id" in data:
thread_id = data["id"]
verbose_logger.debug(f"Stream created thread: {thread_id}")
# Process message deltas - this is where the actual content comes
if current_event == "thread.message.delta":
delta_content = data.get("delta", {}).get("content", [])
for content_item in delta_content:
if content_item.get("type") == "text":
text_value = content_item.get("text", {}).get("value", "")
if text_value:
chunk = ModelResponseStream(
id=response_id,
created=created,
model=model,
object="chat.completion.chunk",
choices=[
StreamingChoices(
finish_reason=None,
index=0,
delta=Delta(
content=text_value, role="assistant"
),
)
],
)
if thread_id:
chunk._hidden_params = {"thread_id": thread_id}
yield chunk
# Singleton instance
azure_ai_agents_handler = AzureAIAgentsHandler()

View File

@@ -0,0 +1,402 @@
"""
Transformation for Azure Foundry Agent Service API.
Azure Foundry Agent Service provides an Assistants-like API for running agents.
This follows the OpenAI Assistants pattern: create thread -> add messages -> create/poll run.
Model format: azure_ai/agents/<agent_id>
API Base format: https://<AIFoundryResourceName>.services.ai.azure.com/api/projects/<ProjectName>
Authentication: Uses Azure AD Bearer tokens (not API keys)
Get token via: az account get-access-token --resource 'https://ai.azure.com'
The API uses these endpoints:
- POST /threads - Create a thread
- POST /threads/{thread_id}/messages - Add message to thread
- POST /threads/{thread_id}/runs - Create a run
- GET /threads/{thread_id}/runs/{run_id} - Poll run status
- GET /threads/{thread_id}/messages - List messages in thread
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
"""
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
HTTPHandler = Any
AsyncHTTPHandler = Any
class AzureAIAgentsError(BaseLLMException):
"""Exception class for Azure AI Agent Service API errors."""
pass
class AzureAIAgentsConfig(BaseConfig):
"""
Configuration for Azure AI Agent Service API.
Azure AI Agent Service is a fully managed service for building AI agents
that can understand natural language and perform tasks.
Model format: azure_ai/agents/<agent_id>
The flow is:
1. Create a thread
2. Add user messages to the thread
3. Create and poll a run
4. Retrieve the assistant's response messages
"""
# Default API version for Azure Foundry Agent Service
# GA version: 2025-05-01, Preview: 2025-05-15-preview
# See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
DEFAULT_API_VERSION = "2025-05-01"
# Polling configuration
MAX_POLL_ATTEMPTS = 60
POLL_INTERVAL_SECONDS = 1.0
def __init__(self, **kwargs):
super().__init__(**kwargs)
@staticmethod
def is_azure_ai_agents_route(model: str) -> bool:
"""
Check if the model is an Azure AI Agents route.
Model format: azure_ai/agents/<agent_id>
"""
return "agents/" in model
@staticmethod
def get_agent_id_from_model(model: str) -> str:
"""
Extract agent ID from the model string.
Model format: azure_ai/agents/<agent_id> -> <agent_id>
or: agents/<agent_id> -> <agent_id>
"""
if "agents/" in model:
# Split on "agents/" and take the part after it
parts = model.split("agents/", 1)
if len(parts) == 2:
return parts[1]
return model
def _get_openai_compatible_provider_info(
self,
api_base: Optional[str],
api_key: Optional[str],
) -> Tuple[Optional[str], Optional[str]]:
"""
Get Azure AI Agent Service API base and key from params or environment.
Returns:
Tuple of (api_base, api_key)
"""
from litellm.secret_managers.main import get_secret_str
api_base = api_base or get_secret_str("AZURE_AI_API_BASE")
api_key = api_key or get_secret_str("AZURE_AI_API_KEY")
return api_base, api_key
def get_supported_openai_params(self, model: str) -> List[str]:
"""
Azure Agents supports minimal OpenAI params since it's an agent runtime.
"""
return ["stream"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map OpenAI params to Azure Agents params.
"""
return optional_params
def _get_api_version(self, optional_params: dict) -> str:
"""Get API version from optional params or use default."""
return optional_params.get("api_version", self.DEFAULT_API_VERSION)
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the base URL for Azure AI Agent Service.
The actual endpoint will vary based on the operation:
- /openai/threads for creating threads
- /openai/threads/{thread_id}/messages for adding messages
- /openai/threads/{thread_id}/runs for creating runs
This returns the base URL that will be modified for each operation.
"""
if api_base is None:
raise ValueError(
"api_base is required for Azure AI Agents. Set it via AZURE_AI_API_BASE env var or api_base parameter."
)
# Remove trailing slash if present
api_base = api_base.rstrip("/")
# Return base URL - actual endpoints will be constructed during request
return api_base
def _get_agent_id(self, model: str, optional_params: dict) -> str:
"""
Get the agent ID from model or optional_params.
model format: "azure_ai/agents/<agent_id>" or "agents/<agent_id>" or just "<agent_id>"
"""
agent_id = optional_params.get("agent_id") or optional_params.get(
"assistant_id"
)
if agent_id:
return agent_id
# Extract from model name using the static method
return self.get_agent_id_from_model(model)
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform the request for Azure Agents.
This stores the necessary data for the multi-step agent flow.
The actual API calls happen in the custom handler.
"""
agent_id = self._get_agent_id(model, optional_params)
# Convert messages to a format we can use
converted_messages = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
# Handle content that might be a list
if isinstance(content, list):
content = convert_content_list_to_str(msg)
# Ensure content is a string
if not isinstance(content, str):
content = str(content)
converted_messages.append({"role": role, "content": content})
payload: Dict[str, Any] = {
"agent_id": agent_id,
"messages": converted_messages,
"api_version": self._get_api_version(optional_params),
}
# Pass through thread_id if provided (for continuing conversations)
if "thread_id" in optional_params:
payload["thread_id"] = optional_params["thread_id"]
# Pass through any additional instructions
if "instructions" in optional_params:
payload["instructions"] = optional_params["instructions"]
verbose_logger.debug(f"Azure AI Agents request payload: {payload}")
return payload
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate and set up environment for Azure Foundry Agents requests.
Azure Foundry Agents uses Bearer token authentication with Azure AD tokens.
Get token via: az account get-access-token --resource 'https://ai.azure.com'
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
"""
headers["Content-Type"] = "application/json"
# Azure Foundry Agents uses Bearer token authentication
# The api_key here is expected to be an Azure AD token
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return AzureAIAgentsError(status_code=status_code, message=error_message)
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
"""
Azure Agents uses polling, so we fake stream by returning the final response.
"""
return True
@property
def has_custom_stream_wrapper(self) -> bool:
"""Azure Agents doesn't have native streaming - uses fake stream."""
return False
@property
def supports_stream_param_in_request_body(self) -> bool:
"""
Azure Agents does not use a stream param in request body.
"""
return False
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
Transform the Azure Agents response to LiteLLM ModelResponse format.
"""
# This is not used since we have a custom handler
return model_response
@staticmethod
def completion(
model: str,
messages: List,
api_base: str,
api_key: Optional[str],
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
litellm_params: dict,
timeout: Union[float, int, Any],
acompletion: bool,
stream: Optional[bool] = False,
headers: Optional[dict] = None,
) -> Any:
"""
Dispatch method for Azure Foundry Agents completion.
Routes to sync or async completion based on acompletion flag.
Supports native streaming via SSE when stream=True and acompletion=True.
Authentication: Uses Azure AD Bearer tokens.
- Pass api_key directly as an Azure AD token
- Or set up Azure AD credentials via environment variables for automatic token retrieval:
- AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET (Service Principal)
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
"""
from litellm.llms.azure.common_utils import get_azure_ad_token
from litellm.llms.azure_ai.agents.handler import azure_ai_agents_handler
from litellm.types.router import GenericLiteLLMParams
# If no api_key is provided, try to get Azure AD token
if api_key is None:
# Try to get Azure AD token using the existing Azure auth mechanisms
# This uses the scope for Azure AI (ai.azure.com) instead of cognitive services
# Create a GenericLiteLLMParams with the scope override for Azure Foundry Agents
azure_auth_params = dict(litellm_params) if litellm_params else {}
azure_auth_params["azure_scope"] = "https://ai.azure.com/.default"
api_key = get_azure_ad_token(GenericLiteLLMParams(**azure_auth_params))
if api_key is None:
raise ValueError(
"api_key (Azure AD token) is required for Azure Foundry Agents. "
"Either pass api_key directly, or set AZURE_TENANT_ID, AZURE_CLIENT_ID, "
"and AZURE_CLIENT_SECRET environment variables for Service Principal auth. "
"Manual token: az account get-access-token --resource 'https://ai.azure.com'"
)
if acompletion:
if stream:
# Native async streaming via SSE - return the async generator directly
return azure_ai_agents_handler.acompletion_stream(
model=model,
messages=messages,
api_base=api_base,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
timeout=timeout,
headers=headers,
)
else:
return azure_ai_agents_handler.acompletion(
model=model,
messages=messages,
api_base=api_base,
api_key=api_key,
model_response=model_response,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
timeout=timeout,
headers=headers,
)
else:
# Sync completion - streaming not supported for sync
return azure_ai_agents_handler.completion(
model=model,
messages=messages,
api_base=api_base,
api_key=api_key,
model_response=model_response,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
timeout=timeout,
headers=headers,
)

View File

@@ -0,0 +1,16 @@
"""
Azure Anthropic provider - supports Claude models via Azure Foundry
"""
from .handler import AzureAnthropicChatCompletion
from .transformation import AzureAnthropicConfig
try:
from .messages_transformation import AzureAnthropicMessagesConfig
__all__ = [
"AzureAnthropicChatCompletion",
"AzureAnthropicConfig",
"AzureAnthropicMessagesConfig",
]
except ImportError:
__all__ = ["AzureAnthropicChatCompletion", "AzureAnthropicConfig"]

View File

@@ -0,0 +1,19 @@
"""
Azure AI Anthropic CountTokens API implementation.
"""
from litellm.llms.azure_ai.anthropic.count_tokens.handler import (
AzureAIAnthropicCountTokensHandler,
)
from litellm.llms.azure_ai.anthropic.count_tokens.token_counter import (
AzureAIAnthropicTokenCounter,
)
from litellm.llms.azure_ai.anthropic.count_tokens.transformation import (
AzureAIAnthropicCountTokensConfig,
)
__all__ = [
"AzureAIAnthropicCountTokensHandler",
"AzureAIAnthropicCountTokensConfig",
"AzureAIAnthropicTokenCounter",
]

View File

@@ -0,0 +1,133 @@
"""
Azure AI Anthropic CountTokens API handler.
Uses httpx for HTTP requests with Azure authentication.
"""
from typing import Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.llms.anthropic.common_utils import AnthropicError
from litellm.llms.azure_ai.anthropic.count_tokens.transformation import (
AzureAIAnthropicCountTokensConfig,
)
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
class AzureAIAnthropicCountTokensHandler(AzureAIAnthropicCountTokensConfig):
"""
Handler for Azure AI Anthropic CountTokens API requests.
Uses httpx for HTTP requests with Azure authentication.
"""
async def handle_count_tokens_request(
self,
model: str,
messages: List[Dict[str, Any]],
api_key: str,
api_base: str,
litellm_params: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
system: Optional[Any] = None,
) -> Dict[str, Any]:
"""
Handle a CountTokens request using httpx with Azure authentication.
Args:
model: The model identifier (e.g., "claude-3-5-sonnet")
messages: The messages to count tokens for
api_key: The Azure AI API key
api_base: The Azure AI API base URL
litellm_params: Optional LiteLLM parameters
timeout: Optional timeout for the request (defaults to litellm.request_timeout)
Returns:
Dictionary containing token count response
Raises:
AnthropicError: If the API request fails
"""
try:
# Validate the request
self.validate_request(model, messages)
verbose_logger.debug(
f"Processing Azure AI Anthropic CountTokens request for model: {model}"
)
# Transform request to Anthropic format
request_body = self.transform_request_to_count_tokens(
model=model,
messages=messages,
tools=tools,
system=system,
)
verbose_logger.debug(f"Transformed request: {request_body}")
# Get endpoint URL
endpoint_url = self.get_count_tokens_endpoint(api_base)
verbose_logger.debug(f"Making request to: {endpoint_url}")
# Get required headers with Azure authentication
headers = self.get_required_headers(
api_key=api_key,
litellm_params=litellm_params,
)
# Use LiteLLM's async httpx client
async_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.AZURE_AI
)
# Use provided timeout or fall back to litellm.request_timeout
request_timeout = (
timeout if timeout is not None else litellm.request_timeout
)
response = await async_client.post(
endpoint_url,
headers=headers,
json=request_body,
timeout=request_timeout,
)
verbose_logger.debug(f"Response status: {response.status_code}")
if response.status_code != 200:
error_text = response.text
verbose_logger.error(f"Azure AI Anthropic API error: {error_text}")
raise AnthropicError(
status_code=response.status_code,
message=error_text,
)
azure_response = response.json()
verbose_logger.debug(f"Azure AI Anthropic response: {azure_response}")
# Return Anthropic-compatible response directly - no transformation needed
return azure_response
except AnthropicError:
# Re-raise Anthropic exceptions as-is
raise
except httpx.HTTPStatusError as e:
# HTTP errors - preserve the actual status code
verbose_logger.error(f"HTTP error in CountTokens handler: {str(e)}")
raise AnthropicError(
status_code=e.response.status_code,
message=e.response.text,
)
except Exception as e:
verbose_logger.error(f"Error in CountTokens handler: {str(e)}")
raise AnthropicError(
status_code=500,
message=f"CountTokens processing error: {str(e)}",
)

View File

@@ -0,0 +1,123 @@
"""
Azure AI Anthropic Token Counter implementation using the CountTokens API.
"""
import os
from typing import Any, Dict, List, Optional
from litellm._logging import verbose_logger
from litellm.llms.azure_ai.anthropic.count_tokens.handler import (
AzureAIAnthropicCountTokensHandler,
)
from litellm.llms.base_llm.base_utils import BaseTokenCounter
from litellm.types.utils import LlmProviders, TokenCountResponse
# Global handler instance - reuse across all token counting requests
azure_ai_anthropic_count_tokens_handler = AzureAIAnthropicCountTokensHandler()
class AzureAIAnthropicTokenCounter(BaseTokenCounter):
"""Token counter implementation for Azure AI Anthropic provider using the CountTokens API."""
def should_use_token_counting_api(
self,
custom_llm_provider: Optional[str] = None,
) -> bool:
return custom_llm_provider == LlmProviders.AZURE_AI.value
async def count_tokens(
self,
model_to_use: str,
messages: Optional[List[Dict[str, Any]]],
contents: Optional[List[Dict[str, Any]]],
deployment: Optional[Dict[str, Any]] = None,
request_model: str = "",
tools: Optional[List[Dict[str, Any]]] = None,
system: Optional[Any] = None,
) -> Optional[TokenCountResponse]:
"""
Count tokens using Azure AI Anthropic's CountTokens API.
Args:
model_to_use: The model identifier
messages: The messages to count tokens for
contents: Alternative content format (not used for Anthropic)
deployment: Deployment configuration containing litellm_params
request_model: The original request model name
Returns:
TokenCountResponse with token count, or None if counting fails
"""
from litellm.llms.anthropic.common_utils import AnthropicError
if not messages:
return None
deployment = deployment or {}
litellm_params = deployment.get("litellm_params", {})
# Get Azure AI API key from deployment config or environment
api_key = litellm_params.get("api_key")
if not api_key:
api_key = os.getenv("AZURE_AI_API_KEY")
# Get API base from deployment config or environment
api_base = litellm_params.get("api_base")
if not api_base:
api_base = os.getenv("AZURE_AI_API_BASE")
if not api_key:
verbose_logger.warning("No Azure AI API key found for token counting")
return None
if not api_base:
verbose_logger.warning("No Azure AI API base found for token counting")
return None
try:
result = await azure_ai_anthropic_count_tokens_handler.handle_count_tokens_request(
model=model_to_use,
messages=messages,
api_key=api_key,
api_base=api_base,
litellm_params=litellm_params,
tools=tools,
system=system,
)
if result is not None:
return TokenCountResponse(
total_tokens=result.get("input_tokens", 0),
request_model=request_model,
model_used=model_to_use,
tokenizer_type="azure_ai_anthropic_api",
original_response=result,
)
except AnthropicError as e:
verbose_logger.warning(
f"Azure AI Anthropic CountTokens API error: status={e.status_code}, message={e.message}"
)
return TokenCountResponse(
total_tokens=0,
request_model=request_model,
model_used=model_to_use,
tokenizer_type="azure_ai_anthropic_api",
error=True,
error_message=e.message,
status_code=e.status_code,
)
except Exception as e:
verbose_logger.warning(
f"Error calling Azure AI Anthropic CountTokens API: {e}"
)
return TokenCountResponse(
total_tokens=0,
request_model=request_model,
model_used=model_to_use,
tokenizer_type="azure_ai_anthropic_api",
error=True,
error_message=str(e),
status_code=500,
)
return None

View File

@@ -0,0 +1,90 @@
"""
Azure AI Anthropic CountTokens API transformation logic.
Extends the base Anthropic CountTokens transformation with Azure authentication.
"""
from typing import Any, Dict, Optional
from litellm.constants import ANTHROPIC_TOKEN_COUNTING_BETA_VERSION
from litellm.llms.anthropic.count_tokens.transformation import (
AnthropicCountTokensConfig,
)
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.types.router import GenericLiteLLMParams
class AzureAIAnthropicCountTokensConfig(AnthropicCountTokensConfig):
"""
Configuration and transformation logic for Azure AI Anthropic CountTokens API.
Extends AnthropicCountTokensConfig with Azure authentication.
Azure AI Anthropic uses the same endpoint format but with Azure auth headers.
"""
def get_required_headers(
self,
api_key: str,
litellm_params: Optional[Dict[str, Any]] = None,
) -> Dict[str, str]:
"""
Get the required headers for the Azure AI Anthropic CountTokens API.
Azure AI Anthropic uses Anthropic's native API format, which requires the
x-api-key header for authentication (in addition to Azure's api-key header).
Args:
api_key: The Azure AI API key
litellm_params: Optional LiteLLM parameters for additional auth config
Returns:
Dictionary of required headers with both x-api-key and Azure authentication
"""
# Start with base headers including x-api-key for Anthropic API compatibility
headers = {
"Content-Type": "application/json",
"anthropic-version": "2023-06-01",
"anthropic-beta": ANTHROPIC_TOKEN_COUNTING_BETA_VERSION,
"x-api-key": api_key, # Azure AI Anthropic requires this header
}
# Also set up Azure auth headers for flexibility
litellm_params = litellm_params or {}
if "api_key" not in litellm_params:
litellm_params["api_key"] = api_key
litellm_params_obj = GenericLiteLLMParams(**litellm_params)
# Get Azure auth headers (api-key or Authorization)
azure_headers = BaseAzureLLM._base_validate_azure_environment(
headers={}, litellm_params=litellm_params_obj
)
# Merge Azure auth headers
headers.update(azure_headers)
return headers
def get_count_tokens_endpoint(self, api_base: str) -> str:
"""
Get the Azure AI Anthropic CountTokens API endpoint.
Args:
api_base: The Azure AI API base URL
(e.g., https://my-resource.services.ai.azure.com or
https://my-resource.services.ai.azure.com/anthropic)
Returns:
The endpoint URL for the CountTokens API
"""
# Azure AI Anthropic endpoint format:
# https://<resource>.services.ai.azure.com/anthropic/v1/messages/count_tokens
api_base = api_base.rstrip("/")
# Ensure the URL has /anthropic path
if not api_base.endswith("/anthropic"):
if "/anthropic" not in api_base:
api_base = f"{api_base}/anthropic"
# Add the count_tokens path
return f"{api_base}/v1/messages/count_tokens"

View File

@@ -0,0 +1,226 @@
"""
Azure Anthropic handler - reuses AnthropicChatCompletion logic with Azure authentication
"""
import copy
import json
from typing import TYPE_CHECKING, Callable, Union
import httpx
from litellm.llms.anthropic.chat.handler import AnthropicChatCompletion
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
)
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper
from .transformation import AzureAnthropicConfig
if TYPE_CHECKING:
pass
class AzureAnthropicChatCompletion(AnthropicChatCompletion):
"""
Azure Anthropic chat completion handler.
Reuses all Anthropic logic but with Azure authentication.
"""
def __init__(self) -> None:
super().__init__()
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_llm_provider: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
acompletion=None,
logger_fn=None,
headers={},
client=None,
):
"""
Completion method that uses Azure authentication instead of Anthropic's x-api-key.
All other logic is the same as AnthropicChatCompletion.
"""
optional_params = copy.deepcopy(optional_params)
stream = optional_params.pop("stream", None)
json_mode: bool = optional_params.pop("json_mode", False)
is_vertex_request: bool = optional_params.pop("is_vertex_request", False)
_is_function_call = False
messages = copy.deepcopy(messages)
# Use AzureAnthropicConfig for both azure_anthropic and azure_ai Claude models
config = AzureAnthropicConfig()
headers = config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
messages=messages,
optional_params={**optional_params, "is_vertex_request": is_vertex_request},
litellm_params=litellm_params,
)
data = config.transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
print_verbose(f"_is_function_call: {_is_function_call}")
if acompletion is True:
if (
stream is True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose("makes async azure anthropic streaming POST request")
data["stream"] = stream
return self.acompletion_stream_function(
model=model,
messages=messages,
data=data,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
_is_function_call=_is_function_call,
json_mode=json_mode,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
)
else:
return self.acompletion_function(
model=model,
messages=messages,
data=data,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
provider_config=config,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
_is_function_call=_is_function_call,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
client=client,
json_mode=json_mode,
timeout=timeout,
)
else:
## COMPLETION CALL
if (
stream is True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
data["stream"] = stream
# Import the make_sync_call from parent
from litellm.llms.anthropic.chat.handler import make_sync_call
completion_stream, response_headers = make_sync_call(
client=client,
api_base=api_base,
headers=headers, # type: ignore
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
json_mode=json_mode,
)
from litellm.llms.anthropic.common_utils import (
process_anthropic_headers,
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="azure_ai",
logging_obj=logging_obj,
_response_headers=process_anthropic_headers(response_headers),
)
else:
if client is None or not isinstance(client, HTTPHandler):
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
client = _get_httpx_client(params={"timeout": timeout})
else:
client = client
try:
response = client.post(
api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
)
except Exception as e:
from litellm.llms.anthropic.common_utils import AnthropicError
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
if error_response and hasattr(error_response, "text"):
error_text = getattr(error_response, "text", error_text)
raise AnthropicError(
message=error_text,
status_code=status_code,
headers=error_headers,
)
return config.transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
json_mode=json_mode,
)

View File

@@ -0,0 +1,166 @@
"""
Azure Anthropic messages transformation config - extends AnthropicMessagesConfig with Azure authentication
"""
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from litellm.llms.anthropic.experimental_pass_through.messages.transformation import (
AnthropicMessagesConfig,
)
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.types.router import GenericLiteLLMParams
if TYPE_CHECKING:
pass
class AzureAnthropicMessagesConfig(AnthropicMessagesConfig):
"""
Azure Anthropic messages configuration that extends AnthropicMessagesConfig.
The only difference is authentication - Azure uses x-api-key header (not api-key)
and Azure endpoint format.
"""
def validate_anthropic_messages_environment(
self,
headers: dict,
model: str,
messages: List[Any],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Tuple[dict, Optional[str]]:
"""
Validate environment and set up Azure authentication headers for /v1/messages endpoint.
Azure Anthropic uses x-api-key header (not api-key).
"""
# Convert dict to GenericLiteLLMParams if needed
if isinstance(litellm_params, dict):
if api_key and "api_key" not in litellm_params:
litellm_params = {**litellm_params, "api_key": api_key}
litellm_params_obj = GenericLiteLLMParams(**litellm_params)
else:
litellm_params_obj = litellm_params or GenericLiteLLMParams()
if api_key and not litellm_params_obj.api_key:
litellm_params_obj.api_key = api_key
# Use Azure authentication logic
headers = BaseAzureLLM._base_validate_azure_environment(
headers=headers, litellm_params=litellm_params_obj
)
# Azure Anthropic uses x-api-key header (not api-key)
# Convert api-key to x-api-key if present
if "api-key" in headers and "x-api-key" not in headers:
headers["x-api-key"] = headers.pop("api-key")
# Set anthropic-version header
if "anthropic-version" not in headers:
headers["anthropic-version"] = "2023-06-01"
# Set content-type header
if "content-type" not in headers:
headers["content-type"] = "application/json"
headers = self._update_headers_with_anthropic_beta(
headers=headers,
optional_params=optional_params,
)
return headers, api_base
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for Azure Anthropic /v1/messages endpoint.
Azure Foundry endpoint format: https://<resource-name>.services.ai.azure.com/anthropic/v1/messages
"""
from litellm.secret_managers.main import get_secret_str
api_base = api_base or get_secret_str("AZURE_API_BASE")
if api_base is None:
raise ValueError(
"Missing Azure API Base - Please set `api_base` or `AZURE_API_BASE` environment variable. "
"Expected format: https://<resource-name>.services.ai.azure.com/anthropic"
)
# Ensure the URL ends with /v1/messages
api_base = api_base.rstrip("/")
if api_base.endswith("/v1/messages"):
# Already correct
pass
elif api_base.endswith("/anthropic/v1/messages"):
# Already correct
pass
else:
# Check if /anthropic is already in the path
if "/anthropic" in api_base:
# /anthropic exists, ensure we end with /anthropic/v1/messages
# Extract the base URL up to and including /anthropic
parts = api_base.split("/anthropic", 1)
api_base = parts[0] + "/anthropic"
else:
# /anthropic not in path, add it
api_base = api_base + "/anthropic"
# Add /v1/messages
api_base = api_base + "/v1/messages"
return api_base
def _remove_scope_from_cache_control(
self, anthropic_messages_request: Dict
) -> None:
"""
Remove `scope` field from cache_control for Azure AI Foundry.
Azure AI Foundry's Anthropic endpoint does not support the `scope` field
(e.g., "global" for cross-request caching). Only `type` and `ttl` are supported.
Processes both `system` and `messages` content blocks.
"""
def _sanitize(cache_control: Any) -> None:
if isinstance(cache_control, dict):
cache_control.pop("scope", None)
def _process_content_list(content: list) -> None:
for item in content:
if isinstance(item, dict) and "cache_control" in item:
_sanitize(item["cache_control"])
if "system" in anthropic_messages_request:
system = anthropic_messages_request["system"]
if isinstance(system, list):
_process_content_list(system)
if "messages" in anthropic_messages_request:
for message in anthropic_messages_request["messages"]:
if isinstance(message, dict) and "content" in message:
content = message["content"]
if isinstance(content, list):
_process_content_list(content)
def transform_anthropic_messages_request(
self,
model: str,
messages: List[Dict],
anthropic_messages_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
anthropic_messages_request = super().transform_anthropic_messages_request(
model=model,
messages=messages,
anthropic_messages_optional_request_params=anthropic_messages_optional_request_params,
litellm_params=litellm_params,
headers=headers,
)
self._remove_scope_from_cache_control(anthropic_messages_request)
return anthropic_messages_request

View File

@@ -0,0 +1,117 @@
"""
Azure Anthropic transformation config - extends AnthropicConfig with Azure authentication
"""
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.types.llms.openai import AllMessageValues
from litellm.types.router import GenericLiteLLMParams
if TYPE_CHECKING:
pass
class AzureAnthropicConfig(AnthropicConfig):
"""
Azure Anthropic configuration that extends AnthropicConfig.
The only difference is authentication - Azure uses api-key header or Azure AD token
instead of x-api-key header.
"""
@property
def custom_llm_provider(self) -> Optional[str]:
return "azure_ai"
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: Union[dict, GenericLiteLLMParams],
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
"""
Validate environment and set up Azure authentication headers.
Azure supports:
1. API key via 'api-key' header
2. Azure AD token via 'Authorization: Bearer <token>' header
"""
# Convert dict to GenericLiteLLMParams if needed
if isinstance(litellm_params, dict):
# Ensure api_key is included if provided
if api_key and "api_key" not in litellm_params:
litellm_params = {**litellm_params, "api_key": api_key}
litellm_params_obj = GenericLiteLLMParams(**litellm_params)
else:
litellm_params_obj = litellm_params or GenericLiteLLMParams()
# Set api_key if provided and not already set
if api_key and not litellm_params_obj.api_key:
litellm_params_obj.api_key = api_key
# Use Azure authentication logic
headers = BaseAzureLLM._base_validate_azure_environment(
headers=headers, litellm_params=litellm_params_obj
)
# Get tools and other anthropic-specific setup
tools = optional_params.get("tools")
prompt_caching_set = self.is_cache_control_set(messages=messages)
computer_tool_used = self.is_computer_tool_used(tools=tools)
mcp_server_used = self.is_mcp_server_used(
mcp_servers=optional_params.get("mcp_servers")
)
pdf_used = self.is_pdf_used(messages=messages)
file_id_used = self.is_file_id_used(messages=messages)
user_anthropic_beta_headers = self._get_user_anthropic_beta_headers(
anthropic_beta_header=headers.get("anthropic-beta")
)
# Get anthropic headers (but we'll replace x-api-key with Azure auth)
anthropic_headers = self.get_anthropic_headers(
computer_tool_used=computer_tool_used,
prompt_caching_set=prompt_caching_set,
pdf_used=pdf_used,
api_key=api_key or "", # Azure auth is already in headers
file_id_used=file_id_used,
is_vertex_request=optional_params.get("is_vertex_request", False),
user_anthropic_beta_headers=user_anthropic_beta_headers,
mcp_server_used=mcp_server_used,
)
# Merge headers - Azure auth (api-key or Authorization) takes precedence
headers = {**anthropic_headers, **headers}
# Ensure anthropic-version header is set
if "anthropic-version" not in headers:
headers["anthropic-version"] = "2023-06-01"
return headers
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform request using parent AnthropicConfig, then remove unsupported params.
Azure Anthropic doesn't support extra_body, max_retries, or stream_options parameters.
"""
# Call parent transform_request
data = super().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
# Remove unsupported parameters for Azure AI Anthropic
data.pop("extra_body", None)
data.pop("max_retries", None)
data.pop("stream_options", None)
return data

View File

@@ -0,0 +1,4 @@
"""Azure AI Foundry Model Router support."""
from .transformation import AzureModelRouterConfig
__all__ = ["AzureModelRouterConfig"]

View File

@@ -0,0 +1,119 @@
"""
Transformation for Azure AI Foundry Model Router.
The Model Router is a special Azure AI deployment that automatically routes requests
to the best available model. It has specific cost tracking requirements.
"""
from typing import Any, List, Optional
from httpx import Response
from litellm.llms.azure_ai.chat.transformation import AzureAIStudioConfig
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
class AzureModelRouterConfig(AzureAIStudioConfig):
"""
Configuration for Azure AI Foundry Model Router.
Handles:
- Stripping model_router prefix before sending to Azure API
- Preserving full model path in responses for cost tracking
- Calculating flat infrastructure costs for Model Router
"""
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform request for Model Router.
Strips the model_router/ prefix so only the deployment name is sent to Azure.
Example: model_router/azure-model-router -> azure-model-router
"""
from litellm.llms.azure_ai.common_utils import AzureFoundryModelInfo
# Get base model name (strips routing prefixes like model_router/)
base_model: str = AzureFoundryModelInfo.get_base_model(model)
return super().transform_request(
base_model, messages, optional_params, litellm_params, headers
)
def transform_response(
self,
model: str,
raw_response: Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
Transform response for Model Router.
Extracts the actual model used from the Azure response (e.g., gpt-5-nano-2025-08-07)
and returns it with the azure_ai/ prefix for proper display and cost tracking.
"""
from litellm.llms.azure_ai.common_utils import AzureFoundryModelInfo
# Get base model for the parent call (strips routing prefixes for API compatibility)
base_model: str = AzureFoundryModelInfo.get_base_model(model)
# Call parent transform_response first - this will extract the actual model
# from the raw response (e.g., "gpt-5-nano-2025-08-07")
model_response = super().transform_response(
model=base_model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)
return model_response
def calculate_additional_costs(
self, model: str, prompt_tokens: int, completion_tokens: int
) -> Optional[dict]:
"""
Calculate additional costs for Azure Model Router.
Adds a flat infrastructure cost of $0.14 per M input tokens for using the Model Router.
Args:
model: The model name (should be a model router model)
prompt_tokens: Number of prompt tokens
completion_tokens: Number of completion tokens
Returns:
Dictionary with additional costs, or None if not applicable.
"""
from litellm.llms.azure_ai.cost_calculator import (
calculate_azure_model_router_flat_cost,
)
flat_cost = calculate_azure_model_router_flat_cost(
model=model, prompt_tokens=prompt_tokens
)
if flat_cost > 0:
return {"Azure Model Router Flat Cost": flat_cost}
return None

View File

@@ -0,0 +1,3 @@
"""
LLM Calling done in `openai/openai.py`
"""

View File

@@ -0,0 +1,350 @@
import enum
from typing import Any, List, Optional, Tuple, cast
from urllib.parse import urlparse
import httpx
from httpx import Response
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.prompt_templates.common_utils import (
_audio_or_image_in_message_content,
convert_content_list_to_str,
)
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
from litellm.llms.openai.common_utils import drop_params_from_unprocessable_entity_error
from litellm.llms.openai.openai import OpenAIConfig
from litellm.llms.xai.chat.transformation import XAIChatConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import ModelResponse, ProviderField
from litellm.utils import _add_path_to_api_base, supports_tool_choice
class AzureFoundryErrorStrings(str, enum.Enum):
SET_EXTRA_PARAMETERS_TO_PASS_THROUGH = "Set extra-parameters to 'pass-through'"
class AzureAIStudioConfig(OpenAIConfig):
def get_supported_openai_params(self, model: str) -> List:
model_supports_tool_choice = True # azure ai supports this by default
if not supports_tool_choice(model=f"azure_ai/{model}"):
model_supports_tool_choice = False
supported_params = super().get_supported_openai_params(model)
if not model_supports_tool_choice:
filtered_supported_params = []
for param in supported_params:
if param != "tool_choice":
filtered_supported_params.append(param)
supported_params = filtered_supported_params
# Filter out unsupported parameters for specific models
if not self._supports_stop_reason(model):
supported_params = [param for param in supported_params if param != "stop"]
return supported_params
def _supports_stop_reason(self, model: str) -> bool:
"""
Check if the model supports stop tokens.
"""
if "grok" in model:
# Reuse Xai method for Grok model
xai_config = XAIChatConfig()
return xai_config._supports_stop_reason(model)
return True
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key:
if api_base and self._should_use_api_key_header(api_base):
headers["api-key"] = api_key
else:
headers["Authorization"] = f"Bearer {api_key}"
else:
# No api_key provided — fall back to Azure AD token-based auth
litellm_params_obj = GenericLiteLLMParams(
**(litellm_params if isinstance(litellm_params, dict) else {})
)
headers = BaseAzureLLM._base_validate_azure_environment(
headers=headers, litellm_params=litellm_params_obj
)
headers["Content-Type"] = "application/json"
return headers
def _should_use_api_key_header(self, api_base: str) -> bool:
"""
Returns True if the request should use `api-key` header for authentication.
"""
parsed_url = urlparse(api_base)
host = parsed_url.hostname
if host and (
host.endswith(".services.ai.azure.com")
or host.endswith(".openai.azure.com")
):
return True
return False
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Constructs a complete URL for the API request.
Args:
- api_base: Base URL, e.g.,
"https://litellm8397336933.services.ai.azure.com"
OR
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
- model: Model name.
- optional_params: Additional query parameters, including "api_version".
- stream: If streaming is required (optional).
Returns:
- A complete URL string, e.g.,
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
"""
if api_base is None:
raise ValueError(
f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`"
)
original_url = httpx.URL(api_base)
# Extract api_version or use default
api_version = cast(Optional[str], litellm_params.get("api_version"))
# Create a new dictionary with existing params
query_params = dict(original_url.params)
# Add api_version if needed
if "api-version" not in query_params and api_version:
query_params["api-version"] = api_version
# Add the path to the base URL
if "services.ai.azure.com" in api_base:
new_url = _add_path_to_api_base(
api_base=api_base, ending_path="/models/chat/completions"
)
else:
new_url = _add_path_to_api_base(
api_base=api_base, ending_path="/chat/completions"
)
# Use the new query_params dictionary
final_url = httpx.URL(new_url).copy_with(params=query_params)
return str(final_url)
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
ProviderField(
field_name="api_key",
field_type="string",
field_description="Your Azure AI Studio API Key.",
field_value="zEJ...",
),
ProviderField(
field_name="api_base",
field_type="string",
field_description="Your Azure AI Studio API Base.",
field_value="https://Mistral-serverless.",
),
]
def _transform_messages(
self,
messages: List[AllMessageValues],
model: str,
) -> List:
"""
- Azure AI Studio doesn't support content as a list. This handles:
1. Transforms list content to a string.
2. If message contains an image or audio, send as is (user-intended)
"""
for message in messages:
# Do nothing if the message contains an image or audio
if _audio_or_image_in_message_content(message):
continue
texts = convert_content_list_to_str(message=message)
if texts:
message["content"] = texts
return messages
def _is_azure_openai_model(self, model: str, api_base: Optional[str]) -> bool:
try:
if "/" in model:
model = model.split("/", 1)[1]
if (
model in litellm.open_ai_chat_completion_models
or model in litellm.open_ai_text_completion_models
or model in litellm.open_ai_embedding_models
):
return True
except Exception:
return False
return False
def _get_openai_compatible_provider_info(
self,
model: str,
api_base: Optional[str],
api_key: Optional[str],
custom_llm_provider: str,
) -> Tuple[Optional[str], Optional[str], str]:
api_base = api_base or get_secret_str("AZURE_AI_API_BASE")
dynamic_api_key = api_key or get_secret_str("AZURE_AI_API_KEY")
if self._is_azure_openai_model(model=model, api_base=api_base):
verbose_logger.debug(
"Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format(
model
)
)
custom_llm_provider = "azure"
return api_base, dynamic_api_key, custom_llm_provider
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
extra_body = optional_params.pop("extra_body", {})
if extra_body and isinstance(extra_body, dict):
optional_params.update(extra_body)
optional_params.pop("max_retries", None)
return super().transform_request(
model, messages, optional_params, litellm_params, headers
)
def transform_response(
self,
model: str,
raw_response: Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
model_response.model = f"azure_ai/{model}"
return super().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)
def should_retry_llm_api_inside_llm_translation_on_http_error(
self, e: httpx.HTTPStatusError, litellm_params: dict
) -> bool:
should_drop_params = litellm_params.get("drop_params") or litellm.drop_params
error_text = e.response.text
if should_drop_params and "Extra inputs are not permitted" in error_text:
return True
elif (
"unknown field: parameter index is not a valid field" in error_text
): # remove index from tool calls
return True
elif (
AzureFoundryErrorStrings.SET_EXTRA_PARAMETERS_TO_PASS_THROUGH.value
in error_text
): # remove extra-parameters from tool calls
return True
return super().should_retry_llm_api_inside_llm_translation_on_http_error(
e=e, litellm_params=litellm_params
)
@property
def max_retry_on_unprocessable_entity_error(self) -> int:
return 2
def transform_request_on_unprocessable_entity_error(
self, e: httpx.HTTPStatusError, request_data: dict
) -> dict:
_messages = cast(Optional[List[AllMessageValues]], request_data.get("messages"))
if (
"unknown field: parameter index is not a valid field" in e.response.text
and _messages is not None
):
litellm.remove_index_from_tool_calls(
messages=_messages,
)
elif (
AzureFoundryErrorStrings.SET_EXTRA_PARAMETERS_TO_PASS_THROUGH.value
in e.response.text
):
request_data = self._drop_extra_params_from_request_data(
request_data, e.response.text
)
data = drop_params_from_unprocessable_entity_error(e=e, data=request_data)
return data
def _drop_extra_params_from_request_data(
self, request_data: dict, error_text: str
) -> dict:
params_to_drop = self._extract_params_to_drop_from_error_text(error_text)
if params_to_drop:
for param in params_to_drop:
if param in request_data:
request_data.pop(param, None)
return request_data
def _extract_params_to_drop_from_error_text(
self, error_text: str
) -> Optional[List[str]]:
"""
Error text looks like this"
"Extra parameters ['stream_options', 'extra-parameters'] are not allowed when extra-parameters is not set or set to be 'error'.
"""
import re
# Extract parameters within square brackets
match = re.search(r"\[(.*?)\]", error_text)
if not match:
return []
# Parse the extracted string into a list of parameter names
params_str = match.group(1)
params = []
for param in params_str.split(","):
# Clean up the parameter name (remove quotes, spaces)
clean_param = param.strip().strip("'").strip('"')
if clean_param:
params.append(clean_param)
return params

View File

@@ -0,0 +1,176 @@
from typing import List, Literal, Optional
import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo, BaseTokenCounter
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
class AzureFoundryModelInfo(BaseLLMModelInfo):
"""Model info for Azure AI / Azure Foundry models."""
def __init__(self, model: Optional[str] = None):
self._model = model
@staticmethod
def get_azure_ai_route(model: str) -> Literal["agents", "model_router", "default"]:
"""
Get the Azure AI route for the given model.
Similar to BedrockModelInfo.get_bedrock_route().
Supported routes:
- agents: azure_ai/agents/<agent_id>
- model_router: azure_ai/model_router/<actual-model-name> or models with "model-router"/"model_router" in name
- default: standard models
"""
if "agents/" in model:
return "agents"
# Detect model router by prefix (model_router/<name>) or by name containing "model-router"/"model_router"
model_lower = model.lower()
if (
"model_router/" in model_lower
or "model-router/" in model_lower
or "model-router" in model_lower
or "model_router" in model_lower
):
return "model_router"
return "default"
@staticmethod
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
return api_base or litellm.api_base or get_secret_str("AZURE_AI_API_BASE")
@staticmethod
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
return (
api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("AZURE_AI_API_KEY")
)
@property
def api_version(self, api_version: Optional[str] = None) -> Optional[str]:
api_version = (
api_version or litellm.api_version or get_secret_str("AZURE_API_VERSION")
)
return api_version
def get_token_counter(self) -> Optional[BaseTokenCounter]:
"""
Factory method to create a token counter for Azure AI.
Returns:
AzureAIAnthropicTokenCounter for Claude models, None otherwise.
"""
# Only return token counter for Claude models
if self._model and "claude" in self._model.lower():
from litellm.llms.azure_ai.anthropic.count_tokens.token_counter import (
AzureAIAnthropicTokenCounter,
)
return AzureAIAnthropicTokenCounter()
return None
def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
"""
Returns a list of models supported by Azure AI.
Azure AI doesn't have a standard model listing endpoint,
so this returns an empty list.
"""
return []
#########################################################
# Not implemented methods
#########################################################
@staticmethod
def strip_model_router_prefix(model: str) -> str:
"""
Strip the model_router prefix from model name.
Examples:
- "model_router/gpt-4o" -> "gpt-4o"
- "model-router/gpt-4o" -> "gpt-4o"
- "gpt-4o" -> "gpt-4o"
Args:
model: Model name potentially with model_router prefix
Returns:
Model name without the prefix
"""
if "model_router/" in model:
return model.split("model_router/", 1)[1]
if "model-router/" in model:
return model.split("model-router/", 1)[1]
return model
@staticmethod
def get_base_model(model: str) -> str:
"""
Get the base model name, stripping any Azure AI routing prefixes.
Args:
model: Model name potentially with routing prefixes
Returns:
Base model name
"""
# Strip model_router prefix if present
model = AzureFoundryModelInfo.strip_model_router_prefix(model)
return model
@staticmethod
def get_azure_ai_config_for_model(model: str):
"""
Get the appropriate Azure AI config class for the given model.
Routes to specialized configs based on model type:
- Model Router: AzureModelRouterConfig
- Claude models: AzureAnthropicConfig
- Default: AzureAIStudioConfig
Args:
model: The model name
Returns:
The appropriate config instance
"""
azure_ai_route = AzureFoundryModelInfo.get_azure_ai_route(model)
if azure_ai_route == "model_router":
from litellm.llms.azure_ai.azure_model_router.transformation import (
AzureModelRouterConfig,
)
return AzureModelRouterConfig()
elif "claude" in model.lower():
from litellm.llms.azure_ai.anthropic.transformation import (
AzureAnthropicConfig,
)
return AzureAnthropicConfig()
else:
from litellm.llms.azure_ai.chat.transformation import AzureAIStudioConfig
return AzureAIStudioConfig()
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""Azure Foundry sends api key in query params"""
raise NotImplementedError(
"Azure Foundry does not support environment validation"
)

View File

@@ -0,0 +1,135 @@
"""
Azure AI cost calculation helper.
Handles Azure AI Foundry Model Router flat cost and other Azure AI specific pricing.
"""
from typing import Optional, Tuple
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
from litellm.types.utils import Usage
from litellm.utils import get_model_info
def _is_azure_model_router(model: str) -> bool:
"""
Check if the model is Azure AI Foundry Model Router.
Detects patterns like:
- "azure-model-router"
- "model-router"
- "model_router/<actual-model>"
- "model-router/<actual-model>"
Args:
model: The model name
Returns:
bool: True if this is a model router model
"""
model_lower = model.lower()
return (
"model-router" in model_lower
or "model_router" in model_lower
or model_lower == "azure-model-router"
)
def calculate_azure_model_router_flat_cost(model: str, prompt_tokens: int) -> float:
"""
Calculate the flat cost for Azure AI Foundry Model Router.
Args:
model: The model name (should be a model router model)
prompt_tokens: Number of prompt tokens
Returns:
float: The flat cost in USD, or 0.0 if not applicable
"""
if not _is_azure_model_router(model):
return 0.0
# Get the model router pricing from model_prices_and_context_window.json
# Use "model_router" as the key (without actual model name suffix)
model_info = get_model_info(model="model_router", custom_llm_provider="azure_ai")
router_flat_cost_per_token = model_info.get("input_cost_per_token", 0)
if router_flat_cost_per_token > 0:
return prompt_tokens * router_flat_cost_per_token
return 0.0
def cost_per_token(
model: str,
usage: Usage,
response_time_ms: Optional[float] = 0.0,
request_model: Optional[str] = None,
) -> Tuple[float, float]:
"""
Calculate the cost per token for Azure AI models.
For Azure AI Foundry Model Router:
- Adds a flat cost of $0.14 per million input tokens (from model_prices_and_context_window.json)
- Plus the cost of the actual model used (handled by generic_cost_per_token)
Args:
model: str, the model name without provider prefix (from response)
usage: LiteLLM Usage block
response_time_ms: Optional response time in milliseconds
request_model: Optional[str], the original request model name (to detect router usage)
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
Raises:
ValueError: If the model is not found in the cost map and cost cannot be calculated
(except for Model Router models where we return just the routing flat cost)
"""
prompt_cost = 0.0
completion_cost = 0.0
# Determine if this was a model router request
# Check both the response model and the request model
is_router_request = _is_azure_model_router(model) or (
request_model is not None and _is_azure_model_router(request_model)
)
# Calculate base cost using generic cost calculator
# This may raise an exception if the model is not in the cost map
try:
prompt_cost, completion_cost = generic_cost_per_token(
model=model,
usage=usage,
custom_llm_provider="azure_ai",
)
except Exception as e:
# For Model Router, the model name (e.g., "azure-model-router") may not be in the cost map
# because it's a routing service, not an actual model. In this case, we continue
# to calculate just the routing flat cost.
if not _is_azure_model_router(model):
# Re-raise for non-router models - they should have pricing defined
raise
verbose_logger.debug(
f"Azure AI Model Router: model '{model}' not in cost map, calculating routing flat cost only. Error: {e}"
)
# Add flat cost for Azure Model Router
# The flat cost is defined in model_prices_and_context_window.json for azure_ai/model_router
if is_router_request:
# Use the request model for flat cost calculation if available, otherwise use response model
router_model_for_calc = request_model if request_model else model
router_flat_cost = calculate_azure_model_router_flat_cost(
router_model_for_calc, usage.prompt_tokens
)
if router_flat_cost > 0:
verbose_logger.debug(
f"Azure AI Model Router flat cost: ${router_flat_cost:.6f} "
f"({usage.prompt_tokens} tokens × ${router_flat_cost / usage.prompt_tokens:.9f}/token)"
)
# Add flat cost to prompt cost
prompt_cost += router_flat_cost
return prompt_cost, completion_cost

View File

@@ -0,0 +1 @@
from .handler import AzureAIEmbedding

View File

@@ -0,0 +1,98 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Azure AI Cohere's /v1/embed.
Why separate file? Make it easy to see how transformation works
Convers
- Cohere request format
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
"""
from typing import List, Optional, Tuple
from litellm.types.llms.azure_ai import ImageEmbeddingInput, ImageEmbeddingRequest
from litellm.types.llms.openai import EmbeddingCreateParams
from litellm.types.utils import EmbeddingResponse, Usage
from litellm.utils import is_base64_encoded
class AzureAICohereConfig:
def __init__(self) -> None:
pass
def _map_azure_model_group(self, model: str) -> str:
if model == "offer-cohere-embed-multili-paygo":
return "Cohere-embed-v3-multilingual"
elif model == "offer-cohere-embed-english-paygo":
return "Cohere-embed-v3-english"
return model
def _transform_request_image_embeddings(
self, input: List[str], optional_params: dict
) -> ImageEmbeddingRequest:
"""
Assume all str in list is base64 encoded string
"""
image_input: List[ImageEmbeddingInput] = []
for i in input:
embedding_input = ImageEmbeddingInput(image=i)
image_input.append(embedding_input)
return ImageEmbeddingRequest(input=image_input, **optional_params)
def _transform_request(
self, input: List[str], optional_params: dict, model: str
) -> Tuple[ImageEmbeddingRequest, EmbeddingCreateParams, List[int]]:
"""
Return the list of input to `/image/embeddings`, `/v1/embeddings`, list of image_embedding_idx for recombination
"""
image_embeddings: List[str] = []
image_embedding_idx: List[int] = []
for idx, i in enumerate(input):
"""
- is base64 -> route to image embeddings
- is ImageEmbeddingInput -> route to image embeddings
- else -> route to `/v1/embeddings`
"""
if is_base64_encoded(i):
image_embeddings.append(i)
image_embedding_idx.append(idx)
## REMOVE IMAGE EMBEDDINGS FROM input list
filtered_input = [
item for idx, item in enumerate(input) if idx not in image_embedding_idx
]
v1_embeddings_request = EmbeddingCreateParams(
input=filtered_input, model=model, **optional_params
)
image_embeddings_request = self._transform_request_image_embeddings(
input=image_embeddings, optional_params=optional_params
)
return image_embeddings_request, v1_embeddings_request, image_embedding_idx
def _transform_response(self, response: EmbeddingResponse) -> EmbeddingResponse:
additional_headers: Optional[dict] = response._hidden_params.get(
"additional_headers"
)
if additional_headers:
# CALCULATE USAGE
input_tokens: Optional[str] = additional_headers.get(
"llm_provider-num_tokens"
)
if input_tokens:
if response.usage:
response.usage.prompt_tokens = int(input_tokens)
else:
response.usage = Usage(prompt_tokens=int(input_tokens))
# SET MODEL
base_model: Optional[str] = additional_headers.get(
"llm_provider-azureml-model-group"
)
if base_model:
response.model = self._map_azure_model_group(base_model)
return response

View File

@@ -0,0 +1,292 @@
from typing import List, Optional, Union
from openai import OpenAI
import litellm
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.llms.openai.openai import OpenAIChatCompletion
from litellm.types.llms.azure_ai import ImageEmbeddingRequest
from litellm.types.utils import EmbeddingResponse
from litellm.utils import convert_to_model_response_object
from .cohere_transformation import AzureAICohereConfig
class AzureAIEmbedding(OpenAIChatCompletion):
def _process_response(
self,
image_embedding_responses: Optional[List],
text_embedding_responses: Optional[List],
image_embeddings_idx: List[int],
model_response: EmbeddingResponse,
input: List,
):
combined_responses = []
if (
image_embedding_responses is not None
and text_embedding_responses is not None
):
# Combine and order the results
text_idx = 0
image_idx = 0
for idx in range(len(input)):
if idx in image_embeddings_idx:
combined_responses.append(image_embedding_responses[image_idx])
image_idx += 1
else:
combined_responses.append(text_embedding_responses[text_idx])
text_idx += 1
model_response.data = combined_responses
elif image_embedding_responses is not None:
model_response.data = image_embedding_responses
elif text_embedding_responses is not None:
model_response.data = text_embedding_responses
response = AzureAICohereConfig()._transform_response(response=model_response) # type: ignore
return response
async def async_image_embedding(
self,
model: str,
data: ImageEmbeddingRequest,
timeout: float,
logging_obj,
model_response: EmbeddingResponse,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> EmbeddingResponse:
if client is None or not isinstance(client, AsyncHTTPHandler):
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.AZURE_AI,
params={"timeout": timeout},
)
url = "{}/images/embeddings".format(api_base)
response = await client.post(
url=url,
json=data, # type: ignore
headers={"Authorization": "Bearer {}".format(api_key)},
)
embedding_response = response.json()
embedding_headers = dict(response.headers)
returned_response: EmbeddingResponse = convert_to_model_response_object( # type: ignore
response_object=embedding_response,
model_response_object=model_response,
response_type="embedding",
stream=False,
_response_headers=embedding_headers,
)
return returned_response
def image_embedding(
self,
model: str,
data: ImageEmbeddingRequest,
timeout: float,
logging_obj,
model_response: EmbeddingResponse,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
if api_base is None:
raise ValueError(
"api_base is None. Please set AZURE_AI_API_BASE or dynamically via `api_base` param, to make the request."
)
if api_key is None:
raise ValueError(
"api_key is None. Please set AZURE_AI_API_KEY or dynamically via `api_key` param, to make the request."
)
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(timeout=timeout, concurrent_limit=1)
url = "{}/images/embeddings".format(api_base)
response = client.post(
url=url,
json=data, # type: ignore
headers={"Authorization": "Bearer {}".format(api_key)},
)
embedding_response = response.json()
embedding_headers = dict(response.headers)
returned_response: EmbeddingResponse = convert_to_model_response_object( # type: ignore
response_object=embedding_response,
model_response_object=model_response,
response_type="embedding",
stream=False,
_response_headers=embedding_headers,
)
return returned_response
async def async_embedding(
self,
model: str,
input: List,
timeout: float,
logging_obj,
model_response: EmbeddingResponse,
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
) -> EmbeddingResponse:
(
image_embeddings_request,
v1_embeddings_request,
image_embeddings_idx,
) = AzureAICohereConfig()._transform_request(
input=input, optional_params=optional_params, model=model
)
image_embedding_responses: Optional[List] = None
text_embedding_responses: Optional[List] = None
if image_embeddings_request["input"]:
image_response = await self.async_image_embedding(
model=model,
data=image_embeddings_request,
timeout=timeout,
logging_obj=logging_obj,
model_response=model_response,
optional_params=optional_params,
api_key=api_key,
api_base=api_base,
client=client,
)
image_embedding_responses = image_response.data
if image_embedding_responses is None:
raise Exception("/image/embeddings route returned None Embeddings.")
if v1_embeddings_request["input"]:
response: EmbeddingResponse = await super().embedding( # type: ignore
model=model,
input=input,
timeout=timeout,
logging_obj=logging_obj,
model_response=model_response,
optional_params=optional_params,
api_key=api_key,
api_base=api_base,
client=client,
aembedding=True,
)
text_embedding_responses = response.data
if text_embedding_responses is None:
raise Exception("/v1/embeddings route returned None Embeddings.")
return self._process_response(
image_embedding_responses=image_embedding_responses,
text_embedding_responses=text_embedding_responses,
image_embeddings_idx=image_embeddings_idx,
model_response=model_response,
input=input,
)
def embedding(
self,
model: str,
input: List,
timeout: float,
logging_obj,
model_response: EmbeddingResponse,
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
aembedding=None,
max_retries: Optional[int] = None,
shared_session=None,
) -> EmbeddingResponse:
"""
- Separate image url from text
-> route image url call to `/image/embeddings`
-> route text call to `/v1/embeddings` (OpenAI route)
assemble result in-order, and return
"""
if aembedding is True:
return self.async_embedding( # type: ignore
model,
input,
timeout,
logging_obj,
model_response,
optional_params,
api_key,
api_base,
client,
)
(
image_embeddings_request,
v1_embeddings_request,
image_embeddings_idx,
) = AzureAICohereConfig()._transform_request(
input=input, optional_params=optional_params, model=model
)
image_embedding_responses: Optional[List] = None
text_embedding_responses: Optional[List] = None
if image_embeddings_request["input"]:
image_response = self.image_embedding(
model=model,
data=image_embeddings_request,
timeout=timeout,
logging_obj=logging_obj,
model_response=model_response,
optional_params=optional_params,
api_key=api_key,
api_base=api_base,
client=client,
)
image_embedding_responses = image_response.data
if image_embedding_responses is None:
raise Exception("/image/embeddings route returned None Embeddings.")
if v1_embeddings_request["input"]:
response: EmbeddingResponse = super().embedding( # type: ignore
model,
input,
timeout,
logging_obj,
model_response,
optional_params,
api_key,
api_base,
client=(
client
if client is not None and isinstance(client, OpenAI)
else None
),
aembedding=aembedding,
shared_session=shared_session,
)
text_embedding_responses = response.data
if text_embedding_responses is None:
raise Exception("/v1/embeddings route returned None Embeddings.")
return self._process_response(
image_embedding_responses=image_embedding_responses,
text_embedding_responses=text_embedding_responses,
image_embeddings_idx=image_embeddings_idx,
model_response=model_response,
input=input,
)

View File

@@ -0,0 +1,28 @@
from litellm.llms.azure_ai.image_generation.flux_transformation import (
AzureFoundryFluxImageGenerationConfig,
)
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
from .flux2_transformation import AzureFoundryFlux2ImageEditConfig
from .transformation import AzureFoundryFluxImageEditConfig
__all__ = ["AzureFoundryFluxImageEditConfig", "AzureFoundryFlux2ImageEditConfig"]
def get_azure_ai_image_edit_config(model: str) -> BaseImageEditConfig:
"""
Get the appropriate image edit config for an Azure AI model.
- FLUX 2 models use JSON with base64 image
- FLUX 1 models use multipart/form-data
"""
# Check if it's a FLUX 2 model
if AzureFoundryFluxImageGenerationConfig.is_flux2_model(model):
return AzureFoundryFlux2ImageEditConfig()
# Default to FLUX 1 config for other FLUX models
model_normalized = model.lower().replace("-", "").replace("_", "")
if model_normalized == "" or "flux" in model_normalized:
return AzureFoundryFluxImageEditConfig()
raise ValueError(f"Model {model} is not supported for Azure AI image editing.")

View File

@@ -0,0 +1,172 @@
import base64
from io import BufferedReader
from typing import Any, Dict, Optional, Tuple
from httpx._types import RequestFiles
import litellm
from litellm.llms.azure_ai.common_utils import AzureFoundryModelInfo
from litellm.llms.azure_ai.image_generation.flux_transformation import (
AzureFoundryFluxImageGenerationConfig,
)
from litellm.llms.openai.image_edit.transformation import OpenAIImageEditConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.images.main import ImageEditOptionalRequestParams
from litellm.types.llms.openai import FileTypes
from litellm.types.router import GenericLiteLLMParams
class AzureFoundryFlux2ImageEditConfig(OpenAIImageEditConfig):
"""
Azure AI Foundry FLUX 2 image edit config
Supports FLUX 2 models (e.g., flux.2-pro) for image editing.
Uses the same /providers/blackforestlabs/v1/flux-2-pro endpoint as image generation,
with the image passed as base64 in JSON body.
"""
def get_supported_openai_params(self, model: str) -> list:
"""
FLUX 2 supports a subset of OpenAI image edit params
"""
return [
"prompt",
"image",
"model",
"n",
"size",
]
def map_openai_params(
self,
image_edit_optional_params: ImageEditOptionalRequestParams,
model: str,
drop_params: bool,
) -> Dict:
"""
Map OpenAI params to FLUX 2 params.
FLUX 2 uses the same param names as OpenAI for supported params.
"""
mapped_params: Dict[str, Any] = {}
supported_params = self.get_supported_openai_params(model)
for key, value in dict(image_edit_optional_params).items():
if key in supported_params and value is not None:
mapped_params[key] = value
return mapped_params
def use_multipart_form_data(self) -> bool:
"""FLUX 2 uses JSON requests, not multipart/form-data."""
return False
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
"""
Validate Azure AI Foundry environment and set up authentication
"""
api_key = AzureFoundryModelInfo.get_api_key(api_key)
if not api_key:
raise ValueError(
f"Azure AI API key is required for model {model}. Set AZURE_AI_API_KEY environment variable or pass api_key parameter."
)
headers.update(
{
"Api-Key": api_key,
"Content-Type": "application/json",
}
)
return headers
def transform_image_edit_request(
self,
model: str,
prompt: Optional[str],
image: Optional[FileTypes],
image_edit_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[Dict, RequestFiles]:
"""
Transform image edit request for FLUX 2.
FLUX 2 uses the same endpoint for generation and editing,
with the image passed as base64 in the JSON body.
"""
if prompt is None:
raise ValueError("FLUX 2 image edit requires a prompt.")
if image is None:
raise ValueError("FLUX 2 image edit requires an image.")
image_b64 = self._convert_image_to_base64(image)
# Build request body with required params
request_body: Dict[str, Any] = {
"prompt": prompt,
"image": image_b64,
"model": model,
}
# Add mapped optional params (already filtered by map_openai_params)
request_body.update(image_edit_optional_request_params)
# Return JSON body and empty files list (FLUX 2 doesn't use multipart)
return request_body, []
def _convert_image_to_base64(self, image: Any) -> str:
"""Convert image file to base64 string"""
# Handle list of images (take first one)
if isinstance(image, list):
if len(image) == 0:
raise ValueError("Empty image list provided")
image = image[0]
if isinstance(image, BufferedReader):
image_bytes = image.read()
image.seek(0) # Reset file pointer for potential reuse
elif isinstance(image, bytes):
image_bytes = image
elif hasattr(image, "read"):
image_bytes = image.read() # type: ignore
else:
raise ValueError(f"Unsupported image type: {type(image)}")
return base64.b64encode(image_bytes).decode("utf-8")
def get_complete_url(
self,
model: str,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Constructs a complete URL for Azure AI Foundry FLUX 2 image edits.
Uses the same /providers/blackforestlabs/v1/flux-2-pro endpoint as image generation.
"""
api_base = AzureFoundryModelInfo.get_api_base(api_base)
if api_base is None:
raise ValueError(
"Azure AI API base is required. Set AZURE_AI_API_BASE environment variable or pass api_base parameter."
)
api_version = (
litellm_params.get("api_version")
or litellm.api_version
or get_secret_str("AZURE_AI_API_VERSION")
or "preview"
)
return AzureFoundryFluxImageGenerationConfig.get_flux2_image_generation_url(
api_base=api_base,
model=model,
api_version=api_version,
)

View File

@@ -0,0 +1,101 @@
from typing import Optional
import httpx
import litellm
from litellm.llms.azure_ai.common_utils import AzureFoundryModelInfo
from litellm.llms.openai.image_edit.transformation import OpenAIImageEditConfig
from litellm.secret_managers.main import get_secret_str
from litellm.utils import _add_path_to_api_base
class AzureFoundryFluxImageEditConfig(OpenAIImageEditConfig):
"""
Azure AI Foundry FLUX image edit config
Supports FLUX models including FLUX-1-kontext-pro for image editing.
Azure AI Foundry FLUX models handle image editing through the /images/edits endpoint,
same as standard Azure OpenAI models. The request format uses multipart/form-data
with image files and prompt.
"""
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
"""
Validate Azure AI Foundry environment and set up authentication
Uses Api-Key header format
"""
api_key = AzureFoundryModelInfo.get_api_key(api_key)
if not api_key:
raise ValueError(
f"Azure AI API key is required for model {model}. Set AZURE_AI_API_KEY environment variable or pass api_key parameter."
)
headers.update(
{
"Api-Key": api_key, # Azure AI Foundry uses Api-Key header format
}
)
return headers
def get_complete_url(
self,
model: str,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Constructs a complete URL for Azure AI Foundry image edits API request.
Azure AI Foundry FLUX models handle image editing through the /images/edits
endpoint.
Args:
- model: Model name (deployment name for Azure AI Foundry)
- api_base: Base URL for Azure AI endpoint
- litellm_params: Additional parameters including api_version
Returns:
- Complete URL for the image edits endpoint
"""
api_base = AzureFoundryModelInfo.get_api_base(api_base)
if api_base is None:
raise ValueError(
"Azure AI API base is required. Set AZURE_AI_API_BASE environment variable or pass api_base parameter."
)
api_version = (
litellm_params.get("api_version")
or litellm.api_version
or get_secret_str("AZURE_AI_API_VERSION")
)
if api_version is None:
# API version is mandatory for Azure AI Foundry
raise ValueError(
"Azure API version is required. Set AZURE_AI_API_VERSION environment variable or pass api_version parameter."
)
# Add the path to the base URL using the model as deployment name
# Azure AI Foundry FLUX models use /images/edits for editing
if "/openai/deployments/" in api_base:
new_url = _add_path_to_api_base(
api_base=api_base,
ending_path="/images/edits",
)
else:
new_url = _add_path_to_api_base(
api_base=api_base,
ending_path=f"/openai/deployments/{model}/images/edits",
)
# Use the new query_params dictionary
final_url = httpx.URL(new_url).copy_with(params={"api-version": api_version})
return str(final_url)

View File

@@ -0,0 +1,33 @@
from litellm._logging import verbose_logger
from litellm.llms.base_llm.image_generation.transformation import (
BaseImageGenerationConfig,
)
from .dall_e_2_transformation import AzureFoundryDallE2ImageGenerationConfig
from .dall_e_3_transformation import AzureFoundryDallE3ImageGenerationConfig
from .flux_transformation import AzureFoundryFluxImageGenerationConfig
from .gpt_transformation import AzureFoundryGPTImageGenerationConfig
__all__ = [
"AzureFoundryFluxImageGenerationConfig",
"AzureFoundryGPTImageGenerationConfig",
"AzureFoundryDallE2ImageGenerationConfig",
"AzureFoundryDallE3ImageGenerationConfig",
]
def get_azure_ai_image_generation_config(model: str) -> BaseImageGenerationConfig:
model = model.lower()
model = model.replace("-", "")
model = model.replace("_", "")
if model == "" or "dalle2" in model: # empty model is dall-e-2
return AzureFoundryDallE2ImageGenerationConfig()
elif "dalle3" in model:
return AzureFoundryDallE3ImageGenerationConfig()
elif "flux" in model:
return AzureFoundryFluxImageGenerationConfig()
else:
verbose_logger.debug(
f"Using AzureGPTImageGenerationConfig for model: {model}. This follows the gpt-image-1 model format."
)
return AzureFoundryGPTImageGenerationConfig()

View File

@@ -0,0 +1,27 @@
from typing import Any
import litellm
from litellm.types.utils import ImageResponse
def cost_calculator(
model: str,
image_response: Any,
) -> float:
"""
Recraft image generation cost calculator
"""
_model_info = litellm.get_model_info(
model=model,
custom_llm_provider=litellm.LlmProviders.AZURE_AI.value,
)
output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
num_images: int = 0
if isinstance(image_response, ImageResponse):
if image_response.data:
num_images = len(image_response.data)
return output_cost_per_image * num_images
else:
raise ValueError(
f"image_response must be of type ImageResponse got type={type(image_response)}"
)

View File

@@ -0,0 +1,9 @@
from litellm.llms.openai.image_generation import DallE2ImageGenerationConfig
class AzureFoundryDallE2ImageGenerationConfig(DallE2ImageGenerationConfig):
"""
Azure dall-e-2 image generation config
"""
pass

View File

@@ -0,0 +1,9 @@
from litellm.llms.openai.image_generation import DallE3ImageGenerationConfig
class AzureFoundryDallE3ImageGenerationConfig(DallE3ImageGenerationConfig):
"""
Azure dall-e-3 image generation config
"""
pass

View File

@@ -0,0 +1,68 @@
from typing import Optional
from litellm.llms.openai.image_generation import GPTImageGenerationConfig
class AzureFoundryFluxImageGenerationConfig(GPTImageGenerationConfig):
"""
Azure Foundry flux image generation config
From manual testing it follows the gpt-image-1 image generation config
(Azure Foundry does not have any docs on supported params at the time of writing)
From our test suite - following GPTImageGenerationConfig is working for this model
"""
@staticmethod
def get_flux2_image_generation_url(
api_base: Optional[str],
model: str,
api_version: Optional[str],
) -> str:
"""
Constructs the complete URL for Azure AI FLUX 2 image generation.
FLUX 2 models on Azure AI use a different URL pattern than standard Azure OpenAI:
- Standard: /openai/deployments/{model}/images/generations
- FLUX 2: /providers/blackforestlabs/v1/flux-2-pro
Args:
api_base: Base URL (e.g., https://litellm-ci-cd-prod.services.ai.azure.com)
model: Model name (e.g., flux.2-pro)
api_version: API version (e.g., preview)
Returns:
Complete URL for the FLUX 2 image generation endpoint
"""
if api_base is None:
raise ValueError(
"api_base is required for Azure AI FLUX 2 image generation"
)
api_base = api_base.rstrip("/")
api_version = api_version or "preview"
# If the api_base already contains /providers/, it's already a complete path
if "/providers/" in api_base:
if "?" in api_base:
return api_base
return f"{api_base}?api-version={api_version}"
# Construct the FLUX 2 provider path
# Model name flux.2-pro maps to endpoint flux-2-pro
return f"{api_base}/providers/blackforestlabs/v1/flux-2-pro?api-version={api_version}"
@staticmethod
def is_flux2_model(model: str) -> bool:
"""
Check if the model is an Azure AI FLUX 2 model.
Args:
model: Model name (e.g., flux.2-pro, azure_ai/flux.2-pro)
Returns:
True if the model is a FLUX 2 model
"""
model_lower = model.lower().replace(".", "-").replace("_", "-")
return "flux-2" in model_lower or "flux2" in model_lower

View File

@@ -0,0 +1,9 @@
from litellm.llms.openai.image_generation import GPTImageGenerationConfig
class AzureFoundryGPTImageGenerationConfig(GPTImageGenerationConfig):
"""
Azure gpt-image-1 image generation config
"""
pass

View File

@@ -0,0 +1,12 @@
"""Azure AI OCR module."""
from .common_utils import get_azure_ai_ocr_config
from .document_intelligence.transformation import (
AzureDocumentIntelligenceOCRConfig,
)
from .transformation import AzureAIOCRConfig
__all__ = [
"AzureAIOCRConfig",
"AzureDocumentIntelligenceOCRConfig",
"get_azure_ai_ocr_config",
]

View File

@@ -0,0 +1,52 @@
"""
Common utilities for Azure AI OCR providers.
This module provides routing logic to determine which OCR configuration to use
based on the model name.
"""
from typing import TYPE_CHECKING, Optional
from litellm._logging import verbose_logger
if TYPE_CHECKING:
from litellm.llms.base_llm.ocr.transformation import BaseOCRConfig
def get_azure_ai_ocr_config(model: str) -> Optional["BaseOCRConfig"]:
"""
Determine which Azure AI OCR configuration to use based on the model name.
Azure AI supports multiple OCR services:
- Azure Document Intelligence: azure_ai/doc-intelligence/<model>
- Mistral OCR (via Azure AI): azure_ai/<model>
Args:
model: The model name (e.g., "azure_ai/doc-intelligence/prebuilt-read",
"azure_ai/pixtral-12b-2409")
Returns:
OCR configuration instance for the specified model
Examples:
>>> get_azure_ai_ocr_config("azure_ai/doc-intelligence/prebuilt-read")
<AzureDocumentIntelligenceOCRConfig object>
>>> get_azure_ai_ocr_config("azure_ai/pixtral-12b-2409")
<AzureAIOCRConfig object>
"""
from litellm.llms.azure_ai.ocr.document_intelligence.transformation import (
AzureDocumentIntelligenceOCRConfig,
)
from litellm.llms.azure_ai.ocr.transformation import AzureAIOCRConfig
# Check for Azure Document Intelligence models
if "doc-intelligence" in model or "documentintelligence" in model:
verbose_logger.debug(
f"Routing {model} to Azure Document Intelligence OCR config"
)
return AzureDocumentIntelligenceOCRConfig()
# Default to Mistral-based OCR for other azure_ai models
verbose_logger.debug(f"Routing {model} to Azure AI (Mistral) OCR config")
return AzureAIOCRConfig()

View File

@@ -0,0 +1,4 @@
"""Azure Document Intelligence OCR module."""
from .transformation import AzureDocumentIntelligenceOCRConfig
__all__ = ["AzureDocumentIntelligenceOCRConfig"]

View File

@@ -0,0 +1,698 @@
"""
Azure Document Intelligence OCR transformation implementation.
Azure Document Intelligence (formerly Form Recognizer) provides advanced document analysis capabilities.
This implementation transforms between Mistral OCR format and Azure Document Intelligence API v4.0.
Note: Azure Document Intelligence API is async - POST returns 202 Accepted with Operation-Location header.
The operation location must be polled until the analysis completes.
"""
import asyncio
import re
import time
from typing import Any, Dict, Optional
import httpx
from litellm._logging import verbose_logger
from litellm.constants import (
AZURE_DOCUMENT_INTELLIGENCE_API_VERSION,
AZURE_DOCUMENT_INTELLIGENCE_DEFAULT_DPI,
AZURE_OPERATION_POLLING_TIMEOUT,
)
from litellm.llms.base_llm.ocr.transformation import (
BaseOCRConfig,
DocumentType,
OCRPage,
OCRPageDimensions,
OCRRequestData,
OCRResponse,
OCRUsageInfo,
)
from litellm.secret_managers.main import get_secret_str
class AzureDocumentIntelligenceOCRConfig(BaseOCRConfig):
"""
Azure Document Intelligence OCR transformation configuration.
Supports Azure Document Intelligence v4.0 (2024-11-30) API.
Model route: azure_ai/doc-intelligence/<model>
Supported models:
- prebuilt-layout: Extracts text with markdown, tables, and structure (closest to Mistral OCR)
- prebuilt-read: Basic text extraction optimized for reading
- prebuilt-document: General document analysis
Reference: https://learn.microsoft.com/en-us/azure/ai-services/document-intelligence/
"""
def __init__(self) -> None:
super().__init__()
def get_supported_ocr_params(self, model: str) -> list:
"""
Get supported OCR parameters for Azure Document Intelligence.
Azure DI has minimal optional parameters compared to Mistral OCR.
Most Mistral-specific params are ignored during transformation.
"""
return []
def validate_environment(
self,
headers: Dict,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
litellm_params: Optional[dict] = None,
**kwargs,
) -> Dict:
"""
Validate environment and return headers for Azure Document Intelligence.
Authentication uses Ocp-Apim-Subscription-Key header.
"""
# Get API key from environment if not provided
if api_key is None:
api_key = get_secret_str("AZURE_DOCUMENT_INTELLIGENCE_API_KEY")
if api_key is None:
raise ValueError(
"Missing Azure Document Intelligence API Key - Set AZURE_DOCUMENT_INTELLIGENCE_API_KEY environment variable or pass api_key parameter"
)
# Validate API base/endpoint is provided
if api_base is None:
api_base = get_secret_str("AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT")
if api_base is None:
raise ValueError(
"Missing Azure Document Intelligence Endpoint - Set AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT environment variable or pass api_base parameter"
)
headers = {
"Ocp-Apim-Subscription-Key": api_key,
"Content-Type": "application/json",
**headers,
}
return headers
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: dict,
litellm_params: Optional[dict] = None,
**kwargs,
) -> str:
"""
Get complete URL for Azure Document Intelligence endpoint.
Format: {endpoint}/documentintelligence/documentModels/{modelId}:analyze?api-version=2024-11-30
Note: API version 2024-11-30 uses /documentintelligence/ path (not /formrecognizer/)
Args:
api_base: Azure Document Intelligence endpoint (e.g., https://your-resource.cognitiveservices.azure.com)
model: Model ID (e.g., "prebuilt-layout", "prebuilt-read")
optional_params: Optional parameters
Returns: Complete URL for Azure DI analyze endpoint
"""
if api_base is None:
api_base = get_secret_str("AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT")
if api_base is None:
raise ValueError(
"Missing Azure Document Intelligence Endpoint - Set AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT environment variable or pass api_base parameter"
)
# Ensure no trailing slash
api_base = api_base.rstrip("/")
# Extract model ID from full model path if needed
# Model can be "prebuilt-layout" or "azure_ai/doc-intelligence/prebuilt-layout"
model_id = model
if "/" in model:
# Extract the last part after the last slash
model_id = model.split("/")[-1]
# Azure Document Intelligence analyze endpoint
# Note: API version 2024-11-30+ uses /documentintelligence/ (not /formrecognizer/)
return f"{api_base}/documentintelligence/documentModels/{model_id}:analyze?api-version={AZURE_DOCUMENT_INTELLIGENCE_API_VERSION}"
def _extract_base64_from_data_uri(self, data_uri: str) -> str:
"""
Extract base64 content from a data URI.
Args:
data_uri: Data URI like "data:application/pdf;base64,..."
Returns:
Base64 string without the data URI prefix
"""
# Match pattern: data:[<mediatype>][;base64],<data>
match = re.match(r"data:([^;]+)(?:;base64)?,(.+)", data_uri)
if match:
return match.group(2)
return data_uri
def transform_ocr_request(
self,
model: str,
document: DocumentType,
optional_params: dict,
headers: dict,
**kwargs,
) -> OCRRequestData:
"""
Transform OCR request to Azure Document Intelligence format.
Mistral OCR format:
{
"document": {
"type": "document_url",
"document_url": "https://example.com/doc.pdf"
}
}
Azure DI format:
{
"urlSource": "https://example.com/doc.pdf"
}
OR
{
"base64Source": "base64_encoded_content"
}
Args:
model: Model name
document: Document dict from user (Mistral format)
optional_params: Already mapped optional parameters
headers: Request headers
Returns:
OCRRequestData with JSON data
"""
verbose_logger.debug(
f"Azure Document Intelligence transform_ocr_request - model: {model}"
)
if not isinstance(document, dict):
raise ValueError(f"Expected document dict, got {type(document)}")
# Extract document URL from Mistral format
doc_type = document.get("type")
document_url = None
if doc_type == "document_url":
document_url = document.get("document_url", "")
elif doc_type == "image_url":
document_url = document.get("image_url", "")
else:
raise ValueError(
f"Invalid document type: {doc_type}. Must be 'document_url' or 'image_url'"
)
if not document_url:
raise ValueError("Document URL is required")
# Build Azure DI request
data: Dict[str, Any] = {}
# Check if it's a data URI (base64)
if document_url.startswith("data:"):
# Extract base64 content
base64_content = self._extract_base64_from_data_uri(document_url)
data["base64Source"] = base64_content
verbose_logger.debug("Using base64Source for Azure Document Intelligence")
else:
# Regular URL
data["urlSource"] = document_url
verbose_logger.debug("Using urlSource for Azure Document Intelligence")
# Azure DI doesn't support most Mistral-specific params
# Ignore pages, include_image_base64, etc.
return OCRRequestData(data=data, files=None)
def _extract_page_markdown(self, page_data: Dict[str, Any]) -> str:
"""
Extract text from Azure DI page and format as markdown.
Azure DI provides text in 'lines' array. We concatenate them with newlines.
Args:
page_data: Azure DI page object
Returns:
Markdown-formatted text
"""
lines = page_data.get("lines", [])
if not lines:
return ""
# Extract text content from each line
text_lines = [line.get("content", "") for line in lines]
# Join with newlines to preserve structure
return "\n".join(text_lines)
def _convert_dimensions(
self, width: float, height: float, unit: str
) -> OCRPageDimensions:
"""
Convert Azure DI dimensions to pixels.
Azure DI provides dimensions in inches. We convert to pixels using configured DPI.
Args:
width: Width in specified unit
height: Height in specified unit
unit: Unit of measurement (e.g., "inch")
Returns:
OCRPageDimensions with pixel values
"""
# Convert to pixels using configured DPI
dpi = AZURE_DOCUMENT_INTELLIGENCE_DEFAULT_DPI
if unit == "inch":
width_px = int(width * dpi)
height_px = int(height * dpi)
else:
# If unit is not inches, assume it's already in pixels
width_px = int(width)
height_px = int(height)
return OCRPageDimensions(width=width_px, height=height_px, dpi=dpi)
@staticmethod
def _check_timeout(start_time: float, timeout_secs: int) -> None:
"""
Check if operation has timed out.
Args:
start_time: Start time of the operation
timeout_secs: Timeout duration in seconds
Raises:
TimeoutError: If operation has exceeded timeout
"""
if time.time() - start_time > timeout_secs:
raise TimeoutError(
f"Azure Document Intelligence operation polling timed out after {timeout_secs} seconds"
)
@staticmethod
def _get_retry_after(response: httpx.Response) -> int:
"""
Get retry-after duration from response headers.
Args:
response: HTTP response
Returns:
Retry-after duration in seconds (default: 2)
"""
retry_after = int(response.headers.get("retry-after", "2"))
verbose_logger.debug(f"Retry polling after: {retry_after} seconds")
return retry_after
@staticmethod
def _check_operation_status(response: httpx.Response) -> str:
"""
Check Azure DI operation status from response.
Args:
response: HTTP response from operation endpoint
Returns:
Operation status string
Raises:
ValueError: If operation failed or status is unknown
"""
try:
result = response.json()
status = result.get("status")
verbose_logger.debug(f"Azure DI operation status: {status}")
if status == "succeeded":
return "succeeded"
elif status == "failed":
error_msg = result.get("error", {}).get("message", "Unknown error")
raise ValueError(
f"Azure Document Intelligence analysis failed: {error_msg}"
)
elif status in ["running", "notStarted"]:
return "running"
else:
raise ValueError(f"Unknown operation status: {status}")
except Exception as e:
if "succeeded" in str(e) or "failed" in str(e):
raise
# If we can't parse JSON, something went wrong
raise ValueError(f"Failed to parse Azure DI operation response: {e}")
def _poll_operation_sync(
self,
operation_url: str,
headers: Dict[str, str],
timeout_secs: int,
) -> httpx.Response:
"""
Poll Azure Document Intelligence operation until completion (sync).
Azure DI POST returns 202 with Operation-Location header.
We need to poll that URL until status is "succeeded" or "failed".
Args:
operation_url: The Operation-Location URL to poll
headers: Request headers (including auth)
timeout_secs: Total timeout in seconds
Returns:
Final response with completed analysis
"""
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
client = _get_httpx_client()
start_time = time.time()
verbose_logger.debug(f"Polling Azure DI operation: {operation_url}")
while True:
self._check_timeout(start_time=start_time, timeout_secs=timeout_secs)
# Poll the operation status
response = client.get(url=operation_url, headers=headers)
# Check operation status
status = self._check_operation_status(response=response)
if status == "succeeded":
return response
elif status == "running":
# Wait before polling again
retry_after = self._get_retry_after(response=response)
time.sleep(retry_after)
async def _poll_operation_async(
self,
operation_url: str,
headers: Dict[str, str],
timeout_secs: int,
) -> httpx.Response:
"""
Poll Azure Document Intelligence operation until completion (async).
Args:
operation_url: The Operation-Location URL to poll
headers: Request headers (including auth)
timeout_secs: Total timeout in seconds
Returns:
Final response with completed analysis
"""
import litellm
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.AZURE_AI)
start_time = time.time()
verbose_logger.debug(f"Polling Azure DI operation (async): {operation_url}")
while True:
self._check_timeout(start_time=start_time, timeout_secs=timeout_secs)
# Poll the operation status
response = await client.get(url=operation_url, headers=headers)
# Check operation status
status = self._check_operation_status(response=response)
if status == "succeeded":
return response
elif status == "running":
# Wait before polling again
retry_after = self._get_retry_after(response=response)
await asyncio.sleep(retry_after)
def transform_ocr_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: Any,
**kwargs,
) -> OCRResponse:
"""
Transform Azure Document Intelligence response to Mistral OCR format.
Handles async operation polling: If response is 202 Accepted, polls Operation-Location
until analysis completes.
Azure DI response (after polling):
{
"status": "succeeded",
"analyzeResult": {
"content": "Full document text...",
"pages": [
{
"pageNumber": 1,
"width": 8.5,
"height": 11,
"unit": "inch",
"lines": [{"content": "text", "boundingBox": [...]}]
}
]
}
}
Mistral OCR format:
{
"pages": [
{
"index": 0,
"markdown": "extracted text",
"dimensions": {"width": 816, "height": 1056, "dpi": 96}
}
],
"model": "azure_ai/doc-intelligence/prebuilt-layout",
"usage_info": {"pages_processed": 1},
"object": "ocr"
}
Args:
model: Model name
raw_response: Raw HTTP response from Azure DI (may be 202 Accepted)
logging_obj: Logging object
Returns:
OCRResponse in Mistral format
"""
try:
# Check if we got 202 Accepted (async operation started)
if raw_response.status_code == 202:
verbose_logger.debug(
"Azure DI returned 202 Accepted, polling operation..."
)
# Get Operation-Location header
operation_url = raw_response.headers.get("Operation-Location")
if not operation_url:
raise ValueError(
"Azure Document Intelligence returned 202 but no Operation-Location header found"
)
# Get headers for polling (need auth)
poll_headers = {
"Ocp-Apim-Subscription-Key": raw_response.request.headers.get(
"Ocp-Apim-Subscription-Key", ""
)
}
# Get timeout from kwargs or use default
timeout_secs = AZURE_OPERATION_POLLING_TIMEOUT
# Poll until operation completes
raw_response = self._poll_operation_sync(
operation_url=operation_url,
headers=poll_headers,
timeout_secs=timeout_secs,
)
# Now parse the completed response
response_json = raw_response.json()
verbose_logger.debug(
f"Azure Document Intelligence response status: {response_json.get('status')}"
)
# Check if request succeeded
status = response_json.get("status")
if status != "succeeded":
raise ValueError(
f"Azure Document Intelligence analysis failed with status: {status}"
)
# Extract analyze result
analyze_result = response_json.get("analyzeResult", {})
azure_pages = analyze_result.get("pages", [])
# Transform pages to Mistral format
mistral_pages = []
for azure_page in azure_pages:
page_number = azure_page.get("pageNumber", 1)
index = page_number - 1 # Convert to 0-based index
# Extract markdown text
markdown = self._extract_page_markdown(azure_page)
# Convert dimensions
width = azure_page.get("width", 8.5)
height = azure_page.get("height", 11)
unit = azure_page.get("unit", "inch")
dimensions = self._convert_dimensions(
width=width, height=height, unit=unit
)
# Build OCR page
ocr_page = OCRPage(
index=index, markdown=markdown, dimensions=dimensions
)
mistral_pages.append(ocr_page)
# Build usage info
usage_info = OCRUsageInfo(
pages_processed=len(mistral_pages), doc_size_bytes=None
)
# Return Mistral OCR response
return OCRResponse(
pages=mistral_pages,
model=model,
usage_info=usage_info,
object="ocr",
)
except Exception as e:
verbose_logger.error(
f"Error parsing Azure Document Intelligence response: {e}"
)
raise e
async def async_transform_ocr_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: Any,
**kwargs,
) -> OCRResponse:
"""
Async transform Azure Document Intelligence response to Mistral OCR format.
Handles async operation polling: If response is 202 Accepted, polls Operation-Location
until analysis completes using async polling.
Args:
model: Model name
raw_response: Raw HTTP response from Azure DI (may be 202 Accepted)
logging_obj: Logging object
Returns:
OCRResponse in Mistral format
"""
try:
# Check if we got 202 Accepted (async operation started)
if raw_response.status_code == 202:
verbose_logger.debug(
"Azure DI returned 202 Accepted, polling operation (async)..."
)
# Get Operation-Location header
operation_url = raw_response.headers.get("Operation-Location")
if not operation_url:
raise ValueError(
"Azure Document Intelligence returned 202 but no Operation-Location header found"
)
# Get headers for polling (need auth)
poll_headers = {
"Ocp-Apim-Subscription-Key": raw_response.request.headers.get(
"Ocp-Apim-Subscription-Key", ""
)
}
# Get timeout from kwargs or use default
timeout_secs = AZURE_OPERATION_POLLING_TIMEOUT
# Poll until operation completes (async)
raw_response = await self._poll_operation_async(
operation_url=operation_url,
headers=poll_headers,
timeout_secs=timeout_secs,
)
# Now parse the completed response
response_json = raw_response.json()
verbose_logger.debug(
f"Azure Document Intelligence response status: {response_json.get('status')}"
)
# Check if request succeeded
status = response_json.get("status")
if status != "succeeded":
raise ValueError(
f"Azure Document Intelligence analysis failed with status: {status}"
)
# Extract analyze result
analyze_result = response_json.get("analyzeResult", {})
azure_pages = analyze_result.get("pages", [])
# Transform pages to Mistral format
mistral_pages = []
for azure_page in azure_pages:
page_number = azure_page.get("pageNumber", 1)
index = page_number - 1 # Convert to 0-based index
# Extract markdown text
markdown = self._extract_page_markdown(azure_page)
# Convert dimensions
width = azure_page.get("width", 8.5)
height = azure_page.get("height", 11)
unit = azure_page.get("unit", "inch")
dimensions = self._convert_dimensions(
width=width, height=height, unit=unit
)
# Build OCR page
ocr_page = OCRPage(
index=index, markdown=markdown, dimensions=dimensions
)
mistral_pages.append(ocr_page)
# Build usage info
usage_info = OCRUsageInfo(
pages_processed=len(mistral_pages), doc_size_bytes=None
)
# Return Mistral OCR response
return OCRResponse(
pages=mistral_pages,
model=model,
usage_info=usage_info,
object="ocr",
)
except Exception as e:
verbose_logger.error(
f"Error parsing Azure Document Intelligence response (async): {e}"
)
raise e

View File

@@ -0,0 +1,281 @@
"""
Azure AI OCR transformation implementation.
"""
from typing import Dict, Optional
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.prompt_templates.image_handling import (
async_convert_url_to_base64,
convert_url_to_base64,
)
from litellm.llms.base_llm.ocr.transformation import DocumentType, OCRRequestData
from litellm.llms.mistral.ocr.transformation import MistralOCRConfig
from litellm.secret_managers.main import get_secret_str
class AzureAIOCRConfig(MistralOCRConfig):
"""
Azure AI OCR transformation configuration.
Azure AI uses Mistral's OCR API but with a different endpoint format.
Inherits transformation logic from MistralOCRConfig since they use the same format.
Reference: Azure AI Foundry OCR documentation
Important: Azure AI only supports base64 data URIs (data:image/..., data:application/pdf;base64,...).
Regular URLs are not supported.
"""
def __init__(self) -> None:
super().__init__()
def validate_environment(
self,
headers: Dict,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
litellm_params: Optional[dict] = None,
**kwargs,
) -> Dict:
"""
Validate environment and return headers for Azure AI OCR.
Azure AI uses Bearer token authentication with AZURE_AI_API_KEY.
"""
# Get API key from environment if not provided
if api_key is None:
api_key = get_secret_str("AZURE_AI_API_KEY")
if api_key is None:
raise ValueError(
"Missing Azure AI API Key - A call is being made to Azure AI but no key is set either in the environment variables or via params"
)
# Validate API base is provided
if api_base is None:
api_base = get_secret_str("AZURE_AI_API_BASE")
if api_base is None:
raise ValueError(
"Missing Azure AI API Base - Set AZURE_AI_API_BASE environment variable or pass api_base parameter"
)
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
**headers,
}
return headers
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: dict,
litellm_params: Optional[dict] = None,
**kwargs,
) -> str:
"""
Get complete URL for Azure AI OCR endpoint.
Azure AI endpoint format: https://<api_base>/providers/mistral/azure/ocr
Args:
api_base: Azure AI API base URL
model: Model name (not used in URL construction)
optional_params: Optional parameters
Returns: Complete URL for Azure AI OCR endpoint
"""
if api_base is None:
raise ValueError(
"Missing Azure AI API Base - Set AZURE_AI_API_BASE environment variable or pass api_base parameter"
)
# Ensure no trailing slash
api_base = api_base.rstrip("/")
# Azure AI OCR endpoint format
return f"{api_base}/providers/mistral/azure/ocr"
def _convert_url_to_data_uri_sync(self, url: str) -> str:
"""
Synchronously convert a URL to a base64 data URI.
Azure AI OCR doesn't have internet access, so we need to fetch URLs
and convert them to base64 data URIs.
Args:
url: The URL to convert
Returns:
Base64 data URI string
"""
verbose_logger.debug(
f"Azure AI OCR: Converting URL to base64 data URI (sync): {url}"
)
# Fetch and convert to base64 data URI
# convert_url_to_base64 already returns a full data URI like "data:image/jpeg;base64,..."
data_uri = convert_url_to_base64(url=url)
verbose_logger.debug(
f"Azure AI OCR: Converted URL to data URI (length: {len(data_uri)})"
)
return data_uri
async def _convert_url_to_data_uri_async(self, url: str) -> str:
"""
Asynchronously convert a URL to a base64 data URI.
Azure AI OCR doesn't have internet access, so we need to fetch URLs
and convert them to base64 data URIs.
Args:
url: The URL to convert
Returns:
Base64 data URI string
"""
verbose_logger.debug(
f"Azure AI OCR: Converting URL to base64 data URI (async): {url}"
)
# Fetch and convert to base64 data URI asynchronously
# async_convert_url_to_base64 already returns a full data URI like "data:image/jpeg;base64,..."
data_uri = await async_convert_url_to_base64(url=url)
verbose_logger.debug(
f"Azure AI OCR: Converted URL to data URI (length: {len(data_uri)})"
)
return data_uri
def transform_ocr_request(
self,
model: str,
document: DocumentType,
optional_params: dict,
headers: dict,
**kwargs,
) -> OCRRequestData:
"""
Transform OCR request for Azure AI, converting URLs to base64 data URIs (sync).
Azure AI OCR doesn't have internet access, so we automatically fetch
any URLs and convert them to base64 data URIs synchronously.
Args:
model: Model name
document: Document dict from user
optional_params: Already mapped optional parameters
headers: Request headers
**kwargs: Additional arguments
Returns:
OCRRequestData with JSON data
"""
verbose_logger.debug(
f"Azure AI OCR transform_ocr_request (sync) - model: {model}"
)
if not isinstance(document, dict):
raise ValueError(f"Expected document dict, got {type(document)}")
# Check if we need to convert URL to base64
doc_type = document.get("type")
transformed_document = document.copy()
if doc_type == "document_url":
document_url = document.get("document_url", "")
# If it's not already a data URI, convert it
if document_url and not document_url.startswith("data:"):
verbose_logger.debug(
"Azure AI OCR: Converting document URL to base64 data URI (sync)"
)
data_uri = self._convert_url_to_data_uri_sync(url=document_url)
transformed_document["document_url"] = data_uri
elif doc_type == "image_url":
image_url = document.get("image_url", "")
# If it's not already a data URI, convert it
if image_url and not image_url.startswith("data:"):
verbose_logger.debug(
"Azure AI OCR: Converting image URL to base64 data URI (sync)"
)
data_uri = self._convert_url_to_data_uri_sync(url=image_url)
transformed_document["image_url"] = data_uri
# Call parent's transform to build the request
return super().transform_ocr_request(
model=model,
document=transformed_document,
optional_params=optional_params,
headers=headers,
**kwargs,
)
async def async_transform_ocr_request(
self,
model: str,
document: DocumentType,
optional_params: dict,
headers: dict,
**kwargs,
) -> OCRRequestData:
"""
Transform OCR request for Azure AI, converting URLs to base64 data URIs (async).
Azure AI OCR doesn't have internet access, so we automatically fetch
any URLs and convert them to base64 data URIs asynchronously.
Args:
model: Model name
document: Document dict from user
optional_params: Already mapped optional parameters
headers: Request headers
**kwargs: Additional arguments
Returns:
OCRRequestData with JSON data
"""
verbose_logger.debug(
f"Azure AI OCR async_transform_ocr_request - model: {model}"
)
if not isinstance(document, dict):
raise ValueError(f"Expected document dict, got {type(document)}")
# Check if we need to convert URL to base64
doc_type = document.get("type")
transformed_document = document.copy()
if doc_type == "document_url":
document_url = document.get("document_url", "")
# If it's not already a data URI, convert it
if document_url and not document_url.startswith("data:"):
verbose_logger.debug(
"Azure AI OCR: Converting document URL to base64 data URI (async)"
)
data_uri = await self._convert_url_to_data_uri_async(url=document_url)
transformed_document["document_url"] = data_uri
elif doc_type == "image_url":
image_url = document.get("image_url", "")
# If it's not already a data URI, convert it
if image_url and not image_url.startswith("data:"):
verbose_logger.debug(
"Azure AI OCR: Converting image URL to base64 data URI (async)"
)
data_uri = await self._convert_url_to_data_uri_async(url=image_url)
transformed_document["image_url"] = data_uri
# Call parent's transform to build the request
return super().transform_ocr_request(
model=model,
document=transformed_document,
optional_params=optional_params,
headers=headers,
**kwargs,
)

View File

@@ -0,0 +1,5 @@
"""
Azure AI Rerank - uses `llm_http_handler.py` to make httpx requests
Request/Response transformation is handled in `transformation.py`
"""

View File

@@ -0,0 +1,125 @@
"""
Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format.
"""
from typing import Optional
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import RerankResponse
from litellm.utils import _add_path_to_api_base
class AzureAIRerankConfig(CohereRerankConfig):
"""
Azure AI Rerank - Follows the same Spec as Cohere Rerank
"""
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: Optional[dict] = None,
) -> str:
if api_base is None:
raise ValueError(
"Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var."
)
original_url = httpx.URL(api_base)
if not original_url.is_absolute_url:
raise ValueError(
"Azure AI API Base must be an absolute URL including scheme (e.g. "
"'https://<resource>.services.ai.azure.com'). "
f"Got api_base={api_base!r}."
)
normalized_path = original_url.path.rstrip("/")
# Allow callers to pass either full v1/v2 rerank endpoints:
# - https://<resource>.services.ai.azure.com/v1/rerank
# - https://<resource>.services.ai.azure.com/providers/cohere/v2/rerank
if normalized_path.endswith("/v1/rerank") or normalized_path.endswith(
"/v2/rerank"
):
return str(original_url.copy_with(path=normalized_path or "/"))
# If callers pass just the version path (e.g. ".../v2" or ".../providers/cohere/v2"), append "/rerank"
if (
normalized_path.endswith("/v1")
or normalized_path.endswith("/v2")
or normalized_path.endswith("/providers/cohere/v2")
):
return _add_path_to_api_base(
api_base=str(original_url.copy_with(path=normalized_path or "/")),
ending_path="/rerank",
)
# Backwards compatible default: Azure AI rerank was originally exposed under /v1/rerank
return _add_path_to_api_base(api_base=api_base, ending_path="/v1/rerank")
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
optional_params: Optional[dict] = None,
) -> dict:
if api_key is None:
api_key = get_secret_str("AZURE_AI_API_KEY") or litellm.azure_key
if api_key is None:
raise ValueError(
"Azure AI API key is required. Please set 'AZURE_AI_API_KEY' or 'litellm.azure_key'"
)
default_headers = {
"Authorization": f"Bearer {api_key}",
"accept": "application/json",
"content-type": "application/json",
}
# If 'Authorization' is provided in headers, it overrides the default.
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
# Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers}
def transform_rerank_response(
self,
model: str,
raw_response: httpx.Response,
model_response: RerankResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> RerankResponse:
rerank_response = super().transform_rerank_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=request_data,
optional_params=optional_params,
litellm_params=litellm_params,
)
base_model = self._get_base_model(
rerank_response._hidden_params.get("llm_provider-azureml-model-group")
)
rerank_response._hidden_params["model"] = base_model
return rerank_response
def _get_base_model(self, azure_model_group: Optional[str]) -> Optional[str]:
if azure_model_group is None:
return None
if azure_model_group == "offer-cohere-rerank-mul-paygo":
return "azure_ai/cohere-rerank-v3-multilingual"
if azure_model_group == "offer-cohere-rerank-eng-paygo":
return "azure_ai/cohere-rerank-v3-english"
return azure_model_group

View File

@@ -0,0 +1,3 @@
from litellm.llms.azure_ai.vector_stores.transformation import AzureAIVectorStoreConfig
__all__ = ["AzureAIVectorStoreConfig"]

View File

@@ -0,0 +1,257 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
import litellm
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
from litellm.types.router import GenericLiteLLMParams
from litellm.types.vector_stores import (
BaseVectorStoreAuthCredentials,
VectorStoreCreateOptionalRequestParams,
VectorStoreCreateResponse,
VectorStoreIndexEndpoints,
VectorStoreResultContent,
VectorStoreSearchOptionalRequestParams,
VectorStoreSearchResponse,
VectorStoreSearchResult,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AzureAIVectorStoreConfig(BaseVectorStoreConfig, BaseAzureLLM):
"""
Configuration for Azure AI Search Vector Store
This implementation uses the Azure AI Search API for vector store operations.
Supports vector search with embeddings generated via litellm.embeddings.
"""
def __init__(self):
super().__init__()
def get_vector_store_endpoints_by_type(self) -> VectorStoreIndexEndpoints:
return {
"read": [("GET", "/docs/search"), ("POST", "/docs/search")],
"write": [("PUT", "/docs")],
}
def get_auth_credentials(
self, litellm_params: dict
) -> BaseVectorStoreAuthCredentials:
api_key = litellm_params.get("api_key")
if api_key is None:
raise ValueError("api_key is required")
return {
"headers": {
"api-key": api_key,
}
}
def validate_environment(
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
basic_headers = self._base_validate_azure_environment(headers, litellm_params)
basic_headers.update({"Content-Type": "application/json"})
return basic_headers
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the base endpoint for Azure AI Search API
Expected format: https://{search_service_name}.search.windows.net
"""
if api_base:
return api_base.rstrip("/")
# Get search service name from litellm_params
search_service_name = litellm_params.get("azure_search_service_name")
if not search_service_name:
raise ValueError(
"Azure AI Search service name is required. "
"Provide it via litellm_params['azure_search_service_name'] or api_base parameter"
)
# Azure AI Search endpoint
return f"https://{search_service_name}.search.windows.net"
def transform_search_vector_store_request(
self,
vector_store_id: str,
query: Union[str, List[str]],
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> Tuple[str, Dict[str, Any]]:
"""
Transform search request for Azure AI Search API
Generates embeddings using litellm.embeddings and constructs Azure AI Search request
"""
# Convert query to string if it's a list
if isinstance(query, list):
query = " ".join(query)
# Get embedding model from litellm_params (required)
embedding_model = litellm_params.get("litellm_embedding_model")
if not embedding_model:
raise ValueError(
"embedding_model is required in litellm_params for Azure AI Search. "
"Example: litellm_params['embedding_model'] = 'azure/text-embedding-3-large'"
)
embedding_config = litellm_params.get("litellm_embedding_config", {})
if not embedding_config:
raise ValueError(
"embedding_config is required in litellm_params for Azure AI Search. "
"Example: litellm_params['embedding_config'] = {'api_base': 'https://krris-mh44uf7y-eastus2.cognitiveservices.azure.com/', 'api_key': 'os.environ/AZURE_API_KEY', 'api_version': '2025-09-01'}"
)
# Get vector field name (defaults to contentVector)
vector_field = litellm_params.get("azure_search_vector_field", "contentVector")
# Get top_k (number of results to return)
top_k = vector_store_search_optional_params.get("top_k", 10)
# Generate embedding for the query using litellm.embeddings
try:
embedding_response = litellm.embedding(
model=embedding_model,
input=[query],
**embedding_config,
)
query_vector = embedding_response.data[0]["embedding"]
except Exception as e:
raise Exception(f"Failed to generate embedding for query: {str(e)}")
# Azure AI Search endpoint for search
index_name = vector_store_id # vector_store_id is the index name
url = f"{api_base}/indexes/{index_name}/docs/search?api-version=2024-07-01"
# Build the request body for Azure AI Search with vector search
request_body = {
"search": "*", # Get all documents (filtered by vector similarity)
"vectorQueries": [
{
"vector": query_vector,
"fields": vector_field,
"kind": "vector",
"k": top_k, # Number of nearest neighbors to return
}
],
"select": "id,content", # Fields to return (customize based on schema)
"top": top_k,
}
#########################################################
# Update logging object with details of the request
#########################################################
litellm_logging_obj.model_call_details["input"] = query
litellm_logging_obj.model_call_details["embedding_model"] = embedding_model
litellm_logging_obj.model_call_details["top_k"] = top_k
return url, request_body
def transform_search_vector_store_response(
self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj
) -> VectorStoreSearchResponse:
"""
Transform Azure AI Search API response to standard vector store search response
Handles the format from Azure AI Search which returns:
{
"value": [
{
"id": "...",
"content": "...",
"@search.score": 0.95,
... (other fields)
}
]
}
"""
try:
response_json = response.json()
# Extract results from Azure AI Search API response
results = response_json.get("value", [])
# Transform results to standard format
search_results: List[VectorStoreSearchResult] = []
for result in results:
# Extract document ID
document_id = result.get("id", "")
# Extract text content
text_content = result.get("content", "")
content = [
VectorStoreResultContent(
text=text_content,
type="text",
)
]
# Get the search score (relevance score from Azure AI Search)
score = result.get("@search.score", 0.0)
# Use document ID as both file_id and filename
file_id = document_id
filename = f"Document {document_id}"
# Build attributes with all available metadata
# Exclude system fields and already-processed fields
attributes = {}
for key, value in result.items():
if key not in ["id", "content", "contentVector", "@search.score"]:
attributes[key] = value
# Always include document_id in attributes
attributes["document_id"] = document_id
result_obj = VectorStoreSearchResult(
score=score,
content=content,
file_id=file_id,
filename=filename,
attributes=attributes,
)
search_results.append(result_obj)
return VectorStoreSearchResponse(
object="vector_store.search_results.page",
search_query=litellm_logging_obj.model_call_details.get("input", ""),
data=search_results,
)
except Exception as e:
raise self.get_error_class(
error_message=str(e),
status_code=response.status_code,
headers=response.headers,
)
def transform_create_vector_store_request(
self,
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
api_base: str,
) -> Tuple[str, Dict]:
raise NotImplementedError
def transform_create_vector_store_response(
self, response: httpx.Response
) -> VectorStoreCreateResponse:
raise NotImplementedError