Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/proxy/db/tool_registry_writer.py

441 lines
15 KiB
Python
Raw Normal View History

"""
DB helpers for LiteLLM_ToolTable the global tool registry.
Tools are auto-discovered from LLM responses and upserted here.
Admins use the management endpoints to read and update input_policy / output_policy.
"""
import uuid
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ToolDiscoveryQueueItem
from litellm.types.tool_management import (
LiteLLM_ToolTableRow,
ToolPolicyOverrideRow,
)
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient
def _row_to_model(row: Union[dict, Any]) -> LiteLLM_ToolTableRow:
"""Convert a Prisma model instance or dict to LiteLLM_ToolTableRow."""
model_dump = getattr(row, "model_dump", None)
if callable(model_dump):
row = model_dump()
elif not isinstance(row, dict):
row = {
k: getattr(row, k, None)
for k in (
"tool_id",
"tool_name",
"origin",
"input_policy",
"output_policy",
"call_count",
"assignments",
"key_hash",
"team_id",
"key_alias",
"user_agent",
"last_used_at",
"created_at",
"updated_at",
"created_by",
"updated_by",
)
}
return LiteLLM_ToolTableRow(
tool_id=row.get("tool_id", ""),
tool_name=row.get("tool_name", ""),
origin=row.get("origin"),
input_policy=row.get("input_policy") or "untrusted",
output_policy=row.get("output_policy") or "untrusted",
call_count=int(row.get("call_count") or 0),
assignments=row.get("assignments"),
key_hash=row.get("key_hash"),
team_id=row.get("team_id"),
key_alias=row.get("key_alias"),
user_agent=row.get("user_agent"),
last_used_at=row.get("last_used_at"),
created_at=row.get("created_at"),
updated_at=row.get("updated_at"),
created_by=row.get("created_by"),
updated_by=row.get("updated_by"),
)
async def batch_upsert_tools(
prisma_client: "PrismaClient",
items: List[ToolDiscoveryQueueItem],
) -> None:
"""
Batch-upsert tool registry rows via Prisma.
On first insert: sets input_policy/output_policy = "untrusted" (default), call_count = 1.
On conflict: increments call_count; preserves existing policies.
"""
if not items:
return
try:
data = [item for item in items if item.get("tool_name")]
if not data:
return
now = datetime.now(timezone.utc)
table = prisma_client.db.litellm_tooltable
for item in data:
tool_name = item.get("tool_name", "")
origin = item.get("origin") or "user_defined"
created_by = item.get("created_by") or "system"
key_hash = item.get("key_hash")
team_id = item.get("team_id")
key_alias = item.get("key_alias")
user_agent = item.get("user_agent")
await table.upsert(
where={"tool_name": tool_name},
data={
"create": {
"tool_id": str(uuid.uuid4()),
"tool_name": tool_name,
"origin": origin,
"input_policy": "untrusted",
"output_policy": "untrusted",
"call_count": 1,
"created_by": created_by,
"updated_by": created_by,
"key_hash": key_hash,
"team_id": team_id,
"key_alias": key_alias,
"user_agent": user_agent,
"last_used_at": now,
},
"update": {
"call_count": {"increment": 1},
"updated_at": now,
"last_used_at": now,
},
},
)
verbose_proxy_logger.debug(
"tool_registry_writer: upserted %d tool(s)", len(data)
)
except Exception as e:
verbose_proxy_logger.error(
"tool_registry_writer batch_upsert_tools error: %s", e
)
async def list_tools(
prisma_client: "PrismaClient",
input_policy: Optional[str] = None,
) -> List[LiteLLM_ToolTableRow]:
"""Return all tools, optionally filtered by input_policy."""
try:
where = {"input_policy": input_policy} if input_policy is not None else {}
rows = await prisma_client.db.litellm_tooltable.find_many(
where=where,
order={"created_at": "desc"},
)
return [_row_to_model(row) for row in rows]
except Exception as e:
verbose_proxy_logger.error("tool_registry_writer list_tools error: %s", e)
return []
async def get_tool(
prisma_client: "PrismaClient",
tool_name: str,
) -> Optional[LiteLLM_ToolTableRow]:
"""Return a single tool row by tool_name."""
try:
row = await prisma_client.db.litellm_tooltable.find_unique(
where={"tool_name": tool_name},
)
if row is None:
return None
return _row_to_model(row)
except Exception as e:
verbose_proxy_logger.error("tool_registry_writer get_tool error: %s", e)
return None
async def update_tool_policy(
prisma_client: "PrismaClient",
tool_name: str,
updated_by: Optional[str],
input_policy: Optional[str] = None,
output_policy: Optional[str] = None,
) -> Optional[LiteLLM_ToolTableRow]:
"""Update input_policy and/or output_policy for a tool. Upserts the row if it does not exist yet."""
try:
_updated_by = updated_by or "system"
now = datetime.now(timezone.utc)
create_data: dict = {
"tool_id": str(uuid.uuid4()),
"tool_name": tool_name,
"input_policy": input_policy or "untrusted",
"output_policy": output_policy or "untrusted",
"created_by": _updated_by,
"updated_by": _updated_by,
"created_at": now,
"updated_at": now,
}
update_data: dict = {
"updated_by": _updated_by,
"updated_at": now,
}
if input_policy is not None:
update_data["input_policy"] = input_policy
if output_policy is not None:
update_data["output_policy"] = output_policy
await prisma_client.db.litellm_tooltable.upsert(
where={"tool_name": tool_name},
data={
"create": create_data,
"update": update_data,
},
)
return await get_tool(prisma_client, tool_name)
except Exception as e:
verbose_proxy_logger.error(
"tool_registry_writer update_tool_policy error: %s", e
)
return None
async def get_tools_by_names(
prisma_client: "PrismaClient",
tool_names: List[str],
) -> Dict[str, Tuple[str, str]]:
"""
Return a {tool_name: (input_policy, output_policy)} map for the given tool names.
"""
if not tool_names:
return {}
try:
rows = await prisma_client.db.litellm_tooltable.find_many(
where={"tool_name": {"in": tool_names}},
)
return {
row.tool_name: (
getattr(row, "input_policy", "untrusted") or "untrusted",
getattr(row, "output_policy", "untrusted") or "untrusted",
)
for row in rows
}
except Exception as e:
verbose_proxy_logger.error(
"tool_registry_writer get_tools_by_names error: %s", e
)
return {}
async def list_overrides_for_tool(
prisma_client: "PrismaClient",
tool_name: str,
) -> List[ToolPolicyOverrideRow]:
"""
Return override-like rows for a tool by finding object permissions that have
this tool in blocked_tools, then resolving each permission to key/team scope for display.
"""
out: List[ToolPolicyOverrideRow] = []
try:
perms = await prisma_client.db.litellm_objectpermissiontable.find_many(
where={"blocked_tools": {"has": tool_name}},
include={
"verification_tokens": True,
"teams": True,
},
)
for perm in perms:
op_id = getattr(perm, "object_permission_id", None) or ""
tokens = getattr(perm, "verification_tokens", []) or []
teams = getattr(perm, "teams", []) or []
for t in tokens:
out.append(
ToolPolicyOverrideRow(
override_id=op_id,
tool_name=tool_name,
team_id=None,
key_hash=getattr(t, "token", None),
input_policy="blocked",
key_alias=getattr(t, "key_alias", None),
created_at=None,
updated_at=None,
)
)
for team in teams:
out.append(
ToolPolicyOverrideRow(
override_id=op_id,
tool_name=tool_name,
team_id=getattr(team, "team_id", None),
key_hash=None,
input_policy="blocked",
key_alias=getattr(team, "team_alias", None),
created_at=None,
updated_at=None,
)
)
return out
except Exception as e:
verbose_proxy_logger.error(
"tool_registry_writer list_overrides_for_tool error: %s", e
)
return []
class ToolPolicyRegistry:
"""
In-memory registry of tool policies synced from DB.
Hot path uses get_effective_policies only no DB, no cache.
"""
def __init__(self) -> None:
self._tool_input_policies: Dict[str, str] = {}
self._tool_output_policies: Dict[str, str] = {}
self._blocked_tools_by_op_id: Dict[str, List[str]] = {}
self._initialized: bool = False
def is_initialized(self) -> bool:
return self._initialized
async def sync_tool_policy_from_db(self, prisma_client: "PrismaClient") -> None:
"""Load all tool policies and object-permission blocked_tools from DB."""
try:
tools = await prisma_client.db.litellm_tooltable.find_many()
self._tool_input_policies = {
row.tool_name: getattr(row, "input_policy", "untrusted") or "untrusted"
for row in tools
}
self._tool_output_policies = {
row.tool_name: getattr(row, "output_policy", "untrusted") or "untrusted"
for row in tools
}
perms = await prisma_client.db.litellm_objectpermissiontable.find_many()
self._blocked_tools_by_op_id = {}
for row in perms:
op_id = getattr(row, "object_permission_id", None)
blocked = getattr(row, "blocked_tools", None) or []
if op_id:
self._blocked_tools_by_op_id[op_id] = list(blocked)
self._initialized = True
verbose_proxy_logger.info(
"ToolPolicyRegistry: synced %d tool policies and %d object permissions from DB",
len(self._tool_input_policies),
len(self._blocked_tools_by_op_id),
)
except Exception as e:
verbose_proxy_logger.exception(
"ToolPolicyRegistry sync_tool_policy_from_db error: %s", e
)
raise
def get_input_policy(self, tool_name: str) -> str:
return self._tool_input_policies.get(tool_name, "untrusted")
def get_output_policy(self, tool_name: str) -> str:
return self._tool_output_policies.get(tool_name, "untrusted")
def get_effective_policies(
self,
tool_names: List[str],
object_permission_id: Optional[str] = None,
team_object_permission_id: Optional[str] = None,
) -> Dict[str, str]:
"""
Return effective input_policy per tool from in-memory state.
If tool is in key or team blocked_tools -> "blocked", else global input_policy or "untrusted".
"""
if not tool_names:
return {}
blocked: set = set()
for op_id in (object_permission_id, team_object_permission_id):
if op_id and op_id.strip():
blocked.update(self._blocked_tools_by_op_id.get(op_id.strip(), []))
result: Dict[str, str] = {}
for name in tool_names:
if name in blocked:
result[name] = "blocked"
else:
result[name] = self._tool_input_policies.get(name, "untrusted")
return result
_tool_policy_registry: Optional[ToolPolicyRegistry] = None
def get_tool_policy_registry() -> ToolPolicyRegistry:
"""Return the global ToolPolicyRegistry singleton."""
global _tool_policy_registry
if _tool_policy_registry is None:
_tool_policy_registry = ToolPolicyRegistry()
return _tool_policy_registry
async def add_tool_to_object_permission_blocked(
prisma_client: "PrismaClient",
object_permission_id: str,
tool_name: str,
) -> bool:
"""Add tool_name to the permission's blocked_tools if not already present."""
if not object_permission_id or not tool_name:
return False
try:
row = await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": object_permission_id},
)
if row is None:
return False
current = list(getattr(row, "blocked_tools", []) or [])
if tool_name in current:
return True
current.append(tool_name)
await prisma_client.db.litellm_objectpermissiontable.update(
where={"object_permission_id": object_permission_id},
data={"blocked_tools": current},
)
return True
except Exception as e:
verbose_proxy_logger.error(
"tool_registry_writer add_tool_to_object_permission_blocked error: %s", e
)
return False
async def remove_tool_from_object_permission_blocked(
prisma_client: "PrismaClient",
object_permission_id: str,
tool_name: str,
) -> bool:
"""Remove tool_name from the permission's blocked_tools. Returns False if tool was not in list."""
if not object_permission_id or not tool_name:
return False
try:
row = await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": object_permission_id},
)
if row is None:
return False
current = list(getattr(row, "blocked_tools", []) or [])
if tool_name not in current:
return False
current = [t for t in current if t != tool_name]
await prisma_client.db.litellm_objectpermissiontable.update(
where={"object_permission_id": object_permission_id},
data={"blocked_tools": current},
)
return True
except Exception as e:
verbose_proxy_logger.error(
"tool_registry_writer remove_tool_from_object_permission_blocked error: %s",
e,
)
return False