chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Mistral OCR transformation module."""
|
||||
@@ -0,0 +1,11 @@
|
||||
"""Mistral OCR handler for Unified Guardrails."""
|
||||
|
||||
from litellm.llms.mistral.ocr.guardrail_translation.handler import OCRHandler
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
guardrail_translation_mappings = {
|
||||
CallTypes.ocr: OCRHandler,
|
||||
CallTypes.aocr: OCRHandler,
|
||||
}
|
||||
|
||||
__all__ = ["guardrail_translation_mappings", "OCRHandler"]
|
||||
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
OCR Handler for Unified Guardrails
|
||||
|
||||
Provides guardrail translation support for the OCR endpoint.
|
||||
Processes the extracted markdown text from OCR pages.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.llms.base_llm.ocr.transformation import OCRResponse
|
||||
|
||||
|
||||
class OCRHandler(BaseTranslation):
|
||||
"""
|
||||
Handler for processing OCR requests/responses with guardrails.
|
||||
|
||||
Input: The OCR input is a document URL/reference - not text content.
|
||||
We pass the document URL as text for guardrails that may want to
|
||||
validate or filter document sources.
|
||||
|
||||
Output: OCR responses contain extracted markdown text per page.
|
||||
The handler extracts all page markdown, applies guardrails,
|
||||
and maps the guardrailed text back to the pages.
|
||||
"""
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process OCR input by applying guardrails to the document reference.
|
||||
|
||||
The OCR input contains a document dict with a URL. We extract
|
||||
the URL and pass it to the guardrail for validation.
|
||||
|
||||
Args:
|
||||
data: Request data containing 'document' parameter
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
|
||||
Returns:
|
||||
Modified data with guardrails applied
|
||||
"""
|
||||
document = data.get("document")
|
||||
if document is None or not isinstance(document, dict):
|
||||
verbose_proxy_logger.debug(
|
||||
"OCR guardrail: No valid document found in request data"
|
||||
)
|
||||
return data
|
||||
|
||||
# Extract the document URL for guardrail checking
|
||||
texts_to_check: List[str] = []
|
||||
doc_type = document.get("type")
|
||||
if doc_type == "document_url":
|
||||
url = document.get("document_url")
|
||||
if url and isinstance(url, str):
|
||||
texts_to_check.append(url)
|
||||
elif doc_type == "image_url":
|
||||
url = document.get("image_url")
|
||||
if url and isinstance(url, str):
|
||||
texts_to_check.append(url)
|
||||
|
||||
if not texts_to_check:
|
||||
return data
|
||||
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
model = data.get("model")
|
||||
if model:
|
||||
inputs["model"] = model
|
||||
|
||||
await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: "OCRResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process OCR output by applying guardrails to extracted page text.
|
||||
|
||||
Extracts markdown text from each OCR page, applies guardrails,
|
||||
and maps the guardrailed text back to the pages.
|
||||
|
||||
Args:
|
||||
response: OCRResponse with pages containing markdown text
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
user_api_key_dict: User API key metadata
|
||||
|
||||
Returns:
|
||||
Modified OCRResponse with guardrailed page text
|
||||
"""
|
||||
if not hasattr(response, "pages") or not response.pages:
|
||||
verbose_proxy_logger.debug("OCR guardrail: No pages found in OCR response")
|
||||
return response
|
||||
|
||||
# Extract markdown text from all pages
|
||||
texts_to_check: List[str] = []
|
||||
page_indices: List[int] = []
|
||||
for i, page in enumerate(response.pages):
|
||||
if hasattr(page, "markdown") and page.markdown:
|
||||
texts_to_check.append(page.markdown)
|
||||
page_indices.append(i)
|
||||
|
||||
if not texts_to_check:
|
||||
return response
|
||||
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
model = getattr(response, "model", None)
|
||||
if model:
|
||||
inputs["model"] = model
|
||||
|
||||
# Add user metadata if available
|
||||
if user_api_key_dict is not None:
|
||||
metadata = self.transform_user_api_key_dict_to_metadata(user_api_key_dict)
|
||||
inputs.update(metadata) # type: ignore
|
||||
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data={},
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
# Map guardrailed text back to pages
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
for idx, page_idx in enumerate(page_indices):
|
||||
if idx < len(guardrailed_texts):
|
||||
response.pages[page_idx].markdown = guardrailed_texts[idx]
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"OCR guardrail: Applied guardrail to %d pages",
|
||||
len(guardrailed_texts),
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
Mistral OCR transformation implementation.
|
||||
"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base_llm.ocr.transformation import (
|
||||
BaseOCRConfig,
|
||||
DocumentType,
|
||||
OCRRequestData,
|
||||
OCRResponse,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
|
||||
class MistralOCRConfig(BaseOCRConfig):
|
||||
"""
|
||||
Mistral OCR transformation configuration.
|
||||
|
||||
Reference: https://docs.mistral.ai/api/#tag/ocr
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_supported_ocr_params(self, model: str) -> list:
|
||||
"""
|
||||
Get supported OCR parameters for Mistral OCR.
|
||||
|
||||
Mistral OCR supports:
|
||||
- pages: List of page numbers to process
|
||||
- include_image_base64: Whether to include base64 encoded images
|
||||
- image_limit: Maximum number of images to return
|
||||
- image_min_size: Minimum size of images to include
|
||||
- bbox_annotation_format: Format for bounding box annotations
|
||||
- document_annotation_format: Format for document annotations
|
||||
"""
|
||||
return [
|
||||
"pages",
|
||||
"include_image_base64",
|
||||
"image_limit",
|
||||
"image_min_size",
|
||||
"bbox_annotation_format",
|
||||
"document_annotation_format",
|
||||
]
|
||||
|
||||
def map_ocr_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OCR parameters to Mistral-specific format.
|
||||
|
||||
Mistral accepts these parameters directly, so no transformation needed.
|
||||
Just filter out unsupported params.
|
||||
"""
|
||||
supported_params = self.get_supported_ocr_params(model=model)
|
||||
|
||||
# Only include params that are in the supported list
|
||||
mapped_params = {}
|
||||
for param, value in non_default_params.items():
|
||||
if param in supported_params:
|
||||
mapped_params[param] = value
|
||||
|
||||
return mapped_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: Dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""
|
||||
Validate environment and return headers for Mistral OCR.
|
||||
"""
|
||||
# Get API key from environment if not provided
|
||||
if api_key is None:
|
||||
api_key = get_secret_str("MISTRAL_API_KEY")
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Missing Mistral API Key - A call is being made to Mistral but no key is set either in the environment variables or via params"
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
**headers,
|
||||
}
|
||||
|
||||
# Don't set Content-Type for multipart/form-data - httpx will handle it
|
||||
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Get complete URL for Mistral OCR endpoint.
|
||||
|
||||
Returns: https://api.mistral.ai/v1/ocr
|
||||
"""
|
||||
if api_base is None:
|
||||
api_base = "https://api.mistral.ai/v1"
|
||||
|
||||
# Ensure no trailing slash
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
# Remove /v1 if it's already in the base to avoid duplication
|
||||
if api_base.endswith("/v1"):
|
||||
return f"{api_base}/ocr"
|
||||
|
||||
return f"{api_base}/v1/ocr"
|
||||
|
||||
def transform_ocr_request(
|
||||
self,
|
||||
model: str,
|
||||
document: DocumentType,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
**kwargs,
|
||||
) -> OCRRequestData:
|
||||
"""
|
||||
Transform OCR request to Mistral-specific format.
|
||||
|
||||
Mistral OCR API accepts:
|
||||
{
|
||||
"model": "mistral-ocr-latest",
|
||||
"document": {
|
||||
"type": "document_url",
|
||||
"document_url": "<https-url or data-uri>"
|
||||
},
|
||||
"pages": [0], # optional
|
||||
"include_image_base64": false, # optional
|
||||
...
|
||||
}
|
||||
|
||||
Args:
|
||||
model: Model name (e.g., "mistral-ocr-latest")
|
||||
document: Document dict from user (Mistral format) - already validated in main.py
|
||||
optional_params: Already mapped optional parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
OCRRequestData with JSON data
|
||||
"""
|
||||
verbose_logger.debug(f"Mistral OCR transform_ocr_request - model: {model}")
|
||||
|
||||
# Document parameter is the Mistral-format dict from the user
|
||||
# Just pass it through as-is to the Mistral API
|
||||
if not isinstance(document, dict):
|
||||
raise ValueError(f"Expected document dict, got {type(document)}")
|
||||
|
||||
# Build request data - use document dict directly
|
||||
data = {
|
||||
"model": model,
|
||||
"document": document, # Pass through the Mistral-format document dict
|
||||
}
|
||||
|
||||
# Add all optional parameters from the already-mapped optional_params
|
||||
data.update(optional_params)
|
||||
|
||||
# No multipart files - using JSON
|
||||
return OCRRequestData(data=data, files=None)
|
||||
|
||||
def transform_ocr_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: Any,
|
||||
**kwargs,
|
||||
) -> OCRResponse:
|
||||
"""
|
||||
Return Mistral OCR response in native format.
|
||||
|
||||
Mistral OCR is the standard format for LiteLLM OCR responses.
|
||||
No transformation needed - return native response.
|
||||
|
||||
Mistral OCR returns:
|
||||
{
|
||||
"pages": [
|
||||
{
|
||||
"index": 0,
|
||||
"markdown": "extracted text content",
|
||||
"images": [...],
|
||||
"dimensions": {...}
|
||||
},
|
||||
...
|
||||
],
|
||||
"model": "mistral-ocr-2505-completion",
|
||||
"document_annotation": null,
|
||||
"usage_info": {...}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
response_json = raw_response.json()
|
||||
|
||||
verbose_logger.debug(f"Mistral OCR response keys: {response_json.keys()}")
|
||||
|
||||
# Return native Mistral format - no transformation
|
||||
return OCRResponse(
|
||||
pages=response_json.get("pages", []),
|
||||
model=response_json.get("model", model),
|
||||
document_annotation=response_json.get("document_annotation"),
|
||||
usage_info=response_json.get("usage_info"),
|
||||
object="ocr",
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error parsing Mistral OCR response: {e}")
|
||||
raise e
|
||||
Reference in New Issue
Block a user