chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,21 @@
"""
LiteLLM RAG (Retrieval Augmented Generation) Module.
Provides an all-in-one API for document ingestion:
Upload -> (OCR) -> Chunk -> Embed -> Vector Store
"""
from litellm.rag.main import aingest, aquery, ingest, query
__all__ = ["ingest", "aingest", "query", "aquery"]
# Expose at litellm.rag level for convenience
async def arag_ingest(*args, **kwargs):
"""Alias for aingest."""
return await aingest(*args, **kwargs)
def rag_ingest(*args, **kwargs):
"""Alias for ingest."""
return ingest(*args, **kwargs)

View File

@@ -0,0 +1,19 @@
"""
RAG Ingestion classes for different providers.
"""
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
from litellm.rag.ingestion.bedrock_ingestion import BedrockRAGIngestion
from litellm.rag.ingestion.gemini_ingestion import GeminiRAGIngestion
from litellm.rag.ingestion.openai_ingestion import OpenAIRAGIngestion
from litellm.rag.ingestion.s3_vectors_ingestion import S3VectorsRAGIngestion
from litellm.rag.ingestion.vertex_ai_ingestion import VertexAIRAGIngestion
__all__ = [
"BaseRAGIngestion",
"BedrockRAGIngestion",
"GeminiRAGIngestion",
"OpenAIRAGIngestion",
"S3VectorsRAGIngestion",
"VertexAIRAGIngestion",
]

View File

@@ -0,0 +1,358 @@
"""
Base RAG Ingestion class.
Provides abstract methods for:
- OCR
- Chunking
- Embedding
- Vector Store operations
Providers can inherit and override methods as needed.
"""
from __future__ import annotations
import base64
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
import litellm
from litellm._logging import verbose_logger
from litellm._uuid import uuid4
from litellm.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.rag.ingestion.file_parsers import extract_text_from_pdf
from litellm.rag.text_splitters import RecursiveCharacterTextSplitter
from litellm.types.rag import RAGIngestOptions, RAGIngestResponse
if TYPE_CHECKING:
from litellm import Router
class BaseRAGIngestion(ABC):
"""
Base class for RAG ingestion.
Providers should inherit from this class and override methods as needed.
For example, OpenAI handles embedding internally when attaching files to
vector stores, so it overrides the embedding step to be a no-op.
"""
def __init__(
self,
ingest_options: RAGIngestOptions,
router: Optional["Router"] = None,
):
self.ingest_options = ingest_options
self.router = router
self.ingest_id = f"ingest_{uuid4()}"
# Extract configs from options
self.ocr_config = ingest_options.get("ocr")
self.chunking_strategy: Dict[str, Any] = cast(
Dict[str, Any],
ingest_options.get("chunking_strategy") or {"type": "auto"},
)
self.embedding_config = ingest_options.get("embedding")
self.vector_store_config: Dict[str, Any] = cast(
Dict[str, Any], ingest_options.get("vector_store") or {}
)
self.ingest_name = ingest_options.get("name")
# Load credentials from litellm_credential_name if provided in vector_store config
self._load_credentials_from_config()
def _load_credentials_from_config(self) -> None:
"""
Load credentials from litellm_credential_name if provided in vector_store config.
This allows users to specify a credential name in the vector_store config
which will be resolved from litellm.credential_list.
"""
from litellm.litellm_core_utils.credential_accessor import CredentialAccessor
credential_name = self.vector_store_config.get("litellm_credential_name")
if credential_name and litellm.credential_list:
credential_values = CredentialAccessor.get_credential_values(
credential_name
)
# Merge credentials into vector_store_config (don't overwrite existing values)
for key, value in credential_values.items():
if key not in self.vector_store_config:
self.vector_store_config[key] = value
@property
def custom_llm_provider(self) -> str:
"""Get the vector store provider."""
return self.vector_store_config.get("custom_llm_provider", "openai")
async def upload(
self,
file_data: Optional[Tuple[str, bytes, str]] = None,
file_url: Optional[str] = None,
file_id: Optional[str] = None,
) -> Tuple[Optional[str], Optional[bytes], Optional[str], Optional[str]]:
"""
Upload / prepare file for ingestion.
Args:
file_data: Tuple of (filename, content_bytes, content_type)
file_url: URL to fetch file from
file_id: Existing file ID to use
Returns:
Tuple of (filename, file_content, content_type, existing_file_id)
"""
if file_data:
filename, file_content, content_type = file_data
return filename, file_content, content_type, None
if file_url:
http_client = get_async_httpx_client(llm_provider=httpxSpecialProvider.RAG)
response = await http_client.get(file_url)
response.raise_for_status()
file_content = response.content
filename = file_url.split("/")[-1] or "document"
content_type = response.headers.get(
"content-type", "application/octet-stream"
)
return filename, file_content, content_type, None
if file_id:
return None, None, None, file_id
raise ValueError("Must provide file_data, file_url, or file_id")
async def ocr(
self,
file_content: Optional[bytes],
content_type: Optional[str],
) -> Optional[str]:
"""
Perform OCR on file content to extract text.
Args:
file_content: Raw file bytes
content_type: MIME type of the file
Returns:
Extracted text or None if OCR not configured/needed
"""
if not self.ocr_config or not file_content:
return None
ocr_model = self.ocr_config.get("model", "mistral/mistral-ocr-latest")
# Determine document type
if content_type and "image" in content_type:
doc_type, url_key = "image_url", "image_url"
else:
doc_type, url_key = "document_url", "document_url"
# Encode as base64 data URL
b64_content = base64.b64encode(file_content).decode("utf-8")
data_url = f"data:{content_type};base64,{b64_content}"
# Use router if available
if self.router is not None:
ocr_response = await self.router.aocr(
model=ocr_model,
document={"type": doc_type, url_key: data_url},
)
else:
ocr_response = await litellm.aocr(
model=ocr_model,
document={"type": doc_type, url_key: data_url},
)
# Extract text from pages
if hasattr(ocr_response, "pages") and ocr_response.pages: # type: ignore
return "\n\n".join(
page.markdown for page in ocr_response.pages if hasattr(page, "markdown") # type: ignore
)
return None
def chunk(
self,
text: Optional[str],
file_content: Optional[bytes],
ocr_was_used: bool,
) -> List[str]:
"""
Split text into chunks using RecursiveCharacterTextSplitter.
Args:
text: Text from OCR (if used)
file_content: Raw file content bytes
ocr_was_used: Whether OCR was performed
Returns:
List of text chunks
"""
# Get text to chunk
text_to_chunk: Optional[str] = None
if text:
text_to_chunk = text
elif file_content and not ocr_was_used:
# Try UTF-8 decode first
try:
text_to_chunk = file_content.decode("utf-8")
except UnicodeDecodeError:
# Check if it's a PDF and try to extract text
if file_content.startswith(b"%PDF"):
verbose_logger.debug("PDF detected, attempting text extraction")
text_to_chunk = extract_text_from_pdf(file_content)
if not text_to_chunk:
verbose_logger.debug(
"PDF text extraction failed. Install 'pypdf' or 'PyPDF2' for PDF support, "
"or enable OCR with a vision model."
)
return []
else:
verbose_logger.debug("Binary file detected, skipping text chunking")
return []
if not text_to_chunk:
return []
# Extract RecursiveCharacterTextSplitter args
splitter_args = self.chunking_strategy or {}
chunk_size = splitter_args.get("chunk_size", DEFAULT_CHUNK_SIZE)
chunk_overlap = splitter_args.get("chunk_overlap", DEFAULT_CHUNK_OVERLAP)
separators = splitter_args.get("separators", None)
# Build splitter kwargs
splitter_kwargs: Dict[str, Any] = {
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
}
if separators:
splitter_kwargs["separators"] = separators
text_splitter = RecursiveCharacterTextSplitter(**splitter_kwargs)
return text_splitter.split_text(text_to_chunk)
async def embed(
self,
chunks: List[str],
) -> Optional[List[List[float]]]:
"""
Generate embeddings for text chunks.
Args:
chunks: List of text chunks
Returns:
List of embeddings or None
"""
if not self.embedding_config or not chunks:
return None
embedding_model = self.embedding_config.get("model", "text-embedding-3-small")
if self.router is not None:
response = await self.router.aembedding(model=embedding_model, input=chunks)
else:
response = await litellm.aembedding(model=embedding_model, input=chunks)
return [item["embedding"] for item in response.data]
@abstractmethod
async def store(
self,
file_content: Optional[bytes],
filename: Optional[str],
content_type: Optional[str],
chunks: List[str],
embeddings: Optional[List[List[float]]],
) -> Tuple[Optional[str], Optional[str]]:
"""
Store content in vector store.
This method must be implemented by provider-specific subclasses.
Args:
file_content: Raw file bytes
filename: Name of the file
content_type: MIME type
chunks: Text chunks (if chunking was done locally)
embeddings: Embeddings (if embedding was done locally)
Returns:
Tuple of (vector_store_id, file_id)
"""
pass
async def ingest(
self,
file_data: Optional[Tuple[str, bytes, str]] = None,
file_url: Optional[str] = None,
file_id: Optional[str] = None,
) -> RAGIngestResponse:
"""
Execute the full ingestion pipeline.
Args:
file_data: Tuple of (filename, content_bytes, content_type)
file_url: URL to fetch file from
file_id: Existing file ID to use
Returns:
RAGIngestResponse with status and IDs
Raises:
ValueError: If no input source is provided
"""
# Step 1: Upload (raises ValueError if no input provided)
filename, file_content, content_type, existing_file_id = await self.upload(
file_data=file_data,
file_url=file_url,
file_id=file_id,
)
try:
# Step 2: OCR (optional)
extracted_text = await self.ocr(
file_content=file_content,
content_type=content_type,
)
# Step 3: Chunking
chunks = self.chunk(
text=extracted_text,
file_content=file_content,
ocr_was_used=self.ocr_config is not None,
)
# Step 4: Embedding (optional - some providers handle this internally)
embeddings = await self.embed(chunks=chunks)
# Step 5: Store in vector store
vector_store_id, result_file_id = await self.store(
file_content=file_content,
filename=filename,
content_type=content_type,
chunks=chunks,
embeddings=embeddings,
)
return RAGIngestResponse(
id=self.ingest_id,
status="completed",
vector_store_id=vector_store_id or "",
file_id=result_file_id or existing_file_id,
)
except Exception as e:
verbose_logger.exception(f"RAG Pipeline failed: {e}")
return RAGIngestResponse(
id=self.ingest_id,
status="failed",
vector_store_id="",
file_id=None,
error=str(e),
)

View File

