chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,229 @@
|
||||
# Cohere Rerank Guardrail Translation Handler
|
||||
|
||||
Handler for processing the rerank endpoint (`/v1/rerank`) with guardrails.
|
||||
|
||||
## Overview
|
||||
|
||||
This handler processes rerank requests by:
|
||||
1. Extracting the query text from the request
|
||||
2. Applying guardrails to the query
|
||||
3. Updating the request with the guardrailed query
|
||||
4. Returning the output unchanged (rankings are not text)
|
||||
|
||||
Note: Documents are not processed by guardrails as they represent the corpus
|
||||
being searched, not user input. Only the query is guardrailed.
|
||||
|
||||
## Data Format
|
||||
|
||||
### Input Format
|
||||
|
||||
**With String Documents:**
|
||||
```json
|
||||
{
|
||||
"model": "rerank-english-v3.0",
|
||||
"query": "What is the capital of France?",
|
||||
"documents": [
|
||||
"Paris is the capital of France.",
|
||||
"Berlin is the capital of Germany.",
|
||||
"Madrid is the capital of Spain."
|
||||
],
|
||||
"top_n": 2
|
||||
}
|
||||
```
|
||||
|
||||
**With Dict Documents:**
|
||||
```json
|
||||
{
|
||||
"model": "rerank-english-v3.0",
|
||||
"query": "What is the capital of France?",
|
||||
"documents": [
|
||||
{"text": "Paris is the capital of France.", "id": "doc1"},
|
||||
{"text": "Berlin is the capital of Germany.", "id": "doc2"},
|
||||
{"text": "Madrid is the capital of Spain.", "id": "doc3"}
|
||||
],
|
||||
"top_n": 2
|
||||
}
|
||||
```
|
||||
|
||||
### Output Format
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "rerank-abc123",
|
||||
"results": [
|
||||
{"index": 0, "relevance_score": 0.98},
|
||||
{"index": 2, "relevance_score": 0.12}
|
||||
],
|
||||
"meta": {
|
||||
"billed_units": {"search_units": 1}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
The handler is automatically discovered and applied when guardrails are used with the rerank endpoint.
|
||||
|
||||
### Example: Using Guardrails with Rerank
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/rerank' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"model": "rerank-english-v3.0",
|
||||
"query": "What is machine learning?",
|
||||
"documents": [
|
||||
"Machine learning is a subset of AI.",
|
||||
"Deep learning uses neural networks.",
|
||||
"Python is a programming language."
|
||||
],
|
||||
"guardrails": ["content_filter"],
|
||||
"top_n": 2
|
||||
}'
|
||||
```
|
||||
|
||||
The guardrail will be applied to the query only (not the documents).
|
||||
|
||||
### Example: PII Masking in Query
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/rerank' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"model": "rerank-english-v3.0",
|
||||
"query": "Find documents about John Doe from john@example.com",
|
||||
"documents": [
|
||||
"Document 1 content here.",
|
||||
"Document 2 content here.",
|
||||
"Document 3 content here."
|
||||
],
|
||||
"guardrails": ["mask_pii"],
|
||||
"top_n": 3
|
||||
}'
|
||||
```
|
||||
|
||||
The query will be masked to: "Find documents about [NAME_REDACTED] from [EMAIL_REDACTED]"
|
||||
|
||||
### Example: Mixed Document Types
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/rerank' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"model": "rerank-english-v3.0",
|
||||
"query": "Technical documentation",
|
||||
"documents": [
|
||||
{"text": "This is document 1", "metadata": {"source": "wiki"}},
|
||||
{"text": "This is document 2", "metadata": {"source": "docs"}},
|
||||
"This is document 3 as a plain string"
|
||||
],
|
||||
"guardrails": ["content_moderation"]
|
||||
}'
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Input Processing
|
||||
|
||||
- **Query Field**: `query` (string)
|
||||
- Processing: Apply guardrail to query text
|
||||
- Result: Updated query
|
||||
|
||||
- **Documents Field**: `documents` (list)
|
||||
- Processing: Not processed (corpus being searched, not user input)
|
||||
- Result: Unchanged
|
||||
|
||||
### Output Processing
|
||||
|
||||
- **Processing**: Not applicable (output contains relevance scores, not text)
|
||||
- **Result**: Response returned unchanged
|
||||
|
||||
## Use Cases
|
||||
|
||||
1. **PII Protection**: Remove PII from queries before reranking
|
||||
2. **Content Filtering**: Filter inappropriate content from search queries
|
||||
3. **Compliance**: Ensure queries meet requirements
|
||||
4. **Data Sanitization**: Clean up query text before semantic search operations
|
||||
|
||||
## Extension
|
||||
|
||||
Override these methods to customize behavior:
|
||||
|
||||
- `process_input_messages()`: Customize how query is processed
|
||||
- `process_output_response()`: Currently a no-op, but can be overridden if needed
|
||||
|
||||
## Supported Call Types
|
||||
|
||||
- `CallTypes.rerank` - Synchronous rerank
|
||||
- `CallTypes.arerank` - Asynchronous rerank
|
||||
|
||||
## Notes
|
||||
|
||||
- Only the query is processed by guardrails
|
||||
- Documents are not processed (they represent the corpus, not user input)
|
||||
- Output processing is a no-op since rankings don't contain text
|
||||
- Both sync and async call types use the same handler
|
||||
- Works with all rerank providers (Cohere, Together AI, etc.)
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### PII Masking in Search
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
response = litellm.rerank(
|
||||
model="rerank-english-v3.0",
|
||||
query="Find info about john@example.com",
|
||||
documents=[
|
||||
"Document 1 content.",
|
||||
"Document 2 content.",
|
||||
"Document 3 content."
|
||||
],
|
||||
guardrails=["mask_pii"],
|
||||
top_n=2
|
||||
)
|
||||
|
||||
# Query will have PII masked
|
||||
# query becomes: "Find info about [EMAIL_REDACTED]"
|
||||
print(response.results)
|
||||
```
|
||||
|
||||
### Content Filtering
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
response = litellm.rerank(
|
||||
model="rerank-english-v3.0",
|
||||
query="Search query here",
|
||||
documents=[
|
||||
{"text": "Document 1 content", "id": "doc1"},
|
||||
{"text": "Document 2 content", "id": "doc2"},
|
||||
],
|
||||
guardrails=["content_filter"],
|
||||
)
|
||||
```
|
||||
|
||||
### Async Rerank with Guardrails
|
||||
|
||||
```python
|
||||
import litellm
|
||||
import asyncio
|
||||
|
||||
async def rerank_with_guardrails():
|
||||
response = await litellm.arerank(
|
||||
model="rerank-english-v3.0",
|
||||
query="Technical query",
|
||||
documents=["Doc 1", "Doc 2", "Doc 3"],
|
||||
guardrails=["sanitize"],
|
||||
top_n=2
|
||||
)
|
||||
return response
|
||||
|
||||
result = asyncio.run(rerank_with_guardrails())
|
||||
```
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
"""Cohere Rerank handler for Unified Guardrails."""
|
||||
|
||||
from litellm.llms.cohere.rerank.guardrail_translation.handler import CohereRerankHandler
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
guardrail_translation_mappings = {
|
||||
CallTypes.rerank: CohereRerankHandler,
|
||||
CallTypes.arerank: CohereRerankHandler,
|
||||
}
|
||||
|
||||
__all__ = ["guardrail_translation_mappings", "CohereRerankHandler"]
|
||||
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Cohere Rerank Handler for Unified Guardrails
|
||||
|
||||
This module provides guardrail translation support for the rerank endpoint.
|
||||
The handler processes only the 'query' parameter for guardrails.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.types.rerank import RerankResponse
|
||||
|
||||
|
||||
class CohereRerankHandler(BaseTranslation):
|
||||
"""
|
||||
Handler for processing rerank requests with guardrails.
|
||||
|
||||
This class provides methods to:
|
||||
1. Process input query (pre-call hook)
|
||||
2. Process output response (post-call hook) - not applicable for rerank
|
||||
|
||||
The handler specifically processes:
|
||||
- The 'query' parameter (string)
|
||||
|
||||
Note: Documents are not processed by guardrails as they are the corpus
|
||||
being searched, not user input.
|
||||
"""
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input query by applying guardrails.
|
||||
|
||||
Args:
|
||||
data: Request data dictionary containing 'query'
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
|
||||
Returns:
|
||||
Modified data with guardrails applied to query only
|
||||
"""
|
||||
# Process query only
|
||||
query = data.get("query")
|
||||
if query is not None and isinstance(query, str):
|
||||
inputs = GenericGuardrailAPIInputs(texts=[query])
|
||||
# Include model information if available
|
||||
model = data.get("model")
|
||||
if model:
|
||||
inputs["model"] = model
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
data["query"] = guardrailed_texts[0] if guardrailed_texts else query
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Rerank: Applied guardrail to query. "
|
||||
"Original length: %d, New length: %d",
|
||||
len(query),
|
||||
len(data["query"]),
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
"Rerank: No query to process or query is not a string"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: "RerankResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output response - not applicable for rerank.
|
||||
|
||||
Rerank responses contain relevance scores and indices, not text,
|
||||
so there's nothing to apply guardrails to. This method returns
|
||||
the response unchanged.
|
||||
|
||||
Args:
|
||||
response: Rerank response object with rankings
|
||||
guardrail_to_apply: The guardrail instance (unused)
|
||||
litellm_logging_obj: Optional logging object (unused)
|
||||
user_api_key_dict: User API key metadata (unused)
|
||||
|
||||
Returns:
|
||||
Unmodified response (rankings don't need text guardrails)
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
"Rerank: Output processing not applicable "
|
||||
"(output contains relevance scores, not text)"
|
||||
)
|
||||
return response
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Cohere Rerank - uses `llm_http_handler.py` to make httpx requests
|
||||
|
||||
Request/Response transformation is handled in `transformation.py`
|
||||
"""
|
||||
@@ -0,0 +1,158 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankRequest, RerankResponse
|
||||
|
||||
from ..common_utils import CohereError
|
||||
|
||||
|
||||
class CohereRerankConfig(BaseRerankConfig):
|
||||
"""
|
||||
Reference: https://docs.cohere.com/v2/reference/rerank
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> str:
|
||||
if api_base:
|
||||
# Remove trailing slashes and ensure clean base URL
|
||||
api_base = api_base.rstrip("/")
|
||||
if not api_base.endswith("/v1/rerank"):
|
||||
api_base = f"{api_base}/v1/rerank"
|
||||
return api_base
|
||||
return "https://api.cohere.ai/v1/rerank"
|
||||
|
||||
def get_supported_cohere_rerank_params(self, model: str) -> list:
|
||||
return [
|
||||
"query",
|
||||
"documents",
|
||||
"top_n",
|
||||
"max_chunks_per_doc",
|
||||
"rank_fields",
|
||||
"return_documents",
|
||||
]
|
||||
|
||||
def map_cohere_rerank_params(
|
||||
self,
|
||||
non_default_params: Optional[dict],
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
max_tokens_per_doc: Optional[int] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Map Cohere rerank params
|
||||
|
||||
No mapping required - returns all supported params
|
||||
"""
|
||||
return dict(
|
||||
OptionalRerankParams(
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
)
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
api_key = (
|
||||
get_secret_str("COHERE_API_KEY")
|
||||
or get_secret_str("CO_API_KEY")
|
||||
or litellm.cohere_key
|
||||
)
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Cohere API key is required. Please set 'COHERE_API_KEY' or 'CO_API_KEY' or 'litellm.cohere_key'"
|
||||
)
|
||||
|
||||
default_headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
# If 'Authorization' is provided in headers, it overrides the default.
|
||||
if "Authorization" in headers:
|
||||
default_headers["Authorization"] = headers["Authorization"]
|
||||
|
||||
# Merge other headers, overriding any default ones except Authorization
|
||||
return {**default_headers, **headers}
|
||||
|
||||
def transform_rerank_request(
|
||||
self,
|
||||
model: str,
|
||||
optional_rerank_params: Dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
if "query" not in optional_rerank_params:
|
||||
raise ValueError("query is required for Cohere rerank")
|
||||
if "documents" not in optional_rerank_params:
|
||||
raise ValueError("documents is required for Cohere rerank")
|
||||
rerank_request = RerankRequest(
|
||||
model=model,
|
||||
query=optional_rerank_params["query"],
|
||||
documents=optional_rerank_params["documents"],
|
||||
top_n=optional_rerank_params.get("top_n", None),
|
||||
rank_fields=optional_rerank_params.get("rank_fields", None),
|
||||
return_documents=optional_rerank_params.get("return_documents", None),
|
||||
max_chunks_per_doc=optional_rerank_params.get("max_chunks_per_doc", None),
|
||||
)
|
||||
return rerank_request.model_dump(exclude_none=True)
|
||||
|
||||
def transform_rerank_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: RerankResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> RerankResponse:
|
||||
"""
|
||||
Transform Cohere rerank response
|
||||
|
||||
No transformation required, litellm follows cohere API response format
|
||||
"""
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise CohereError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
return RerankResponse(**raw_response_json)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return CohereError(message=error_message, status_code=status_code)
|
||||
Reference in New Issue
Block a user