459 lines
18 KiB
Python
459 lines
18 KiB
Python
|
|
import hashlib
|
||
|
|
import json
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
from typing import Any, Dict, List, Optional
|
||
|
|
|
||
|
|
import litellm
|
||
|
|
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||
|
|
from litellm.proxy.management_helpers.object_permission_utils import (
|
||
|
|
handle_update_object_permission_common,
|
||
|
|
)
|
||
|
|
from litellm.proxy.utils import PrismaClient
|
||
|
|
from litellm.types.agents import AgentConfig, AgentResponse, PatchAgentRequest
|
||
|
|
|
||
|
|
|
||
|
|
class AgentRegistry:
|
||
|
|
def __init__(self):
|
||
|
|
self.agent_list: List[AgentResponse] = []
|
||
|
|
|
||
|
|
def reset_agent_list(self):
|
||
|
|
self.agent_list = []
|
||
|
|
|
||
|
|
def register_agent(self, agent_config: AgentResponse):
|
||
|
|
self.agent_list.append(agent_config)
|
||
|
|
|
||
|
|
def deregister_agent(self, agent_name: str):
|
||
|
|
self.agent_list = [
|
||
|
|
agent for agent in self.agent_list if agent.agent_name != agent_name
|
||
|
|
]
|
||
|
|
|
||
|
|
def get_agent_list(self, agent_names: Optional[List[str]] = None):
|
||
|
|
if agent_names is not None:
|
||
|
|
return [
|
||
|
|
agent for agent in self.agent_list if agent.agent_name in agent_names
|
||
|
|
]
|
||
|
|
return self.agent_list
|
||
|
|
|
||
|
|
def get_public_agent_list(self) -> List[AgentResponse]:
|
||
|
|
public_agent_list: List[AgentResponse] = []
|
||
|
|
if litellm.public_agent_groups is None:
|
||
|
|
return public_agent_list
|
||
|
|
for agent in self.agent_list:
|
||
|
|
if agent.agent_id in litellm.public_agent_groups:
|
||
|
|
public_agent_list.append(agent)
|
||
|
|
return public_agent_list
|
||
|
|
|
||
|
|
def _create_agent_id(self, agent_config: AgentConfig) -> str:
|
||
|
|
return hashlib.sha256(
|
||
|
|
json.dumps(agent_config, sort_keys=True).encode()
|
||
|
|
).hexdigest()
|
||
|
|
|
||
|
|
def load_agents_from_config(self, agent_config: Optional[List[AgentConfig]] = None):
|
||
|
|
if agent_config is None:
|
||
|
|
return None
|
||
|
|
|
||
|
|
for agent_config_item in agent_config:
|
||
|
|
if not isinstance(agent_config_item, dict):
|
||
|
|
raise ValueError("agent_config must be a list of dictionaries")
|
||
|
|
|
||
|
|
agent_name = agent_config_item.get("agent_name")
|
||
|
|
agent_card_params = agent_config_item.get("agent_card_params")
|
||
|
|
if not all([agent_name, agent_card_params]):
|
||
|
|
continue
|
||
|
|
|
||
|
|
# create a stable hash id for config item
|
||
|
|
config_hash = self._create_agent_id(agent_config_item)
|
||
|
|
|
||
|
|
self.register_agent(agent_config=AgentResponse(agent_id=config_hash, **agent_config_item)) # type: ignore
|
||
|
|
|
||
|
|
def load_agents_from_db_and_config(
|
||
|
|
self,
|
||
|
|
agent_config: Optional[List[AgentConfig]] = None,
|
||
|
|
db_agents: Optional[List[Dict[str, Any]]] = None,
|
||
|
|
):
|
||
|
|
self.reset_agent_list()
|
||
|
|
|
||
|
|
if agent_config:
|
||
|
|
for agent_config_item in agent_config:
|
||
|
|
if not isinstance(agent_config_item, dict):
|
||
|
|
raise ValueError("agent_config must be a list of dictionaries")
|
||
|
|
|
||
|
|
self.register_agent(agent_config=AgentResponse(agent_id=self._create_agent_id(agent_config_item), **agent_config_item)) # type: ignore
|
||
|
|
|
||
|
|
if db_agents:
|
||
|
|
for db_agent in db_agents:
|
||
|
|
if not isinstance(db_agent, dict):
|
||
|
|
raise ValueError("db_agents must be a list of dictionaries")
|
||
|
|
|
||
|
|
self.register_agent(agent_config=AgentResponse(**db_agent)) # type: ignore
|
||
|
|
return self.agent_list
|
||
|
|
|
||
|
|
###########################################################
|
||
|
|
########### DB management helpers for agents ###########
|
||
|
|
############################################################
|
||
|
|
async def add_agent_to_db(
|
||
|
|
self, agent: AgentConfig, prisma_client: PrismaClient, created_by: str
|
||
|
|
) -> AgentResponse:
|
||
|
|
"""
|
||
|
|
Add an agent to the database
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
agent_name = agent.get("agent_name")
|
||
|
|
|
||
|
|
# Serialize litellm_params
|
||
|
|
litellm_params_obj: Any = agent.get("litellm_params", {})
|
||
|
|
if hasattr(litellm_params_obj, "model_dump"):
|
||
|
|
litellm_params_dict = litellm_params_obj.model_dump()
|
||
|
|
else:
|
||
|
|
litellm_params_dict = (
|
||
|
|
dict(litellm_params_obj) if litellm_params_obj else {}
|
||
|
|
)
|
||
|
|
litellm_params: str = safe_dumps(litellm_params_dict)
|
||
|
|
|
||
|
|
# Serialize agent_card_params
|
||
|
|
agent_card_params_obj: Any = agent.get("agent_card_params", {})
|
||
|
|
if hasattr(agent_card_params_obj, "model_dump"):
|
||
|
|
agent_card_params_dict = agent_card_params_obj.model_dump()
|
||
|
|
else:
|
||
|
|
agent_card_params_dict = (
|
||
|
|
dict(agent_card_params_obj) if agent_card_params_obj else {}
|
||
|
|
)
|
||
|
|
agent_card_params: str = safe_dumps(agent_card_params_dict)
|
||
|
|
|
||
|
|
# Handle object_permission (MCP tool access for agent)
|
||
|
|
object_permission_id: Optional[str] = None
|
||
|
|
if agent.get("object_permission") is not None:
|
||
|
|
agent_copy = dict(agent)
|
||
|
|
object_permission_id = await handle_update_object_permission_common(
|
||
|
|
agent_copy, None, prisma_client
|
||
|
|
)
|
||
|
|
|
||
|
|
# Serialize static_headers
|
||
|
|
static_headers_obj = agent.get("static_headers")
|
||
|
|
static_headers_val: Optional[str] = (
|
||
|
|
safe_dumps(dict(static_headers_obj)) if static_headers_obj else None
|
||
|
|
)
|
||
|
|
|
||
|
|
extra_headers_val: Optional[List[str]] = agent.get("extra_headers")
|
||
|
|
|
||
|
|
create_data: Dict[str, Any] = {
|
||
|
|
"agent_name": agent_name,
|
||
|
|
"litellm_params": litellm_params,
|
||
|
|
"agent_card_params": agent_card_params,
|
||
|
|
"created_by": created_by,
|
||
|
|
"updated_by": created_by,
|
||
|
|
"created_at": datetime.now(timezone.utc),
|
||
|
|
"updated_at": datetime.now(timezone.utc),
|
||
|
|
}
|
||
|
|
if static_headers_val is not None:
|
||
|
|
create_data["static_headers"] = static_headers_val
|
||
|
|
if extra_headers_val is not None:
|
||
|
|
create_data["extra_headers"] = extra_headers_val
|
||
|
|
if object_permission_id is not None:
|
||
|
|
create_data["object_permission_id"] = object_permission_id
|
||
|
|
|
||
|
|
for rate_field in (
|
||
|
|
"tpm_limit",
|
||
|
|
"rpm_limit",
|
||
|
|
"session_tpm_limit",
|
||
|
|
"session_rpm_limit",
|
||
|
|
):
|
||
|
|
_val = agent.get(rate_field)
|
||
|
|
if _val is not None:
|
||
|
|
create_data[rate_field] = _val
|
||
|
|
|
||
|
|
# Create agent in DB
|
||
|
|
created_agent = await prisma_client.db.litellm_agentstable.create(
|
||
|
|
data=create_data,
|
||
|
|
include={"object_permission": True},
|
||
|
|
)
|
||
|
|
|
||
|
|
created_agent_dict = created_agent.model_dump()
|
||
|
|
if created_agent.object_permission is not None:
|
||
|
|
try:
|
||
|
|
created_agent_dict[
|
||
|
|
"object_permission"
|
||
|
|
] = created_agent.object_permission.model_dump()
|
||
|
|
except Exception:
|
||
|
|
created_agent_dict[
|
||
|
|
"object_permission"
|
||
|
|
] = created_agent.object_permission.dict()
|
||
|
|
return AgentResponse(**created_agent_dict) # type: ignore
|
||
|
|
except Exception as e:
|
||
|
|
raise Exception(f"Error adding agent to DB: {str(e)}")
|
||
|
|
|
||
|
|
async def delete_agent_from_db(
|
||
|
|
self, agent_id: str, prisma_client: PrismaClient
|
||
|
|
) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Delete an agent from the database
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
deleted_agent = await prisma_client.db.litellm_agentstable.delete(
|
||
|
|
where={"agent_id": agent_id}
|
||
|
|
)
|
||
|
|
return dict(deleted_agent)
|
||
|
|
except Exception as e:
|
||
|
|
raise Exception(f"Error deleting agent from DB: {str(e)}")
|
||
|
|
|
||
|
|
async def patch_agent_in_db(
|
||
|
|
self,
|
||
|
|
agent_id: str,
|
||
|
|
agent: PatchAgentRequest,
|
||
|
|
prisma_client: PrismaClient,
|
||
|
|
updated_by: str,
|
||
|
|
) -> AgentResponse:
|
||
|
|
"""
|
||
|
|
Patch an agent in the database.
|
||
|
|
|
||
|
|
Get the existing agent from the database and patch it with the new values.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
agent_id: The ID of the agent to patch
|
||
|
|
agent: The new agent values to patch
|
||
|
|
prisma_client: The Prisma client to use
|
||
|
|
updated_by: The user ID of the user who is patching the agent
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The patched agent
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||
|
|
where={"agent_id": agent_id}
|
||
|
|
)
|
||
|
|
if existing_agent is not None:
|
||
|
|
existing_agent = dict(existing_agent)
|
||
|
|
|
||
|
|
if existing_agent is None:
|
||
|
|
raise Exception(f"Agent with ID {agent_id} not found")
|
||
|
|
|
||
|
|
augment_agent = {**existing_agent, **agent}
|
||
|
|
update_data: Dict[str, Any] = {}
|
||
|
|
if augment_agent.get("agent_name"):
|
||
|
|
update_data["agent_name"] = augment_agent.get("agent_name")
|
||
|
|
if augment_agent.get("litellm_params"):
|
||
|
|
update_data["litellm_params"] = safe_dumps(
|
||
|
|
augment_agent.get("litellm_params")
|
||
|
|
)
|
||
|
|
if augment_agent.get("agent_card_params"):
|
||
|
|
update_data["agent_card_params"] = safe_dumps(
|
||
|
|
augment_agent.get("agent_card_params")
|
||
|
|
)
|
||
|
|
|
||
|
|
for rate_field in (
|
||
|
|
"tpm_limit",
|
||
|
|
"rpm_limit",
|
||
|
|
"session_tpm_limit",
|
||
|
|
"session_rpm_limit",
|
||
|
|
):
|
||
|
|
if rate_field in agent:
|
||
|
|
update_data[rate_field] = agent.get(rate_field)
|
||
|
|
if "static_headers" in agent:
|
||
|
|
headers_value = agent.get("static_headers")
|
||
|
|
update_data["static_headers"] = safe_dumps(
|
||
|
|
dict(headers_value) if headers_value is not None else {}
|
||
|
|
)
|
||
|
|
if "extra_headers" in agent:
|
||
|
|
extra_headers_value = agent.get("extra_headers")
|
||
|
|
update_data["extra_headers"] = (
|
||
|
|
extra_headers_value if extra_headers_value is not None else []
|
||
|
|
)
|
||
|
|
if agent.get("object_permission") is not None:
|
||
|
|
agent_copy = dict(augment_agent)
|
||
|
|
existing_object_permission_id = existing_agent.get(
|
||
|
|
"object_permission_id"
|
||
|
|
)
|
||
|
|
object_permission_id = await handle_update_object_permission_common(
|
||
|
|
agent_copy,
|
||
|
|
existing_object_permission_id,
|
||
|
|
prisma_client,
|
||
|
|
)
|
||
|
|
if object_permission_id is not None:
|
||
|
|
update_data["object_permission_id"] = object_permission_id
|
||
|
|
# Patch agent in DB
|
||
|
|
patched_agent = await prisma_client.db.litellm_agentstable.update(
|
||
|
|
where={"agent_id": agent_id},
|
||
|
|
data={
|
||
|
|
**update_data,
|
||
|
|
"updated_by": updated_by,
|
||
|
|
"updated_at": datetime.now(timezone.utc),
|
||
|
|
},
|
||
|
|
include={"object_permission": True},
|
||
|
|
)
|
||
|
|
patched_agent_dict = patched_agent.model_dump()
|
||
|
|
if patched_agent.object_permission is not None:
|
||
|
|
try:
|
||
|
|
patched_agent_dict[
|
||
|
|
"object_permission"
|
||
|
|
] = patched_agent.object_permission.model_dump()
|
||
|
|
except Exception:
|
||
|
|
patched_agent_dict[
|
||
|
|
"object_permission"
|
||
|
|
] = patched_agent.object_permission.dict()
|
||
|
|
return AgentResponse(**patched_agent_dict) # type: ignore
|
||
|
|
except Exception as e:
|
||
|
|
raise Exception(f"Error patching agent in DB: {str(e)}")
|
||
|
|
|
||
|
|
async def update_agent_in_db(
|
||
|
|
self,
|
||
|
|
agent_id: str,
|
||
|
|
agent: AgentConfig,
|
||
|
|
prisma_client: PrismaClient,
|
||
|
|
updated_by: str,
|
||
|
|
) -> AgentResponse:
|
||
|
|
"""
|
||
|
|
Update an agent in the database
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
agent_name = agent.get("agent_name")
|
||
|
|
|
||
|
|
# Serialize litellm_params
|
||
|
|
litellm_params_obj: Any = agent.get("litellm_params", {})
|
||
|
|
if hasattr(litellm_params_obj, "model_dump"):
|
||
|
|
litellm_params_dict = litellm_params_obj.model_dump()
|
||
|
|
else:
|
||
|
|
litellm_params_dict = (
|
||
|
|
dict(litellm_params_obj) if litellm_params_obj else {}
|
||
|
|
)
|
||
|
|
litellm_params: str = safe_dumps(litellm_params_dict)
|
||
|
|
|
||
|
|
# Serialize agent_card_params
|
||
|
|
agent_card_params_obj: Any = agent.get("agent_card_params", {})
|
||
|
|
if hasattr(agent_card_params_obj, "model_dump"):
|
||
|
|
agent_card_params_dict = agent_card_params_obj.model_dump()
|
||
|
|
else:
|
||
|
|
agent_card_params_dict = (
|
||
|
|
dict(agent_card_params_obj) if agent_card_params_obj else {}
|
||
|
|
)
|
||
|
|
agent_card_params: str = safe_dumps(agent_card_params_dict)
|
||
|
|
|
||
|
|
# Serialize static_headers for update
|
||
|
|
static_headers_obj_u = agent.get("static_headers")
|
||
|
|
static_headers_val_u: str = (
|
||
|
|
safe_dumps(dict(static_headers_obj_u))
|
||
|
|
if static_headers_obj_u is not None
|
||
|
|
else safe_dumps({})
|
||
|
|
)
|
||
|
|
extra_headers_val_u: List[str] = agent.get("extra_headers") or []
|
||
|
|
|
||
|
|
update_data: Dict[str, Any] = {
|
||
|
|
"agent_name": agent_name,
|
||
|
|
"litellm_params": litellm_params,
|
||
|
|
"agent_card_params": agent_card_params,
|
||
|
|
"static_headers": static_headers_val_u,
|
||
|
|
"extra_headers": extra_headers_val_u,
|
||
|
|
"updated_by": updated_by,
|
||
|
|
"updated_at": datetime.now(timezone.utc),
|
||
|
|
}
|
||
|
|
|
||
|
|
for rate_field in (
|
||
|
|
"tpm_limit",
|
||
|
|
"rpm_limit",
|
||
|
|
"session_tpm_limit",
|
||
|
|
"session_rpm_limit",
|
||
|
|
):
|
||
|
|
_val = agent.get(rate_field)
|
||
|
|
if _val is not None:
|
||
|
|
update_data[rate_field] = _val
|
||
|
|
|
||
|
|
if agent.get("object_permission") is not None:
|
||
|
|
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||
|
|
where={"agent_id": agent_id}
|
||
|
|
)
|
||
|
|
existing_object_permission_id = (
|
||
|
|
existing_agent.object_permission_id
|
||
|
|
if existing_agent is not None
|
||
|
|
else None
|
||
|
|
)
|
||
|
|
agent_copy = dict(agent)
|
||
|
|
object_permission_id = await handle_update_object_permission_common(
|
||
|
|
agent_copy,
|
||
|
|
existing_object_permission_id,
|
||
|
|
prisma_client,
|
||
|
|
)
|
||
|
|
if object_permission_id is not None:
|
||
|
|
update_data["object_permission_id"] = object_permission_id
|
||
|
|
|
||
|
|
# Update agent in DB
|
||
|
|
updated_agent = await prisma_client.db.litellm_agentstable.update(
|
||
|
|
where={"agent_id": agent_id},
|
||
|
|
data=update_data,
|
||
|
|
include={"object_permission": True},
|
||
|
|
)
|
||
|
|
|
||
|
|
updated_agent_dict = updated_agent.model_dump()
|
||
|
|
if updated_agent.object_permission is not None:
|
||
|
|
try:
|
||
|
|
updated_agent_dict[
|
||
|
|
"object_permission"
|
||
|
|
] = updated_agent.object_permission.model_dump()
|
||
|
|
except Exception:
|
||
|
|
updated_agent_dict[
|
||
|
|
"object_permission"
|
||
|
|
] = updated_agent.object_permission.dict()
|
||
|
|
return AgentResponse(**updated_agent_dict) # type: ignore
|
||
|
|
except Exception as e:
|
||
|
|
raise Exception(f"Error updating agent in DB: {str(e)}")
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def get_all_agents_from_db(
|
||
|
|
prisma_client: PrismaClient,
|
||
|
|
) -> List[Dict[str, Any]]:
|
||
|
|
"""
|
||
|
|
Get all agents from the database
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
agents_from_db = await prisma_client.db.litellm_agentstable.find_many(
|
||
|
|
order={"created_at": "desc"},
|
||
|
|
include={"object_permission": True},
|
||
|
|
)
|
||
|
|
|
||
|
|
agents: List[Dict[str, Any]] = []
|
||
|
|
for agent in agents_from_db:
|
||
|
|
agent_dict = dict(agent)
|
||
|
|
# object_permission is eagerly loaded via include above
|
||
|
|
if agent.object_permission is not None:
|
||
|
|
try:
|
||
|
|
agent_dict[
|
||
|
|
"object_permission"
|
||
|
|
] = agent.object_permission.model_dump()
|
||
|
|
except Exception:
|
||
|
|
agent_dict["object_permission"] = agent.object_permission.dict()
|
||
|
|
agents.append(agent_dict)
|
||
|
|
|
||
|
|
return agents
|
||
|
|
except Exception as e:
|
||
|
|
raise Exception(f"Error getting agents from DB: {str(e)}")
|
||
|
|
|
||
|
|
def get_agent_by_id(
|
||
|
|
self,
|
||
|
|
agent_id: str,
|
||
|
|
) -> Optional[AgentResponse]:
|
||
|
|
"""
|
||
|
|
Get an agent by its ID from the database
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
for agent in self.agent_list:
|
||
|
|
if agent.agent_id == agent_id:
|
||
|
|
return agent
|
||
|
|
|
||
|
|
return None
|
||
|
|
except Exception as e:
|
||
|
|
raise Exception(f"Error getting agent from DB: {str(e)}")
|
||
|
|
|
||
|
|
def get_agent_by_name(self, agent_name: str) -> Optional[AgentResponse]:
|
||
|
|
"""
|
||
|
|
Get an agent by its name from the database
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
for agent in self.agent_list:
|
||
|
|
if agent.agent_name == agent_name:
|
||
|
|
return agent
|
||
|
|
|
||
|
|
return None
|
||
|
|
except Exception as e:
|
||
|
|
raise Exception(f"Error getting agent from DB: {str(e)}")
|
||
|
|
|
||
|
|
|
||
|
|
global_agent_registry = AgentRegistry()
|