@@ -0,0 +1,775 @@
"""
Bedrock-specific RAG Ingestion implementation.
Bedrock Knowledge Bases handle embedding internally when files are ingested,
so this implementation uploads files to S3 and triggers ingestion jobs.
Supports two modes:
1. Use existing KB: Provide vector_store_id (KB ID)
2. Auto-create KB: Don't provide vector_store_id - creates all AWS resources automatically
"""
from __future__ import annotations
import asyncio
import json
import uuid
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from litellm._logging import verbose_logger
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
if TYPE_CHECKING:
from litellm import Router
from litellm.types.rag import RAGIngestOptions
def _get_str_or_none(value: Any) -> Optional[str]:
"""Cast config value to Optional[str]."""
return str(value) if value is not None else None
def _get_int(value: Any, default: int) -> int:
"""Cast config value to int with default."""
if value is None:
return default
return int(value)
def _normalize_principal_arn(caller_arn: str, account_id: str) -> str:
"""
Normalize a caller ARN to the format required by OpenSearch data access policies.
OpenSearch Serverless data access policies require:
- IAM users: arn:aws:iam::account-id:user/user-name
- IAM roles: arn:aws:iam::account-id:role/role-name
But get_caller_identity() returns for assumed roles:
- arn:aws:sts::account-id:assumed-role/role-name/session-name
This function converts assumed-role ARNs to the proper IAM role ARN format.
"""
if ":assumed-role/" in caller_arn:
# Extract role name from assumed-role ARN
# Format: arn:aws:sts::ACCOUNT:assumed-role/ROLE-NAME/SESSION-NAME
parts = caller_arn.split("/")
if len(parts) >= 2:
role_name = parts[1]
return f"arn:aws:iam::{account_id}:role/{role_name}"
return caller_arn
class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
"""
Bedrock Knowledge Base RAG ingestion.
Supports two modes:
1. **Use existing KB**: Provide vector_store_id
2. **Auto-create KB**: Don't provide vector_store_id - creates S3 bucket,
OpenSearch Serverless collection, IAM role, KB, and data source automatically
Optional config:
- vector_store_id: Existing KB ID (if not provided, auto-creates)
- s3_bucket: S3 bucket (auto-created if not provided)
- embedding_model: Bedrock embedding model (default: amazon.titan-embed-text-v2:0)
- wait_for_ingestion: Wait for completion (default: True)
- ingestion_timeout: Max seconds to wait (default: 300)
AWS Auth (uses BaseAWSLLM):
- aws_access_key_id, aws_secret_access_key, aws_session_token
- aws_region_name (default: us-west-2)
- aws_role_name, aws_session_name, aws_profile_name
- aws_web_identity_token, aws_sts_endpoint, aws_external_id
"""
def __init__(
self,
ingest_options: "RAGIngestOptions",
router: Optional["Router"] = None,
):
BaseRAGIngestion.__init__(self, ingest_options=ingest_options, router=router)
BaseAWSLLM.__init__(self)
# Use vector_store_id as unified param (maps to knowledge_base_id)
self.knowledge_base_id = self.vector_store_config.get(
"vector_store_id"
) or self.vector_store_config.get("knowledge_base_id")
# Optional config
self._data_source_id = self.vector_store_config.get("data_source_id")
self._s3_bucket = self.vector_store_config.get("s3_bucket")
self._s3_prefix: Optional[str] = (
str(self.vector_store_config.get("s3_prefix"))
if self.vector_store_config.get("s3_prefix")
else None
)
self.embedding_model = (
self.vector_store_config.get("embedding_model")
or "amazon.titan-embed-text-v2:0"
)
self.wait_for_ingestion = self.vector_store_config.get(
"wait_for_ingestion", False
)
self.ingestion_timeout: int = _get_int(
self.vector_store_config.get("ingestion_timeout"), 300
)
# Get AWS region using BaseAWSLLM method
_aws_region = self.vector_store_config.get("aws_region_name")
self.aws_region_name = self.get_aws_region_name_for_non_llm_api_calls(
aws_region_name=str(_aws_region) if _aws_region else None
)
# Will be set during initialization
self.data_source_id: Optional[str] = None
self.s3_bucket: Optional[str] = None
self.s3_prefix: str = self._s3_prefix or "data/"
self._config_initialized = False
# Track resources we create (for cleanup if needed)
self._created_resources: Dict[str, Any] = {}
async def _ensure_config_initialized(self):
"""Lazily initialize KB config - either detect from existing or create new."""
if self._config_initialized:
return
if self.knowledge_base_id:
# Use existing KB - auto-detect data source and S3 bucket
self._auto_detect_config()
else:
# No KB provided - create everything from scratch
await self._create_knowledge_base_infrastructure()
self._config_initialized = True
def _auto_detect_config(self):
"""Auto-detect data source ID and S3 bucket from existing Knowledge Base."""
verbose_logger.debug(
f"Auto-detecting data source and S3 bucket for KB={self.knowledge_base_id}"
)
bedrock_agent = self._get_boto3_client("bedrock-agent")
# List data sources for this KB
ds_response = bedrock_agent.list_data_sources(
knowledgeBaseId=self.knowledge_base_id
)
data_sources = ds_response.get("dataSourceSummaries", [])
if not data_sources:
raise ValueError(
f"No data sources found for Knowledge Base {self.knowledge_base_id}. "
"Please create a data source first or provide data_source_id and s3_bucket."
)
# Use first data source (or user-provided override)
if self._data_source_id:
self.data_source_id = self._data_source_id
else:
self.data_source_id = data_sources[0]["dataSourceId"]
verbose_logger.info(f"Auto-detected data source: {self.data_source_id}")
# Get data source details for S3 bucket
ds_details = bedrock_agent.get_data_source(
knowledgeBaseId=self.knowledge_base_id,
dataSourceId=self.data_source_id,
)
s3_config = (
ds_details.get("dataSource", {})
.get("dataSourceConfiguration", {})
.get("s3Configuration", {})
)
bucket_arn = s3_config.get("bucketArn", "")
if bucket_arn:
# Extract bucket name from ARN: arn:aws:s3:::bucket-name
self.s3_bucket = self._s3_bucket or bucket_arn.split(":")[-1]
verbose_logger.info(f"Auto-detected S3 bucket: {self.s3_bucket}")
# Use inclusion prefix if available
prefixes = s3_config.get("inclusionPrefixes", [])
if prefixes and not self._s3_prefix:
self.s3_prefix = prefixes[0]
else:
if not self._s3_bucket:
raise ValueError(
f"Could not auto-detect S3 bucket for data source {self.data_source_id}. "
"Please provide s3_bucket in config."
)
self.s3_bucket = self._s3_bucket
async def _create_knowledge_base_infrastructure(self):
"""Create all AWS resources needed for a new Knowledge Base."""
verbose_logger.info("Creating new Bedrock Knowledge Base infrastructure...")
# Generate unique names
unique_id = uuid.uuid4().hex[:8]
kb_name = self.ingest_name or f"litellm-kb-{unique_id}"
# Get AWS account ID and caller ARN (for data access policy)
sts = self._get_boto3_client("sts")
caller_identity = sts.get_caller_identity()
account_id = caller_identity["Account"]
caller_arn = caller_identity["Arn"]
# Step 1: Create S3 bucket (if not provided)
self.s3_bucket = self._s3_bucket or self._create_s3_bucket(unique_id)
# Step 2: Create OpenSearch Serverless collection
collection_name, collection_arn = await self._create_opensearch_collection(
unique_id, account_id, caller_arn
)
# Step 3: Create OpenSearch index
await self._create_opensearch_index(collection_name)
# Step 4: Create IAM role for Bedrock
role_arn = await self._create_bedrock_role(
unique_id, account_id, collection_arn
)
# Step 5: Create Knowledge Base
self.knowledge_base_id = await self._create_knowledge_base(
kb_name, role_arn, collection_arn
)
# Step 6: Create Data Source
self.data_source_id = self._create_data_source(kb_name)
verbose_logger.info(
f"Created KB infrastructure: kb_id={self.knowledge_base_id}, "
f"ds_id={self.data_source_id}, bucket={self.s3_bucket}"
)
def _create_s3_bucket(self, unique_id: str) -> str:
"""Create S3 bucket for KB data source."""
s3 = self._get_boto3_client("s3")
bucket_name = f"litellm-kb-{unique_id}"
verbose_logger.debug(f"Creating S3 bucket: {bucket_name}")
create_params: Dict[str, Any] = {"Bucket": bucket_name}
if self.aws_region_name != "us-east-1":
create_params["CreateBucketConfiguration"] = {
"LocationConstraint": self.aws_region_name
}
s3.create_bucket(**create_params)
self._created_resources["s3_bucket"] = bucket_name
verbose_logger.info(f"Created S3 bucket: {bucket_name}")
return bucket_name
async def _create_opensearch_collection(
self, unique_id: str, account_id: str, caller_arn: str
) -> Tuple[str, str]:
"""Create OpenSearch Serverless collection for vector storage."""
oss = self._get_boto3_client("opensearchserverless")
collection_name = f"litellm-kb-{unique_id}"
verbose_logger.debug(
f"Creating OpenSearch Serverless collection: {collection_name}"
)
# Create encryption policy
oss.create_security_policy(
name=f"{collection_name}-enc",
type="encryption",
policy=json.dumps(
{
"Rules": [
{
"ResourceType": "collection",
"Resource": [f"collection/{collection_name}"],
}
],
"AWSOwnedKey": True,
}
),
)
# Create network policy (public access for simplicity)
oss.create_security_policy(
name=f"{collection_name}-net",
type="network",
policy=json.dumps(
[
{
"Rules": [
{
"ResourceType": "collection",
"Resource": [f"collection/{collection_name}"],
},
{
"ResourceType": "dashboard",
"Resource": [f"collection/{collection_name}"],
},
],
"AllowFromPublic": True,
}
]
),
)
# Create data access policy - include both root and actual caller ARN
# This ensures the credentials being used have access to the collection
# Normalize the caller ARN (convert assumed-role ARN to IAM role ARN if needed)
normalized_caller_arn = _normalize_principal_arn(caller_arn, account_id)
verbose_logger.debug(
f"Caller ARN: {caller_arn}, Normalized: {normalized_caller_arn}"
)
principals = [f"arn:aws:iam::{account_id}:root", normalized_caller_arn]
# Deduplicate in case caller is root
principals = list(set(principals))
oss.create_access_policy(
name=f"{collection_name}-access",
type="data",
policy=json.dumps(
[
{
"Rules": [
{
"ResourceType": "index",
"Resource": [f"index/{collection_name}/*"],
"Permission": ["aoss:*"],
},
{
"ResourceType": "collection",
"Resource": [f"collection/{collection_name}"],
"Permission": ["aoss:*"],
},
],
"Principal": principals,
}
]
),
)
# Create collection
response = oss.create_collection(
name=collection_name,
type="VECTORSEARCH",
)
collection_id = response["createCollectionDetail"]["id"]
self._created_resources["opensearch_collection"] = collection_name
# Wait for collection to be active (use asyncio.sleep to avoid blocking)
verbose_logger.debug("Waiting for OpenSearch collection to be active...")
for _ in range(60): # 5 min timeout
status_response = oss.batch_get_collection(ids=[collection_id])
status = status_response["collectionDetails"][0]["status"]
if status == "ACTIVE":
break
await asyncio.sleep(5)
else:
raise TimeoutError("OpenSearch collection did not become active in time")
collection_arn = status_response["collectionDetails"][0]["arn"]
verbose_logger.info(f"Created OpenSearch collection: {collection_name}")
# Wait for data access policy to propagate before returning
# AWS recommends waiting 60+ seconds for policy propagation
verbose_logger.debug("Waiting for data access policy to propagate (60s)...")
await asyncio.sleep(60)
return collection_name, collection_arn
async def _create_opensearch_index(self, collection_name: str):
"""Create vector index in OpenSearch collection with retry logic."""
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth
# Get credentials for signing
credentials = self.get_credentials(
aws_access_key_id=_get_str_or_none(
self.vector_store_config.get("aws_access_key_id")
),
aws_secret_access_key=_get_str_or_none(
self.vector_store_config.get("aws_secret_access_key")
),
aws_session_token=_get_str_or_none(
self.vector_store_config.get("aws_session_token")
),
aws_region_name=self.aws_region_name,
)
# Get collection endpoint
oss = self._get_boto3_client("opensearchserverless")
collections = oss.batch_get_collection(names=[collection_name])
endpoint = collections["collectionDetails"][0]["collectionEndpoint"]
host = endpoint.replace("https://", "")
auth = AWS4Auth(
credentials.access_key,
credentials.secret_key,
self.aws_region_name,
"aoss",
session_token=credentials.token,
)
client = OpenSearch(
hosts=[{"host": host, "port": 443}],
http_auth=auth,
use_ssl=True,
verify_certs=True,
connection_class=RequestsHttpConnection,
)
index_name = "bedrock-kb-index"
index_body = {
"settings": {"index": {"knn": True, "knn.algo_param.ef_search": 512}},
"mappings": {
"properties": {
"bedrock-knowledge-base-default-vector": {
"type": "knn_vector",
"dimension": 1024,
"method": {
"engine": "faiss",
"name": "hnsw",
"space_type": "l2",
},
},
"AMAZON_BEDROCK_METADATA": {"type": "text", "index": False},
"AMAZON_BEDROCK_TEXT_CHUNK": {"type": "text"},
}
},
}
# Retry logic for index creation - data access policy may take time to propagate
max_retries = 8
retry_delay = 20 # seconds
last_error = None
for attempt in range(max_retries):
try:
client.indices.create(index=index_name, body=index_body)
verbose_logger.info(f"Created OpenSearch index: {index_name}")
return
except Exception as e:
last_error = e
error_str = str(e)
if (
"authorization_exception" in error_str.lower()
or "security_exception" in error_str.lower()
):
verbose_logger.warning(
f"OpenSearch index creation attempt {attempt + 1}/{max_retries} failed due to authorization. "
f"Waiting {retry_delay}s for policy propagation..."
)
await asyncio.sleep(retry_delay)
else:
# Non-auth error, raise immediately
raise
# All retries exhausted
raise RuntimeError(
f"Failed to create OpenSearch index after {max_retries} attempts. "
f"Data access policy may not have propagated. Last error: {last_error}"
)
async def _create_bedrock_role(
self, unique_id: str, account_id: str, collection_arn: str
) -> str:
"""Create IAM role for Bedrock KB."""
iam = self._get_boto3_client("iam")
role_name = f"litellm-bedrock-kb-{unique_id}"
verbose_logger.debug(f"Creating IAM role: {role_name}")
trust_policy = {
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {"Service": "bedrock.amazonaws.com"},
"Action": "sts:AssumeRole",
"Condition": {
"StringEquals": {"aws:SourceAccount": account_id},
"ArnLike": {
"aws:SourceArn": f"arn:aws:bedrock:{self.aws_region_name}:{account_id}:knowledge-base/*"
},
},
}
],
}
response = iam.create_role(
RoleName=role_name,
AssumeRolePolicyDocument=json.dumps(trust_policy),
)
role_arn = response["Role"]["Arn"]
self._created_resources["iam_role"] = role_name
# Attach permissions policy
permissions_policy = {
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": ["bedrock:InvokeModel"],
"Resource": [
f"arn:aws:bedrock:{self.aws_region_name}::foundation-model/{self.embedding_model}"
],
},
{
"Effect": "Allow",
"Action": ["aoss:APIAccessAll"],
"Resource": [collection_arn],
},
{
"Effect": "Allow",
"Action": ["s3:GetObject", "s3:ListBucket"],
"Resource": [
f"arn:aws:s3:::{self.s3_bucket}",
f"arn:aws:s3:::{self.s3_bucket}/*",
],
},
],
}
iam.put_role_policy(
RoleName=role_name,
PolicyName=f"{role_name}-policy",
PolicyDocument=json.dumps(permissions_policy),
)
# Wait for role to propagate (use asyncio.sleep to avoid blocking)
await asyncio.sleep(10)
verbose_logger.info(f"Created IAM role: {role_arn}")
return role_arn
async def _create_knowledge_base(
self, kb_name: str, role_arn: str, collection_arn: str
) -> str:
"""Create Bedrock Knowledge Base."""
bedrock_agent = self._get_boto3_client("bedrock-agent")
verbose_logger.debug(f"Creating Knowledge Base: {kb_name}")
response = bedrock_agent.create_knowledge_base(
name=kb_name,
roleArn=role_arn,
knowledgeBaseConfiguration={
"type": "VECTOR",
"vectorKnowledgeBaseConfiguration": {
"embeddingModelArn": f"arn:aws:bedrock:{self.aws_region_name}::foundation-model/{self.embedding_model}",
},
},
storageConfiguration={
"type": "OPENSEARCH_SERVERLESS",
"opensearchServerlessConfiguration": {
"collectionArn": collection_arn,
"fieldMapping": {
"metadataField": "AMAZON_BEDROCK_METADATA",
"textField": "AMAZON_BEDROCK_TEXT_CHUNK",
"vectorField": "bedrock-knowledge-base-default-vector",
},
"vectorIndexName": "bedrock-kb-index",
},
},
)
kb_id = response["knowledgeBase"]["knowledgeBaseId"]
self._created_resources["knowledge_base"] = kb_id
# Wait for KB to be active (use asyncio.sleep to avoid blocking)
verbose_logger.debug("Waiting for Knowledge Base to be active...")
for _ in range(30):
kb_status = bedrock_agent.get_knowledge_base(knowledgeBaseId=kb_id)
status = kb_status["knowledgeBase"]["status"]
if status == "ACTIVE":
break
await asyncio.sleep(2)
else:
raise TimeoutError("Knowledge Base did not become active in time")
verbose_logger.info(f"Created Knowledge Base: {kb_id}")
return kb_id
def _create_data_source(self, kb_name: str) -> str:
"""Create Data Source for the Knowledge Base."""
bedrock_agent = self._get_boto3_client("bedrock-agent")
verbose_logger.debug(f"Creating Data Source for KB: {self.knowledge_base_id}")
response = bedrock_agent.create_data_source(
knowledgeBaseId=self.knowledge_base_id,
name=f"{kb_name}-s3-source",
dataSourceConfiguration={
"type": "S3",
"s3Configuration": {
"bucketArn": f"arn:aws:s3:::{self.s3_bucket}",
"inclusionPrefixes": [self.s3_prefix],
},
},
)
ds_id = response["dataSource"]["dataSourceId"]
self._created_resources["data_source"] = ds_id
verbose_logger.info(f"Created Data Source: {ds_id}")
return ds_id
def _get_boto3_client(self, service_name: str):
"""Get a boto3 client for the specified service using BaseAWSLLM auth."""
try:
import boto3
except ImportError:
raise ImportError(
"boto3 is required for Bedrock ingestion. Install with: pip install boto3"
)
# Get credentials using BaseAWSLLM's get_credentials method
credentials = self.get_credentials(
aws_access_key_id=_get_str_or_none(
self.vector_store_config.get("aws_access_key_id")
),
aws_secret_access_key=_get_str_or_none(
self.vector_store_config.get("aws_secret_access_key")
),
aws_session_token=_get_str_or_none(
self.vector_store_config.get("aws_session_token")
),
aws_region_name=self.aws_region_name,
aws_session_name=_get_str_or_none(
self.vector_store_config.get("aws_session_name")
),
aws_profile_name=_get_str_or_none(
self.vector_store_config.get("aws_profile_name")
),
aws_role_name=_get_str_or_none(
self.vector_store_config.get("aws_role_name")
),
aws_web_identity_token=_get_str_or_none(
self.vector_store_config.get("aws_web_identity_token")
),
aws_sts_endpoint=_get_str_or_none(
self.vector_store_config.get("aws_sts_endpoint")
),
aws_external_id=_get_str_or_none(
self.vector_store_config.get("aws_external_id")
),
)
# Create session with credentials
session = boto3.Session(
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token,
region_name=self.aws_region_name,
)
return session.client(service_name)
async def embed(
self,
chunks: List[str],
) -> Optional[List[List[float]]]:
"""
Bedrock handles embedding internally - skip this step.
Returns:
None (Bedrock embeds when files are ingested)
"""
return None
async def store(
self,
file_content: Optional[bytes],
filename: Optional[str],
content_type: Optional[str],
chunks: List[str],
embeddings: Optional[List[List[float]]],
) -> Tuple[Optional[str], Optional[str]]:
"""
Store content in Bedrock Knowledge Base.
Bedrock workflow:
1. Auto-detect data source and S3 bucket (if not provided)
2. Upload file to S3 bucket
3. Start ingestion job
4. (Optional) Wait for ingestion to complete
Args:
file_content: Raw file bytes
filename: Name of the file
content_type: MIME type
chunks: Ignored - Bedrock handles chunking
embeddings: Ignored - Bedrock handles embedding
Returns:
Tuple of (knowledge_base_id, file_key)
"""
# Auto-detect data source and S3 bucket if needed
await self._ensure_config_initialized()
if not file_content or not filename:
verbose_logger.warning(
"No file content or filename provided for Bedrock ingestion"
)
return _get_str_or_none(self.knowledge_base_id), None
# Step 1: Upload file to S3
s3_client = self._get_boto3_client("s3")
s3_key = f"{self.s3_prefix.rstrip('/')}/{filename}"
verbose_logger.debug(f"Uploading file to s3://{self.s3_bucket}/{s3_key}")
s3_client.put_object(
Bucket=self.s3_bucket,
Key=s3_key,
Body=file_content,
ContentType=content_type or "application/octet-stream",
)
verbose_logger.info(f"Uploaded file to s3://{self.s3_bucket}/{s3_key}")
# Step 2: Start ingestion job
bedrock_agent = self._get_boto3_client("bedrock-agent")
verbose_logger.debug(
f"Starting ingestion job for KB={self.knowledge_base_id}, DS={self.data_source_id}"
)
ingestion_response = bedrock_agent.start_ingestion_job(
knowledgeBaseId=self.knowledge_base_id,
dataSourceId=self.data_source_id,
)
job_id = ingestion_response["ingestionJob"]["ingestionJobId"]
verbose_logger.info(f"Started ingestion job: {job_id}")
# Step 3: Wait for ingestion (optional) - use asyncio.sleep to avoid blocking
if self.wait_for_ingestion:
import time as time_module
start_time = time_module.time()
while time_module.time() - start_time < self.ingestion_timeout:
job_status = bedrock_agent.get_ingestion_job(
knowledgeBaseId=self.knowledge_base_id,
dataSourceId=self.data_source_id,
ingestionJobId=job_id,
)
status = job_status["ingestionJob"]["status"]
verbose_logger.debug(f"Ingestion job {job_id} status: {status}")
if status == "COMPLETE":
stats = job_status["ingestionJob"].get("statistics", {})
verbose_logger.info(
f"Ingestion complete: {stats.get('numberOfNewDocumentsIndexed', 0)} docs indexed"
)
break
elif status == "FAILED":
failure_reasons = job_status["ingestionJob"].get(
"failureReasons", []
)
verbose_logger.error(f"Ingestion failed: {failure_reasons}")
break
elif status in ("STARTING", "IN_PROGRESS"):
await asyncio.sleep(2)
else:
verbose_logger.warning(f"Unknown ingestion status: {status}")
break
return str(self.knowledge_base_id) if self.knowledge_base_id else None, s3_key

