chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
RAGFlow provider for LiteLLM.
|
||||
|
||||
RAGFlow provides OpenAI-compatible APIs with unique path structures:
|
||||
- Chat endpoint: /api/v1/chats_openai/{chat_id}/chat/completions
|
||||
- Agent endpoint: /api/v1/agents_openai/{agent_id}/chat/completions
|
||||
"""
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
RAGFlow chat completion configuration.
|
||||
"""
|
||||
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
RAGFlow provider configuration for OpenAI-compatible API.
|
||||
|
||||
RAGFlow provides OpenAI-compatible APIs with unique path structures:
|
||||
- Chat endpoint: /api/v1/chats_openai/{chat_id}/chat/completions
|
||||
- Agent endpoint: /api/v1/agents_openai/{agent_id}/chat/completions
|
||||
|
||||
Model name format:
|
||||
- Chat: ragflow/chat/{chat_id}/{model_name}
|
||||
- Agent: ragflow/agent/{agent_id}/{model_name}
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm.llms.openai.openai import OpenAIConfig
|
||||
from litellm.secret_managers.main import get_secret, get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
class RAGFlowConfig(OpenAIConfig):
|
||||
"""
|
||||
Configuration for RAGFlow OpenAI-compatible API.
|
||||
|
||||
Handles both chat and agent endpoints by parsing the model name format:
|
||||
- ragflow/chat/{chat_id}/{model_name} for chat endpoints
|
||||
- ragflow/agent/{agent_id}/{model_name} for agent endpoints
|
||||
"""
|
||||
|
||||
def _parse_ragflow_model(self, model: str) -> Tuple[str, str, str]:
|
||||
"""
|
||||
Parse RAGFlow model name format: ragflow/{endpoint_type}/{id}/{model_name}
|
||||
|
||||
Args:
|
||||
model: Model name in format ragflow/chat/{chat_id}/{model} or ragflow/agent/{agent_id}/{model}
|
||||
|
||||
Returns:
|
||||
Tuple of (endpoint_type, id, model_name)
|
||||
|
||||
Raises:
|
||||
ValueError: If model format is invalid
|
||||
"""
|
||||
parts = model.split("/")
|
||||
if len(parts) < 4:
|
||||
raise ValueError(
|
||||
f"Invalid RAGFlow model format: {model}. "
|
||||
f"Expected format: ragflow/chat/{{chat_id}}/{{model}} or ragflow/agent/{{agent_id}}/{{model}}"
|
||||
)
|
||||
|
||||
if parts[0] != "ragflow":
|
||||
raise ValueError(
|
||||
f"Invalid RAGFlow model format: {model}. Must start with 'ragflow/'"
|
||||
)
|
||||
|
||||
endpoint_type = parts[1]
|
||||
if endpoint_type not in ["chat", "agent"]:
|
||||
raise ValueError(
|
||||
f"Invalid RAGFlow endpoint type: {endpoint_type}. Must be 'chat' or 'agent'"
|
||||
)
|
||||
|
||||
entity_id = parts[2]
|
||||
model_name = "/".join(
|
||||
parts[3:]
|
||||
) # Handle model names that might contain slashes
|
||||
|
||||
return endpoint_type, entity_id, model_name
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for the RAGFlow API call.
|
||||
|
||||
Constructs URL based on endpoint type:
|
||||
- Chat: /api/v1/chats_openai/{chat_id}/chat/completions
|
||||
- Agent: /api/v1/agents_openai/{agent_id}/chat/completions
|
||||
|
||||
Args:
|
||||
api_base: Base API URL (e.g., http://ragflow-server:port or http://ragflow-server:port/v1)
|
||||
api_key: API key (not used in URL construction)
|
||||
model: Model name in format ragflow/{endpoint_type}/{id}/{model}
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters (may contain api_base)
|
||||
stream: Whether streaming is enabled
|
||||
|
||||
Returns:
|
||||
Complete URL for the API call
|
||||
"""
|
||||
# Get api_base from multiple sources: input param, litellm_params, environment, or global litellm setting
|
||||
if (
|
||||
litellm_params
|
||||
and hasattr(litellm_params, "api_base")
|
||||
and litellm_params.api_base
|
||||
):
|
||||
api_base = api_base or litellm_params.api_base
|
||||
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("RAGFLOW_API_BASE")
|
||||
or get_secret_str("RAGFLOW_API_BASE")
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"api_base is required for RAGFlow provider. Set it via api_base parameter, RAGFLOW_API_BASE environment variable, or litellm.api_base"
|
||||
)
|
||||
|
||||
# Parse model name to extract endpoint type and ID
|
||||
endpoint_type, entity_id, _ = self._parse_ragflow_model(model)
|
||||
|
||||
# Remove trailing slash from api_base if present
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
# Strip /v1 or /api/v1 from api_base if present, since we'll add the full path
|
||||
# Check /api/v1 first because /api/v1 ends with /v1
|
||||
if api_base.endswith("/api/v1"):
|
||||
api_base = api_base[:-7] # Remove /api/v1
|
||||
elif api_base.endswith("/v1"):
|
||||
api_base = api_base[:-3] # Remove /v1
|
||||
|
||||
# Construct the RAGFlow-specific path
|
||||
if endpoint_type == "chat":
|
||||
path = f"/api/v1/chats_openai/{entity_id}/chat/completions"
|
||||
else: # agent
|
||||
path = f"/api/v1/agents_openai/{entity_id}/chat/completions"
|
||||
|
||||
# Ensure path starts with /
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
return f"{api_base}{path}"
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
custom_llm_provider: str,
|
||||
) -> Tuple[Optional[str], Optional[str], str]:
|
||||
"""
|
||||
Get OpenAI-compatible provider information for RAGFlow.
|
||||
|
||||
Args:
|
||||
model: Model name (will be parsed to extract actual model name)
|
||||
api_base: Base API URL (from input params)
|
||||
api_key: API key (from input params)
|
||||
custom_llm_provider: Custom LLM provider name
|
||||
|
||||
Returns:
|
||||
Tuple of (api_base, api_key, custom_llm_provider)
|
||||
"""
|
||||
# Parse model to extract the actual model name
|
||||
# The model name will be stored in litellm_params for use in requests
|
||||
_, _, actual_model = self._parse_ragflow_model(model)
|
||||
|
||||
# Get api_base from multiple sources: input param, environment, or global litellm setting
|
||||
dynamic_api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("RAGFLOW_API_BASE")
|
||||
or get_secret_str("RAGFLOW_API_BASE")
|
||||
)
|
||||
|
||||
# Get api_key from multiple sources: input param, environment, or global litellm setting
|
||||
dynamic_api_key = (
|
||||
api_key or litellm.api_key or get_secret_str("RAGFLOW_API_KEY")
|
||||
)
|
||||
|
||||
return dynamic_api_base, dynamic_api_key, custom_llm_provider
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate environment and set up headers for RAGFlow API.
|
||||
|
||||
Args:
|
||||
headers: Request headers
|
||||
model: Model name
|
||||
messages: Chat messages
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters (may contain api_key)
|
||||
api_key: API key (from input params)
|
||||
api_base: Base API URL
|
||||
|
||||
Returns:
|
||||
Updated headers dictionary
|
||||
"""
|
||||
# Use api_key from litellm_params if available, otherwise fall back to other sources
|
||||
if (
|
||||
litellm_params
|
||||
and hasattr(litellm_params, "api_key")
|
||||
and litellm_params.api_key
|
||||
):
|
||||
api_key = api_key or litellm_params.api_key
|
||||
|
||||
# Get api_key from multiple sources: input param, litellm_params, environment, or global litellm setting
|
||||
api_key = api_key or litellm.api_key or get_secret_str("RAGFLOW_API_KEY")
|
||||
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Ensure Content-Type is set to application/json
|
||||
if "content-type" not in headers and "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
# Parse model to extract actual model name and store it
|
||||
# The actual model name should be used in the request body
|
||||
try:
|
||||
_, _, actual_model = self._parse_ragflow_model(model)
|
||||
# Store the actual model name in litellm_params for use in transform_request
|
||||
litellm_params["_ragflow_actual_model"] = actual_model
|
||||
except ValueError:
|
||||
# If parsing fails, use the original model name
|
||||
pass
|
||||
|
||||
return headers
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform request for RAGFlow API.
|
||||
|
||||
Uses the actual model name extracted from the RAGFlow model format.
|
||||
|
||||
Args:
|
||||
model: Model name in RAGFlow format
|
||||
messages: Chat messages
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters (may contain _ragflow_actual_model)
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Transformed request dictionary
|
||||
"""
|
||||
# Get the actual model name from litellm_params if available
|
||||
actual_model = litellm_params.get("_ragflow_actual_model")
|
||||
if actual_model is None:
|
||||
# Fallback: try to parse the model name
|
||||
try:
|
||||
_, _, actual_model = self._parse_ragflow_model(model)
|
||||
except ValueError:
|
||||
# If parsing fails, use the original model name
|
||||
actual_model = model
|
||||
|
||||
# Use parent's transform_request with the actual model name
|
||||
return super().transform_request(
|
||||
actual_model, messages, optional_params, litellm_params, headers
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
# RAGFlow vector stores module
|
||||
@@ -0,0 +1,249 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.vector_stores import (
|
||||
BaseVectorStoreAuthCredentials,
|
||||
VectorStoreCreateOptionalRequestParams,
|
||||
VectorStoreCreateResponse,
|
||||
VectorStoreFileCounts,
|
||||
VectorStoreIndexEndpoints,
|
||||
VectorStoreSearchOptionalRequestParams,
|
||||
VectorStoreSearchResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class RAGFlowVectorStoreConfig(BaseVectorStoreConfig):
|
||||
"""Vector store configuration for RAGFlow datasets."""
|
||||
|
||||
def get_auth_credentials(
|
||||
self, litellm_params: dict
|
||||
) -> BaseVectorStoreAuthCredentials:
|
||||
api_key = litellm_params.get("api_key")
|
||||
if api_key is None:
|
||||
# Try to get from environment variable
|
||||
api_key = get_secret_str("RAGFLOW_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"api_key is required (set RAGFLOW_API_KEY env var or pass in litellm_params)"
|
||||
)
|
||||
return {
|
||||
"headers": {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
},
|
||||
}
|
||||
|
||||
def get_vector_store_endpoints_by_type(self) -> VectorStoreIndexEndpoints:
|
||||
"""RAGFlow vector stores are management-only, no search support."""
|
||||
return {
|
||||
"read": [],
|
||||
"write": [],
|
||||
}
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
"""Validate environment and set headers for RAGFlow API."""
|
||||
litellm_params = litellm_params or GenericLiteLLMParams()
|
||||
api_key = litellm_params.api_key or get_secret_str("RAGFLOW_API_KEY")
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"RAGFLOW_API_KEY is required (set env var or pass in litellm_params)"
|
||||
)
|
||||
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
)
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for RAGFlow datasets API.
|
||||
|
||||
Supports:
|
||||
- RAGFLOW_API_BASE env var
|
||||
- api_base in litellm_params
|
||||
- Default: http://localhost:9380
|
||||
"""
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm_params.get("api_base")
|
||||
or get_secret_str("RAGFLOW_API_BASE")
|
||||
or "http://localhost:9380"
|
||||
)
|
||||
|
||||
# Remove trailing slashes
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
# RAGFlow datasets API endpoint
|
||||
return f"{api_base}/api/v1/datasets"
|
||||
|
||||
def transform_search_vector_store_request(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: Union[str, List[str]],
|
||||
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
|
||||
api_base: str,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""RAGFlow vector stores are management-only, search is not supported."""
|
||||
raise NotImplementedError(
|
||||
"RAGFlow vector stores support dataset management only, not search/retrieval"
|
||||
)
|
||||
|
||||
def transform_search_vector_store_response(
|
||||
self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj
|
||||
) -> VectorStoreSearchResponse:
|
||||
"""RAGFlow vector stores are management-only, search is not supported."""
|
||||
raise NotImplementedError(
|
||||
"RAGFlow vector stores support dataset management only, not search/retrieval"
|
||||
)
|
||||
|
||||
def transform_create_vector_store_request(
|
||||
self,
|
||||
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform create request to RAGFlow POST /api/v1/datasets format.
|
||||
|
||||
Maps LiteLLM params to RAGFlow dataset creation parameters.
|
||||
RAGFlow-specific fields can be passed via metadata.
|
||||
"""
|
||||
url = api_base # Already includes /api/v1/datasets from get_complete_url
|
||||
|
||||
# Extract name (required by RAGFlow)
|
||||
name = vector_store_create_optional_params.get("name")
|
||||
if not name:
|
||||
raise ValueError("name is required for RAGFlow dataset creation")
|
||||
|
||||
# Build request body
|
||||
request_body: Dict[str, Any] = {
|
||||
"name": name,
|
||||
}
|
||||
|
||||
# Extract RAGFlow-specific fields from metadata
|
||||
metadata = vector_store_create_optional_params.get("metadata")
|
||||
if metadata:
|
||||
# RAGFlow-specific fields that can be in metadata
|
||||
ragflow_fields = [
|
||||
"avatar",
|
||||
"description",
|
||||
"embedding_model",
|
||||
"permission",
|
||||
"chunk_method",
|
||||
"parser_config",
|
||||
"parse_type",
|
||||
"pipeline_id",
|
||||
]
|
||||
|
||||
for field in ragflow_fields:
|
||||
if field in metadata:
|
||||
request_body[field] = metadata[field]
|
||||
|
||||
# Validate: chunk_method and pipeline_id are mutually exclusive
|
||||
if "chunk_method" in request_body and "pipeline_id" in request_body:
|
||||
raise ValueError(
|
||||
"chunk_method and pipeline_id are mutually exclusive. "
|
||||
"Specify either chunk_method or pipeline_id, not both."
|
||||
)
|
||||
|
||||
# If neither chunk_method nor pipeline_id is specified, default to naive
|
||||
if "chunk_method" not in request_body and "pipeline_id" not in request_body:
|
||||
request_body["chunk_method"] = "naive"
|
||||
|
||||
return url, request_body
|
||||
|
||||
def transform_create_vector_store_response(
|
||||
self, response: httpx.Response
|
||||
) -> VectorStoreCreateResponse:
|
||||
"""
|
||||
Transform RAGFlow response to VectorStoreCreateResponse format.
|
||||
|
||||
RAGFlow response format:
|
||||
{
|
||||
"code": 0,
|
||||
"data": {
|
||||
"id": "...",
|
||||
"name": "...",
|
||||
"create_time": 1745836841611, # milliseconds
|
||||
...
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
response_json = response.json()
|
||||
|
||||
# Check for RAGFlow error response
|
||||
if response_json.get("code") != 0:
|
||||
error_message = response_json.get("message", "Unknown error")
|
||||
raise self.get_error_class(
|
||||
error_message=error_message,
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
)
|
||||
|
||||
data = response_json.get("data", {})
|
||||
|
||||
# Extract dataset ID
|
||||
dataset_id = data.get("id")
|
||||
if not dataset_id:
|
||||
raise ValueError("RAGFlow response missing dataset id")
|
||||
|
||||
# Extract name
|
||||
name = data.get("name")
|
||||
|
||||
# Convert create_time from milliseconds to seconds (Unix timestamp)
|
||||
create_time_ms = data.get("create_time", 0)
|
||||
created_at = int(create_time_ms / 1000) if create_time_ms else None
|
||||
|
||||
# Build VectorStoreCreateResponse
|
||||
return VectorStoreCreateResponse(
|
||||
id=dataset_id,
|
||||
object="vector_store",
|
||||
created_at=created_at or 0,
|
||||
name=name,
|
||||
bytes=0, # RAGFlow doesn't provide bytes in response
|
||||
file_counts=VectorStoreFileCounts(
|
||||
in_progress=0,
|
||||
completed=0,
|
||||
failed=0,
|
||||
cancelled=0,
|
||||
total=0,
|
||||
),
|
||||
status="completed",
|
||||
expires_after=None,
|
||||
expires_at=None,
|
||||
last_active_at=None,
|
||||
metadata=None,
|
||||
)
|
||||
except Exception as e:
|
||||
# If it's already a ValueError we raised, re-raise it
|
||||
if isinstance(e, ValueError) and "RAGFlow response" in str(e):
|
||||
raise
|
||||
# If it's already our error class (has status_code), re-raise
|
||||
if hasattr(e, "status_code"):
|
||||
raise
|
||||
# Otherwise, wrap in our error class
|
||||
raise self.get_error_class(
|
||||
error_message=str(e),
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
)
|
||||
Reference in New Issue
Block a user