chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
LiteLLM RAG (Retrieval Augmented Generation) Module.
|
||||
|
||||
Provides an all-in-one API for document ingestion:
|
||||
Upload -> (OCR) -> Chunk -> Embed -> Vector Store
|
||||
"""
|
||||
|
||||
from litellm.rag.main import aingest, aquery, ingest, query
|
||||
|
||||
__all__ = ["ingest", "aingest", "query", "aquery"]
|
||||
|
||||
|
||||
# Expose at litellm.rag level for convenience
|
||||
async def arag_ingest(*args, **kwargs):
|
||||
"""Alias for aingest."""
|
||||
return await aingest(*args, **kwargs)
|
||||
|
||||
|
||||
def rag_ingest(*args, **kwargs):
|
||||
"""Alias for ingest."""
|
||||
return ingest(*args, **kwargs)
|
||||
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
RAG Ingestion classes for different providers.
|
||||
"""
|
||||
|
||||
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
|
||||
from litellm.rag.ingestion.bedrock_ingestion import BedrockRAGIngestion
|
||||
from litellm.rag.ingestion.gemini_ingestion import GeminiRAGIngestion
|
||||
from litellm.rag.ingestion.openai_ingestion import OpenAIRAGIngestion
|
||||
from litellm.rag.ingestion.s3_vectors_ingestion import S3VectorsRAGIngestion
|
||||
from litellm.rag.ingestion.vertex_ai_ingestion import VertexAIRAGIngestion
|
||||
|
||||
__all__ = [
|
||||
"BaseRAGIngestion",
|
||||
"BedrockRAGIngestion",
|
||||
"GeminiRAGIngestion",
|
||||
"OpenAIRAGIngestion",
|
||||
"S3VectorsRAGIngestion",
|
||||
"VertexAIRAGIngestion",
|
||||
]
|
||||
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
Base RAG Ingestion class.
|
||||
|
||||
Provides abstract methods for:
|
||||
- OCR
|
||||
- Chunking
|
||||
- Embedding
|
||||
- Vector Store operations
|
||||
|
||||
Providers can inherit and override methods as needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm._uuid import uuid4
|
||||
from litellm.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.rag.ingestion.file_parsers import extract_text_from_pdf
|
||||
from litellm.rag.text_splitters import RecursiveCharacterTextSplitter
|
||||
from litellm.types.rag import RAGIngestOptions, RAGIngestResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import Router
|
||||
|
||||
|
||||
class BaseRAGIngestion(ABC):
|
||||
"""
|
||||
Base class for RAG ingestion.
|
||||
|
||||
Providers should inherit from this class and override methods as needed.
|
||||
For example, OpenAI handles embedding internally when attaching files to
|
||||
vector stores, so it overrides the embedding step to be a no-op.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ingest_options: RAGIngestOptions,
|
||||
router: Optional["Router"] = None,
|
||||
):
|
||||
self.ingest_options = ingest_options
|
||||
self.router = router
|
||||
self.ingest_id = f"ingest_{uuid4()}"
|
||||
|
||||
# Extract configs from options
|
||||
self.ocr_config = ingest_options.get("ocr")
|
||||
self.chunking_strategy: Dict[str, Any] = cast(
|
||||
Dict[str, Any],
|
||||
ingest_options.get("chunking_strategy") or {"type": "auto"},
|
||||
)
|
||||
self.embedding_config = ingest_options.get("embedding")
|
||||
self.vector_store_config: Dict[str, Any] = cast(
|
||||
Dict[str, Any], ingest_options.get("vector_store") or {}
|
||||
)
|
||||
self.ingest_name = ingest_options.get("name")
|
||||
|
||||
# Load credentials from litellm_credential_name if provided in vector_store config
|
||||
self._load_credentials_from_config()
|
||||
|
||||
def _load_credentials_from_config(self) -> None:
|
||||
"""
|
||||
Load credentials from litellm_credential_name if provided in vector_store config.
|
||||
|
||||
This allows users to specify a credential name in the vector_store config
|
||||
which will be resolved from litellm.credential_list.
|
||||
"""
|
||||
from litellm.litellm_core_utils.credential_accessor import CredentialAccessor
|
||||
|
||||
credential_name = self.vector_store_config.get("litellm_credential_name")
|
||||
if credential_name and litellm.credential_list:
|
||||
credential_values = CredentialAccessor.get_credential_values(
|
||||
credential_name
|
||||
)
|
||||
# Merge credentials into vector_store_config (don't overwrite existing values)
|
||||
for key, value in credential_values.items():
|
||||
if key not in self.vector_store_config:
|
||||
self.vector_store_config[key] = value
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> str:
|
||||
"""Get the vector store provider."""
|
||||
return self.vector_store_config.get("custom_llm_provider", "openai")
|
||||
|
||||
async def upload(
|
||||
self,
|
||||
file_data: Optional[Tuple[str, bytes, str]] = None,
|
||||
file_url: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
) -> Tuple[Optional[str], Optional[bytes], Optional[str], Optional[str]]:
|
||||
"""
|
||||
Upload / prepare file for ingestion.
|
||||
|
||||
Args:
|
||||
file_data: Tuple of (filename, content_bytes, content_type)
|
||||
file_url: URL to fetch file from
|
||||
file_id: Existing file ID to use
|
||||
|
||||
Returns:
|
||||
Tuple of (filename, file_content, content_type, existing_file_id)
|
||||
"""
|
||||
if file_data:
|
||||
filename, file_content, content_type = file_data
|
||||
return filename, file_content, content_type, None
|
||||
|
||||
if file_url:
|
||||
http_client = get_async_httpx_client(llm_provider=httpxSpecialProvider.RAG)
|
||||
response = await http_client.get(file_url)
|
||||
response.raise_for_status()
|
||||
file_content = response.content
|
||||
filename = file_url.split("/")[-1] or "document"
|
||||
content_type = response.headers.get(
|
||||
"content-type", "application/octet-stream"
|
||||
)
|
||||
return filename, file_content, content_type, None
|
||||
|
||||
if file_id:
|
||||
return None, None, None, file_id
|
||||
|
||||
raise ValueError("Must provide file_data, file_url, or file_id")
|
||||
|
||||
async def ocr(
|
||||
self,
|
||||
file_content: Optional[bytes],
|
||||
content_type: Optional[str],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Perform OCR on file content to extract text.
|
||||
|
||||
Args:
|
||||
file_content: Raw file bytes
|
||||
content_type: MIME type of the file
|
||||
|
||||
Returns:
|
||||
Extracted text or None if OCR not configured/needed
|
||||
"""
|
||||
if not self.ocr_config or not file_content:
|
||||
return None
|
||||
|
||||
ocr_model = self.ocr_config.get("model", "mistral/mistral-ocr-latest")
|
||||
|
||||
# Determine document type
|
||||
if content_type and "image" in content_type:
|
||||
doc_type, url_key = "image_url", "image_url"
|
||||
else:
|
||||
doc_type, url_key = "document_url", "document_url"
|
||||
|
||||
# Encode as base64 data URL
|
||||
b64_content = base64.b64encode(file_content).decode("utf-8")
|
||||
data_url = f"data:{content_type};base64,{b64_content}"
|
||||
|
||||
# Use router if available
|
||||
if self.router is not None:
|
||||
ocr_response = await self.router.aocr(
|
||||
model=ocr_model,
|
||||
document={"type": doc_type, url_key: data_url},
|
||||
)
|
||||
else:
|
||||
ocr_response = await litellm.aocr(
|
||||
model=ocr_model,
|
||||
document={"type": doc_type, url_key: data_url},
|
||||
)
|
||||
|
||||
# Extract text from pages
|
||||
if hasattr(ocr_response, "pages") and ocr_response.pages: # type: ignore
|
||||
return "\n\n".join(
|
||||
page.markdown for page in ocr_response.pages if hasattr(page, "markdown") # type: ignore
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def chunk(
|
||||
self,
|
||||
text: Optional[str],
|
||||
file_content: Optional[bytes],
|
||||
ocr_was_used: bool,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Split text into chunks using RecursiveCharacterTextSplitter.
|
||||
|
||||
Args:
|
||||
text: Text from OCR (if used)
|
||||
file_content: Raw file content bytes
|
||||
ocr_was_used: Whether OCR was performed
|
||||
|
||||
Returns:
|
||||
List of text chunks
|
||||
"""
|
||||
# Get text to chunk
|
||||
text_to_chunk: Optional[str] = None
|
||||
if text:
|
||||
text_to_chunk = text
|
||||
elif file_content and not ocr_was_used:
|
||||
# Try UTF-8 decode first
|
||||
try:
|
||||
text_to_chunk = file_content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# Check if it's a PDF and try to extract text
|
||||
if file_content.startswith(b"%PDF"):
|
||||
verbose_logger.debug("PDF detected, attempting text extraction")
|
||||
text_to_chunk = extract_text_from_pdf(file_content)
|
||||
if not text_to_chunk:
|
||||
verbose_logger.debug(
|
||||
"PDF text extraction failed. Install 'pypdf' or 'PyPDF2' for PDF support, "
|
||||
"or enable OCR with a vision model."
|
||||
)
|
||||
return []
|
||||
else:
|
||||
verbose_logger.debug("Binary file detected, skipping text chunking")
|
||||
return []
|
||||
|
||||
if not text_to_chunk:
|
||||
return []
|
||||
|
||||
# Extract RecursiveCharacterTextSplitter args
|
||||
splitter_args = self.chunking_strategy or {}
|
||||
chunk_size = splitter_args.get("chunk_size", DEFAULT_CHUNK_SIZE)
|
||||
chunk_overlap = splitter_args.get("chunk_overlap", DEFAULT_CHUNK_OVERLAP)
|
||||
separators = splitter_args.get("separators", None)
|
||||
|
||||
# Build splitter kwargs
|
||||
splitter_kwargs: Dict[str, Any] = {
|
||||
"chunk_size": chunk_size,
|
||||
"chunk_overlap": chunk_overlap,
|
||||
}
|
||||
if separators:
|
||||
splitter_kwargs["separators"] = separators
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(**splitter_kwargs)
|
||||
return text_splitter.split_text(text_to_chunk)
|
||||
|
||||
async def embed(
|
||||
self,
|
||||
chunks: List[str],
|
||||
) -> Optional[List[List[float]]]:
|
||||
"""
|
||||
Generate embeddings for text chunks.
|
||||
|
||||
Args:
|
||||
chunks: List of text chunks
|
||||
|
||||
Returns:
|
||||
List of embeddings or None
|
||||
"""
|
||||
if not self.embedding_config or not chunks:
|
||||
return None
|
||||
|
||||
embedding_model = self.embedding_config.get("model", "text-embedding-3-small")
|
||||
|
||||
if self.router is not None:
|
||||
response = await self.router.aembedding(model=embedding_model, input=chunks)
|
||||
else:
|
||||
response = await litellm.aembedding(model=embedding_model, input=chunks)
|
||||
|
||||
return [item["embedding"] for item in response.data]
|
||||
|
||||
@abstractmethod
|
||||
async def store(
|
||||
self,
|
||||
file_content: Optional[bytes],
|
||||
filename: Optional[str],
|
||||
content_type: Optional[str],
|
||||
chunks: List[str],
|
||||
embeddings: Optional[List[List[float]]],
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Store content in vector store.
|
||||
|
||||
This method must be implemented by provider-specific subclasses.
|
||||
|
||||
Args:
|
||||
file_content: Raw file bytes
|
||||
filename: Name of the file
|
||||
content_type: MIME type
|
||||
chunks: Text chunks (if chunking was done locally)
|
||||
embeddings: Embeddings (if embedding was done locally)
|
||||
|
||||
Returns:
|
||||
Tuple of (vector_store_id, file_id)
|
||||
"""
|
||||
pass
|
||||
|
||||
async def ingest(
|
||||
self,
|
||||
file_data: Optional[Tuple[str, bytes, str]] = None,
|
||||
file_url: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
) -> RAGIngestResponse:
|
||||
"""
|
||||
Execute the full ingestion pipeline.
|
||||
|
||||
Args:
|
||||
file_data: Tuple of (filename, content_bytes, content_type)
|
||||
file_url: URL to fetch file from
|
||||
file_id: Existing file ID to use
|
||||
|
||||
Returns:
|
||||
RAGIngestResponse with status and IDs
|
||||
|
||||
Raises:
|
||||
ValueError: If no input source is provided
|
||||
"""
|
||||
# Step 1: Upload (raises ValueError if no input provided)
|
||||
filename, file_content, content_type, existing_file_id = await self.upload(
|
||||
file_data=file_data,
|
||||
file_url=file_url,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 2: OCR (optional)
|
||||
extracted_text = await self.ocr(
|
||||
file_content=file_content,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
# Step 3: Chunking
|
||||
chunks = self.chunk(
|
||||
text=extracted_text,
|
||||
file_content=file_content,
|
||||
ocr_was_used=self.ocr_config is not None,
|
||||
)
|
||||
|
||||
# Step 4: Embedding (optional - some providers handle this internally)
|
||||
embeddings = await self.embed(chunks=chunks)
|
||||
|
||||
# Step 5: Store in vector store
|
||||
vector_store_id, result_file_id = await self.store(
|
||||
file_content=file_content,
|
||||
filename=filename,
|
||||
content_type=content_type,
|
||||
chunks=chunks,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
|
||||
return RAGIngestResponse(
|
||||
id=self.ingest_id,
|
||||
status="completed",
|
||||
vector_store_id=vector_store_id or "",
|
||||
file_id=result_file_id or existing_file_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"RAG Pipeline failed: {e}")
|
||||
return RAGIngestResponse(
|
||||
id=self.ingest_id,
|
||||
status="failed",
|
||||
vector_store_id="",
|
||||
file_id=None,
|
||||
error=str(e),
|
||||
)
|
||||
@@ -0,0 +1,775 @@
|
||||
"""
|
||||
Bedrock-specific RAG Ingestion implementation.
|
||||
|
||||
Bedrock Knowledge Bases handle embedding internally when files are ingested,
|
||||
so this implementation uploads files to S3 and triggers ingestion jobs.
|
||||
|
||||
Supports two modes:
|
||||
1. Use existing KB: Provide vector_store_id (KB ID)
|
||||
2. Auto-create KB: Don't provide vector_store_id - creates all AWS resources automatically
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import Router
|
||||
from litellm.types.rag import RAGIngestOptions
|
||||
|
||||
|
||||
def _get_str_or_none(value: Any) -> Optional[str]:
|
||||
"""Cast config value to Optional[str]."""
|
||||
return str(value) if value is not None else None
|
||||
|
||||
|
||||
def _get_int(value: Any, default: int) -> int:
|
||||
"""Cast config value to int with default."""
|
||||
if value is None:
|
||||
return default
|
||||
return int(value)
|
||||
|
||||
|
||||
def _normalize_principal_arn(caller_arn: str, account_id: str) -> str:
|
||||
"""
|
||||
Normalize a caller ARN to the format required by OpenSearch data access policies.
|
||||
|
||||
OpenSearch Serverless data access policies require:
|
||||
- IAM users: arn:aws:iam::account-id:user/user-name
|
||||
- IAM roles: arn:aws:iam::account-id:role/role-name
|
||||
|
||||
But get_caller_identity() returns for assumed roles:
|
||||
- arn:aws:sts::account-id:assumed-role/role-name/session-name
|
||||
|
||||
This function converts assumed-role ARNs to the proper IAM role ARN format.
|
||||
"""
|
||||
if ":assumed-role/" in caller_arn:
|
||||
# Extract role name from assumed-role ARN
|
||||
# Format: arn:aws:sts::ACCOUNT:assumed-role/ROLE-NAME/SESSION-NAME
|
||||
parts = caller_arn.split("/")
|
||||
if len(parts) >= 2:
|
||||
role_name = parts[1]
|
||||
return f"arn:aws:iam::{account_id}:role/{role_name}"
|
||||
return caller_arn
|
||||
|
||||
|
||||
class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
|
||||
"""
|
||||
Bedrock Knowledge Base RAG ingestion.
|
||||
|
||||
Supports two modes:
|
||||
1. **Use existing KB**: Provide vector_store_id
|
||||
2. **Auto-create KB**: Don't provide vector_store_id - creates S3 bucket,
|
||||
OpenSearch Serverless collection, IAM role, KB, and data source automatically
|
||||
|
||||
Optional config:
|
||||
- vector_store_id: Existing KB ID (if not provided, auto-creates)
|
||||
- s3_bucket: S3 bucket (auto-created if not provided)
|
||||
- embedding_model: Bedrock embedding model (default: amazon.titan-embed-text-v2:0)
|
||||
- wait_for_ingestion: Wait for completion (default: True)
|
||||
- ingestion_timeout: Max seconds to wait (default: 300)
|
||||
|
||||
AWS Auth (uses BaseAWSLLM):
|
||||
- aws_access_key_id, aws_secret_access_key, aws_session_token
|
||||
- aws_region_name (default: us-west-2)
|
||||
- aws_role_name, aws_session_name, aws_profile_name
|
||||
- aws_web_identity_token, aws_sts_endpoint, aws_external_id
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ingest_options: "RAGIngestOptions",
|
||||
router: Optional["Router"] = None,
|
||||
):
|
||||
BaseRAGIngestion.__init__(self, ingest_options=ingest_options, router=router)
|
||||
BaseAWSLLM.__init__(self)
|
||||
|
||||
# Use vector_store_id as unified param (maps to knowledge_base_id)
|
||||
self.knowledge_base_id = self.vector_store_config.get(
|
||||
"vector_store_id"
|
||||
) or self.vector_store_config.get("knowledge_base_id")
|
||||
|
||||
# Optional config
|
||||
self._data_source_id = self.vector_store_config.get("data_source_id")
|
||||
self._s3_bucket = self.vector_store_config.get("s3_bucket")
|
||||
self._s3_prefix: Optional[str] = (
|
||||
str(self.vector_store_config.get("s3_prefix"))
|
||||
if self.vector_store_config.get("s3_prefix")
|
||||
else None
|
||||
)
|
||||
self.embedding_model = (
|
||||
self.vector_store_config.get("embedding_model")
|
||||
or "amazon.titan-embed-text-v2:0"
|
||||
)
|
||||
|
||||
self.wait_for_ingestion = self.vector_store_config.get(
|
||||
"wait_for_ingestion", False
|
||||
)
|
||||
self.ingestion_timeout: int = _get_int(
|
||||
self.vector_store_config.get("ingestion_timeout"), 300
|
||||
)
|
||||
|
||||
# Get AWS region using BaseAWSLLM method
|
||||
_aws_region = self.vector_store_config.get("aws_region_name")
|
||||
self.aws_region_name = self.get_aws_region_name_for_non_llm_api_calls(
|
||||
aws_region_name=str(_aws_region) if _aws_region else None
|
||||
)
|
||||
|
||||
# Will be set during initialization
|
||||
self.data_source_id: Optional[str] = None
|
||||
self.s3_bucket: Optional[str] = None
|
||||
self.s3_prefix: str = self._s3_prefix or "data/"
|
||||
self._config_initialized = False
|
||||
|
||||
# Track resources we create (for cleanup if needed)
|
||||
self._created_resources: Dict[str, Any] = {}
|
||||
|
||||
async def _ensure_config_initialized(self):
|
||||
"""Lazily initialize KB config - either detect from existing or create new."""
|
||||
if self._config_initialized:
|
||||
return
|
||||
|
||||
if self.knowledge_base_id:
|
||||
# Use existing KB - auto-detect data source and S3 bucket
|
||||
self._auto_detect_config()
|
||||
else:
|
||||
# No KB provided - create everything from scratch
|
||||
await self._create_knowledge_base_infrastructure()
|
||||
|
||||
self._config_initialized = True
|
||||
|
||||
def _auto_detect_config(self):
|
||||
"""Auto-detect data source ID and S3 bucket from existing Knowledge Base."""
|
||||
verbose_logger.debug(
|
||||
f"Auto-detecting data source and S3 bucket for KB={self.knowledge_base_id}"
|
||||
)
|
||||
|
||||
bedrock_agent = self._get_boto3_client("bedrock-agent")
|
||||
|
||||
# List data sources for this KB
|
||||
ds_response = bedrock_agent.list_data_sources(
|
||||
knowledgeBaseId=self.knowledge_base_id
|
||||
)
|
||||
data_sources = ds_response.get("dataSourceSummaries", [])
|
||||
|
||||
if not data_sources:
|
||||
raise ValueError(
|
||||
f"No data sources found for Knowledge Base {self.knowledge_base_id}. "
|
||||
"Please create a data source first or provide data_source_id and s3_bucket."
|
||||
)
|
||||
|
||||
# Use first data source (or user-provided override)
|
||||
if self._data_source_id:
|
||||
self.data_source_id = self._data_source_id
|
||||
else:
|
||||
self.data_source_id = data_sources[0]["dataSourceId"]
|
||||
verbose_logger.info(f"Auto-detected data source: {self.data_source_id}")
|
||||
|
||||
# Get data source details for S3 bucket
|
||||
ds_details = bedrock_agent.get_data_source(
|
||||
knowledgeBaseId=self.knowledge_base_id,
|
||||
dataSourceId=self.data_source_id,
|
||||
)
|
||||
|
||||
s3_config = (
|
||||
ds_details.get("dataSource", {})
|
||||
.get("dataSourceConfiguration", {})
|
||||
.get("s3Configuration", {})
|
||||
)
|
||||
|
||||
bucket_arn = s3_config.get("bucketArn", "")
|
||||
if bucket_arn:
|
||||
# Extract bucket name from ARN: arn:aws:s3:::bucket-name
|
||||
self.s3_bucket = self._s3_bucket or bucket_arn.split(":")[-1]
|
||||
verbose_logger.info(f"Auto-detected S3 bucket: {self.s3_bucket}")
|
||||
|
||||
# Use inclusion prefix if available
|
||||
prefixes = s3_config.get("inclusionPrefixes", [])
|
||||
if prefixes and not self._s3_prefix:
|
||||
self.s3_prefix = prefixes[0]
|
||||
else:
|
||||
if not self._s3_bucket:
|
||||
raise ValueError(
|
||||
f"Could not auto-detect S3 bucket for data source {self.data_source_id}. "
|
||||
"Please provide s3_bucket in config."
|
||||
)
|
||||
self.s3_bucket = self._s3_bucket
|
||||
|
||||
async def _create_knowledge_base_infrastructure(self):
|
||||
"""Create all AWS resources needed for a new Knowledge Base."""
|
||||
verbose_logger.info("Creating new Bedrock Knowledge Base infrastructure...")
|
||||
|
||||
# Generate unique names
|
||||
unique_id = uuid.uuid4().hex[:8]
|
||||
kb_name = self.ingest_name or f"litellm-kb-{unique_id}"
|
||||
|
||||
# Get AWS account ID and caller ARN (for data access policy)
|
||||
sts = self._get_boto3_client("sts")
|
||||
caller_identity = sts.get_caller_identity()
|
||||
account_id = caller_identity["Account"]
|
||||
caller_arn = caller_identity["Arn"]
|
||||
|
||||
# Step 1: Create S3 bucket (if not provided)
|
||||
self.s3_bucket = self._s3_bucket or self._create_s3_bucket(unique_id)
|
||||
|
||||
# Step 2: Create OpenSearch Serverless collection
|
||||
collection_name, collection_arn = await self._create_opensearch_collection(
|
||||
unique_id, account_id, caller_arn
|
||||
)
|
||||
|
||||
# Step 3: Create OpenSearch index
|
||||
await self._create_opensearch_index(collection_name)
|
||||
|
||||
# Step 4: Create IAM role for Bedrock
|
||||
role_arn = await self._create_bedrock_role(
|
||||
unique_id, account_id, collection_arn
|
||||
)
|
||||
|
||||
# Step 5: Create Knowledge Base
|
||||
self.knowledge_base_id = await self._create_knowledge_base(
|
||||
kb_name, role_arn, collection_arn
|
||||
)
|
||||
|
||||
# Step 6: Create Data Source
|
||||
self.data_source_id = self._create_data_source(kb_name)
|
||||
|
||||
verbose_logger.info(
|
||||
f"Created KB infrastructure: kb_id={self.knowledge_base_id}, "
|
||||
f"ds_id={self.data_source_id}, bucket={self.s3_bucket}"
|
||||
)
|
||||
|
||||
def _create_s3_bucket(self, unique_id: str) -> str:
|
||||
"""Create S3 bucket for KB data source."""
|
||||
s3 = self._get_boto3_client("s3")
|
||||
bucket_name = f"litellm-kb-{unique_id}"
|
||||
|
||||
verbose_logger.debug(f"Creating S3 bucket: {bucket_name}")
|
||||
|
||||
create_params: Dict[str, Any] = {"Bucket": bucket_name}
|
||||
if self.aws_region_name != "us-east-1":
|
||||
create_params["CreateBucketConfiguration"] = {
|
||||
"LocationConstraint": self.aws_region_name
|
||||
}
|
||||
|
||||
s3.create_bucket(**create_params)
|
||||
self._created_resources["s3_bucket"] = bucket_name
|
||||
|
||||
verbose_logger.info(f"Created S3 bucket: {bucket_name}")
|
||||
return bucket_name
|
||||
|
||||
async def _create_opensearch_collection(
|
||||
self, unique_id: str, account_id: str, caller_arn: str
|
||||
) -> Tuple[str, str]:
|
||||
"""Create OpenSearch Serverless collection for vector storage."""
|
||||
oss = self._get_boto3_client("opensearchserverless")
|
||||
collection_name = f"litellm-kb-{unique_id}"
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Creating OpenSearch Serverless collection: {collection_name}"
|
||||
)
|
||||
|
||||
# Create encryption policy
|
||||
oss.create_security_policy(
|
||||
name=f"{collection_name}-enc",
|
||||
type="encryption",
|
||||
policy=json.dumps(
|
||||
{
|
||||
"Rules": [
|
||||
{
|
||||
"ResourceType": "collection",
|
||||
"Resource": [f"collection/{collection_name}"],
|
||||
}
|
||||
],
|
||||
"AWSOwnedKey": True,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Create network policy (public access for simplicity)
|
||||
oss.create_security_policy(
|
||||
name=f"{collection_name}-net",
|
||||
type="network",
|
||||
policy=json.dumps(
|
||||
[
|
||||
{
|
||||
"Rules": [
|
||||
{
|
||||
"ResourceType": "collection",
|
||||
"Resource": [f"collection/{collection_name}"],
|
||||
},
|
||||
{
|
||||
"ResourceType": "dashboard",
|
||||
"Resource": [f"collection/{collection_name}"],
|
||||
},
|
||||
],
|
||||
"AllowFromPublic": True,
|
||||
}
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
# Create data access policy - include both root and actual caller ARN
|
||||
# This ensures the credentials being used have access to the collection
|
||||
# Normalize the caller ARN (convert assumed-role ARN to IAM role ARN if needed)
|
||||
normalized_caller_arn = _normalize_principal_arn(caller_arn, account_id)
|
||||
verbose_logger.debug(
|
||||
f"Caller ARN: {caller_arn}, Normalized: {normalized_caller_arn}"
|
||||
)
|
||||
|
||||
principals = [f"arn:aws:iam::{account_id}:root", normalized_caller_arn]
|
||||
# Deduplicate in case caller is root
|
||||
principals = list(set(principals))
|
||||
|
||||
oss.create_access_policy(
|
||||
name=f"{collection_name}-access",
|
||||
type="data",
|
||||
policy=json.dumps(
|
||||
[
|
||||
{
|
||||
"Rules": [
|
||||
{
|
||||
"ResourceType": "index",
|
||||
"Resource": [f"index/{collection_name}/*"],
|
||||
"Permission": ["aoss:*"],
|
||||
},
|
||||
{
|
||||
"ResourceType": "collection",
|
||||
"Resource": [f"collection/{collection_name}"],
|
||||
"Permission": ["aoss:*"],
|
||||
},
|
||||
],
|
||||
"Principal": principals,
|
||||
}
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
# Create collection
|
||||
response = oss.create_collection(
|
||||
name=collection_name,
|
||||
type="VECTORSEARCH",
|
||||
)
|
||||
collection_id = response["createCollectionDetail"]["id"]
|
||||
self._created_resources["opensearch_collection"] = collection_name
|
||||
|
||||
# Wait for collection to be active (use asyncio.sleep to avoid blocking)
|
||||
verbose_logger.debug("Waiting for OpenSearch collection to be active...")
|
||||
for _ in range(60): # 5 min timeout
|
||||
status_response = oss.batch_get_collection(ids=[collection_id])
|
||||
status = status_response["collectionDetails"][0]["status"]
|
||||
if status == "ACTIVE":
|
||||
break
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
raise TimeoutError("OpenSearch collection did not become active in time")
|
||||
|
||||
collection_arn = status_response["collectionDetails"][0]["arn"]
|
||||
verbose_logger.info(f"Created OpenSearch collection: {collection_name}")
|
||||
|
||||
# Wait for data access policy to propagate before returning
|
||||
# AWS recommends waiting 60+ seconds for policy propagation
|
||||
verbose_logger.debug("Waiting for data access policy to propagate (60s)...")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
return collection_name, collection_arn
|
||||
|
||||
async def _create_opensearch_index(self, collection_name: str):
|
||||
"""Create vector index in OpenSearch collection with retry logic."""
|
||||
from opensearchpy import OpenSearch, RequestsHttpConnection
|
||||
from requests_aws4auth import AWS4Auth
|
||||
|
||||
# Get credentials for signing
|
||||
credentials = self.get_credentials(
|
||||
aws_access_key_id=_get_str_or_none(
|
||||
self.vector_store_config.get("aws_access_key_id")
|
||||
),
|
||||
aws_secret_access_key=_get_str_or_none(
|
||||
self.vector_store_config.get("aws_secret_access_key")
|
||||
),
|
||||
aws_session_token=_get_str_or_none(
|
||||
self.vector_store_config.get("aws_session_token")
|
||||
),
|
||||
aws_region_name=self.aws_region_name,
|
||||
)
|
||||
|
||||
# Get collection endpoint
|
||||
oss = self._get_boto3_client("opensearchserverless")
|
||||
collections = oss.batch_get_collection(names=[collection_name])
|
||||
endpoint = collections["collectionDetails"][0]["collectionEndpoint"]
|
||||
host = endpoint.replace("https://", "")
|
||||
|
||||
auth = AWS4Auth(
|
||||
credentials.access_key,
|
||||
credentials.secret_key,
|
||||
self.aws_region_name,
|
||||
"aoss",
|
||||
session_token=credentials.token,
|
||||
)
|
||||
|
||||
client = OpenSearch(
|
||||
hosts=[{"host": host, "port": 443}],
|
||||
http_auth=auth,
|
||||
use_ssl=True,
|
||||
verify_certs=True,
|
||||
connection_class=RequestsHttpConnection,
|
||||
)
|
||||
|
||||
index_name = "bedrock-kb-index"
|
||||
index_body = {
|
||||
"settings": {"index": {"knn": True, "knn.algo_param.ef_search": 512}},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"bedrock-knowledge-base-default-vector": {
|
||||
"type": "knn_vector",
|
||||
"dimension": 1024,
|
||||
"method": {
|
||||
"engine": "faiss",
|
||||
"name": "hnsw",
|
||||
"space_type": "l2",
|
||||
},
|
||||
},
|
||||
"AMAZON_BEDROCK_METADATA": {"type": "text", "index": False},
|
||||
"AMAZON_BEDROCK_TEXT_CHUNK": {"type": "text"},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Retry logic for index creation - data access policy may take time to propagate
|
||||
max_retries = 8
|
||||
retry_delay = 20 # seconds
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
client.indices.create(index=index_name, body=index_body)
|
||||
verbose_logger.info(f"Created OpenSearch index: {index_name}")
|
||||
return
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
error_str = str(e)
|
||||
if (
|
||||
"authorization_exception" in error_str.lower()
|
||||
or "security_exception" in error_str.lower()
|
||||
):
|
||||
verbose_logger.warning(
|
||||
f"OpenSearch index creation attempt {attempt + 1}/{max_retries} failed due to authorization. "
|
||||
f"Waiting {retry_delay}s for policy propagation..."
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
# Non-auth error, raise immediately
|
||||
raise
|
||||
|
||||
# All retries exhausted
|
||||
raise RuntimeError(
|
||||
f"Failed to create OpenSearch index after {max_retries} attempts. "
|
||||
f"Data access policy may not have propagated. Last error: {last_error}"
|
||||
)
|
||||
|
||||
async def _create_bedrock_role(
|
||||
self, unique_id: str, account_id: str, collection_arn: str
|
||||
) -> str:
|
||||
"""Create IAM role for Bedrock KB."""
|
||||
iam = self._get_boto3_client("iam")
|
||||
role_name = f"litellm-bedrock-kb-{unique_id}"
|
||||
|
||||
verbose_logger.debug(f"Creating IAM role: {role_name}")
|
||||
|
||||
trust_policy = {
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Principal": {"Service": "bedrock.amazonaws.com"},
|
||||
"Action": "sts:AssumeRole",
|
||||
"Condition": {
|
||||
"StringEquals": {"aws:SourceAccount": account_id},
|
||||
"ArnLike": {
|
||||
"aws:SourceArn": f"arn:aws:bedrock:{self.aws_region_name}:{account_id}:knowledge-base/*"
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
response = iam.create_role(
|
||||
RoleName=role_name,
|
||||
AssumeRolePolicyDocument=json.dumps(trust_policy),
|
||||
)
|
||||
role_arn = response["Role"]["Arn"]
|
||||
self._created_resources["iam_role"] = role_name
|
||||
|
||||
# Attach permissions policy
|
||||
permissions_policy = {
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": ["bedrock:InvokeModel"],
|
||||
"Resource": [
|
||||
f"arn:aws:bedrock:{self.aws_region_name}::foundation-model/{self.embedding_model}"
|
||||
],
|
||||
},
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": ["aoss:APIAccessAll"],
|
||||
"Resource": [collection_arn],
|
||||
},
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": ["s3:GetObject", "s3:ListBucket"],
|
||||
"Resource": [
|
||||
f"arn:aws:s3:::{self.s3_bucket}",
|
||||
f"arn:aws:s3:::{self.s3_bucket}/*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
iam.put_role_policy(
|
||||
RoleName=role_name,
|
||||
PolicyName=f"{role_name}-policy",
|
||||
PolicyDocument=json.dumps(permissions_policy),
|
||||
)
|
||||
|
||||
# Wait for role to propagate (use asyncio.sleep to avoid blocking)
|
||||
await asyncio.sleep(10)
|
||||
|
||||
verbose_logger.info(f"Created IAM role: {role_arn}")
|
||||
return role_arn
|
||||
|
||||
async def _create_knowledge_base(
|
||||
self, kb_name: str, role_arn: str, collection_arn: str
|
||||
) -> str:
|
||||
"""Create Bedrock Knowledge Base."""
|
||||
bedrock_agent = self._get_boto3_client("bedrock-agent")
|
||||
|
||||
verbose_logger.debug(f"Creating Knowledge Base: {kb_name}")
|
||||
|
||||
response = bedrock_agent.create_knowledge_base(
|
||||
name=kb_name,
|
||||
roleArn=role_arn,
|
||||
knowledgeBaseConfiguration={
|
||||
"type": "VECTOR",
|
||||
"vectorKnowledgeBaseConfiguration": {
|
||||
"embeddingModelArn": f"arn:aws:bedrock:{self.aws_region_name}::foundation-model/{self.embedding_model}",
|
||||
},
|
||||
},
|
||||
storageConfiguration={
|
||||
"type": "OPENSEARCH_SERVERLESS",
|
||||
"opensearchServerlessConfiguration": {
|
||||
"collectionArn": collection_arn,
|
||||
"fieldMapping": {
|
||||
"metadataField": "AMAZON_BEDROCK_METADATA",
|
||||
"textField": "AMAZON_BEDROCK_TEXT_CHUNK",
|
||||
"vectorField": "bedrock-knowledge-base-default-vector",
|
||||
},
|
||||
"vectorIndexName": "bedrock-kb-index",
|
||||
},
|
||||
},
|
||||
)
|
||||
kb_id = response["knowledgeBase"]["knowledgeBaseId"]
|
||||
self._created_resources["knowledge_base"] = kb_id
|
||||
|
||||
# Wait for KB to be active (use asyncio.sleep to avoid blocking)
|
||||
verbose_logger.debug("Waiting for Knowledge Base to be active...")
|
||||
for _ in range(30):
|
||||
kb_status = bedrock_agent.get_knowledge_base(knowledgeBaseId=kb_id)
|
||||
status = kb_status["knowledgeBase"]["status"]
|
||||
if status == "ACTIVE":
|
||||
break
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
raise TimeoutError("Knowledge Base did not become active in time")
|
||||
|
||||
verbose_logger.info(f"Created Knowledge Base: {kb_id}")
|
||||
return kb_id
|
||||
|
||||
def _create_data_source(self, kb_name: str) -> str:
|
||||
"""Create Data Source for the Knowledge Base."""
|
||||
bedrock_agent = self._get_boto3_client("bedrock-agent")
|
||||
|
||||
verbose_logger.debug(f"Creating Data Source for KB: {self.knowledge_base_id}")
|
||||
|
||||
response = bedrock_agent.create_data_source(
|
||||
knowledgeBaseId=self.knowledge_base_id,
|
||||
name=f"{kb_name}-s3-source",
|
||||
dataSourceConfiguration={
|
||||
"type": "S3",
|
||||
"s3Configuration": {
|
||||
"bucketArn": f"arn:aws:s3:::{self.s3_bucket}",
|
||||
"inclusionPrefixes": [self.s3_prefix],
|
||||
},
|
||||
},
|
||||
)
|
||||
ds_id = response["dataSource"]["dataSourceId"]
|
||||
self._created_resources["data_source"] = ds_id
|
||||
|
||||
verbose_logger.info(f"Created Data Source: {ds_id}")
|
||||
return ds_id
|
||||
|
||||
def _get_boto3_client(self, service_name: str):
|
||||
"""Get a boto3 client for the specified service using BaseAWSLLM auth."""
|
||||
try:
|
||||
import boto3
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"boto3 is required for Bedrock ingestion. Install with: pip install boto3"
|
||||
)
|
||||
|
||||
# Get credentials using BaseAWSLLM's get_credentials method
|
||||
credentials = self.get_credentials(
|
||||
aws_access_key_id=_get_str_or_none(
|
||||
self.vector_store_config.get("aws_access_key_id")
|
||||
),
|
||||
aws_secret_access_key=_get_str_or_none(
|
||||
self.vector_store_config.get("aws_secret_access_key")
|
||||
),
|
||||
aws_session_token=_get_str_or_none(
|
||||
self.vector_store_config.get("aws_session_token")
|
||||
),
|
||||
aws_region_name=self.aws_region_name,
|
||||
aws_session_name=_get_str_or_none(
|
||||
self.vector_store_config.get("aws_session_name")
|
||||
),
|
||||
aws_profile_name=_get_str_or_none(
|
||||
self.vector_store_config.get("aws_profile_name")
|
||||
),
|
||||
aws_role_name=_get_str_or_none(
|
||||
self.vector_store_config.get("aws_role_name")
|
||||
),
|
||||
aws_web_identity_token=_get_str_or_none(
|
||||
self.vector_store_config.get("aws_web_identity_token")
|
||||
),
|
||||
aws_sts_endpoint=_get_str_or_none(
|
||||
self.vector_store_config.get("aws_sts_endpoint")
|
||||
),
|
||||
aws_external_id=_get_str_or_none(
|
||||
self.vector_store_config.get("aws_external_id")
|
||||
),
|
||||
)
|
||||
|
||||
# Create session with credentials
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=credentials.access_key,
|
||||
aws_secret_access_key=credentials.secret_key,
|
||||
aws_session_token=credentials.token,
|
||||
region_name=self.aws_region_name,
|
||||
)
|
||||
|
||||
return session.client(service_name)
|
||||
|
||||
async def embed(
|
||||
self,
|
||||
chunks: List[str],
|
||||
) -> Optional[List[List[float]]]:
|
||||
"""
|
||||
Bedrock handles embedding internally - skip this step.
|
||||
|
||||
Returns:
|
||||
None (Bedrock embeds when files are ingested)
|
||||
"""
|
||||
return None
|
||||
|
||||
async def store(
|
||||
self,
|
||||
file_content: Optional[bytes],
|
||||
filename: Optional[str],
|
||||
content_type: Optional[str],
|
||||
chunks: List[str],
|
||||
embeddings: Optional[List[List[float]]],
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Store content in Bedrock Knowledge Base.
|
||||
|
||||
Bedrock workflow:
|
||||
1. Auto-detect data source and S3 bucket (if not provided)
|
||||
2. Upload file to S3 bucket
|
||||
3. Start ingestion job
|
||||
4. (Optional) Wait for ingestion to complete
|
||||
|
||||
Args:
|
||||
file_content: Raw file bytes
|
||||
filename: Name of the file
|
||||
content_type: MIME type
|
||||
chunks: Ignored - Bedrock handles chunking
|
||||
embeddings: Ignored - Bedrock handles embedding
|
||||
|
||||
Returns:
|
||||
Tuple of (knowledge_base_id, file_key)
|
||||
"""
|
||||
# Auto-detect data source and S3 bucket if needed
|
||||
await self._ensure_config_initialized()
|
||||
|
||||
if not file_content or not filename:
|
||||
verbose_logger.warning(
|
||||
"No file content or filename provided for Bedrock ingestion"
|
||||
)
|
||||
return _get_str_or_none(self.knowledge_base_id), None
|
||||
|
||||
# Step 1: Upload file to S3
|
||||
s3_client = self._get_boto3_client("s3")
|
||||
s3_key = f"{self.s3_prefix.rstrip('/')}/{filename}"
|
||||
|
||||
verbose_logger.debug(f"Uploading file to s3://{self.s3_bucket}/{s3_key}")
|
||||
s3_client.put_object(
|
||||
Bucket=self.s3_bucket,
|
||||
Key=s3_key,
|
||||
Body=file_content,
|
||||
ContentType=content_type or "application/octet-stream",
|
||||
)
|
||||
verbose_logger.info(f"Uploaded file to s3://{self.s3_bucket}/{s3_key}")
|
||||
|
||||
# Step 2: Start ingestion job
|
||||
bedrock_agent = self._get_boto3_client("bedrock-agent")
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Starting ingestion job for KB={self.knowledge_base_id}, DS={self.data_source_id}"
|
||||
)
|
||||
ingestion_response = bedrock_agent.start_ingestion_job(
|
||||
knowledgeBaseId=self.knowledge_base_id,
|
||||
dataSourceId=self.data_source_id,
|
||||
)
|
||||
job_id = ingestion_response["ingestionJob"]["ingestionJobId"]
|
||||
verbose_logger.info(f"Started ingestion job: {job_id}")
|
||||
|
||||
# Step 3: Wait for ingestion (optional) - use asyncio.sleep to avoid blocking
|
||||
if self.wait_for_ingestion:
|
||||
import time as time_module
|
||||
|
||||
start_time = time_module.time()
|
||||
while time_module.time() - start_time < self.ingestion_timeout:
|
||||
job_status = bedrock_agent.get_ingestion_job(
|
||||
knowledgeBaseId=self.knowledge_base_id,
|
||||
dataSourceId=self.data_source_id,
|
||||
ingestionJobId=job_id,
|
||||
)
|
||||
status = job_status["ingestionJob"]["status"]
|
||||
verbose_logger.debug(f"Ingestion job {job_id} status: {status}")
|
||||
|
||||
if status == "COMPLETE":
|
||||
stats = job_status["ingestionJob"].get("statistics", {})
|
||||
verbose_logger.info(
|
||||
f"Ingestion complete: {stats.get('numberOfNewDocumentsIndexed', 0)} docs indexed"
|
||||
)
|
||||
break
|
||||
elif status == "FAILED":
|
||||
failure_reasons = job_status["ingestionJob"].get(
|
||||
"failureReasons", []
|
||||
)
|
||||
verbose_logger.error(f"Ingestion failed: {failure_reasons}")
|
||||
break
|
||||
elif status in ("STARTING", "IN_PROGRESS"):
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
verbose_logger.warning(f"Unknown ingestion status: {status}")
|
||||
break
|
||||
|
||||
return str(self.knowledge_base_id) if self.knowledge_base_id else None, s3_key
|
||||
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
File parsers for RAG ingestion.
|
||||
|
||||
Provides text extraction utilities for various file formats.
|
||||
"""
|
||||
|
||||
from .pdf_parser import extract_text_from_pdf
|
||||
|
||||
__all__ = ["extract_text_from_pdf"]
|
||||
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
PDF text extraction utilities.
|
||||
|
||||
Provides text extraction from PDF files using pypdf or PyPDF2.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
def extract_text_from_pdf(file_content: bytes) -> Optional[str]:
|
||||
"""
|
||||
Extract text from PDF using pypdf if available.
|
||||
|
||||
Args:
|
||||
file_content: Raw PDF bytes
|
||||
|
||||
Returns:
|
||||
Extracted text or None if extraction fails
|
||||
"""
|
||||
try:
|
||||
from io import BytesIO
|
||||
|
||||
# Try pypdf first (most common)
|
||||
try:
|
||||
from pypdf import PdfReader as PypdfReader
|
||||
|
||||
pdf_file = BytesIO(file_content)
|
||||
reader = PypdfReader(pdf_file)
|
||||
|
||||
text_parts = []
|
||||
for page in reader.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
|
||||
if text_parts:
|
||||
extracted_text = "\n\n".join(text_parts)
|
||||
verbose_logger.debug(
|
||||
f"Extracted {len(extracted_text)} characters from PDF using pypdf"
|
||||
)
|
||||
return extracted_text
|
||||
|
||||
except ImportError:
|
||||
verbose_logger.debug("pypdf not available, trying PyPDF2")
|
||||
|
||||
# Fallback to PyPDF2
|
||||
try:
|
||||
from PyPDF2 import PdfReader as PyPDF2Reader
|
||||
|
||||
pdf_file = BytesIO(file_content)
|
||||
reader = PyPDF2Reader(pdf_file)
|
||||
|
||||
text_parts = []
|
||||
for page in reader.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
|
||||
if text_parts:
|
||||
extracted_text = "\n\n".join(text_parts)
|
||||
verbose_logger.debug(
|
||||
f"Extracted {len(extracted_text)} characters from PDF using PyPDF2"
|
||||
)
|
||||
return extracted_text
|
||||
|
||||
except ImportError:
|
||||
verbose_logger.debug(
|
||||
"PyPDF2 not available, PDF extraction requires OCR or pypdf/PyPDF2 library"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"PDF text extraction failed: {e}")
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,339 @@
|
||||
"""
|
||||
Gemini-specific RAG Ingestion implementation.
|
||||
|
||||
Gemini handles embedding and chunking internally when files are uploaded to File Search stores,
|
||||
so this implementation skips the embedding step and directly uploads files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.llms.gemini.common_utils import GeminiModelInfo
|
||||
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import Router
|
||||
from litellm.types.rag import RAGIngestOptions
|
||||
|
||||
|
||||
class GeminiRAGIngestion(BaseRAGIngestion):
|
||||
"""
|
||||
Gemini-specific RAG ingestion using File Search API.
|
||||
|
||||
Key differences from base:
|
||||
- Embedding is handled by Gemini when files are uploaded to File Search stores
|
||||
- Files are uploaded using uploadToFileSearchStore API
|
||||
- Chunking is done by Gemini's File Search (supports custom white_space_config)
|
||||
- Supports custom metadata attachment
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ingest_options: "RAGIngestOptions",
|
||||
router: Optional["Router"] = None,
|
||||
):
|
||||
super().__init__(ingest_options=ingest_options, router=router)
|
||||
self.model_info = GeminiModelInfo()
|
||||
|
||||
async def embed(
|
||||
self,
|
||||
chunks: List[str],
|
||||
) -> Optional[List[List[float]]]:
|
||||
"""
|
||||
Gemini handles embedding internally - skip this step.
|
||||
|
||||
Returns:
|
||||
None (Gemini embeds when files are uploaded to File Search store)
|
||||
"""
|
||||
# Gemini handles embedding when files are uploaded to File Search stores
|
||||
return None
|
||||
|
||||
async def store(
|
||||
self,
|
||||
file_content: Optional[bytes],
|
||||
filename: Optional[str],
|
||||
content_type: Optional[str],
|
||||
chunks: List[str],
|
||||
embeddings: Optional[List[List[float]]],
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Store content in Gemini File Search store.
|
||||
|
||||
Gemini workflow:
|
||||
1. Create File Search store (if not provided)
|
||||
2. Upload file using uploadToFileSearchStore (Gemini handles chunking/embedding)
|
||||
|
||||
Args:
|
||||
file_content: Raw file bytes
|
||||
filename: Name of the file
|
||||
content_type: MIME type
|
||||
chunks: Ignored - Gemini handles chunking
|
||||
embeddings: Ignored - Gemini handles embedding
|
||||
|
||||
Returns:
|
||||
Tuple of (vector_store_id, file_id)
|
||||
"""
|
||||
vector_store_id = self.vector_store_config.get("vector_store_id")
|
||||
|
||||
vector_store_config = cast(Dict[str, Any], self.vector_store_config)
|
||||
|
||||
# Get API credentials
|
||||
api_key = (
|
||||
cast(Optional[str], vector_store_config.get("api_key"))
|
||||
or GeminiModelInfo.get_api_key()
|
||||
)
|
||||
api_base = (
|
||||
cast(Optional[str], vector_store_config.get("api_base"))
|
||||
or GeminiModelInfo.get_api_base()
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"GEMINI_API_KEY or GOOGLE_API_KEY is required for Gemini File Search"
|
||||
)
|
||||
|
||||
if not api_base:
|
||||
raise ValueError("GEMINI_API_BASE is required")
|
||||
|
||||
api_version = "v1beta"
|
||||
base_url = f"{api_base}/{api_version}"
|
||||
|
||||
# Create File Search store if not provided
|
||||
if not vector_store_id:
|
||||
vector_store_id = await self._create_file_search_store(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
display_name=self.ingest_name or "litellm-rag-ingest",
|
||||
)
|
||||
|
||||
# Upload file to File Search store
|
||||
result_file_id = None
|
||||
if file_content and filename and vector_store_id:
|
||||
result_file_id = await self._upload_to_file_search_store(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
vector_store_id=vector_store_id,
|
||||
filename=filename,
|
||||
file_content=file_content,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
return vector_store_id, result_file_id
|
||||
|
||||
async def _create_file_search_store(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
display_name: str,
|
||||
) -> str:
|
||||
"""
|
||||
Create a Gemini File Search store.
|
||||
|
||||
Args:
|
||||
api_key: Gemini API key
|
||||
base_url: Base URL for Gemini API
|
||||
display_name: Display name for the store
|
||||
|
||||
Returns:
|
||||
Store name (format: fileSearchStores/xxxxxxx)
|
||||
"""
|
||||
url = f"{base_url}/fileSearchStores?key={api_key}"
|
||||
|
||||
request_body = {"displayName": display_name}
|
||||
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.RAG,
|
||||
params={"timeout": 60.0},
|
||||
)
|
||||
response = await client.post(
|
||||
url,
|
||||
json=request_body,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_msg = f"Failed to create File Search store: {response.text}"
|
||||
verbose_logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
response_data = response.json()
|
||||
store_name = response_data.get("name", "")
|
||||
|
||||
verbose_logger.debug(f"Created File Search store: {store_name}")
|
||||
return store_name
|
||||
|
||||
async def _upload_to_file_search_store(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
vector_store_id: str,
|
||||
filename: str,
|
||||
file_content: bytes,
|
||||
content_type: Optional[str],
|
||||
) -> str:
|
||||
"""
|
||||
Upload a file to Gemini File Search store using resumable upload.
|
||||
|
||||
Args:
|
||||
api_key: Gemini API key
|
||||
base_url: Base URL for Gemini API
|
||||
vector_store_id: File Search store name
|
||||
filename: Name of the file
|
||||
file_content: File content bytes
|
||||
content_type: MIME type
|
||||
|
||||
Returns:
|
||||
File ID or document name
|
||||
"""
|
||||
# Step 1: Initiate resumable upload
|
||||
upload_url = await self._initiate_resumable_upload(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
vector_store_id=vector_store_id,
|
||||
filename=filename,
|
||||
file_size=len(file_content),
|
||||
content_type=content_type or "application/octet-stream",
|
||||
)
|
||||
|
||||
# Step 2: Upload the file content
|
||||
file_id = await self._upload_file_content(
|
||||
upload_url=upload_url,
|
||||
file_content=file_content,
|
||||
)
|
||||
|
||||
return file_id
|
||||
|
||||
async def _initiate_resumable_upload(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
vector_store_id: str,
|
||||
filename: str,
|
||||
file_size: int,
|
||||
content_type: str,
|
||||
) -> str:
|
||||
"""
|
||||
Initiate a resumable upload session.
|
||||
|
||||
Returns:
|
||||
Upload URL for the resumable session
|
||||
"""
|
||||
# Construct the upload URL - need to use the full upload endpoint
|
||||
# base_url is like: https://generativelanguage.googleapis.com/v1beta
|
||||
# We need: https://generativelanguage.googleapis.com/upload/v1beta/{store_id}:uploadToFileSearchStore
|
||||
api_base = base_url.replace("/v1beta", "") # Get base without version
|
||||
url = f"{api_base}/upload/v1beta/{vector_store_id}:uploadToFileSearchStore?key={api_key}"
|
||||
|
||||
# Build request body with chunking config and metadata if provided
|
||||
request_body: Dict[str, Any] = {"displayName": filename}
|
||||
|
||||
# Add chunking configuration if provided
|
||||
chunking_strategy = self.chunking_strategy
|
||||
if chunking_strategy and isinstance(chunking_strategy, dict):
|
||||
white_space_config = chunking_strategy.get("white_space_config")
|
||||
if white_space_config:
|
||||
request_body["chunkingConfig"] = {
|
||||
"whiteSpaceConfig": {
|
||||
"maxTokensPerChunk": white_space_config.get(
|
||||
"max_tokens_per_chunk", 800
|
||||
),
|
||||
"maxOverlapTokens": white_space_config.get(
|
||||
"max_overlap_tokens", 400
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
# Add custom metadata if provided in vector_store_config
|
||||
custom_metadata = cast(
|
||||
Optional[List[Dict[str, Any]]],
|
||||
self.vector_store_config.get("custom_metadata"),
|
||||
)
|
||||
if custom_metadata:
|
||||
request_body["customMetadata"] = custom_metadata
|
||||
|
||||
headers = {
|
||||
"X-Goog-Upload-Protocol": "resumable",
|
||||
"X-Goog-Upload-Command": "start",
|
||||
"X-Goog-Upload-Header-Content-Length": str(file_size),
|
||||
"X-Goog-Upload-Header-Content-Type": content_type,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
verbose_logger.debug(f"Initiating resumable upload: {url}")
|
||||
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.RAG,
|
||||
params={"timeout": 60.0},
|
||||
)
|
||||
response = await client.post(
|
||||
url,
|
||||
json=request_body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 201]:
|
||||
error_msg = f"Failed to initiate upload: {response.text}"
|
||||
verbose_logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
verbose_logger.debug(f"Initiate resumable upload response: {response.headers}")
|
||||
# Extract upload URL from response headers
|
||||
upload_url = response.headers.get("x-goog-upload-url")
|
||||
if not upload_url:
|
||||
raise Exception("No upload URL returned in response headers")
|
||||
|
||||
verbose_logger.debug(f"Got upload URL: {upload_url}")
|
||||
return upload_url
|
||||
|
||||
async def _upload_file_content(
|
||||
self,
|
||||
upload_url: str,
|
||||
file_content: bytes,
|
||||
) -> str:
|
||||
"""
|
||||
Upload file content to the resumable upload URL.
|
||||
|
||||
Returns:
|
||||
File ID or document name from the response
|
||||
"""
|
||||
headers = {
|
||||
"Content-Length": str(len(file_content)),
|
||||
"X-Goog-Upload-Offset": "0",
|
||||
"X-Goog-Upload-Command": "upload, finalize",
|
||||
}
|
||||
|
||||
verbose_logger.debug(f"Uploading file content ({len(file_content)} bytes)")
|
||||
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.RAG,
|
||||
params={"timeout": 300.0}, # Longer timeout for large files
|
||||
)
|
||||
response = await client.put(
|
||||
upload_url,
|
||||
content=file_content,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 201]:
|
||||
error_msg = f"Failed to upload file: {response.text}"
|
||||
verbose_logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
# Parse response to get file/document ID
|
||||
try:
|
||||
response_data = response.json()
|
||||
# The response should contain the document name or file reference
|
||||
file_id = response_data.get("name", "") or response_data.get(
|
||||
"file", {}
|
||||
).get("name", "")
|
||||
verbose_logger.debug(f"Upload complete. File ID: {file_id}")
|
||||
return file_id
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Could not parse upload response: {e}")
|
||||
# Return a placeholder if we can't get the ID
|
||||
return "uploaded"
|
||||
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
OpenAI-specific RAG Ingestion implementation.
|
||||
|
||||
OpenAI handles embedding internally when files are attached to vector stores,
|
||||
so this implementation skips the embedding step and directly uploads files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
import litellm
|
||||
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
|
||||
from litellm.vector_store_files.main import acreate as vector_store_file_acreate
|
||||
from litellm.vector_stores.main import acreate as vector_store_acreate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import Router
|
||||
from litellm.types.rag import RAGIngestOptions
|
||||
|
||||
|
||||
class OpenAIRAGIngestion(BaseRAGIngestion):
|
||||
"""
|
||||
OpenAI-specific RAG ingestion.
|
||||
|
||||
Key differences from base:
|
||||
- Embedding is handled by OpenAI when attaching files to vector stores
|
||||
- Files are uploaded and attached to vector stores directly
|
||||
- Chunking is done by OpenAI's vector store (uses 'auto' strategy)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ingest_options: "RAGIngestOptions",
|
||||
router: Optional["Router"] = None,
|
||||
):
|
||||
super().__init__(ingest_options=ingest_options, router=router)
|
||||
|
||||
async def embed(
|
||||
self,
|
||||
chunks: List[str],
|
||||
) -> Optional[List[List[float]]]:
|
||||
"""
|
||||
OpenAI handles embedding internally - skip this step.
|
||||
|
||||
Returns:
|
||||
None (OpenAI embeds when files are attached to vector store)
|
||||
"""
|
||||
# OpenAI handles embedding when files are attached to vector stores
|
||||
return None
|
||||
|
||||
async def store(
|
||||
self,
|
||||
file_content: Optional[bytes],
|
||||
filename: Optional[str],
|
||||
content_type: Optional[str],
|
||||
chunks: List[str],
|
||||
embeddings: Optional[List[List[float]]],
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Store content in OpenAI vector store.
|
||||
|
||||
OpenAI workflow:
|
||||
1. Create vector store (if not provided)
|
||||
2. Upload file to OpenAI
|
||||
3. Attach file to vector store (OpenAI handles chunking/embedding)
|
||||
|
||||
Args:
|
||||
file_content: Raw file bytes
|
||||
filename: Name of the file
|
||||
content_type: MIME type
|
||||
chunks: Ignored - OpenAI handles chunking
|
||||
embeddings: Ignored - OpenAI handles embedding
|
||||
|
||||
Returns:
|
||||
Tuple of (vector_store_id, file_id)
|
||||
"""
|
||||
vector_store_id = self.vector_store_config.get("vector_store_id")
|
||||
ttl_days = self.vector_store_config.get("ttl_days")
|
||||
|
||||
# Get credentials from vector_store_config (loaded from litellm_credential_name if provided)
|
||||
api_key = self.vector_store_config.get("api_key")
|
||||
api_base = self.vector_store_config.get("api_base")
|
||||
|
||||
# Create vector store if not provided
|
||||
if not vector_store_id:
|
||||
expires_after = (
|
||||
{"anchor": "last_active_at", "days": ttl_days} if ttl_days else None
|
||||
)
|
||||
create_response = await vector_store_acreate(
|
||||
name=self.ingest_name or "litellm-rag-ingest",
|
||||
custom_llm_provider="openai",
|
||||
expires_after=expires_after,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
vector_store_id = create_response.get("id")
|
||||
|
||||
# Upload file and attach to vector store
|
||||
result_file_id = None
|
||||
if file_content and filename and vector_store_id:
|
||||
# Upload file to OpenAI
|
||||
file_response = await litellm.acreate_file(
|
||||
file=(
|
||||
filename,
|
||||
file_content,
|
||||
content_type or "application/octet-stream",
|
||||
),
|
||||
purpose="assistants",
|
||||
custom_llm_provider="openai",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
result_file_id = file_response.id
|
||||
|
||||
# Attach file to vector store (OpenAI handles chunking/embedding)
|
||||
await vector_store_file_acreate(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=result_file_id,
|
||||
custom_llm_provider="openai",
|
||||
chunking_strategy=cast(
|
||||
Optional[Dict[str, Any]], self.chunking_strategy
|
||||
),
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
return vector_store_id, result_file_id
|
||||
@@ -0,0 +1,594 @@
|
||||
"""
|
||||
S3 Vectors-specific RAG Ingestion implementation.
|
||||
|
||||
S3 Vectors is AWS's native vector storage service that provides:
|
||||
- Purpose-built vector buckets for storing and querying vectors
|
||||
- Vector indexes with configurable dimensions and distance metrics
|
||||
- Metadata filtering for semantic search
|
||||
|
||||
This implementation:
|
||||
1. Auto-creates vector buckets and indexes if not provided
|
||||
2. Uses LiteLLM's embedding API (supports any provider)
|
||||
3. Uses httpx + AWS SigV4 signing (no boto3 dependency for S3 Vectors APIs)
|
||||
4. Stores vectors with metadata using PutVectors API
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import (
|
||||
S3_VECTORS_DEFAULT_DIMENSION,
|
||||
S3_VECTORS_DEFAULT_DISTANCE_METRIC,
|
||||
S3_VECTORS_DEFAULT_NON_FILTERABLE_METADATA_KEYS,
|
||||
)
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import Router
|
||||
from litellm.types.rag import RAGIngestOptions
|
||||
|
||||
|
||||
class S3VectorsRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
|
||||
"""
|
||||
S3 Vectors RAG ingestion using httpx + AWS SigV4 signing.
|
||||
|
||||
Workflow:
|
||||
1. Auto-create vector bucket if needed (CreateVectorBucket API)
|
||||
2. Auto-create vector index if needed (CreateVectorIndex API)
|
||||
3. Generate embeddings using LiteLLM (supports any provider)
|
||||
4. Store vectors with PutVectors API
|
||||
|
||||
Configuration:
|
||||
- vector_bucket_name: S3 vector bucket name (required)
|
||||
- index_name: Vector index name (auto-creates if not provided)
|
||||
- dimension: Vector dimension (default: S3_VECTORS_DEFAULT_DIMENSION)
|
||||
- distance_metric: "cosine" or "euclidean" (default: S3_VECTORS_DEFAULT_DISTANCE_METRIC)
|
||||
- non_filterable_metadata_keys: List of metadata keys to exclude from filtering
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ingest_options: "RAGIngestOptions",
|
||||
router: Optional["Router"] = None,
|
||||
):
|
||||
BaseRAGIngestion.__init__(self, ingest_options=ingest_options, router=router)
|
||||
BaseAWSLLM.__init__(self)
|
||||
|
||||
# Extract config
|
||||
self.vector_bucket_name = self.vector_store_config["vector_bucket_name"]
|
||||
self.index_name = self.vector_store_config.get("index_name")
|
||||
self.distance_metric = self.vector_store_config.get(
|
||||
"distance_metric", S3_VECTORS_DEFAULT_DISTANCE_METRIC
|
||||
)
|
||||
self.non_filterable_metadata_keys = self.vector_store_config.get(
|
||||
"non_filterable_metadata_keys",
|
||||
S3_VECTORS_DEFAULT_NON_FILTERABLE_METADATA_KEYS,
|
||||
)
|
||||
|
||||
# Get dimension from config (will be auto-detected on first use if not provided)
|
||||
self.dimension = self._get_dimension_from_config()
|
||||
|
||||
# Get AWS region using BaseAWSLLM method
|
||||
_aws_region = self.vector_store_config.get("aws_region_name")
|
||||
self.aws_region_name = self.get_aws_region_name_for_non_llm_api_calls(
|
||||
aws_region_name=str(_aws_region) if _aws_region else None
|
||||
)
|
||||
|
||||
# Create httpx client (similar to s3_v2.py)
|
||||
ssl_verify = self._get_ssl_verify(
|
||||
ssl_verify=self.vector_store_config.get("ssl_verify")
|
||||
)
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.RAG,
|
||||
params={"ssl_verify": ssl_verify} if ssl_verify is not None else None,
|
||||
)
|
||||
|
||||
# Track if infrastructure is initialized
|
||||
self._config_initialized = False
|
||||
|
||||
async def _get_dimension_from_embedding_request(self) -> int:
|
||||
"""
|
||||
Auto-detect dimension by making a test embedding request.
|
||||
|
||||
Makes a single embedding request with a test string to determine
|
||||
the output dimension of the embedding model.
|
||||
"""
|
||||
if not self.embedding_config or "model" not in self.embedding_config:
|
||||
return S3_VECTORS_DEFAULT_DIMENSION
|
||||
|
||||
try:
|
||||
model_name = self.embedding_config["model"]
|
||||
verbose_logger.debug(
|
||||
f"Auto-detecting dimension by making test embedding request to {model_name}"
|
||||
)
|
||||
|
||||
# Make a test embedding request
|
||||
test_input = "test"
|
||||
if self.router:
|
||||
response = await self.router.aembedding(
|
||||
model=model_name, input=[test_input]
|
||||
)
|
||||
else:
|
||||
response = await litellm.aembedding(
|
||||
model=model_name, input=[test_input]
|
||||
)
|
||||
|
||||
# Get dimension from the response
|
||||
if response.data and len(response.data) > 0:
|
||||
dimension = len(response.data[0]["embedding"])
|
||||
verbose_logger.debug(
|
||||
f"Auto-detected dimension {dimension} for embedding model {model_name}"
|
||||
)
|
||||
return dimension
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
f"Could not auto-detect dimension from embedding model: {e}. "
|
||||
f"Using default dimension of {S3_VECTORS_DEFAULT_DIMENSION}."
|
||||
)
|
||||
|
||||
return S3_VECTORS_DEFAULT_DIMENSION
|
||||
|
||||
def _get_dimension_from_config(self) -> Optional[int]:
|
||||
"""
|
||||
Get vector dimension from config if explicitly provided.
|
||||
|
||||
Returns None if dimension should be auto-detected.
|
||||
"""
|
||||
if "dimension" in self.vector_store_config:
|
||||
return int(self.vector_store_config["dimension"])
|
||||
return None
|
||||
|
||||
async def _ensure_config_initialized(self):
|
||||
"""Lazily initialize S3 Vectors infrastructure."""
|
||||
if self._config_initialized:
|
||||
return
|
||||
|
||||
# Auto-detect dimension if not provided
|
||||
if self.dimension is None:
|
||||
self.dimension = await self._get_dimension_from_embedding_request()
|
||||
|
||||
# Ensure vector bucket exists
|
||||
await self._ensure_vector_bucket_exists()
|
||||
|
||||
# Ensure vector index exists
|
||||
if not self.index_name:
|
||||
# Auto-generate index name
|
||||
unique_id = uuid.uuid4().hex[:8]
|
||||
self.index_name = f"litellm-index-{unique_id}"
|
||||
|
||||
await self._ensure_vector_index_exists()
|
||||
|
||||
self._config_initialized = True
|
||||
|
||||
async def _sign_and_execute_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
data: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Helper to sign and execute AWS API requests using httpx + SigV4.
|
||||
|
||||
Pattern from litellm/integrations/s3_v2.py
|
||||
"""
|
||||
try:
|
||||
import requests
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Missing botocore to call S3 Vectors. Run 'pip install boto3'."
|
||||
)
|
||||
|
||||
# Get AWS credentials using BaseAWSLLM's get_credentials method
|
||||
credentials = self.get_credentials(
|
||||
aws_access_key_id=self.vector_store_config.get("aws_access_key_id"),
|
||||
aws_secret_access_key=self.vector_store_config.get("aws_secret_access_key"),
|
||||
aws_session_token=self.vector_store_config.get("aws_session_token"),
|
||||
aws_region_name=self.aws_region_name,
|
||||
aws_session_name=self.vector_store_config.get("aws_session_name"),
|
||||
aws_profile_name=self.vector_store_config.get("aws_profile_name"),
|
||||
aws_role_name=self.vector_store_config.get("aws_role_name"),
|
||||
aws_web_identity_token=self.vector_store_config.get(
|
||||
"aws_web_identity_token"
|
||||
),
|
||||
aws_sts_endpoint=self.vector_store_config.get("aws_sts_endpoint"),
|
||||
aws_external_id=self.vector_store_config.get("aws_external_id"),
|
||||
)
|
||||
|
||||
# Prepare headers
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
if data:
|
||||
headers["Content-Type"] = "application/json"
|
||||
# Calculate SHA256 hash of the content
|
||||
content_hash = hashlib.sha256(data.encode("utf-8")).hexdigest()
|
||||
headers["x-amz-content-sha256"] = content_hash
|
||||
else:
|
||||
# For requests without body, use hash of empty string
|
||||
headers["x-amz-content-sha256"] = hashlib.sha256(b"").hexdigest()
|
||||
|
||||
# Prepare the request
|
||||
req = requests.Request(method, url, data=data, headers=headers)
|
||||
prepped = req.prepare()
|
||||
|
||||
# Sign the request
|
||||
aws_request = AWSRequest(
|
||||
method=prepped.method,
|
||||
url=prepped.url,
|
||||
data=prepped.body,
|
||||
headers=prepped.headers,
|
||||
)
|
||||
SigV4Auth(credentials, "s3vectors", self.aws_region_name).add_auth(aws_request)
|
||||
|
||||
# Prepare the signed headers
|
||||
signed_headers = dict(aws_request.headers.items())
|
||||
|
||||
# Make the request using specific method (pattern from s3_v2.py)
|
||||
method_upper = method.upper()
|
||||
if method_upper == "PUT":
|
||||
response = await self.async_httpx_client.put(
|
||||
url, data=data, headers=signed_headers
|
||||
)
|
||||
elif method_upper == "POST":
|
||||
response = await self.async_httpx_client.post(
|
||||
url, data=data, headers=signed_headers
|
||||
)
|
||||
elif method_upper == "GET":
|
||||
response = await self.async_httpx_client.get(url, headers=signed_headers)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
return response
|
||||
|
||||
async def _ensure_vector_bucket_exists(self):
|
||||
"""Create vector bucket if it doesn't exist using GetVectorBucket and CreateVectorBucket APIs."""
|
||||
verbose_logger.debug(
|
||||
f"Ensuring S3 vector bucket exists: {self.vector_bucket_name}"
|
||||
)
|
||||
|
||||
# Validate bucket name (AWS S3 naming rules)
|
||||
if len(self.vector_bucket_name) < 3:
|
||||
raise ValueError(
|
||||
f"Invalid vector_bucket_name '{self.vector_bucket_name}': "
|
||||
f"AWS S3 bucket names must be at least 3 characters long. "
|
||||
f"Please provide a valid bucket name (e.g., 'my-vector-bucket')."
|
||||
)
|
||||
if not self.vector_bucket_name.replace("-", "").replace(".", "").isalnum():
|
||||
raise ValueError(
|
||||
f"Invalid vector_bucket_name '{self.vector_bucket_name}': "
|
||||
f"AWS S3 bucket names can only contain lowercase letters, numbers, hyphens, and periods. "
|
||||
f"Please provide a valid bucket name (e.g., 'my-vector-bucket')."
|
||||
)
|
||||
|
||||
# Try to get bucket info using GetVectorBucket API
|
||||
get_url = f"https://s3vectors.{self.aws_region_name}.api.aws/GetVectorBucket"
|
||||
get_body = safe_dumps({"vectorBucketName": self.vector_bucket_name})
|
||||
|
||||
try:
|
||||
response = await self._sign_and_execute_request(
|
||||
"POST", get_url, data=get_body
|
||||
)
|
||||
if response.status_code == 200:
|
||||
verbose_logger.debug(f"Vector bucket {self.vector_bucket_name} exists")
|
||||
return
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Bucket check failed (may not exist): {e}, attempting to create"
|
||||
)
|
||||
|
||||
# Create vector bucket using CreateVectorBucket API
|
||||
try:
|
||||
verbose_logger.debug(f"Creating vector bucket: {self.vector_bucket_name}")
|
||||
create_url = (
|
||||
f"https://s3vectors.{self.aws_region_name}.api.aws/CreateVectorBucket"
|
||||
)
|
||||
create_body = safe_dumps({"vectorBucketName": self.vector_bucket_name})
|
||||
|
||||
response = await self._sign_and_execute_request(
|
||||
"POST", create_url, data=create_body
|
||||
)
|
||||
|
||||
if response.status_code in (200, 201):
|
||||
verbose_logger.info(f"Created vector bucket: {self.vector_bucket_name}")
|
||||
elif response.status_code == 409:
|
||||
# Bucket already exists (ConflictException)
|
||||
verbose_logger.debug(
|
||||
f"Vector bucket {self.vector_bucket_name} already exists"
|
||||
)
|
||||
else:
|
||||
verbose_logger.error(
|
||||
f"CreateVectorBucket failed: {response.status_code} - {response.text}"
|
||||
)
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error creating vector bucket: {e}")
|
||||
raise
|
||||
|
||||
async def _ensure_vector_index_exists(self):
|
||||
"""Create vector index if it doesn't exist using GetIndex and CreateIndex APIs."""
|
||||
verbose_logger.debug(
|
||||
f"Ensuring vector index exists: {self.vector_bucket_name}/{self.index_name}"
|
||||
)
|
||||
|
||||
# Try to get index info using GetIndex API
|
||||
get_url = f"https://s3vectors.{self.aws_region_name}.api.aws/GetIndex"
|
||||
get_body = safe_dumps(
|
||||
{"vectorBucketName": self.vector_bucket_name, "indexName": self.index_name}
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self._sign_and_execute_request(
|
||||
"POST", get_url, data=get_body
|
||||
)
|
||||
if response.status_code == 200:
|
||||
verbose_logger.debug(f"Vector index {self.index_name} exists")
|
||||
return
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Index check failed (may not exist): {e}, attempting to create"
|
||||
)
|
||||
|
||||
# Create vector index using CreateIndex API
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
f"Creating vector index: {self.index_name} with dimension={self.dimension}, metric={self.distance_metric}"
|
||||
)
|
||||
|
||||
# Prepare index configuration per AWS API docs
|
||||
index_config = {
|
||||
"vectorBucketName": self.vector_bucket_name,
|
||||
"indexName": self.index_name,
|
||||
"dataType": "float32",
|
||||
"dimension": self.dimension,
|
||||
"distanceMetric": self.distance_metric,
|
||||
}
|
||||
|
||||
if self.non_filterable_metadata_keys:
|
||||
index_config["metadataConfiguration"] = {
|
||||
"nonFilterableMetadataKeys": self.non_filterable_metadata_keys
|
||||
}
|
||||
|
||||
create_url = f"https://s3vectors.{self.aws_region_name}.api.aws/CreateIndex"
|
||||
response = await self._sign_and_execute_request(
|
||||
"POST", create_url, data=safe_dumps(index_config)
|
||||
)
|
||||
|
||||
if response.status_code in (200, 201):
|
||||
verbose_logger.info(f"Created vector index: {self.index_name}")
|
||||
elif response.status_code == 409:
|
||||
verbose_logger.debug(f"Vector index {self.index_name} already exists")
|
||||
else:
|
||||
verbose_logger.error(
|
||||
f"CreateIndex failed: {response.status_code} - {response.text}"
|
||||
)
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error creating vector index: {e}")
|
||||
raise
|
||||
|
||||
async def _put_vectors(self, vectors: List[Dict[str, Any]]):
|
||||
"""
|
||||
Call PutVectors API to store vectors in S3 Vectors.
|
||||
|
||||
Args:
|
||||
vectors: List of vector objects with keys: "key", "data", "metadata"
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
f"Storing {len(vectors)} vectors in {self.vector_bucket_name}/{self.index_name}"
|
||||
)
|
||||
|
||||
url = f"https://s3vectors.{self.aws_region_name}.api.aws/PutVectors"
|
||||
|
||||
# Prepare request body per AWS API docs
|
||||
request_body = {
|
||||
"vectorBucketName": self.vector_bucket_name,
|
||||
"indexName": self.index_name,
|
||||
"vectors": vectors,
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self._sign_and_execute_request(
|
||||
"POST", url, data=safe_dumps(request_body)
|
||||
)
|
||||
|
||||
if response.status_code in (200, 201):
|
||||
verbose_logger.info(
|
||||
f"Successfully stored {len(vectors)} vectors in index {self.index_name}"
|
||||
)
|
||||
else:
|
||||
verbose_logger.error(
|
||||
f"PutVectors failed with status {response.status_code}: {response.text}"
|
||||
)
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error storing vectors: {e}")
|
||||
raise
|
||||
|
||||
async def embed(
|
||||
self,
|
||||
chunks: List[str],
|
||||
) -> Optional[List[List[float]]]:
|
||||
"""
|
||||
Generate embeddings using LiteLLM's embedding API.
|
||||
|
||||
Supports any embedding provider (OpenAI, Bedrock, Cohere, etc.)
|
||||
"""
|
||||
if not chunks:
|
||||
return None
|
||||
|
||||
# Use embedding config from ingest_options or default
|
||||
if not self.embedding_config:
|
||||
verbose_logger.warning(
|
||||
"No embedding config provided, using default text-embedding-3-small"
|
||||
)
|
||||
self.embedding_config = {"model": "text-embedding-3-small"}
|
||||
|
||||
embedding_model = self.embedding_config.get("model", "text-embedding-3-small")
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Generating embeddings for {len(chunks)} chunks using {embedding_model}"
|
||||
)
|
||||
|
||||
# Convert to list to ensure type compatibility
|
||||
input_chunks: List[str] = list(chunks)
|
||||
|
||||
if self.router:
|
||||
response = await self.router.aembedding(
|
||||
model=embedding_model, input=input_chunks
|
||||
)
|
||||
else:
|
||||
response = await litellm.aembedding(
|
||||
model=embedding_model, input=input_chunks
|
||||
)
|
||||
|
||||
return [item["embedding"] for item in response.data]
|
||||
|
||||
async def store(
|
||||
self,
|
||||
file_content: Optional[bytes],
|
||||
filename: Optional[str],
|
||||
content_type: Optional[str],
|
||||
chunks: List[str],
|
||||
embeddings: Optional[List[List[float]]],
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Store vectors in S3 Vectors using PutVectors API.
|
||||
|
||||
Steps:
|
||||
1. Ensure vector bucket exists (auto-create if needed)
|
||||
2. Ensure vector index exists (auto-create if needed)
|
||||
3. Prepare vector data with metadata
|
||||
4. Call PutVectors API with httpx + SigV4 signing
|
||||
|
||||
Args:
|
||||
file_content: Raw file bytes (not used for S3 Vectors)
|
||||
filename: Name of the file
|
||||
content_type: MIME type (not used for S3 Vectors)
|
||||
chunks: Text chunks
|
||||
embeddings: Vector embeddings
|
||||
|
||||
Returns:
|
||||
Tuple of (index_name, filename)
|
||||
"""
|
||||
# Ensure infrastructure exists
|
||||
await self._ensure_config_initialized()
|
||||
|
||||
if not embeddings or not chunks:
|
||||
error_msg = (
|
||||
"No text content could be extracted from the file for embedding. "
|
||||
"Possible causes:\n"
|
||||
" 1. PDF files require OCR - add 'ocr' config with a vision model (e.g., 'anthropic/claude-3-5-sonnet-20241022')\n"
|
||||
" 2. Binary files cannot be processed - convert to text first\n"
|
||||
" 3. File is empty or contains no extractable text\n"
|
||||
"For PDFs, either enable OCR or use a PDF extraction library to convert to text before ingestion."
|
||||
)
|
||||
verbose_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Prepare vectors for PutVectors API
|
||||
vectors = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
# Build metadata dict
|
||||
metadata: Dict[str, str] = {
|
||||
"source_text": chunk, # Non-filterable (for reference)
|
||||
"chunk_index": str(i), # Filterable
|
||||
}
|
||||
|
||||
if filename:
|
||||
metadata["filename"] = filename # Filterable
|
||||
|
||||
vector_obj = {
|
||||
"key": f"{filename}_{i}" if filename else f"chunk_{i}",
|
||||
"data": {"float32": embedding},
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
vectors.append(vector_obj)
|
||||
|
||||
# Call PutVectors API
|
||||
await self._put_vectors(vectors)
|
||||
|
||||
# Return vector_store_id in format bucket_name:index_name for S3 Vectors search compatibility
|
||||
vector_store_id = f"{self.vector_bucket_name}:{self.index_name}"
|
||||
return vector_store_id, filename
|
||||
|
||||
async def query_vector_store(
|
||||
self, vector_store_id: str, query: str, top_k: int = 5
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Query S3 Vectors using QueryVectors API.
|
||||
|
||||
Args:
|
||||
vector_store_id: Index name
|
||||
query: Query text
|
||||
top_k: Number of results to return
|
||||
|
||||
Returns:
|
||||
Query results with vectors and metadata
|
||||
"""
|
||||
verbose_logger.debug(f"Querying index {vector_store_id} with query: {query}")
|
||||
|
||||
# Generate query embedding
|
||||
if not self.embedding_config:
|
||||
self.embedding_config = {"model": "text-embedding-3-small"}
|
||||
|
||||
embedding_model = self.embedding_config.get("model", "text-embedding-3-small")
|
||||
|
||||
response = await litellm.aembedding(model=embedding_model, input=[query])
|
||||
query_embedding = response.data[0]["embedding"]
|
||||
|
||||
# Call QueryVectors API
|
||||
url = f"https://s3vectors.{self.aws_region_name}.api.aws/QueryVectors"
|
||||
|
||||
request_body = {
|
||||
"vectorBucketName": self.vector_bucket_name,
|
||||
"indexName": vector_store_id,
|
||||
"queryVector": {"float32": query_embedding},
|
||||
"topK": top_k,
|
||||
"returnDistance": True,
|
||||
"returnMetadata": True,
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self._sign_and_execute_request(
|
||||
"POST", url, data=safe_dumps(request_body)
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
results = response.json()
|
||||
verbose_logger.debug(
|
||||
f"Query returned {len(results.get('vectors', []))} results"
|
||||
)
|
||||
|
||||
# Check if query terms appear in results
|
||||
if results.get("vectors"):
|
||||
for result in results["vectors"]:
|
||||
metadata = result.get("metadata", {})
|
||||
source_text = metadata.get("source_text", "")
|
||||
if query.lower() in source_text.lower():
|
||||
return results
|
||||
|
||||
# Return results even if exact match not found
|
||||
return results
|
||||
else:
|
||||
verbose_logger.error(
|
||||
f"QueryVectors failed with status {response.status_code}: {response.text}"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error querying vectors: {e}")
|
||||
return None
|
||||
@@ -0,0 +1,478 @@
|
||||
"""
|
||||
Vertex AI-specific RAG Ingestion implementation.
|
||||
|
||||
Vertex AI RAG Engine handles embedding and chunking internally when files are uploaded,
|
||||
so this implementation skips the embedding step and directly uploads files to RAG corpora.
|
||||
|
||||
Based on: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/model-reference/rag-api-v1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import Router
|
||||
from litellm.types.rag import RAGIngestOptions
|
||||
|
||||
|
||||
class VertexAIRAGIngestion(BaseRAGIngestion, VertexBase):
|
||||
"""
|
||||
Vertex AI RAG Engine ingestion implementation.
|
||||
|
||||
Key differences from base:
|
||||
- Embedding is handled by Vertex AI RAG Engine when files are uploaded
|
||||
- Files are uploaded using the RAG API (import or upload)
|
||||
- Chunking is done by Vertex AI RAG Engine (supports custom chunking config)
|
||||
- Supports Google Cloud Storage (GCS) and Google Drive sources
|
||||
- Supports custom parsing configurations (layout parser, LLM parser)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ingest_options: "RAGIngestOptions",
|
||||
router: Optional["Router"] = None,
|
||||
):
|
||||
BaseRAGIngestion.__init__(self, ingest_options=ingest_options, router=router)
|
||||
VertexBase.__init__(self)
|
||||
|
||||
# Extract Vertex AI specific configs from vector_store_config
|
||||
litellm_params = dict(self.vector_store_config)
|
||||
|
||||
# Get project, location, and credentials using VertexBase methods
|
||||
self.project_id = self.safe_get_vertex_ai_project(litellm_params)
|
||||
self.location = self.get_vertex_ai_location(litellm_params) or "us-central1"
|
||||
self.vertex_credentials = self.safe_get_vertex_ai_credentials(litellm_params)
|
||||
|
||||
async def embed(
|
||||
self,
|
||||
chunks: List[str],
|
||||
) -> Optional[List[List[float]]]:
|
||||
"""
|
||||
Vertex AI RAG Engine handles embedding internally - skip this step.
|
||||
|
||||
Returns:
|
||||
None (Vertex AI embeds when files are uploaded to RAG corpus)
|
||||
"""
|
||||
# Vertex AI RAG Engine handles embedding when files are uploaded
|
||||
return None
|
||||
|
||||
async def store(
|
||||
self,
|
||||
file_content: Optional[bytes],
|
||||
filename: Optional[str],
|
||||
content_type: Optional[str],
|
||||
chunks: List[str],
|
||||
embeddings: Optional[List[List[float]]],
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Store content in Vertex AI RAG corpus.
|
||||
|
||||
Vertex AI workflow:
|
||||
1. Create RAG corpus (if not provided)
|
||||
2. Upload file using RAG API (Vertex AI handles chunking/embedding)
|
||||
|
||||
Args:
|
||||
file_content: Raw file bytes
|
||||
filename: Name of the file
|
||||
content_type: MIME type
|
||||
chunks: Ignored - Vertex AI handles chunking
|
||||
embeddings: Ignored - Vertex AI handles embedding
|
||||
|
||||
Returns:
|
||||
Tuple of (rag_corpus_id, file_id)
|
||||
"""
|
||||
if not self.project_id:
|
||||
raise ValueError(
|
||||
"vertex_project is required for Vertex AI RAG ingestion. "
|
||||
"Set it in vector_store config."
|
||||
)
|
||||
|
||||
# Get or create RAG corpus
|
||||
rag_corpus_id = self.vector_store_config.get("vector_store_id")
|
||||
if not rag_corpus_id:
|
||||
rag_corpus_id = await self._create_rag_corpus(
|
||||
display_name=self.ingest_name or "litellm-rag-corpus",
|
||||
description=self.vector_store_config.get("description"),
|
||||
)
|
||||
|
||||
# Upload file to RAG corpus
|
||||
result_file_id = None
|
||||
if file_content and filename and rag_corpus_id:
|
||||
result_file_id = await self._upload_file_to_corpus(
|
||||
rag_corpus_id=rag_corpus_id,
|
||||
filename=filename,
|
||||
file_content=file_content,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
return rag_corpus_id, result_file_id
|
||||
|
||||
async def _create_rag_corpus(
|
||||
self,
|
||||
display_name: str,
|
||||
description: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a Vertex AI RAG corpus.
|
||||
|
||||
Args:
|
||||
display_name: Display name for the corpus
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
RAG corpus ID (format: projects/{project}/locations/{location}/ragCorpora/{corpus_id})
|
||||
"""
|
||||
# Get access token using VertexBase method
|
||||
access_token, project_id = self._ensure_access_token(
|
||||
credentials=self.vertex_credentials,
|
||||
project_id=self.project_id,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
# Use the project_id from token if not set
|
||||
if not self.project_id:
|
||||
self.project_id = project_id
|
||||
|
||||
# Construct URL using vertex base URL helper
|
||||
base_url = get_vertex_base_url(self.location)
|
||||
url = (
|
||||
f"{base_url}/v1beta1/"
|
||||
f"projects/{self.project_id}/locations/{self.location}/ragCorpora"
|
||||
)
|
||||
|
||||
# Build request body with camelCase keys (Vertex AI API format)
|
||||
request_body: Dict[str, Any] = {
|
||||
"displayName": display_name,
|
||||
}
|
||||
|
||||
if description:
|
||||
request_body["description"] = description
|
||||
|
||||
# Add vector database config if specified
|
||||
vector_db_config = self.vector_store_config.get("vector_db_config")
|
||||
if vector_db_config:
|
||||
request_body["vectorDbConfig"] = vector_db_config
|
||||
|
||||
# Add embedding model config if specified
|
||||
embedding_model = self.vector_store_config.get("embedding_model")
|
||||
if embedding_model:
|
||||
if "vectorDbConfig" not in request_body:
|
||||
request_body["vectorDbConfig"] = {}
|
||||
request_body["vectorDbConfig"]["ragEmbeddingModelConfig"] = {
|
||||
"vertexPredictionEndpoint": {"endpoint": embedding_model}
|
||||
}
|
||||
|
||||
verbose_logger.debug(f"Creating RAG corpus: {url}")
|
||||
verbose_logger.debug(f"Request body: {json.dumps(request_body, indent=2)}")
|
||||
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.RAG,
|
||||
params={"timeout": 60.0},
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
url,
|
||||
json=request_body,
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
if response.status_code not in [200, 201]:
|
||||
error_msg = f"Failed to create RAG corpus: {response.text}"
|
||||
verbose_logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
response_data = response.json()
|
||||
verbose_logger.debug(
|
||||
f"Create corpus response: {json.dumps(response_data, indent=2)}"
|
||||
)
|
||||
|
||||
# The response is a long-running operation
|
||||
# Check if it's already done or if we need to poll
|
||||
if response_data.get("done"):
|
||||
# Operation completed immediately
|
||||
corpus_name = response_data.get("response", {}).get("name", "")
|
||||
else:
|
||||
# Need to poll the operation
|
||||
operation_name = response_data.get("name", "")
|
||||
verbose_logger.debug(f"Polling operation: {operation_name}")
|
||||
corpus_name = await self._poll_operation(
|
||||
operation_name=operation_name,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Created RAG corpus: {corpus_name}")
|
||||
return corpus_name
|
||||
|
||||
async def _poll_operation(
|
||||
self,
|
||||
operation_name: str,
|
||||
access_token: str,
|
||||
max_retries: int = 30,
|
||||
retry_delay: float = 2.0,
|
||||
) -> str:
|
||||
"""
|
||||
Poll a long-running operation until it completes.
|
||||
|
||||
Args:
|
||||
operation_name: The operation name (e.g., "operations/123456")
|
||||
access_token: Access token for authentication
|
||||
max_retries: Maximum number of polling attempts
|
||||
retry_delay: Delay between polling attempts in seconds
|
||||
|
||||
Returns:
|
||||
The corpus name from the completed operation
|
||||
|
||||
Raises:
|
||||
Exception: If operation fails or times out
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
base_url = get_vertex_base_url(self.location)
|
||||
# Operation name is like: projects/{project}/locations/{location}/operations/{operation_id}
|
||||
# We need to construct the full URL
|
||||
url = f"{base_url}/v1beta1/{operation_name}"
|
||||
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.RAG,
|
||||
params={"timeout": 60.0},
|
||||
)
|
||||
|
||||
for attempt in range(max_retries):
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_msg = f"Failed to poll operation: {response.text}"
|
||||
verbose_logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
operation_data = response.json()
|
||||
|
||||
if operation_data.get("done"):
|
||||
# Check for errors
|
||||
if "error" in operation_data:
|
||||
error = operation_data["error"]
|
||||
raise Exception(f"Operation failed: {error}")
|
||||
|
||||
# Extract corpus name from response
|
||||
corpus_name = operation_data.get("response", {}).get("name", "")
|
||||
if corpus_name:
|
||||
return corpus_name
|
||||
else:
|
||||
raise Exception(
|
||||
f"No corpus name in operation response: {operation_data}"
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Operation not done yet, attempt {attempt + 1}/{max_retries}"
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
|
||||
raise Exception(f"Operation timed out after {max_retries} attempts")
|
||||
|
||||
async def _upload_file_to_corpus(
|
||||
self,
|
||||
rag_corpus_id: str,
|
||||
filename: str,
|
||||
file_content: bytes,
|
||||
content_type: Optional[str],
|
||||
) -> str:
|
||||
"""
|
||||
Upload a file to Vertex AI RAG corpus using multipart upload.
|
||||
|
||||
Args:
|
||||
rag_corpus_id: RAG corpus resource name
|
||||
filename: Name of the file
|
||||
file_content: File content bytes
|
||||
content_type: MIME type
|
||||
|
||||
Returns:
|
||||
File ID or resource name
|
||||
"""
|
||||
# Get access token using VertexBase method
|
||||
access_token, _ = self._ensure_access_token(
|
||||
credentials=self.vertex_credentials,
|
||||
project_id=self.project_id,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
# Construct upload URL using vertex base URL helper
|
||||
base_url = get_vertex_base_url(self.location)
|
||||
url = f"{base_url}/upload/v1beta1/" f"{rag_corpus_id}/ragFiles:upload"
|
||||
|
||||
# Build metadata for the file with snake_case keys (as per upload API docs)
|
||||
metadata: Dict[str, Any] = {
|
||||
"rag_file": {
|
||||
"display_name": filename,
|
||||
}
|
||||
}
|
||||
|
||||
# Add description if provided
|
||||
description = self.vector_store_config.get("file_description")
|
||||
if description:
|
||||
metadata["rag_file"]["description"] = description
|
||||
|
||||
# Add chunking configuration if provided
|
||||
chunking_strategy = self.chunking_strategy
|
||||
if chunking_strategy and isinstance(chunking_strategy, dict):
|
||||
chunk_size = chunking_strategy.get("chunk_size")
|
||||
chunk_overlap = chunking_strategy.get("chunk_overlap")
|
||||
|
||||
if chunk_size or chunk_overlap:
|
||||
if "upload_rag_file_config" not in metadata:
|
||||
metadata["upload_rag_file_config"] = {}
|
||||
|
||||
metadata["upload_rag_file_config"]["rag_file_transformation_config"] = {
|
||||
"rag_file_chunking_config": {"fixed_length_chunking": {}}
|
||||
}
|
||||
|
||||
chunking_config = metadata["upload_rag_file_config"][
|
||||
"rag_file_transformation_config"
|
||||
]["rag_file_chunking_config"]["fixed_length_chunking"]
|
||||
|
||||
if chunk_size:
|
||||
chunking_config["chunk_size"] = chunk_size
|
||||
if chunk_overlap:
|
||||
chunking_config["chunk_overlap"] = chunk_overlap
|
||||
|
||||
verbose_logger.debug(f"Uploading file to RAG corpus: {url}")
|
||||
verbose_logger.debug(f"Metadata: {json.dumps(metadata, indent=2)}")
|
||||
|
||||
# Prepare multipart form data
|
||||
files = {
|
||||
"metadata": (None, json.dumps(metadata), "application/json"),
|
||||
"file": (
|
||||
filename,
|
||||
file_content,
|
||||
content_type or "application/octet-stream",
|
||||
),
|
||||
}
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.RAG,
|
||||
params={"timeout": 300.0}, # Longer timeout for large files
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
url,
|
||||
files=files,
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"X-Goog-Upload-Protocol": "multipart",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 201]:
|
||||
error_msg = f"Failed to upload file: {response.text}"
|
||||
verbose_logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
# Parse response to get file ID
|
||||
try:
|
||||
response_data = response.json()
|
||||
# The response should contain the rag_file resource name
|
||||
file_id = response_data.get("ragFile", {}).get("name", "")
|
||||
if not file_id:
|
||||
file_id = response_data.get("name", "")
|
||||
|
||||
verbose_logger.debug(f"Upload complete. File ID: {file_id}")
|
||||
return file_id
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Could not parse upload response: {e}")
|
||||
return "uploaded"
|
||||
|
||||
async def _import_files_from_gcs(
|
||||
self,
|
||||
rag_corpus_id: str,
|
||||
gcs_uris: List[str],
|
||||
) -> str:
|
||||
"""
|
||||
Import files from Google Cloud Storage into RAG corpus.
|
||||
|
||||
Args:
|
||||
rag_corpus_id: RAG corpus resource name
|
||||
gcs_uris: List of GCS URIs (e.g., ["gs://bucket/file.pdf"])
|
||||
|
||||
Returns:
|
||||
Operation name for tracking import progress
|
||||
"""
|
||||
# Get access token using VertexBase method
|
||||
access_token, _ = self._ensure_access_token(
|
||||
credentials=self.vertex_credentials,
|
||||
project_id=self.project_id,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
# Construct import URL using vertex base URL helper
|
||||
base_url = get_vertex_base_url(self.location)
|
||||
url = f"{base_url}/v1beta1/" f"{rag_corpus_id}/ragFiles:import"
|
||||
|
||||
# Build request body with camelCase keys (Vertex AI API format)
|
||||
request_body: Dict[str, Any] = {
|
||||
"importRagFilesConfig": {"gcsSource": {"uris": gcs_uris}}
|
||||
}
|
||||
|
||||
# Add chunking configuration if provided
|
||||
chunking_strategy = self.chunking_strategy
|
||||
if chunking_strategy and isinstance(chunking_strategy, dict):
|
||||
chunk_size = chunking_strategy.get("chunk_size")
|
||||
chunk_overlap = chunking_strategy.get("chunk_overlap")
|
||||
|
||||
if chunk_size or chunk_overlap:
|
||||
request_body["importRagFilesConfig"]["ragFileChunkingConfig"] = {
|
||||
"chunkSize": chunk_size or 1024,
|
||||
"chunkOverlap": chunk_overlap or 200,
|
||||
}
|
||||
|
||||
# Add max embedding requests per minute if specified
|
||||
max_embedding_qpm = self.vector_store_config.get(
|
||||
"max_embedding_requests_per_min"
|
||||
)
|
||||
if max_embedding_qpm:
|
||||
request_body["importRagFilesConfig"][
|
||||
"maxEmbeddingRequestsPerMin"
|
||||
] = max_embedding_qpm
|
||||
|
||||
verbose_logger.debug(f"Importing files from GCS: {url}")
|
||||
verbose_logger.debug(f"Request body: {json.dumps(request_body, indent=2)}")
|
||||
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.RAG,
|
||||
params={"timeout": 60.0},
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
url,
|
||||
json=request_body,
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 201]:
|
||||
error_msg = f"Failed to import files: {response.text}"
|
||||
verbose_logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
response_data = response.json()
|
||||
operation_name = response_data.get("name", "")
|
||||
|
||||
verbose_logger.debug(f"Import operation started: {operation_name}")
|
||||
return operation_name
|
||||
441
llm-gateway-competitors/litellm-wheel-src/litellm/rag/main.py
Normal file
441
llm-gateway-competitors/litellm-wheel-src/litellm/rag/main.py
Normal file
@@ -0,0 +1,441 @@
|
||||
"""
|
||||
RAG Ingest API for LiteLLM.
|
||||
|
||||
Provides an all-in-one API for document ingestion:
|
||||
Upload -> (OCR) -> Chunk -> Embed -> Vector Store
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ["ingest", "aingest", "query", "aquery"]
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
from functools import partial
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Coroutine,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
|
||||
from litellm.rag.ingestion.bedrock_ingestion import BedrockRAGIngestion
|
||||
from litellm.rag.ingestion.gemini_ingestion import GeminiRAGIngestion
|
||||
from litellm.rag.ingestion.openai_ingestion import OpenAIRAGIngestion
|
||||
from litellm.rag.ingestion.s3_vectors_ingestion import S3VectorsRAGIngestion
|
||||
from litellm.rag.ingestion.vertex_ai_ingestion import VertexAIRAGIngestion
|
||||
from litellm.rag.rag_query import RAGQuery
|
||||
from litellm.types.rag import (
|
||||
RAGIngestOptions,
|
||||
RAGIngestResponse,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import Router
|
||||
|
||||
|
||||
# Registry of provider-specific ingestion classes
|
||||
INGESTION_REGISTRY: Dict[str, Type[BaseRAGIngestion]] = {
|
||||
"openai": OpenAIRAGIngestion,
|
||||
"bedrock": BedrockRAGIngestion,
|
||||
"gemini": GeminiRAGIngestion,
|
||||
"s3_vectors": S3VectorsRAGIngestion,
|
||||
"vertex_ai": VertexAIRAGIngestion,
|
||||
}
|
||||
|
||||
|
||||
def get_ingestion_class(provider: str) -> Type[BaseRAGIngestion]:
|
||||
"""
|
||||
Get the ingestion class for a given provider.
|
||||
|
||||
Args:
|
||||
provider: The vector store provider name (e.g., 'openai')
|
||||
|
||||
Returns:
|
||||
The ingestion class for the provider
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not supported
|
||||
"""
|
||||
ingestion_class = INGESTION_REGISTRY.get(provider)
|
||||
if ingestion_class is None:
|
||||
supported = ", ".join(INGESTION_REGISTRY.keys())
|
||||
raise ValueError(
|
||||
f"Provider '{provider}' is not supported for RAG ingestion. "
|
||||
f"Supported providers: {supported}"
|
||||
)
|
||||
return ingestion_class
|
||||
|
||||
|
||||
async def _execute_ingest_pipeline(
|
||||
ingest_options: RAGIngestOptions,
|
||||
file_data: Optional[Tuple[str, bytes, str]] = None,
|
||||
file_url: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
router: Optional["Router"] = None,
|
||||
) -> RAGIngestResponse:
|
||||
"""
|
||||
Execute the RAG ingest pipeline using provider-specific implementation.
|
||||
|
||||
Args:
|
||||
ingest_options: Configuration for the ingest pipeline
|
||||
file_data: Tuple of (filename, content_bytes, content_type)
|
||||
file_url: URL to fetch file from
|
||||
file_id: Existing file ID to use
|
||||
router: Optional LiteLLM router for load balancing
|
||||
|
||||
Returns:
|
||||
RAGIngestResponse with status and IDs
|
||||
"""
|
||||
# Get provider from vector store config
|
||||
vector_store_config = ingest_options.get("vector_store") or {}
|
||||
provider = vector_store_config.get("custom_llm_provider", "openai")
|
||||
|
||||
# Get provider-specific ingestion class
|
||||
ingestion_class = get_ingestion_class(provider)
|
||||
|
||||
# Create ingestion instance
|
||||
ingestion = ingestion_class(
|
||||
ingest_options=ingest_options,
|
||||
router=router,
|
||||
)
|
||||
|
||||
# Execute ingestion pipeline
|
||||
return await ingestion.ingest(
|
||||
file_data=file_data,
|
||||
file_url=file_url,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
|
||||
####### PUBLIC API ###################
|
||||
|
||||
|
||||
@client
|
||||
async def aingest(
|
||||
ingest_options: Dict[str, Any],
|
||||
file_data: Optional[Tuple[str, bytes, str]] = None,
|
||||
file: Optional[Dict[str, str]] = None,
|
||||
file_url: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
**kwargs,
|
||||
) -> RAGIngestResponse:
|
||||
"""
|
||||
Async: Ingest a document into a vector store.
|
||||
|
||||
Args:
|
||||
ingest_options: Configuration for the ingest pipeline
|
||||
file_data: Tuple of (filename, content_bytes, content_type)
|
||||
file: Dict with {filename, content (base64), content_type} - for JSON API
|
||||
file_url: URL to fetch file from
|
||||
file_id: Existing file ID to use
|
||||
|
||||
Example:
|
||||
```python
|
||||
response = await litellm.aingest(
|
||||
ingest_options={
|
||||
"vector_store": {
|
||||
"custom_llm_provider": "openai",
|
||||
"litellm_credential_name": "my-openai-creds", # optional
|
||||
}
|
||||
},
|
||||
file_url="https://example.com/doc.pdf",
|
||||
)
|
||||
```
|
||||
"""
|
||||
local_vars = locals()
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["aingest"] = True
|
||||
|
||||
func = partial(
|
||||
ingest,
|
||||
ingest_options=ingest_options,
|
||||
file_data=file_data,
|
||||
file=file,
|
||||
file_url=file_url,
|
||||
file_id=file_id,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model=None,
|
||||
custom_llm_provider=ingest_options.get("vector_store", {}).get(
|
||||
"custom_llm_provider"
|
||||
),
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def _execute_query_pipeline(
|
||||
model: str,
|
||||
messages: List[Any],
|
||||
retrieval_config: Dict[str, Any],
|
||||
rerank: Optional[Dict[str, Any]] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Execute the RAG query pipeline.
|
||||
"""
|
||||
# Extract router from kwargs - use it for completion if available
|
||||
# to properly resolve virtual model names
|
||||
router: Optional["Router"] = kwargs.pop("router", None)
|
||||
|
||||
# 1. Extract query from last user message
|
||||
query_text = RAGQuery.extract_query_from_messages(messages)
|
||||
if not query_text:
|
||||
raise ValueError("No query found in messages for RAG query")
|
||||
|
||||
# 2. Search vector store
|
||||
search_response = await litellm.vector_stores.asearch(
|
||||
vector_store_id=retrieval_config["vector_store_id"],
|
||||
query=query_text,
|
||||
max_num_results=retrieval_config.get("top_k", 10),
|
||||
custom_llm_provider=retrieval_config.get("custom_llm_provider", "openai"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
rerank_response = None
|
||||
context_chunks = search_response.get("data", [])
|
||||
|
||||
# 3. Optional rerank
|
||||
if rerank and rerank.get("enabled"):
|
||||
documents = RAGQuery.extract_documents_from_search(search_response)
|
||||
if documents:
|
||||
rerank_response = await litellm.arerank(
|
||||
model=rerank["model"],
|
||||
query=query_text,
|
||||
documents=documents,
|
||||
top_n=rerank.get("top_n", 5),
|
||||
)
|
||||
context_chunks = RAGQuery.get_top_chunks_from_rerank(
|
||||
search_response, rerank_response
|
||||
)
|
||||
|
||||
# 4. Build context message and call completion
|
||||
context_message = RAGQuery.build_context_message(context_chunks)
|
||||
modified_messages = messages[:-1] + [context_message] + [messages[-1]]
|
||||
|
||||
# Use router if available to properly resolve virtual model names
|
||||
if router is not None:
|
||||
response = await router.acompletion(
|
||||
model=model,
|
||||
messages=modified_messages,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=modified_messages,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# 5. Attach search results to response
|
||||
if not stream and isinstance(response, ModelResponse):
|
||||
response = RAGQuery.add_search_results_to_response(
|
||||
response=response,
|
||||
search_results=search_response,
|
||||
rerank_results=rerank_response,
|
||||
)
|
||||
|
||||
return response # type: ignore[return-value]
|
||||
|
||||
|
||||
@client
|
||||
async def aquery(
|
||||
model: str,
|
||||
messages: List[Any],
|
||||
retrieval_config: Dict[str, Any],
|
||||
rerank: Optional[Dict[str, Any]] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Async: Query a RAG pipeline.
|
||||
"""
|
||||
local_vars = locals()
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["aquery"] = True
|
||||
|
||||
func = partial(
|
||||
query,
|
||||
model=model,
|
||||
messages=messages,
|
||||
retrieval_config=retrieval_config,
|
||||
rerank=rerank,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model=model,
|
||||
custom_llm_provider=retrieval_config.get("custom_llm_provider"),
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
|
||||
@client
|
||||
def query(
|
||||
model: str,
|
||||
messages: List[Any],
|
||||
retrieval_config: Dict[str, Any],
|
||||
rerank: Optional[Dict[str, Any]] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[ModelResponse, Coroutine[Any, Any, ModelResponse]]:
|
||||
"""
|
||||
Query a RAG pipeline.
|
||||
"""
|
||||
local_vars = locals()
|
||||
try:
|
||||
_is_async = kwargs.pop("aquery", False) is True
|
||||
|
||||
if _is_async:
|
||||
return _execute_query_pipeline(
|
||||
model=model,
|
||||
messages=messages,
|
||||
retrieval_config=retrieval_config,
|
||||
rerank=rerank,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
_execute_query_pipeline(
|
||||
model=model,
|
||||
messages=messages,
|
||||
retrieval_config=retrieval_config,
|
||||
rerank=rerank,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model=model,
|
||||
custom_llm_provider=retrieval_config.get("custom_llm_provider"),
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
|
||||
@client
|
||||
def ingest(
|
||||
ingest_options: Dict[str, Any],
|
||||
file_data: Optional[Tuple[str, bytes, str]] = None,
|
||||
file: Optional[Dict[str, str]] = None,
|
||||
file_url: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
**kwargs,
|
||||
) -> Union[RAGIngestResponse, Coroutine[Any, Any, RAGIngestResponse]]:
|
||||
"""
|
||||
Ingest a document into a vector store.
|
||||
|
||||
Args:
|
||||
ingest_options: Configuration for the ingest pipeline
|
||||
file_data: Tuple of (filename, content_bytes, content_type)
|
||||
file: Dict with {filename, content (base64), content_type} - for JSON API
|
||||
file_url: URL to fetch file from
|
||||
file_id: Existing file ID to use
|
||||
|
||||
Example:
|
||||
```python
|
||||
response = litellm.ingest(
|
||||
ingest_options={
|
||||
"vector_store": {
|
||||
"custom_llm_provider": "openai",
|
||||
"litellm_credential_name": "my-openai-creds", # optional
|
||||
}
|
||||
},
|
||||
file_data=("doc.txt", b"Hello world", "text/plain"),
|
||||
)
|
||||
```
|
||||
"""
|
||||
import base64
|
||||
|
||||
local_vars = locals()
|
||||
try:
|
||||
_is_async = kwargs.pop("aingest", False) is True
|
||||
router: Optional["Router"] = kwargs.get("router")
|
||||
|
||||
# Convert file dict to file_data tuple if provided
|
||||
if file is not None and file_data is None:
|
||||
filename = file.get("filename", "document")
|
||||
content_b64 = file.get("content", "")
|
||||
content_type = file.get("content_type", "application/octet-stream")
|
||||
content_bytes = base64.b64decode(content_b64)
|
||||
file_data = (filename, content_bytes, content_type)
|
||||
|
||||
if _is_async:
|
||||
return _execute_ingest_pipeline(
|
||||
ingest_options=ingest_options, # type: ignore
|
||||
file_data=file_data,
|
||||
file_url=file_url,
|
||||
file_id=file_id,
|
||||
router=router,
|
||||
)
|
||||
else:
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
_execute_ingest_pipeline(
|
||||
ingest_options=ingest_options, # type: ignore
|
||||
file_data=file_data,
|
||||
file_url=file_url,
|
||||
file_id=file_id,
|
||||
router=router,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model=None,
|
||||
custom_llm_provider=ingest_options.get("vector_store", {}).get(
|
||||
"custom_llm_provider"
|
||||
),
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
@@ -0,0 +1,120 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.types.vector_stores import (
|
||||
VectorStoreResultContent,
|
||||
VectorStoreSearchResponse,
|
||||
)
|
||||
|
||||
|
||||
class RAGQuery:
|
||||
CONTENT_PREFIX_STRING = "Context:\n\n"
|
||||
|
||||
@staticmethod
|
||||
def extract_query_from_messages(messages: List[AllMessageValues]) -> Optional[str]:
|
||||
"""
|
||||
Extract the query from the last user message.
|
||||
"""
|
||||
if not messages or len(messages) == 0:
|
||||
return None
|
||||
|
||||
last_message = messages[-1]
|
||||
if not isinstance(last_message, dict) or "content" not in last_message:
|
||||
return None
|
||||
|
||||
content = last_message["content"]
|
||||
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list) and len(content) > 0:
|
||||
# Handle list of content items, extract text from first text item
|
||||
for item in content:
|
||||
if (
|
||||
isinstance(item, dict)
|
||||
and item.get("type") == "text"
|
||||
and "text" in item
|
||||
):
|
||||
return item["text"]
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def build_context_message(context_chunks: List[Any]) -> ChatCompletionUserMessage:
|
||||
"""
|
||||
Process search results and build a context message.
|
||||
"""
|
||||
context_content = RAGQuery.CONTENT_PREFIX_STRING
|
||||
|
||||
for chunk in context_chunks:
|
||||
if isinstance(chunk, dict):
|
||||
result_content: Optional[List[VectorStoreResultContent]] = chunk.get(
|
||||
"content"
|
||||
)
|
||||
if result_content:
|
||||
for content_item in result_content:
|
||||
content_text: Optional[str] = content_item.get("text")
|
||||
if content_text:
|
||||
context_content += content_text + "\n\n"
|
||||
elif "text" in chunk: # Fallback for simple dict with text
|
||||
context_content += chunk["text"] + "\n\n"
|
||||
elif isinstance(chunk, str):
|
||||
context_content += chunk + "\n\n"
|
||||
|
||||
return {
|
||||
"role": "user",
|
||||
"content": context_content,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def add_search_results_to_response(
|
||||
response: ModelResponse,
|
||||
search_results: VectorStoreSearchResponse,
|
||||
rerank_results: Optional[Any] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Add search results to the response choices.
|
||||
"""
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
for choice in response.choices:
|
||||
message = getattr(choice, "message", None)
|
||||
if message is not None:
|
||||
# Get existing provider_specific_fields or create new dict
|
||||
provider_fields = (
|
||||
getattr(message, "provider_specific_fields", None) or {}
|
||||
)
|
||||
|
||||
# Add search results
|
||||
provider_fields["search_results"] = search_results
|
||||
if rerank_results:
|
||||
provider_fields["rerank_results"] = rerank_results
|
||||
|
||||
# Set the provider_specific_fields
|
||||
setattr(message, "provider_specific_fields", provider_fields)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def extract_documents_from_search(
|
||||
search_response: Any,
|
||||
) -> List[Union[str, Dict[str, Any]]]:
|
||||
"""Extract text documents from vector store search response."""
|
||||
documents: List[Union[str, Dict[str, Any]]] = []
|
||||
for result in search_response.get("data", []):
|
||||
content_list = result.get("content", [])
|
||||
for content in content_list:
|
||||
if content.get("type") == "text" and content.get("text"):
|
||||
documents.append(content["text"])
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_top_chunks_from_rerank(
|
||||
search_response: Any, rerank_response: Any
|
||||
) -> List[Any]:
|
||||
"""Get the original search results corresponding to the top reranked results."""
|
||||
top_chunks = []
|
||||
original_results = search_response.get("data", [])
|
||||
for result in rerank_response.get("results", []):
|
||||
index = result.get("index")
|
||||
if index is not None and index < len(original_results):
|
||||
top_chunks.append(original_results[index])
|
||||
return top_chunks
|
||||
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Text splitting utilities for RAG ingestion.
|
||||
"""
|
||||
|
||||
from litellm.rag.text_splitters.recursive_character_text_splitter import (
|
||||
RecursiveCharacterTextSplitter,
|
||||
)
|
||||
|
||||
__all__ = ["RecursiveCharacterTextSplitter"]
|
||||
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
RecursiveCharacterTextSplitter for RAG ingestion.
|
||||
|
||||
A simple implementation that splits text recursively by different separators.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
|
||||
|
||||
|
||||
class RecursiveCharacterTextSplitter:
|
||||
"""
|
||||
Split text recursively by different separators.
|
||||
|
||||
Tries to split by the first separator, then recursively splits
|
||||
by subsequent separators if chunks are still too large.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
||||
separators: Optional[List[str]] = None,
|
||||
):
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
self.separators = separators or ["\n\n", "\n", " ", ""]
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split text into chunks."""
|
||||
return self._split_text(text, self.separators)
|
||||
|
||||
def _split_text(
|
||||
self, text: str, separators: List[str], depth: int = 0
|
||||
) -> List[str]:
|
||||
"""Recursively split text using separators."""
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
|
||||
if depth > DEFAULT_MAX_RECURSE_DEPTH:
|
||||
# Max depth reached, return text as-is split into chunk_size pieces
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
for i in range(0, len(text), self.chunk_size)
|
||||
]
|
||||
|
||||
final_chunks: List[str] = []
|
||||
|
||||
# Get the appropriate separator
|
||||
separator = separators[-1]
|
||||
new_separators: List[str] = []
|
||||
|
||||
for i, sep in enumerate(separators):
|
||||
if sep == "":
|
||||
separator = sep
|
||||
break
|
||||
if sep in text:
|
||||
separator = sep
|
||||
new_separators = separators[i + 1 :]
|
||||
break
|
||||
|
||||
# Split by the chosen separator
|
||||
if separator:
|
||||
splits = text.split(separator)
|
||||
else:
|
||||
splits = list(text)
|
||||
|
||||
# Merge splits into chunks
|
||||
good_splits: List[str] = []
|
||||
for split in splits:
|
||||
if len(split) < self.chunk_size:
|
||||
good_splits.append(split)
|
||||
else:
|
||||
# Chunk is too big, merge what we have and recurse
|
||||
if good_splits:
|
||||
merged = self._merge_splits(good_splits, separator)
|
||||
final_chunks.extend(merged)
|
||||
good_splits = []
|
||||
|
||||
if new_separators:
|
||||
# Recursively split with finer separators
|
||||
other_chunks = self._split_text(split, new_separators, depth + 1)
|
||||
final_chunks.extend(other_chunks)
|
||||
else:
|
||||
# No more separators, force split
|
||||
final_chunks.extend(self._force_split(split))
|
||||
|
||||
# Merge remaining good splits
|
||||
if good_splits:
|
||||
merged = self._merge_splits(good_splits, separator)
|
||||
final_chunks.extend(merged)
|
||||
|
||||
return final_chunks
|
||||
|
||||
def _merge_splits(self, splits: List[str], separator: str) -> List[str]:
|
||||
"""Merge splits into chunks respecting chunk_size and chunk_overlap."""
|
||||
chunks: List[str] = []
|
||||
current_chunk: List[str] = []
|
||||
current_length = 0
|
||||
|
||||
for split in splits:
|
||||
split_len = len(split)
|
||||
sep_len = len(separator) if current_chunk else 0
|
||||
|
||||
if current_length + split_len + sep_len > self.chunk_size:
|
||||
if current_chunk:
|
||||
chunk_text = separator.join(current_chunk).strip()
|
||||
if chunk_text:
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Handle overlap
|
||||
while (
|
||||
current_length > self.chunk_overlap and len(current_chunk) > 1
|
||||
):
|
||||
removed = current_chunk.pop(0)
|
||||
current_length -= len(removed) + len(separator)
|
||||
|
||||
current_chunk.append(split)
|
||||
current_length += split_len + sep_len
|
||||
|
||||
# Add remaining
|
||||
if current_chunk:
|
||||
chunk_text = separator.join(current_chunk).strip()
|
||||
if chunk_text:
|
||||
chunks.append(chunk_text)
|
||||
|
||||
return chunks
|
||||
|
||||
def _force_split(self, text: str) -> List[str]:
|
||||
"""Force split text by chunk_size when no separator works."""
|
||||
chunks: List[str] = []
|
||||
start = 0
|
||||
|
||||
while start < len(text):
|
||||
end = start + self.chunk_size
|
||||
chunk = text[start:end].strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
start = end - self.chunk_overlap if end < len(text) else len(text)
|
||||
|
||||
return chunks
|
||||
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
RAG utility functions.
|
||||
|
||||
Provides provider configuration utilities similar to ProviderConfigManager.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
|
||||
|
||||
|
||||
def get_rag_ingestion_class(custom_llm_provider: str) -> Type["BaseRAGIngestion"]:
|
||||
"""
|
||||
Get the appropriate RAG ingestion class for a provider.
|
||||
|
||||
Args:
|
||||
custom_llm_provider: The LLM provider name (e.g., "openai", "bedrock", "vertex_ai")
|
||||
|
||||
Returns:
|
||||
The ingestion class for the provider
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not supported
|
||||
"""
|
||||
from litellm.llms.vertex_ai.rag_engine.ingestion import VertexAIRAGIngestion
|
||||
from litellm.rag.ingestion.bedrock_ingestion import BedrockRAGIngestion
|
||||
from litellm.rag.ingestion.openai_ingestion import OpenAIRAGIngestion
|
||||
|
||||
provider_map = {
|
||||
"openai": OpenAIRAGIngestion,
|
||||
"bedrock": BedrockRAGIngestion,
|
||||
"vertex_ai": VertexAIRAGIngestion,
|
||||
}
|
||||
|
||||
ingestion_class = provider_map.get(custom_llm_provider)
|
||||
if ingestion_class is None:
|
||||
raise ValueError(
|
||||
f"RAG ingestion not supported for provider: {custom_llm_provider}. "
|
||||
f"Supported providers: {list(provider_map.keys())}"
|
||||
)
|
||||
|
||||
return ingestion_class
|
||||
|
||||
|
||||
def get_rag_transformation_class(custom_llm_provider: str):
|
||||
"""
|
||||
Get the appropriate RAG transformation class for a provider.
|
||||
|
||||
Args:
|
||||
custom_llm_provider: The LLM provider name
|
||||
|
||||
Returns:
|
||||
The transformation class for the provider, or None if not needed
|
||||
"""
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
from litellm.llms.vertex_ai.rag_engine.transformation import (
|
||||
VertexAIRAGTransformation,
|
||||
)
|
||||
|
||||
return VertexAIRAGTransformation
|
||||
|
||||
# OpenAI and Bedrock don't need special transformations
|
||||
return None
|
||||
Reference in New Issue
Block a user