View File

@@ -0,0 +1,9 @@
"""
File parsers for RAG ingestion.
Provides text extraction utilities for various file formats.
"""
from .pdf_parser import extract_text_from_pdf
__all__ = ["extract_text_from_pdf"]

View File

@@ -0,0 +1,76 @@
"""
PDF text extraction utilities.
Provides text extraction from PDF files using pypdf or PyPDF2.
"""
from typing import Optional
from litellm._logging import verbose_logger
def extract_text_from_pdf(file_content: bytes) -> Optional[str]:
"""
Extract text from PDF using pypdf if available.
Args:
file_content: Raw PDF bytes
Returns:
Extracted text or None if extraction fails
"""
try:
from io import BytesIO
# Try pypdf first (most common)
try:
from pypdf import PdfReader as PypdfReader
pdf_file = BytesIO(file_content)
reader = PypdfReader(pdf_file)
text_parts = []
for page in reader.pages:
text = page.extract_text()
if text:
text_parts.append(text)
if text_parts:
extracted_text = "\n\n".join(text_parts)
verbose_logger.debug(
f"Extracted {len(extracted_text)} characters from PDF using pypdf"
)
return extracted_text
except ImportError:
verbose_logger.debug("pypdf not available, trying PyPDF2")
# Fallback to PyPDF2
try:
from PyPDF2 import PdfReader as PyPDF2Reader
pdf_file = BytesIO(file_content)
reader = PyPDF2Reader(pdf_file)
text_parts = []
for page in reader.pages:
text = page.extract_text()
if text:
text_parts.append(text)
if text_parts:
extracted_text = "\n\n".join(text_parts)
verbose_logger.debug(
f"Extracted {len(extracted_text)} characters from PDF using PyPDF2"
)
return extracted_text
except ImportError:
verbose_logger.debug(
"PyPDF2 not available, PDF extraction requires OCR or pypdf/PyPDF2 library"
)
except Exception as e:
verbose_logger.debug(f"PDF text extraction failed: {e}")
return None

View File

