chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,978 @@
|
||||
"""
|
||||
Policy Registry - In-memory storage for policies.
|
||||
|
||||
Handles storing, retrieving, and managing policies.
|
||||
|
||||
Policies define WHAT guardrails to apply. WHERE they apply is defined
|
||||
by policy_attachments (see AttachmentRegistry).
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.types.proxy.policy_engine import (
|
||||
GuardrailPipeline,
|
||||
PipelineStep,
|
||||
Policy,
|
||||
PolicyCondition,
|
||||
PolicyCreateRequest,
|
||||
PolicyDBResponse,
|
||||
PolicyGuardrails,
|
||||
PolicyUpdateRequest,
|
||||
PolicyVersionCompareResponse,
|
||||
PolicyVersionListResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
# Prefix for policy version IDs in request body. Use policy_<uuid> to execute a specific version.
|
||||
POLICY_VERSION_ID_PREFIX = "policy_"
|
||||
|
||||
|
||||
def _row_to_policy_db_response(row: Any) -> PolicyDBResponse:
|
||||
"""Build PolicyDBResponse from a Prisma LiteLLM_PolicyTable row."""
|
||||
return PolicyDBResponse(
|
||||
policy_id=row.policy_id,
|
||||
policy_name=row.policy_name,
|
||||
version_number=getattr(row, "version_number", 1),
|
||||
version_status=getattr(row, "version_status", "production"),
|
||||
parent_version_id=getattr(row, "parent_version_id", None),
|
||||
is_latest=getattr(row, "is_latest", True),
|
||||
published_at=getattr(row, "published_at", None),
|
||||
production_at=getattr(row, "production_at", None),
|
||||
inherit=row.inherit,
|
||||
description=row.description,
|
||||
guardrails_add=row.guardrails_add or [],
|
||||
guardrails_remove=row.guardrails_remove or [],
|
||||
condition=row.condition,
|
||||
pipeline=row.pipeline,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
created_by=row.created_by,
|
||||
updated_by=row.updated_by,
|
||||
)
|
||||
|
||||
|
||||
class PolicyRegistry:
|
||||
"""
|
||||
In-memory registry for storing and managing policies.
|
||||
|
||||
This is a singleton that holds all loaded policies and provides
|
||||
methods to access them.
|
||||
|
||||
Policies define WHAT guardrails to apply:
|
||||
- Base guardrails via guardrails.add/remove
|
||||
- Inheritance via inherit field
|
||||
- Conditional guardrails via condition.model
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._policies: Dict[str, Policy] = {}
|
||||
self._policies_by_id: Dict[str, Tuple[str, Policy]] = {}
|
||||
self._initialized: bool = False
|
||||
|
||||
def load_policies(self, policies_config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Load policies from a configuration dictionary.
|
||||
|
||||
Args:
|
||||
policies_config: Dictionary mapping policy names to policy definitions.
|
||||
This is the raw config from the YAML file.
|
||||
"""
|
||||
self._policies = {}
|
||||
self._policies_by_id = {}
|
||||
|
||||
for policy_name, policy_data in policies_config.items():
|
||||
try:
|
||||
policy = self._parse_policy(policy_name, policy_data)
|
||||
self._policies[policy_name] = policy
|
||||
verbose_proxy_logger.debug(f"Loaded policy: {policy_name}")
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error loading policy '{policy_name}': {str(e)}"
|
||||
)
|
||||
raise ValueError(f"Invalid policy '{policy_name}': {str(e)}") from e
|
||||
|
||||
self._initialized = True
|
||||
verbose_proxy_logger.info(f"Loaded {len(self._policies)} policies")
|
||||
|
||||
def _parse_policy(self, policy_name: str, policy_data: Dict[str, Any]) -> Policy:
|
||||
"""
|
||||
Parse a policy from raw configuration data.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy
|
||||
policy_data: Raw policy configuration
|
||||
|
||||
Returns:
|
||||
Parsed Policy object
|
||||
"""
|
||||
# Parse guardrails
|
||||
guardrails_data = policy_data.get("guardrails", {})
|
||||
if isinstance(guardrails_data, dict):
|
||||
guardrails = PolicyGuardrails(
|
||||
add=guardrails_data.get("add"),
|
||||
remove=guardrails_data.get("remove"),
|
||||
)
|
||||
else:
|
||||
# Handle legacy format where guardrails might be a list
|
||||
guardrails = PolicyGuardrails(
|
||||
add=guardrails_data if guardrails_data else None
|
||||
)
|
||||
|
||||
# Parse condition (simple model-based condition)
|
||||
condition = None
|
||||
condition_data = policy_data.get("condition")
|
||||
if condition_data:
|
||||
condition = PolicyCondition(model=condition_data.get("model"))
|
||||
|
||||
# Parse pipeline (optional ordered guardrail execution)
|
||||
pipeline = PolicyRegistry._parse_pipeline(policy_data.get("pipeline"))
|
||||
|
||||
return Policy(
|
||||
inherit=policy_data.get("inherit"),
|
||||
description=policy_data.get("description"),
|
||||
guardrails=guardrails,
|
||||
condition=condition,
|
||||
pipeline=pipeline,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_pipeline(
|
||||
pipeline_data: Optional[Dict[str, Any]],
|
||||
) -> Optional[GuardrailPipeline]:
|
||||
"""Parse a pipeline configuration from raw data."""
|
||||
if pipeline_data is None:
|
||||
return None
|
||||
|
||||
steps_data = pipeline_data.get("steps", [])
|
||||
steps = [
|
||||
PipelineStep(**step_data) if isinstance(step_data, dict) else step_data
|
||||
for step_data in steps_data
|
||||
]
|
||||
|
||||
return GuardrailPipeline(
|
||||
mode=pipeline_data.get("mode", "pre_call"),
|
||||
steps=steps,
|
||||
)
|
||||
|
||||
def get_policy(self, policy_name: str) -> Optional[Policy]:
|
||||
"""
|
||||
Get a policy by name.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy to retrieve
|
||||
|
||||
Returns:
|
||||
Policy object if found, None otherwise
|
||||
"""
|
||||
return self._policies.get(policy_name)
|
||||
|
||||
def get_all_policies(self) -> Dict[str, Policy]:
|
||||
"""
|
||||
Get all loaded policies.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping policy names to Policy objects
|
||||
"""
|
||||
return self._policies.copy()
|
||||
|
||||
def get_policy_names(self) -> List[str]:
|
||||
"""
|
||||
Get list of all policy names.
|
||||
|
||||
Returns:
|
||||
List of policy names
|
||||
"""
|
||||
return list(self._policies.keys())
|
||||
|
||||
def has_policy(self, policy_name: str) -> bool:
|
||||
"""
|
||||
Check if a policy exists.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy to check
|
||||
|
||||
Returns:
|
||||
True if policy exists, False otherwise
|
||||
"""
|
||||
return policy_name in self._policies
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
"""
|
||||
Check if the registry has been initialized with policies.
|
||||
|
||||
Returns:
|
||||
True if policies have been loaded, False otherwise
|
||||
"""
|
||||
return self._initialized
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all policies from the registry.
|
||||
"""
|
||||
self._policies = {}
|
||||
self._initialized = False
|
||||
|
||||
def add_policy(self, policy_name: str, policy: Policy) -> None:
|
||||
"""
|
||||
Add or update a single policy.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy
|
||||
policy: Policy object to add
|
||||
"""
|
||||
self._policies[policy_name] = policy
|
||||
verbose_proxy_logger.debug(f"Added/updated policy: {policy_name}")
|
||||
|
||||
def remove_policy(self, policy_name: str) -> bool:
|
||||
"""
|
||||
Remove a policy by name.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy to remove
|
||||
|
||||
Returns:
|
||||
True if policy was removed, False if it didn't exist
|
||||
"""
|
||||
if policy_name in self._policies:
|
||||
del self._policies[policy_name]
|
||||
verbose_proxy_logger.debug(f"Removed policy: {policy_name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Database CRUD Methods
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def add_policy_to_db(
|
||||
self,
|
||||
policy_request: PolicyCreateRequest,
|
||||
prisma_client: "PrismaClient",
|
||||
created_by: Optional[str] = None,
|
||||
) -> PolicyDBResponse:
|
||||
"""
|
||||
Add a policy to the database.
|
||||
|
||||
Args:
|
||||
policy_request: The policy creation request
|
||||
prisma_client: The Prisma client instance
|
||||
created_by: User who created the policy
|
||||
|
||||
Returns:
|
||||
PolicyDBResponse with the created policy
|
||||
"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
# Build data dict; new policy is v1 production
|
||||
data: Dict[str, Any] = {
|
||||
"policy_name": policy_request.policy_name,
|
||||
"version_number": 1,
|
||||
"version_status": "production",
|
||||
"is_latest": True,
|
||||
"production_at": now,
|
||||
"guardrails_add": policy_request.guardrails_add or [],
|
||||
"guardrails_remove": policy_request.guardrails_remove or [],
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
|
||||
# Only add optional fields if they have values
|
||||
if policy_request.inherit is not None:
|
||||
data["inherit"] = policy_request.inherit
|
||||
if policy_request.description is not None:
|
||||
data["description"] = policy_request.description
|
||||
if created_by is not None:
|
||||
data["created_by"] = created_by
|
||||
data["updated_by"] = created_by
|
||||
if policy_request.condition is not None:
|
||||
data["condition"] = json.dumps(policy_request.condition.model_dump())
|
||||
if policy_request.pipeline is not None:
|
||||
validated_pipeline = GuardrailPipeline(**policy_request.pipeline)
|
||||
data["pipeline"] = json.dumps(validated_pipeline.model_dump())
|
||||
|
||||
created_policy = await prisma_client.db.litellm_policytable.create(
|
||||
data=data
|
||||
)
|
||||
|
||||
# Also add to in-memory registry
|
||||
policy = self._parse_policy(
|
||||
policy_request.policy_name,
|
||||
{
|
||||
"inherit": policy_request.inherit,
|
||||
"description": policy_request.description,
|
||||
"guardrails": {
|
||||
"add": policy_request.guardrails_add,
|
||||
"remove": policy_request.guardrails_remove,
|
||||
},
|
||||
"condition": (
|
||||
policy_request.condition.model_dump()
|
||||
if policy_request.condition
|
||||
else None
|
||||
),
|
||||
"pipeline": policy_request.pipeline,
|
||||
},
|
||||
)
|
||||
self.add_policy(policy_request.policy_name, policy)
|
||||
|
||||
return _row_to_policy_db_response(created_policy)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error adding policy to DB: {e}")
|
||||
raise Exception(f"Error adding policy to DB: {str(e)}")
|
||||
|
||||
async def update_policy_in_db(
|
||||
self,
|
||||
policy_id: str,
|
||||
policy_request: PolicyUpdateRequest,
|
||||
prisma_client: "PrismaClient",
|
||||
updated_by: Optional[str] = None,
|
||||
) -> PolicyDBResponse:
|
||||
"""
|
||||
Update a policy in the database. Only draft versions can be updated.
|
||||
|
||||
Args:
|
||||
policy_id: The ID of the policy to update
|
||||
policy_request: The policy update request
|
||||
prisma_client: The Prisma client instance
|
||||
updated_by: User who updated the policy
|
||||
|
||||
Returns:
|
||||
PolicyDBResponse with the updated policy
|
||||
|
||||
Raises:
|
||||
Exception: If policy is not in draft status (only drafts are editable).
|
||||
"""
|
||||
try:
|
||||
existing = await prisma_client.db.litellm_policytable.find_unique(
|
||||
where={"policy_id": policy_id}
|
||||
)
|
||||
if existing is None:
|
||||
raise Exception(f"Policy with ID {policy_id} not found")
|
||||
version_status = getattr(existing, "version_status", "production")
|
||||
if version_status != "draft":
|
||||
raise Exception(
|
||||
f"Only draft versions can be updated. This policy has status '{version_status}'."
|
||||
)
|
||||
|
||||
# Build update data - only include fields that are set
|
||||
update_data: Dict[str, Any] = {
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
"updated_by": updated_by,
|
||||
}
|
||||
|
||||
if policy_request.policy_name is not None:
|
||||
update_data["policy_name"] = policy_request.policy_name
|
||||
if policy_request.inherit is not None:
|
||||
update_data["inherit"] = policy_request.inherit
|
||||
if policy_request.description is not None:
|
||||
update_data["description"] = policy_request.description
|
||||
if policy_request.guardrails_add is not None:
|
||||
update_data["guardrails_add"] = policy_request.guardrails_add
|
||||
if policy_request.guardrails_remove is not None:
|
||||
update_data["guardrails_remove"] = policy_request.guardrails_remove
|
||||
if policy_request.condition is not None:
|
||||
update_data["condition"] = json.dumps(
|
||||
policy_request.condition.model_dump()
|
||||
)
|
||||
if policy_request.pipeline is not None:
|
||||
validated_pipeline = GuardrailPipeline(**policy_request.pipeline)
|
||||
update_data["pipeline"] = json.dumps(validated_pipeline.model_dump())
|
||||
|
||||
updated_policy = await prisma_client.db.litellm_policytable.update(
|
||||
where={"policy_id": policy_id},
|
||||
data=update_data,
|
||||
)
|
||||
|
||||
# Do NOT update in-memory registry: drafts are not loaded into memory.
|
||||
|
||||
return _row_to_policy_db_response(updated_policy)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error updating policy in DB: {e}")
|
||||
raise Exception(f"Error updating policy in DB: {str(e)}")
|
||||
|
||||
async def delete_policy_from_db(
|
||||
self,
|
||||
policy_id: str,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Delete a policy version from the database.
|
||||
|
||||
If the deleted version was production, it is removed from the in-memory
|
||||
registry. No other version is auto-promoted; admin must explicitly promote.
|
||||
|
||||
Args:
|
||||
policy_id: The ID of the policy version to delete
|
||||
prisma_client: The Prisma client instance
|
||||
|
||||
Returns:
|
||||
Dict with "message" and optional "warning" if production was deleted.
|
||||
"""
|
||||
try:
|
||||
policy = await prisma_client.db.litellm_policytable.find_unique(
|
||||
where={"policy_id": policy_id}
|
||||
)
|
||||
|
||||
if policy is None:
|
||||
raise Exception(f"Policy with ID {policy_id} not found")
|
||||
|
||||
version_status = getattr(policy, "version_status", "production")
|
||||
policy_name = policy.policy_name
|
||||
|
||||
# Delete from DB
|
||||
await prisma_client.db.litellm_policytable.delete(
|
||||
where={"policy_id": policy_id}
|
||||
)
|
||||
|
||||
result: Dict[str, Any] = {
|
||||
"message": f"Policy {policy_id} deleted successfully"
|
||||
}
|
||||
|
||||
# Remove from in-memory registry only if this was the production version
|
||||
if version_status == "production":
|
||||
self.remove_policy(policy_name)
|
||||
result["warning"] = (
|
||||
"Production version was deleted. No other version was promoted. "
|
||||
"Promote another version to production if this policy should remain active."
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error deleting policy from DB: {e}")
|
||||
raise Exception(f"Error deleting policy from DB: {str(e)}")
|
||||
|
||||
async def get_policy_by_id_from_db(
|
||||
self,
|
||||
policy_id: str,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> Optional[PolicyDBResponse]:
|
||||
"""
|
||||
Get a policy by ID from the database.
|
||||
|
||||
Args:
|
||||
policy_id: The ID of the policy to retrieve
|
||||
prisma_client: The Prisma client instance
|
||||
|
||||
Returns:
|
||||
PolicyDBResponse if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
policy = await prisma_client.db.litellm_policytable.find_unique(
|
||||
where={"policy_id": policy_id}
|
||||
)
|
||||
|
||||
if policy is None:
|
||||
return None
|
||||
|
||||
return _row_to_policy_db_response(policy)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting policy from DB: {e}")
|
||||
raise Exception(f"Error getting policy from DB: {str(e)}")
|
||||
|
||||
def get_policy_by_id_for_request(
|
||||
self, policy_id: str
|
||||
) -> Optional[Tuple[str, Policy]]:
|
||||
"""
|
||||
Return a policy version by ID from in-memory cache (no DB access).
|
||||
|
||||
Used when the request body specifies policy_<uuid> to execute a specific version
|
||||
(e.g. published or draft). The cache is populated by sync_policies_from_db,
|
||||
which loads draft and published versions keyed by policy_id.
|
||||
|
||||
Args:
|
||||
policy_id: The policy version ID (raw UUID, no prefix)
|
||||
|
||||
Returns:
|
||||
(policy_name, Policy) if found, None otherwise
|
||||
"""
|
||||
return self._policies_by_id.get(policy_id)
|
||||
|
||||
async def get_all_policies_from_db(
|
||||
self,
|
||||
prisma_client: "PrismaClient",
|
||||
version_status: Optional[str] = None,
|
||||
) -> List[PolicyDBResponse]:
|
||||
"""
|
||||
Get all policies from the database, optionally filtered by version_status.
|
||||
|
||||
Args:
|
||||
prisma_client: The Prisma client instance
|
||||
version_status: If set, only return policies with this status
|
||||
("draft", "published", "production").
|
||||
|
||||
Returns:
|
||||
List of PolicyDBResponse objects
|
||||
"""
|
||||
try:
|
||||
where: Dict[str, Any] = {}
|
||||
if version_status is not None:
|
||||
where["version_status"] = version_status
|
||||
|
||||
policies = await prisma_client.db.litellm_policytable.find_many(
|
||||
where=where if where else None,
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
|
||||
return [_row_to_policy_db_response(p) for p in policies]
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting policies from DB: {e}")
|
||||
raise Exception(f"Error getting policies from DB: {str(e)}")
|
||||
|
||||
async def sync_policies_from_db(
|
||||
self,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> None:
|
||||
"""
|
||||
Sync policies from the database to in-memory registry.
|
||||
- Production versions are loaded into _policies (by policy name) for resolution.
|
||||
- Draft and published versions are loaded into _policies_by_id so request-body
|
||||
policy_<uuid> overrides can be resolved without DB access in the hot path.
|
||||
"""
|
||||
try:
|
||||
self._policies = {}
|
||||
production = await self.get_all_policies_from_db(
|
||||
prisma_client, version_status="production"
|
||||
)
|
||||
for policy_response in production:
|
||||
policy = self._parse_policy(
|
||||
policy_response.policy_name,
|
||||
{
|
||||
"inherit": policy_response.inherit,
|
||||
"description": policy_response.description,
|
||||
"guardrails": {
|
||||
"add": policy_response.guardrails_add,
|
||||
"remove": policy_response.guardrails_remove,
|
||||
},
|
||||
"condition": policy_response.condition,
|
||||
"pipeline": policy_response.pipeline,
|
||||
},
|
||||
)
|
||||
self.add_policy(policy_response.policy_name, policy)
|
||||
|
||||
self._policies_by_id = {}
|
||||
non_production = await prisma_client.db.litellm_policytable.find_many(
|
||||
where={"version_status": {"in": ["draft", "published"]}},
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
for row in non_production:
|
||||
policy = self._parse_policy(
|
||||
row.policy_name,
|
||||
{
|
||||
"inherit": row.inherit,
|
||||
"description": row.description,
|
||||
"guardrails": {
|
||||
"add": row.guardrails_add or [],
|
||||
"remove": row.guardrails_remove or [],
|
||||
},
|
||||
"condition": row.condition,
|
||||
"pipeline": row.pipeline,
|
||||
},
|
||||
)
|
||||
self._policies_by_id[row.policy_id] = (row.policy_name, policy)
|
||||
|
||||
self._initialized = True
|
||||
verbose_proxy_logger.info(
|
||||
f"Synced {len(production)} production policies and {len(non_production)} "
|
||||
"draft/published (by ID) from DB to in-memory registry"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error syncing policies from DB: {e}")
|
||||
raise Exception(f"Error syncing policies from DB: {str(e)}")
|
||||
|
||||
async def resolve_guardrails_from_db(
|
||||
self,
|
||||
policy_name: str,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> List[str]:
|
||||
"""
|
||||
Resolve all guardrails for a policy from the database.
|
||||
|
||||
Uses the existing PolicyResolver to handle inheritance chain resolution.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy to resolve
|
||||
prisma_client: The Prisma client instance
|
||||
|
||||
Returns:
|
||||
List of resolved guardrail names
|
||||
"""
|
||||
from litellm.proxy.policy_engine.policy_resolver import PolicyResolver
|
||||
|
||||
try:
|
||||
# Load only production versions so inheritance resolves against production
|
||||
policies = await self.get_all_policies_from_db(
|
||||
prisma_client, version_status="production"
|
||||
)
|
||||
|
||||
# Build a temporary in-memory map for resolution
|
||||
temp_policies = {}
|
||||
for policy_response in policies:
|
||||
policy = self._parse_policy(
|
||||
policy_response.policy_name,
|
||||
{
|
||||
"inherit": policy_response.inherit,
|
||||
"description": policy_response.description,
|
||||
"guardrails": {
|
||||
"add": policy_response.guardrails_add,
|
||||
"remove": policy_response.guardrails_remove,
|
||||
},
|
||||
"condition": policy_response.condition,
|
||||
"pipeline": policy_response.pipeline,
|
||||
},
|
||||
)
|
||||
temp_policies[policy_response.policy_name] = policy
|
||||
|
||||
# Use the existing PolicyResolver to resolve guardrails
|
||||
resolved_policy = PolicyResolver.resolve_policy_guardrails(
|
||||
policy_name=policy_name,
|
||||
policies=temp_policies,
|
||||
context=None, # No context needed for simple resolution
|
||||
)
|
||||
|
||||
return sorted(resolved_policy.guardrails)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error resolving guardrails from DB: {e}")
|
||||
raise Exception(f"Error resolving guardrails from DB: {str(e)}")
|
||||
|
||||
async def get_versions_by_policy_name(
|
||||
self,
|
||||
policy_name: str,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> PolicyVersionListResponse:
|
||||
"""
|
||||
Get all versions of a policy by name, ordered by version_number descending.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy
|
||||
prisma_client: The Prisma client instance
|
||||
|
||||
Returns:
|
||||
PolicyVersionListResponse with policy_name and list of versions
|
||||
"""
|
||||
try:
|
||||
rows = await prisma_client.db.litellm_policytable.find_many(
|
||||
where={"policy_name": policy_name},
|
||||
order={"version_number": "desc"},
|
||||
)
|
||||
versions = [_row_to_policy_db_response(r) for r in rows]
|
||||
return PolicyVersionListResponse(
|
||||
policy_name=policy_name,
|
||||
versions=versions,
|
||||
total_count=len(versions),
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting versions: {e}")
|
||||
raise Exception(f"Error getting versions: {str(e)}")
|
||||
|
||||
async def create_new_version(
|
||||
self,
|
||||
policy_name: str,
|
||||
prisma_client: "PrismaClient",
|
||||
source_policy_id: Optional[str] = None,
|
||||
created_by: Optional[str] = None,
|
||||
) -> PolicyDBResponse:
|
||||
"""
|
||||
Create a new draft version of a policy. Copies all fields from the source.
|
||||
Source is current production if source_policy_id is None.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy
|
||||
prisma_client: The Prisma client instance
|
||||
source_policy_id: Policy ID to clone from; if None, use current production
|
||||
created_by: User who created the version
|
||||
|
||||
Returns:
|
||||
PolicyDBResponse for the new draft version
|
||||
"""
|
||||
try:
|
||||
if source_policy_id is not None:
|
||||
source = await prisma_client.db.litellm_policytable.find_unique(
|
||||
where={"policy_id": source_policy_id}
|
||||
)
|
||||
if source is None:
|
||||
raise Exception(f"Source policy {source_policy_id} not found")
|
||||
if source.policy_name != policy_name:
|
||||
raise Exception(
|
||||
f"Source policy name '{source.policy_name}' does not match '{policy_name}'"
|
||||
)
|
||||
else:
|
||||
# Find current production version for this policy_name
|
||||
prod = await prisma_client.db.litellm_policytable.find_first(
|
||||
where={
|
||||
"policy_name": policy_name,
|
||||
"version_status": "production",
|
||||
}
|
||||
)
|
||||
if prod is None:
|
||||
raise Exception(
|
||||
f"No production version found for policy '{policy_name}'"
|
||||
)
|
||||
source = prod
|
||||
|
||||
# Next version number
|
||||
latest = await prisma_client.db.litellm_policytable.find_first(
|
||||
where={"policy_name": policy_name},
|
||||
order={"version_number": "desc"},
|
||||
)
|
||||
next_num = (latest.version_number + 1) if latest else 1
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
# Set is_latest=False on all existing versions for this policy_name
|
||||
await prisma_client.db.litellm_policytable.update_many(
|
||||
where={"policy_name": policy_name},
|
||||
data={"is_latest": False},
|
||||
)
|
||||
|
||||
data: Dict[str, Any] = {
|
||||
"policy_name": policy_name,
|
||||
"version_number": next_num,
|
||||
"version_status": "draft",
|
||||
"parent_version_id": source.policy_id,
|
||||
"is_latest": True,
|
||||
"published_at": None,
|
||||
"production_at": None,
|
||||
"inherit": source.inherit,
|
||||
"description": source.description,
|
||||
"guardrails_add": source.guardrails_add or [],
|
||||
"guardrails_remove": source.guardrails_remove or [],
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"created_by": created_by,
|
||||
"updated_by": created_by,
|
||||
}
|
||||
# Prisma expects Json fields as JSON strings on create (same as add_policy_to_db)
|
||||
if source.condition is not None:
|
||||
data["condition"] = (
|
||||
json.dumps(source.condition)
|
||||
if isinstance(source.condition, dict)
|
||||
else source.condition
|
||||
)
|
||||
if source.pipeline is not None:
|
||||
data["pipeline"] = (
|
||||
json.dumps(source.pipeline)
|
||||
if isinstance(source.pipeline, dict)
|
||||
else source.pipeline
|
||||
)
|
||||
|
||||
created = await prisma_client.db.litellm_policytable.create(data=data)
|
||||
return _row_to_policy_db_response(created)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error creating new version: {e}")
|
||||
raise Exception(f"Error creating new version: {str(e)}")
|
||||
|
||||
async def update_version_status(
|
||||
self,
|
||||
policy_id: str,
|
||||
new_status: str,
|
||||
prisma_client: "PrismaClient",
|
||||
updated_by: Optional[str] = None,
|
||||
) -> PolicyDBResponse:
|
||||
"""
|
||||
Update a policy version's status. Valid transitions:
|
||||
- draft -> published (sets published_at)
|
||||
- published -> production (sets production_at, demotes current production to published, updates in-memory)
|
||||
- production -> published (demotes, removes from in-memory)
|
||||
- draft -> production: NOT allowed (must publish first)
|
||||
- published -> draft: NOT allowed
|
||||
|
||||
Args:
|
||||
policy_id: The policy version ID
|
||||
new_status: "published" or "production"
|
||||
prisma_client: The Prisma client instance
|
||||
updated_by: User who updated
|
||||
|
||||
Returns:
|
||||
PolicyDBResponse for the updated version
|
||||
"""
|
||||
try:
|
||||
if new_status not in ("published", "production"):
|
||||
raise Exception(
|
||||
f"Invalid status '{new_status}'. Use 'published' or 'production'."
|
||||
)
|
||||
|
||||
row = await prisma_client.db.litellm_policytable.find_unique(
|
||||
where={"policy_id": policy_id}
|
||||
)
|
||||
if row is None:
|
||||
raise Exception(f"Policy with ID {policy_id} not found")
|
||||
|
||||
current = getattr(row, "version_status", "production")
|
||||
policy_name = row.policy_name
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if new_status == "published":
|
||||
if current != "draft":
|
||||
raise Exception(
|
||||
f"Only draft versions can be published. Current status: '{current}'."
|
||||
)
|
||||
updated = await prisma_client.db.litellm_policytable.update(
|
||||
where={"policy_id": policy_id},
|
||||
data={
|
||||
"version_status": "published",
|
||||
"published_at": now,
|
||||
"updated_at": now,
|
||||
"updated_by": updated_by,
|
||||
},
|
||||
)
|
||||
return _row_to_policy_db_response(updated)
|
||||
|
||||
# new_status == "production"
|
||||
if current not in ("draft", "published"):
|
||||
raise Exception(
|
||||
f"Only draft or published versions can be promoted to production. Current: '{current}'."
|
||||
)
|
||||
# Plan: "draft -> production" NOT allowed
|
||||
if current == "draft":
|
||||
raise Exception(
|
||||
"Cannot promote draft directly to production. Publish the version first."
|
||||
)
|
||||
|
||||
# Demote current production to published
|
||||
await prisma_client.db.litellm_policytable.update_many(
|
||||
where={
|
||||
"policy_name": policy_name,
|
||||
"version_status": "production",
|
||||
},
|
||||
data={
|
||||
"version_status": "published",
|
||||
"updated_at": now,
|
||||
"updated_by": updated_by,
|
||||
},
|
||||
)
|
||||
|
||||
# Promote this version to production
|
||||
updated = await prisma_client.db.litellm_policytable.update(
|
||||
where={"policy_id": policy_id},
|
||||
data={
|
||||
"version_status": "production",
|
||||
"production_at": now,
|
||||
"updated_at": now,
|
||||
"updated_by": updated_by,
|
||||
},
|
||||
)
|
||||
|
||||
# Update in-memory registry: remove old production (by name), add this one
|
||||
self.remove_policy(policy_name)
|
||||
policy = self._parse_policy(
|
||||
policy_name,
|
||||
{
|
||||
"inherit": updated.inherit,
|
||||
"description": updated.description,
|
||||
"guardrails": {
|
||||
"add": updated.guardrails_add or [],
|
||||
"remove": updated.guardrails_remove or [],
|
||||
},
|
||||
"condition": updated.condition,
|
||||
"pipeline": updated.pipeline,
|
||||
},
|
||||
)
|
||||
self.add_policy(policy_name, policy)
|
||||
|
||||
return _row_to_policy_db_response(updated)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error updating version status: {e}")
|
||||
raise Exception(f"Error updating version status: {str(e)}")
|
||||
|
||||
async def compare_versions(
|
||||
self,
|
||||
policy_id_a: str,
|
||||
policy_id_b: str,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> PolicyVersionCompareResponse:
|
||||
"""
|
||||
Compare two policy versions and return field-by-field diffs.
|
||||
|
||||
Args:
|
||||
policy_id_a: First policy version ID
|
||||
policy_id_b: Second policy version ID
|
||||
prisma_client: The Prisma client instance
|
||||
|
||||
Returns:
|
||||
PolicyVersionCompareResponse with both versions and field_diffs
|
||||
"""
|
||||
try:
|
||||
a = await prisma_client.db.litellm_policytable.find_unique(
|
||||
where={"policy_id": policy_id_a}
|
||||
)
|
||||
b = await prisma_client.db.litellm_policytable.find_unique(
|
||||
where={"policy_id": policy_id_b}
|
||||
)
|
||||
if a is None:
|
||||
raise Exception(f"Policy {policy_id_a} not found")
|
||||
if b is None:
|
||||
raise Exception(f"Policy {policy_id_b} not found")
|
||||
|
||||
resp_a = _row_to_policy_db_response(a)
|
||||
resp_b = _row_to_policy_db_response(b)
|
||||
|
||||
# Compare fields that are part of policy content (not metadata)
|
||||
compare_fields = [
|
||||
"inherit",
|
||||
"description",
|
||||
"guardrails_add",
|
||||
"guardrails_remove",
|
||||
"condition",
|
||||
"pipeline",
|
||||
]
|
||||
field_diffs: Dict[str, Dict[str, Any]] = {}
|
||||
for field in compare_fields:
|
||||
val_a = getattr(resp_a, field)
|
||||
val_b = getattr(resp_b, field)
|
||||
if val_a != val_b:
|
||||
field_diffs[field] = {"version_a": val_a, "version_b": val_b}
|
||||
|
||||
return PolicyVersionCompareResponse(
|
||||
version_a=resp_a,
|
||||
version_b=resp_b,
|
||||
field_diffs=field_diffs,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error comparing versions: {e}")
|
||||
raise Exception(f"Error comparing versions: {str(e)}")
|
||||
|
||||
async def delete_all_versions(
|
||||
self,
|
||||
policy_name: str,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Delete all versions of a policy. Also removes from in-memory registry.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy
|
||||
prisma_client: The Prisma client instance
|
||||
|
||||
Returns:
|
||||
Dict with success message
|
||||
"""
|
||||
try:
|
||||
await prisma_client.db.litellm_policytable.delete_many(
|
||||
where={"policy_name": policy_name}
|
||||
)
|
||||
self.remove_policy(policy_name)
|
||||
return {
|
||||
"message": f"All versions of policy '{policy_name}' deleted successfully"
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error deleting all versions: {e}")
|
||||
raise Exception(f"Error deleting all versions: {str(e)}")
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_policy_registry: Optional[PolicyRegistry] = None
|
||||
|
||||
|
||||
def get_policy_registry() -> PolicyRegistry:
|
||||
"""
|
||||
Get the global PolicyRegistry singleton.
|
||||
|
||||
Returns:
|
||||
The global PolicyRegistry instance
|
||||
"""
|
||||
global _policy_registry
|
||||
if _policy_registry is None:
|
||||
_policy_registry = PolicyRegistry()
|
||||
return _policy_registry
|
||||
Reference in New Issue
Block a user