chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,698 @@
|
||||
# litellm/proxy/guardrails/guardrail_registry.py
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Type, cast
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy.guardrails.guardrail_hooks.grayswan import GraySwanGuardrail
|
||||
from litellm.proxy.guardrails.guardrail_hooks.grayswan import (
|
||||
initialize_guardrail as initialize_grayswan,
|
||||
)
|
||||
from litellm.proxy.types_utils.utils import get_instance_fn
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.guardrails import (
|
||||
Guardrail,
|
||||
GuardrailEventHooks,
|
||||
LakeraCategoryThresholds,
|
||||
LitellmParams,
|
||||
SupportedGuardrailIntegrations,
|
||||
)
|
||||
|
||||
from .guardrail_initializers import (
|
||||
initialize_bedrock,
|
||||
initialize_hide_secrets,
|
||||
initialize_lakera,
|
||||
initialize_lakera_v2,
|
||||
initialize_presidio,
|
||||
initialize_tool_permission,
|
||||
)
|
||||
|
||||
guardrail_initializer_registry = {
|
||||
SupportedGuardrailIntegrations.BEDROCK.value: initialize_bedrock,
|
||||
SupportedGuardrailIntegrations.LAKERA.value: initialize_lakera,
|
||||
SupportedGuardrailIntegrations.LAKERA_V2.value: initialize_lakera_v2,
|
||||
SupportedGuardrailIntegrations.PRESIDIO.value: initialize_presidio,
|
||||
SupportedGuardrailIntegrations.HIDE_SECRETS.value: initialize_hide_secrets,
|
||||
SupportedGuardrailIntegrations.TOOL_PERMISSION.value: initialize_tool_permission,
|
||||
SupportedGuardrailIntegrations.GRAYSWAN.value: initialize_grayswan,
|
||||
}
|
||||
|
||||
guardrail_class_registry: Dict[str, Type[CustomGuardrail]] = {
|
||||
SupportedGuardrailIntegrations.GRAYSWAN.value: GraySwanGuardrail
|
||||
}
|
||||
|
||||
|
||||
def get_guardrail_initializer_from_hooks():
|
||||
"""
|
||||
Get guardrail initializers by discovering them from the guardrail_hooks directory structure.
|
||||
|
||||
Scans the guardrail_hooks directory for subdirectories containing __init__.py files
|
||||
with either guardrail_initializer_registry or initialize_guardrail functions.
|
||||
|
||||
Returns:
|
||||
Dict[str, Callable]: A dictionary mapping guardrail types to their initializer functions
|
||||
"""
|
||||
discovered_initializers = {}
|
||||
|
||||
try:
|
||||
# Get the path to the guardrail_hooks directory
|
||||
current_dir = os.path.dirname(__file__)
|
||||
hooks_dir = os.path.join(current_dir, "guardrail_hooks")
|
||||
|
||||
if not os.path.exists(hooks_dir):
|
||||
verbose_proxy_logger.debug("guardrail_hooks directory not found")
|
||||
return discovered_initializers
|
||||
|
||||
# Scan each subdirectory in guardrail_hooks
|
||||
for item in os.listdir(hooks_dir):
|
||||
item_path = os.path.join(hooks_dir, item)
|
||||
|
||||
# Skip files and __pycache__ directories
|
||||
if not os.path.isdir(item_path) or item.startswith("__"):
|
||||
continue
|
||||
|
||||
# Check if the directory has an __init__.py file
|
||||
init_file = os.path.join(item_path, "__init__.py")
|
||||
if not os.path.exists(init_file):
|
||||
continue
|
||||
|
||||
module_path = f"litellm.proxy.guardrails.guardrail_hooks.{item}"
|
||||
try:
|
||||
# Import the module
|
||||
verbose_proxy_logger.debug(f"Discovering guardrails in: {module_path}")
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
|
||||
# Check for guardrail_initializer_registry dictionary
|
||||
if hasattr(module, "guardrail_initializer_registry"):
|
||||
registry = getattr(module, "guardrail_initializer_registry")
|
||||
if isinstance(registry, dict):
|
||||
discovered_initializers.update(registry)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Found guardrail_initializer_registry in {module_path}: {list(registry.keys())}"
|
||||
)
|
||||
|
||||
# Check for standalone initialize_guardrail function (fallback for directory-based guardrails)
|
||||
elif hasattr(module, "initialize_guardrail"):
|
||||
# For directories with just initialize_guardrail, use the directory name as the key
|
||||
initialize_fn = getattr(module, "initialize_guardrail")
|
||||
discovered_initializers[item] = initialize_fn
|
||||
verbose_proxy_logger.debug(
|
||||
f"Found initialize_guardrail function in {module_path}"
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
verbose_proxy_logger.error(f"Could not import {module_path}: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error processing {module_path}: {e}")
|
||||
continue
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Discovered {len(discovered_initializers)} guardrail initializers: {list(discovered_initializers.keys())}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error discovering guardrail initializers: {e}")
|
||||
|
||||
return discovered_initializers
|
||||
|
||||
|
||||
def get_guardrail_class_from_hooks():
|
||||
"""
|
||||
Get guardrail classes by discovering them from the guardrail_hooks directory structure.
|
||||
"""
|
||||
"""
|
||||
Get guardrail initializers by discovering them from the guardrail_hooks directory structure.
|
||||
|
||||
Scans the guardrail_hooks directory for subdirectories containing __init__.py files
|
||||
with either guardrail_initializer_registry or initialize_guardrail functions.
|
||||
|
||||
Returns:
|
||||
Dict[str, Callable]: A dictionary mapping guardrail types to their initializer functions
|
||||
"""
|
||||
discovered_classes = {}
|
||||
|
||||
try:
|
||||
# Get the path to the guardrail_hooks directory
|
||||
current_dir = os.path.dirname(__file__)
|
||||
hooks_dir = os.path.join(current_dir, "guardrail_hooks")
|
||||
|
||||
if not os.path.exists(hooks_dir):
|
||||
verbose_proxy_logger.debug("guardrail_hooks directory not found")
|
||||
return discovered_classes
|
||||
|
||||
# Scan each subdirectory in guardrail_hooks
|
||||
for item in os.listdir(hooks_dir):
|
||||
item_path = os.path.join(hooks_dir, item)
|
||||
|
||||
# Skip files and __pycache__ directories
|
||||
if not os.path.isdir(item_path) or item.startswith("__"):
|
||||
continue
|
||||
|
||||
# Check if the directory has an __init__.py file
|
||||
init_file = os.path.join(item_path, "__init__.py")
|
||||
|
||||
if not os.path.exists(init_file):
|
||||
continue
|
||||
|
||||
module_path = f"litellm.proxy.guardrails.guardrail_hooks.{item}"
|
||||
|
||||
try:
|
||||
# Import the module
|
||||
verbose_proxy_logger.debug(f"Discovering guardrails in: {module_path}")
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
|
||||
# Check for guardrail_initializer_registry dictionary
|
||||
if hasattr(module, "guardrail_class_registry"):
|
||||
registry = getattr(module, "guardrail_class_registry")
|
||||
if isinstance(registry, dict):
|
||||
discovered_classes.update(registry)
|
||||
|
||||
except ImportError as e:
|
||||
verbose_proxy_logger.debug(f"Could not import {module_path}: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error processing {module_path}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error discovering guardrail initializers: {e}")
|
||||
|
||||
return discovered_classes
|
||||
|
||||
|
||||
guardrail_class_registry.update(get_guardrail_class_from_hooks())
|
||||
|
||||
|
||||
# Merge with dynamically discovered guardrail initializers
|
||||
_discovered_initializers = get_guardrail_initializer_from_hooks()
|
||||
|
||||
guardrail_initializer_registry.update(_discovered_initializers)
|
||||
|
||||
|
||||
class GuardrailRegistry:
|
||||
"""
|
||||
Registry for guardrails
|
||||
|
||||
Handles adding, removing, and getting guardrails in DB + in memory
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
###########################################################
|
||||
########### In memory management helpers for guardrails ###########
|
||||
############################################################
|
||||
def get_initialized_guardrail_callback(
|
||||
self, guardrail_name: str
|
||||
) -> Optional[CustomGuardrail]:
|
||||
"""
|
||||
Returns the initialized guardrail callback for a given guardrail name
|
||||
"""
|
||||
active_guardrails = (
|
||||
litellm.logging_callback_manager.get_custom_loggers_for_type(
|
||||
callback_type=CustomGuardrail
|
||||
)
|
||||
)
|
||||
for active_guardrail in active_guardrails:
|
||||
if isinstance(active_guardrail, CustomGuardrail):
|
||||
if active_guardrail.guardrail_name == guardrail_name:
|
||||
return active_guardrail
|
||||
return None
|
||||
|
||||
###########################################################
|
||||
########### DB management helpers for guardrails ###########
|
||||
############################################################
|
||||
async def add_guardrail_to_db(
|
||||
self, guardrail: Guardrail, prisma_client: PrismaClient
|
||||
):
|
||||
"""
|
||||
Add a guardrail to the database
|
||||
"""
|
||||
try:
|
||||
guardrail_name = guardrail.get("guardrail_name")
|
||||
# Properly serialize LitellmParams Pydantic model to dict
|
||||
litellm_params_obj: Any = guardrail.get("litellm_params", {})
|
||||
if hasattr(litellm_params_obj, "model_dump"):
|
||||
litellm_params_dict = litellm_params_obj.model_dump()
|
||||
else:
|
||||
litellm_params_dict = (
|
||||
dict(litellm_params_obj) if litellm_params_obj else {}
|
||||
)
|
||||
litellm_params: str = safe_dumps(litellm_params_dict)
|
||||
guardrail_info: str = safe_dumps(guardrail.get("guardrail_info", {}))
|
||||
|
||||
# Create guardrail in DB
|
||||
created_guardrail = await prisma_client.db.litellm_guardrailstable.create(
|
||||
data={
|
||||
"guardrail_name": guardrail_name,
|
||||
"litellm_params": litellm_params,
|
||||
"guardrail_info": guardrail_info,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
}
|
||||
)
|
||||
|
||||
# Add guardrail_id to the returned guardrail object
|
||||
guardrail_dict = dict(guardrail)
|
||||
guardrail_dict["guardrail_id"] = created_guardrail.guardrail_id
|
||||
|
||||
return guardrail_dict
|
||||
except Exception as e:
|
||||
raise Exception(f"Error adding guardrail to DB: {str(e)}")
|
||||
|
||||
async def delete_guardrail_from_db(
|
||||
self, guardrail_id: str, prisma_client: PrismaClient
|
||||
):
|
||||
"""
|
||||
Delete a guardrail from the database
|
||||
"""
|
||||
try:
|
||||
# Delete from DB
|
||||
await prisma_client.db.litellm_guardrailstable.delete(
|
||||
where={"guardrail_id": guardrail_id}
|
||||
)
|
||||
|
||||
return {"message": f"Guardrail {guardrail_id} deleted successfully"}
|
||||
except Exception as e:
|
||||
raise Exception(f"Error deleting guardrail from DB: {str(e)}")
|
||||
|
||||
async def update_guardrail_in_db(
|
||||
self, guardrail_id: str, guardrail: Guardrail, prisma_client: PrismaClient
|
||||
):
|
||||
"""
|
||||
Update a guardrail in the database
|
||||
"""
|
||||
try:
|
||||
guardrail_name = guardrail.get("guardrail_name")
|
||||
# Properly serialize LitellmParams Pydantic model to dict
|
||||
litellm_params_obj: Any = guardrail.get("litellm_params", {})
|
||||
if hasattr(litellm_params_obj, "model_dump"):
|
||||
litellm_params_dict = litellm_params_obj.model_dump()
|
||||
else:
|
||||
litellm_params_dict = (
|
||||
dict(litellm_params_obj) if litellm_params_obj else {}
|
||||
)
|
||||
litellm_params: str = safe_dumps(litellm_params_dict)
|
||||
guardrail_info: str = safe_dumps(guardrail.get("guardrail_info", {}))
|
||||
|
||||
# Update in DB
|
||||
updated_guardrail = await prisma_client.db.litellm_guardrailstable.update(
|
||||
where={"guardrail_id": guardrail_id},
|
||||
data={
|
||||
"guardrail_name": guardrail_name,
|
||||
"litellm_params": litellm_params,
|
||||
"guardrail_info": guardrail_info,
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
|
||||
# Convert to dict and return
|
||||
return dict(updated_guardrail)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error updating guardrail in DB: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def get_all_guardrails_from_db(
|
||||
prisma_client: PrismaClient,
|
||||
) -> List[Guardrail]:
|
||||
"""
|
||||
Get all active guardrails from the database.
|
||||
Only rows with status == "active" are returned (pending_review and rejected are excluded).
|
||||
"""
|
||||
try:
|
||||
guardrails_from_db = (
|
||||
await prisma_client.db.litellm_guardrailstable.find_many(
|
||||
where={"status": "active"},
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
)
|
||||
|
||||
guardrails: List[Guardrail] = []
|
||||
for guardrail in guardrails_from_db:
|
||||
guardrails.append(Guardrail(**(dict(guardrail)))) # type: ignore
|
||||
|
||||
return guardrails
|
||||
except Exception as e:
|
||||
raise Exception(f"Error getting guardrails from DB: {str(e)}")
|
||||
|
||||
async def get_guardrail_by_id_from_db(
|
||||
self, guardrail_id: str, prisma_client: PrismaClient
|
||||
) -> Optional[Guardrail]:
|
||||
"""
|
||||
Get a guardrail by its ID from the database
|
||||
"""
|
||||
try:
|
||||
guardrail = await prisma_client.db.litellm_guardrailstable.find_unique(
|
||||
where={"guardrail_id": guardrail_id}
|
||||
)
|
||||
|
||||
if not guardrail:
|
||||
return None
|
||||
|
||||
return Guardrail(**(dict(guardrail))) # type: ignore
|
||||
except Exception as e:
|
||||
raise Exception(f"Error getting guardrail from DB: {str(e)}")
|
||||
|
||||
async def get_guardrail_by_name_from_db(
|
||||
self, guardrail_name: str, prisma_client: PrismaClient
|
||||
) -> Optional[Guardrail]:
|
||||
"""
|
||||
Get a guardrail by its name from the database
|
||||
"""
|
||||
try:
|
||||
guardrail = await prisma_client.db.litellm_guardrailstable.find_unique(
|
||||
where={"guardrail_name": guardrail_name}
|
||||
)
|
||||
|
||||
if not guardrail:
|
||||
return None
|
||||
|
||||
return Guardrail(**(dict(guardrail))) # type: ignore
|
||||
except Exception as e:
|
||||
raise Exception(f"Error getting guardrail from DB: {str(e)}")
|
||||
|
||||
|
||||
class InMemoryGuardrailHandler:
|
||||
"""
|
||||
Class that handles initializing guardrails and adding them to the CallbackManager
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.IN_MEMORY_GUARDRAILS: Dict[str, Guardrail] = {}
|
||||
"""
|
||||
Guardrail id to Guardrail object mapping
|
||||
"""
|
||||
|
||||
self.guardrail_id_to_custom_guardrail: Dict[str, Optional[CustomGuardrail]] = {}
|
||||
"""
|
||||
Guardrail id to CustomGuardrail object mapping
|
||||
"""
|
||||
|
||||
def initialize_guardrail(
|
||||
self,
|
||||
guardrail: Guardrail,
|
||||
config_file_path: Optional[str] = None,
|
||||
llm_router: Optional["Router"] = None,
|
||||
) -> Optional[Guardrail]:
|
||||
"""
|
||||
Initialize a guardrail from a dictionary and add it to the litellm callback manager
|
||||
|
||||
Returns a Guardrail object if the guardrail is initialized successfully
|
||||
"""
|
||||
guardrail_id = guardrail.get("guardrail_id") or str(uuid.uuid4())
|
||||
guardrail["guardrail_id"] = guardrail_id
|
||||
if guardrail_id in self.IN_MEMORY_GUARDRAILS:
|
||||
verbose_proxy_logger.debug(
|
||||
"guardrail_id already exists in IN_MEMORY_GUARDRAILS"
|
||||
)
|
||||
return self.IN_MEMORY_GUARDRAILS[guardrail_id]
|
||||
|
||||
custom_guardrail_callback: Optional[CustomGuardrail] = None
|
||||
litellm_params_data = guardrail["litellm_params"]
|
||||
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
|
||||
|
||||
if isinstance(litellm_params_data, dict):
|
||||
litellm_params = LitellmParams(**litellm_params_data)
|
||||
else:
|
||||
litellm_params = litellm_params_data
|
||||
|
||||
if (
|
||||
"category_thresholds" in litellm_params_data
|
||||
and litellm_params_data["category_thresholds"]
|
||||
):
|
||||
lakera_category_thresholds = LakeraCategoryThresholds(
|
||||
**litellm_params_data["category_thresholds"]
|
||||
)
|
||||
litellm_params.category_thresholds = lakera_category_thresholds
|
||||
|
||||
if litellm_params.api_key and litellm_params.api_key.startswith("os.environ/"):
|
||||
litellm_params.api_key = str(get_secret(litellm_params.api_key))
|
||||
|
||||
if litellm_params.api_base and litellm_params.api_base.startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
litellm_params.api_base = str(get_secret(litellm_params.api_base))
|
||||
|
||||
guardrail_type = litellm_params.guardrail
|
||||
|
||||
if guardrail_type is None:
|
||||
raise ValueError("guardrail_type is required")
|
||||
|
||||
initializer = guardrail_initializer_registry.get(guardrail_type)
|
||||
|
||||
if initializer:
|
||||
# Try to call with llm_router first, fall back to without if it fails
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(initializer)
|
||||
if "llm_router" in sig.parameters:
|
||||
custom_guardrail_callback = initializer(
|
||||
litellm_params, guardrail, llm_router # type: ignore
|
||||
)
|
||||
else:
|
||||
custom_guardrail_callback = initializer(litellm_params, guardrail)
|
||||
elif isinstance(guardrail_type, str) and "." in guardrail_type:
|
||||
custom_guardrail_callback = self.initialize_custom_guardrail(
|
||||
guardrail=cast(dict, guardrail),
|
||||
guardrail_type=guardrail_type,
|
||||
litellm_params=litellm_params,
|
||||
config_file_path=config_file_path,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported guardrail: {guardrail_type}")
|
||||
|
||||
parsed_guardrail = Guardrail(
|
||||
guardrail_id=guardrail.get("guardrail_id"),
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
# store references to the guardrail in memory
|
||||
self.IN_MEMORY_GUARDRAILS[guardrail_id] = parsed_guardrail
|
||||
self.guardrail_id_to_custom_guardrail[guardrail_id] = custom_guardrail_callback
|
||||
|
||||
return parsed_guardrail
|
||||
|
||||
def initialize_custom_guardrail(
|
||||
self,
|
||||
guardrail: Dict,
|
||||
guardrail_type: str,
|
||||
litellm_params: LitellmParams,
|
||||
config_file_path: Optional[str] = None,
|
||||
) -> Optional[CustomGuardrail]:
|
||||
"""
|
||||
Initialize a Custom Guardrail from a python file or module path
|
||||
|
||||
This initializes it by adding it to the litellm callback manager
|
||||
"""
|
||||
if not config_file_path:
|
||||
raise Exception(
|
||||
"GuardrailsAIException - Please pass the config_file_path to initialize_guardrails_v2"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Initializing custom guardrail: %s",
|
||||
guardrail_type,
|
||||
)
|
||||
|
||||
_guardrail_class = get_instance_fn(
|
||||
guardrail_type, config_file_path=config_file_path
|
||||
)
|
||||
|
||||
mode = litellm_params.mode
|
||||
if mode is None:
|
||||
raise ValueError(
|
||||
f"mode is required for guardrail {guardrail_type} please set mode to one of the following: {', '.join(GuardrailEventHooks)}"
|
||||
)
|
||||
|
||||
default_on = litellm_params.default_on
|
||||
|
||||
# Extract additional params from litellm_params to pass to custom guardrail
|
||||
# This matches the behavior of other guardrail initializers (e.g., initialize_lakera)
|
||||
# and aligns with the documented behavior for custom guardrails
|
||||
if hasattr(litellm_params, "model_dump"):
|
||||
extra_params = litellm_params.model_dump(exclude_none=True)
|
||||
else:
|
||||
extra_params = dict(litellm_params) if litellm_params else {}
|
||||
|
||||
# Remove params that are handled explicitly or are internal
|
||||
for key in ["guardrail", "mode", "default_on"]:
|
||||
extra_params.pop(key, None)
|
||||
|
||||
_guardrail_callback = _guardrail_class(
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
event_hook=mode,
|
||||
default_on=default_on,
|
||||
**extra_params,
|
||||
)
|
||||
litellm.logging_callback_manager.add_litellm_callback(_guardrail_callback) # type: ignore
|
||||
|
||||
return _guardrail_callback
|
||||
|
||||
def update_in_memory_guardrail(
|
||||
self, guardrail_id: str, guardrail: Guardrail
|
||||
) -> None:
|
||||
"""
|
||||
Update a guardrail in memory
|
||||
|
||||
- updates the guardrail in memory
|
||||
- updates the guardrail params in litellm.callback_manager
|
||||
"""
|
||||
self.IN_MEMORY_GUARDRAILS[guardrail_id] = guardrail
|
||||
|
||||
custom_guardrail_callback = self.guardrail_id_to_custom_guardrail.get(
|
||||
guardrail_id
|
||||
)
|
||||
if custom_guardrail_callback:
|
||||
updated_litellm_params = cast(
|
||||
LitellmParams, guardrail.get("litellm_params", {})
|
||||
)
|
||||
custom_guardrail_callback.update_in_memory_litellm_params(
|
||||
litellm_params=updated_litellm_params
|
||||
)
|
||||
|
||||
def delete_in_memory_guardrail(self, guardrail_id: str) -> None:
|
||||
"""
|
||||
Delete a guardrail in memory and remove from litellm callbacks.
|
||||
"""
|
||||
# Remove from in-memory storage
|
||||
self.IN_MEMORY_GUARDRAILS.pop(guardrail_id, None)
|
||||
|
||||
# Remove the callback from litellm.callbacks
|
||||
custom_guardrail_callback = self.guardrail_id_to_custom_guardrail.pop(
|
||||
guardrail_id, None
|
||||
)
|
||||
if custom_guardrail_callback:
|
||||
litellm.logging_callback_manager.remove_callback_from_list_by_object(
|
||||
callback_list=litellm.callbacks,
|
||||
obj=custom_guardrail_callback,
|
||||
require_self=False,
|
||||
)
|
||||
|
||||
def list_in_memory_guardrails(self) -> List[Guardrail]:
|
||||
"""
|
||||
List all guardrails in memory
|
||||
"""
|
||||
return list(self.IN_MEMORY_GUARDRAILS.values())
|
||||
|
||||
def get_guardrail_by_id(self, guardrail_id: str) -> Optional[Guardrail]:
|
||||
"""
|
||||
Get a guardrail by its ID from memory
|
||||
"""
|
||||
return self.IN_MEMORY_GUARDRAILS.get(guardrail_id)
|
||||
|
||||
def _has_guardrail_params_changed(
|
||||
self, guardrail_id: str, new_guardrail: Guardrail
|
||||
) -> bool:
|
||||
"""
|
||||
Check if guardrail params or name have changed compared to in-memory version.
|
||||
Returns True if params/name changed or guardrail doesn't exist in memory.
|
||||
"""
|
||||
existing = self.IN_MEMORY_GUARDRAILS.get(guardrail_id)
|
||||
if existing is None:
|
||||
return True
|
||||
|
||||
# Compare guardrail_name
|
||||
if existing.get("guardrail_name") != new_guardrail.get("guardrail_name"):
|
||||
return True
|
||||
|
||||
# Compare litellm_params
|
||||
existing_params = existing.get("litellm_params")
|
||||
new_params = new_guardrail.get("litellm_params")
|
||||
|
||||
# Convert to dicts for comparison
|
||||
existing_dict = (
|
||||
existing_params.model_dump()
|
||||
if isinstance(existing_params, LitellmParams)
|
||||
else existing_params
|
||||
)
|
||||
new_dict = (
|
||||
new_params.model_dump()
|
||||
if isinstance(new_params, LitellmParams)
|
||||
else new_params
|
||||
)
|
||||
|
||||
# Compare and identify specific differences
|
||||
changed_fields = {}
|
||||
if existing_dict is not None and new_dict is not None:
|
||||
all_keys = set(existing_dict.keys()) | set(new_dict.keys())
|
||||
for key in all_keys:
|
||||
old_val = existing_dict.get(key)
|
||||
new_val = new_dict.get(key)
|
||||
if old_val != new_val:
|
||||
changed_fields[key] = {"old": old_val, "new": new_val}
|
||||
elif existing_dict != new_dict:
|
||||
changed_fields = {"litellm_params": {"old": existing_dict, "new": new_dict}}
|
||||
|
||||
# Log differences if any found
|
||||
if changed_fields:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Guardrail params changed. Differences: {changed_fields}"
|
||||
)
|
||||
|
||||
# Return True if any fields changed
|
||||
return len(changed_fields) > 0
|
||||
|
||||
def reinitialize_guardrail(
|
||||
self, guardrail: Guardrail, config_file_path: Optional[str] = None
|
||||
) -> Optional[Guardrail]:
|
||||
"""
|
||||
Force re-initialization of a guardrail even if it exists in memory.
|
||||
Removes old callback from litellm.callbacks and creates fresh instance.
|
||||
"""
|
||||
guardrail_id = guardrail.get("guardrail_id")
|
||||
if not guardrail_id:
|
||||
verbose_proxy_logger.error(
|
||||
"Cannot reinitialize guardrail without guardrail_id"
|
||||
)
|
||||
return None
|
||||
|
||||
# Remove from memory if exists (also removes from callbacks)
|
||||
if guardrail_id in self.IN_MEMORY_GUARDRAILS:
|
||||
self.delete_in_memory_guardrail(guardrail_id)
|
||||
|
||||
# Initialize fresh (will add new callback to litellm.callbacks)
|
||||
return self.initialize_guardrail(
|
||||
guardrail=guardrail, config_file_path=config_file_path
|
||||
)
|
||||
|
||||
def sync_guardrail_from_db(
|
||||
self, guardrail: Guardrail, config_file_path: Optional[str] = None
|
||||
) -> Optional[Guardrail]:
|
||||
"""
|
||||
Sync a guardrail from DB - initializes if new, re-initializes if changed.
|
||||
This is the method to call during DB polling.
|
||||
"""
|
||||
guardrail_id = guardrail.get("guardrail_id")
|
||||
if not guardrail_id:
|
||||
verbose_proxy_logger.error("Cannot sync guardrail without guardrail_id")
|
||||
return None
|
||||
|
||||
if self._has_guardrail_params_changed(guardrail_id, guardrail):
|
||||
guardrail_name = guardrail.get("guardrail_name", "Unknown")
|
||||
verbose_proxy_logger.info(
|
||||
f"Guardrail '{guardrail_name}' (ID: {guardrail_id}) params changed, re-initializing..."
|
||||
)
|
||||
return self.reinitialize_guardrail(
|
||||
guardrail=guardrail, config_file_path=config_file_path
|
||||
)
|
||||
|
||||
return self.IN_MEMORY_GUARDRAILS.get(guardrail_id)
|
||||
|
||||
|
||||
########################################################
|
||||
# In Memory Guardrail Handler for LiteLLM Proxy
|
||||
########################################################
|
||||
IN_MEMORY_GUARDRAIL_HANDLER = InMemoryGuardrailHandler()
|
||||
########################################################
|
||||
Reference in New Issue
Block a user