@@ -0,0 +1,339 @@
"""
Gemini-specific RAG Ingestion implementation.
Gemini handles embedding and chunking internally when files are uploaded to File Search stores,
so this implementation skips the embedding step and directly uploads files.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.llms.gemini.common_utils import GeminiModelInfo
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
if TYPE_CHECKING:
from litellm import Router
from litellm.types.rag import RAGIngestOptions
class GeminiRAGIngestion(BaseRAGIngestion):
"""
Gemini-specific RAG ingestion using File Search API.
Key differences from base:
- Embedding is handled by Gemini when files are uploaded to File Search stores
- Files are uploaded using uploadToFileSearchStore API
- Chunking is done by Gemini's File Search (supports custom white_space_config)
- Supports custom metadata attachment
"""
def __init__(
self,
ingest_options: "RAGIngestOptions",
router: Optional["Router"] = None,
):
super().__init__(ingest_options=ingest_options, router=router)
self.model_info = GeminiModelInfo()
async def embed(
self,
chunks: List[str],
) -> Optional[List[List[float]]]:
"""
Gemini handles embedding internally - skip this step.
Returns:
None (Gemini embeds when files are uploaded to File Search store)
"""
# Gemini handles embedding when files are uploaded to File Search stores
return None
async def store(
self,
file_content: Optional[bytes],
filename: Optional[str],
content_type: Optional[str],
chunks: List[str],
embeddings: Optional[List[List[float]]],
) -> Tuple[Optional[str], Optional[str]]:
"""
Store content in Gemini File Search store.
Gemini workflow:
1. Create File Search store (if not provided)
2. Upload file using uploadToFileSearchStore (Gemini handles chunking/embedding)
Args:
file_content: Raw file bytes
filename: Name of the file
content_type: MIME type
chunks: Ignored - Gemini handles chunking
embeddings: Ignored - Gemini handles embedding
Returns:
Tuple of (vector_store_id, file_id)
"""
vector_store_id = self.vector_store_config.get("vector_store_id")
vector_store_config = cast(Dict[str, Any], self.vector_store_config)
# Get API credentials
api_key = (
cast(Optional[str], vector_store_config.get("api_key"))
or GeminiModelInfo.get_api_key()
)
api_base = (
cast(Optional[str], vector_store_config.get("api_base"))
or GeminiModelInfo.get_api_base()
)
if not api_key:
raise ValueError(
"GEMINI_API_KEY or GOOGLE_API_KEY is required for Gemini File Search"
)
if not api_base:
raise ValueError("GEMINI_API_BASE is required")
api_version = "v1beta"
base_url = f"{api_base}/{api_version}"
# Create File Search store if not provided
if not vector_store_id:
vector_store_id = await self._create_file_search_store(
api_key=api_key,
base_url=base_url,
display_name=self.ingest_name or "litellm-rag-ingest",
)
# Upload file to File Search store
result_file_id = None
if file_content and filename and vector_store_id:
result_file_id = await self._upload_to_file_search_store(
api_key=api_key,
base_url=base_url,
vector_store_id=vector_store_id,
filename=filename,
file_content=file_content,
content_type=content_type,
)
return vector_store_id, result_file_id
async def _create_file_search_store(
self,
api_key: str,
base_url: str,
display_name: str,
) -> str:
"""
Create a Gemini File Search store.
Args:
api_key: Gemini API key
base_url: Base URL for Gemini API
display_name: Display name for the store
Returns:
Store name (format: fileSearchStores/xxxxxxx)
"""
url = f"{base_url}/fileSearchStores?key={api_key}"
request_body = {"displayName": display_name}
client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.RAG,
params={"timeout": 60.0},
)
response = await client.post(
url,
json=request_body,
headers={"Content-Type": "application/json"},
)
if response.status_code != 200:
error_msg = f"Failed to create File Search store: {response.text}"
verbose_logger.error(error_msg)
raise Exception(error_msg)
response_data = response.json()
store_name = response_data.get("name", "")
verbose_logger.debug(f"Created File Search store: {store_name}")
return store_name
async def _upload_to_file_search_store(
self,
api_key: str,
base_url: str,
vector_store_id: str,
filename: str,
file_content: bytes,
content_type: Optional[str],
) -> str:
"""
Upload a file to Gemini File Search store using resumable upload.
Args:
api_key: Gemini API key
base_url: Base URL for Gemini API
vector_store_id: File Search store name
filename: Name of the file
file_content: File content bytes
content_type: MIME type
Returns:
File ID or document name
"""
# Step 1: Initiate resumable upload
upload_url = await self._initiate_resumable_upload(
api_key=api_key,
base_url=base_url,
vector_store_id=vector_store_id,
filename=filename,
file_size=len(file_content),
content_type=content_type or "application/octet-stream",
)
# Step 2: Upload the file content
file_id = await self._upload_file_content(
upload_url=upload_url,
file_content=file_content,
)
return file_id
async def _initiate_resumable_upload(
self,
api_key: str,
base_url: str,
vector_store_id: str,
filename: str,
file_size: int,
content_type: str,
) -> str:
"""
Initiate a resumable upload session.
Returns:
Upload URL for the resumable session
"""
# Construct the upload URL - need to use the full upload endpoint
# base_url is like: https://generativelanguage.googleapis.com/v1beta
# We need: https://generativelanguage.googleapis.com/upload/v1beta/{store_id}:uploadToFileSearchStore
api_base = base_url.replace("/v1beta", "") # Get base without version
url = f"{api_base}/upload/v1beta/{vector_store_id}:uploadToFileSearchStore?key={api_key}"
# Build request body with chunking config and metadata if provided
request_body: Dict[str, Any] = {"displayName": filename}
# Add chunking configuration if provided
chunking_strategy = self.chunking_strategy
if chunking_strategy and isinstance(chunking_strategy, dict):
white_space_config = chunking_strategy.get("white_space_config")
if white_space_config:
request_body["chunkingConfig"] = {
"whiteSpaceConfig": {
"maxTokensPerChunk": white_space_config.get(
"max_tokens_per_chunk", 800
),
"maxOverlapTokens": white_space_config.get(
"max_overlap_tokens", 400
),
}
}
# Add custom metadata if provided in vector_store_config
custom_metadata = cast(
Optional[List[Dict[str, Any]]],
self.vector_store_config.get("custom_metadata"),
)
if custom_metadata:
request_body["customMetadata"] = custom_metadata
headers = {
"X-Goog-Upload-Protocol": "resumable",
"X-Goog-Upload-Command": "start",
"X-Goog-Upload-Header-Content-Length": str(file_size),
"X-Goog-Upload-Header-Content-Type": content_type,
"Content-Type": "application/json",
}
verbose_logger.debug(f"Initiating resumable upload: {url}")
client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.RAG,
params={"timeout": 60.0},
)
response = await client.post(
url,
json=request_body,
headers=headers,
)
if response.status_code not in [200, 201]:
error_msg = f"Failed to initiate upload: {response.text}"
verbose_logger.error(error_msg)
raise Exception(error_msg)
verbose_logger.debug(f"Initiate resumable upload response: {response.headers}")
# Extract upload URL from response headers
upload_url = response.headers.get("x-goog-upload-url")
if not upload_url:
raise Exception("No upload URL returned in response headers")
verbose_logger.debug(f"Got upload URL: {upload_url}")
return upload_url
async def _upload_file_content(
self,
upload_url: str,
file_content: bytes,
) -> str:
"""
Upload file content to the resumable upload URL.
Returns:
File ID or document name from the response
"""
headers = {
"Content-Length": str(len(file_content)),
"X-Goog-Upload-Offset": "0",
"X-Goog-Upload-Command": "upload, finalize",
}
verbose_logger.debug(f"Uploading file content ({len(file_content)} bytes)")
client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.RAG,
params={"timeout": 300.0}, # Longer timeout for large files
)
response = await client.put(
upload_url,
content=file_content,
headers=headers,
)
if response.status_code not in [200, 201]:
error_msg = f"Failed to upload file: {response.text}"
verbose_logger.error(error_msg)
raise Exception(error_msg)
# Parse response to get file/document ID
try:
response_data = response.json()
# The response should contain the document name or file reference
file_id = response_data.get("name", "") or response_data.get(
"file", {}
).get("name", "")
verbose_logger.debug(f"Upload complete. File ID: {file_id}")
return file_id
except Exception as e:
verbose_logger.warning(f"Could not parse upload response: {e}")
# Return a placeholder if we can't get the ID
return "uploaded"

View File

@@ -0,0 +1,128 @@
"""
OpenAI-specific RAG Ingestion implementation.
OpenAI handles embedding internally when files are attached to vector stores,
so this implementation skips the embedding step and directly uploads files.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
import litellm
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
from litellm.vector_store_files.main import acreate as vector_store_file_acreate
from litellm.vector_stores.main import acreate as vector_store_acreate
if TYPE_CHECKING:
from litellm import Router
from litellm.types.rag import RAGIngestOptions
class OpenAIRAGIngestion(BaseRAGIngestion):
"""
OpenAI-specific RAG ingestion.
Key differences from base:
- Embedding is handled by OpenAI when attaching files to vector stores
- Files are uploaded and attached to vector stores directly
- Chunking is done by OpenAI's vector store (uses 'auto' strategy)
"""
def __init__(
self,
ingest_options: "RAGIngestOptions",
router: Optional["Router"] = None,
):
super().__init__(ingest_options=ingest_options, router=router)
async def embed(
self,
chunks: List[str],
) -> Optional[List[List[float]]]:
"""
OpenAI handles embedding internally - skip this step.
Returns:
None (OpenAI embeds when files are attached to vector store)
"""
# OpenAI handles embedding when files are attached to vector stores
return None
async def store(
self,
file_content: Optional[bytes],
filename: Optional[str],
content_type: Optional[str],
chunks: List[str],
embeddings: Optional[List[List[float]]],
) -> Tuple[Optional[str], Optional[str]]:
"""
Store content in OpenAI vector store.
OpenAI workflow:
1. Create vector store (if not provided)
2. Upload file to OpenAI
3. Attach file to vector store (OpenAI handles chunking/embedding)
Args:
file_content: Raw file bytes
filename: Name of the file
content_type: MIME type
chunks: Ignored - OpenAI handles chunking
embeddings: Ignored - OpenAI handles embedding
Returns:
Tuple of (vector_store_id, file_id)
"""
vector_store_id = self.vector_store_config.get("vector_store_id")
ttl_days = self.vector_store_config.get("ttl_days")
# Get credentials from vector_store_config (loaded from litellm_credential_name if provided)
api_key = self.vector_store_config.get("api_key")
api_base = self.vector_store_config.get("api_base")
# Create vector store if not provided
if not vector_store_id:
expires_after = (
{"anchor": "last_active_at", "days": ttl_days} if ttl_days else None
)
create_response = await vector_store_acreate(
name=self.ingest_name or "litellm-rag-ingest",
custom_llm_provider="openai",
expires_after=expires_after,
api_key=api_key,
api_base=api_base,
)
vector_store_id = create_response.get("id")
# Upload file and attach to vector store
result_file_id = None
if file_content and filename and vector_store_id:
# Upload file to OpenAI
file_response = await litellm.acreate_file(
file=(
filename,
file_content,
content_type or "application/octet-stream",
),
purpose="assistants",
custom_llm_provider="openai",
api_key=api_key,
api_base=api_base,
)
result_file_id = file_response.id
# Attach file to vector store (OpenAI handles chunking/embedding)
await vector_store_file_acreate(
vector_store_id=vector_store_id,
file_id=result_file_id,
custom_llm_provider="openai",
chunking_strategy=cast(
Optional[Dict[str, Any]], self.chunking_strategy
),
api_key=api_key,
api_base=api_base,
)
return vector_store_id, result_file_id

View File

