chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1 @@
|
||||
`/chat/completion` calls routed via `openai.py`.
|
||||
@@ -0,0 +1,11 @@
|
||||
from litellm.llms.azure_ai.agents.handler import azure_ai_agents_handler
|
||||
from litellm.llms.azure_ai.agents.transformation import (
|
||||
AzureAIAgentsConfig,
|
||||
AzureAIAgentsError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AzureAIAgentsConfig",
|
||||
"AzureAIAgentsError",
|
||||
"azure_ai_agents_handler",
|
||||
]
|
||||
@@ -0,0 +1,659 @@
|
||||
"""
|
||||
Handler for Azure Foundry Agent Service API.
|
||||
|
||||
This handler executes the multi-step agent flow:
|
||||
1. Create thread (or use existing)
|
||||
2. Add messages to thread
|
||||
3. Create and poll a run
|
||||
4. Retrieve the assistant's response messages
|
||||
|
||||
Model format: azure_ai/agents/<agent_id>
|
||||
API Base format: https://<AIFoundryResourceName>.services.ai.azure.com/api/projects/<ProjectName>
|
||||
|
||||
Authentication: Uses Azure AD Bearer tokens (not API keys)
|
||||
Get token via: az account get-access-token --resource 'https://ai.azure.com'
|
||||
|
||||
Supports both polling-based and native streaming (SSE) modes.
|
||||
|
||||
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.azure_ai.agents.transformation import (
|
||||
AzureAIAgentsConfig,
|
||||
AzureAIAgentsError,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
HTTPHandler = Any
|
||||
AsyncHTTPHandler = Any
|
||||
|
||||
|
||||
class AzureAIAgentsHandler:
|
||||
"""
|
||||
Handler for Azure AI Agent Service.
|
||||
|
||||
Executes the complete agent flow which requires multiple API calls.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = AzureAIAgentsConfig()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# URL Builders
|
||||
# -------------------------------------------------------------------------
|
||||
# Azure Foundry Agents API uses /assistants, /threads, etc. directly
|
||||
# See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
# -------------------------------------------------------------------------
|
||||
def _build_thread_url(self, api_base: str, api_version: str) -> str:
|
||||
return f"{api_base}/threads?api-version={api_version}"
|
||||
|
||||
def _build_messages_url(
|
||||
self, api_base: str, thread_id: str, api_version: str
|
||||
) -> str:
|
||||
return f"{api_base}/threads/{thread_id}/messages?api-version={api_version}"
|
||||
|
||||
def _build_runs_url(self, api_base: str, thread_id: str, api_version: str) -> str:
|
||||
return f"{api_base}/threads/{thread_id}/runs?api-version={api_version}"
|
||||
|
||||
def _build_run_status_url(
|
||||
self, api_base: str, thread_id: str, run_id: str, api_version: str
|
||||
) -> str:
|
||||
return f"{api_base}/threads/{thread_id}/runs/{run_id}?api-version={api_version}"
|
||||
|
||||
def _build_list_messages_url(
|
||||
self, api_base: str, thread_id: str, api_version: str
|
||||
) -> str:
|
||||
return f"{api_base}/threads/{thread_id}/messages?api-version={api_version}"
|
||||
|
||||
def _build_create_thread_and_run_url(self, api_base: str, api_version: str) -> str:
|
||||
"""URL for the create-thread-and-run endpoint (supports streaming)."""
|
||||
return f"{api_base}/threads/runs?api-version={api_version}"
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Response Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
def _extract_content_from_messages(self, messages_data: dict) -> str:
|
||||
"""Extract assistant content from the messages response."""
|
||||
for msg in messages_data.get("data", []):
|
||||
if msg.get("role") == "assistant":
|
||||
for content_item in msg.get("content", []):
|
||||
if content_item.get("type") == "text":
|
||||
return content_item.get("text", {}).get("value", "")
|
||||
return ""
|
||||
|
||||
def _build_model_response(
|
||||
self,
|
||||
model: str,
|
||||
content: str,
|
||||
model_response: ModelResponse,
|
||||
thread_id: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
) -> ModelResponse:
|
||||
"""Build the ModelResponse from agent output."""
|
||||
from litellm.types.utils import Choices, Message, Usage
|
||||
|
||||
model_response.choices = [
|
||||
Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=Message(content=content, role="assistant"),
|
||||
)
|
||||
]
|
||||
model_response.model = model
|
||||
|
||||
# Store thread_id for conversation continuity
|
||||
if (
|
||||
not hasattr(model_response, "_hidden_params")
|
||||
or model_response._hidden_params is None
|
||||
):
|
||||
model_response._hidden_params = {}
|
||||
model_response._hidden_params["thread_id"] = thread_id
|
||||
|
||||
# Estimate token usage
|
||||
try:
|
||||
from litellm.utils import token_counter
|
||||
|
||||
prompt_tokens = token_counter(model="gpt-3.5-turbo", messages=messages)
|
||||
completion_tokens = token_counter(
|
||||
model="gpt-3.5-turbo", text=content, count_response_tokens=True
|
||||
)
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to calculate token usage: {str(e)}")
|
||||
|
||||
return model_response
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
model: str,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
optional_params: dict,
|
||||
headers: Optional[dict],
|
||||
) -> tuple:
|
||||
"""Prepare common parameters for completion.
|
||||
|
||||
Azure Foundry Agents API uses Bearer token authentication:
|
||||
- Authorization: Bearer <token> (Azure AD token from 'az account get-access-token --resource https://ai.azure.com')
|
||||
|
||||
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
"""
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
# Azure Foundry Agents uses Bearer token authentication
|
||||
# The api_key here is expected to be an Azure AD token
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
api_version = optional_params.get(
|
||||
"api_version", self.config.DEFAULT_API_VERSION
|
||||
)
|
||||
agent_id = self.config._get_agent_id(model, optional_params)
|
||||
thread_id = optional_params.get("thread_id")
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Azure AI Agents completion - api_base: {api_base}, agent_id: {agent_id}"
|
||||
)
|
||||
|
||||
return headers, api_version, agent_id, thread_id, api_base
|
||||
|
||||
def _check_response(
|
||||
self, response: httpx.Response, expected_codes: List[int], error_msg: str
|
||||
):
|
||||
"""Check response status and raise error if not expected."""
|
||||
if response.status_code not in expected_codes:
|
||||
raise AzureAIAgentsError(
|
||||
status_code=response.status_code,
|
||||
message=f"{error_msg}: {response.text}",
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Sync Completion
|
||||
# -------------------------------------------------------------------------
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
timeout: float,
|
||||
client: Optional[HTTPHandler] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> ModelResponse:
|
||||
"""Execute synchronous completion using Azure Agent Service."""
|
||||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||
|
||||
if client is None:
|
||||
client = _get_httpx_client(
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)}
|
||||
)
|
||||
|
||||
(
|
||||
headers,
|
||||
api_version,
|
||||
agent_id,
|
||||
thread_id,
|
||||
api_base,
|
||||
) = self._prepare_completion_params(
|
||||
model, api_base, api_key, optional_params, headers
|
||||
)
|
||||
|
||||
def make_request(
|
||||
method: str, url: str, json_data: Optional[dict] = None
|
||||
) -> httpx.Response:
|
||||
if method == "GET":
|
||||
return client.get(url=url, headers=headers)
|
||||
return client.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=json.dumps(json_data) if json_data else None,
|
||||
)
|
||||
|
||||
# Execute the agent flow
|
||||
thread_id, content = self._execute_agent_flow_sync(
|
||||
make_request=make_request,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
agent_id=agent_id,
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
return self._build_model_response(
|
||||
model, content, model_response, thread_id, messages
|
||||
)
|
||||
|
||||
def _execute_agent_flow_sync(
|
||||
self,
|
||||
make_request: Callable,
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
agent_id: str,
|
||||
thread_id: Optional[str],
|
||||
messages: List[Dict[str, Any]],
|
||||
optional_params: dict,
|
||||
) -> Tuple[str, str]:
|
||||
"""Execute the agent flow synchronously. Returns (thread_id, content)."""
|
||||
|
||||
# Step 1: Create thread if not provided
|
||||
if not thread_id:
|
||||
verbose_logger.debug(
|
||||
f"Creating thread at: {self._build_thread_url(api_base, api_version)}"
|
||||
)
|
||||
response = make_request(
|
||||
"POST", self._build_thread_url(api_base, api_version), {}
|
||||
)
|
||||
self._check_response(response, [200, 201], "Failed to create thread")
|
||||
thread_id = response.json()["id"]
|
||||
verbose_logger.debug(f"Created thread: {thread_id}")
|
||||
|
||||
# At this point thread_id is guaranteed to be a string
|
||||
assert thread_id is not None
|
||||
|
||||
# Step 2: Add messages to thread
|
||||
for msg in messages:
|
||||
if msg.get("role") in ["user", "system"]:
|
||||
url = self._build_messages_url(api_base, thread_id, api_version)
|
||||
response = make_request(
|
||||
"POST", url, {"role": "user", "content": msg.get("content", "")}
|
||||
)
|
||||
self._check_response(response, [200, 201], "Failed to add message")
|
||||
|
||||
# Step 3: Create run
|
||||
run_payload = {"assistant_id": agent_id}
|
||||
if "instructions" in optional_params:
|
||||
run_payload["instructions"] = optional_params["instructions"]
|
||||
|
||||
response = make_request(
|
||||
"POST", self._build_runs_url(api_base, thread_id, api_version), run_payload
|
||||
)
|
||||
self._check_response(response, [200, 201], "Failed to create run")
|
||||
run_id = response.json()["id"]
|
||||
verbose_logger.debug(f"Created run: {run_id}")
|
||||
|
||||
# Step 4: Poll for completion
|
||||
status_url = self._build_run_status_url(
|
||||
api_base, thread_id, run_id, api_version
|
||||
)
|
||||
for _ in range(self.config.MAX_POLL_ATTEMPTS):
|
||||
response = make_request("GET", status_url)
|
||||
self._check_response(response, [200], "Failed to get run status")
|
||||
|
||||
status = response.json().get("status")
|
||||
verbose_logger.debug(f"Run status: {status}")
|
||||
|
||||
if status == "completed":
|
||||
break
|
||||
elif status in ["failed", "cancelled", "expired"]:
|
||||
error_msg = (
|
||||
response.json()
|
||||
.get("last_error", {})
|
||||
.get("message", "Unknown error")
|
||||
)
|
||||
raise AzureAIAgentsError(
|
||||
status_code=500, message=f"Run {status}: {error_msg}"
|
||||
)
|
||||
|
||||
time.sleep(self.config.POLL_INTERVAL_SECONDS)
|
||||
else:
|
||||
raise AzureAIAgentsError(
|
||||
status_code=408, message="Run timed out waiting for completion"
|
||||
)
|
||||
|
||||
# Step 5: Get messages
|
||||
response = make_request(
|
||||
"GET", self._build_list_messages_url(api_base, thread_id, api_version)
|
||||
)
|
||||
self._check_response(response, [200], "Failed to get messages")
|
||||
|
||||
content = self._extract_content_from_messages(response.json())
|
||||
return thread_id, content
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Async Completion
|
||||
# -------------------------------------------------------------------------
|
||||
async def acompletion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
timeout: float,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> ModelResponse:
|
||||
"""Execute asynchronous completion using Azure Agent Service."""
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
|
||||
if client is None:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.AZURE_AI,
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
|
||||
)
|
||||
|
||||
(
|
||||
headers,
|
||||
api_version,
|
||||
agent_id,
|
||||
thread_id,
|
||||
api_base,
|
||||
) = self._prepare_completion_params(
|
||||
model, api_base, api_key, optional_params, headers
|
||||
)
|
||||
|
||||
async def make_request(
|
||||
method: str, url: str, json_data: Optional[dict] = None
|
||||
) -> httpx.Response:
|
||||
if method == "GET":
|
||||
return await client.get(url=url, headers=headers)
|
||||
return await client.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=json.dumps(json_data) if json_data else None,
|
||||
)
|
||||
|
||||
# Execute the agent flow
|
||||
thread_id, content = await self._execute_agent_flow_async(
|
||||
make_request=make_request,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
agent_id=agent_id,
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
return self._build_model_response(
|
||||
model, content, model_response, thread_id, messages
|
||||
)
|
||||
|
||||
async def _execute_agent_flow_async(
|
||||
self,
|
||||
make_request: Callable,
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
agent_id: str,
|
||||
thread_id: Optional[str],
|
||||
messages: List[Dict[str, Any]],
|
||||
optional_params: dict,
|
||||
) -> Tuple[str, str]:
|
||||
"""Execute the agent flow asynchronously. Returns (thread_id, content)."""
|
||||
|
||||
# Step 1: Create thread if not provided
|
||||
if not thread_id:
|
||||
verbose_logger.debug(
|
||||
f"Creating thread at: {self._build_thread_url(api_base, api_version)}"
|
||||
)
|
||||
response = await make_request(
|
||||
"POST", self._build_thread_url(api_base, api_version), {}
|
||||
)
|
||||
self._check_response(response, [200, 201], "Failed to create thread")
|
||||
thread_id = response.json()["id"]
|
||||
verbose_logger.debug(f"Created thread: {thread_id}")
|
||||
|
||||
# At this point thread_id is guaranteed to be a string
|
||||
assert thread_id is not None
|
||||
|
||||
# Step 2: Add messages to thread
|
||||
for msg in messages:
|
||||
if msg.get("role") in ["user", "system"]:
|
||||
url = self._build_messages_url(api_base, thread_id, api_version)
|
||||
response = await make_request(
|
||||
"POST", url, {"role": "user", "content": msg.get("content", "")}
|
||||
)
|
||||
self._check_response(response, [200, 201], "Failed to add message")
|
||||
|
||||
# Step 3: Create run
|
||||
run_payload = {"assistant_id": agent_id}
|
||||
if "instructions" in optional_params:
|
||||
run_payload["instructions"] = optional_params["instructions"]
|
||||
|
||||
response = await make_request(
|
||||
"POST", self._build_runs_url(api_base, thread_id, api_version), run_payload
|
||||
)
|
||||
self._check_response(response, [200, 201], "Failed to create run")
|
||||
run_id = response.json()["id"]
|
||||
verbose_logger.debug(f"Created run: {run_id}")
|
||||
|
||||
# Step 4: Poll for completion
|
||||
status_url = self._build_run_status_url(
|
||||
api_base, thread_id, run_id, api_version
|
||||
)
|
||||
for _ in range(self.config.MAX_POLL_ATTEMPTS):
|
||||
response = await make_request("GET", status_url)
|
||||
self._check_response(response, [200], "Failed to get run status")
|
||||
|
||||
status = response.json().get("status")
|
||||
verbose_logger.debug(f"Run status: {status}")
|
||||
|
||||
if status == "completed":
|
||||
break
|
||||
elif status in ["failed", "cancelled", "expired"]:
|
||||
error_msg = (
|
||||
response.json()
|
||||
.get("last_error", {})
|
||||
.get("message", "Unknown error")
|
||||
)
|
||||
raise AzureAIAgentsError(
|
||||
status_code=500, message=f"Run {status}: {error_msg}"
|
||||
)
|
||||
|
||||
await asyncio.sleep(self.config.POLL_INTERVAL_SECONDS)
|
||||
else:
|
||||
raise AzureAIAgentsError(
|
||||
status_code=408, message="Run timed out waiting for completion"
|
||||
)
|
||||
|
||||
# Step 5: Get messages
|
||||
response = await make_request(
|
||||
"GET", self._build_list_messages_url(api_base, thread_id, api_version)
|
||||
)
|
||||
self._check_response(response, [200], "Failed to get messages")
|
||||
|
||||
content = self._extract_content_from_messages(response.json())
|
||||
return thread_id, content
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Streaming Completion (Native SSE)
|
||||
# -------------------------------------------------------------------------
|
||||
async def acompletion_stream(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
timeout: float,
|
||||
headers: Optional[dict] = None,
|
||||
) -> AsyncIterator:
|
||||
"""Execute async streaming completion using Azure Agent Service with native SSE."""
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
|
||||
(
|
||||
headers,
|
||||
api_version,
|
||||
agent_id,
|
||||
thread_id,
|
||||
api_base,
|
||||
) = self._prepare_completion_params(
|
||||
model, api_base, api_key, optional_params, headers
|
||||
)
|
||||
|
||||
# Build payload for create-thread-and-run with streaming
|
||||
thread_messages = []
|
||||
for msg in messages:
|
||||
if msg.get("role") in ["user", "system"]:
|
||||
thread_messages.append(
|
||||
{"role": "user", "content": msg.get("content", "")}
|
||||
)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"assistant_id": agent_id,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# Add thread with messages if we don't have an existing thread
|
||||
if not thread_id:
|
||||
payload["thread"] = {"messages": thread_messages}
|
||||
|
||||
if "instructions" in optional_params:
|
||||
payload["instructions"] = optional_params["instructions"]
|
||||
|
||||
url = self._build_create_thread_and_run_url(api_base, api_version)
|
||||
verbose_logger.debug(f"Azure AI Agents streaming - URL: {url}")
|
||||
|
||||
# Use LiteLLM's async HTTP client for streaming
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.AZURE_AI,
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=json.dumps(payload),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 201]:
|
||||
error_text = await response.aread()
|
||||
raise AzureAIAgentsError(
|
||||
status_code=response.status_code,
|
||||
message=f"Streaming request failed: {error_text.decode()}",
|
||||
)
|
||||
|
||||
async for chunk in self._process_sse_stream(response, model):
|
||||
yield chunk
|
||||
|
||||
async def _process_sse_stream(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
model: str,
|
||||
) -> AsyncIterator:
|
||||
"""Process SSE stream and yield OpenAI-compatible streaming chunks."""
|
||||
from litellm.types.utils import Delta, ModelResponseStream, StreamingChoices
|
||||
|
||||
response_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
|
||||
created = int(time.time())
|
||||
thread_id = None
|
||||
|
||||
current_event = None
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
line = line.strip()
|
||||
|
||||
if line.startswith("event:"):
|
||||
current_event = line[6:].strip()
|
||||
continue
|
||||
|
||||
if line.startswith("data:"):
|
||||
data_str = line[5:].strip()
|
||||
|
||||
if data_str == "[DONE]":
|
||||
# Send final chunk with finish_reason
|
||||
final_chunk = ModelResponseStream(
|
||||
id=response_id,
|
||||
created=created,
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
delta=Delta(content=None),
|
||||
)
|
||||
],
|
||||
)
|
||||
if thread_id:
|
||||
final_chunk._hidden_params = {"thread_id": thread_id}
|
||||
yield final_chunk
|
||||
return
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Extract thread_id from thread.created event
|
||||
if current_event == "thread.created" and "id" in data:
|
||||
thread_id = data["id"]
|
||||
verbose_logger.debug(f"Stream created thread: {thread_id}")
|
||||
|
||||
# Process message deltas - this is where the actual content comes
|
||||
if current_event == "thread.message.delta":
|
||||
delta_content = data.get("delta", {}).get("content", [])
|
||||
for content_item in delta_content:
|
||||
if content_item.get("type") == "text":
|
||||
text_value = content_item.get("text", {}).get("value", "")
|
||||
if text_value:
|
||||
chunk = ModelResponseStream(
|
||||
id=response_id,
|
||||
created=created,
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=Delta(
|
||||
content=text_value, role="assistant"
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
if thread_id:
|
||||
chunk._hidden_params = {"thread_id": thread_id}
|
||||
yield chunk
|
||||
|
||||
|
||||
# Singleton instance
|
||||
azure_ai_agents_handler = AzureAIAgentsHandler()
|
||||
@@ -0,0 +1,402 @@
|
||||
"""
|
||||
Transformation for Azure Foundry Agent Service API.
|
||||
|
||||
Azure Foundry Agent Service provides an Assistants-like API for running agents.
|
||||
This follows the OpenAI Assistants pattern: create thread -> add messages -> create/poll run.
|
||||
|
||||
Model format: azure_ai/agents/<agent_id>
|
||||
|
||||
API Base format: https://<AIFoundryResourceName>.services.ai.azure.com/api/projects/<ProjectName>
|
||||
|
||||
Authentication: Uses Azure AD Bearer tokens (not API keys)
|
||||
Get token via: az account get-access-token --resource 'https://ai.azure.com'
|
||||
|
||||
The API uses these endpoints:
|
||||
- POST /threads - Create a thread
|
||||
- POST /threads/{thread_id}/messages - Add message to thread
|
||||
- POST /threads/{thread_id}/runs - Create a run
|
||||
- GET /threads/{thread_id}/runs/{run_id} - Poll run status
|
||||
- GET /threads/{thread_id}/messages - List messages in thread
|
||||
|
||||
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
HTTPHandler = Any
|
||||
AsyncHTTPHandler = Any
|
||||
|
||||
|
||||
class AzureAIAgentsError(BaseLLMException):
|
||||
"""Exception class for Azure AI Agent Service API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AzureAIAgentsConfig(BaseConfig):
|
||||
"""
|
||||
Configuration for Azure AI Agent Service API.
|
||||
|
||||
Azure AI Agent Service is a fully managed service for building AI agents
|
||||
that can understand natural language and perform tasks.
|
||||
|
||||
Model format: azure_ai/agents/<agent_id>
|
||||
|
||||
The flow is:
|
||||
1. Create a thread
|
||||
2. Add user messages to the thread
|
||||
3. Create and poll a run
|
||||
4. Retrieve the assistant's response messages
|
||||
"""
|
||||
|
||||
# Default API version for Azure Foundry Agent Service
|
||||
# GA version: 2025-05-01, Preview: 2025-05-15-preview
|
||||
# See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
DEFAULT_API_VERSION = "2025-05-01"
|
||||
|
||||
# Polling configuration
|
||||
MAX_POLL_ATTEMPTS = 60
|
||||
POLL_INTERVAL_SECONDS = 1.0
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def is_azure_ai_agents_route(model: str) -> bool:
|
||||
"""
|
||||
Check if the model is an Azure AI Agents route.
|
||||
|
||||
Model format: azure_ai/agents/<agent_id>
|
||||
"""
|
||||
return "agents/" in model
|
||||
|
||||
@staticmethod
|
||||
def get_agent_id_from_model(model: str) -> str:
|
||||
"""
|
||||
Extract agent ID from the model string.
|
||||
|
||||
Model format: azure_ai/agents/<agent_id> -> <agent_id>
|
||||
or: agents/<agent_id> -> <agent_id>
|
||||
"""
|
||||
if "agents/" in model:
|
||||
# Split on "agents/" and take the part after it
|
||||
parts = model.split("agents/", 1)
|
||||
if len(parts) == 2:
|
||||
return parts[1]
|
||||
return model
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Get Azure AI Agent Service API base and key from params or environment.
|
||||
|
||||
Returns:
|
||||
Tuple of (api_base, api_key)
|
||||
"""
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
api_base = api_base or get_secret_str("AZURE_AI_API_BASE")
|
||||
api_key = api_key or get_secret_str("AZURE_AI_API_KEY")
|
||||
|
||||
return api_base, api_key
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
Azure Agents supports minimal OpenAI params since it's an agent runtime.
|
||||
"""
|
||||
return ["stream"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI params to Azure Agents params.
|
||||
"""
|
||||
return optional_params
|
||||
|
||||
def _get_api_version(self, optional_params: dict) -> str:
|
||||
"""Get API version from optional params or use default."""
|
||||
return optional_params.get("api_version", self.DEFAULT_API_VERSION)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the base URL for Azure AI Agent Service.
|
||||
|
||||
The actual endpoint will vary based on the operation:
|
||||
- /openai/threads for creating threads
|
||||
- /openai/threads/{thread_id}/messages for adding messages
|
||||
- /openai/threads/{thread_id}/runs for creating runs
|
||||
|
||||
This returns the base URL that will be modified for each operation.
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"api_base is required for Azure AI Agents. Set it via AZURE_AI_API_BASE env var or api_base parameter."
|
||||
)
|
||||
|
||||
# Remove trailing slash if present
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
# Return base URL - actual endpoints will be constructed during request
|
||||
return api_base
|
||||
|
||||
def _get_agent_id(self, model: str, optional_params: dict) -> str:
|
||||
"""
|
||||
Get the agent ID from model or optional_params.
|
||||
|
||||
model format: "azure_ai/agents/<agent_id>" or "agents/<agent_id>" or just "<agent_id>"
|
||||
"""
|
||||
agent_id = optional_params.get("agent_id") or optional_params.get(
|
||||
"assistant_id"
|
||||
)
|
||||
if agent_id:
|
||||
return agent_id
|
||||
|
||||
# Extract from model name using the static method
|
||||
return self.get_agent_id_from_model(model)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the request for Azure Agents.
|
||||
|
||||
This stores the necessary data for the multi-step agent flow.
|
||||
The actual API calls happen in the custom handler.
|
||||
"""
|
||||
agent_id = self._get_agent_id(model, optional_params)
|
||||
|
||||
# Convert messages to a format we can use
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
# Handle content that might be a list
|
||||
if isinstance(content, list):
|
||||
content = convert_content_list_to_str(msg)
|
||||
|
||||
# Ensure content is a string
|
||||
if not isinstance(content, str):
|
||||
content = str(content)
|
||||
|
||||
converted_messages.append({"role": role, "content": content})
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"agent_id": agent_id,
|
||||
"messages": converted_messages,
|
||||
"api_version": self._get_api_version(optional_params),
|
||||
}
|
||||
|
||||
# Pass through thread_id if provided (for continuing conversations)
|
||||
if "thread_id" in optional_params:
|
||||
payload["thread_id"] = optional_params["thread_id"]
|
||||
|
||||
# Pass through any additional instructions
|
||||
if "instructions" in optional_params:
|
||||
payload["instructions"] = optional_params["instructions"]
|
||||
|
||||
verbose_logger.debug(f"Azure AI Agents request payload: {payload}")
|
||||
return payload
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate and set up environment for Azure Foundry Agents requests.
|
||||
|
||||
Azure Foundry Agents uses Bearer token authentication with Azure AD tokens.
|
||||
Get token via: az account get-access-token --resource 'https://ai.azure.com'
|
||||
|
||||
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
"""
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
# Azure Foundry Agents uses Bearer token authentication
|
||||
# The api_key here is expected to be an Azure AD token
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
return headers
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return AzureAIAgentsError(status_code=status_code, message=error_message)
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Azure Agents uses polling, so we fake stream by returning the final response.
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def has_custom_stream_wrapper(self) -> bool:
|
||||
"""Azure Agents doesn't have native streaming - uses fake stream."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_stream_param_in_request_body(self) -> bool:
|
||||
"""
|
||||
Azure Agents does not use a stream param in request body.
|
||||
"""
|
||||
return False
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform the Azure Agents response to LiteLLM ModelResponse format.
|
||||
"""
|
||||
# This is not used since we have a custom handler
|
||||
return model_response
|
||||
|
||||
@staticmethod
|
||||
def completion(
|
||||
model: str,
|
||||
messages: List,
|
||||
api_base: str,
|
||||
api_key: Optional[str],
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
timeout: Union[float, int, Any],
|
||||
acompletion: bool,
|
||||
stream: Optional[bool] = False,
|
||||
headers: Optional[dict] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Dispatch method for Azure Foundry Agents completion.
|
||||
|
||||
Routes to sync or async completion based on acompletion flag.
|
||||
Supports native streaming via SSE when stream=True and acompletion=True.
|
||||
|
||||
Authentication: Uses Azure AD Bearer tokens.
|
||||
- Pass api_key directly as an Azure AD token
|
||||
- Or set up Azure AD credentials via environment variables for automatic token retrieval:
|
||||
- AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET (Service Principal)
|
||||
|
||||
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
"""
|
||||
from litellm.llms.azure.common_utils import get_azure_ad_token
|
||||
from litellm.llms.azure_ai.agents.handler import azure_ai_agents_handler
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
# If no api_key is provided, try to get Azure AD token
|
||||
if api_key is None:
|
||||
# Try to get Azure AD token using the existing Azure auth mechanisms
|
||||
# This uses the scope for Azure AI (ai.azure.com) instead of cognitive services
|
||||
# Create a GenericLiteLLMParams with the scope override for Azure Foundry Agents
|
||||
azure_auth_params = dict(litellm_params) if litellm_params else {}
|
||||
azure_auth_params["azure_scope"] = "https://ai.azure.com/.default"
|
||||
api_key = get_azure_ad_token(GenericLiteLLMParams(**azure_auth_params))
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"api_key (Azure AD token) is required for Azure Foundry Agents. "
|
||||
"Either pass api_key directly, or set AZURE_TENANT_ID, AZURE_CLIENT_ID, "
|
||||
"and AZURE_CLIENT_SECRET environment variables for Service Principal auth. "
|
||||
"Manual token: az account get-access-token --resource 'https://ai.azure.com'"
|
||||
)
|
||||
if acompletion:
|
||||
if stream:
|
||||
# Native async streaming via SSE - return the async generator directly
|
||||
return azure_ai_agents_handler.acompletion_stream(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
return azure_ai_agents_handler.acompletion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
# Sync completion - streaming not supported for sync
|
||||
return azure_ai_agents_handler.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Azure Anthropic provider - supports Claude models via Azure Foundry
|
||||
"""
|
||||
from .handler import AzureAnthropicChatCompletion
|
||||
from .transformation import AzureAnthropicConfig
|
||||
|
||||
try:
|
||||
from .messages_transformation import AzureAnthropicMessagesConfig
|
||||
|
||||
__all__ = [
|
||||
"AzureAnthropicChatCompletion",
|
||||
"AzureAnthropicConfig",
|
||||
"AzureAnthropicMessagesConfig",
|
||||
]
|
||||
except ImportError:
|
||||
__all__ = ["AzureAnthropicChatCompletion", "AzureAnthropicConfig"]
|
||||
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Azure AI Anthropic CountTokens API implementation.
|
||||
"""
|
||||
|
||||
from litellm.llms.azure_ai.anthropic.count_tokens.handler import (
|
||||
AzureAIAnthropicCountTokensHandler,
|
||||
)
|
||||
from litellm.llms.azure_ai.anthropic.count_tokens.token_counter import (
|
||||
AzureAIAnthropicTokenCounter,
|
||||
)
|
||||
from litellm.llms.azure_ai.anthropic.count_tokens.transformation import (
|
||||
AzureAIAnthropicCountTokensConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AzureAIAnthropicCountTokensHandler",
|
||||
"AzureAIAnthropicCountTokensConfig",
|
||||
"AzureAIAnthropicTokenCounter",
|
||||
]
|
||||
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
Azure AI Anthropic CountTokens API handler.
|
||||
|
||||
Uses httpx for HTTP requests with Azure authentication.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.anthropic.common_utils import AnthropicError
|
||||
from litellm.llms.azure_ai.anthropic.count_tokens.transformation import (
|
||||
AzureAIAnthropicCountTokensConfig,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
|
||||
|
||||
class AzureAIAnthropicCountTokensHandler(AzureAIAnthropicCountTokensConfig):
|
||||
"""
|
||||
Handler for Azure AI Anthropic CountTokens API requests.
|
||||
|
||||
Uses httpx for HTTP requests with Azure authentication.
|
||||
"""
|
||||
|
||||
async def handle_count_tokens_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
system: Optional[Any] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle a CountTokens request using httpx with Azure authentication.
|
||||
|
||||
Args:
|
||||
model: The model identifier (e.g., "claude-3-5-sonnet")
|
||||
messages: The messages to count tokens for
|
||||
api_key: The Azure AI API key
|
||||
api_base: The Azure AI API base URL
|
||||
litellm_params: Optional LiteLLM parameters
|
||||
timeout: Optional timeout for the request (defaults to litellm.request_timeout)
|
||||
|
||||
Returns:
|
||||
Dictionary containing token count response
|
||||
|
||||
Raises:
|
||||
AnthropicError: If the API request fails
|
||||
"""
|
||||
try:
|
||||
# Validate the request
|
||||
self.validate_request(model, messages)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Processing Azure AI Anthropic CountTokens request for model: {model}"
|
||||
)
|
||||
|
||||
# Transform request to Anthropic format
|
||||
request_body = self.transform_request_to_count_tokens(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
system=system,
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Transformed request: {request_body}")
|
||||
|
||||
# Get endpoint URL
|
||||
endpoint_url = self.get_count_tokens_endpoint(api_base)
|
||||
|
||||
verbose_logger.debug(f"Making request to: {endpoint_url}")
|
||||
|
||||
# Get required headers with Azure authentication
|
||||
headers = self.get_required_headers(
|
||||
api_key=api_key,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
# Use LiteLLM's async httpx client
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.AZURE_AI
|
||||
)
|
||||
|
||||
# Use provided timeout or fall back to litellm.request_timeout
|
||||
request_timeout = (
|
||||
timeout if timeout is not None else litellm.request_timeout
|
||||
)
|
||||
|
||||
response = await async_client.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=request_body,
|
||||
timeout=request_timeout,
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Response status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
verbose_logger.error(f"Azure AI Anthropic API error: {error_text}")
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code,
|
||||
message=error_text,
|
||||
)
|
||||
|
||||
azure_response = response.json()
|
||||
|
||||
verbose_logger.debug(f"Azure AI Anthropic response: {azure_response}")
|
||||
|
||||
# Return Anthropic-compatible response directly - no transformation needed
|
||||
return azure_response
|
||||
|
||||
except AnthropicError:
|
||||
# Re-raise Anthropic exceptions as-is
|
||||
raise
|
||||
except httpx.HTTPStatusError as e:
|
||||
# HTTP errors - preserve the actual status code
|
||||
verbose_logger.error(f"HTTP error in CountTokens handler: {str(e)}")
|
||||
raise AnthropicError(
|
||||
status_code=e.response.status_code,
|
||||
message=e.response.text,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error in CountTokens handler: {str(e)}")
|
||||
raise AnthropicError(
|
||||
status_code=500,
|
||||
message=f"CountTokens processing error: {str(e)}",
|
||||
)
|
||||
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Azure AI Anthropic Token Counter implementation using the CountTokens API.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.azure_ai.anthropic.count_tokens.handler import (
|
||||
AzureAIAnthropicCountTokensHandler,
|
||||
)
|
||||
from litellm.llms.base_llm.base_utils import BaseTokenCounter
|
||||
from litellm.types.utils import LlmProviders, TokenCountResponse
|
||||
|
||||
# Global handler instance - reuse across all token counting requests
|
||||
azure_ai_anthropic_count_tokens_handler = AzureAIAnthropicCountTokensHandler()
|
||||
|
||||
|
||||
class AzureAIAnthropicTokenCounter(BaseTokenCounter):
|
||||
"""Token counter implementation for Azure AI Anthropic provider using the CountTokens API."""
|
||||
|
||||
def should_use_token_counting_api(
|
||||
self,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
return custom_llm_provider == LlmProviders.AZURE_AI.value
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
model_to_use: str,
|
||||
messages: Optional[List[Dict[str, Any]]],
|
||||
contents: Optional[List[Dict[str, Any]]],
|
||||
deployment: Optional[Dict[str, Any]] = None,
|
||||
request_model: str = "",
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
system: Optional[Any] = None,
|
||||
) -> Optional[TokenCountResponse]:
|
||||
"""
|
||||
Count tokens using Azure AI Anthropic's CountTokens API.
|
||||
|
||||
Args:
|
||||
model_to_use: The model identifier
|
||||
messages: The messages to count tokens for
|
||||
contents: Alternative content format (not used for Anthropic)
|
||||
deployment: Deployment configuration containing litellm_params
|
||||
request_model: The original request model name
|
||||
|
||||
Returns:
|
||||
TokenCountResponse with token count, or None if counting fails
|
||||
"""
|
||||
from litellm.llms.anthropic.common_utils import AnthropicError
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
deployment = deployment or {}
|
||||
litellm_params = deployment.get("litellm_params", {})
|
||||
|
||||
# Get Azure AI API key from deployment config or environment
|
||||
api_key = litellm_params.get("api_key")
|
||||
if not api_key:
|
||||
api_key = os.getenv("AZURE_AI_API_KEY")
|
||||
|
||||
# Get API base from deployment config or environment
|
||||
api_base = litellm_params.get("api_base")
|
||||
if not api_base:
|
||||
api_base = os.getenv("AZURE_AI_API_BASE")
|
||||
|
||||
if not api_key:
|
||||
verbose_logger.warning("No Azure AI API key found for token counting")
|
||||
return None
|
||||
|
||||
if not api_base:
|
||||
verbose_logger.warning("No Azure AI API base found for token counting")
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await azure_ai_anthropic_count_tokens_handler.handle_count_tokens_request(
|
||||
model=model_to_use,
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
tools=tools,
|
||||
system=system,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return TokenCountResponse(
|
||||
total_tokens=result.get("input_tokens", 0),
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="azure_ai_anthropic_api",
|
||||
original_response=result,
|
||||
)
|
||||
except AnthropicError as e:
|
||||
verbose_logger.warning(
|
||||
f"Azure AI Anthropic CountTokens API error: status={e.status_code}, message={e.message}"
|
||||
)
|
||||
return TokenCountResponse(
|
||||
total_tokens=0,
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="azure_ai_anthropic_api",
|
||||
error=True,
|
||||
error_message=e.message,
|
||||
status_code=e.status_code,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
f"Error calling Azure AI Anthropic CountTokens API: {e}"
|
||||
)
|
||||
return TokenCountResponse(
|
||||
total_tokens=0,
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="azure_ai_anthropic_api",
|
||||
error=True,
|
||||
error_message=str(e),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Azure AI Anthropic CountTokens API transformation logic.
|
||||
|
||||
Extends the base Anthropic CountTokens transformation with Azure authentication.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from litellm.constants import ANTHROPIC_TOKEN_COUNTING_BETA_VERSION
|
||||
from litellm.llms.anthropic.count_tokens.transformation import (
|
||||
AnthropicCountTokensConfig,
|
||||
)
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
|
||||
class AzureAIAnthropicCountTokensConfig(AnthropicCountTokensConfig):
|
||||
"""
|
||||
Configuration and transformation logic for Azure AI Anthropic CountTokens API.
|
||||
|
||||
Extends AnthropicCountTokensConfig with Azure authentication.
|
||||
Azure AI Anthropic uses the same endpoint format but with Azure auth headers.
|
||||
"""
|
||||
|
||||
def get_required_headers(
|
||||
self,
|
||||
api_key: str,
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Get the required headers for the Azure AI Anthropic CountTokens API.
|
||||
|
||||
Azure AI Anthropic uses Anthropic's native API format, which requires the
|
||||
x-api-key header for authentication (in addition to Azure's api-key header).
|
||||
|
||||
Args:
|
||||
api_key: The Azure AI API key
|
||||
litellm_params: Optional LiteLLM parameters for additional auth config
|
||||
|
||||
Returns:
|
||||
Dictionary of required headers with both x-api-key and Azure authentication
|
||||
"""
|
||||
# Start with base headers including x-api-key for Anthropic API compatibility
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"anthropic-version": "2023-06-01",
|
||||
"anthropic-beta": ANTHROPIC_TOKEN_COUNTING_BETA_VERSION,
|
||||
"x-api-key": api_key, # Azure AI Anthropic requires this header
|
||||
}
|
||||
|
||||
# Also set up Azure auth headers for flexibility
|
||||
litellm_params = litellm_params or {}
|
||||
if "api_key" not in litellm_params:
|
||||
litellm_params["api_key"] = api_key
|
||||
|
||||
litellm_params_obj = GenericLiteLLMParams(**litellm_params)
|
||||
|
||||
# Get Azure auth headers (api-key or Authorization)
|
||||
azure_headers = BaseAzureLLM._base_validate_azure_environment(
|
||||
headers={}, litellm_params=litellm_params_obj
|
||||
)
|
||||
|
||||
# Merge Azure auth headers
|
||||
headers.update(azure_headers)
|
||||
|
||||
return headers
|
||||
|
||||
def get_count_tokens_endpoint(self, api_base: str) -> str:
|
||||
"""
|
||||
Get the Azure AI Anthropic CountTokens API endpoint.
|
||||
|
||||
Args:
|
||||
api_base: The Azure AI API base URL
|
||||
(e.g., https://my-resource.services.ai.azure.com or
|
||||
https://my-resource.services.ai.azure.com/anthropic)
|
||||
|
||||
Returns:
|
||||
The endpoint URL for the CountTokens API
|
||||
"""
|
||||
# Azure AI Anthropic endpoint format:
|
||||
# https://<resource>.services.ai.azure.com/anthropic/v1/messages/count_tokens
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
# Ensure the URL has /anthropic path
|
||||
if not api_base.endswith("/anthropic"):
|
||||
if "/anthropic" not in api_base:
|
||||
api_base = f"{api_base}/anthropic"
|
||||
|
||||
# Add the count_tokens path
|
||||
return f"{api_base}/v1/messages/count_tokens"
|
||||
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
Azure Anthropic handler - reuses AnthropicChatCompletion logic with Azure authentication
|
||||
"""
|
||||
import copy
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Callable, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.anthropic.chat.handler import AnthropicChatCompletion
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
from .transformation import AzureAnthropicConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class AzureAnthropicChatCompletion(AnthropicChatCompletion):
|
||||
"""
|
||||
Azure Anthropic chat completion handler.
|
||||
Reuses all Anthropic logic but with Azure authentication.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_llm_provider: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
acompletion=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client=None,
|
||||
):
|
||||
"""
|
||||
Completion method that uses Azure authentication instead of Anthropic's x-api-key.
|
||||
All other logic is the same as AnthropicChatCompletion.
|
||||
"""
|
||||
|
||||
optional_params = copy.deepcopy(optional_params)
|
||||
stream = optional_params.pop("stream", None)
|
||||
json_mode: bool = optional_params.pop("json_mode", False)
|
||||
is_vertex_request: bool = optional_params.pop("is_vertex_request", False)
|
||||
_is_function_call = False
|
||||
messages = copy.deepcopy(messages)
|
||||
|
||||
# Use AzureAnthropicConfig for both azure_anthropic and azure_ai Claude models
|
||||
config = AzureAnthropicConfig()
|
||||
|
||||
headers = config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers=headers,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params={**optional_params, "is_vertex_request": is_vertex_request},
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
data = config.transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
print_verbose(f"_is_function_call: {_is_function_call}")
|
||||
if acompletion is True:
|
||||
if (
|
||||
stream is True
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
print_verbose("makes async azure anthropic streaming POST request")
|
||||
data["stream"] = stream
|
||||
return self.acompletion_stream_function(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream,
|
||||
_is_function_call=_is_function_call,
|
||||
json_mode=json_mode,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
return self.acompletion_function(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
provider_config=config,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream,
|
||||
_is_function_call=_is_function_call,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
client=client,
|
||||
json_mode=json_mode,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
## COMPLETION CALL
|
||||
if (
|
||||
stream is True
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
data["stream"] = stream
|
||||
# Import the make_sync_call from parent
|
||||
from litellm.llms.anthropic.chat.handler import make_sync_call
|
||||
|
||||
completion_stream, response_headers = make_sync_call(
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers, # type: ignore
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
from litellm.llms.anthropic.common_utils import (
|
||||
process_anthropic_headers,
|
||||
)
|
||||
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="azure_ai",
|
||||
logging_obj=logging_obj,
|
||||
_response_headers=process_anthropic_headers(response_headers),
|
||||
)
|
||||
|
||||
else:
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||
|
||||
client = _get_httpx_client(params={"timeout": timeout})
|
||||
else:
|
||||
client = client
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
from litellm.llms.anthropic.common_utils import AnthropicError
|
||||
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
if error_response and hasattr(error_response, "text"):
|
||||
error_text = getattr(error_response, "text", error_text)
|
||||
raise AnthropicError(
|
||||
message=error_text,
|
||||
status_code=status_code,
|
||||
headers=error_headers,
|
||||
)
|
||||
|
||||
return config.transform_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Azure Anthropic messages transformation config - extends AnthropicMessagesConfig with Azure authentication
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages.transformation import (
|
||||
AnthropicMessagesConfig,
|
||||
)
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class AzureAnthropicMessagesConfig(AnthropicMessagesConfig):
|
||||
"""
|
||||
Azure Anthropic messages configuration that extends AnthropicMessagesConfig.
|
||||
The only difference is authentication - Azure uses x-api-key header (not api-key)
|
||||
and Azure endpoint format.
|
||||
"""
|
||||
|
||||
def validate_anthropic_messages_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[Any],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Tuple[dict, Optional[str]]:
|
||||
"""
|
||||
Validate environment and set up Azure authentication headers for /v1/messages endpoint.
|
||||
Azure Anthropic uses x-api-key header (not api-key).
|
||||
"""
|
||||
# Convert dict to GenericLiteLLMParams if needed
|
||||
if isinstance(litellm_params, dict):
|
||||
if api_key and "api_key" not in litellm_params:
|
||||
litellm_params = {**litellm_params, "api_key": api_key}
|
||||
litellm_params_obj = GenericLiteLLMParams(**litellm_params)
|
||||
else:
|
||||
litellm_params_obj = litellm_params or GenericLiteLLMParams()
|
||||
if api_key and not litellm_params_obj.api_key:
|
||||
litellm_params_obj.api_key = api_key
|
||||
|
||||
# Use Azure authentication logic
|
||||
headers = BaseAzureLLM._base_validate_azure_environment(
|
||||
headers=headers, litellm_params=litellm_params_obj
|
||||
)
|
||||
|
||||
# Azure Anthropic uses x-api-key header (not api-key)
|
||||
# Convert api-key to x-api-key if present
|
||||
if "api-key" in headers and "x-api-key" not in headers:
|
||||
headers["x-api-key"] = headers.pop("api-key")
|
||||
|
||||
# Set anthropic-version header
|
||||
if "anthropic-version" not in headers:
|
||||
headers["anthropic-version"] = "2023-06-01"
|
||||
|
||||
# Set content-type header
|
||||
if "content-type" not in headers:
|
||||
headers["content-type"] = "application/json"
|
||||
|
||||
headers = self._update_headers_with_anthropic_beta(
|
||||
headers=headers,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
return headers, api_base
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Azure Anthropic /v1/messages endpoint.
|
||||
Azure Foundry endpoint format: https://<resource-name>.services.ai.azure.com/anthropic/v1/messages
|
||||
"""
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
api_base = api_base or get_secret_str("AZURE_API_BASE")
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Missing Azure API Base - Please set `api_base` or `AZURE_API_BASE` environment variable. "
|
||||
"Expected format: https://<resource-name>.services.ai.azure.com/anthropic"
|
||||
)
|
||||
|
||||
# Ensure the URL ends with /v1/messages
|
||||
api_base = api_base.rstrip("/")
|
||||
if api_base.endswith("/v1/messages"):
|
||||
# Already correct
|
||||
pass
|
||||
elif api_base.endswith("/anthropic/v1/messages"):
|
||||
# Already correct
|
||||
pass
|
||||
else:
|
||||
# Check if /anthropic is already in the path
|
||||
if "/anthropic" in api_base:
|
||||
# /anthropic exists, ensure we end with /anthropic/v1/messages
|
||||
# Extract the base URL up to and including /anthropic
|
||||
parts = api_base.split("/anthropic", 1)
|
||||
api_base = parts[0] + "/anthropic"
|
||||
else:
|
||||
# /anthropic not in path, add it
|
||||
api_base = api_base + "/anthropic"
|
||||
# Add /v1/messages
|
||||
api_base = api_base + "/v1/messages"
|
||||
|
||||
return api_base
|
||||
|
||||
def _remove_scope_from_cache_control(
|
||||
self, anthropic_messages_request: Dict
|
||||
) -> None:
|
||||
"""
|
||||
Remove `scope` field from cache_control for Azure AI Foundry.
|
||||
|
||||
Azure AI Foundry's Anthropic endpoint does not support the `scope` field
|
||||
(e.g., "global" for cross-request caching). Only `type` and `ttl` are supported.
|
||||
|
||||
Processes both `system` and `messages` content blocks.
|
||||
"""
|
||||
|
||||
def _sanitize(cache_control: Any) -> None:
|
||||
if isinstance(cache_control, dict):
|
||||
cache_control.pop("scope", None)
|
||||
|
||||
def _process_content_list(content: list) -> None:
|
||||
for item in content:
|
||||
if isinstance(item, dict) and "cache_control" in item:
|
||||
_sanitize(item["cache_control"])
|
||||
|
||||
if "system" in anthropic_messages_request:
|
||||
system = anthropic_messages_request["system"]
|
||||
if isinstance(system, list):
|
||||
_process_content_list(system)
|
||||
|
||||
if "messages" in anthropic_messages_request:
|
||||
for message in anthropic_messages_request["messages"]:
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
content = message["content"]
|
||||
if isinstance(content, list):
|
||||
_process_content_list(content)
|
||||
|
||||
def transform_anthropic_messages_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict],
|
||||
anthropic_messages_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
anthropic_messages_request = super().transform_anthropic_messages_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
anthropic_messages_optional_request_params=anthropic_messages_optional_request_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
self._remove_scope_from_cache_control(anthropic_messages_request)
|
||||
return anthropic_messages_request
|
||||
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Azure Anthropic transformation config - extends AnthropicConfig with Azure authentication
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class AzureAnthropicConfig(AnthropicConfig):
|
||||
"""
|
||||
Azure Anthropic configuration that extends AnthropicConfig.
|
||||
The only difference is authentication - Azure uses api-key header or Azure AD token
|
||||
instead of x-api-key header.
|
||||
"""
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return "azure_ai"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: Union[dict, GenericLiteLLMParams],
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Validate environment and set up Azure authentication headers.
|
||||
Azure supports:
|
||||
1. API key via 'api-key' header
|
||||
2. Azure AD token via 'Authorization: Bearer <token>' header
|
||||
"""
|
||||
# Convert dict to GenericLiteLLMParams if needed
|
||||
if isinstance(litellm_params, dict):
|
||||
# Ensure api_key is included if provided
|
||||
if api_key and "api_key" not in litellm_params:
|
||||
litellm_params = {**litellm_params, "api_key": api_key}
|
||||
litellm_params_obj = GenericLiteLLMParams(**litellm_params)
|
||||
else:
|
||||
litellm_params_obj = litellm_params or GenericLiteLLMParams()
|
||||
# Set api_key if provided and not already set
|
||||
if api_key and not litellm_params_obj.api_key:
|
||||
litellm_params_obj.api_key = api_key
|
||||
|
||||
# Use Azure authentication logic
|
||||
headers = BaseAzureLLM._base_validate_azure_environment(
|
||||
headers=headers, litellm_params=litellm_params_obj
|
||||
)
|
||||
|
||||
# Get tools and other anthropic-specific setup
|
||||
tools = optional_params.get("tools")
|
||||
prompt_caching_set = self.is_cache_control_set(messages=messages)
|
||||
computer_tool_used = self.is_computer_tool_used(tools=tools)
|
||||
mcp_server_used = self.is_mcp_server_used(
|
||||
mcp_servers=optional_params.get("mcp_servers")
|
||||
)
|
||||
pdf_used = self.is_pdf_used(messages=messages)
|
||||
file_id_used = self.is_file_id_used(messages=messages)
|
||||
user_anthropic_beta_headers = self._get_user_anthropic_beta_headers(
|
||||
anthropic_beta_header=headers.get("anthropic-beta")
|
||||
)
|
||||
|
||||
# Get anthropic headers (but we'll replace x-api-key with Azure auth)
|
||||
anthropic_headers = self.get_anthropic_headers(
|
||||
computer_tool_used=computer_tool_used,
|
||||
prompt_caching_set=prompt_caching_set,
|
||||
pdf_used=pdf_used,
|
||||
api_key=api_key or "", # Azure auth is already in headers
|
||||
file_id_used=file_id_used,
|
||||
is_vertex_request=optional_params.get("is_vertex_request", False),
|
||||
user_anthropic_beta_headers=user_anthropic_beta_headers,
|
||||
mcp_server_used=mcp_server_used,
|
||||
)
|
||||
# Merge headers - Azure auth (api-key or Authorization) takes precedence
|
||||
headers = {**anthropic_headers, **headers}
|
||||
|
||||
# Ensure anthropic-version header is set
|
||||
if "anthropic-version" not in headers:
|
||||
headers["anthropic-version"] = "2023-06-01"
|
||||
|
||||
return headers
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform request using parent AnthropicConfig, then remove unsupported params.
|
||||
Azure Anthropic doesn't support extra_body, max_retries, or stream_options parameters.
|
||||
"""
|
||||
# Call parent transform_request
|
||||
data = super().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Remove unsupported parameters for Azure AI Anthropic
|
||||
data.pop("extra_body", None)
|
||||
data.pop("max_retries", None)
|
||||
data.pop("stream_options", None)
|
||||
|
||||
return data
|
||||
@@ -0,0 +1,4 @@
|
||||
"""Azure AI Foundry Model Router support."""
|
||||
from .transformation import AzureModelRouterConfig
|
||||
|
||||
__all__ = ["AzureModelRouterConfig"]
|
||||
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
Transformation for Azure AI Foundry Model Router.
|
||||
|
||||
The Model Router is a special Azure AI deployment that automatically routes requests
|
||||
to the best available model. It has specific cost tracking requirements.
|
||||
"""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from httpx import Response
|
||||
|
||||
from litellm.llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
||||
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
|
||||
class AzureModelRouterConfig(AzureAIStudioConfig):
|
||||
"""
|
||||
Configuration for Azure AI Foundry Model Router.
|
||||
|
||||
Handles:
|
||||
- Stripping model_router prefix before sending to Azure API
|
||||
- Preserving full model path in responses for cost tracking
|
||||
- Calculating flat infrastructure costs for Model Router
|
||||
"""
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform request for Model Router.
|
||||
|
||||
Strips the model_router/ prefix so only the deployment name is sent to Azure.
|
||||
Example: model_router/azure-model-router -> azure-model-router
|
||||
"""
|
||||
from litellm.llms.azure_ai.common_utils import AzureFoundryModelInfo
|
||||
|
||||
# Get base model name (strips routing prefixes like model_router/)
|
||||
base_model: str = AzureFoundryModelInfo.get_base_model(model)
|
||||
|
||||
return super().transform_request(
|
||||
base_model, messages, optional_params, litellm_params, headers
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform response for Model Router.
|
||||
|
||||
Extracts the actual model used from the Azure response (e.g., gpt-5-nano-2025-08-07)
|
||||
and returns it with the azure_ai/ prefix for proper display and cost tracking.
|
||||
"""
|
||||
from litellm.llms.azure_ai.common_utils import AzureFoundryModelInfo
|
||||
|
||||
# Get base model for the parent call (strips routing prefixes for API compatibility)
|
||||
base_model: str = AzureFoundryModelInfo.get_base_model(model)
|
||||
|
||||
# Call parent transform_response first - this will extract the actual model
|
||||
# from the raw response (e.g., "gpt-5-nano-2025-08-07")
|
||||
model_response = super().transform_response(
|
||||
model=base_model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
return model_response
|
||||
|
||||
def calculate_additional_costs(
|
||||
self, model: str, prompt_tokens: int, completion_tokens: int
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Calculate additional costs for Azure Model Router.
|
||||
|
||||
Adds a flat infrastructure cost of $0.14 per M input tokens for using the Model Router.
|
||||
|
||||
Args:
|
||||
model: The model name (should be a model router model)
|
||||
prompt_tokens: Number of prompt tokens
|
||||
completion_tokens: Number of completion tokens
|
||||
|
||||
Returns:
|
||||
Dictionary with additional costs, or None if not applicable.
|
||||
"""
|
||||
from litellm.llms.azure_ai.cost_calculator import (
|
||||
calculate_azure_model_router_flat_cost,
|
||||
)
|
||||
|
||||
flat_cost = calculate_azure_model_router_flat_cost(
|
||||
model=model, prompt_tokens=prompt_tokens
|
||||
)
|
||||
|
||||
if flat_cost > 0:
|
||||
return {"Azure Model Router Flat Cost": flat_cost}
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
LLM Calling done in `openai/openai.py`
|
||||
"""
|
||||
@@ -0,0 +1,350 @@
|
||||
import enum
|
||||
from typing import Any, List, Optional, Tuple, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from httpx import Response
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
_audio_or_image_in_message_content,
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
|
||||
from litellm.llms.openai.common_utils import drop_params_from_unprocessable_entity_error
|
||||
from litellm.llms.openai.openai import OpenAIConfig
|
||||
from litellm.llms.xai.chat.transformation import XAIChatConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import ModelResponse, ProviderField
|
||||
from litellm.utils import _add_path_to_api_base, supports_tool_choice
|
||||
|
||||
|
||||
class AzureFoundryErrorStrings(str, enum.Enum):
|
||||
SET_EXTRA_PARAMETERS_TO_PASS_THROUGH = "Set extra-parameters to 'pass-through'"
|
||||
|
||||
|
||||
class AzureAIStudioConfig(OpenAIConfig):
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
model_supports_tool_choice = True # azure ai supports this by default
|
||||
if not supports_tool_choice(model=f"azure_ai/{model}"):
|
||||
model_supports_tool_choice = False
|
||||
supported_params = super().get_supported_openai_params(model)
|
||||
if not model_supports_tool_choice:
|
||||
filtered_supported_params = []
|
||||
for param in supported_params:
|
||||
if param != "tool_choice":
|
||||
filtered_supported_params.append(param)
|
||||
supported_params = filtered_supported_params
|
||||
|
||||
# Filter out unsupported parameters for specific models
|
||||
if not self._supports_stop_reason(model):
|
||||
supported_params = [param for param in supported_params if param != "stop"]
|
||||
|
||||
return supported_params
|
||||
|
||||
def _supports_stop_reason(self, model: str) -> bool:
|
||||
"""
|
||||
Check if the model supports stop tokens.
|
||||
"""
|
||||
if "grok" in model:
|
||||
# Reuse Xai method for Grok model
|
||||
xai_config = XAIChatConfig()
|
||||
return xai_config._supports_stop_reason(model)
|
||||
return True
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key:
|
||||
if api_base and self._should_use_api_key_header(api_base):
|
||||
headers["api-key"] = api_key
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
else:
|
||||
# No api_key provided — fall back to Azure AD token-based auth
|
||||
litellm_params_obj = GenericLiteLLMParams(
|
||||
**(litellm_params if isinstance(litellm_params, dict) else {})
|
||||
)
|
||||
headers = BaseAzureLLM._base_validate_azure_environment(
|
||||
headers=headers, litellm_params=litellm_params_obj
|
||||
)
|
||||
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
def _should_use_api_key_header(self, api_base: str) -> bool:
|
||||
"""
|
||||
Returns True if the request should use `api-key` header for authentication.
|
||||
"""
|
||||
parsed_url = urlparse(api_base)
|
||||
host = parsed_url.hostname
|
||||
if host and (
|
||||
host.endswith(".services.ai.azure.com")
|
||||
or host.endswith(".openai.azure.com")
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Constructs a complete URL for the API request.
|
||||
|
||||
Args:
|
||||
- api_base: Base URL, e.g.,
|
||||
"https://litellm8397336933.services.ai.azure.com"
|
||||
OR
|
||||
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
|
||||
- model: Model name.
|
||||
- optional_params: Additional query parameters, including "api_version".
|
||||
- stream: If streaming is required (optional).
|
||||
|
||||
Returns:
|
||||
- A complete URL string, e.g.,
|
||||
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`"
|
||||
)
|
||||
original_url = httpx.URL(api_base)
|
||||
|
||||
# Extract api_version or use default
|
||||
api_version = cast(Optional[str], litellm_params.get("api_version"))
|
||||
|
||||
# Create a new dictionary with existing params
|
||||
query_params = dict(original_url.params)
|
||||
|
||||
# Add api_version if needed
|
||||
if "api-version" not in query_params and api_version:
|
||||
query_params["api-version"] = api_version
|
||||
|
||||
# Add the path to the base URL
|
||||
if "services.ai.azure.com" in api_base:
|
||||
new_url = _add_path_to_api_base(
|
||||
api_base=api_base, ending_path="/models/chat/completions"
|
||||
)
|
||||
else:
|
||||
new_url = _add_path_to_api_base(
|
||||
api_base=api_base, ending_path="/chat/completions"
|
||||
)
|
||||
|
||||
# Use the new query_params dictionary
|
||||
final_url = httpx.URL(new_url).copy_with(params=query_params)
|
||||
|
||||
return str(final_url)
|
||||
|
||||
def get_required_params(self) -> List[ProviderField]:
|
||||
"""For a given provider, return it's required fields with a description"""
|
||||
return [
|
||||
ProviderField(
|
||||
field_name="api_key",
|
||||
field_type="string",
|
||||
field_description="Your Azure AI Studio API Key.",
|
||||
field_value="zEJ...",
|
||||
),
|
||||
ProviderField(
|
||||
field_name="api_base",
|
||||
field_type="string",
|
||||
field_description="Your Azure AI Studio API Base.",
|
||||
field_value="https://Mistral-serverless.",
|
||||
),
|
||||
]
|
||||
|
||||
def _transform_messages(
|
||||
self,
|
||||
messages: List[AllMessageValues],
|
||||
model: str,
|
||||
) -> List:
|
||||
"""
|
||||
- Azure AI Studio doesn't support content as a list. This handles:
|
||||
1. Transforms list content to a string.
|
||||
2. If message contains an image or audio, send as is (user-intended)
|
||||
"""
|
||||
for message in messages:
|
||||
# Do nothing if the message contains an image or audio
|
||||
if _audio_or_image_in_message_content(message):
|
||||
continue
|
||||
|
||||
texts = convert_content_list_to_str(message=message)
|
||||
if texts:
|
||||
message["content"] = texts
|
||||
return messages
|
||||
|
||||
def _is_azure_openai_model(self, model: str, api_base: Optional[str]) -> bool:
|
||||
try:
|
||||
if "/" in model:
|
||||
model = model.split("/", 1)[1]
|
||||
if (
|
||||
model in litellm.open_ai_chat_completion_models
|
||||
or model in litellm.open_ai_text_completion_models
|
||||
or model in litellm.open_ai_embedding_models
|
||||
):
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
custom_llm_provider: str,
|
||||
) -> Tuple[Optional[str], Optional[str], str]:
|
||||
api_base = api_base or get_secret_str("AZURE_AI_API_BASE")
|
||||
dynamic_api_key = api_key or get_secret_str("AZURE_AI_API_KEY")
|
||||
|
||||
if self._is_azure_openai_model(model=model, api_base=api_base):
|
||||
verbose_logger.debug(
|
||||
"Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format(
|
||||
model
|
||||
)
|
||||
)
|
||||
custom_llm_provider = "azure"
|
||||
return api_base, dynamic_api_key, custom_llm_provider
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
extra_body = optional_params.pop("extra_body", {})
|
||||
if extra_body and isinstance(extra_body, dict):
|
||||
optional_params.update(extra_body)
|
||||
optional_params.pop("max_retries", None)
|
||||
return super().transform_request(
|
||||
model, messages, optional_params, litellm_params, headers
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
model_response.model = f"azure_ai/{model}"
|
||||
return super().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||
self, e: httpx.HTTPStatusError, litellm_params: dict
|
||||
) -> bool:
|
||||
should_drop_params = litellm_params.get("drop_params") or litellm.drop_params
|
||||
error_text = e.response.text
|
||||
|
||||
if should_drop_params and "Extra inputs are not permitted" in error_text:
|
||||
return True
|
||||
elif (
|
||||
"unknown field: parameter index is not a valid field" in error_text
|
||||
): # remove index from tool calls
|
||||
return True
|
||||
elif (
|
||||
AzureFoundryErrorStrings.SET_EXTRA_PARAMETERS_TO_PASS_THROUGH.value
|
||||
in error_text
|
||||
): # remove extra-parameters from tool calls
|
||||
return True
|
||||
return super().should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||
e=e, litellm_params=litellm_params
|
||||
)
|
||||
|
||||
@property
|
||||
def max_retry_on_unprocessable_entity_error(self) -> int:
|
||||
return 2
|
||||
|
||||
def transform_request_on_unprocessable_entity_error(
|
||||
self, e: httpx.HTTPStatusError, request_data: dict
|
||||
) -> dict:
|
||||
_messages = cast(Optional[List[AllMessageValues]], request_data.get("messages"))
|
||||
if (
|
||||
"unknown field: parameter index is not a valid field" in e.response.text
|
||||
and _messages is not None
|
||||
):
|
||||
litellm.remove_index_from_tool_calls(
|
||||
messages=_messages,
|
||||
)
|
||||
elif (
|
||||
AzureFoundryErrorStrings.SET_EXTRA_PARAMETERS_TO_PASS_THROUGH.value
|
||||
in e.response.text
|
||||
):
|
||||
request_data = self._drop_extra_params_from_request_data(
|
||||
request_data, e.response.text
|
||||
)
|
||||
data = drop_params_from_unprocessable_entity_error(e=e, data=request_data)
|
||||
return data
|
||||
|
||||
def _drop_extra_params_from_request_data(
|
||||
self, request_data: dict, error_text: str
|
||||
) -> dict:
|
||||
params_to_drop = self._extract_params_to_drop_from_error_text(error_text)
|
||||
if params_to_drop:
|
||||
for param in params_to_drop:
|
||||
if param in request_data:
|
||||
request_data.pop(param, None)
|
||||
return request_data
|
||||
|
||||
def _extract_params_to_drop_from_error_text(
|
||||
self, error_text: str
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Error text looks like this"
|
||||
"Extra parameters ['stream_options', 'extra-parameters'] are not allowed when extra-parameters is not set or set to be 'error'.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Extract parameters within square brackets
|
||||
match = re.search(r"\[(.*?)\]", error_text)
|
||||
if not match:
|
||||
return []
|
||||
|
||||
# Parse the extracted string into a list of parameter names
|
||||
params_str = match.group(1)
|
||||
params = []
|
||||
for param in params_str.split(","):
|
||||
# Clean up the parameter name (remove quotes, spaces)
|
||||
clean_param = param.strip().strip("'").strip('"')
|
||||
if clean_param:
|
||||
params.append(clean_param)
|
||||
return params
|
||||
@@ -0,0 +1,176 @@
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo, BaseTokenCounter
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
class AzureFoundryModelInfo(BaseLLMModelInfo):
|
||||
"""Model info for Azure AI / Azure Foundry models."""
|
||||
|
||||
def __init__(self, model: Optional[str] = None):
|
||||
self._model = model
|
||||
|
||||
@staticmethod
|
||||
def get_azure_ai_route(model: str) -> Literal["agents", "model_router", "default"]:
|
||||
"""
|
||||
Get the Azure AI route for the given model.
|
||||
|
||||
Similar to BedrockModelInfo.get_bedrock_route().
|
||||
|
||||
Supported routes:
|
||||
- agents: azure_ai/agents/<agent_id>
|
||||
- model_router: azure_ai/model_router/<actual-model-name> or models with "model-router"/"model_router" in name
|
||||
- default: standard models
|
||||
"""
|
||||
if "agents/" in model:
|
||||
return "agents"
|
||||
# Detect model router by prefix (model_router/<name>) or by name containing "model-router"/"model_router"
|
||||
model_lower = model.lower()
|
||||
if (
|
||||
"model_router/" in model_lower
|
||||
or "model-router/" in model_lower
|
||||
or "model-router" in model_lower
|
||||
or "model_router" in model_lower
|
||||
):
|
||||
return "model_router"
|
||||
return "default"
|
||||
|
||||
@staticmethod
|
||||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||
return api_base or litellm.api_base or get_secret_str("AZURE_AI_API_BASE")
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
||||
return (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("AZURE_AI_API_KEY")
|
||||
)
|
||||
|
||||
@property
|
||||
def api_version(self, api_version: Optional[str] = None) -> Optional[str]:
|
||||
api_version = (
|
||||
api_version or litellm.api_version or get_secret_str("AZURE_API_VERSION")
|
||||
)
|
||||
return api_version
|
||||
|
||||
def get_token_counter(self) -> Optional[BaseTokenCounter]:
|
||||
"""
|
||||
Factory method to create a token counter for Azure AI.
|
||||
|
||||
Returns:
|
||||
AzureAIAnthropicTokenCounter for Claude models, None otherwise.
|
||||
"""
|
||||
# Only return token counter for Claude models
|
||||
if self._model and "claude" in self._model.lower():
|
||||
from litellm.llms.azure_ai.anthropic.count_tokens.token_counter import (
|
||||
AzureAIAnthropicTokenCounter,
|
||||
)
|
||||
|
||||
return AzureAIAnthropicTokenCounter()
|
||||
return None
|
||||
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns a list of models supported by Azure AI.
|
||||
|
||||
Azure AI doesn't have a standard model listing endpoint,
|
||||
so this returns an empty list.
|
||||
"""
|
||||
return []
|
||||
|
||||
#########################################################
|
||||
# Not implemented methods
|
||||
#########################################################
|
||||
|
||||
@staticmethod
|
||||
def strip_model_router_prefix(model: str) -> str:
|
||||
"""
|
||||
Strip the model_router prefix from model name.
|
||||
|
||||
Examples:
|
||||
- "model_router/gpt-4o" -> "gpt-4o"
|
||||
- "model-router/gpt-4o" -> "gpt-4o"
|
||||
- "gpt-4o" -> "gpt-4o"
|
||||
|
||||
Args:
|
||||
model: Model name potentially with model_router prefix
|
||||
|
||||
Returns:
|
||||
Model name without the prefix
|
||||
"""
|
||||
if "model_router/" in model:
|
||||
return model.split("model_router/", 1)[1]
|
||||
if "model-router/" in model:
|
||||
return model.split("model-router/", 1)[1]
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: str) -> str:
|
||||
"""
|
||||
Get the base model name, stripping any Azure AI routing prefixes.
|
||||
|
||||
Args:
|
||||
model: Model name potentially with routing prefixes
|
||||
|
||||
Returns:
|
||||
Base model name
|
||||
"""
|
||||
# Strip model_router prefix if present
|
||||
model = AzureFoundryModelInfo.strip_model_router_prefix(model)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def get_azure_ai_config_for_model(model: str):
|
||||
"""
|
||||
Get the appropriate Azure AI config class for the given model.
|
||||
|
||||
Routes to specialized configs based on model type:
|
||||
- Model Router: AzureModelRouterConfig
|
||||
- Claude models: AzureAnthropicConfig
|
||||
- Default: AzureAIStudioConfig
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
|
||||
Returns:
|
||||
The appropriate config instance
|
||||
"""
|
||||
azure_ai_route = AzureFoundryModelInfo.get_azure_ai_route(model)
|
||||
|
||||
if azure_ai_route == "model_router":
|
||||
from litellm.llms.azure_ai.azure_model_router.transformation import (
|
||||
AzureModelRouterConfig,
|
||||
)
|
||||
|
||||
return AzureModelRouterConfig()
|
||||
elif "claude" in model.lower():
|
||||
from litellm.llms.azure_ai.anthropic.transformation import (
|
||||
AzureAnthropicConfig,
|
||||
)
|
||||
|
||||
return AzureAnthropicConfig()
|
||||
else:
|
||||
from litellm.llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
||||
|
||||
return AzureAIStudioConfig()
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Azure Foundry sends api key in query params"""
|
||||
raise NotImplementedError(
|
||||
"Azure Foundry does not support environment validation"
|
||||
)
|
||||
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Azure AI cost calculation helper.
|
||||
Handles Azure AI Foundry Model Router flat cost and other Azure AI specific pricing.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
|
||||
from litellm.types.utils import Usage
|
||||
from litellm.utils import get_model_info
|
||||
|
||||
|
||||
def _is_azure_model_router(model: str) -> bool:
|
||||
"""
|
||||
Check if the model is Azure AI Foundry Model Router.
|
||||
|
||||
Detects patterns like:
|
||||
- "azure-model-router"
|
||||
- "model-router"
|
||||
- "model_router/<actual-model>"
|
||||
- "model-router/<actual-model>"
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
|
||||
Returns:
|
||||
bool: True if this is a model router model
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
return (
|
||||
"model-router" in model_lower
|
||||
or "model_router" in model_lower
|
||||
or model_lower == "azure-model-router"
|
||||
)
|
||||
|
||||
|
||||
def calculate_azure_model_router_flat_cost(model: str, prompt_tokens: int) -> float:
|
||||
"""
|
||||
Calculate the flat cost for Azure AI Foundry Model Router.
|
||||
|
||||
Args:
|
||||
model: The model name (should be a model router model)
|
||||
prompt_tokens: Number of prompt tokens
|
||||
|
||||
Returns:
|
||||
float: The flat cost in USD, or 0.0 if not applicable
|
||||
"""
|
||||
if not _is_azure_model_router(model):
|
||||
return 0.0
|
||||
|
||||
# Get the model router pricing from model_prices_and_context_window.json
|
||||
# Use "model_router" as the key (without actual model name suffix)
|
||||
model_info = get_model_info(model="model_router", custom_llm_provider="azure_ai")
|
||||
router_flat_cost_per_token = model_info.get("input_cost_per_token", 0)
|
||||
|
||||
if router_flat_cost_per_token > 0:
|
||||
return prompt_tokens * router_flat_cost_per_token
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
def cost_per_token(
|
||||
model: str,
|
||||
usage: Usage,
|
||||
response_time_ms: Optional[float] = 0.0,
|
||||
request_model: Optional[str] = None,
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculate the cost per token for Azure AI models.
|
||||
|
||||
For Azure AI Foundry Model Router:
|
||||
- Adds a flat cost of $0.14 per million input tokens (from model_prices_and_context_window.json)
|
||||
- Plus the cost of the actual model used (handled by generic_cost_per_token)
|
||||
|
||||
Args:
|
||||
model: str, the model name without provider prefix (from response)
|
||||
usage: LiteLLM Usage block
|
||||
response_time_ms: Optional response time in milliseconds
|
||||
request_model: Optional[str], the original request model name (to detect router usage)
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not found in the cost map and cost cannot be calculated
|
||||
(except for Model Router models where we return just the routing flat cost)
|
||||
"""
|
||||
prompt_cost = 0.0
|
||||
completion_cost = 0.0
|
||||
|
||||
# Determine if this was a model router request
|
||||
# Check both the response model and the request model
|
||||
is_router_request = _is_azure_model_router(model) or (
|
||||
request_model is not None and _is_azure_model_router(request_model)
|
||||
)
|
||||
|
||||
# Calculate base cost using generic cost calculator
|
||||
# This may raise an exception if the model is not in the cost map
|
||||
try:
|
||||
prompt_cost, completion_cost = generic_cost_per_token(
|
||||
model=model,
|
||||
usage=usage,
|
||||
custom_llm_provider="azure_ai",
|
||||
)
|
||||
except Exception as e:
|
||||
# For Model Router, the model name (e.g., "azure-model-router") may not be in the cost map
|
||||
# because it's a routing service, not an actual model. In this case, we continue
|
||||
# to calculate just the routing flat cost.
|
||||
if not _is_azure_model_router(model):
|
||||
# Re-raise for non-router models - they should have pricing defined
|
||||
raise
|
||||
verbose_logger.debug(
|
||||
f"Azure AI Model Router: model '{model}' not in cost map, calculating routing flat cost only. Error: {e}"
|
||||
)
|
||||
|
||||
# Add flat cost for Azure Model Router
|
||||
# The flat cost is defined in model_prices_and_context_window.json for azure_ai/model_router
|
||||
if is_router_request:
|
||||
# Use the request model for flat cost calculation if available, otherwise use response model
|
||||
router_model_for_calc = request_model if request_model else model
|
||||
router_flat_cost = calculate_azure_model_router_flat_cost(
|
||||
router_model_for_calc, usage.prompt_tokens
|
||||
)
|
||||
|
||||
if router_flat_cost > 0:
|
||||
verbose_logger.debug(
|
||||
f"Azure AI Model Router flat cost: ${router_flat_cost:.6f} "
|
||||
f"({usage.prompt_tokens} tokens × ${router_flat_cost / usage.prompt_tokens:.9f}/token)"
|
||||
)
|
||||
|
||||
# Add flat cost to prompt cost
|
||||
prompt_cost += router_flat_cost
|
||||
|
||||
return prompt_cost, completion_cost
|
||||
@@ -0,0 +1 @@
|
||||
from .handler import AzureAIEmbedding
|
||||
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Azure AI Cohere's /v1/embed.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
Convers
|
||||
- Cohere request format
|
||||
|
||||
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from litellm.types.llms.azure_ai import ImageEmbeddingInput, ImageEmbeddingRequest
|
||||
from litellm.types.llms.openai import EmbeddingCreateParams
|
||||
from litellm.types.utils import EmbeddingResponse, Usage
|
||||
from litellm.utils import is_base64_encoded
|
||||
|
||||
|
||||
class AzureAICohereConfig:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def _map_azure_model_group(self, model: str) -> str:
|
||||
if model == "offer-cohere-embed-multili-paygo":
|
||||
return "Cohere-embed-v3-multilingual"
|
||||
elif model == "offer-cohere-embed-english-paygo":
|
||||
return "Cohere-embed-v3-english"
|
||||
|
||||
return model
|
||||
|
||||
def _transform_request_image_embeddings(
|
||||
self, input: List[str], optional_params: dict
|
||||
) -> ImageEmbeddingRequest:
|
||||
"""
|
||||
Assume all str in list is base64 encoded string
|
||||
"""
|
||||
image_input: List[ImageEmbeddingInput] = []
|
||||
for i in input:
|
||||
embedding_input = ImageEmbeddingInput(image=i)
|
||||
image_input.append(embedding_input)
|
||||
return ImageEmbeddingRequest(input=image_input, **optional_params)
|
||||
|
||||
def _transform_request(
|
||||
self, input: List[str], optional_params: dict, model: str
|
||||
) -> Tuple[ImageEmbeddingRequest, EmbeddingCreateParams, List[int]]:
|
||||
"""
|
||||
Return the list of input to `/image/embeddings`, `/v1/embeddings`, list of image_embedding_idx for recombination
|
||||
"""
|
||||
image_embeddings: List[str] = []
|
||||
image_embedding_idx: List[int] = []
|
||||
for idx, i in enumerate(input):
|
||||
"""
|
||||
- is base64 -> route to image embeddings
|
||||
- is ImageEmbeddingInput -> route to image embeddings
|
||||
- else -> route to `/v1/embeddings`
|
||||
"""
|
||||
if is_base64_encoded(i):
|
||||
image_embeddings.append(i)
|
||||
image_embedding_idx.append(idx)
|
||||
|
||||
## REMOVE IMAGE EMBEDDINGS FROM input list
|
||||
filtered_input = [
|
||||
item for idx, item in enumerate(input) if idx not in image_embedding_idx
|
||||
]
|
||||
|
||||
v1_embeddings_request = EmbeddingCreateParams(
|
||||
input=filtered_input, model=model, **optional_params
|
||||
)
|
||||
image_embeddings_request = self._transform_request_image_embeddings(
|
||||
input=image_embeddings, optional_params=optional_params
|
||||
)
|
||||
|
||||
return image_embeddings_request, v1_embeddings_request, image_embedding_idx
|
||||
|
||||
def _transform_response(self, response: EmbeddingResponse) -> EmbeddingResponse:
|
||||
additional_headers: Optional[dict] = response._hidden_params.get(
|
||||
"additional_headers"
|
||||
)
|
||||
if additional_headers:
|
||||
# CALCULATE USAGE
|
||||
input_tokens: Optional[str] = additional_headers.get(
|
||||
"llm_provider-num_tokens"
|
||||
)
|
||||
if input_tokens:
|
||||
if response.usage:
|
||||
response.usage.prompt_tokens = int(input_tokens)
|
||||
else:
|
||||
response.usage = Usage(prompt_tokens=int(input_tokens))
|
||||
|
||||
# SET MODEL
|
||||
base_model: Optional[str] = additional_headers.get(
|
||||
"llm_provider-azureml-model-group"
|
||||
)
|
||||
if base_model:
|
||||
response.model = self._map_azure_model_group(base_model)
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,292 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.openai.openai import OpenAIChatCompletion
|
||||
from litellm.types.llms.azure_ai import ImageEmbeddingRequest
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
from litellm.utils import convert_to_model_response_object
|
||||
|
||||
from .cohere_transformation import AzureAICohereConfig
|
||||
|
||||
|
||||
class AzureAIEmbedding(OpenAIChatCompletion):
|
||||
def _process_response(
|
||||
self,
|
||||
image_embedding_responses: Optional[List],
|
||||
text_embedding_responses: Optional[List],
|
||||
image_embeddings_idx: List[int],
|
||||
model_response: EmbeddingResponse,
|
||||
input: List,
|
||||
):
|
||||
combined_responses = []
|
||||
if (
|
||||
image_embedding_responses is not None
|
||||
and text_embedding_responses is not None
|
||||
):
|
||||
# Combine and order the results
|
||||
text_idx = 0
|
||||
image_idx = 0
|
||||
|
||||
for idx in range(len(input)):
|
||||
if idx in image_embeddings_idx:
|
||||
combined_responses.append(image_embedding_responses[image_idx])
|
||||
image_idx += 1
|
||||
else:
|
||||
combined_responses.append(text_embedding_responses[text_idx])
|
||||
text_idx += 1
|
||||
|
||||
model_response.data = combined_responses
|
||||
elif image_embedding_responses is not None:
|
||||
model_response.data = image_embedding_responses
|
||||
elif text_embedding_responses is not None:
|
||||
model_response.data = text_embedding_responses
|
||||
|
||||
response = AzureAICohereConfig()._transform_response(response=model_response) # type: ignore
|
||||
|
||||
return response
|
||||
|
||||
async def async_image_embedding(
|
||||
self,
|
||||
model: str,
|
||||
data: ImageEmbeddingRequest,
|
||||
timeout: float,
|
||||
logging_obj,
|
||||
model_response: EmbeddingResponse,
|
||||
optional_params: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> EmbeddingResponse:
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.AZURE_AI,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
url = "{}/images/embeddings".format(api_base)
|
||||
|
||||
response = await client.post(
|
||||
url=url,
|
||||
json=data, # type: ignore
|
||||
headers={"Authorization": "Bearer {}".format(api_key)},
|
||||
)
|
||||
|
||||
embedding_response = response.json()
|
||||
embedding_headers = dict(response.headers)
|
||||
returned_response: EmbeddingResponse = convert_to_model_response_object( # type: ignore
|
||||
response_object=embedding_response,
|
||||
model_response_object=model_response,
|
||||
response_type="embedding",
|
||||
stream=False,
|
||||
_response_headers=embedding_headers,
|
||||
)
|
||||
return returned_response
|
||||
|
||||
def image_embedding(
|
||||
self,
|
||||
model: str,
|
||||
data: ImageEmbeddingRequest,
|
||||
timeout: float,
|
||||
logging_obj,
|
||||
model_response: EmbeddingResponse,
|
||||
optional_params: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"api_base is None. Please set AZURE_AI_API_BASE or dynamically via `api_base` param, to make the request."
|
||||
)
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"api_key is None. Please set AZURE_AI_API_KEY or dynamically via `api_key` param, to make the request."
|
||||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = HTTPHandler(timeout=timeout, concurrent_limit=1)
|
||||
|
||||
url = "{}/images/embeddings".format(api_base)
|
||||
|
||||
response = client.post(
|
||||
url=url,
|
||||
json=data, # type: ignore
|
||||
headers={"Authorization": "Bearer {}".format(api_key)},
|
||||
)
|
||||
|
||||
embedding_response = response.json()
|
||||
embedding_headers = dict(response.headers)
|
||||
returned_response: EmbeddingResponse = convert_to_model_response_object( # type: ignore
|
||||
response_object=embedding_response,
|
||||
model_response_object=model_response,
|
||||
response_type="embedding",
|
||||
stream=False,
|
||||
_response_headers=embedding_headers,
|
||||
)
|
||||
return returned_response
|
||||
|
||||
async def async_embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: List,
|
||||
timeout: float,
|
||||
logging_obj,
|
||||
model_response: EmbeddingResponse,
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
) -> EmbeddingResponse:
|
||||
(
|
||||
image_embeddings_request,
|
||||
v1_embeddings_request,
|
||||
image_embeddings_idx,
|
||||
) = AzureAICohereConfig()._transform_request(
|
||||
input=input, optional_params=optional_params, model=model
|
||||
)
|
||||
|
||||
image_embedding_responses: Optional[List] = None
|
||||
text_embedding_responses: Optional[List] = None
|
||||
|
||||
if image_embeddings_request["input"]:
|
||||
image_response = await self.async_image_embedding(
|
||||
model=model,
|
||||
data=image_embeddings_request,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
client=client,
|
||||
)
|
||||
|
||||
image_embedding_responses = image_response.data
|
||||
if image_embedding_responses is None:
|
||||
raise Exception("/image/embeddings route returned None Embeddings.")
|
||||
|
||||
if v1_embeddings_request["input"]:
|
||||
response: EmbeddingResponse = await super().embedding( # type: ignore
|
||||
model=model,
|
||||
input=input,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
client=client,
|
||||
aembedding=True,
|
||||
)
|
||||
text_embedding_responses = response.data
|
||||
if text_embedding_responses is None:
|
||||
raise Exception("/v1/embeddings route returned None Embeddings.")
|
||||
|
||||
return self._process_response(
|
||||
image_embedding_responses=image_embedding_responses,
|
||||
text_embedding_responses=text_embedding_responses,
|
||||
image_embeddings_idx=image_embeddings_idx,
|
||||
model_response=model_response,
|
||||
input=input,
|
||||
)
|
||||
|
||||
def embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: List,
|
||||
timeout: float,
|
||||
logging_obj,
|
||||
model_response: EmbeddingResponse,
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
max_retries: Optional[int] = None,
|
||||
shared_session=None,
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
- Separate image url from text
|
||||
-> route image url call to `/image/embeddings`
|
||||
-> route text call to `/v1/embeddings` (OpenAI route)
|
||||
|
||||
assemble result in-order, and return
|
||||
"""
|
||||
if aembedding is True:
|
||||
return self.async_embedding( # type: ignore
|
||||
model,
|
||||
input,
|
||||
timeout,
|
||||
logging_obj,
|
||||
model_response,
|
||||
optional_params,
|
||||
api_key,
|
||||
api_base,
|
||||
client,
|
||||
)
|
||||
|
||||
(
|
||||
image_embeddings_request,
|
||||
v1_embeddings_request,
|
||||
image_embeddings_idx,
|
||||
) = AzureAICohereConfig()._transform_request(
|
||||
input=input, optional_params=optional_params, model=model
|
||||
)
|
||||
|
||||
image_embedding_responses: Optional[List] = None
|
||||
text_embedding_responses: Optional[List] = None
|
||||
|
||||
if image_embeddings_request["input"]:
|
||||
image_response = self.image_embedding(
|
||||
model=model,
|
||||
data=image_embeddings_request,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
client=client,
|
||||
)
|
||||
|
||||
image_embedding_responses = image_response.data
|
||||
if image_embedding_responses is None:
|
||||
raise Exception("/image/embeddings route returned None Embeddings.")
|
||||
|
||||
if v1_embeddings_request["input"]:
|
||||
response: EmbeddingResponse = super().embedding( # type: ignore
|
||||
model,
|
||||
input,
|
||||
timeout,
|
||||
logging_obj,
|
||||
model_response,
|
||||
optional_params,
|
||||
api_key,
|
||||
api_base,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, OpenAI)
|
||||
else None
|
||||
),
|
||||
aembedding=aembedding,
|
||||
shared_session=shared_session,
|
||||
)
|
||||
|
||||
text_embedding_responses = response.data
|
||||
if text_embedding_responses is None:
|
||||
raise Exception("/v1/embeddings route returned None Embeddings.")
|
||||
|
||||
return self._process_response(
|
||||
image_embedding_responses=image_embedding_responses,
|
||||
text_embedding_responses=text_embedding_responses,
|
||||
image_embeddings_idx=image_embeddings_idx,
|
||||
model_response=model_response,
|
||||
input=input,
|
||||
)
|
||||
@@ -0,0 +1,28 @@
|
||||
from litellm.llms.azure_ai.image_generation.flux_transformation import (
|
||||
AzureFoundryFluxImageGenerationConfig,
|
||||
)
|
||||
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
|
||||
|
||||
from .flux2_transformation import AzureFoundryFlux2ImageEditConfig
|
||||
from .transformation import AzureFoundryFluxImageEditConfig
|
||||
|
||||
__all__ = ["AzureFoundryFluxImageEditConfig", "AzureFoundryFlux2ImageEditConfig"]
|
||||
|
||||
|
||||
def get_azure_ai_image_edit_config(model: str) -> BaseImageEditConfig:
|
||||
"""
|
||||
Get the appropriate image edit config for an Azure AI model.
|
||||
|
||||
- FLUX 2 models use JSON with base64 image
|
||||
- FLUX 1 models use multipart/form-data
|
||||
"""
|
||||
# Check if it's a FLUX 2 model
|
||||
if AzureFoundryFluxImageGenerationConfig.is_flux2_model(model):
|
||||
return AzureFoundryFlux2ImageEditConfig()
|
||||
|
||||
# Default to FLUX 1 config for other FLUX models
|
||||
model_normalized = model.lower().replace("-", "").replace("_", "")
|
||||
if model_normalized == "" or "flux" in model_normalized:
|
||||
return AzureFoundryFluxImageEditConfig()
|
||||
|
||||
raise ValueError(f"Model {model} is not supported for Azure AI image editing.")
|
||||
@@ -0,0 +1,172 @@
|
||||
import base64
|
||||
from io import BufferedReader
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from httpx._types import RequestFiles
|
||||
|
||||
import litellm
|
||||
from litellm.llms.azure_ai.common_utils import AzureFoundryModelInfo
|
||||
from litellm.llms.azure_ai.image_generation.flux_transformation import (
|
||||
AzureFoundryFluxImageGenerationConfig,
|
||||
)
|
||||
from litellm.llms.openai.image_edit.transformation import OpenAIImageEditConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.images.main import ImageEditOptionalRequestParams
|
||||
from litellm.types.llms.openai import FileTypes
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
|
||||
class AzureFoundryFlux2ImageEditConfig(OpenAIImageEditConfig):
|
||||
"""
|
||||
Azure AI Foundry FLUX 2 image edit config
|
||||
|
||||
Supports FLUX 2 models (e.g., flux.2-pro) for image editing.
|
||||
Uses the same /providers/blackforestlabs/v1/flux-2-pro endpoint as image generation,
|
||||
with the image passed as base64 in JSON body.
|
||||
"""
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
FLUX 2 supports a subset of OpenAI image edit params
|
||||
"""
|
||||
return [
|
||||
"prompt",
|
||||
"image",
|
||||
"model",
|
||||
"n",
|
||||
"size",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
image_edit_optional_params: ImageEditOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
"""
|
||||
Map OpenAI params to FLUX 2 params.
|
||||
FLUX 2 uses the same param names as OpenAI for supported params.
|
||||
"""
|
||||
mapped_params: Dict[str, Any] = {}
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
|
||||
for key, value in dict(image_edit_optional_params).items():
|
||||
if key in supported_params and value is not None:
|
||||
mapped_params[key] = value
|
||||
|
||||
return mapped_params
|
||||
|
||||
def use_multipart_form_data(self) -> bool:
|
||||
"""FLUX 2 uses JSON requests, not multipart/form-data."""
|
||||
return False
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate Azure AI Foundry environment and set up authentication
|
||||
"""
|
||||
api_key = AzureFoundryModelInfo.get_api_key(api_key)
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
f"Azure AI API key is required for model {model}. Set AZURE_AI_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
|
||||
headers.update(
|
||||
{
|
||||
"Api-Key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
)
|
||||
return headers
|
||||
|
||||
def transform_image_edit_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Optional[str],
|
||||
image: Optional[FileTypes],
|
||||
image_edit_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[Dict, RequestFiles]:
|
||||
"""
|
||||
Transform image edit request for FLUX 2.
|
||||
|
||||
FLUX 2 uses the same endpoint for generation and editing,
|
||||
with the image passed as base64 in the JSON body.
|
||||
"""
|
||||
if prompt is None:
|
||||
raise ValueError("FLUX 2 image edit requires a prompt.")
|
||||
|
||||
if image is None:
|
||||
raise ValueError("FLUX 2 image edit requires an image.")
|
||||
|
||||
image_b64 = self._convert_image_to_base64(image)
|
||||
|
||||
# Build request body with required params
|
||||
request_body: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
"image": image_b64,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
# Add mapped optional params (already filtered by map_openai_params)
|
||||
request_body.update(image_edit_optional_request_params)
|
||||
|
||||
# Return JSON body and empty files list (FLUX 2 doesn't use multipart)
|
||||
return request_body, []
|
||||
|
||||
def _convert_image_to_base64(self, image: Any) -> str:
|
||||
"""Convert image file to base64 string"""
|
||||
# Handle list of images (take first one)
|
||||
if isinstance(image, list):
|
||||
if len(image) == 0:
|
||||
raise ValueError("Empty image list provided")
|
||||
image = image[0]
|
||||
|
||||
if isinstance(image, BufferedReader):
|
||||
image_bytes = image.read()
|
||||
image.seek(0) # Reset file pointer for potential reuse
|
||||
elif isinstance(image, bytes):
|
||||
image_bytes = image
|
||||
elif hasattr(image, "read"):
|
||||
image_bytes = image.read() # type: ignore
|
||||
else:
|
||||
raise ValueError(f"Unsupported image type: {type(image)}")
|
||||
|
||||
return base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Constructs a complete URL for Azure AI Foundry FLUX 2 image edits.
|
||||
|
||||
Uses the same /providers/blackforestlabs/v1/flux-2-pro endpoint as image generation.
|
||||
"""
|
||||
api_base = AzureFoundryModelInfo.get_api_base(api_base)
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Azure AI API base is required. Set AZURE_AI_API_BASE environment variable or pass api_base parameter."
|
||||
)
|
||||
|
||||
api_version = (
|
||||
litellm_params.get("api_version")
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_AI_API_VERSION")
|
||||
or "preview"
|
||||
)
|
||||
|
||||
return AzureFoundryFluxImageGenerationConfig.get_flux2_image_generation_url(
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
api_version=api_version,
|
||||
)
|
||||
@@ -0,0 +1,101 @@
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.azure_ai.common_utils import AzureFoundryModelInfo
|
||||
from litellm.llms.openai.image_edit.transformation import OpenAIImageEditConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.utils import _add_path_to_api_base
|
||||
|
||||
|
||||
class AzureFoundryFluxImageEditConfig(OpenAIImageEditConfig):
|
||||
"""
|
||||
Azure AI Foundry FLUX image edit config
|
||||
|
||||
Supports FLUX models including FLUX-1-kontext-pro for image editing.
|
||||
|
||||
Azure AI Foundry FLUX models handle image editing through the /images/edits endpoint,
|
||||
same as standard Azure OpenAI models. The request format uses multipart/form-data
|
||||
with image files and prompt.
|
||||
"""
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate Azure AI Foundry environment and set up authentication
|
||||
Uses Api-Key header format
|
||||
"""
|
||||
api_key = AzureFoundryModelInfo.get_api_key(api_key)
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
f"Azure AI API key is required for model {model}. Set AZURE_AI_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
|
||||
headers.update(
|
||||
{
|
||||
"Api-Key": api_key, # Azure AI Foundry uses Api-Key header format
|
||||
}
|
||||
)
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Constructs a complete URL for Azure AI Foundry image edits API request.
|
||||
|
||||
Azure AI Foundry FLUX models handle image editing through the /images/edits
|
||||
endpoint.
|
||||
|
||||
Args:
|
||||
- model: Model name (deployment name for Azure AI Foundry)
|
||||
- api_base: Base URL for Azure AI endpoint
|
||||
- litellm_params: Additional parameters including api_version
|
||||
|
||||
Returns:
|
||||
- Complete URL for the image edits endpoint
|
||||
"""
|
||||
api_base = AzureFoundryModelInfo.get_api_base(api_base)
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Azure AI API base is required. Set AZURE_AI_API_BASE environment variable or pass api_base parameter."
|
||||
)
|
||||
|
||||
api_version = (
|
||||
litellm_params.get("api_version")
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_AI_API_VERSION")
|
||||
)
|
||||
if api_version is None:
|
||||
# API version is mandatory for Azure AI Foundry
|
||||
raise ValueError(
|
||||
"Azure API version is required. Set AZURE_AI_API_VERSION environment variable or pass api_version parameter."
|
||||
)
|
||||
|
||||
# Add the path to the base URL using the model as deployment name
|
||||
# Azure AI Foundry FLUX models use /images/edits for editing
|
||||
if "/openai/deployments/" in api_base:
|
||||
new_url = _add_path_to_api_base(
|
||||
api_base=api_base,
|
||||
ending_path="/images/edits",
|
||||
)
|
||||
else:
|
||||
new_url = _add_path_to_api_base(
|
||||
api_base=api_base,
|
||||
ending_path=f"/openai/deployments/{model}/images/edits",
|
||||
)
|
||||
|
||||
# Use the new query_params dictionary
|
||||
final_url = httpx.URL(new_url).copy_with(params={"api-version": api_version})
|
||||
|
||||
return str(final_url)
|
||||
@@ -0,0 +1,33 @@
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
|
||||
from .dall_e_2_transformation import AzureFoundryDallE2ImageGenerationConfig
|
||||
from .dall_e_3_transformation import AzureFoundryDallE3ImageGenerationConfig
|
||||
from .flux_transformation import AzureFoundryFluxImageGenerationConfig
|
||||
from .gpt_transformation import AzureFoundryGPTImageGenerationConfig
|
||||
|
||||
__all__ = [
|
||||
"AzureFoundryFluxImageGenerationConfig",
|
||||
"AzureFoundryGPTImageGenerationConfig",
|
||||
"AzureFoundryDallE2ImageGenerationConfig",
|
||||
"AzureFoundryDallE3ImageGenerationConfig",
|
||||
]
|
||||
|
||||
|
||||
def get_azure_ai_image_generation_config(model: str) -> BaseImageGenerationConfig:
|
||||
model = model.lower()
|
||||
model = model.replace("-", "")
|
||||
model = model.replace("_", "")
|
||||
if model == "" or "dalle2" in model: # empty model is dall-e-2
|
||||
return AzureFoundryDallE2ImageGenerationConfig()
|
||||
elif "dalle3" in model:
|
||||
return AzureFoundryDallE3ImageGenerationConfig()
|
||||
elif "flux" in model:
|
||||
return AzureFoundryFluxImageGenerationConfig()
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"Using AzureGPTImageGenerationConfig for model: {model}. This follows the gpt-image-1 model format."
|
||||
)
|
||||
return AzureFoundryGPTImageGenerationConfig()
|
||||
@@ -0,0 +1,27 @@
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
def cost_calculator(
|
||||
model: str,
|
||||
image_response: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Recraft image generation cost calculator
|
||||
"""
|
||||
_model_info = litellm.get_model_info(
|
||||
model=model,
|
||||
custom_llm_provider=litellm.LlmProviders.AZURE_AI.value,
|
||||
)
|
||||
output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
|
||||
num_images: int = 0
|
||||
if isinstance(image_response, ImageResponse):
|
||||
if image_response.data:
|
||||
num_images = len(image_response.data)
|
||||
return output_cost_per_image * num_images
|
||||
else:
|
||||
raise ValueError(
|
||||
f"image_response must be of type ImageResponse got type={type(image_response)}"
|
||||
)
|
||||
@@ -0,0 +1,9 @@
|
||||
from litellm.llms.openai.image_generation import DallE2ImageGenerationConfig
|
||||
|
||||
|
||||
class AzureFoundryDallE2ImageGenerationConfig(DallE2ImageGenerationConfig):
|
||||
"""
|
||||
Azure dall-e-2 image generation config
|
||||
"""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,9 @@
|
||||
from litellm.llms.openai.image_generation import DallE3ImageGenerationConfig
|
||||
|
||||
|
||||
class AzureFoundryDallE3ImageGenerationConfig(DallE3ImageGenerationConfig):
|
||||
"""
|
||||
Azure dall-e-3 image generation config
|
||||
"""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,68 @@
|
||||
from typing import Optional
|
||||
|
||||
from litellm.llms.openai.image_generation import GPTImageGenerationConfig
|
||||
|
||||
|
||||
class AzureFoundryFluxImageGenerationConfig(GPTImageGenerationConfig):
|
||||
"""
|
||||
Azure Foundry flux image generation config
|
||||
|
||||
From manual testing it follows the gpt-image-1 image generation config
|
||||
|
||||
(Azure Foundry does not have any docs on supported params at the time of writing)
|
||||
|
||||
From our test suite - following GPTImageGenerationConfig is working for this model
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_flux2_image_generation_url(
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
api_version: Optional[str],
|
||||
) -> str:
|
||||
"""
|
||||
Constructs the complete URL for Azure AI FLUX 2 image generation.
|
||||
|
||||
FLUX 2 models on Azure AI use a different URL pattern than standard Azure OpenAI:
|
||||
- Standard: /openai/deployments/{model}/images/generations
|
||||
- FLUX 2: /providers/blackforestlabs/v1/flux-2-pro
|
||||
|
||||
Args:
|
||||
api_base: Base URL (e.g., https://litellm-ci-cd-prod.services.ai.azure.com)
|
||||
model: Model name (e.g., flux.2-pro)
|
||||
api_version: API version (e.g., preview)
|
||||
|
||||
Returns:
|
||||
Complete URL for the FLUX 2 image generation endpoint
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"api_base is required for Azure AI FLUX 2 image generation"
|
||||
)
|
||||
|
||||
api_base = api_base.rstrip("/")
|
||||
api_version = api_version or "preview"
|
||||
|
||||
# If the api_base already contains /providers/, it's already a complete path
|
||||
if "/providers/" in api_base:
|
||||
if "?" in api_base:
|
||||
return api_base
|
||||
return f"{api_base}?api-version={api_version}"
|
||||
|
||||
# Construct the FLUX 2 provider path
|
||||
# Model name flux.2-pro maps to endpoint flux-2-pro
|
||||
return f"{api_base}/providers/blackforestlabs/v1/flux-2-pro?api-version={api_version}"
|
||||
|
||||
@staticmethod
|
||||
def is_flux2_model(model: str) -> bool:
|
||||
"""
|
||||
Check if the model is an Azure AI FLUX 2 model.
|
||||
|
||||
Args:
|
||||
model: Model name (e.g., flux.2-pro, azure_ai/flux.2-pro)
|
||||
|
||||
Returns:
|
||||
True if the model is a FLUX 2 model
|
||||
"""
|
||||
model_lower = model.lower().replace(".", "-").replace("_", "-")
|
||||
return "flux-2" in model_lower or "flux2" in model_lower
|
||||
@@ -0,0 +1,9 @@
|
||||
from litellm.llms.openai.image_generation import GPTImageGenerationConfig
|
||||
|
||||
|
||||
class AzureFoundryGPTImageGenerationConfig(GPTImageGenerationConfig):
|
||||
"""
|
||||
Azure gpt-image-1 image generation config
|
||||
"""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,12 @@
|
||||
"""Azure AI OCR module."""
|
||||
from .common_utils import get_azure_ai_ocr_config
|
||||
from .document_intelligence.transformation import (
|
||||
AzureDocumentIntelligenceOCRConfig,
|
||||
)
|
||||
from .transformation import AzureAIOCRConfig
|
||||
|
||||
__all__ = [
|
||||
"AzureAIOCRConfig",
|
||||
"AzureDocumentIntelligenceOCRConfig",
|
||||
"get_azure_ai_ocr_config",
|
||||
]
|
||||
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Common utilities for Azure AI OCR providers.
|
||||
|
||||
This module provides routing logic to determine which OCR configuration to use
|
||||
based on the model name.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.llms.base_llm.ocr.transformation import BaseOCRConfig
|
||||
|
||||
|
||||
def get_azure_ai_ocr_config(model: str) -> Optional["BaseOCRConfig"]:
|
||||
"""
|
||||
Determine which Azure AI OCR configuration to use based on the model name.
|
||||
|
||||
Azure AI supports multiple OCR services:
|
||||
- Azure Document Intelligence: azure_ai/doc-intelligence/<model>
|
||||
- Mistral OCR (via Azure AI): azure_ai/<model>
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., "azure_ai/doc-intelligence/prebuilt-read",
|
||||
"azure_ai/pixtral-12b-2409")
|
||||
|
||||
Returns:
|
||||
OCR configuration instance for the specified model
|
||||
|
||||
Examples:
|
||||
>>> get_azure_ai_ocr_config("azure_ai/doc-intelligence/prebuilt-read")
|
||||
<AzureDocumentIntelligenceOCRConfig object>
|
||||
|
||||
>>> get_azure_ai_ocr_config("azure_ai/pixtral-12b-2409")
|
||||
<AzureAIOCRConfig object>
|
||||
"""
|
||||
from litellm.llms.azure_ai.ocr.document_intelligence.transformation import (
|
||||
AzureDocumentIntelligenceOCRConfig,
|
||||
)
|
||||
from litellm.llms.azure_ai.ocr.transformation import AzureAIOCRConfig
|
||||
|
||||
# Check for Azure Document Intelligence models
|
||||
if "doc-intelligence" in model or "documentintelligence" in model:
|
||||
verbose_logger.debug(
|
||||
f"Routing {model} to Azure Document Intelligence OCR config"
|
||||
)
|
||||
return AzureDocumentIntelligenceOCRConfig()
|
||||
|
||||
# Default to Mistral-based OCR for other azure_ai models
|
||||
verbose_logger.debug(f"Routing {model} to Azure AI (Mistral) OCR config")
|
||||
return AzureAIOCRConfig()
|
||||
@@ -0,0 +1,4 @@
|
||||
"""Azure Document Intelligence OCR module."""
|
||||
from .transformation import AzureDocumentIntelligenceOCRConfig
|
||||
|
||||
__all__ = ["AzureDocumentIntelligenceOCRConfig"]
|
||||
@@ -0,0 +1,698 @@
|
||||
"""
|
||||
Azure Document Intelligence OCR transformation implementation.
|
||||
|
||||
Azure Document Intelligence (formerly Form Recognizer) provides advanced document analysis capabilities.
|
||||
This implementation transforms between Mistral OCR format and Azure Document Intelligence API v4.0.
|
||||
|
||||
Note: Azure Document Intelligence API is async - POST returns 202 Accepted with Operation-Location header.
|
||||
The operation location must be polled until the analysis completes.
|
||||
"""
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import (
|
||||
AZURE_DOCUMENT_INTELLIGENCE_API_VERSION,
|
||||
AZURE_DOCUMENT_INTELLIGENCE_DEFAULT_DPI,
|
||||
AZURE_OPERATION_POLLING_TIMEOUT,
|
||||
)
|
||||
from litellm.llms.base_llm.ocr.transformation import (
|
||||
BaseOCRConfig,
|
||||
DocumentType,
|
||||
OCRPage,
|
||||
OCRPageDimensions,
|
||||
OCRRequestData,
|
||||
OCRResponse,
|
||||
OCRUsageInfo,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
|
||||
class AzureDocumentIntelligenceOCRConfig(BaseOCRConfig):
|
||||
"""
|
||||
Azure Document Intelligence OCR transformation configuration.
|
||||
|
||||
Supports Azure Document Intelligence v4.0 (2024-11-30) API.
|
||||
Model route: azure_ai/doc-intelligence/<model>
|
||||
|
||||
Supported models:
|
||||
- prebuilt-layout: Extracts text with markdown, tables, and structure (closest to Mistral OCR)
|
||||
- prebuilt-read: Basic text extraction optimized for reading
|
||||
- prebuilt-document: General document analysis
|
||||
|
||||
Reference: https://learn.microsoft.com/en-us/azure/ai-services/document-intelligence/
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_supported_ocr_params(self, model: str) -> list:
|
||||
"""
|
||||
Get supported OCR parameters for Azure Document Intelligence.
|
||||
|
||||
Azure DI has minimal optional parameters compared to Mistral OCR.
|
||||
Most Mistral-specific params are ignored during transformation.
|
||||
"""
|
||||
return []
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: Dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""
|
||||
Validate environment and return headers for Azure Document Intelligence.
|
||||
|
||||
Authentication uses Ocp-Apim-Subscription-Key header.
|
||||
"""
|
||||
# Get API key from environment if not provided
|
||||
if api_key is None:
|
||||
api_key = get_secret_str("AZURE_DOCUMENT_INTELLIGENCE_API_KEY")
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Missing Azure Document Intelligence API Key - Set AZURE_DOCUMENT_INTELLIGENCE_API_KEY environment variable or pass api_key parameter"
|
||||
)
|
||||
|
||||
# Validate API base/endpoint is provided
|
||||
if api_base is None:
|
||||
api_base = get_secret_str("AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT")
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Missing Azure Document Intelligence Endpoint - Set AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT environment variable or pass api_base parameter"
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
**headers,
|
||||
}
|
||||
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Get complete URL for Azure Document Intelligence endpoint.
|
||||
|
||||
Format: {endpoint}/documentintelligence/documentModels/{modelId}:analyze?api-version=2024-11-30
|
||||
|
||||
Note: API version 2024-11-30 uses /documentintelligence/ path (not /formrecognizer/)
|
||||
|
||||
Args:
|
||||
api_base: Azure Document Intelligence endpoint (e.g., https://your-resource.cognitiveservices.azure.com)
|
||||
model: Model ID (e.g., "prebuilt-layout", "prebuilt-read")
|
||||
optional_params: Optional parameters
|
||||
|
||||
Returns: Complete URL for Azure DI analyze endpoint
|
||||
"""
|
||||
if api_base is None:
|
||||
api_base = get_secret_str("AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT")
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Missing Azure Document Intelligence Endpoint - Set AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT environment variable or pass api_base parameter"
|
||||
)
|
||||
|
||||
# Ensure no trailing slash
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
# Extract model ID from full model path if needed
|
||||
# Model can be "prebuilt-layout" or "azure_ai/doc-intelligence/prebuilt-layout"
|
||||
model_id = model
|
||||
if "/" in model:
|
||||
# Extract the last part after the last slash
|
||||
model_id = model.split("/")[-1]
|
||||
|
||||
# Azure Document Intelligence analyze endpoint
|
||||
# Note: API version 2024-11-30+ uses /documentintelligence/ (not /formrecognizer/)
|
||||
return f"{api_base}/documentintelligence/documentModels/{model_id}:analyze?api-version={AZURE_DOCUMENT_INTELLIGENCE_API_VERSION}"
|
||||
|
||||
def _extract_base64_from_data_uri(self, data_uri: str) -> str:
|
||||
"""
|
||||
Extract base64 content from a data URI.
|
||||
|
||||
Args:
|
||||
data_uri: Data URI like "data:application/pdf;base64,..."
|
||||
|
||||
Returns:
|
||||
Base64 string without the data URI prefix
|
||||
"""
|
||||
# Match pattern: data:[<mediatype>][;base64],<data>
|
||||
match = re.match(r"data:([^;]+)(?:;base64)?,(.+)", data_uri)
|
||||
if match:
|
||||
return match.group(2)
|
||||
return data_uri
|
||||
|
||||
def transform_ocr_request(
|
||||
self,
|
||||
model: str,
|
||||
document: DocumentType,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
**kwargs,
|
||||
) -> OCRRequestData:
|
||||
"""
|
||||
Transform OCR request to Azure Document Intelligence format.
|
||||
|
||||
Mistral OCR format:
|
||||
{
|
||||
"document": {
|
||||
"type": "document_url",
|
||||
"document_url": "https://example.com/doc.pdf"
|
||||
}
|
||||
}
|
||||
|
||||
Azure DI format:
|
||||
{
|
||||
"urlSource": "https://example.com/doc.pdf"
|
||||
}
|
||||
OR
|
||||
{
|
||||
"base64Source": "base64_encoded_content"
|
||||
}
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
document: Document dict from user (Mistral format)
|
||||
optional_params: Already mapped optional parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
OCRRequestData with JSON data
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
f"Azure Document Intelligence transform_ocr_request - model: {model}"
|
||||
)
|
||||
|
||||
if not isinstance(document, dict):
|
||||
raise ValueError(f"Expected document dict, got {type(document)}")
|
||||
|
||||
# Extract document URL from Mistral format
|
||||
doc_type = document.get("type")
|
||||
document_url = None
|
||||
|
||||
if doc_type == "document_url":
|
||||
document_url = document.get("document_url", "")
|
||||
elif doc_type == "image_url":
|
||||
document_url = document.get("image_url", "")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid document type: {doc_type}. Must be 'document_url' or 'image_url'"
|
||||
)
|
||||
|
||||
if not document_url:
|
||||
raise ValueError("Document URL is required")
|
||||
|
||||
# Build Azure DI request
|
||||
data: Dict[str, Any] = {}
|
||||
|
||||
# Check if it's a data URI (base64)
|
||||
if document_url.startswith("data:"):
|
||||
# Extract base64 content
|
||||
base64_content = self._extract_base64_from_data_uri(document_url)
|
||||
data["base64Source"] = base64_content
|
||||
verbose_logger.debug("Using base64Source for Azure Document Intelligence")
|
||||
else:
|
||||
# Regular URL
|
||||
data["urlSource"] = document_url
|
||||
verbose_logger.debug("Using urlSource for Azure Document Intelligence")
|
||||
|
||||
# Azure DI doesn't support most Mistral-specific params
|
||||
# Ignore pages, include_image_base64, etc.
|
||||
|
||||
return OCRRequestData(data=data, files=None)
|
||||
|
||||
def _extract_page_markdown(self, page_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract text from Azure DI page and format as markdown.
|
||||
|
||||
Azure DI provides text in 'lines' array. We concatenate them with newlines.
|
||||
|
||||
Args:
|
||||
page_data: Azure DI page object
|
||||
|
||||
Returns:
|
||||
Markdown-formatted text
|
||||
"""
|
||||
lines = page_data.get("lines", [])
|
||||
if not lines:
|
||||
return ""
|
||||
|
||||
# Extract text content from each line
|
||||
text_lines = [line.get("content", "") for line in lines]
|
||||
|
||||
# Join with newlines to preserve structure
|
||||
return "\n".join(text_lines)
|
||||
|
||||
def _convert_dimensions(
|
||||
self, width: float, height: float, unit: str
|
||||
) -> OCRPageDimensions:
|
||||
"""
|
||||
Convert Azure DI dimensions to pixels.
|
||||
|
||||
Azure DI provides dimensions in inches. We convert to pixels using configured DPI.
|
||||
|
||||
Args:
|
||||
width: Width in specified unit
|
||||
height: Height in specified unit
|
||||
unit: Unit of measurement (e.g., "inch")
|
||||
|
||||
Returns:
|
||||
OCRPageDimensions with pixel values
|
||||
"""
|
||||
# Convert to pixels using configured DPI
|
||||
dpi = AZURE_DOCUMENT_INTELLIGENCE_DEFAULT_DPI
|
||||
if unit == "inch":
|
||||
width_px = int(width * dpi)
|
||||
height_px = int(height * dpi)
|
||||
else:
|
||||
# If unit is not inches, assume it's already in pixels
|
||||
width_px = int(width)
|
||||
height_px = int(height)
|
||||
|
||||
return OCRPageDimensions(width=width_px, height=height_px, dpi=dpi)
|
||||
|
||||
@staticmethod
|
||||
def _check_timeout(start_time: float, timeout_secs: int) -> None:
|
||||
"""
|
||||
Check if operation has timed out.
|
||||
|
||||
Args:
|
||||
start_time: Start time of the operation
|
||||
timeout_secs: Timeout duration in seconds
|
||||
|
||||
Raises:
|
||||
TimeoutError: If operation has exceeded timeout
|
||||
"""
|
||||
if time.time() - start_time > timeout_secs:
|
||||
raise TimeoutError(
|
||||
f"Azure Document Intelligence operation polling timed out after {timeout_secs} seconds"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_retry_after(response: httpx.Response) -> int:
|
||||
"""
|
||||
Get retry-after duration from response headers.
|
||||
|
||||
Args:
|
||||
response: HTTP response
|
||||
|
||||
Returns:
|
||||
Retry-after duration in seconds (default: 2)
|
||||
"""
|
||||
retry_after = int(response.headers.get("retry-after", "2"))
|
||||
verbose_logger.debug(f"Retry polling after: {retry_after} seconds")
|
||||
return retry_after
|
||||
|
||||
@staticmethod
|
||||
def _check_operation_status(response: httpx.Response) -> str:
|
||||
"""
|
||||
Check Azure DI operation status from response.
|
||||
|
||||
Args:
|
||||
response: HTTP response from operation endpoint
|
||||
|
||||
Returns:
|
||||
Operation status string
|
||||
|
||||
Raises:
|
||||
ValueError: If operation failed or status is unknown
|
||||
"""
|
||||
try:
|
||||
result = response.json()
|
||||
status = result.get("status")
|
||||
|
||||
verbose_logger.debug(f"Azure DI operation status: {status}")
|
||||
|
||||
if status == "succeeded":
|
||||
return "succeeded"
|
||||
elif status == "failed":
|
||||
error_msg = result.get("error", {}).get("message", "Unknown error")
|
||||
raise ValueError(
|
||||
f"Azure Document Intelligence analysis failed: {error_msg}"
|
||||
)
|
||||
elif status in ["running", "notStarted"]:
|
||||
return "running"
|
||||
else:
|
||||
raise ValueError(f"Unknown operation status: {status}")
|
||||
|
||||
except Exception as e:
|
||||
if "succeeded" in str(e) or "failed" in str(e):
|
||||
raise
|
||||
# If we can't parse JSON, something went wrong
|
||||
raise ValueError(f"Failed to parse Azure DI operation response: {e}")
|
||||
|
||||
def _poll_operation_sync(
|
||||
self,
|
||||
operation_url: str,
|
||||
headers: Dict[str, str],
|
||||
timeout_secs: int,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Poll Azure Document Intelligence operation until completion (sync).
|
||||
|
||||
Azure DI POST returns 202 with Operation-Location header.
|
||||
We need to poll that URL until status is "succeeded" or "failed".
|
||||
|
||||
Args:
|
||||
operation_url: The Operation-Location URL to poll
|
||||
headers: Request headers (including auth)
|
||||
timeout_secs: Total timeout in seconds
|
||||
|
||||
Returns:
|
||||
Final response with completed analysis
|
||||
"""
|
||||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||
|
||||
client = _get_httpx_client()
|
||||
start_time = time.time()
|
||||
|
||||
verbose_logger.debug(f"Polling Azure DI operation: {operation_url}")
|
||||
|
||||
while True:
|
||||
self._check_timeout(start_time=start_time, timeout_secs=timeout_secs)
|
||||
|
||||
# Poll the operation status
|
||||
response = client.get(url=operation_url, headers=headers)
|
||||
|
||||
# Check operation status
|
||||
status = self._check_operation_status(response=response)
|
||||
|
||||
if status == "succeeded":
|
||||
return response
|
||||
elif status == "running":
|
||||
# Wait before polling again
|
||||
retry_after = self._get_retry_after(response=response)
|
||||
time.sleep(retry_after)
|
||||
|
||||
async def _poll_operation_async(
|
||||
self,
|
||||
operation_url: str,
|
||||
headers: Dict[str, str],
|
||||
timeout_secs: int,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Poll Azure Document Intelligence operation until completion (async).
|
||||
|
||||
Args:
|
||||
operation_url: The Operation-Location URL to poll
|
||||
headers: Request headers (including auth)
|
||||
timeout_secs: Total timeout in seconds
|
||||
|
||||
Returns:
|
||||
Final response with completed analysis
|
||||
"""
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
|
||||
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.AZURE_AI)
|
||||
start_time = time.time()
|
||||
|
||||
verbose_logger.debug(f"Polling Azure DI operation (async): {operation_url}")
|
||||
|
||||
while True:
|
||||
self._check_timeout(start_time=start_time, timeout_secs=timeout_secs)
|
||||
|
||||
# Poll the operation status
|
||||
response = await client.get(url=operation_url, headers=headers)
|
||||
|
||||
# Check operation status
|
||||
status = self._check_operation_status(response=response)
|
||||
|
||||
if status == "succeeded":
|
||||
return response
|
||||
elif status == "running":
|
||||
# Wait before polling again
|
||||
retry_after = self._get_retry_after(response=response)
|
||||
await asyncio.sleep(retry_after)
|
||||
|
||||
def transform_ocr_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: Any,
|
||||
**kwargs,
|
||||
) -> OCRResponse:
|
||||
"""
|
||||
Transform Azure Document Intelligence response to Mistral OCR format.
|
||||
|
||||
Handles async operation polling: If response is 202 Accepted, polls Operation-Location
|
||||
until analysis completes.
|
||||
|
||||
Azure DI response (after polling):
|
||||
{
|
||||
"status": "succeeded",
|
||||
"analyzeResult": {
|
||||
"content": "Full document text...",
|
||||
"pages": [
|
||||
{
|
||||
"pageNumber": 1,
|
||||
"width": 8.5,
|
||||
"height": 11,
|
||||
"unit": "inch",
|
||||
"lines": [{"content": "text", "boundingBox": [...]}]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
Mistral OCR format:
|
||||
{
|
||||
"pages": [
|
||||
{
|
||||
"index": 0,
|
||||
"markdown": "extracted text",
|
||||
"dimensions": {"width": 816, "height": 1056, "dpi": 96}
|
||||
}
|
||||
],
|
||||
"model": "azure_ai/doc-intelligence/prebuilt-layout",
|
||||
"usage_info": {"pages_processed": 1},
|
||||
"object": "ocr"
|
||||
}
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
raw_response: Raw HTTP response from Azure DI (may be 202 Accepted)
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
OCRResponse in Mistral format
|
||||
"""
|
||||
try:
|
||||
# Check if we got 202 Accepted (async operation started)
|
||||
if raw_response.status_code == 202:
|
||||
verbose_logger.debug(
|
||||
"Azure DI returned 202 Accepted, polling operation..."
|
||||
)
|
||||
|
||||
# Get Operation-Location header
|
||||
operation_url = raw_response.headers.get("Operation-Location")
|
||||
if not operation_url:
|
||||
raise ValueError(
|
||||
"Azure Document Intelligence returned 202 but no Operation-Location header found"
|
||||
)
|
||||
|
||||
# Get headers for polling (need auth)
|
||||
poll_headers = {
|
||||
"Ocp-Apim-Subscription-Key": raw_response.request.headers.get(
|
||||
"Ocp-Apim-Subscription-Key", ""
|
||||
)
|
||||
}
|
||||
|
||||
# Get timeout from kwargs or use default
|
||||
timeout_secs = AZURE_OPERATION_POLLING_TIMEOUT
|
||||
|
||||
# Poll until operation completes
|
||||
raw_response = self._poll_operation_sync(
|
||||
operation_url=operation_url,
|
||||
headers=poll_headers,
|
||||
timeout_secs=timeout_secs,
|
||||
)
|
||||
|
||||
# Now parse the completed response
|
||||
response_json = raw_response.json()
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Azure Document Intelligence response status: {response_json.get('status')}"
|
||||
)
|
||||
|
||||
# Check if request succeeded
|
||||
status = response_json.get("status")
|
||||
if status != "succeeded":
|
||||
raise ValueError(
|
||||
f"Azure Document Intelligence analysis failed with status: {status}"
|
||||
)
|
||||
|
||||
# Extract analyze result
|
||||
analyze_result = response_json.get("analyzeResult", {})
|
||||
azure_pages = analyze_result.get("pages", [])
|
||||
|
||||
# Transform pages to Mistral format
|
||||
mistral_pages = []
|
||||
for azure_page in azure_pages:
|
||||
page_number = azure_page.get("pageNumber", 1)
|
||||
index = page_number - 1 # Convert to 0-based index
|
||||
|
||||
# Extract markdown text
|
||||
markdown = self._extract_page_markdown(azure_page)
|
||||
|
||||
# Convert dimensions
|
||||
width = azure_page.get("width", 8.5)
|
||||
height = azure_page.get("height", 11)
|
||||
unit = azure_page.get("unit", "inch")
|
||||
dimensions = self._convert_dimensions(
|
||||
width=width, height=height, unit=unit
|
||||
)
|
||||
|
||||
# Build OCR page
|
||||
ocr_page = OCRPage(
|
||||
index=index, markdown=markdown, dimensions=dimensions
|
||||
)
|
||||
mistral_pages.append(ocr_page)
|
||||
|
||||
# Build usage info
|
||||
usage_info = OCRUsageInfo(
|
||||
pages_processed=len(mistral_pages), doc_size_bytes=None
|
||||
)
|
||||
|
||||
# Return Mistral OCR response
|
||||
return OCRResponse(
|
||||
pages=mistral_pages,
|
||||
model=model,
|
||||
usage_info=usage_info,
|
||||
object="ocr",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"Error parsing Azure Document Intelligence response: {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
async def async_transform_ocr_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: Any,
|
||||
**kwargs,
|
||||
) -> OCRResponse:
|
||||
"""
|
||||
Async transform Azure Document Intelligence response to Mistral OCR format.
|
||||
|
||||
Handles async operation polling: If response is 202 Accepted, polls Operation-Location
|
||||
until analysis completes using async polling.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
raw_response: Raw HTTP response from Azure DI (may be 202 Accepted)
|
||||
logging_obj: Logging object
|
||||
|
||||
Returns:
|
||||
OCRResponse in Mistral format
|
||||
"""
|
||||
try:
|
||||
# Check if we got 202 Accepted (async operation started)
|
||||
if raw_response.status_code == 202:
|
||||
verbose_logger.debug(
|
||||
"Azure DI returned 202 Accepted, polling operation (async)..."
|
||||
)
|
||||
|
||||
# Get Operation-Location header
|
||||
operation_url = raw_response.headers.get("Operation-Location")
|
||||
if not operation_url:
|
||||
raise ValueError(
|
||||
"Azure Document Intelligence returned 202 but no Operation-Location header found"
|
||||
)
|
||||
|
||||
# Get headers for polling (need auth)
|
||||
poll_headers = {
|
||||
"Ocp-Apim-Subscription-Key": raw_response.request.headers.get(
|
||||
"Ocp-Apim-Subscription-Key", ""
|
||||
)
|
||||
}
|
||||
|
||||
# Get timeout from kwargs or use default
|
||||
timeout_secs = AZURE_OPERATION_POLLING_TIMEOUT
|
||||
|
||||
# Poll until operation completes (async)
|
||||
raw_response = await self._poll_operation_async(
|
||||
operation_url=operation_url,
|
||||
headers=poll_headers,
|
||||
timeout_secs=timeout_secs,
|
||||
)
|
||||
|
||||
# Now parse the completed response
|
||||
response_json = raw_response.json()
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Azure Document Intelligence response status: {response_json.get('status')}"
|
||||
)
|
||||
|
||||
# Check if request succeeded
|
||||
status = response_json.get("status")
|
||||
if status != "succeeded":
|
||||
raise ValueError(
|
||||
f"Azure Document Intelligence analysis failed with status: {status}"
|
||||
)
|
||||
|
||||
# Extract analyze result
|
||||
analyze_result = response_json.get("analyzeResult", {})
|
||||
azure_pages = analyze_result.get("pages", [])
|
||||
|
||||
# Transform pages to Mistral format
|
||||
mistral_pages = []
|
||||
for azure_page in azure_pages:
|
||||
page_number = azure_page.get("pageNumber", 1)
|
||||
index = page_number - 1 # Convert to 0-based index
|
||||
|
||||
# Extract markdown text
|
||||
markdown = self._extract_page_markdown(azure_page)
|
||||
|
||||
# Convert dimensions
|
||||
width = azure_page.get("width", 8.5)
|
||||
height = azure_page.get("height", 11)
|
||||
unit = azure_page.get("unit", "inch")
|
||||
dimensions = self._convert_dimensions(
|
||||
width=width, height=height, unit=unit
|
||||
)
|
||||
|
||||
# Build OCR page
|
||||
ocr_page = OCRPage(
|
||||
index=index, markdown=markdown, dimensions=dimensions
|
||||
)
|
||||
mistral_pages.append(ocr_page)
|
||||
|
||||
# Build usage info
|
||||
usage_info = OCRUsageInfo(
|
||||
pages_processed=len(mistral_pages), doc_size_bytes=None
|
||||
)
|
||||
|
||||
# Return Mistral OCR response
|
||||
return OCRResponse(
|
||||
pages=mistral_pages,
|
||||
model=model,
|
||||
usage_info=usage_info,
|
||||
object="ocr",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"Error parsing Azure Document Intelligence response (async): {e}"
|
||||
)
|
||||
raise e
|
||||
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
Azure AI OCR transformation implementation.
|
||||
"""
|
||||
from typing import Dict, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.prompt_templates.image_handling import (
|
||||
async_convert_url_to_base64,
|
||||
convert_url_to_base64,
|
||||
)
|
||||
from litellm.llms.base_llm.ocr.transformation import DocumentType, OCRRequestData
|
||||
from litellm.llms.mistral.ocr.transformation import MistralOCRConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
|
||||
class AzureAIOCRConfig(MistralOCRConfig):
|
||||
"""
|
||||
Azure AI OCR transformation configuration.
|
||||
|
||||
Azure AI uses Mistral's OCR API but with a different endpoint format.
|
||||
Inherits transformation logic from MistralOCRConfig since they use the same format.
|
||||
|
||||
Reference: Azure AI Foundry OCR documentation
|
||||
|
||||
Important: Azure AI only supports base64 data URIs (data:image/..., data:application/pdf;base64,...).
|
||||
Regular URLs are not supported.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: Dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""
|
||||
Validate environment and return headers for Azure AI OCR.
|
||||
|
||||
Azure AI uses Bearer token authentication with AZURE_AI_API_KEY.
|
||||
"""
|
||||
# Get API key from environment if not provided
|
||||
if api_key is None:
|
||||
api_key = get_secret_str("AZURE_AI_API_KEY")
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Missing Azure AI API Key - A call is being made to Azure AI but no key is set either in the environment variables or via params"
|
||||
)
|
||||
|
||||
# Validate API base is provided
|
||||
if api_base is None:
|
||||
api_base = get_secret_str("AZURE_AI_API_BASE")
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Missing Azure AI API Base - Set AZURE_AI_API_BASE environment variable or pass api_base parameter"
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
**headers,
|
||||
}
|
||||
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Get complete URL for Azure AI OCR endpoint.
|
||||
|
||||
Azure AI endpoint format: https://<api_base>/providers/mistral/azure/ocr
|
||||
|
||||
Args:
|
||||
api_base: Azure AI API base URL
|
||||
model: Model name (not used in URL construction)
|
||||
optional_params: Optional parameters
|
||||
|
||||
Returns: Complete URL for Azure AI OCR endpoint
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Missing Azure AI API Base - Set AZURE_AI_API_BASE environment variable or pass api_base parameter"
|
||||
)
|
||||
|
||||
# Ensure no trailing slash
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
# Azure AI OCR endpoint format
|
||||
return f"{api_base}/providers/mistral/azure/ocr"
|
||||
|
||||
def _convert_url_to_data_uri_sync(self, url: str) -> str:
|
||||
"""
|
||||
Synchronously convert a URL to a base64 data URI.
|
||||
|
||||
Azure AI OCR doesn't have internet access, so we need to fetch URLs
|
||||
and convert them to base64 data URIs.
|
||||
|
||||
Args:
|
||||
url: The URL to convert
|
||||
|
||||
Returns:
|
||||
Base64 data URI string
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
f"Azure AI OCR: Converting URL to base64 data URI (sync): {url}"
|
||||
)
|
||||
|
||||
# Fetch and convert to base64 data URI
|
||||
# convert_url_to_base64 already returns a full data URI like "data:image/jpeg;base64,..."
|
||||
data_uri = convert_url_to_base64(url=url)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Azure AI OCR: Converted URL to data URI (length: {len(data_uri)})"
|
||||
)
|
||||
|
||||
return data_uri
|
||||
|
||||
async def _convert_url_to_data_uri_async(self, url: str) -> str:
|
||||
"""
|
||||
Asynchronously convert a URL to a base64 data URI.
|
||||
|
||||
Azure AI OCR doesn't have internet access, so we need to fetch URLs
|
||||
and convert them to base64 data URIs.
|
||||
|
||||
Args:
|
||||
url: The URL to convert
|
||||
|
||||
Returns:
|
||||
Base64 data URI string
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
f"Azure AI OCR: Converting URL to base64 data URI (async): {url}"
|
||||
)
|
||||
|
||||
# Fetch and convert to base64 data URI asynchronously
|
||||
# async_convert_url_to_base64 already returns a full data URI like "data:image/jpeg;base64,..."
|
||||
data_uri = await async_convert_url_to_base64(url=url)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Azure AI OCR: Converted URL to data URI (length: {len(data_uri)})"
|
||||
)
|
||||
|
||||
return data_uri
|
||||
|
||||
def transform_ocr_request(
|
||||
self,
|
||||
model: str,
|
||||
document: DocumentType,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
**kwargs,
|
||||
) -> OCRRequestData:
|
||||
"""
|
||||
Transform OCR request for Azure AI, converting URLs to base64 data URIs (sync).
|
||||
|
||||
Azure AI OCR doesn't have internet access, so we automatically fetch
|
||||
any URLs and convert them to base64 data URIs synchronously.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
document: Document dict from user
|
||||
optional_params: Already mapped optional parameters
|
||||
headers: Request headers
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
OCRRequestData with JSON data
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
f"Azure AI OCR transform_ocr_request (sync) - model: {model}"
|
||||
)
|
||||
|
||||
if not isinstance(document, dict):
|
||||
raise ValueError(f"Expected document dict, got {type(document)}")
|
||||
|
||||
# Check if we need to convert URL to base64
|
||||
doc_type = document.get("type")
|
||||
transformed_document = document.copy()
|
||||
|
||||
if doc_type == "document_url":
|
||||
document_url = document.get("document_url", "")
|
||||
# If it's not already a data URI, convert it
|
||||
if document_url and not document_url.startswith("data:"):
|
||||
verbose_logger.debug(
|
||||
"Azure AI OCR: Converting document URL to base64 data URI (sync)"
|
||||
)
|
||||
data_uri = self._convert_url_to_data_uri_sync(url=document_url)
|
||||
transformed_document["document_url"] = data_uri
|
||||
elif doc_type == "image_url":
|
||||
image_url = document.get("image_url", "")
|
||||
# If it's not already a data URI, convert it
|
||||
if image_url and not image_url.startswith("data:"):
|
||||
verbose_logger.debug(
|
||||
"Azure AI OCR: Converting image URL to base64 data URI (sync)"
|
||||
)
|
||||
data_uri = self._convert_url_to_data_uri_sync(url=image_url)
|
||||
transformed_document["image_url"] = data_uri
|
||||
|
||||
# Call parent's transform to build the request
|
||||
return super().transform_ocr_request(
|
||||
model=model,
|
||||
document=transformed_document,
|
||||
optional_params=optional_params,
|
||||
headers=headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def async_transform_ocr_request(
|
||||
self,
|
||||
model: str,
|
||||
document: DocumentType,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
**kwargs,
|
||||
) -> OCRRequestData:
|
||||
"""
|
||||
Transform OCR request for Azure AI, converting URLs to base64 data URIs (async).
|
||||
|
||||
Azure AI OCR doesn't have internet access, so we automatically fetch
|
||||
any URLs and convert them to base64 data URIs asynchronously.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
document: Document dict from user
|
||||
optional_params: Already mapped optional parameters
|
||||
headers: Request headers
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
OCRRequestData with JSON data
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
f"Azure AI OCR async_transform_ocr_request - model: {model}"
|
||||
)
|
||||
|
||||
if not isinstance(document, dict):
|
||||
raise ValueError(f"Expected document dict, got {type(document)}")
|
||||
|
||||
# Check if we need to convert URL to base64
|
||||
doc_type = document.get("type")
|
||||
transformed_document = document.copy()
|
||||
|
||||
if doc_type == "document_url":
|
||||
document_url = document.get("document_url", "")
|
||||
# If it's not already a data URI, convert it
|
||||
if document_url and not document_url.startswith("data:"):
|
||||
verbose_logger.debug(
|
||||
"Azure AI OCR: Converting document URL to base64 data URI (async)"
|
||||
)
|
||||
data_uri = await self._convert_url_to_data_uri_async(url=document_url)
|
||||
transformed_document["document_url"] = data_uri
|
||||
elif doc_type == "image_url":
|
||||
image_url = document.get("image_url", "")
|
||||
# If it's not already a data URI, convert it
|
||||
if image_url and not image_url.startswith("data:"):
|
||||
verbose_logger.debug(
|
||||
"Azure AI OCR: Converting image URL to base64 data URI (async)"
|
||||
)
|
||||
data_uri = await self._convert_url_to_data_uri_async(url=image_url)
|
||||
transformed_document["image_url"] = data_uri
|
||||
|
||||
# Call parent's transform to build the request
|
||||
return super().transform_ocr_request(
|
||||
model=model,
|
||||
document=transformed_document,
|
||||
optional_params=optional_params,
|
||||
headers=headers,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Azure AI Rerank - uses `llm_http_handler.py` to make httpx requests
|
||||
|
||||
Request/Response transformation is handled in `transformation.py`
|
||||
"""
|
||||
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.utils import RerankResponse
|
||||
from litellm.utils import _add_path_to_api_base
|
||||
|
||||
|
||||
class AzureAIRerankConfig(CohereRerankConfig):
|
||||
"""
|
||||
Azure AI Rerank - Follows the same Spec as Cohere Rerank
|
||||
"""
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> str:
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var."
|
||||
)
|
||||
original_url = httpx.URL(api_base)
|
||||
if not original_url.is_absolute_url:
|
||||
raise ValueError(
|
||||
"Azure AI API Base must be an absolute URL including scheme (e.g. "
|
||||
"'https://<resource>.services.ai.azure.com'). "
|
||||
f"Got api_base={api_base!r}."
|
||||
)
|
||||
normalized_path = original_url.path.rstrip("/")
|
||||
|
||||
# Allow callers to pass either full v1/v2 rerank endpoints:
|
||||
# - https://<resource>.services.ai.azure.com/v1/rerank
|
||||
# - https://<resource>.services.ai.azure.com/providers/cohere/v2/rerank
|
||||
if normalized_path.endswith("/v1/rerank") or normalized_path.endswith(
|
||||
"/v2/rerank"
|
||||
):
|
||||
return str(original_url.copy_with(path=normalized_path or "/"))
|
||||
|
||||
# If callers pass just the version path (e.g. ".../v2" or ".../providers/cohere/v2"), append "/rerank"
|
||||
if (
|
||||
normalized_path.endswith("/v1")
|
||||
or normalized_path.endswith("/v2")
|
||||
or normalized_path.endswith("/providers/cohere/v2")
|
||||
):
|
||||
return _add_path_to_api_base(
|
||||
api_base=str(original_url.copy_with(path=normalized_path or "/")),
|
||||
ending_path="/rerank",
|
||||
)
|
||||
|
||||
# Backwards compatible default: Azure AI rerank was originally exposed under /v1/rerank
|
||||
return _add_path_to_api_base(api_base=api_base, ending_path="/v1/rerank")
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
api_key = get_secret_str("AZURE_AI_API_KEY") or litellm.azure_key
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Azure AI API key is required. Please set 'AZURE_AI_API_KEY' or 'litellm.azure_key'"
|
||||
)
|
||||
|
||||
default_headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
# If 'Authorization' is provided in headers, it overrides the default.
|
||||
if "Authorization" in headers:
|
||||
default_headers["Authorization"] = headers["Authorization"]
|
||||
|
||||
# Merge other headers, overriding any default ones except Authorization
|
||||
return {**default_headers, **headers}
|
||||
|
||||
def transform_rerank_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: RerankResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> RerankResponse:
|
||||
rerank_response = super().transform_rerank_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=request_data,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
base_model = self._get_base_model(
|
||||
rerank_response._hidden_params.get("llm_provider-azureml-model-group")
|
||||
)
|
||||
rerank_response._hidden_params["model"] = base_model
|
||||
return rerank_response
|
||||
|
||||
def _get_base_model(self, azure_model_group: Optional[str]) -> Optional[str]:
|
||||
if azure_model_group is None:
|
||||
return None
|
||||
if azure_model_group == "offer-cohere-rerank-mul-paygo":
|
||||
return "azure_ai/cohere-rerank-v3-multilingual"
|
||||
if azure_model_group == "offer-cohere-rerank-eng-paygo":
|
||||
return "azure_ai/cohere-rerank-v3-english"
|
||||
return azure_model_group
|
||||
@@ -0,0 +1,3 @@
|
||||
from litellm.llms.azure_ai.vector_stores.transformation import AzureAIVectorStoreConfig
|
||||
|
||||
__all__ = ["AzureAIVectorStoreConfig"]
|
||||
@@ -0,0 +1,257 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.vector_stores import (
|
||||
BaseVectorStoreAuthCredentials,
|
||||
VectorStoreCreateOptionalRequestParams,
|
||||
VectorStoreCreateResponse,
|
||||
VectorStoreIndexEndpoints,
|
||||
VectorStoreResultContent,
|
||||
VectorStoreSearchOptionalRequestParams,
|
||||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResult,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AzureAIVectorStoreConfig(BaseVectorStoreConfig, BaseAzureLLM):
|
||||
"""
|
||||
Configuration for Azure AI Search Vector Store
|
||||
|
||||
This implementation uses the Azure AI Search API for vector store operations.
|
||||
Supports vector search with embeddings generated via litellm.embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_vector_store_endpoints_by_type(self) -> VectorStoreIndexEndpoints:
|
||||
return {
|
||||
"read": [("GET", "/docs/search"), ("POST", "/docs/search")],
|
||||
"write": [("PUT", "/docs")],
|
||||
}
|
||||
|
||||
def get_auth_credentials(
|
||||
self, litellm_params: dict
|
||||
) -> BaseVectorStoreAuthCredentials:
|
||||
api_key = litellm_params.get("api_key")
|
||||
if api_key is None:
|
||||
raise ValueError("api_key is required")
|
||||
|
||||
return {
|
||||
"headers": {
|
||||
"api-key": api_key,
|
||||
}
|
||||
}
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
basic_headers = self._base_validate_azure_environment(headers, litellm_params)
|
||||
basic_headers.update({"Content-Type": "application/json"})
|
||||
return basic_headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the base endpoint for Azure AI Search API
|
||||
|
||||
Expected format: https://{search_service_name}.search.windows.net
|
||||
"""
|
||||
if api_base:
|
||||
return api_base.rstrip("/")
|
||||
|
||||
# Get search service name from litellm_params
|
||||
search_service_name = litellm_params.get("azure_search_service_name")
|
||||
|
||||
if not search_service_name:
|
||||
raise ValueError(
|
||||
"Azure AI Search service name is required. "
|
||||
"Provide it via litellm_params['azure_search_service_name'] or api_base parameter"
|
||||
)
|
||||
|
||||
# Azure AI Search endpoint
|
||||
return f"https://{search_service_name}.search.windows.net"
|
||||
|
||||
def transform_search_vector_store_request(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: Union[str, List[str]],
|
||||
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
|
||||
api_base: str,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
Transform search request for Azure AI Search API
|
||||
|
||||
Generates embeddings using litellm.embeddings and constructs Azure AI Search request
|
||||
"""
|
||||
# Convert query to string if it's a list
|
||||
if isinstance(query, list):
|
||||
query = " ".join(query)
|
||||
|
||||
# Get embedding model from litellm_params (required)
|
||||
embedding_model = litellm_params.get("litellm_embedding_model")
|
||||
if not embedding_model:
|
||||
raise ValueError(
|
||||
"embedding_model is required in litellm_params for Azure AI Search. "
|
||||
"Example: litellm_params['embedding_model'] = 'azure/text-embedding-3-large'"
|
||||
)
|
||||
|
||||
embedding_config = litellm_params.get("litellm_embedding_config", {})
|
||||
if not embedding_config:
|
||||
raise ValueError(
|
||||
"embedding_config is required in litellm_params for Azure AI Search. "
|
||||
"Example: litellm_params['embedding_config'] = {'api_base': 'https://krris-mh44uf7y-eastus2.cognitiveservices.azure.com/', 'api_key': 'os.environ/AZURE_API_KEY', 'api_version': '2025-09-01'}"
|
||||
)
|
||||
|
||||
# Get vector field name (defaults to contentVector)
|
||||
vector_field = litellm_params.get("azure_search_vector_field", "contentVector")
|
||||
|
||||
# Get top_k (number of results to return)
|
||||
top_k = vector_store_search_optional_params.get("top_k", 10)
|
||||
|
||||
# Generate embedding for the query using litellm.embeddings
|
||||
try:
|
||||
embedding_response = litellm.embedding(
|
||||
model=embedding_model,
|
||||
input=[query],
|
||||
**embedding_config,
|
||||
)
|
||||
query_vector = embedding_response.data[0]["embedding"]
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to generate embedding for query: {str(e)}")
|
||||
|
||||
# Azure AI Search endpoint for search
|
||||
index_name = vector_store_id # vector_store_id is the index name
|
||||
url = f"{api_base}/indexes/{index_name}/docs/search?api-version=2024-07-01"
|
||||
|
||||
# Build the request body for Azure AI Search with vector search
|
||||
request_body = {
|
||||
"search": "*", # Get all documents (filtered by vector similarity)
|
||||
"vectorQueries": [
|
||||
{
|
||||
"vector": query_vector,
|
||||
"fields": vector_field,
|
||||
"kind": "vector",
|
||||
"k": top_k, # Number of nearest neighbors to return
|
||||
}
|
||||
],
|
||||
"select": "id,content", # Fields to return (customize based on schema)
|
||||
"top": top_k,
|
||||
}
|
||||
|
||||
#########################################################
|
||||
# Update logging object with details of the request
|
||||
#########################################################
|
||||
litellm_logging_obj.model_call_details["input"] = query
|
||||
litellm_logging_obj.model_call_details["embedding_model"] = embedding_model
|
||||
litellm_logging_obj.model_call_details["top_k"] = top_k
|
||||
|
||||
return url, request_body
|
||||
|
||||
def transform_search_vector_store_response(
|
||||
self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj
|
||||
) -> VectorStoreSearchResponse:
|
||||
"""
|
||||
Transform Azure AI Search API response to standard vector store search response
|
||||
|
||||
Handles the format from Azure AI Search which returns:
|
||||
{
|
||||
"value": [
|
||||
{
|
||||
"id": "...",
|
||||
"content": "...",
|
||||
"@search.score": 0.95,
|
||||
... (other fields)
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
response_json = response.json()
|
||||
|
||||
# Extract results from Azure AI Search API response
|
||||
results = response_json.get("value", [])
|
||||
|
||||
# Transform results to standard format
|
||||
search_results: List[VectorStoreSearchResult] = []
|
||||
for result in results:
|
||||
# Extract document ID
|
||||
document_id = result.get("id", "")
|
||||
|
||||
# Extract text content
|
||||
text_content = result.get("content", "")
|
||||
|
||||
content = [
|
||||
VectorStoreResultContent(
|
||||
text=text_content,
|
||||
type="text",
|
||||
)
|
||||
]
|
||||
|
||||
# Get the search score (relevance score from Azure AI Search)
|
||||
score = result.get("@search.score", 0.0)
|
||||
|
||||
# Use document ID as both file_id and filename
|
||||
file_id = document_id
|
||||
filename = f"Document {document_id}"
|
||||
|
||||
# Build attributes with all available metadata
|
||||
# Exclude system fields and already-processed fields
|
||||
attributes = {}
|
||||
for key, value in result.items():
|
||||
if key not in ["id", "content", "contentVector", "@search.score"]:
|
||||
attributes[key] = value
|
||||
|
||||
# Always include document_id in attributes
|
||||
attributes["document_id"] = document_id
|
||||
|
||||
result_obj = VectorStoreSearchResult(
|
||||
score=score,
|
||||
content=content,
|
||||
file_id=file_id,
|
||||
filename=filename,
|
||||
attributes=attributes,
|
||||
)
|
||||
search_results.append(result_obj)
|
||||
|
||||
return VectorStoreSearchResponse(
|
||||
object="vector_store.search_results.page",
|
||||
search_query=litellm_logging_obj.model_call_details.get("input", ""),
|
||||
data=search_results,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=str(e),
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
)
|
||||
|
||||
def transform_create_vector_store_request(
|
||||
self,
|
||||
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
def transform_create_vector_store_response(
|
||||
self, response: httpx.Response
|
||||
) -> VectorStoreCreateResponse:
|
||||
raise NotImplementedError
|
||||
Reference in New Issue
Block a user