chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
from litellm.llms.azure_ai.agents.handler import azure_ai_agents_handler
|
||||
from litellm.llms.azure_ai.agents.transformation import (
|
||||
AzureAIAgentsConfig,
|
||||
AzureAIAgentsError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AzureAIAgentsConfig",
|
||||
"AzureAIAgentsError",
|
||||
"azure_ai_agents_handler",
|
||||
]
|
||||
@@ -0,0 +1,659 @@
|
||||
"""
|
||||
Handler for Azure Foundry Agent Service API.
|
||||
|
||||
This handler executes the multi-step agent flow:
|
||||
1. Create thread (or use existing)
|
||||
2. Add messages to thread
|
||||
3. Create and poll a run
|
||||
4. Retrieve the assistant's response messages
|
||||
|
||||
Model format: azure_ai/agents/<agent_id>
|
||||
API Base format: https://<AIFoundryResourceName>.services.ai.azure.com/api/projects/<ProjectName>
|
||||
|
||||
Authentication: Uses Azure AD Bearer tokens (not API keys)
|
||||
Get token via: az account get-access-token --resource 'https://ai.azure.com'
|
||||
|
||||
Supports both polling-based and native streaming (SSE) modes.
|
||||
|
||||
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.azure_ai.agents.transformation import (
|
||||
AzureAIAgentsConfig,
|
||||
AzureAIAgentsError,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
HTTPHandler = Any
|
||||
AsyncHTTPHandler = Any
|
||||
|
||||
|
||||
class AzureAIAgentsHandler:
|
||||
"""
|
||||
Handler for Azure AI Agent Service.
|
||||
|
||||
Executes the complete agent flow which requires multiple API calls.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = AzureAIAgentsConfig()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# URL Builders
|
||||
# -------------------------------------------------------------------------
|
||||
# Azure Foundry Agents API uses /assistants, /threads, etc. directly
|
||||
# See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
# -------------------------------------------------------------------------
|
||||
def _build_thread_url(self, api_base: str, api_version: str) -> str:
|
||||
return f"{api_base}/threads?api-version={api_version}"
|
||||
|
||||
def _build_messages_url(
|
||||
self, api_base: str, thread_id: str, api_version: str
|
||||
) -> str:
|
||||
return f"{api_base}/threads/{thread_id}/messages?api-version={api_version}"
|
||||
|
||||
def _build_runs_url(self, api_base: str, thread_id: str, api_version: str) -> str:
|
||||
return f"{api_base}/threads/{thread_id}/runs?api-version={api_version}"
|
||||
|
||||
def _build_run_status_url(
|
||||
self, api_base: str, thread_id: str, run_id: str, api_version: str
|
||||
) -> str:
|
||||
return f"{api_base}/threads/{thread_id}/runs/{run_id}?api-version={api_version}"
|
||||
|
||||
def _build_list_messages_url(
|
||||
self, api_base: str, thread_id: str, api_version: str
|
||||
) -> str:
|
||||
return f"{api_base}/threads/{thread_id}/messages?api-version={api_version}"
|
||||
|
||||
def _build_create_thread_and_run_url(self, api_base: str, api_version: str) -> str:
|
||||
"""URL for the create-thread-and-run endpoint (supports streaming)."""
|
||||
return f"{api_base}/threads/runs?api-version={api_version}"
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Response Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
def _extract_content_from_messages(self, messages_data: dict) -> str:
|
||||
"""Extract assistant content from the messages response."""
|
||||
for msg in messages_data.get("data", []):
|
||||
if msg.get("role") == "assistant":
|
||||
for content_item in msg.get("content", []):
|
||||
if content_item.get("type") == "text":
|
||||
return content_item.get("text", {}).get("value", "")
|
||||
return ""
|
||||
|
||||
def _build_model_response(
|
||||
self,
|
||||
model: str,
|
||||
content: str,
|
||||
model_response: ModelResponse,
|
||||
thread_id: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
) -> ModelResponse:
|
||||
"""Build the ModelResponse from agent output."""
|
||||
from litellm.types.utils import Choices, Message, Usage
|
||||
|
||||
model_response.choices = [
|
||||
Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=Message(content=content, role="assistant"),
|
||||
)
|
||||
]
|
||||
model_response.model = model
|
||||
|
||||
# Store thread_id for conversation continuity
|
||||
if (
|
||||
not hasattr(model_response, "_hidden_params")
|
||||
or model_response._hidden_params is None
|
||||
):
|
||||
model_response._hidden_params = {}
|
||||
model_response._hidden_params["thread_id"] = thread_id
|
||||
|
||||
# Estimate token usage
|
||||
try:
|
||||
from litellm.utils import token_counter
|
||||
|
||||
prompt_tokens = token_counter(model="gpt-3.5-turbo", messages=messages)
|
||||
completion_tokens = token_counter(
|
||||
model="gpt-3.5-turbo", text=content, count_response_tokens=True
|
||||
)
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to calculate token usage: {str(e)}")
|
||||
|
||||
return model_response
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
model: str,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
optional_params: dict,
|
||||
headers: Optional[dict],
|
||||
) -> tuple:
|
||||
"""Prepare common parameters for completion.
|
||||
|
||||
Azure Foundry Agents API uses Bearer token authentication:
|
||||
- Authorization: Bearer <token> (Azure AD token from 'az account get-access-token --resource https://ai.azure.com')
|
||||
|
||||
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
"""
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
# Azure Foundry Agents uses Bearer token authentication
|
||||
# The api_key here is expected to be an Azure AD token
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
api_version = optional_params.get(
|
||||
"api_version", self.config.DEFAULT_API_VERSION
|
||||
)
|
||||
agent_id = self.config._get_agent_id(model, optional_params)
|
||||
thread_id = optional_params.get("thread_id")
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Azure AI Agents completion - api_base: {api_base}, agent_id: {agent_id}"
|
||||
)
|
||||
|
||||
return headers, api_version, agent_id, thread_id, api_base
|
||||
|
||||
def _check_response(
|
||||
self, response: httpx.Response, expected_codes: List[int], error_msg: str
|
||||
):
|
||||
"""Check response status and raise error if not expected."""
|
||||
if response.status_code not in expected_codes:
|
||||
raise AzureAIAgentsError(
|
||||
status_code=response.status_code,
|
||||
message=f"{error_msg}: {response.text}",
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Sync Completion
|
||||
# -------------------------------------------------------------------------
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
timeout: float,
|
||||
client: Optional[HTTPHandler] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> ModelResponse:
|
||||
"""Execute synchronous completion using Azure Agent Service."""
|
||||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||
|
||||
if client is None:
|
||||
client = _get_httpx_client(
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)}
|
||||
)
|
||||
|
||||
(
|
||||
headers,
|
||||
api_version,
|
||||
agent_id,
|
||||
thread_id,
|
||||
api_base,
|
||||
) = self._prepare_completion_params(
|
||||
model, api_base, api_key, optional_params, headers
|
||||
)
|
||||
|
||||
def make_request(
|
||||
method: str, url: str, json_data: Optional[dict] = None
|
||||
) -> httpx.Response:
|
||||
if method == "GET":
|
||||
return client.get(url=url, headers=headers)
|
||||
return client.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=json.dumps(json_data) if json_data else None,
|
||||
)
|
||||
|
||||
# Execute the agent flow
|
||||
thread_id, content = self._execute_agent_flow_sync(
|
||||
make_request=make_request,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
agent_id=agent_id,
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
return self._build_model_response(
|
||||
model, content, model_response, thread_id, messages
|
||||
)
|
||||
|
||||
def _execute_agent_flow_sync(
|
||||
self,
|
||||
make_request: Callable,
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
agent_id: str,
|
||||
thread_id: Optional[str],
|
||||
messages: List[Dict[str, Any]],
|
||||
optional_params: dict,
|
||||
) -> Tuple[str, str]:
|
||||
"""Execute the agent flow synchronously. Returns (thread_id, content)."""
|
||||
|
||||
# Step 1: Create thread if not provided
|
||||
if not thread_id:
|
||||
verbose_logger.debug(
|
||||
f"Creating thread at: {self._build_thread_url(api_base, api_version)}"
|
||||
)
|
||||
response = make_request(
|
||||
"POST", self._build_thread_url(api_base, api_version), {}
|
||||
)
|
||||
self._check_response(response, [200, 201], "Failed to create thread")
|
||||
thread_id = response.json()["id"]
|
||||
verbose_logger.debug(f"Created thread: {thread_id}")
|
||||
|
||||
# At this point thread_id is guaranteed to be a string
|
||||
assert thread_id is not None
|
||||
|
||||
# Step 2: Add messages to thread
|
||||
for msg in messages:
|
||||
if msg.get("role") in ["user", "system"]:
|
||||
url = self._build_messages_url(api_base, thread_id, api_version)
|
||||
response = make_request(
|
||||
"POST", url, {"role": "user", "content": msg.get("content", "")}
|
||||
)
|
||||
self._check_response(response, [200, 201], "Failed to add message")
|
||||
|
||||
# Step 3: Create run
|
||||
run_payload = {"assistant_id": agent_id}
|
||||
if "instructions" in optional_params:
|
||||
run_payload["instructions"] = optional_params["instructions"]
|
||||
|
||||
response = make_request(
|
||||
"POST", self._build_runs_url(api_base, thread_id, api_version), run_payload
|
||||
)
|
||||
self._check_response(response, [200, 201], "Failed to create run")
|
||||
run_id = response.json()["id"]
|
||||
verbose_logger.debug(f"Created run: {run_id}")
|
||||
|
||||
# Step 4: Poll for completion
|
||||
status_url = self._build_run_status_url(
|
||||
api_base, thread_id, run_id, api_version
|
||||
)
|
||||
for _ in range(self.config.MAX_POLL_ATTEMPTS):
|
||||
response = make_request("GET", status_url)
|
||||
self._check_response(response, [200], "Failed to get run status")
|
||||
|
||||
status = response.json().get("status")
|
||||
verbose_logger.debug(f"Run status: {status}")
|
||||
|
||||
if status == "completed":
|
||||
break
|
||||
elif status in ["failed", "cancelled", "expired"]:
|
||||
error_msg = (
|
||||
response.json()
|
||||
.get("last_error", {})
|
||||
.get("message", "Unknown error")
|
||||
)
|
||||
raise AzureAIAgentsError(
|
||||
status_code=500, message=f"Run {status}: {error_msg}"
|
||||
)
|
||||
|
||||
time.sleep(self.config.POLL_INTERVAL_SECONDS)
|
||||
else:
|
||||
raise AzureAIAgentsError(
|
||||
status_code=408, message="Run timed out waiting for completion"
|
||||
)
|
||||
|
||||
# Step 5: Get messages
|
||||
response = make_request(
|
||||
"GET", self._build_list_messages_url(api_base, thread_id, api_version)
|
||||
)
|
||||
self._check_response(response, [200], "Failed to get messages")
|
||||
|
||||
content = self._extract_content_from_messages(response.json())
|
||||
return thread_id, content
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Async Completion
|
||||
# -------------------------------------------------------------------------
|
||||
async def acompletion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
timeout: float,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> ModelResponse:
|
||||
"""Execute asynchronous completion using Azure Agent Service."""
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
|
||||
if client is None:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.AZURE_AI,
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
|
||||
)
|
||||
|
||||
(
|
||||
headers,
|
||||
api_version,
|
||||
agent_id,
|
||||
thread_id,
|
||||
api_base,
|
||||
) = self._prepare_completion_params(
|
||||
model, api_base, api_key, optional_params, headers
|
||||
)
|
||||
|
||||
async def make_request(
|
||||
method: str, url: str, json_data: Optional[dict] = None
|
||||
) -> httpx.Response:
|
||||
if method == "GET":
|
||||
return await client.get(url=url, headers=headers)
|
||||
return await client.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=json.dumps(json_data) if json_data else None,
|
||||
)
|
||||
|
||||
# Execute the agent flow
|
||||
thread_id, content = await self._execute_agent_flow_async(
|
||||
make_request=make_request,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
agent_id=agent_id,
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
return self._build_model_response(
|
||||
model, content, model_response, thread_id, messages
|
||||
)
|
||||
|
||||
async def _execute_agent_flow_async(
|
||||
self,
|
||||
make_request: Callable,
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
agent_id: str,
|
||||
thread_id: Optional[str],
|
||||
messages: List[Dict[str, Any]],
|
||||
optional_params: dict,
|
||||
) -> Tuple[str, str]:
|
||||
"""Execute the agent flow asynchronously. Returns (thread_id, content)."""
|
||||
|
||||
# Step 1: Create thread if not provided
|
||||
if not thread_id:
|
||||
verbose_logger.debug(
|
||||
f"Creating thread at: {self._build_thread_url(api_base, api_version)}"
|
||||
)
|
||||
response = await make_request(
|
||||
"POST", self._build_thread_url(api_base, api_version), {}
|
||||
)
|
||||
self._check_response(response, [200, 201], "Failed to create thread")
|
||||
thread_id = response.json()["id"]
|
||||
verbose_logger.debug(f"Created thread: {thread_id}")
|
||||
|
||||
# At this point thread_id is guaranteed to be a string
|
||||
assert thread_id is not None
|
||||
|
||||
# Step 2: Add messages to thread
|
||||
for msg in messages:
|
||||
if msg.get("role") in ["user", "system"]:
|
||||
url = self._build_messages_url(api_base, thread_id, api_version)
|
||||
response = await make_request(
|
||||
"POST", url, {"role": "user", "content": msg.get("content", "")}
|
||||
)
|
||||
self._check_response(response, [200, 201], "Failed to add message")
|
||||
|
||||
# Step 3: Create run
|
||||
run_payload = {"assistant_id": agent_id}
|
||||
if "instructions" in optional_params:
|
||||
run_payload["instructions"] = optional_params["instructions"]
|
||||
|
||||
response = await make_request(
|
||||
"POST", self._build_runs_url(api_base, thread_id, api_version), run_payload
|
||||
)
|
||||
self._check_response(response, [200, 201], "Failed to create run")
|
||||
run_id = response.json()["id"]
|
||||
verbose_logger.debug(f"Created run: {run_id}")
|
||||
|
||||
# Step 4: Poll for completion
|
||||
status_url = self._build_run_status_url(
|
||||
api_base, thread_id, run_id, api_version
|
||||
)
|
||||
for _ in range(self.config.MAX_POLL_ATTEMPTS):
|
||||
response = await make_request("GET", status_url)
|
||||
self._check_response(response, [200], "Failed to get run status")
|
||||
|
||||
status = response.json().get("status")
|
||||
verbose_logger.debug(f"Run status: {status}")
|
||||
|
||||
if status == "completed":
|
||||
break
|
||||
elif status in ["failed", "cancelled", "expired"]:
|
||||
error_msg = (
|
||||
response.json()
|
||||
.get("last_error", {})
|
||||
.get("message", "Unknown error")
|
||||
)
|
||||
raise AzureAIAgentsError(
|
||||
status_code=500, message=f"Run {status}: {error_msg}"
|
||||
)
|
||||
|
||||
await asyncio.sleep(self.config.POLL_INTERVAL_SECONDS)
|
||||
else:
|
||||
raise AzureAIAgentsError(
|
||||
status_code=408, message="Run timed out waiting for completion"
|
||||
)
|
||||
|
||||
# Step 5: Get messages
|
||||
response = await make_request(
|
||||
"GET", self._build_list_messages_url(api_base, thread_id, api_version)
|
||||
)
|
||||
self._check_response(response, [200], "Failed to get messages")
|
||||
|
||||
content = self._extract_content_from_messages(response.json())
|
||||
return thread_id, content
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Streaming Completion (Native SSE)
|
||||
# -------------------------------------------------------------------------
|
||||
async def acompletion_stream(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
timeout: float,
|
||||
headers: Optional[dict] = None,
|
||||
) -> AsyncIterator:
|
||||
"""Execute async streaming completion using Azure Agent Service with native SSE."""
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
|
||||
(
|
||||
headers,
|
||||
api_version,
|
||||
agent_id,
|
||||
thread_id,
|
||||
api_base,
|
||||
) = self._prepare_completion_params(
|
||||
model, api_base, api_key, optional_params, headers
|
||||
)
|
||||
|
||||
# Build payload for create-thread-and-run with streaming
|
||||
thread_messages = []
|
||||
for msg in messages:
|
||||
if msg.get("role") in ["user", "system"]:
|
||||
thread_messages.append(
|
||||
{"role": "user", "content": msg.get("content", "")}
|
||||
)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"assistant_id": agent_id,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# Add thread with messages if we don't have an existing thread
|
||||
if not thread_id:
|
||||
payload["thread"] = {"messages": thread_messages}
|
||||
|
||||
if "instructions" in optional_params:
|
||||
payload["instructions"] = optional_params["instructions"]
|
||||
|
||||
url = self._build_create_thread_and_run_url(api_base, api_version)
|
||||
verbose_logger.debug(f"Azure AI Agents streaming - URL: {url}")
|
||||
|
||||
# Use LiteLLM's async HTTP client for streaming
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.AZURE_AI,
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=json.dumps(payload),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 201]:
|
||||
error_text = await response.aread()
|
||||
raise AzureAIAgentsError(
|
||||
status_code=response.status_code,
|
||||
message=f"Streaming request failed: {error_text.decode()}",
|
||||
)
|
||||
|
||||
async for chunk in self._process_sse_stream(response, model):
|
||||
yield chunk
|
||||
|
||||
async def _process_sse_stream(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
model: str,
|
||||
) -> AsyncIterator:
|
||||
"""Process SSE stream and yield OpenAI-compatible streaming chunks."""
|
||||
from litellm.types.utils import Delta, ModelResponseStream, StreamingChoices
|
||||
|
||||
response_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
|
||||
created = int(time.time())
|
||||
thread_id = None
|
||||
|
||||
current_event = None
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
line = line.strip()
|
||||
|
||||
if line.startswith("event:"):
|
||||
current_event = line[6:].strip()
|
||||
continue
|
||||
|
||||
if line.startswith("data:"):
|
||||
data_str = line[5:].strip()
|
||||
|
||||
if data_str == "[DONE]":
|
||||
# Send final chunk with finish_reason
|
||||
final_chunk = ModelResponseStream(
|
||||
id=response_id,
|
||||
created=created,
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
delta=Delta(content=None),
|
||||
)
|
||||
],
|
||||
)
|
||||
if thread_id:
|
||||
final_chunk._hidden_params = {"thread_id": thread_id}
|
||||
yield final_chunk
|
||||
return
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Extract thread_id from thread.created event
|
||||
if current_event == "thread.created" and "id" in data:
|
||||
thread_id = data["id"]
|
||||
verbose_logger.debug(f"Stream created thread: {thread_id}")
|
||||
|
||||
# Process message deltas - this is where the actual content comes
|
||||
if current_event == "thread.message.delta":
|
||||
delta_content = data.get("delta", {}).get("content", [])
|
||||
for content_item in delta_content:
|
||||
if content_item.get("type") == "text":
|
||||
text_value = content_item.get("text", {}).get("value", "")
|
||||
if text_value:
|
||||
chunk = ModelResponseStream(
|
||||
id=response_id,
|
||||
created=created,
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=Delta(
|
||||
content=text_value, role="assistant"
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
if thread_id:
|
||||
chunk._hidden_params = {"thread_id": thread_id}
|
||||
yield chunk
|
||||
|
||||
|
||||
# Singleton instance
|
||||
azure_ai_agents_handler = AzureAIAgentsHandler()
|
||||
@@ -0,0 +1,402 @@
|
||||
"""
|
||||
Transformation for Azure Foundry Agent Service API.
|
||||
|
||||
Azure Foundry Agent Service provides an Assistants-like API for running agents.
|
||||
This follows the OpenAI Assistants pattern: create thread -> add messages -> create/poll run.
|
||||
|
||||
Model format: azure_ai/agents/<agent_id>
|
||||
|
||||
API Base format: https://<AIFoundryResourceName>.services.ai.azure.com/api/projects/<ProjectName>
|
||||
|
||||
Authentication: Uses Azure AD Bearer tokens (not API keys)
|
||||
Get token via: az account get-access-token --resource 'https://ai.azure.com'
|
||||
|
||||
The API uses these endpoints:
|
||||
- POST /threads - Create a thread
|
||||
- POST /threads/{thread_id}/messages - Add message to thread
|
||||
- POST /threads/{thread_id}/runs - Create a run
|
||||
- GET /threads/{thread_id}/runs/{run_id} - Poll run status
|
||||
- GET /threads/{thread_id}/messages - List messages in thread
|
||||
|
||||
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
HTTPHandler = Any
|
||||
AsyncHTTPHandler = Any
|
||||
|
||||
|
||||
class AzureAIAgentsError(BaseLLMException):
|
||||
"""Exception class for Azure AI Agent Service API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AzureAIAgentsConfig(BaseConfig):
|
||||
"""
|
||||
Configuration for Azure AI Agent Service API.
|
||||
|
||||
Azure AI Agent Service is a fully managed service for building AI agents
|
||||
that can understand natural language and perform tasks.
|
||||
|
||||
Model format: azure_ai/agents/<agent_id>
|
||||
|
||||
The flow is:
|
||||
1. Create a thread
|
||||
2. Add user messages to the thread
|
||||
3. Create and poll a run
|
||||
4. Retrieve the assistant's response messages
|
||||
"""
|
||||
|
||||
# Default API version for Azure Foundry Agent Service
|
||||
# GA version: 2025-05-01, Preview: 2025-05-15-preview
|
||||
# See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
DEFAULT_API_VERSION = "2025-05-01"
|
||||
|
||||
# Polling configuration
|
||||
MAX_POLL_ATTEMPTS = 60
|
||||
POLL_INTERVAL_SECONDS = 1.0
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def is_azure_ai_agents_route(model: str) -> bool:
|
||||
"""
|
||||
Check if the model is an Azure AI Agents route.
|
||||
|
||||
Model format: azure_ai/agents/<agent_id>
|
||||
"""
|
||||
return "agents/" in model
|
||||
|
||||
@staticmethod
|
||||
def get_agent_id_from_model(model: str) -> str:
|
||||
"""
|
||||
Extract agent ID from the model string.
|
||||
|
||||
Model format: azure_ai/agents/<agent_id> -> <agent_id>
|
||||
or: agents/<agent_id> -> <agent_id>
|
||||
"""
|
||||
if "agents/" in model:
|
||||
# Split on "agents/" and take the part after it
|
||||
parts = model.split("agents/", 1)
|
||||
if len(parts) == 2:
|
||||
return parts[1]
|
||||
return model
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Get Azure AI Agent Service API base and key from params or environment.
|
||||
|
||||
Returns:
|
||||
Tuple of (api_base, api_key)
|
||||
"""
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
api_base = api_base or get_secret_str("AZURE_AI_API_BASE")
|
||||
api_key = api_key or get_secret_str("AZURE_AI_API_KEY")
|
||||
|
||||
return api_base, api_key
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
Azure Agents supports minimal OpenAI params since it's an agent runtime.
|
||||
"""
|
||||
return ["stream"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI params to Azure Agents params.
|
||||
"""
|
||||
return optional_params
|
||||
|
||||
def _get_api_version(self, optional_params: dict) -> str:
|
||||
"""Get API version from optional params or use default."""
|
||||
return optional_params.get("api_version", self.DEFAULT_API_VERSION)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the base URL for Azure AI Agent Service.
|
||||
|
||||
The actual endpoint will vary based on the operation:
|
||||
- /openai/threads for creating threads
|
||||
- /openai/threads/{thread_id}/messages for adding messages
|
||||
- /openai/threads/{thread_id}/runs for creating runs
|
||||
|
||||
This returns the base URL that will be modified for each operation.
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"api_base is required for Azure AI Agents. Set it via AZURE_AI_API_BASE env var or api_base parameter."
|
||||
)
|
||||
|
||||
# Remove trailing slash if present
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
# Return base URL - actual endpoints will be constructed during request
|
||||
return api_base
|
||||
|
||||
def _get_agent_id(self, model: str, optional_params: dict) -> str:
|
||||
"""
|
||||
Get the agent ID from model or optional_params.
|
||||
|
||||
model format: "azure_ai/agents/<agent_id>" or "agents/<agent_id>" or just "<agent_id>"
|
||||
"""
|
||||
agent_id = optional_params.get("agent_id") or optional_params.get(
|
||||
"assistant_id"
|
||||
)
|
||||
if agent_id:
|
||||
return agent_id
|
||||
|
||||
# Extract from model name using the static method
|
||||
return self.get_agent_id_from_model(model)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the request for Azure Agents.
|
||||
|
||||
This stores the necessary data for the multi-step agent flow.
|
||||
The actual API calls happen in the custom handler.
|
||||
"""
|
||||
agent_id = self._get_agent_id(model, optional_params)
|
||||
|
||||
# Convert messages to a format we can use
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
# Handle content that might be a list
|
||||
if isinstance(content, list):
|
||||
content = convert_content_list_to_str(msg)
|
||||
|
||||
# Ensure content is a string
|
||||
if not isinstance(content, str):
|
||||
content = str(content)
|
||||
|
||||
converted_messages.append({"role": role, "content": content})
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"agent_id": agent_id,
|
||||
"messages": converted_messages,
|
||||
"api_version": self._get_api_version(optional_params),
|
||||
}
|
||||
|
||||
# Pass through thread_id if provided (for continuing conversations)
|
||||
if "thread_id" in optional_params:
|
||||
payload["thread_id"] = optional_params["thread_id"]
|
||||
|
||||
# Pass through any additional instructions
|
||||
if "instructions" in optional_params:
|
||||
payload["instructions"] = optional_params["instructions"]
|
||||
|
||||
verbose_logger.debug(f"Azure AI Agents request payload: {payload}")
|
||||
return payload
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate and set up environment for Azure Foundry Agents requests.
|
||||
|
||||
Azure Foundry Agents uses Bearer token authentication with Azure AD tokens.
|
||||
Get token via: az account get-access-token --resource 'https://ai.azure.com'
|
||||
|
||||
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
"""
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
# Azure Foundry Agents uses Bearer token authentication
|
||||
# The api_key here is expected to be an Azure AD token
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
return headers
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return AzureAIAgentsError(status_code=status_code, message=error_message)
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Azure Agents uses polling, so we fake stream by returning the final response.
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def has_custom_stream_wrapper(self) -> bool:
|
||||
"""Azure Agents doesn't have native streaming - uses fake stream."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_stream_param_in_request_body(self) -> bool:
|
||||
"""
|
||||
Azure Agents does not use a stream param in request body.
|
||||
"""
|
||||
return False
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform the Azure Agents response to LiteLLM ModelResponse format.
|
||||
"""
|
||||
# This is not used since we have a custom handler
|
||||
return model_response
|
||||
|
||||
@staticmethod
|
||||
def completion(
|
||||
model: str,
|
||||
messages: List,
|
||||
api_base: str,
|
||||
api_key: Optional[str],
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
timeout: Union[float, int, Any],
|
||||
acompletion: bool,
|
||||
stream: Optional[bool] = False,
|
||||
headers: Optional[dict] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Dispatch method for Azure Foundry Agents completion.
|
||||
|
||||
Routes to sync or async completion based on acompletion flag.
|
||||
Supports native streaming via SSE when stream=True and acompletion=True.
|
||||
|
||||
Authentication: Uses Azure AD Bearer tokens.
|
||||
- Pass api_key directly as an Azure AD token
|
||||
- Or set up Azure AD credentials via environment variables for automatic token retrieval:
|
||||
- AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET (Service Principal)
|
||||
|
||||
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
"""
|
||||
from litellm.llms.azure.common_utils import get_azure_ad_token
|
||||
from litellm.llms.azure_ai.agents.handler import azure_ai_agents_handler
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
# If no api_key is provided, try to get Azure AD token
|
||||
if api_key is None:
|
||||
# Try to get Azure AD token using the existing Azure auth mechanisms
|
||||
# This uses the scope for Azure AI (ai.azure.com) instead of cognitive services
|
||||
# Create a GenericLiteLLMParams with the scope override for Azure Foundry Agents
|
||||
azure_auth_params = dict(litellm_params) if litellm_params else {}
|
||||
azure_auth_params["azure_scope"] = "https://ai.azure.com/.default"
|
||||
api_key = get_azure_ad_token(GenericLiteLLMParams(**azure_auth_params))
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"api_key (Azure AD token) is required for Azure Foundry Agents. "
|
||||
"Either pass api_key directly, or set AZURE_TENANT_ID, AZURE_CLIENT_ID, "
|
||||
"and AZURE_CLIENT_SECRET environment variables for Service Principal auth. "
|
||||
"Manual token: az account get-access-token --resource 'https://ai.azure.com'"
|
||||
)
|
||||
if acompletion:
|
||||
if stream:
|
||||
# Native async streaming via SSE - return the async generator directly
|
||||
return azure_ai_agents_handler.acompletion_stream(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
return azure_ai_agents_handler.acompletion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
# Sync completion - streaming not supported for sync
|
||||
return azure_ai_agents_handler.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
)
|
||||
Reference in New Issue
Block a user