@@ -0,0 +1,594 @@
"""
S3 Vectors-specific RAG Ingestion implementation.
S3 Vectors is AWS's native vector storage service that provides:
- Purpose-built vector buckets for storing and querying vectors
- Vector indexes with configurable dimensions and distance metrics
- Metadata filtering for semantic search
This implementation:
1. Auto-creates vector buckets and indexes if not provided
2. Uses LiteLLM's embedding API (supports any provider)
3. Uses httpx + AWS SigV4 signing (no boto3 dependency for S3 Vectors APIs)
4. Stores vectors with metadata using PutVectors API
"""
from __future__ import annotations
import hashlib
import uuid
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import litellm
from litellm._logging import verbose_logger
from litellm.constants import (
S3_VECTORS_DEFAULT_DIMENSION,
S3_VECTORS_DEFAULT_DISTANCE_METRIC,
S3_VECTORS_DEFAULT_NON_FILTERABLE_METADATA_KEYS,
)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
if TYPE_CHECKING:
from litellm import Router
from litellm.types.rag import RAGIngestOptions
class S3VectorsRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
"""
S3 Vectors RAG ingestion using httpx + AWS SigV4 signing.
Workflow:
1. Auto-create vector bucket if needed (CreateVectorBucket API)
2. Auto-create vector index if needed (CreateVectorIndex API)
3. Generate embeddings using LiteLLM (supports any provider)
4. Store vectors with PutVectors API
Configuration:
- vector_bucket_name: S3 vector bucket name (required)
- index_name: Vector index name (auto-creates if not provided)
- dimension: Vector dimension (default: S3_VECTORS_DEFAULT_DIMENSION)
- distance_metric: "cosine" or "euclidean" (default: S3_VECTORS_DEFAULT_DISTANCE_METRIC)
- non_filterable_metadata_keys: List of metadata keys to exclude from filtering
"""
def __init__(
self,
ingest_options: "RAGIngestOptions",
router: Optional["Router"] = None,
):
BaseRAGIngestion.__init__(self, ingest_options=ingest_options, router=router)
BaseAWSLLM.__init__(self)
# Extract config
self.vector_bucket_name = self.vector_store_config["vector_bucket_name"]
self.index_name = self.vector_store_config.get("index_name")
self.distance_metric = self.vector_store_config.get(
"distance_metric", S3_VECTORS_DEFAULT_DISTANCE_METRIC
)
self.non_filterable_metadata_keys = self.vector_store_config.get(
"non_filterable_metadata_keys",
S3_VECTORS_DEFAULT_NON_FILTERABLE_METADATA_KEYS,
)
# Get dimension from config (will be auto-detected on first use if not provided)
self.dimension = self._get_dimension_from_config()
# Get AWS region using BaseAWSLLM method
_aws_region = self.vector_store_config.get("aws_region_name")
self.aws_region_name = self.get_aws_region_name_for_non_llm_api_calls(
aws_region_name=str(_aws_region) if _aws_region else None
)
# Create httpx client (similar to s3_v2.py)
ssl_verify = self._get_ssl_verify(
ssl_verify=self.vector_store_config.get("ssl_verify")
)
self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.RAG,
params={"ssl_verify": ssl_verify} if ssl_verify is not None else None,
)
# Track if infrastructure is initialized
self._config_initialized = False
async def _get_dimension_from_embedding_request(self) -> int:
"""
Auto-detect dimension by making a test embedding request.
Makes a single embedding request with a test string to determine
the output dimension of the embedding model.
"""
if not self.embedding_config or "model" not in self.embedding_config:
return S3_VECTORS_DEFAULT_DIMENSION
try:
model_name = self.embedding_config["model"]
verbose_logger.debug(
f"Auto-detecting dimension by making test embedding request to {model_name}"
)
# Make a test embedding request
test_input = "test"
if self.router:
response = await self.router.aembedding(
model=model_name, input=[test_input]
)
else:
response = await litellm.aembedding(
model=model_name, input=[test_input]
)
# Get dimension from the response
if response.data and len(response.data) > 0:
dimension = len(response.data[0]["embedding"])
verbose_logger.debug(
f"Auto-detected dimension {dimension} for embedding model {model_name}"
)
return dimension
except Exception as e:
verbose_logger.warning(
f"Could not auto-detect dimension from embedding model: {e}. "
f"Using default dimension of {S3_VECTORS_DEFAULT_DIMENSION}."
)
return S3_VECTORS_DEFAULT_DIMENSION
def _get_dimension_from_config(self) -> Optional[int]:
"""
Get vector dimension from config if explicitly provided.
Returns None if dimension should be auto-detected.
"""
if "dimension" in self.vector_store_config:
return int(self.vector_store_config["dimension"])
return None
async def _ensure_config_initialized(self):
"""Lazily initialize S3 Vectors infrastructure."""
if self._config_initialized:
return
# Auto-detect dimension if not provided
if self.dimension is None:
self.dimension = await self._get_dimension_from_embedding_request()
# Ensure vector bucket exists
await self._ensure_vector_bucket_exists()
# Ensure vector index exists
if not self.index_name:
# Auto-generate index name
unique_id = uuid.uuid4().hex[:8]
self.index_name = f"litellm-index-{unique_id}"
await self._ensure_vector_index_exists()
self._config_initialized = True
async def _sign_and_execute_request(
self,
method: str,
url: str,
data: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
) -> Any:
"""
Helper to sign and execute AWS API requests using httpx + SigV4.
Pattern from litellm/integrations/s3_v2.py
"""
try:
import requests
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
except ImportError:
raise ImportError(
"Missing botocore to call S3 Vectors. Run 'pip install boto3'."
)
# Get AWS credentials using BaseAWSLLM's get_credentials method
credentials = self.get_credentials(
aws_access_key_id=self.vector_store_config.get("aws_access_key_id"),
aws_secret_access_key=self.vector_store_config.get("aws_secret_access_key"),
aws_session_token=self.vector_store_config.get("aws_session_token"),
aws_region_name=self.aws_region_name,
aws_session_name=self.vector_store_config.get("aws_session_name"),
aws_profile_name=self.vector_store_config.get("aws_profile_name"),
aws_role_name=self.vector_store_config.get("aws_role_name"),
aws_web_identity_token=self.vector_store_config.get(
"aws_web_identity_token"
),
aws_sts_endpoint=self.vector_store_config.get("aws_sts_endpoint"),
aws_external_id=self.vector_store_config.get("aws_external_id"),
)
# Prepare headers
if headers is None:
headers = {}
if data:
headers["Content-Type"] = "application/json"
# Calculate SHA256 hash of the content
content_hash = hashlib.sha256(data.encode("utf-8")).hexdigest()
headers["x-amz-content-sha256"] = content_hash
else:
# For requests without body, use hash of empty string
headers["x-amz-content-sha256"] = hashlib.sha256(b"").hexdigest()
# Prepare the request
req = requests.Request(method, url, data=data, headers=headers)
prepped = req.prepare()
# Sign the request
aws_request = AWSRequest(
method=prepped.method,
url=prepped.url,
data=prepped.body,
headers=prepped.headers,
)
SigV4Auth(credentials, "s3vectors", self.aws_region_name).add_auth(aws_request)
# Prepare the signed headers
signed_headers = dict(aws_request.headers.items())
# Make the request using specific method (pattern from s3_v2.py)
method_upper = method.upper()
if method_upper == "PUT":
response = await self.async_httpx_client.put(
url, data=data, headers=signed_headers
)
elif method_upper == "POST":
response = await self.async_httpx_client.post(
url, data=data, headers=signed_headers
)
elif method_upper == "GET":
response = await self.async_httpx_client.get(url, headers=signed_headers)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
return response
async def _ensure_vector_bucket_exists(self):
"""Create vector bucket if it doesn't exist using GetVectorBucket and CreateVectorBucket APIs."""
verbose_logger.debug(
f"Ensuring S3 vector bucket exists: {self.vector_bucket_name}"
)
# Validate bucket name (AWS S3 naming rules)
if len(self.vector_bucket_name) < 3:
raise ValueError(
f"Invalid vector_bucket_name '{self.vector_bucket_name}': "
f"AWS S3 bucket names must be at least 3 characters long. "
f"Please provide a valid bucket name (e.g., 'my-vector-bucket')."
)
if not self.vector_bucket_name.replace("-", "").replace(".", "").isalnum():
raise ValueError(
f"Invalid vector_bucket_name '{self.vector_bucket_name}': "
f"AWS S3 bucket names can only contain lowercase letters, numbers, hyphens, and periods. "
f"Please provide a valid bucket name (e.g., 'my-vector-bucket')."
)
# Try to get bucket info using GetVectorBucket API
get_url = f"https://s3vectors.{self.aws_region_name}.api.aws/GetVectorBucket"
get_body = safe_dumps({"vectorBucketName": self.vector_bucket_name})
try:
response = await self._sign_and_execute_request(
"POST", get_url, data=get_body
)
if response.status_code == 200:
verbose_logger.debug(f"Vector bucket {self.vector_bucket_name} exists")
return
except Exception as e:
verbose_logger.debug(
f"Bucket check failed (may not exist): {e}, attempting to create"
)
# Create vector bucket using CreateVectorBucket API
try:
verbose_logger.debug(f"Creating vector bucket: {self.vector_bucket_name}")
create_url = (
f"https://s3vectors.{self.aws_region_name}.api.aws/CreateVectorBucket"
)
create_body = safe_dumps({"vectorBucketName": self.vector_bucket_name})
response = await self._sign_and_execute_request(
"POST", create_url, data=create_body
)
if response.status_code in (200, 201):
verbose_logger.info(f"Created vector bucket: {self.vector_bucket_name}")
elif response.status_code == 409:
# Bucket already exists (ConflictException)
verbose_logger.debug(
f"Vector bucket {self.vector_bucket_name} already exists"
)
else:
verbose_logger.error(
f"CreateVectorBucket failed: {response.status_code} - {response.text}"
)
response.raise_for_status()
except Exception as e:
verbose_logger.exception(f"Error creating vector bucket: {e}")
raise
async def _ensure_vector_index_exists(self):
"""Create vector index if it doesn't exist using GetIndex and CreateIndex APIs."""
verbose_logger.debug(
f"Ensuring vector index exists: {self.vector_bucket_name}/{self.index_name}"
)
# Try to get index info using GetIndex API
get_url = f"https://s3vectors.{self.aws_region_name}.api.aws/GetIndex"
get_body = safe_dumps(
{"vectorBucketName": self.vector_bucket_name, "indexName": self.index_name}
)
try:
response = await self._sign_and_execute_request(
"POST", get_url, data=get_body
)
if response.status_code == 200:
verbose_logger.debug(f"Vector index {self.index_name} exists")
return
except Exception as e:
verbose_logger.debug(
f"Index check failed (may not exist): {e}, attempting to create"
)
# Create vector index using CreateIndex API
try:
verbose_logger.debug(
f"Creating vector index: {self.index_name} with dimension={self.dimension}, metric={self.distance_metric}"
)
# Prepare index configuration per AWS API docs
index_config = {
"vectorBucketName": self.vector_bucket_name,
"indexName": self.index_name,
"dataType": "float32",
"dimension": self.dimension,
"distanceMetric": self.distance_metric,
}
if self.non_filterable_metadata_keys:
index_config["metadataConfiguration"] = {
"nonFilterableMetadataKeys": self.non_filterable_metadata_keys
}
create_url = f"https://s3vectors.{self.aws_region_name}.api.aws/CreateIndex"
response = await self._sign_and_execute_request(
"POST", create_url, data=safe_dumps(index_config)
)
if response.status_code in (200, 201):
verbose_logger.info(f"Created vector index: {self.index_name}")
elif response.status_code == 409:
verbose_logger.debug(f"Vector index {self.index_name} already exists")
else:
verbose_logger.error(
f"CreateIndex failed: {response.status_code} - {response.text}"
)
response.raise_for_status()
except Exception as e:
verbose_logger.exception(f"Error creating vector index: {e}")
raise
async def _put_vectors(self, vectors: List[Dict[str, Any]]):
"""
Call PutVectors API to store vectors in S3 Vectors.
Args:
vectors: List of vector objects with keys: "key", "data", "metadata"
"""
verbose_logger.debug(
f"Storing {len(vectors)} vectors in {self.vector_bucket_name}/{self.index_name}"
)
url = f"https://s3vectors.{self.aws_region_name}.api.aws/PutVectors"
# Prepare request body per AWS API docs
request_body = {
"vectorBucketName": self.vector_bucket_name,
"indexName": self.index_name,
"vectors": vectors,
}
try:
response = await self._sign_and_execute_request(
"POST", url, data=safe_dumps(request_body)
)
if response.status_code in (200, 201):
verbose_logger.info(
f"Successfully stored {len(vectors)} vectors in index {self.index_name}"
)
else:
verbose_logger.error(
f"PutVectors failed with status {response.status_code}: {response.text}"
)
response.raise_for_status()
except Exception as e:
verbose_logger.exception(f"Error storing vectors: {e}")
raise
async def embed(
self,
chunks: List[str],
) -> Optional[List[List[float]]]:
"""
Generate embeddings using LiteLLM's embedding API.
Supports any embedding provider (OpenAI, Bedrock, Cohere, etc.)
"""
if not chunks:
return None
# Use embedding config from ingest_options or default
if not self.embedding_config:
verbose_logger.warning(
"No embedding config provided, using default text-embedding-3-small"
)
self.embedding_config = {"model": "text-embedding-3-small"}
embedding_model = self.embedding_config.get("model", "text-embedding-3-small")
verbose_logger.debug(
f"Generating embeddings for {len(chunks)} chunks using {embedding_model}"
)
# Convert to list to ensure type compatibility
input_chunks: List[str] = list(chunks)
if self.router:
response = await self.router.aembedding(
model=embedding_model, input=input_chunks
)
else:
response = await litellm.aembedding(
model=embedding_model, input=input_chunks
)
return [item["embedding"] for item in response.data]
async def store(
self,
file_content: Optional[bytes],
filename: Optional[str],
content_type: Optional[str],
chunks: List[str],
embeddings: Optional[List[List[float]]],
) -> Tuple[Optional[str], Optional[str]]:
"""
Store vectors in S3 Vectors using PutVectors API.
Steps:
1. Ensure vector bucket exists (auto-create if needed)
2. Ensure vector index exists (auto-create if needed)
3. Prepare vector data with metadata
4. Call PutVectors API with httpx + SigV4 signing
Args:
file_content: Raw file bytes (not used for S3 Vectors)
filename: Name of the file
content_type: MIME type (not used for S3 Vectors)
chunks: Text chunks
embeddings: Vector embeddings
Returns:
Tuple of (index_name, filename)
"""
# Ensure infrastructure exists
await self._ensure_config_initialized()
if not embeddings or not chunks:
error_msg = (
"No text content could be extracted from the file for embedding. "
"Possible causes:\n"
" 1. PDF files require OCR - add 'ocr' config with a vision model (e.g., 'anthropic/claude-3-5-sonnet-20241022')\n"
" 2. Binary files cannot be processed - convert to text first\n"
" 3. File is empty or contains no extractable text\n"
"For PDFs, either enable OCR or use a PDF extraction library to convert to text before ingestion."
)
verbose_logger.error(error_msg)
raise ValueError(error_msg)
# Prepare vectors for PutVectors API
vectors = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
# Build metadata dict
metadata: Dict[str, str] = {
"source_text": chunk, # Non-filterable (for reference)
"chunk_index": str(i), # Filterable
}
if filename:
metadata["filename"] = filename # Filterable
vector_obj = {
"key": f"{filename}_{i}" if filename else f"chunk_{i}",
"data": {"float32": embedding},
"metadata": metadata,
}
vectors.append(vector_obj)
# Call PutVectors API
await self._put_vectors(vectors)
# Return vector_store_id in format bucket_name:index_name for S3 Vectors search compatibility
vector_store_id = f"{self.vector_bucket_name}:{self.index_name}"
return vector_store_id, filename
async def query_vector_store(
self, vector_store_id: str, query: str, top_k: int = 5
) -> Optional[Dict[str, Any]]:
"""
Query S3 Vectors using QueryVectors API.
Args:
vector_store_id: Index name
query: Query text
top_k: Number of results to return
Returns:
Query results with vectors and metadata
"""
verbose_logger.debug(f"Querying index {vector_store_id} with query: {query}")
# Generate query embedding
if not self.embedding_config:
self.embedding_config = {"model": "text-embedding-3-small"}
embedding_model = self.embedding_config.get("model", "text-embedding-3-small")
response = await litellm.aembedding(model=embedding_model, input=[query])
query_embedding = response.data[0]["embedding"]
# Call QueryVectors API
url = f"https://s3vectors.{self.aws_region_name}.api.aws/QueryVectors"
request_body = {
"vectorBucketName": self.vector_bucket_name,
"indexName": vector_store_id,
"queryVector": {"float32": query_embedding},
"topK": top_k,
"returnDistance": True,
"returnMetadata": True,
}
try:
response = await self._sign_and_execute_request(
"POST", url, data=safe_dumps(request_body)
)
if response.status_code == 200:
results = response.json()
verbose_logger.debug(
f"Query returned {len(results.get('vectors', []))} results"
)
# Check if query terms appear in results
if results.get("vectors"):
for result in results["vectors"]:
metadata = result.get("metadata", {})
source_text = metadata.get("source_text", "")
if query.lower() in source_text.lower():
return results
# Return results even if exact match not found
return results
else:
verbose_logger.error(
f"QueryVectors failed with status {response.status_code}: {response.text}"
)
return None
except Exception as e:
verbose_logger.exception(f"Error querying vectors: {e}")
return None

