chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from litellm.proxy.guardrails.guardrail_hooks.onyx.onyx import OnyxGuardrail
|
||||
from litellm.types.guardrails import SupportedGuardrailIntegrations
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.guardrails import Guardrail, LitellmParams
|
||||
|
||||
|
||||
def initialize_guardrail(litellm_params: "LitellmParams", guardrail: "Guardrail"):
|
||||
import litellm
|
||||
|
||||
_onyx_callback = OnyxGuardrail(
|
||||
api_base=litellm_params.api_base,
|
||||
api_key=litellm_params.api_key,
|
||||
guardrail_name=guardrail.get("guardrail_name", ""),
|
||||
event_hook=litellm_params.mode,
|
||||
default_on=litellm_params.default_on,
|
||||
)
|
||||
litellm.logging_callback_manager.add_litellm_callback(_onyx_callback)
|
||||
|
||||
return _onyx_callback
|
||||
|
||||
|
||||
guardrail_initializer_registry = {
|
||||
SupportedGuardrailIntegrations.ONYX.value: initialize_guardrail,
|
||||
}
|
||||
|
||||
|
||||
guardrail_class_registry = {
|
||||
SupportedGuardrailIntegrations.ONYX.value: OnyxGuardrail,
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
# +-------------------------------------------------------------+
|
||||
#
|
||||
# Use Onyx Guardrails for your LLM calls
|
||||
# https://onyx.security/
|
||||
#
|
||||
# +-------------------------------------------------------------+
|
||||
import os
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Type
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_guardrail import (
|
||||
CustomGuardrail,
|
||||
log_guardrail_information,
|
||||
)
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs, ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel
|
||||
|
||||
|
||||
class OnyxGuardrail(CustomGuardrail):
|
||||
def __init__(
|
||||
self,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: Optional[float] = 10.0,
|
||||
**kwargs,
|
||||
):
|
||||
timeout = timeout or int(os.getenv("ONYX_TIMEOUT", 10.0))
|
||||
self.async_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.GuardrailCallback,
|
||||
params={"timeout": httpx.Timeout(timeout=timeout, connect=5.0)},
|
||||
)
|
||||
self.api_base = api_base or os.getenv(
|
||||
"ONYX_API_BASE",
|
||||
"https://ai-guard.onyx.security",
|
||||
)
|
||||
self.api_key = api_key or os.getenv("ONYX_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("ONYX_API_KEY environment variable is not set")
|
||||
self.optional_params = kwargs
|
||||
super().__init__(**kwargs)
|
||||
verbose_proxy_logger.info(f"OnyxGuard initialized with server: {self.api_base}")
|
||||
|
||||
async def _validate_with_guard_server(
|
||||
self,
|
||||
payload: Any,
|
||||
input_type: Literal["request", "response"],
|
||||
conversation_id: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Call external Onyx Guard server for validation
|
||||
"""
|
||||
response = await self.async_handler.post(
|
||||
f"{self.api_base}/guard/evaluate/v1/{self.api_key}/litellm",
|
||||
json={
|
||||
"payload": payload,
|
||||
"input_type": input_type,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
if not result.get("allowed", True):
|
||||
detection_message = "Unknown violation"
|
||||
if "violated_rules" in result:
|
||||
detection_message = ", ".join(result["violated_rules"])
|
||||
verbose_proxy_logger.warning(
|
||||
f"Request blocked by Onyx Guard. Violations: {detection_message}."
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Request blocked by Onyx Guard. Violations: {detection_message}.",
|
||||
)
|
||||
return result
|
||||
|
||||
@log_guardrail_information
|
||||
async def apply_guardrail(
|
||||
self,
|
||||
inputs: GenericGuardrailAPIInputs,
|
||||
request_data: dict,
|
||||
input_type: Literal["request", "response"],
|
||||
logging_obj: Optional["LiteLLMLoggingObj"] = None,
|
||||
) -> GenericGuardrailAPIInputs:
|
||||
conversation_id = (
|
||||
logging_obj.litellm_call_id if logging_obj else str(uuid.uuid4())
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
"Running Onyx Guard apply_guardrail hook",
|
||||
extra={"conversation_id": conversation_id, "input_type": input_type},
|
||||
)
|
||||
payload = {}
|
||||
if input_type == "request":
|
||||
payload = request_data.get("proxy_server_request", {})
|
||||
else:
|
||||
try:
|
||||
response = ModelResponse(**request_data)
|
||||
parsed = response.json()
|
||||
payload = parsed.get("response", {})
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error in converting request_data to ModelResponse: {str(e)}",
|
||||
extra={
|
||||
"conversation_id": conversation_id,
|
||||
"input_type": input_type,
|
||||
},
|
||||
)
|
||||
payload = request_data
|
||||
|
||||
try:
|
||||
await self._validate_with_guard_server(payload, input_type, conversation_id)
|
||||
return inputs
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error in apply_guardrail guard: {str(e)}",
|
||||
extra={"conversation_id": conversation_id, "input_type": input_type},
|
||||
)
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.onyx import (
|
||||
OnyxGuardrailConfigModel,
|
||||
)
|
||||
|
||||
return OnyxGuardrailConfigModel
|
||||
Reference in New Issue
Block a user