Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/proxy/policy_engine/policy_registry.py
2026-03-26 20:06:14 +08:00

979 lines
36 KiB
Python

"""
Policy Registry - In-memory storage for policies.
Handles storing, retrieving, and managing policies.
Policies define WHAT guardrails to apply. WHERE they apply is defined
by policy_attachments (see AttachmentRegistry).
"""
import json
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from litellm._logging import verbose_proxy_logger
from litellm.types.proxy.policy_engine import (
GuardrailPipeline,
PipelineStep,
Policy,
PolicyCondition,
PolicyCreateRequest,
PolicyDBResponse,
PolicyGuardrails,
PolicyUpdateRequest,
PolicyVersionCompareResponse,
PolicyVersionListResponse,
)
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient
# Prefix for policy version IDs in request body. Use policy_<uuid> to execute a specific version.
POLICY_VERSION_ID_PREFIX = "policy_"
def _row_to_policy_db_response(row: Any) -> PolicyDBResponse:
"""Build PolicyDBResponse from a Prisma LiteLLM_PolicyTable row."""
return PolicyDBResponse(
policy_id=row.policy_id,
policy_name=row.policy_name,
version_number=getattr(row, "version_number", 1),
version_status=getattr(row, "version_status", "production"),
parent_version_id=getattr(row, "parent_version_id", None),
is_latest=getattr(row, "is_latest", True),
published_at=getattr(row, "published_at", None),
production_at=getattr(row, "production_at", None),
inherit=row.inherit,
description=row.description,
guardrails_add=row.guardrails_add or [],
guardrails_remove=row.guardrails_remove or [],
condition=row.condition,
pipeline=row.pipeline,
created_at=row.created_at,
updated_at=row.updated_at,
created_by=row.created_by,
updated_by=row.updated_by,
)
class PolicyRegistry:
"""
In-memory registry for storing and managing policies.
This is a singleton that holds all loaded policies and provides
methods to access them.
Policies define WHAT guardrails to apply:
- Base guardrails via guardrails.add/remove
- Inheritance via inherit field
- Conditional guardrails via condition.model
"""
def __init__(self):
self._policies: Dict[str, Policy] = {}
self._policies_by_id: Dict[str, Tuple[str, Policy]] = {}
self._initialized: bool = False
def load_policies(self, policies_config: Dict[str, Any]) -> None:
"""
Load policies from a configuration dictionary.
Args:
policies_config: Dictionary mapping policy names to policy definitions.
This is the raw config from the YAML file.
"""
self._policies = {}
self._policies_by_id = {}
for policy_name, policy_data in policies_config.items():
try:
policy = self._parse_policy(policy_name, policy_data)
self._policies[policy_name] = policy
verbose_proxy_logger.debug(f"Loaded policy: {policy_name}")
except Exception as e:
verbose_proxy_logger.error(
f"Error loading policy '{policy_name}': {str(e)}"
)
raise ValueError(f"Invalid policy '{policy_name}': {str(e)}") from e
self._initialized = True
verbose_proxy_logger.info(f"Loaded {len(self._policies)} policies")
def _parse_policy(self, policy_name: str, policy_data: Dict[str, Any]) -> Policy:
"""
Parse a policy from raw configuration data.
Args:
policy_name: Name of the policy
policy_data: Raw policy configuration
Returns:
Parsed Policy object
"""
# Parse guardrails
guardrails_data = policy_data.get("guardrails", {})
if isinstance(guardrails_data, dict):
guardrails = PolicyGuardrails(
add=guardrails_data.get("add"),
remove=guardrails_data.get("remove"),
)
else:
# Handle legacy format where guardrails might be a list
guardrails = PolicyGuardrails(
add=guardrails_data if guardrails_data else None
)
# Parse condition (simple model-based condition)
condition = None
condition_data = policy_data.get("condition")
if condition_data:
condition = PolicyCondition(model=condition_data.get("model"))
# Parse pipeline (optional ordered guardrail execution)
pipeline = PolicyRegistry._parse_pipeline(policy_data.get("pipeline"))
return Policy(
inherit=policy_data.get("inherit"),
description=policy_data.get("description"),
guardrails=guardrails,
condition=condition,
pipeline=pipeline,
)
@staticmethod
def _parse_pipeline(
pipeline_data: Optional[Dict[str, Any]],
) -> Optional[GuardrailPipeline]:
"""Parse a pipeline configuration from raw data."""
if pipeline_data is None:
return None
steps_data = pipeline_data.get("steps", [])
steps = [
PipelineStep(**step_data) if isinstance(step_data, dict) else step_data
for step_data in steps_data
]
return GuardrailPipeline(
mode=pipeline_data.get("mode", "pre_call"),
steps=steps,
)
def get_policy(self, policy_name: str) -> Optional[Policy]:
"""
Get a policy by name.
Args:
policy_name: Name of the policy to retrieve
Returns:
Policy object if found, None otherwise
"""
return self._policies.get(policy_name)
def get_all_policies(self) -> Dict[str, Policy]:
"""
Get all loaded policies.
Returns:
Dictionary mapping policy names to Policy objects
"""
return self._policies.copy()
def get_policy_names(self) -> List[str]:
"""
Get list of all policy names.
Returns:
List of policy names
"""
return list(self._policies.keys())
def has_policy(self, policy_name: str) -> bool:
"""
Check if a policy exists.
Args:
policy_name: Name of the policy to check
Returns:
True if policy exists, False otherwise
"""
return policy_name in self._policies
def is_initialized(self) -> bool:
"""
Check if the registry has been initialized with policies.
Returns:
True if policies have been loaded, False otherwise
"""
return self._initialized
def clear(self) -> None:
"""
Clear all policies from the registry.
"""
self._policies = {}
self._initialized = False
def add_policy(self, policy_name: str, policy: Policy) -> None:
"""
Add or update a single policy.
Args:
policy_name: Name of the policy
policy: Policy object to add
"""
self._policies[policy_name] = policy
verbose_proxy_logger.debug(f"Added/updated policy: {policy_name}")
def remove_policy(self, policy_name: str) -> bool:
"""
Remove a policy by name.
Args:
policy_name: Name of the policy to remove
Returns:
True if policy was removed, False if it didn't exist
"""
if policy_name in self._policies:
del self._policies[policy_name]
verbose_proxy_logger.debug(f"Removed policy: {policy_name}")
return True
return False
# ─────────────────────────────────────────────────────────────────────────
# Database CRUD Methods
# ─────────────────────────────────────────────────────────────────────────
async def add_policy_to_db(
self,
policy_request: PolicyCreateRequest,
prisma_client: "PrismaClient",
created_by: Optional[str] = None,
) -> PolicyDBResponse:
"""
Add a policy to the database.
Args:
policy_request: The policy creation request
prisma_client: The Prisma client instance
created_by: User who created the policy
Returns:
PolicyDBResponse with the created policy
"""
try:
now = datetime.now(timezone.utc)
# Build data dict; new policy is v1 production
data: Dict[str, Any] = {
"policy_name": policy_request.policy_name,
"version_number": 1,
"version_status": "production",
"is_latest": True,
"production_at": now,
"guardrails_add": policy_request.guardrails_add or [],
"guardrails_remove": policy_request.guardrails_remove or [],
"created_at": now,
"updated_at": now,
}
# Only add optional fields if they have values
if policy_request.inherit is not None:
data["inherit"] = policy_request.inherit
if policy_request.description is not None:
data["description"] = policy_request.description
if created_by is not None:
data["created_by"] = created_by
data["updated_by"] = created_by
if policy_request.condition is not None:
data["condition"] = json.dumps(policy_request.condition.model_dump())
if policy_request.pipeline is not None:
validated_pipeline = GuardrailPipeline(**policy_request.pipeline)
data["pipeline"] = json.dumps(validated_pipeline.model_dump())
created_policy = await prisma_client.db.litellm_policytable.create(
data=data
)
# Also add to in-memory registry
policy = self._parse_policy(
policy_request.policy_name,
{
"inherit": policy_request.inherit,
"description": policy_request.description,
"guardrails": {
"add": policy_request.guardrails_add,
"remove": policy_request.guardrails_remove,
},
"condition": (
policy_request.condition.model_dump()
if policy_request.condition
else None
),
"pipeline": policy_request.pipeline,
},
)
self.add_policy(policy_request.policy_name, policy)
return _row_to_policy_db_response(created_policy)
except Exception as e:
verbose_proxy_logger.exception(f"Error adding policy to DB: {e}")
raise Exception(f"Error adding policy to DB: {str(e)}")
async def update_policy_in_db(
self,
policy_id: str,
policy_request: PolicyUpdateRequest,
prisma_client: "PrismaClient",
updated_by: Optional[str] = None,
) -> PolicyDBResponse:
"""
Update a policy in the database. Only draft versions can be updated.
Args:
policy_id: The ID of the policy to update
policy_request: The policy update request
prisma_client: The Prisma client instance
updated_by: User who updated the policy
Returns:
PolicyDBResponse with the updated policy
Raises:
Exception: If policy is not in draft status (only drafts are editable).
"""
try:
existing = await prisma_client.db.litellm_policytable.find_unique(
where={"policy_id": policy_id}
)
if existing is None:
raise Exception(f"Policy with ID {policy_id} not found")
version_status = getattr(existing, "version_status", "production")
if version_status != "draft":
raise Exception(
f"Only draft versions can be updated. This policy has status '{version_status}'."
)
# Build update data - only include fields that are set
update_data: Dict[str, Any] = {
"updated_at": datetime.now(timezone.utc),
"updated_by": updated_by,
}
if policy_request.policy_name is not None:
update_data["policy_name"] = policy_request.policy_name
if policy_request.inherit is not None:
update_data["inherit"] = policy_request.inherit
if policy_request.description is not None:
update_data["description"] = policy_request.description
if policy_request.guardrails_add is not None:
update_data["guardrails_add"] = policy_request.guardrails_add
if policy_request.guardrails_remove is not None:
update_data["guardrails_remove"] = policy_request.guardrails_remove
if policy_request.condition is not None:
update_data["condition"] = json.dumps(
policy_request.condition.model_dump()
)
if policy_request.pipeline is not None:
validated_pipeline = GuardrailPipeline(**policy_request.pipeline)
update_data["pipeline"] = json.dumps(validated_pipeline.model_dump())
updated_policy = await prisma_client.db.litellm_policytable.update(
where={"policy_id": policy_id},
data=update_data,
)
# Do NOT update in-memory registry: drafts are not loaded into memory.
return _row_to_policy_db_response(updated_policy)
except Exception as e:
verbose_proxy_logger.exception(f"Error updating policy in DB: {e}")
raise Exception(f"Error updating policy in DB: {str(e)}")
async def delete_policy_from_db(
self,
policy_id: str,
prisma_client: "PrismaClient",
) -> Dict[str, Any]:
"""
Delete a policy version from the database.
If the deleted version was production, it is removed from the in-memory
registry. No other version is auto-promoted; admin must explicitly promote.
Args:
policy_id: The ID of the policy version to delete
prisma_client: The Prisma client instance
Returns:
Dict with "message" and optional "warning" if production was deleted.
"""
try:
policy = await prisma_client.db.litellm_policytable.find_unique(
where={"policy_id": policy_id}
)
if policy is None:
raise Exception(f"Policy with ID {policy_id} not found")
version_status = getattr(policy, "version_status", "production")
policy_name = policy.policy_name
# Delete from DB
await prisma_client.db.litellm_policytable.delete(
where={"policy_id": policy_id}
)
result: Dict[str, Any] = {
"message": f"Policy {policy_id} deleted successfully"
}
# Remove from in-memory registry only if this was the production version
if version_status == "production":
self.remove_policy(policy_name)
result["warning"] = (
"Production version was deleted. No other version was promoted. "
"Promote another version to production if this policy should remain active."
)
return result
except Exception as e:
verbose_proxy_logger.exception(f"Error deleting policy from DB: {e}")
raise Exception(f"Error deleting policy from DB: {str(e)}")
async def get_policy_by_id_from_db(
self,
policy_id: str,
prisma_client: "PrismaClient",
) -> Optional[PolicyDBResponse]:
"""
Get a policy by ID from the database.
Args:
policy_id: The ID of the policy to retrieve
prisma_client: The Prisma client instance
Returns:
PolicyDBResponse if found, None otherwise
"""
try:
policy = await prisma_client.db.litellm_policytable.find_unique(
where={"policy_id": policy_id}
)
if policy is None:
return None
return _row_to_policy_db_response(policy)
except Exception as e:
verbose_proxy_logger.exception(f"Error getting policy from DB: {e}")
raise Exception(f"Error getting policy from DB: {str(e)}")
def get_policy_by_id_for_request(
self, policy_id: str
) -> Optional[Tuple[str, Policy]]:
"""
Return a policy version by ID from in-memory cache (no DB access).
Used when the request body specifies policy_<uuid> to execute a specific version
(e.g. published or draft). The cache is populated by sync_policies_from_db,
which loads draft and published versions keyed by policy_id.
Args:
policy_id: The policy version ID (raw UUID, no prefix)
Returns:
(policy_name, Policy) if found, None otherwise
"""
return self._policies_by_id.get(policy_id)
async def get_all_policies_from_db(
self,
prisma_client: "PrismaClient",
version_status: Optional[str] = None,
) -> List[PolicyDBResponse]:
"""
Get all policies from the database, optionally filtered by version_status.
Args:
prisma_client: The Prisma client instance
version_status: If set, only return policies with this status
("draft", "published", "production").
Returns:
List of PolicyDBResponse objects
"""
try:
where: Dict[str, Any] = {}
if version_status is not None:
where["version_status"] = version_status
policies = await prisma_client.db.litellm_policytable.find_many(
where=where if where else None,
order={"created_at": "desc"},
)
return [_row_to_policy_db_response(p) for p in policies]
except Exception as e:
verbose_proxy_logger.exception(f"Error getting policies from DB: {e}")
raise Exception(f"Error getting policies from DB: {str(e)}")
async def sync_policies_from_db(
self,
prisma_client: "PrismaClient",
) -> None:
"""
Sync policies from the database to in-memory registry.
- Production versions are loaded into _policies (by policy name) for resolution.
- Draft and published versions are loaded into _policies_by_id so request-body
policy_<uuid> overrides can be resolved without DB access in the hot path.
"""
try:
self._policies = {}
production = await self.get_all_policies_from_db(
prisma_client, version_status="production"
)
for policy_response in production:
policy = self._parse_policy(
policy_response.policy_name,
{
"inherit": policy_response.inherit,
"description": policy_response.description,
"guardrails": {
"add": policy_response.guardrails_add,
"remove": policy_response.guardrails_remove,
},
"condition": policy_response.condition,
"pipeline": policy_response.pipeline,
},
)
self.add_policy(policy_response.policy_name, policy)
self._policies_by_id = {}
non_production = await prisma_client.db.litellm_policytable.find_many(
where={"version_status": {"in": ["draft", "published"]}},
order={"created_at": "desc"},
)
for row in non_production:
policy = self._parse_policy(
row.policy_name,
{
"inherit": row.inherit,
"description": row.description,
"guardrails": {
"add": row.guardrails_add or [],
"remove": row.guardrails_remove or [],
},
"condition": row.condition,
"pipeline": row.pipeline,
},
)
self._policies_by_id[row.policy_id] = (row.policy_name, policy)
self._initialized = True
verbose_proxy_logger.info(
f"Synced {len(production)} production policies and {len(non_production)} "
"draft/published (by ID) from DB to in-memory registry"
)
except Exception as e:
verbose_proxy_logger.exception(f"Error syncing policies from DB: {e}")
raise Exception(f"Error syncing policies from DB: {str(e)}")
async def resolve_guardrails_from_db(
self,
policy_name: str,
prisma_client: "PrismaClient",
) -> List[str]:
"""
Resolve all guardrails for a policy from the database.
Uses the existing PolicyResolver to handle inheritance chain resolution.
Args:
policy_name: Name of the policy to resolve
prisma_client: The Prisma client instance
Returns:
List of resolved guardrail names
"""
from litellm.proxy.policy_engine.policy_resolver import PolicyResolver
try:
# Load only production versions so inheritance resolves against production
policies = await self.get_all_policies_from_db(
prisma_client, version_status="production"
)
# Build a temporary in-memory map for resolution
temp_policies = {}
for policy_response in policies:
policy = self._parse_policy(
policy_response.policy_name,
{
"inherit": policy_response.inherit,
"description": policy_response.description,
"guardrails": {
"add": policy_response.guardrails_add,
"remove": policy_response.guardrails_remove,
},
"condition": policy_response.condition,
"pipeline": policy_response.pipeline,
},
)
temp_policies[policy_response.policy_name] = policy
# Use the existing PolicyResolver to resolve guardrails
resolved_policy = PolicyResolver.resolve_policy_guardrails(
policy_name=policy_name,
policies=temp_policies,
context=None, # No context needed for simple resolution
)
return sorted(resolved_policy.guardrails)
except Exception as e:
verbose_proxy_logger.exception(f"Error resolving guardrails from DB: {e}")
raise Exception(f"Error resolving guardrails from DB: {str(e)}")
async def get_versions_by_policy_name(
self,
policy_name: str,
prisma_client: "PrismaClient",
) -> PolicyVersionListResponse:
"""
Get all versions of a policy by name, ordered by version_number descending.
Args:
policy_name: Name of the policy
prisma_client: The Prisma client instance
Returns:
PolicyVersionListResponse with policy_name and list of versions
"""
try:
rows = await prisma_client.db.litellm_policytable.find_many(
where={"policy_name": policy_name},
order={"version_number": "desc"},
)
versions = [_row_to_policy_db_response(r) for r in rows]
return PolicyVersionListResponse(
policy_name=policy_name,
versions=versions,
total_count=len(versions),
)
except Exception as e:
verbose_proxy_logger.exception(f"Error getting versions: {e}")
raise Exception(f"Error getting versions: {str(e)}")
async def create_new_version(
self,
policy_name: str,
prisma_client: "PrismaClient",
source_policy_id: Optional[str] = None,
created_by: Optional[str] = None,
) -> PolicyDBResponse:
"""
Create a new draft version of a policy. Copies all fields from the source.
Source is current production if source_policy_id is None.
Args:
policy_name: Name of the policy
prisma_client: The Prisma client instance
source_policy_id: Policy ID to clone from; if None, use current production
created_by: User who created the version
Returns:
PolicyDBResponse for the new draft version
"""
try:
if source_policy_id is not None:
source = await prisma_client.db.litellm_policytable.find_unique(
where={"policy_id": source_policy_id}
)
if source is None:
raise Exception(f"Source policy {source_policy_id} not found")
if source.policy_name != policy_name:
raise Exception(
f"Source policy name '{source.policy_name}' does not match '{policy_name}'"
)
else:
# Find current production version for this policy_name
prod = await prisma_client.db.litellm_policytable.find_first(
where={
"policy_name": policy_name,
"version_status": "production",
}
)
if prod is None:
raise Exception(
f"No production version found for policy '{policy_name}'"
)
source = prod
# Next version number
latest = await prisma_client.db.litellm_policytable.find_first(
where={"policy_name": policy_name},
order={"version_number": "desc"},
)
next_num = (latest.version_number + 1) if latest else 1
now = datetime.now(timezone.utc)
# Set is_latest=False on all existing versions for this policy_name
await prisma_client.db.litellm_policytable.update_many(
where={"policy_name": policy_name},
data={"is_latest": False},
)
data: Dict[str, Any] = {
"policy_name": policy_name,
"version_number": next_num,
"version_status": "draft",
"parent_version_id": source.policy_id,
"is_latest": True,
"published_at": None,
"production_at": None,
"inherit": source.inherit,
"description": source.description,
"guardrails_add": source.guardrails_add or [],
"guardrails_remove": source.guardrails_remove or [],
"created_at": now,
"updated_at": now,
"created_by": created_by,
"updated_by": created_by,
}
# Prisma expects Json fields as JSON strings on create (same as add_policy_to_db)
if source.condition is not None:
data["condition"] = (
json.dumps(source.condition)
if isinstance(source.condition, dict)
else source.condition
)
if source.pipeline is not None:
data["pipeline"] = (
json.dumps(source.pipeline)
if isinstance(source.pipeline, dict)
else source.pipeline
)
created = await prisma_client.db.litellm_policytable.create(data=data)
return _row_to_policy_db_response(created)
except Exception as e:
verbose_proxy_logger.exception(f"Error creating new version: {e}")
raise Exception(f"Error creating new version: {str(e)}")
async def update_version_status(
self,
policy_id: str,
new_status: str,
prisma_client: "PrismaClient",
updated_by: Optional[str] = None,
) -> PolicyDBResponse:
"""
Update a policy version's status. Valid transitions:
- draft -> published (sets published_at)
- published -> production (sets production_at, demotes current production to published, updates in-memory)
- production -> published (demotes, removes from in-memory)
- draft -> production: NOT allowed (must publish first)
- published -> draft: NOT allowed
Args:
policy_id: The policy version ID
new_status: "published" or "production"
prisma_client: The Prisma client instance
updated_by: User who updated
Returns:
PolicyDBResponse for the updated version
"""
try:
if new_status not in ("published", "production"):
raise Exception(
f"Invalid status '{new_status}'. Use 'published' or 'production'."
)
row = await prisma_client.db.litellm_policytable.find_unique(
where={"policy_id": policy_id}
)
if row is None:
raise Exception(f"Policy with ID {policy_id} not found")
current = getattr(row, "version_status", "production")
policy_name = row.policy_name
now = datetime.now(timezone.utc)
if new_status == "published":
if current != "draft":
raise Exception(
f"Only draft versions can be published. Current status: '{current}'."
)
updated = await prisma_client.db.litellm_policytable.update(
where={"policy_id": policy_id},
data={
"version_status": "published",
"published_at": now,
"updated_at": now,
"updated_by": updated_by,
},
)
return _row_to_policy_db_response(updated)
# new_status == "production"
if current not in ("draft", "published"):
raise Exception(
f"Only draft or published versions can be promoted to production. Current: '{current}'."
)
# Plan: "draft -> production" NOT allowed
if current == "draft":
raise Exception(
"Cannot promote draft directly to production. Publish the version first."
)
# Demote current production to published
await prisma_client.db.litellm_policytable.update_many(
where={
"policy_name": policy_name,
"version_status": "production",
},
data={
"version_status": "published",
"updated_at": now,
"updated_by": updated_by,
},
)
# Promote this version to production
updated = await prisma_client.db.litellm_policytable.update(
where={"policy_id": policy_id},
data={
"version_status": "production",
"production_at": now,
"updated_at": now,
"updated_by": updated_by,
},
)
# Update in-memory registry: remove old production (by name), add this one
self.remove_policy(policy_name)
policy = self._parse_policy(
policy_name,
{
"inherit": updated.inherit,
"description": updated.description,
"guardrails": {
"add": updated.guardrails_add or [],
"remove": updated.guardrails_remove or [],
},
"condition": updated.condition,
"pipeline": updated.pipeline,
},
)
self.add_policy(policy_name, policy)
return _row_to_policy_db_response(updated)
except Exception as e:
verbose_proxy_logger.exception(f"Error updating version status: {e}")
raise Exception(f"Error updating version status: {str(e)}")
async def compare_versions(
self,
policy_id_a: str,
policy_id_b: str,
prisma_client: "PrismaClient",
) -> PolicyVersionCompareResponse:
"""
Compare two policy versions and return field-by-field diffs.
Args:
policy_id_a: First policy version ID
policy_id_b: Second policy version ID
prisma_client: The Prisma client instance
Returns:
PolicyVersionCompareResponse with both versions and field_diffs
"""
try:
a = await prisma_client.db.litellm_policytable.find_unique(
where={"policy_id": policy_id_a}
)
b = await prisma_client.db.litellm_policytable.find_unique(
where={"policy_id": policy_id_b}
)
if a is None:
raise Exception(f"Policy {policy_id_a} not found")
if b is None:
raise Exception(f"Policy {policy_id_b} not found")
resp_a = _row_to_policy_db_response(a)
resp_b = _row_to_policy_db_response(b)
# Compare fields that are part of policy content (not metadata)
compare_fields = [
"inherit",
"description",
"guardrails_add",
"guardrails_remove",
"condition",
"pipeline",
]
field_diffs: Dict[str, Dict[str, Any]] = {}
for field in compare_fields:
val_a = getattr(resp_a, field)
val_b = getattr(resp_b, field)
if val_a != val_b:
field_diffs[field] = {"version_a": val_a, "version_b": val_b}
return PolicyVersionCompareResponse(
version_a=resp_a,
version_b=resp_b,
field_diffs=field_diffs,
)
except Exception as e:
verbose_proxy_logger.exception(f"Error comparing versions: {e}")
raise Exception(f"Error comparing versions: {str(e)}")
async def delete_all_versions(
self,
policy_name: str,
prisma_client: "PrismaClient",
) -> Dict[str, str]:
"""
Delete all versions of a policy. Also removes from in-memory registry.
Args:
policy_name: Name of the policy
prisma_client: The Prisma client instance
Returns:
Dict with success message
"""
try:
await prisma_client.db.litellm_policytable.delete_many(
where={"policy_name": policy_name}
)
self.remove_policy(policy_name)
return {
"message": f"All versions of policy '{policy_name}' deleted successfully"
}
except Exception as e:
verbose_proxy_logger.exception(f"Error deleting all versions: {e}")
raise Exception(f"Error deleting all versions: {str(e)}")
# Global singleton instance
_policy_registry: Optional[PolicyRegistry] = None
def get_policy_registry() -> PolicyRegistry:
"""
Get the global PolicyRegistry singleton.
Returns:
The global PolicyRegistry instance
"""
global _policy_registry
if _policy_registry is None:
_policy_registry = PolicyRegistry()
return _policy_registry