View File

@@ -0,0 +1,478 @@
"""
Vertex AI-specific RAG Ingestion implementation.
Vertex AI RAG Engine handles embedding and chunking internally when files are uploaded,
so this implementation skips the embedding step and directly uploads files to RAG corpora.
Based on: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/model-reference/rag-api-v1
"""
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
if TYPE_CHECKING:
from litellm import Router
from litellm.types.rag import RAGIngestOptions
class VertexAIRAGIngestion(BaseRAGIngestion, VertexBase):
"""
Vertex AI RAG Engine ingestion implementation.
Key differences from base:
- Embedding is handled by Vertex AI RAG Engine when files are uploaded
- Files are uploaded using the RAG API (import or upload)
- Chunking is done by Vertex AI RAG Engine (supports custom chunking config)
- Supports Google Cloud Storage (GCS) and Google Drive sources
- Supports custom parsing configurations (layout parser, LLM parser)
"""
def __init__(
self,
ingest_options: "RAGIngestOptions",
router: Optional["Router"] = None,
):
BaseRAGIngestion.__init__(self, ingest_options=ingest_options, router=router)
VertexBase.__init__(self)
# Extract Vertex AI specific configs from vector_store_config
litellm_params = dict(self.vector_store_config)
# Get project, location, and credentials using VertexBase methods
self.project_id = self.safe_get_vertex_ai_project(litellm_params)
self.location = self.get_vertex_ai_location(litellm_params) or "us-central1"
self.vertex_credentials = self.safe_get_vertex_ai_credentials(litellm_params)
async def embed(
self,
chunks: List[str],
) -> Optional[List[List[float]]]:
"""
Vertex AI RAG Engine handles embedding internally - skip this step.
Returns:
None (Vertex AI embeds when files are uploaded to RAG corpus)
"""
# Vertex AI RAG Engine handles embedding when files are uploaded
return None
async def store(
self,
file_content: Optional[bytes],
filename: Optional[str],
content_type: Optional[str],
chunks: List[str],
embeddings: Optional[List[List[float]]],
) -> Tuple[Optional[str], Optional[str]]:
"""
Store content in Vertex AI RAG corpus.
Vertex AI workflow:
1. Create RAG corpus (if not provided)
2. Upload file using RAG API (Vertex AI handles chunking/embedding)
Args:
file_content: Raw file bytes
filename: Name of the file
content_type: MIME type
chunks: Ignored - Vertex AI handles chunking
embeddings: Ignored - Vertex AI handles embedding
Returns:
Tuple of (rag_corpus_id, file_id)
"""
if not self.project_id:
raise ValueError(
"vertex_project is required for Vertex AI RAG ingestion. "
"Set it in vector_store config."
)
# Get or create RAG corpus
rag_corpus_id = self.vector_store_config.get("vector_store_id")
if not rag_corpus_id:
rag_corpus_id = await self._create_rag_corpus(
display_name=self.ingest_name or "litellm-rag-corpus",
description=self.vector_store_config.get("description"),
)
# Upload file to RAG corpus
result_file_id = None
if file_content and filename and rag_corpus_id:
result_file_id = await self._upload_file_to_corpus(
rag_corpus_id=rag_corpus_id,
filename=filename,
file_content=file_content,
content_type=content_type,
)
return rag_corpus_id, result_file_id
async def _create_rag_corpus(
self,
display_name: str,
description: Optional[str] = None,
) -> str:
"""
Create a Vertex AI RAG corpus.
Args:
display_name: Display name for the corpus
description: Optional description
Returns:
RAG corpus ID (format: projects/{project}/locations/{location}/ragCorpora/{corpus_id})
"""
# Get access token using VertexBase method
access_token, project_id = self._ensure_access_token(
credentials=self.vertex_credentials,
project_id=self.project_id,
custom_llm_provider="vertex_ai",
)
# Use the project_id from token if not set
if not self.project_id:
self.project_id = project_id
# Construct URL using vertex base URL helper
base_url = get_vertex_base_url(self.location)
url = (
f"{base_url}/v1beta1/"
f"projects/{self.project_id}/locations/{self.location}/ragCorpora"
)
# Build request body with camelCase keys (Vertex AI API format)
request_body: Dict[str, Any] = {
"displayName": display_name,
}
if description:
request_body["description"] = description
# Add vector database config if specified
vector_db_config = self.vector_store_config.get("vector_db_config")
if vector_db_config:
request_body["vectorDbConfig"] = vector_db_config
# Add embedding model config if specified
embedding_model = self.vector_store_config.get("embedding_model")
if embedding_model:
if "vectorDbConfig" not in request_body:
request_body["vectorDbConfig"] = {}
request_body["vectorDbConfig"]["ragEmbeddingModelConfig"] = {
"vertexPredictionEndpoint": {"endpoint": embedding_model}
}
verbose_logger.debug(f"Creating RAG corpus: {url}")
verbose_logger.debug(f"Request body: {json.dumps(request_body, indent=2)}")
client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.RAG,
params={"timeout": 60.0},
)
response = await client.post(
url,
json=request_body,
headers={
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
},
)
if response.status_code not in [200, 201]:
error_msg = f"Failed to create RAG corpus: {response.text}"
verbose_logger.error(error_msg)
raise Exception(error_msg)
response_data = response.json()
verbose_logger.debug(
f"Create corpus response: {json.dumps(response_data, indent=2)}"
)
# The response is a long-running operation
# Check if it's already done or if we need to poll
if response_data.get("done"):
# Operation completed immediately
corpus_name = response_data.get("response", {}).get("name", "")
else:
# Need to poll the operation
operation_name = response_data.get("name", "")
verbose_logger.debug(f"Polling operation: {operation_name}")
corpus_name = await self._poll_operation(
operation_name=operation_name,
access_token=access_token,
)
verbose_logger.debug(f"Created RAG corpus: {corpus_name}")
return corpus_name
async def _poll_operation(
self,
operation_name: str,
access_token: str,
max_retries: int = 30,
retry_delay: float = 2.0,
) -> str:
"""
Poll a long-running operation until it completes.
Args:
operation_name: The operation name (e.g., "operations/123456")
access_token: Access token for authentication
max_retries: Maximum number of polling attempts
retry_delay: Delay between polling attempts in seconds
Returns:
The corpus name from the completed operation
Raises:
Exception: If operation fails or times out
"""
import asyncio
base_url = get_vertex_base_url(self.location)
# Operation name is like: projects/{project}/locations/{location}/operations/{operation_id}
# We need to construct the full URL
url = f"{base_url}/v1beta1/{operation_name}"
client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.RAG,
params={"timeout": 60.0},
)
for attempt in range(max_retries):
response = await client.get(
url,
headers={
"Authorization": f"Bearer {access_token}",
},
)
if response.status_code != 200:
error_msg = f"Failed to poll operation: {response.text}"
verbose_logger.error(error_msg)
raise Exception(error_msg)
operation_data = response.json()
if operation_data.get("done"):
# Check for errors
if "error" in operation_data:
error = operation_data["error"]
raise Exception(f"Operation failed: {error}")
# Extract corpus name from response
corpus_name = operation_data.get("response", {}).get("name", "")
if corpus_name:
return corpus_name
else:
raise Exception(
f"No corpus name in operation response: {operation_data}"
)
verbose_logger.debug(
f"Operation not done yet, attempt {attempt + 1}/{max_retries}"
)
await asyncio.sleep(retry_delay)
raise Exception(f"Operation timed out after {max_retries} attempts")
async def _upload_file_to_corpus(
self,
rag_corpus_id: str,
filename: str,
file_content: bytes,
content_type: Optional[str],
) -> str:
"""
Upload a file to Vertex AI RAG corpus using multipart upload.
Args:
rag_corpus_id: RAG corpus resource name
filename: Name of the file
file_content: File content bytes
content_type: MIME type
Returns:
File ID or resource name
"""
# Get access token using VertexBase method
access_token, _ = self._ensure_access_token(
credentials=self.vertex_credentials,
project_id=self.project_id,
custom_llm_provider="vertex_ai",
)
# Construct upload URL using vertex base URL helper
base_url = get_vertex_base_url(self.location)
url = f"{base_url}/upload/v1beta1/" f"{rag_corpus_id}/ragFiles:upload"
# Build metadata for the file with snake_case keys (as per upload API docs)
metadata: Dict[str, Any] = {
"rag_file": {
"display_name": filename,
}
}
# Add description if provided
description = self.vector_store_config.get("file_description")
if description:
metadata["rag_file"]["description"] = description
# Add chunking configuration if provided
chunking_strategy = self.chunking_strategy
if chunking_strategy and isinstance(chunking_strategy, dict):
chunk_size = chunking_strategy.get("chunk_size")
chunk_overlap = chunking_strategy.get("chunk_overlap")
if chunk_size or chunk_overlap:
if "upload_rag_file_config" not in metadata:
metadata["upload_rag_file_config"] = {}
metadata["upload_rag_file_config"]["rag_file_transformation_config"] = {
"rag_file_chunking_config": {"fixed_length_chunking": {}}
}
chunking_config = metadata["upload_rag_file_config"][
"rag_file_transformation_config"
]["rag_file_chunking_config"]["fixed_length_chunking"]
if chunk_size:
chunking_config["chunk_size"] = chunk_size
if chunk_overlap:
chunking_config["chunk_overlap"] = chunk_overlap
verbose_logger.debug(f"Uploading file to RAG corpus: {url}")
verbose_logger.debug(f"Metadata: {json.dumps(metadata, indent=2)}")
# Prepare multipart form data
files = {
"metadata": (None, json.dumps(metadata), "application/json"),
"file": (
filename,
file_content,
content_type or "application/octet-stream",
),
}
client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.RAG,
params={"timeout": 300.0}, # Longer timeout for large files
)
response = await client.post(
url,
files=files,
headers={
"Authorization": f"Bearer {access_token}",
"X-Goog-Upload-Protocol": "multipart",
},
)
if response.status_code not in [200, 201]:
error_msg = f"Failed to upload file: {response.text}"
verbose_logger.error(error_msg)
raise Exception(error_msg)
# Parse response to get file ID
try:
response_data = response.json()
# The response should contain the rag_file resource name
file_id = response_data.get("ragFile", {}).get("name", "")
if not file_id:
file_id = response_data.get("name", "")
verbose_logger.debug(f"Upload complete. File ID: {file_id}")
return file_id
except Exception as e:
verbose_logger.warning(f"Could not parse upload response: {e}")
return "uploaded"
async def _import_files_from_gcs(
self,
rag_corpus_id: str,
gcs_uris: List[str],
) -> str:
"""
Import files from Google Cloud Storage into RAG corpus.
Args:
rag_corpus_id: RAG corpus resource name
gcs_uris: List of GCS URIs (e.g., ["gs://bucket/file.pdf"])
Returns:
Operation name for tracking import progress
"""
# Get access token using VertexBase method
access_token, _ = self._ensure_access_token(
credentials=self.vertex_credentials,
project_id=self.project_id,
custom_llm_provider="vertex_ai",
)
# Construct import URL using vertex base URL helper
base_url = get_vertex_base_url(self.location)
url = f"{base_url}/v1beta1/" f"{rag_corpus_id}/ragFiles:import"
# Build request body with camelCase keys (Vertex AI API format)
request_body: Dict[str, Any] = {
"importRagFilesConfig": {"gcsSource": {"uris": gcs_uris}}
}
# Add chunking configuration if provided
chunking_strategy = self.chunking_strategy
if chunking_strategy and isinstance(chunking_strategy, dict):
chunk_size = chunking_strategy.get("chunk_size")
chunk_overlap = chunking_strategy.get("chunk_overlap")
if chunk_size or chunk_overlap:
request_body["importRagFilesConfig"]["ragFileChunkingConfig"] = {
"chunkSize": chunk_size or 1024,
"chunkOverlap": chunk_overlap or 200,
}
# Add max embedding requests per minute if specified
max_embedding_qpm = self.vector_store_config.get(
"max_embedding_requests_per_min"
)
if max_embedding_qpm:
request_body["importRagFilesConfig"][
"maxEmbeddingRequestsPerMin"
] = max_embedding_qpm
verbose_logger.debug(f"Importing files from GCS: {url}")
verbose_logger.debug(f"Request body: {json.dumps(request_body, indent=2)}")
client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.RAG,
params={"timeout": 60.0},
)
response = await client.post(
url,
json=request_body,
headers={
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
},
)
if response.status_code not in [200, 201]:
error_msg = f"Failed to import files: {response.text}"
verbose_logger.error(error_msg)
raise Exception(error_msg)
response_data = response.json()
operation_name = response_data.get("name", "")
verbose_logger.debug(f"Import operation started: {operation_name}")
return operation_name

