chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
"""RAG Endpoints for LiteLLM Proxy."""
|
||||
|
||||
from litellm.proxy.rag_endpoints.endpoints import router
|
||||
|
||||
__all__ = ["router"]
|
||||
@@ -0,0 +1,581 @@
|
||||
"""
|
||||
RAG Endpoints for LiteLLM Proxy.
|
||||
|
||||
Provides:
|
||||
- /rag/ingest: All-in-one document ingestion pipeline (Upload -> Chunk -> Embed -> Vector Store)
|
||||
- /rag/query: RAG query pipeline (Search -> Rerank -> LLM Completion)
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import orjson
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.responses import ORJSONResponse
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth, user_api_key_auth
|
||||
from litellm.proxy.common_utils.http_parsing_utils import (
|
||||
_read_request_body,
|
||||
_safe_get_request_headers,
|
||||
get_form_data,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _build_file_metadata_entry(
|
||||
response: Any,
|
||||
file_data: Optional[Tuple[str, bytes, str]] = None,
|
||||
file_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build a file metadata entry for storing in vector_store_metadata.
|
||||
|
||||
Args:
|
||||
response: The response from litellm.aingest containing file_id
|
||||
file_data: Optional tuple of (filename, content, content_type)
|
||||
file_url: Optional URL if file was ingested from URL
|
||||
|
||||
Returns:
|
||||
Dictionary with file metadata (file_id, filename, file_url, ingested_at, etc.)
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Extract file_id from response
|
||||
file_id = None
|
||||
if hasattr(response, "get"):
|
||||
file_id = response.get("file_id")
|
||||
elif hasattr(response, "file_id"):
|
||||
file_id = response.file_id
|
||||
|
||||
# Extract file information from file_data tuple
|
||||
filename = None
|
||||
file_size = None
|
||||
content_type = None
|
||||
|
||||
if file_data:
|
||||
filename = file_data[0]
|
||||
file_size = len(file_data[1]) if len(file_data) > 1 else None
|
||||
content_type = file_data[2] if len(file_data) > 2 else None
|
||||
|
||||
# Build file metadata entry
|
||||
file_entry = {
|
||||
"file_id": file_id,
|
||||
"filename": filename,
|
||||
"file_url": file_url,
|
||||
"ingested_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# Add optional fields if available
|
||||
if file_size is not None:
|
||||
file_entry["file_size"] = file_size
|
||||
if content_type is not None:
|
||||
file_entry["content_type"] = content_type
|
||||
|
||||
return file_entry
|
||||
|
||||
|
||||
async def _save_vector_store_to_db_from_rag_ingest(
|
||||
response: Any,
|
||||
ingest_options: Dict[str, Any],
|
||||
prisma_client,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
file_data: Optional[Tuple[str, bytes, str]] = None,
|
||||
file_url: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Helper function to save a newly created vector store from RAG ingest to the database.
|
||||
|
||||
This function:
|
||||
- Extracts vector store ID and config from the ingest response
|
||||
- Checks if the vector store already exists in the database
|
||||
- Creates a new database entry if it doesn't exist
|
||||
- Adds the vector store to the registry
|
||||
- Tracks team_id and user_id for access control
|
||||
|
||||
Args:
|
||||
response: The response from litellm.aingest()
|
||||
ingest_options: The ingest options containing vector store config
|
||||
prisma_client: The Prisma database client
|
||||
user_api_key_dict: User API key authentication info
|
||||
"""
|
||||
from litellm.proxy.vector_store_endpoints.management_endpoints import (
|
||||
create_vector_store_in_db,
|
||||
)
|
||||
|
||||
# Handle both dict and object responses
|
||||
if hasattr(response, "get"):
|
||||
vector_store_id = response.get("vector_store_id")
|
||||
elif hasattr(response, "vector_store_id"):
|
||||
vector_store_id = response.vector_store_id
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Unable to extract vector_store_id from response type: {type(response)}"
|
||||
)
|
||||
return
|
||||
|
||||
if vector_store_id is None or not isinstance(vector_store_id, str):
|
||||
verbose_proxy_logger.warning(
|
||||
"Vector store ID is None or not a string, skipping database save"
|
||||
)
|
||||
return
|
||||
|
||||
vector_store_config = ingest_options.get("vector_store", {})
|
||||
custom_llm_provider = vector_store_config.get("custom_llm_provider")
|
||||
|
||||
# Extract litellm_vector_store_params for custom name and description
|
||||
litellm_vector_store_params = ingest_options.get("litellm_vector_store_params", {})
|
||||
custom_vector_store_name = litellm_vector_store_params.get("vector_store_name")
|
||||
custom_vector_store_description = litellm_vector_store_params.get(
|
||||
"vector_store_description"
|
||||
)
|
||||
|
||||
# Extract provider-specific params from vector_store_config to save as litellm_params
|
||||
# This ensures params like aws_region_name, embedding_model, etc. are available for search
|
||||
provider_specific_params = {}
|
||||
excluded_keys = {"custom_llm_provider", "vector_store_id"}
|
||||
for key, value in vector_store_config.items():
|
||||
if key not in excluded_keys and value is not None:
|
||||
provider_specific_params[key] = value
|
||||
|
||||
# Build file metadata entry using helper
|
||||
file_entry = _build_file_metadata_entry(
|
||||
response=response,
|
||||
file_data=file_data,
|
||||
file_url=file_url,
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if vector store already exists in database
|
||||
existing_vector_store = (
|
||||
await prisma_client.db.litellm_managedvectorstorestable.find_unique(
|
||||
where={"vector_store_id": vector_store_id}
|
||||
)
|
||||
)
|
||||
|
||||
# Only create if it doesn't exist
|
||||
if existing_vector_store is None:
|
||||
verbose_proxy_logger.info(
|
||||
f"Saving newly created vector store {vector_store_id} to database"
|
||||
)
|
||||
|
||||
# Initialize metadata with first file
|
||||
initial_metadata = {"ingested_files": [file_entry]}
|
||||
|
||||
# Use custom name if provided, otherwise default
|
||||
vector_store_name = (
|
||||
custom_vector_store_name or f"RAG Vector Store - {vector_store_id[:8]}"
|
||||
)
|
||||
vector_store_description = (
|
||||
custom_vector_store_description or "Created via RAG ingest endpoint"
|
||||
)
|
||||
|
||||
await create_vector_store_in_db(
|
||||
vector_store_id=vector_store_id,
|
||||
custom_llm_provider=custom_llm_provider or "openai",
|
||||
prisma_client=prisma_client,
|
||||
vector_store_name=vector_store_name,
|
||||
vector_store_description=vector_store_description,
|
||||
vector_store_metadata=initial_metadata,
|
||||
litellm_params=provider_specific_params
|
||||
if provider_specific_params
|
||||
else None,
|
||||
team_id=user_api_key_dict.team_id,
|
||||
user_id=user_api_key_dict.user_id,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Vector store {vector_store_id} saved to database successfully"
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.info(
|
||||
f"Vector store {vector_store_id} already exists, appending file to metadata"
|
||||
)
|
||||
|
||||
# Update existing vector store with new file
|
||||
existing_metadata = existing_vector_store.vector_store_metadata or {}
|
||||
if isinstance(existing_metadata, str):
|
||||
import json
|
||||
|
||||
existing_metadata = json.loads(existing_metadata)
|
||||
|
||||
ingested_files = existing_metadata.get("ingested_files", [])
|
||||
ingested_files.append(file_entry)
|
||||
existing_metadata["ingested_files"] = ingested_files
|
||||
|
||||
# Update the vector store
|
||||
from litellm.proxy.utils import safe_dumps
|
||||
|
||||
await prisma_client.db.litellm_managedvectorstorestable.update(
|
||||
where={"vector_store_id": vector_store_id},
|
||||
data={"vector_store_metadata": safe_dumps(existing_metadata)},
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Added file {file_entry.get('filename') or file_entry.get('file_url', 'Unknown')} to vector store {vector_store_id} metadata"
|
||||
)
|
||||
except Exception as db_error:
|
||||
# Log the error but don't fail the request since ingestion succeeded
|
||||
verbose_proxy_logger.exception(
|
||||
f"Failed to save vector store {vector_store_id} to database: {db_error}"
|
||||
)
|
||||
|
||||
|
||||
async def parse_rag_ingest_request(
|
||||
request: Request,
|
||||
) -> Tuple[
|
||||
Dict[str, Any], Optional[Tuple[str, bytes, str]], Optional[str], Optional[str]
|
||||
]:
|
||||
"""
|
||||
Parse RAG ingest request.
|
||||
|
||||
Supports:
|
||||
- Form: file + request JSON in form field
|
||||
- JSON body for URL-based ingestion
|
||||
|
||||
Returns:
|
||||
Tuple of (ingest_options, file_data, file_url, file_id)
|
||||
"""
|
||||
headers = _safe_get_request_headers(request)
|
||||
content_type = headers.get("content-type", "")
|
||||
|
||||
file_data = None
|
||||
file_url = None
|
||||
file_id = None
|
||||
ingest_options: Dict[str, Any] = {}
|
||||
|
||||
if "multipart/form-data" in content_type:
|
||||
# Form upload
|
||||
form_data = await get_form_data(request)
|
||||
|
||||
# Get file
|
||||
file_obj = form_data.get("file")
|
||||
if file_obj is not None and hasattr(file_obj, "read"):
|
||||
file_content = await file_obj.read()
|
||||
file_data = (file_obj.filename, file_content, file_obj.content_type)
|
||||
|
||||
# Parse JSON from 'request' form field (contains full request body as JSON)
|
||||
request_json_str = form_data.get("request")
|
||||
if request_json_str:
|
||||
request_data = orjson.loads(request_json_str)
|
||||
ingest_options = request_data.get("ingest_options", {})
|
||||
file_url = request_data.get("file_url")
|
||||
file_id = request_data.get("file_id")
|
||||
|
||||
else:
|
||||
# JSON body
|
||||
data = await _read_request_body(request)
|
||||
ingest_options = data.get("ingest_options", {})
|
||||
file_url = data.get("file_url")
|
||||
file_id = data.get("file_id")
|
||||
|
||||
# Handle base64-encoded file in JSON body
|
||||
file_obj = data.get("file")
|
||||
if file_obj and isinstance(file_obj, dict):
|
||||
filename = file_obj.get("filename")
|
||||
content_b64 = file_obj.get("content")
|
||||
content_type = file_obj.get("content_type", "application/octet-stream")
|
||||
|
||||
if filename and content_b64:
|
||||
try:
|
||||
file_content = base64.b64decode(content_b64)
|
||||
file_data = (filename, file_content, content_type)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": f"Invalid base64 content: {e}"},
|
||||
)
|
||||
|
||||
# Validate
|
||||
if file_data is None and file_url is None and file_id is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Must provide file, file_url, or file_id"},
|
||||
)
|
||||
|
||||
if "vector_store" not in ingest_options:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "ingest_options must contain 'vector_store' configuration"
|
||||
},
|
||||
)
|
||||
|
||||
return ingest_options, file_data, file_url, file_id
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/rag/ingest",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_class=ORJSONResponse,
|
||||
tags=["rag"],
|
||||
)
|
||||
@router.post(
|
||||
"/rag/ingest",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_class=ORJSONResponse,
|
||||
tags=["rag"],
|
||||
)
|
||||
async def rag_ingest(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
RAG Ingest endpoint - all-in-one document ingestion pipeline.
|
||||
|
||||
Supports form upload (for files) or JSON body (for URLs).
|
||||
|
||||
## Form upload (for files):
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/v1/rag/ingest" \\
|
||||
-H "Authorization: Bearer sk-1234" \\
|
||||
-F file="@document.pdf" \\
|
||||
-F 'ingest_options={"vector_store": {"custom_llm_provider": "openai"}}'
|
||||
```
|
||||
|
||||
## JSON body (for URLs):
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/v1/rag/ingest" \\
|
||||
-H "Authorization: Bearer sk-1234" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"file_url": "https://example.com/document.pdf",
|
||||
"ingest_options": {"vector_store": {"custom_llm_provider": "openai"}}
|
||||
}'
|
||||
```
|
||||
|
||||
## Bedrock:
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/v1/rag/ingest" \\
|
||||
-H "Authorization: Bearer sk-1234" \\
|
||||
-F file="@document.pdf" \\
|
||||
-F 'ingest_options={"vector_store": {"custom_llm_provider": "bedrock"}}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
llm_router,
|
||||
prisma_client,
|
||||
proxy_config,
|
||||
version,
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse request
|
||||
ingest_options, file_data, file_url, file_id = await parse_rag_ingest_request(
|
||||
request
|
||||
)
|
||||
|
||||
# INTERNAL_USER_VIEW_ONLY can ingest to existing vector stores only
|
||||
if (
|
||||
user_api_key_dict.user_role
|
||||
== LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
|
||||
and not ingest_options.get("vector_store", {}).get("vector_store_id")
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"error": "internal_user_viewer role can only ingest files to an existing vector store. "
|
||||
"Provide 'vector_store_id' in ingest_options.vector_store."
|
||||
},
|
||||
)
|
||||
|
||||
# Add litellm data
|
||||
request_data: Dict[str, Any] = {}
|
||||
request_data = await add_litellm_data_to_request(
|
||||
data=request_data,
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(f"RAG Ingest - options: {ingest_options}")
|
||||
|
||||
# Call ingest
|
||||
response = await litellm.aingest(
|
||||
ingest_options=ingest_options,
|
||||
file_data=file_data,
|
||||
file_url=file_url,
|
||||
file_id=file_id,
|
||||
router=llm_router,
|
||||
**request_data,
|
||||
)
|
||||
|
||||
# Save vector store to database if it was newly created and prisma_client is available
|
||||
verbose_proxy_logger.debug(
|
||||
f"RAG Ingest - Checking database save conditions: prisma_client={prisma_client is not None}, response={response is not None}, response_type={type(response)}"
|
||||
)
|
||||
|
||||
if prisma_client is not None and response is not None:
|
||||
await _save_vector_store_to_db_from_rag_ingest(
|
||||
response=response,
|
||||
ingest_options=ingest_options,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
file_data=file_data,
|
||||
file_url=file_url,
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Skipping database save: prisma_client={prisma_client is not None}, response={response is not None}"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"RAG Ingest failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/rag/query",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_class=ORJSONResponse,
|
||||
tags=["rag"],
|
||||
)
|
||||
@router.post(
|
||||
"/rag/query",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_class=ORJSONResponse,
|
||||
tags=["rag"],
|
||||
)
|
||||
async def rag_query(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
RAG Query endpoint - search vector store, optionally rerank, and generate LLM response.
|
||||
|
||||
This endpoint:
|
||||
1. Extracts the query from the last user message
|
||||
2. Searches the vector store for relevant context
|
||||
3. Optionally reranks the results
|
||||
4. Generates an LLM response with the retrieved context
|
||||
|
||||
## Example Request:
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/v1/rag/query" \\
|
||||
-H "Authorization: Bearer sk-1234" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"model": "gpt-4o-mini",
|
||||
"messages": [{"role": "user", "content": "What is LiteLLM?"}],
|
||||
"retrieval_config": {
|
||||
"vector_store_id": "vs_abc123",
|
||||
"custom_llm_provider": "openai",
|
||||
"top_k": 5
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
## With Reranking:
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/v1/rag/query" \\
|
||||
-H "Authorization: Bearer sk-1234" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"model": "gpt-4o-mini",
|
||||
"messages": [{"role": "user", "content": "What is LiteLLM?"}],
|
||||
"retrieval_config": {
|
||||
"vector_store_id": "vs_abc123",
|
||||
"custom_llm_provider": "openai",
|
||||
"top_k": 10
|
||||
},
|
||||
"rerank": {
|
||||
"enabled": true,
|
||||
"model": "cohere/rerank-english-v3.0",
|
||||
"top_n": 3
|
||||
}
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
version,
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse request body
|
||||
data = await _read_request_body(request)
|
||||
|
||||
# Extract required fields
|
||||
model = data.get("model")
|
||||
messages = data.get("messages")
|
||||
retrieval_config = data.get("retrieval_config")
|
||||
rerank = data.get("rerank")
|
||||
stream = data.get("stream", False)
|
||||
|
||||
# Validate required fields
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "model is required"},
|
||||
)
|
||||
if not messages:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "messages is required"},
|
||||
)
|
||||
if not retrieval_config:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "retrieval_config is required"},
|
||||
)
|
||||
if "vector_store_id" not in retrieval_config:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "retrieval_config must contain 'vector_store_id'"},
|
||||
)
|
||||
|
||||
# Add litellm data
|
||||
request_data: Dict[str, Any] = {}
|
||||
request_data = await add_litellm_data_to_request(
|
||||
data=request_data,
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"RAG Query - model: {model}, retrieval_config: {retrieval_config}"
|
||||
)
|
||||
|
||||
# Call query
|
||||
response = await litellm.aquery(
|
||||
model=model,
|
||||
messages=messages,
|
||||
retrieval_config=retrieval_config,
|
||||
rerank=rerank,
|
||||
stream=stream,
|
||||
router=llm_router,
|
||||
**request_data,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"RAG Query failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": str(e)},
|
||||
)
|
||||
Reference in New Issue
Block a user