""" Vertex AI-specific RAG Ingestion implementation. Vertex AI RAG Engine handles embedding and chunking internally when files are uploaded, so this implementation skips the embedding step and directly uploads files to RAG corpora. Based on: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/model-reference/rag-api-v1 """ from __future__ import annotations import json from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from litellm._logging import verbose_logger from litellm.llms.custom_httpx.http_handler import ( get_async_httpx_client, httpxSpecialProvider, ) from litellm.llms.vertex_ai.common_utils import get_vertex_base_url from litellm.llms.vertex_ai.vertex_llm_base import VertexBase from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion if TYPE_CHECKING: from litellm import Router from litellm.types.rag import RAGIngestOptions class VertexAIRAGIngestion(BaseRAGIngestion, VertexBase): """ Vertex AI RAG Engine ingestion implementation. Key differences from base: - Embedding is handled by Vertex AI RAG Engine when files are uploaded - Files are uploaded using the RAG API (import or upload) - Chunking is done by Vertex AI RAG Engine (supports custom chunking config) - Supports Google Cloud Storage (GCS) and Google Drive sources - Supports custom parsing configurations (layout parser, LLM parser) """ def __init__( self, ingest_options: "RAGIngestOptions", router: Optional["Router"] = None, ): BaseRAGIngestion.__init__(self, ingest_options=ingest_options, router=router) VertexBase.__init__(self) # Extract Vertex AI specific configs from vector_store_config litellm_params = dict(self.vector_store_config) # Get project, location, and credentials using VertexBase methods self.project_id = self.safe_get_vertex_ai_project(litellm_params) self.location = self.get_vertex_ai_location(litellm_params) or "us-central1" self.vertex_credentials = self.safe_get_vertex_ai_credentials(litellm_params) async def embed( self, chunks: List[str], ) -> Optional[List[List[float]]]: """ Vertex AI RAG Engine handles embedding internally - skip this step. Returns: None (Vertex AI embeds when files are uploaded to RAG corpus) """ # Vertex AI RAG Engine handles embedding when files are uploaded return None async def store( self, file_content: Optional[bytes], filename: Optional[str], content_type: Optional[str], chunks: List[str], embeddings: Optional[List[List[float]]], ) -> Tuple[Optional[str], Optional[str]]: """ Store content in Vertex AI RAG corpus. Vertex AI workflow: 1. Create RAG corpus (if not provided) 2. Upload file using RAG API (Vertex AI handles chunking/embedding) Args: file_content: Raw file bytes filename: Name of the file content_type: MIME type chunks: Ignored - Vertex AI handles chunking embeddings: Ignored - Vertex AI handles embedding Returns: Tuple of (rag_corpus_id, file_id) """ if not self.project_id: raise ValueError( "vertex_project is required for Vertex AI RAG ingestion. " "Set it in vector_store config." ) # Get or create RAG corpus rag_corpus_id = self.vector_store_config.get("vector_store_id") if not rag_corpus_id: rag_corpus_id = await self._create_rag_corpus( display_name=self.ingest_name or "litellm-rag-corpus", description=self.vector_store_config.get("description"), ) # Upload file to RAG corpus result_file_id = None if file_content and filename and rag_corpus_id: result_file_id = await self._upload_file_to_corpus( rag_corpus_id=rag_corpus_id, filename=filename, file_content=file_content, content_type=content_type, ) return rag_corpus_id, result_file_id async def _create_rag_corpus( self, display_name: str, description: Optional[str] = None, ) -> str: """ Create a Vertex AI RAG corpus. Args: display_name: Display name for the corpus description: Optional description Returns: RAG corpus ID (format: projects/{project}/locations/{location}/ragCorpora/{corpus_id}) """ # Get access token using VertexBase method access_token, project_id = self._ensure_access_token( credentials=self.vertex_credentials, project_id=self.project_id, custom_llm_provider="vertex_ai", ) # Use the project_id from token if not set if not self.project_id: self.project_id = project_id # Construct URL using vertex base URL helper base_url = get_vertex_base_url(self.location) url = ( f"{base_url}/v1beta1/" f"projects/{self.project_id}/locations/{self.location}/ragCorpora" ) # Build request body with camelCase keys (Vertex AI API format) request_body: Dict[str, Any] = { "displayName": display_name, } if description: request_body["description"] = description # Add vector database config if specified vector_db_config = self.vector_store_config.get("vector_db_config") if vector_db_config: request_body["vectorDbConfig"] = vector_db_config # Add embedding model config if specified embedding_model = self.vector_store_config.get("embedding_model") if embedding_model: if "vectorDbConfig" not in request_body: request_body["vectorDbConfig"] = {} request_body["vectorDbConfig"]["ragEmbeddingModelConfig"] = { "vertexPredictionEndpoint": {"endpoint": embedding_model} } verbose_logger.debug(f"Creating RAG corpus: {url}") verbose_logger.debug(f"Request body: {json.dumps(request_body, indent=2)}") client = get_async_httpx_client( llm_provider=httpxSpecialProvider.RAG, params={"timeout": 60.0}, ) response = await client.post( url, json=request_body, headers={ "Authorization": f"Bearer {access_token}", "Content-Type": "application/json", }, ) if response.status_code not in [200, 201]: error_msg = f"Failed to create RAG corpus: {response.text}" verbose_logger.error(error_msg) raise Exception(error_msg) response_data = response.json() verbose_logger.debug( f"Create corpus response: {json.dumps(response_data, indent=2)}" ) # The response is a long-running operation # Check if it's already done or if we need to poll if response_data.get("done"): # Operation completed immediately corpus_name = response_data.get("response", {}).get("name", "") else: # Need to poll the operation operation_name = response_data.get("name", "") verbose_logger.debug(f"Polling operation: {operation_name}") corpus_name = await self._poll_operation( operation_name=operation_name, access_token=access_token, ) verbose_logger.debug(f"Created RAG corpus: {corpus_name}") return corpus_name async def _poll_operation( self, operation_name: str, access_token: str, max_retries: int = 30, retry_delay: float = 2.0, ) -> str: """ Poll a long-running operation until it completes. Args: operation_name: The operation name (e.g., "operations/123456") access_token: Access token for authentication max_retries: Maximum number of polling attempts retry_delay: Delay between polling attempts in seconds Returns: The corpus name from the completed operation Raises: Exception: If operation fails or times out """ import asyncio base_url = get_vertex_base_url(self.location) # Operation name is like: projects/{project}/locations/{location}/operations/{operation_id} # We need to construct the full URL url = f"{base_url}/v1beta1/{operation_name}" client = get_async_httpx_client( llm_provider=httpxSpecialProvider.RAG, params={"timeout": 60.0}, ) for attempt in range(max_retries): response = await client.get( url, headers={ "Authorization": f"Bearer {access_token}", }, ) if response.status_code != 200: error_msg = f"Failed to poll operation: {response.text}" verbose_logger.error(error_msg) raise Exception(error_msg) operation_data = response.json() if operation_data.get("done"): # Check for errors if "error" in operation_data: error = operation_data["error"] raise Exception(f"Operation failed: {error}") # Extract corpus name from response corpus_name = operation_data.get("response", {}).get("name", "") if corpus_name: return corpus_name else: raise Exception( f"No corpus name in operation response: {operation_data}" ) verbose_logger.debug( f"Operation not done yet, attempt {attempt + 1}/{max_retries}" ) await asyncio.sleep(retry_delay) raise Exception(f"Operation timed out after {max_retries} attempts") async def _upload_file_to_corpus( self, rag_corpus_id: str, filename: str, file_content: bytes, content_type: Optional[str], ) -> str: """ Upload a file to Vertex AI RAG corpus using multipart upload. Args: rag_corpus_id: RAG corpus resource name filename: Name of the file file_content: File content bytes content_type: MIME type Returns: File ID or resource name """ # Get access token using VertexBase method access_token, _ = self._ensure_access_token( credentials=self.vertex_credentials, project_id=self.project_id, custom_llm_provider="vertex_ai", ) # Construct upload URL using vertex base URL helper base_url = get_vertex_base_url(self.location) url = f"{base_url}/upload/v1beta1/" f"{rag_corpus_id}/ragFiles:upload" # Build metadata for the file with snake_case keys (as per upload API docs) metadata: Dict[str, Any] = { "rag_file": { "display_name": filename, } } # Add description if provided description = self.vector_store_config.get("file_description") if description: metadata["rag_file"]["description"] = description # Add chunking configuration if provided chunking_strategy = self.chunking_strategy if chunking_strategy and isinstance(chunking_strategy, dict): chunk_size = chunking_strategy.get("chunk_size") chunk_overlap = chunking_strategy.get("chunk_overlap") if chunk_size or chunk_overlap: if "upload_rag_file_config" not in metadata: metadata["upload_rag_file_config"] = {} metadata["upload_rag_file_config"]["rag_file_transformation_config"] = { "rag_file_chunking_config": {"fixed_length_chunking": {}} } chunking_config = metadata["upload_rag_file_config"][ "rag_file_transformation_config" ]["rag_file_chunking_config"]["fixed_length_chunking"] if chunk_size: chunking_config["chunk_size"] = chunk_size if chunk_overlap: chunking_config["chunk_overlap"] = chunk_overlap verbose_logger.debug(f"Uploading file to RAG corpus: {url}") verbose_logger.debug(f"Metadata: {json.dumps(metadata, indent=2)}") # Prepare multipart form data files = { "metadata": (None, json.dumps(metadata), "application/json"), "file": ( filename, file_content, content_type or "application/octet-stream", ), } client = get_async_httpx_client( llm_provider=httpxSpecialProvider.RAG, params={"timeout": 300.0}, # Longer timeout for large files ) response = await client.post( url, files=files, headers={ "Authorization": f"Bearer {access_token}", "X-Goog-Upload-Protocol": "multipart", }, ) if response.status_code not in [200, 201]: error_msg = f"Failed to upload file: {response.text}" verbose_logger.error(error_msg) raise Exception(error_msg) # Parse response to get file ID try: response_data = response.json() # The response should contain the rag_file resource name file_id = response_data.get("ragFile", {}).get("name", "") if not file_id: file_id = response_data.get("name", "") verbose_logger.debug(f"Upload complete. File ID: {file_id}") return file_id except Exception as e: verbose_logger.warning(f"Could not parse upload response: {e}") return "uploaded" async def _import_files_from_gcs( self, rag_corpus_id: str, gcs_uris: List[str], ) -> str: """ Import files from Google Cloud Storage into RAG corpus. Args: rag_corpus_id: RAG corpus resource name gcs_uris: List of GCS URIs (e.g., ["gs://bucket/file.pdf"]) Returns: Operation name for tracking import progress """ # Get access token using VertexBase method access_token, _ = self._ensure_access_token( credentials=self.vertex_credentials, project_id=self.project_id, custom_llm_provider="vertex_ai", ) # Construct import URL using vertex base URL helper base_url = get_vertex_base_url(self.location) url = f"{base_url}/v1beta1/" f"{rag_corpus_id}/ragFiles:import" # Build request body with camelCase keys (Vertex AI API format) request_body: Dict[str, Any] = { "importRagFilesConfig": {"gcsSource": {"uris": gcs_uris}} } # Add chunking configuration if provided chunking_strategy = self.chunking_strategy if chunking_strategy and isinstance(chunking_strategy, dict): chunk_size = chunking_strategy.get("chunk_size") chunk_overlap = chunking_strategy.get("chunk_overlap") if chunk_size or chunk_overlap: request_body["importRagFilesConfig"]["ragFileChunkingConfig"] = { "chunkSize": chunk_size or 1024, "chunkOverlap": chunk_overlap or 200, } # Add max embedding requests per minute if specified max_embedding_qpm = self.vector_store_config.get( "max_embedding_requests_per_min" ) if max_embedding_qpm: request_body["importRagFilesConfig"][ "maxEmbeddingRequestsPerMin" ] = max_embedding_qpm verbose_logger.debug(f"Importing files from GCS: {url}") verbose_logger.debug(f"Request body: {json.dumps(request_body, indent=2)}") client = get_async_httpx_client( llm_provider=httpxSpecialProvider.RAG, params={"timeout": 60.0}, ) response = await client.post( url, json=request_body, headers={ "Authorization": f"Bearer {access_token}", "Content-Type": "application/json", }, ) if response.status_code not in [200, 201]: error_msg = f"Failed to import files: {response.text}" verbose_logger.error(error_msg) raise Exception(error_msg) response_data = response.json() operation_name = response_data.get("name", "") verbose_logger.debug(f"Import operation started: {operation_name}") return operation_name