View File

@@ -0,0 +1,441 @@
"""
RAG Ingest API for LiteLLM.
Provides an all-in-one API for document ingestion:
Upload -> (OCR) -> Chunk -> Embed -> Vector Store
"""
from __future__ import annotations
__all__ = ["ingest", "aingest", "query", "aquery"]
import asyncio
import contextvars
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Coroutine,
Dict,
List,
Optional,
Tuple,
Type,
Union,
)
import httpx
import litellm
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
from litellm.rag.ingestion.bedrock_ingestion import BedrockRAGIngestion
from litellm.rag.ingestion.gemini_ingestion import GeminiRAGIngestion
from litellm.rag.ingestion.openai_ingestion import OpenAIRAGIngestion
from litellm.rag.ingestion.s3_vectors_ingestion import S3VectorsRAGIngestion
from litellm.rag.ingestion.vertex_ai_ingestion import VertexAIRAGIngestion
from litellm.rag.rag_query import RAGQuery
from litellm.types.rag import (
RAGIngestOptions,
RAGIngestResponse,
)
from litellm.types.utils import ModelResponse
from litellm.utils import client
if TYPE_CHECKING:
from litellm import Router
# Registry of provider-specific ingestion classes
INGESTION_REGISTRY: Dict[str, Type[BaseRAGIngestion]] = {
"openai": OpenAIRAGIngestion,
"bedrock": BedrockRAGIngestion,
"gemini": GeminiRAGIngestion,
"s3_vectors": S3VectorsRAGIngestion,
"vertex_ai": VertexAIRAGIngestion,
}
def get_ingestion_class(provider: str) -> Type[BaseRAGIngestion]:
"""
Get the ingestion class for a given provider.
Args:
provider: The vector store provider name (e.g., 'openai')
Returns:
The ingestion class for the provider
Raises:
ValueError: If provider is not supported
"""
ingestion_class = INGESTION_REGISTRY.get(provider)
if ingestion_class is None:
supported = ", ".join(INGESTION_REGISTRY.keys())
raise ValueError(
f"Provider '{provider}' is not supported for RAG ingestion. "
f"Supported providers: {supported}"
)
return ingestion_class
async def _execute_ingest_pipeline(
ingest_options: RAGIngestOptions,
file_data: Optional[Tuple[str, bytes, str]] = None,
file_url: Optional[str] = None,
file_id: Optional[str] = None,
router: Optional["Router"] = None,
) -> RAGIngestResponse:
"""
Execute the RAG ingest pipeline using provider-specific implementation.
Args:
ingest_options: Configuration for the ingest pipeline
file_data: Tuple of (filename, content_bytes, content_type)
file_url: URL to fetch file from
file_id: Existing file ID to use
router: Optional LiteLLM router for load balancing
Returns:
RAGIngestResponse with status and IDs
"""
# Get provider from vector store config
vector_store_config = ingest_options.get("vector_store") or {}
provider = vector_store_config.get("custom_llm_provider", "openai")
# Get provider-specific ingestion class
ingestion_class = get_ingestion_class(provider)
# Create ingestion instance
ingestion = ingestion_class(
ingest_options=ingest_options,
router=router,
)
# Execute ingestion pipeline
return await ingestion.ingest(
file_data=file_data,
file_url=file_url,
file_id=file_id,
)
####### PUBLIC API ###################
@client
async def aingest(
ingest_options: Dict[str, Any],
file_data: Optional[Tuple[str, bytes, str]] = None,
file: Optional[Dict[str, str]] = None,
file_url: Optional[str] = None,
file_id: Optional[str] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
**kwargs,
) -> RAGIngestResponse:
"""
Async: Ingest a document into a vector store.
Args:
ingest_options: Configuration for the ingest pipeline
file_data: Tuple of (filename, content_bytes, content_type)
file: Dict with {filename, content (base64), content_type} - for JSON API
file_url: URL to fetch file from
file_id: Existing file ID to use
Example:
```python
response = await litellm.aingest(
ingest_options={
"vector_store": {
"custom_llm_provider": "openai",
"litellm_credential_name": "my-openai-creds", # optional
}
},
file_url="https://example.com/doc.pdf",
)
```
"""
local_vars = locals()
try:
loop = asyncio.get_event_loop()
kwargs["aingest"] = True
func = partial(
ingest,
ingest_options=ingest_options,
file_data=file_data,
file=file,
file_url=file_url,
file_id=file_id,
timeout=timeout,
**kwargs,
)
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response
except Exception as e:
raise litellm.exception_type(
model=None,
custom_llm_provider=ingest_options.get("vector_store", {}).get(
"custom_llm_provider"
),
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
async def _execute_query_pipeline(
model: str,
messages: List[Any],
retrieval_config: Dict[str, Any],
rerank: Optional[Dict[str, Any]] = None,
stream: bool = False,
**kwargs,
) -> ModelResponse:
"""
Execute the RAG query pipeline.
"""
# Extract router from kwargs - use it for completion if available
# to properly resolve virtual model names
router: Optional["Router"] = kwargs.pop("router", None)
# 1. Extract query from last user message
query_text = RAGQuery.extract_query_from_messages(messages)
if not query_text:
raise ValueError("No query found in messages for RAG query")
# 2. Search vector store
search_response = await litellm.vector_stores.asearch(
vector_store_id=retrieval_config["vector_store_id"],
query=query_text,
max_num_results=retrieval_config.get("top_k", 10),
custom_llm_provider=retrieval_config.get("custom_llm_provider", "openai"),
**kwargs,
)
rerank_response = None
context_chunks = search_response.get("data", [])
# 3. Optional rerank
if rerank and rerank.get("enabled"):
documents = RAGQuery.extract_documents_from_search(search_response)
if documents:
rerank_response = await litellm.arerank(
model=rerank["model"],
query=query_text,
documents=documents,
top_n=rerank.get("top_n", 5),
)
context_chunks = RAGQuery.get_top_chunks_from_rerank(
search_response, rerank_response
)
# 4. Build context message and call completion
context_message = RAGQuery.build_context_message(context_chunks)
modified_messages = messages[:-1] + [context_message] + [messages[-1]]
# Use router if available to properly resolve virtual model names
if router is not None:
response = await router.acompletion(
model=model,
messages=modified_messages,
stream=stream,
**kwargs,
)
else:
response = await litellm.acompletion(
model=model,
messages=modified_messages,
stream=stream,
**kwargs,
)
# 5. Attach search results to response
if not stream and isinstance(response, ModelResponse):
response = RAGQuery.add_search_results_to_response(
response=response,
search_results=search_response,
rerank_results=rerank_response,
)
return response # type: ignore[return-value]
@client
async def aquery(
model: str,
messages: List[Any],
retrieval_config: Dict[str, Any],
rerank: Optional[Dict[str, Any]] = None,
stream: bool = False,
**kwargs,
) -> ModelResponse:
"""
Async: Query a RAG pipeline.
"""
local_vars = locals()
try:
loop = asyncio.get_event_loop()
kwargs["aquery"] = True
func = partial(
query,
model=model,
messages=messages,
retrieval_config=retrieval_config,
rerank=rerank,
stream=stream,
**kwargs,
)
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response
except Exception as e:
raise litellm.exception_type(
model=model,
custom_llm_provider=retrieval_config.get("custom_llm_provider"),
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
@client
def query(
model: str,
messages: List[Any],
retrieval_config: Dict[str, Any],
rerank: Optional[Dict[str, Any]] = None,
stream: bool = False,
**kwargs,
) -> Union[ModelResponse, Coroutine[Any, Any, ModelResponse]]:
"""
Query a RAG pipeline.
"""
local_vars = locals()
try:
_is_async = kwargs.pop("aquery", False) is True
if _is_async:
return _execute_query_pipeline(
model=model,
messages=messages,
retrieval_config=retrieval_config,
rerank=rerank,
stream=stream,
**kwargs,
)
else:
return asyncio.get_event_loop().run_until_complete(
_execute_query_pipeline(
model=model,
messages=messages,
retrieval_config=retrieval_config,
rerank=rerank,
stream=stream,
**kwargs,
)
)
except Exception as e:
raise litellm.exception_type(
model=model,
custom_llm_provider=retrieval_config.get("custom_llm_provider"),
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
@client
def ingest(
ingest_options: Dict[str, Any],
file_data: Optional[Tuple[str, bytes, str]] = None,
file: Optional[Dict[str, str]] = None,
file_url: Optional[str] = None,
file_id: Optional[str] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
**kwargs,
) -> Union[RAGIngestResponse, Coroutine[Any, Any, RAGIngestResponse]]:
"""
Ingest a document into a vector store.
Args:
ingest_options: Configuration for the ingest pipeline
file_data: Tuple of (filename, content_bytes, content_type)
file: Dict with {filename, content (base64), content_type} - for JSON API
file_url: URL to fetch file from
file_id: Existing file ID to use
Example:
```python
response = litellm.ingest(
ingest_options={
"vector_store": {
"custom_llm_provider": "openai",
"litellm_credential_name": "my-openai-creds", # optional
}
},
file_data=("doc.txt", b"Hello world", "text/plain"),
)
```
"""
import base64
local_vars = locals()
try:
_is_async = kwargs.pop("aingest", False) is True
router: Optional["Router"] = kwargs.get("router")
# Convert file dict to file_data tuple if provided
if file is not None and file_data is None:
filename = file.get("filename", "document")
content_b64 = file.get("content", "")
content_type = file.get("content_type", "application/octet-stream")
content_bytes = base64.b64decode(content_b64)
file_data = (filename, content_bytes, content_type)
if _is_async:
return _execute_ingest_pipeline(
ingest_options=ingest_options, # type: ignore
file_data=file_data,
file_url=file_url,
file_id=file_id,
router=router,
)
else:
return asyncio.get_event_loop().run_until_complete(
_execute_ingest_pipeline(
ingest_options=ingest_options, # type: ignore
file_data=file_data,
file_url=file_url,
file_id=file_id,
router=router,
)
)
except Exception as e:
raise litellm.exception_type(
model=None,
custom_llm_provider=ingest_options.get("vector_store", {}).get(
"custom_llm_provider"
),
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)

View File

@@ -0,0 +1,120 @@
from typing import Any, Dict, List, Optional, Union
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
from litellm.types.utils import ModelResponse
from litellm.types.vector_stores import (
VectorStoreResultContent,
VectorStoreSearchResponse,
)
class RAGQuery:
CONTENT_PREFIX_STRING = "Context:\n\n"
@staticmethod
def extract_query_from_messages(messages: List[AllMessageValues]) -> Optional[str]:
"""
Extract the query from the last user message.
"""
if not messages or len(messages) == 0:
return None
last_message = messages[-1]
if not isinstance(last_message, dict) or "content" not in last_message:
return None
content = last_message["content"]
if isinstance(content, str):
return content
elif isinstance(content, list) and len(content) > 0:
# Handle list of content items, extract text from first text item
for item in content:
if (
isinstance(item, dict)
and item.get("type") == "text"
and "text" in item
):
return item["text"]
return None
@staticmethod
def build_context_message(context_chunks: List[Any]) -> ChatCompletionUserMessage:
"""
Process search results and build a context message.
"""
context_content = RAGQuery.CONTENT_PREFIX_STRING
for chunk in context_chunks:
if isinstance(chunk, dict):
result_content: Optional[List[VectorStoreResultContent]] = chunk.get(
"content"
)
if result_content:
for content_item in result_content:
content_text: Optional[str] = content_item.get("text")
if content_text:
context_content += content_text + "\n\n"
elif "text" in chunk: # Fallback for simple dict with text
context_content += chunk["text"] + "\n\n"
elif isinstance(chunk, str):
context_content += chunk + "\n\n"
return {
"role": "user",
"content": context_content,
}
@staticmethod
def add_search_results_to_response(
response: ModelResponse,
search_results: VectorStoreSearchResponse,
rerank_results: Optional[Any] = None,
) -> ModelResponse:
"""
Add search results to the response choices.
"""
if hasattr(response, "choices") and response.choices:
for choice in response.choices:
message = getattr(choice, "message", None)
if message is not None:
# Get existing provider_specific_fields or create new dict
provider_fields = (
getattr(message, "provider_specific_fields", None) or {}
)
# Add search results
provider_fields["search_results"] = search_results
if rerank_results:
provider_fields["rerank_results"] = rerank_results
# Set the provider_specific_fields
setattr(message, "provider_specific_fields", provider_fields)
return response
@staticmethod
def extract_documents_from_search(
search_response: Any,
) -> List[Union[str, Dict[str, Any]]]:
"""Extract text documents from vector store search response."""
documents: List[Union[str, Dict[str, Any]]] = []
for result in search_response.get("data", []):
content_list = result.get("content", [])
for content in content_list:
if content.get("type") == "text" and content.get("text"):
documents.append(content["text"])
return documents
@staticmethod
def get_top_chunks_from_rerank(
search_response: Any, rerank_response: Any
) -> List[Any]:
"""Get the original search results corresponding to the top reranked results."""
top_chunks = []
original_results = search_response.get("data", [])
for result in rerank_response.get("results", []):
index = result.get("index")
if index is not None and index < len(original_results):
top_chunks.append(original_results[index])
return top_chunks

View File

@@ -0,0 +1,9 @@
"""
Text splitting utilities for RAG ingestion.
"""
from litellm.rag.text_splitters.recursive_character_text_splitter import (
RecursiveCharacterTextSplitter,
)
__all__ = ["RecursiveCharacterTextSplitter"]

View File

@@ -0,0 +1,141 @@
"""
RecursiveCharacterTextSplitter for RAG ingestion.
A simple implementation that splits text recursively by different separators.
"""
from typing import List, Optional
from litellm.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
class RecursiveCharacterTextSplitter:
"""
Split text recursively by different separators.
Tries to split by the first separator, then recursively splits
by subsequent separators if chunks are still too large.
"""
def __init__(
self,
chunk_size: int = DEFAULT_CHUNK_SIZE,
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
separators: Optional[List[str]] = None,
):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.separators = separators or ["\n\n", "\n", " ", ""]
def split_text(self, text: str) -> List[str]:
"""Split text into chunks."""
return self._split_text(text, self.separators)
def _split_text(
self, text: str, separators: List[str], depth: int = 0
) -> List[str]:
"""Recursively split text using separators."""
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
if depth > DEFAULT_MAX_RECURSE_DEPTH:
# Max depth reached, return text as-is split into chunk_size pieces
return [
text[i : i + self.chunk_size]
for i in range(0, len(text), self.chunk_size)
]
final_chunks: List[str] = []
# Get the appropriate separator
separator = separators[-1]
new_separators: List[str] = []
for i, sep in enumerate(separators):
if sep == "":
separator = sep
break
if sep in text:
separator = sep
new_separators = separators[i + 1 :]
break
# Split by the chosen separator
if separator:
splits = text.split(separator)
else:
splits = list(text)
# Merge splits into chunks
good_splits: List[str] = []
for split in splits:
if len(split) < self.chunk_size:
good_splits.append(split)
else:
# Chunk is too big, merge what we have and recurse
if good_splits:
merged = self._merge_splits(good_splits, separator)
final_chunks.extend(merged)
good_splits = []
if new_separators:
# Recursively split with finer separators
other_chunks = self._split_text(split, new_separators, depth + 1)
final_chunks.extend(other_chunks)
else:
# No more separators, force split
final_chunks.extend(self._force_split(split))
# Merge remaining good splits
if good_splits:
merged = self._merge_splits(good_splits, separator)
final_chunks.extend(merged)
return final_chunks
def _merge_splits(self, splits: List[str], separator: str) -> List[str]:
"""Merge splits into chunks respecting chunk_size and chunk_overlap."""
chunks: List[str] = []
current_chunk: List[str] = []
current_length = 0
for split in splits:
split_len = len(split)
sep_len = len(separator) if current_chunk else 0
if current_length + split_len + sep_len > self.chunk_size:
if current_chunk:
chunk_text = separator.join(current_chunk).strip()
if chunk_text:
chunks.append(chunk_text)
# Handle overlap
while (
current_length > self.chunk_overlap and len(current_chunk) > 1
):
removed = current_chunk.pop(0)
current_length -= len(removed) + len(separator)
current_chunk.append(split)
current_length += split_len + sep_len
# Add remaining
if current_chunk:
chunk_text = separator.join(current_chunk).strip()
if chunk_text:
chunks.append(chunk_text)
return chunks
def _force_split(self, text: str) -> List[str]:
"""Force split text by chunk_size when no separator works."""
chunks: List[str] = []
start = 0
while start < len(text):
end = start + self.chunk_size
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
start = end - self.chunk_overlap if end < len(text) else len(text)
return chunks

View File

@@ -0,0 +1,64 @@
"""
RAG utility functions.
Provides provider configuration utilities similar to ProviderConfigManager.
"""
from typing import TYPE_CHECKING, Type
if TYPE_CHECKING:
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
def get_rag_ingestion_class(custom_llm_provider: str) -> Type["BaseRAGIngestion"]:
"""
Get the appropriate RAG ingestion class for a provider.
Args:
custom_llm_provider: The LLM provider name (e.g., "openai", "bedrock", "vertex_ai")
Returns:
The ingestion class for the provider
Raises:
ValueError: If the provider is not supported
"""
from litellm.llms.vertex_ai.rag_engine.ingestion import VertexAIRAGIngestion
from litellm.rag.ingestion.bedrock_ingestion import BedrockRAGIngestion
from litellm.rag.ingestion.openai_ingestion import OpenAIRAGIngestion
provider_map = {
"openai": OpenAIRAGIngestion,
"bedrock": BedrockRAGIngestion,
"vertex_ai": VertexAIRAGIngestion,
}
ingestion_class = provider_map.get(custom_llm_provider)
if ingestion_class is None:
raise ValueError(
f"RAG ingestion not supported for provider: {custom_llm_provider}. "
f"Supported providers: {list(provider_map.keys())}"
)
return ingestion_class
def get_rag_transformation_class(custom_llm_provider: str):
"""
Get the appropriate RAG transformation class for a provider.
Args:
custom_llm_provider: The LLM provider name
Returns:
The transformation class for the provider, or None if not needed
"""
if custom_llm_provider == "vertex_ai":
from litellm.llms.vertex_ai.rag_engine.transformation import (
VertexAIRAGTransformation,
)
return VertexAIRAGTransformation
# OpenAI and Bedrock don't need special transformations
return None