chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
Policy Resolver - Resolves final guardrail list from policies.
|
||||
|
||||
Handles:
|
||||
- Inheritance chain resolution (inherit with add/remove)
|
||||
- Applying add/remove guardrails
|
||||
- Evaluating model conditions
|
||||
- Combining guardrails from multiple matching policies
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.types.proxy.policy_engine import (
|
||||
GuardrailPipeline,
|
||||
Policy,
|
||||
PolicyMatchContext,
|
||||
ResolvedPolicy,
|
||||
)
|
||||
|
||||
|
||||
class PolicyResolver:
|
||||
"""
|
||||
Resolves the final list of guardrails from policies.
|
||||
|
||||
Handles:
|
||||
- Inheritance chains with add/remove operations
|
||||
- Model-based conditions
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def resolve_inheritance_chain(
|
||||
policy_name: str,
|
||||
policies: Dict[str, Policy],
|
||||
visited: Optional[Set[str]] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get the inheritance chain for a policy (from root to policy).
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy
|
||||
policies: Dictionary of all policies
|
||||
visited: Set of visited policies (for cycle detection)
|
||||
|
||||
Returns:
|
||||
List of policy names from root ancestor to the given policy
|
||||
"""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
if policy_name in visited:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Circular inheritance detected for policy '{policy_name}'"
|
||||
)
|
||||
return []
|
||||
|
||||
policy = policies.get(policy_name)
|
||||
if policy is None:
|
||||
return []
|
||||
|
||||
visited.add(policy_name)
|
||||
|
||||
if policy.inherit:
|
||||
parent_chain = PolicyResolver.resolve_inheritance_chain(
|
||||
policy_name=policy.inherit, policies=policies, visited=visited
|
||||
)
|
||||
return parent_chain + [policy_name]
|
||||
|
||||
return [policy_name]
|
||||
|
||||
@staticmethod
|
||||
def resolve_policy_guardrails(
|
||||
policy_name: str,
|
||||
policies: Dict[str, Policy],
|
||||
context: Optional[PolicyMatchContext] = None,
|
||||
) -> ResolvedPolicy:
|
||||
"""
|
||||
Resolve the final guardrails for a single policy, including inheritance.
|
||||
|
||||
This method:
|
||||
1. Resolves the inheritance chain
|
||||
2. Applies add/remove from each policy in the chain
|
||||
3. Evaluates model conditions (if context provided)
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy to resolve
|
||||
policies: Dictionary of all policies
|
||||
context: Optional request context for evaluating conditions
|
||||
|
||||
Returns:
|
||||
ResolvedPolicy with final guardrails list
|
||||
"""
|
||||
from litellm.proxy.policy_engine.condition_evaluator import ConditionEvaluator
|
||||
|
||||
inheritance_chain = PolicyResolver.resolve_inheritance_chain(
|
||||
policy_name=policy_name, policies=policies
|
||||
)
|
||||
|
||||
# Start with empty set of guardrails
|
||||
guardrails: Set[str] = set()
|
||||
|
||||
# Apply each policy in the chain (from root to leaf)
|
||||
for chain_policy_name in inheritance_chain:
|
||||
policy = policies.get(chain_policy_name)
|
||||
if policy is None:
|
||||
continue
|
||||
|
||||
# Check if policy condition matches (if context provided)
|
||||
if context is not None and policy.condition is not None:
|
||||
if not ConditionEvaluator.evaluate(
|
||||
condition=policy.condition,
|
||||
context=context,
|
||||
):
|
||||
verbose_proxy_logger.debug(
|
||||
f"Policy '{chain_policy_name}' condition did not match, skipping guardrails"
|
||||
)
|
||||
continue
|
||||
|
||||
# Add guardrails from guardrails.add
|
||||
for guardrail in policy.guardrails.get_add():
|
||||
guardrails.add(guardrail)
|
||||
|
||||
# Remove guardrails from guardrails.remove
|
||||
for guardrail in policy.guardrails.get_remove():
|
||||
guardrails.discard(guardrail)
|
||||
|
||||
return ResolvedPolicy(
|
||||
policy_name=policy_name,
|
||||
guardrails=list(guardrails),
|
||||
inheritance_chain=inheritance_chain,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def resolve_guardrails_for_context(
|
||||
context: PolicyMatchContext,
|
||||
policies: Optional[Dict[str, Policy]] = None,
|
||||
policy_names: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Resolve the final list of guardrails for a request context.
|
||||
|
||||
This:
|
||||
1. Finds all policies that match the context via policy_attachments (or policy_names if provided)
|
||||
2. Resolves each policy's guardrails (including inheritance)
|
||||
3. Evaluates model conditions
|
||||
4. Combines all guardrails (union)
|
||||
|
||||
Args:
|
||||
context: The request context
|
||||
policies: Dictionary of all policies (if None, uses global registry)
|
||||
policy_names: If provided, use this list instead of attachment matching
|
||||
|
||||
Returns:
|
||||
List of guardrail names to apply
|
||||
"""
|
||||
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
|
||||
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
|
||||
|
||||
if policies is None:
|
||||
registry = get_policy_registry()
|
||||
if not registry.is_initialized():
|
||||
return []
|
||||
policies = registry.get_all_policies()
|
||||
|
||||
# Use provided policy names or get matching policies via attachments
|
||||
matching_policy_names = (
|
||||
policy_names
|
||||
if policy_names is not None
|
||||
else PolicyMatcher.get_matching_policies(context=context)
|
||||
)
|
||||
|
||||
if not matching_policy_names:
|
||||
verbose_proxy_logger.debug(
|
||||
f"No policies match context: team_alias={context.team_alias}, "
|
||||
f"key_alias={context.key_alias}, model={context.model}"
|
||||
)
|
||||
return []
|
||||
|
||||
# Resolve each matching policy and combine guardrails
|
||||
all_guardrails: Set[str] = set()
|
||||
|
||||
for policy_name in matching_policy_names:
|
||||
resolved = PolicyResolver.resolve_policy_guardrails(
|
||||
policy_name=policy_name,
|
||||
policies=policies,
|
||||
context=context,
|
||||
)
|
||||
all_guardrails.update(resolved.guardrails)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Policy '{policy_name}' contributes guardrails: {resolved.guardrails}"
|
||||
)
|
||||
|
||||
result = list(all_guardrails)
|
||||
verbose_proxy_logger.debug(f"Final guardrails for context: {result}")
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def resolve_pipelines_for_context(
|
||||
context: PolicyMatchContext,
|
||||
policies: Optional[Dict[str, Policy]] = None,
|
||||
policy_names: Optional[List[str]] = None,
|
||||
) -> List[Tuple[str, GuardrailPipeline]]:
|
||||
"""
|
||||
Resolve pipelines from matching policies for a request context.
|
||||
|
||||
Returns (policy_name, pipeline) tuples for policies that have pipelines.
|
||||
Guardrails managed by pipelines should be excluded from the flat
|
||||
guardrails list to avoid double execution.
|
||||
|
||||
Args:
|
||||
context: The request context
|
||||
policies: Dictionary of all policies (if None, uses global registry)
|
||||
policy_names: If provided, use this list instead of attachment matching
|
||||
|
||||
Returns:
|
||||
List of (policy_name, GuardrailPipeline) tuples
|
||||
"""
|
||||
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
|
||||
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
|
||||
|
||||
if policies is None:
|
||||
registry = get_policy_registry()
|
||||
if not registry.is_initialized():
|
||||
return []
|
||||
policies = registry.get_all_policies()
|
||||
|
||||
matching_policy_names = (
|
||||
policy_names
|
||||
if policy_names is not None
|
||||
else PolicyMatcher.get_matching_policies(context=context)
|
||||
)
|
||||
if not matching_policy_names:
|
||||
return []
|
||||
|
||||
pipelines: List[Tuple[str, GuardrailPipeline]] = []
|
||||
for policy_name in matching_policy_names:
|
||||
policy = policies.get(policy_name)
|
||||
if policy is None:
|
||||
continue
|
||||
if policy.pipeline is not None:
|
||||
pipelines.append((policy_name, policy.pipeline))
|
||||
verbose_proxy_logger.debug(
|
||||
f"Policy '{policy_name}' has pipeline with "
|
||||
f"{len(policy.pipeline.steps)} steps"
|
||||
)
|
||||
|
||||
return pipelines
|
||||
|
||||
@staticmethod
|
||||
def get_pipeline_managed_guardrails(
|
||||
pipelines: List[Tuple[str, GuardrailPipeline]],
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Get the set of guardrail names managed by pipelines.
|
||||
|
||||
These guardrails should be excluded from normal independent execution.
|
||||
"""
|
||||
managed: Set[str] = set()
|
||||
for _policy_name, pipeline in pipelines:
|
||||
for step in pipeline.steps:
|
||||
managed.add(step.guardrail)
|
||||
return managed
|
||||
|
||||
@staticmethod
|
||||
def get_all_resolved_policies(
|
||||
policies: Optional[Dict[str, Policy]] = None,
|
||||
context: Optional[PolicyMatchContext] = None,
|
||||
) -> Dict[str, ResolvedPolicy]:
|
||||
"""
|
||||
Resolve all policies and return their final guardrails.
|
||||
|
||||
Useful for debugging and displaying policy configurations.
|
||||
|
||||
Args:
|
||||
policies: Dictionary of all policies (if None, uses global registry)
|
||||
context: Optional context for evaluating conditions
|
||||
|
||||
Returns:
|
||||
Dictionary mapping policy names to ResolvedPolicy objects
|
||||
"""
|
||||
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
|
||||
|
||||
if policies is None:
|
||||
registry = get_policy_registry()
|
||||
if not registry.is_initialized():
|
||||
return {}
|
||||
policies = registry.get_all_policies()
|
||||
|
||||
resolved: Dict[str, ResolvedPolicy] = {}
|
||||
|
||||
for policy_name in policies:
|
||||
resolved[policy_name] = PolicyResolver.resolve_policy_guardrails(
|
||||
policy_name=policy_name,
|
||||
policies=policies,
|
||||
context=context,
|
||||
)
|
||||
|
||||
return resolved
|
||||
Reference in New Issue
Block a user