479 lines
17 KiB
Python
479 lines
17 KiB
Python
"""
|
|
Vertex AI-specific RAG Ingestion implementation.
|
|
|
|
Vertex AI RAG Engine handles embedding and chunking internally when files are uploaded,
|
|
so this implementation skips the embedding step and directly uploads files to RAG corpora.
|
|
|
|
Based on: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/model-reference/rag-api-v1
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
|
|
|
from litellm._logging import verbose_logger
|
|
from litellm.llms.custom_httpx.http_handler import (
|
|
get_async_httpx_client,
|
|
httpxSpecialProvider,
|
|
)
|
|
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
|
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
|
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
|
|
|
|
if TYPE_CHECKING:
|
|
from litellm import Router
|
|
from litellm.types.rag import RAGIngestOptions
|
|
|
|
|
|
class VertexAIRAGIngestion(BaseRAGIngestion, VertexBase):
|
|
"""
|
|
Vertex AI RAG Engine ingestion implementation.
|
|
|
|
Key differences from base:
|
|
- Embedding is handled by Vertex AI RAG Engine when files are uploaded
|
|
- Files are uploaded using the RAG API (import or upload)
|
|
- Chunking is done by Vertex AI RAG Engine (supports custom chunking config)
|
|
- Supports Google Cloud Storage (GCS) and Google Drive sources
|
|
- Supports custom parsing configurations (layout parser, LLM parser)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
ingest_options: "RAGIngestOptions",
|
|
router: Optional["Router"] = None,
|
|
):
|
|
BaseRAGIngestion.__init__(self, ingest_options=ingest_options, router=router)
|
|
VertexBase.__init__(self)
|
|
|
|
# Extract Vertex AI specific configs from vector_store_config
|
|
litellm_params = dict(self.vector_store_config)
|
|
|
|
# Get project, location, and credentials using VertexBase methods
|
|
self.project_id = self.safe_get_vertex_ai_project(litellm_params)
|
|
self.location = self.get_vertex_ai_location(litellm_params) or "us-central1"
|
|
self.vertex_credentials = self.safe_get_vertex_ai_credentials(litellm_params)
|
|
|
|
async def embed(
|
|
self,
|
|
chunks: List[str],
|
|
) -> Optional[List[List[float]]]:
|
|
"""
|
|
Vertex AI RAG Engine handles embedding internally - skip this step.
|
|
|
|
Returns:
|
|
None (Vertex AI embeds when files are uploaded to RAG corpus)
|
|
"""
|
|
# Vertex AI RAG Engine handles embedding when files are uploaded
|
|
return None
|
|
|
|
async def store(
|
|
self,
|
|
file_content: Optional[bytes],
|
|
filename: Optional[str],
|
|
content_type: Optional[str],
|
|
chunks: List[str],
|
|
embeddings: Optional[List[List[float]]],
|
|
) -> Tuple[Optional[str], Optional[str]]:
|
|
"""
|
|
Store content in Vertex AI RAG corpus.
|
|
|
|
Vertex AI workflow:
|
|
1. Create RAG corpus (if not provided)
|
|
2. Upload file using RAG API (Vertex AI handles chunking/embedding)
|
|
|
|
Args:
|
|
file_content: Raw file bytes
|
|
filename: Name of the file
|
|
content_type: MIME type
|
|
chunks: Ignored - Vertex AI handles chunking
|
|
embeddings: Ignored - Vertex AI handles embedding
|
|
|
|
Returns:
|
|
Tuple of (rag_corpus_id, file_id)
|
|
"""
|
|
if not self.project_id:
|
|
raise ValueError(
|
|
"vertex_project is required for Vertex AI RAG ingestion. "
|
|
"Set it in vector_store config."
|
|
)
|
|
|
|
# Get or create RAG corpus
|
|
rag_corpus_id = self.vector_store_config.get("vector_store_id")
|
|
if not rag_corpus_id:
|
|
rag_corpus_id = await self._create_rag_corpus(
|
|
display_name=self.ingest_name or "litellm-rag-corpus",
|
|
description=self.vector_store_config.get("description"),
|
|
)
|
|
|
|
# Upload file to RAG corpus
|
|
result_file_id = None
|
|
if file_content and filename and rag_corpus_id:
|
|
result_file_id = await self._upload_file_to_corpus(
|
|
rag_corpus_id=rag_corpus_id,
|
|
filename=filename,
|
|
file_content=file_content,
|
|
content_type=content_type,
|
|
)
|
|
|
|
return rag_corpus_id, result_file_id
|
|
|
|
async def _create_rag_corpus(
|
|
self,
|
|
display_name: str,
|
|
description: Optional[str] = None,
|
|
) -> str:
|
|
"""
|
|
Create a Vertex AI RAG corpus.
|
|
|
|
Args:
|
|
display_name: Display name for the corpus
|
|
description: Optional description
|
|
|
|
Returns:
|
|
RAG corpus ID (format: projects/{project}/locations/{location}/ragCorpora/{corpus_id})
|
|
"""
|
|
# Get access token using VertexBase method
|
|
access_token, project_id = self._ensure_access_token(
|
|
credentials=self.vertex_credentials,
|
|
project_id=self.project_id,
|
|
custom_llm_provider="vertex_ai",
|
|
)
|
|
|
|
# Use the project_id from token if not set
|
|
if not self.project_id:
|
|
self.project_id = project_id
|
|
|
|
# Construct URL using vertex base URL helper
|
|
base_url = get_vertex_base_url(self.location)
|
|
url = (
|
|
f"{base_url}/v1beta1/"
|
|
f"projects/{self.project_id}/locations/{self.location}/ragCorpora"
|
|
)
|
|
|
|
# Build request body with camelCase keys (Vertex AI API format)
|
|
request_body: Dict[str, Any] = {
|
|
"displayName": display_name,
|
|
}
|
|
|
|
if description:
|
|
request_body["description"] = description
|
|
|
|
# Add vector database config if specified
|
|
vector_db_config = self.vector_store_config.get("vector_db_config")
|
|
if vector_db_config:
|
|
request_body["vectorDbConfig"] = vector_db_config
|
|
|
|
# Add embedding model config if specified
|
|
embedding_model = self.vector_store_config.get("embedding_model")
|
|
if embedding_model:
|
|
if "vectorDbConfig" not in request_body:
|
|
request_body["vectorDbConfig"] = {}
|
|
request_body["vectorDbConfig"]["ragEmbeddingModelConfig"] = {
|
|
"vertexPredictionEndpoint": {"endpoint": embedding_model}
|
|
}
|
|
|
|
verbose_logger.debug(f"Creating RAG corpus: {url}")
|
|
verbose_logger.debug(f"Request body: {json.dumps(request_body, indent=2)}")
|
|
|
|
client = get_async_httpx_client(
|
|
llm_provider=httpxSpecialProvider.RAG,
|
|
params={"timeout": 60.0},
|
|
)
|
|
|
|
response = await client.post(
|
|
url,
|
|
json=request_body,
|
|
headers={
|
|
"Authorization": f"Bearer {access_token}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
)
|
|
if response.status_code not in [200, 201]:
|
|
error_msg = f"Failed to create RAG corpus: {response.text}"
|
|
verbose_logger.error(error_msg)
|
|
raise Exception(error_msg)
|
|
|
|
response_data = response.json()
|
|
verbose_logger.debug(
|
|
f"Create corpus response: {json.dumps(response_data, indent=2)}"
|
|
)
|
|
|
|
# The response is a long-running operation
|
|
# Check if it's already done or if we need to poll
|
|
if response_data.get("done"):
|
|
# Operation completed immediately
|
|
corpus_name = response_data.get("response", {}).get("name", "")
|
|
else:
|
|
# Need to poll the operation
|
|
operation_name = response_data.get("name", "")
|
|
verbose_logger.debug(f"Polling operation: {operation_name}")
|
|
corpus_name = await self._poll_operation(
|
|
operation_name=operation_name,
|
|
access_token=access_token,
|
|
)
|
|
|
|
verbose_logger.debug(f"Created RAG corpus: {corpus_name}")
|
|
return corpus_name
|
|
|
|
async def _poll_operation(
|
|
self,
|
|
operation_name: str,
|
|
access_token: str,
|
|
max_retries: int = 30,
|
|
retry_delay: float = 2.0,
|
|
) -> str:
|
|
"""
|
|
Poll a long-running operation until it completes.
|
|
|
|
Args:
|
|
operation_name: The operation name (e.g., "operations/123456")
|
|
access_token: Access token for authentication
|
|
max_retries: Maximum number of polling attempts
|
|
retry_delay: Delay between polling attempts in seconds
|
|
|
|
Returns:
|
|
The corpus name from the completed operation
|
|
|
|
Raises:
|
|
Exception: If operation fails or times out
|
|
"""
|
|
import asyncio
|
|
|
|
base_url = get_vertex_base_url(self.location)
|
|
# Operation name is like: projects/{project}/locations/{location}/operations/{operation_id}
|
|
# We need to construct the full URL
|
|
url = f"{base_url}/v1beta1/{operation_name}"
|
|
|
|
client = get_async_httpx_client(
|
|
llm_provider=httpxSpecialProvider.RAG,
|
|
params={"timeout": 60.0},
|
|
)
|
|
|
|
for attempt in range(max_retries):
|
|
response = await client.get(
|
|
url,
|
|
headers={
|
|
"Authorization": f"Bearer {access_token}",
|
|
},
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
error_msg = f"Failed to poll operation: {response.text}"
|
|
verbose_logger.error(error_msg)
|
|
raise Exception(error_msg)
|
|
|
|
operation_data = response.json()
|
|
|
|
if operation_data.get("done"):
|
|
# Check for errors
|
|
if "error" in operation_data:
|
|
error = operation_data["error"]
|
|
raise Exception(f"Operation failed: {error}")
|
|
|
|
# Extract corpus name from response
|
|
corpus_name = operation_data.get("response", {}).get("name", "")
|
|
if corpus_name:
|
|
return corpus_name
|
|
else:
|
|
raise Exception(
|
|
f"No corpus name in operation response: {operation_data}"
|
|
)
|
|
|
|
verbose_logger.debug(
|
|
f"Operation not done yet, attempt {attempt + 1}/{max_retries}"
|
|
)
|
|
await asyncio.sleep(retry_delay)
|
|
|
|
raise Exception(f"Operation timed out after {max_retries} attempts")
|
|
|
|
async def _upload_file_to_corpus(
|
|
self,
|
|
rag_corpus_id: str,
|
|
filename: str,
|
|
file_content: bytes,
|
|
content_type: Optional[str],
|
|
) -> str:
|
|
"""
|
|
Upload a file to Vertex AI RAG corpus using multipart upload.
|
|
|
|
Args:
|
|
rag_corpus_id: RAG corpus resource name
|
|
filename: Name of the file
|
|
file_content: File content bytes
|
|
content_type: MIME type
|
|
|
|
Returns:
|
|
File ID or resource name
|
|
"""
|
|
# Get access token using VertexBase method
|
|
access_token, _ = self._ensure_access_token(
|
|
credentials=self.vertex_credentials,
|
|
project_id=self.project_id,
|
|
custom_llm_provider="vertex_ai",
|
|
)
|
|
|
|
# Construct upload URL using vertex base URL helper
|
|
base_url = get_vertex_base_url(self.location)
|
|
url = f"{base_url}/upload/v1beta1/" f"{rag_corpus_id}/ragFiles:upload"
|
|
|
|
# Build metadata for the file with snake_case keys (as per upload API docs)
|
|
metadata: Dict[str, Any] = {
|
|
"rag_file": {
|
|
"display_name": filename,
|
|
}
|
|
}
|
|
|
|
# Add description if provided
|
|
description = self.vector_store_config.get("file_description")
|
|
if description:
|
|
metadata["rag_file"]["description"] = description
|
|
|
|
# Add chunking configuration if provided
|
|
chunking_strategy = self.chunking_strategy
|
|
if chunking_strategy and isinstance(chunking_strategy, dict):
|
|
chunk_size = chunking_strategy.get("chunk_size")
|
|
chunk_overlap = chunking_strategy.get("chunk_overlap")
|
|
|
|
if chunk_size or chunk_overlap:
|
|
if "upload_rag_file_config" not in metadata:
|
|
metadata["upload_rag_file_config"] = {}
|
|
|
|
metadata["upload_rag_file_config"]["rag_file_transformation_config"] = {
|
|
"rag_file_chunking_config": {"fixed_length_chunking": {}}
|
|
}
|
|
|
|
chunking_config = metadata["upload_rag_file_config"][
|
|
"rag_file_transformation_config"
|
|
]["rag_file_chunking_config"]["fixed_length_chunking"]
|
|
|
|
if chunk_size:
|
|
chunking_config["chunk_size"] = chunk_size
|
|
if chunk_overlap:
|
|
chunking_config["chunk_overlap"] = chunk_overlap
|
|
|
|
verbose_logger.debug(f"Uploading file to RAG corpus: {url}")
|
|
verbose_logger.debug(f"Metadata: {json.dumps(metadata, indent=2)}")
|
|
|
|
# Prepare multipart form data
|
|
files = {
|
|
"metadata": (None, json.dumps(metadata), "application/json"),
|
|
"file": (
|
|
filename,
|
|
file_content,
|
|
content_type or "application/octet-stream",
|
|
),
|
|
}
|
|
client = get_async_httpx_client(
|
|
llm_provider=httpxSpecialProvider.RAG,
|
|
params={"timeout": 300.0}, # Longer timeout for large files
|
|
)
|
|
|
|
response = await client.post(
|
|
url,
|
|
files=files,
|
|
headers={
|
|
"Authorization": f"Bearer {access_token}",
|
|
"X-Goog-Upload-Protocol": "multipart",
|
|
},
|
|
)
|
|
|
|
if response.status_code not in [200, 201]:
|
|
error_msg = f"Failed to upload file: {response.text}"
|
|
verbose_logger.error(error_msg)
|
|
raise Exception(error_msg)
|
|
|
|
# Parse response to get file ID
|
|
try:
|
|
response_data = response.json()
|
|
# The response should contain the rag_file resource name
|
|
file_id = response_data.get("ragFile", {}).get("name", "")
|
|
if not file_id:
|
|
file_id = response_data.get("name", "")
|
|
|
|
verbose_logger.debug(f"Upload complete. File ID: {file_id}")
|
|
return file_id
|
|
except Exception as e:
|
|
verbose_logger.warning(f"Could not parse upload response: {e}")
|
|
return "uploaded"
|
|
|
|
async def _import_files_from_gcs(
|
|
self,
|
|
rag_corpus_id: str,
|
|
gcs_uris: List[str],
|
|
) -> str:
|
|
"""
|
|
Import files from Google Cloud Storage into RAG corpus.
|
|
|
|
Args:
|
|
rag_corpus_id: RAG corpus resource name
|
|
gcs_uris: List of GCS URIs (e.g., ["gs://bucket/file.pdf"])
|
|
|
|
Returns:
|
|
Operation name for tracking import progress
|
|
"""
|
|
# Get access token using VertexBase method
|
|
access_token, _ = self._ensure_access_token(
|
|
credentials=self.vertex_credentials,
|
|
project_id=self.project_id,
|
|
custom_llm_provider="vertex_ai",
|
|
)
|
|
|
|
# Construct import URL using vertex base URL helper
|
|
base_url = get_vertex_base_url(self.location)
|
|
url = f"{base_url}/v1beta1/" f"{rag_corpus_id}/ragFiles:import"
|
|
|
|
# Build request body with camelCase keys (Vertex AI API format)
|
|
request_body: Dict[str, Any] = {
|
|
"importRagFilesConfig": {"gcsSource": {"uris": gcs_uris}}
|
|
}
|
|
|
|
# Add chunking configuration if provided
|
|
chunking_strategy = self.chunking_strategy
|
|
if chunking_strategy and isinstance(chunking_strategy, dict):
|
|
chunk_size = chunking_strategy.get("chunk_size")
|
|
chunk_overlap = chunking_strategy.get("chunk_overlap")
|
|
|
|
if chunk_size or chunk_overlap:
|
|
request_body["importRagFilesConfig"]["ragFileChunkingConfig"] = {
|
|
"chunkSize": chunk_size or 1024,
|
|
"chunkOverlap": chunk_overlap or 200,
|
|
}
|
|
|
|
# Add max embedding requests per minute if specified
|
|
max_embedding_qpm = self.vector_store_config.get(
|
|
"max_embedding_requests_per_min"
|
|
)
|
|
if max_embedding_qpm:
|
|
request_body["importRagFilesConfig"][
|
|
"maxEmbeddingRequestsPerMin"
|
|
] = max_embedding_qpm
|
|
|
|
verbose_logger.debug(f"Importing files from GCS: {url}")
|
|
verbose_logger.debug(f"Request body: {json.dumps(request_body, indent=2)}")
|
|
|
|
client = get_async_httpx_client(
|
|
llm_provider=httpxSpecialProvider.RAG,
|
|
params={"timeout": 60.0},
|
|
)
|
|
|
|
response = await client.post(
|
|
url,
|
|
json=request_body,
|
|
headers={
|
|
"Authorization": f"Bearer {access_token}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
)
|
|
|
|
if response.status_code not in [200, 201]:
|
|
error_msg = f"Failed to import files: {response.text}"
|
|
verbose_logger.error(error_msg)
|
|
raise Exception(error_msg)
|
|
|
|
response_data = response.json()
|
|
operation_name = response_data.get("name", "")
|
|
|
|
verbose_logger.debug(f"Import operation started: {operation_name}")
|
|
return operation_name
|