chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1,34 @@
from typing import TYPE_CHECKING
from litellm.types.guardrails import SupportedGuardrailIntegrations
from .aim import AimGuardrail
if TYPE_CHECKING:
from litellm.types.guardrails import Guardrail, LitellmParams
def initialize_guardrail(litellm_params: "LitellmParams", guardrail: "Guardrail"):
import litellm
from litellm.proxy.guardrails.guardrail_hooks.aim import AimGuardrail
_aim_callback = AimGuardrail(
api_base=litellm_params.api_base,
api_key=litellm_params.api_key,
guardrail_name=guardrail.get("guardrail_name", ""),
event_hook=litellm_params.mode,
default_on=litellm_params.default_on,
)
litellm.logging_callback_manager.add_litellm_callback(_aim_callback)
return _aim_callback
guardrail_initializer_registry = {
SupportedGuardrailIntegrations.AIM.value: initialize_guardrail,
}
guardrail_class_registry = {
SupportedGuardrailIntegrations.AIM.value: AimGuardrail,
}

View File

@@ -0,0 +1,318 @@
# +-------------------------------------------------------------+
#
# Use Aim Security Guardrails for your LLM calls
# https://www.aim.security/
#
# +-------------------------------------------------------------+
import asyncio
import json
import os
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Type, Union
from fastapi import HTTPException
from pydantic import BaseModel
from websockets.asyncio.client import ClientConnection, connect
from litellm import DualCache
from litellm._logging import verbose_proxy_logger
from litellm._version import version as litellm_version
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.utils import (
CallTypesLiteral,
Choices,
EmbeddingResponse,
ImageResponse,
ModelResponse,
ModelResponseStream,
)
if TYPE_CHECKING:
from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel
class AimGuardrailMissingSecrets(Exception):
pass
class AimGuardrail(CustomGuardrail):
def __init__(
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
):
ssl_verify = kwargs.pop("ssl_verify", None)
self.async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.GuardrailCallback,
params={"ssl_verify": ssl_verify} if ssl_verify is not None else None,
)
self.api_key = api_key or os.environ.get("AIM_API_KEY")
if not self.api_key:
msg = (
"Couldn't get Aim api key, either set the `AIM_API_KEY` in the environment or "
"pass it as a parameter to the guardrail in the config file"
)
raise AimGuardrailMissingSecrets(msg)
self.api_base = (
api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security"
)
self.ws_api_base = self.api_base.replace("http://", "ws://").replace(
"https://", "wss://"
)
self.dlp_entities: list[dict] = []
self._max_dlp_entities = 100
super().__init__(**kwargs)
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: CallTypesLiteral,
) -> Union[Exception, str, dict, None]:
verbose_proxy_logger.debug("Inside AIM Pre-Call Hook")
return await self.call_aim_guardrail(
data, hook="pre_call", key_alias=user_api_key_dict.key_alias
)
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: CallTypesLiteral,
) -> Union[Exception, str, dict, None]:
verbose_proxy_logger.debug("Inside AIM Moderation Hook")
await self.call_aim_guardrail(
data, hook="moderation", key_alias=user_api_key_dict.key_alias
)
return data
async def call_aim_guardrail(
self, data: dict, hook: str, key_alias: Optional[str]
) -> dict:
user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
call_id = data.get("litellm_call_id")
headers = self._build_aim_headers(
hook=hook,
key_alias=key_alias,
user_email=user_email,
litellm_call_id=call_id,
)
response = await self.async_handler.post(
f"{self.api_base}/fw/v1/analyze",
headers=headers,
json={"messages": data.get("messages", [])},
)
response.raise_for_status()
res = response.json()
required_action = res.get("required_action")
action_type = required_action and required_action.get("action_type", None)
if action_type is None:
verbose_proxy_logger.debug("Aim: No required action specified")
return data
if action_type == "monitor_action":
verbose_proxy_logger.info("Aim: monitor action")
elif action_type == "block_action":
self._handle_block_action(res["analysis_result"], required_action)
elif action_type == "anonymize_action":
return self._anonymize_request(res, data)
else:
verbose_proxy_logger.error(f"Aim: {action_type} action")
return data
def _handle_block_action(self, analysis_result: Any, required_action: Any) -> None:
detection_message = required_action.get("detection_message", None)
verbose_proxy_logger.info(
"Aim: Violation detected enabled policies: {policies}".format(
policies=list(analysis_result["policy_drill_down"].keys()),
),
)
raise HTTPException(status_code=400, detail=detection_message)
def _anonymize_request(self, res: Any, data: dict) -> dict:
verbose_proxy_logger.info("Aim: anonymize action")
redacted_chat = res.get("redacted_chat")
if not redacted_chat:
return data
data["messages"] = [
{
"role": message["role"],
"content": message["content"],
}
for message in redacted_chat["all_redacted_messages"]
]
return data
async def call_aim_guardrail_on_output(
self, request_data: dict, output: str, hook: str, key_alias: Optional[str]
) -> Optional[dict]:
user_email = (
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
)
call_id = request_data.get("litellm_call_id")
response = await self.async_handler.post(
f"{self.api_base}/fw/v1/analyze",
headers=self._build_aim_headers(
hook=hook,
key_alias=key_alias,
user_email=user_email,
litellm_call_id=call_id,
),
json={
"messages": request_data.get("messages", [])
+ [{"role": "assistant", "content": output}]
},
)
response.raise_for_status()
res = response.json()
required_action = res.get("required_action")
action_type = required_action and required_action.get("action_type", None)
if action_type and action_type == "block_action":
return self._handle_block_action_on_output(
res["analysis_result"], required_action
)
redacted_chat = res.get("redacted_chat", None)
if action_type and action_type == "anonymize_action" and redacted_chat:
return {
"redacted_output": redacted_chat["all_redacted_messages"][-1]["content"]
}
return {"redacted_output": output}
def _handle_block_action_on_output(
self, analysis_result: Any, required_action: Any
) -> dict | None:
detection_message = required_action.get("detection_message", None)
verbose_proxy_logger.info(
"Aim: detected: {detected}, enabled policies: {policies}".format(
detected=True,
policies=list(analysis_result["policy_drill_down"].keys()),
),
)
return {"detection_message": detection_message}
def _build_aim_headers(
self,
*,
hook: str,
key_alias: Optional[str],
user_email: Optional[str],
litellm_call_id: Optional[str],
):
"""
A helper function to build the http headers that are required by AIM guardrails.
"""
return (
{
"Authorization": f"Bearer {self.api_key}",
# Used by Aim to apply only the guardrails that should be applied in a specific request phase.
"x-aim-litellm-hook": hook,
# Used by Aim to track LiteLLM version and provide backward compatibility.
"x-aim-litellm-version": litellm_version,
}
# Used by Aim to track together single call input and output
| ({"x-aim-call-id": litellm_call_id} if litellm_call_id else {})
# Used by Aim to track guardrails violations by user.
| ({"x-aim-user-email": user_email} if user_email else {})
| (
{
# Used by Aim apply only the guardrails that are associated with the key alias.
"x-aim-gateway-key-alias": key_alias,
}
if key_alias
else {}
)
)
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
) -> Any:
if (
isinstance(response, ModelResponse)
and response.choices
and isinstance(response.choices[0], Choices)
):
content = response.choices[0].message.content or ""
aim_output_guardrail_result = await self.call_aim_guardrail_on_output(
data, content, hook="output", key_alias=user_api_key_dict.key_alias
)
if aim_output_guardrail_result and aim_output_guardrail_result.get(
"detection_message"
):
raise HTTPException(
status_code=400,
detail=aim_output_guardrail_result.get("detection_message"),
)
if aim_output_guardrail_result and aim_output_guardrail_result.get(
"redacted_output"
):
response.choices[0].message.content = aim_output_guardrail_result.get(
"redacted_output"
)
return response
async def async_post_call_streaming_iterator_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response,
request_data: dict,
) -> AsyncGenerator[ModelResponseStream, None]:
user_email = (
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
)
call_id = request_data.get("litellm_call_id")
async with connect(
f"{self.ws_api_base}/fw/v1/analyze/stream",
additional_headers=self._build_aim_headers(
hook="output",
key_alias=user_api_key_dict.key_alias,
user_email=user_email,
litellm_call_id=call_id,
),
) as websocket:
sender = asyncio.create_task(
self.forward_the_stream_to_aim(websocket, response)
)
while True:
result = json.loads(await websocket.recv())
if verified_chunk := result.get("verified_chunk"):
yield ModelResponseStream.model_validate(verified_chunk)
else:
sender.cancel()
if result.get("done"):
return
if blocking_message := result.get("blocking_message"):
from litellm.proxy.proxy_server import StreamingCallbackError
raise StreamingCallbackError(blocking_message)
verbose_proxy_logger.error(
f"Unknown message received from AIM: {result}"
)
return
async def forward_the_stream_to_aim(
self,
websocket: ClientConnection,
response_iter,
) -> None:
async for chunk in response_iter:
if isinstance(chunk, BaseModel):
chunk = chunk.model_dump_json()
if isinstance(chunk, dict):
chunk = json.dumps(chunk)
await websocket.send(chunk)
await websocket.send(json.dumps({"done": True}))
@staticmethod
def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
from litellm.types.proxy.guardrails.guardrail_hooks.aim import (
AimGuardrailConfigModel,
)
return AimGuardrailConfigModel