297 lines
12 KiB
Python
297 lines
12 KiB
Python
|
|
"""
|
||
|
|
Security hook to prevent user B from seeing response from user A.
|
||
|
|
|
||
|
|
This hook uses the DBSpendUpdateWriter to batch-write response IDs to the database
|
||
|
|
instead of writing immediately on each request.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Tuple, Union, cast
|
||
|
|
|
||
|
|
from fastapi import HTTPException
|
||
|
|
|
||
|
|
from litellm._logging import verbose_proxy_logger
|
||
|
|
from litellm.integrations.custom_logger import CustomLogger
|
||
|
|
from litellm.proxy._types import LitellmUserRoles
|
||
|
|
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||
|
|
decrypt_value_helper,
|
||
|
|
encrypt_value_helper,
|
||
|
|
)
|
||
|
|
from litellm.types.llms.openai import (
|
||
|
|
BaseLiteLLMOpenAIResponseObject,
|
||
|
|
ResponsesAPIResponse,
|
||
|
|
)
|
||
|
|
from litellm.types.utils import CallTypesLiteral, LLMResponseTypes, SpecialEnums
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from litellm.caching.caching import DualCache
|
||
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
||
|
|
|
||
|
|
|
||
|
|
class ResponsesIDSecurity(CustomLogger):
|
||
|
|
def __init__(self):
|
||
|
|
pass
|
||
|
|
|
||
|
|
async def async_pre_call_hook(
|
||
|
|
self,
|
||
|
|
user_api_key_dict: "UserAPIKeyAuth",
|
||
|
|
cache: "DualCache",
|
||
|
|
data: dict,
|
||
|
|
call_type: CallTypesLiteral,
|
||
|
|
) -> Optional[Union[Exception, str, dict]]:
|
||
|
|
# MAP all the responses api response ids to the encrypted response ids
|
||
|
|
responses_api_call_types = {
|
||
|
|
"aresponses",
|
||
|
|
"aget_responses",
|
||
|
|
"adelete_responses",
|
||
|
|
"acancel_responses",
|
||
|
|
}
|
||
|
|
if call_type not in responses_api_call_types:
|
||
|
|
return None
|
||
|
|
if call_type == "aresponses":
|
||
|
|
# check 'previous_response_id' if present in the data
|
||
|
|
previous_response_id = data.get("previous_response_id")
|
||
|
|
if previous_response_id and self._is_encrypted_response_id(
|
||
|
|
previous_response_id
|
||
|
|
):
|
||
|
|
original_response_id, user_id, team_id = self._decrypt_response_id(
|
||
|
|
previous_response_id
|
||
|
|
)
|
||
|
|
self.check_user_access_to_response_id(
|
||
|
|
user_id, team_id, user_api_key_dict
|
||
|
|
)
|
||
|
|
data["previous_response_id"] = original_response_id
|
||
|
|
elif call_type in {"aget_responses", "adelete_responses", "acancel_responses"}:
|
||
|
|
response_id = data.get("response_id")
|
||
|
|
|
||
|
|
if response_id and self._is_encrypted_response_id(response_id):
|
||
|
|
original_response_id, user_id, team_id = self._decrypt_response_id(
|
||
|
|
response_id
|
||
|
|
)
|
||
|
|
|
||
|
|
self.check_user_access_to_response_id(
|
||
|
|
user_id, team_id, user_api_key_dict
|
||
|
|
)
|
||
|
|
data["response_id"] = original_response_id
|
||
|
|
return data
|
||
|
|
|
||
|
|
def check_user_access_to_response_id(
|
||
|
|
self,
|
||
|
|
response_id_user_id: Optional[str],
|
||
|
|
response_id_team_id: Optional[str],
|
||
|
|
user_api_key_dict: "UserAPIKeyAuth",
|
||
|
|
) -> bool:
|
||
|
|
from litellm.proxy.proxy_server import general_settings
|
||
|
|
|
||
|
|
if (
|
||
|
|
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||
|
|
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
||
|
|
):
|
||
|
|
return True
|
||
|
|
|
||
|
|
if response_id_user_id and response_id_user_id != user_api_key_dict.user_id:
|
||
|
|
if general_settings.get("disable_responses_id_security", False):
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Responses ID Security is disabled. User {user_api_key_dict.user_id} is accessing response id {response_id_user_id} which is not associated with them."
|
||
|
|
)
|
||
|
|
return True
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=403,
|
||
|
|
detail="Forbidden. The response id is not associated with the user, who this key belongs to. To disable this security feature, set general_settings::disable_responses_id_security to True in the config.yaml file.",
|
||
|
|
)
|
||
|
|
|
||
|
|
if response_id_team_id and response_id_team_id != user_api_key_dict.team_id:
|
||
|
|
if general_settings.get("disable_responses_id_security", False):
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Responses ID Security is disabled. Response belongs to team {response_id_team_id} but user {user_api_key_dict.user_id} is accessing it with team id {user_api_key_dict.team_id}."
|
||
|
|
)
|
||
|
|
return True
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=403,
|
||
|
|
detail="Forbidden. The response id is not associated with the team, who this key belongs to. To disable this security feature, set general_settings::disable_responses_id_security to True in the config.yaml file.",
|
||
|
|
)
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
def _is_encrypted_response_id(self, response_id: str) -> bool:
|
||
|
|
split_result = response_id.split("resp_")
|
||
|
|
if len(split_result) < 2:
|
||
|
|
return False
|
||
|
|
|
||
|
|
remaining_string = split_result[1]
|
||
|
|
decrypted_value = decrypt_value_helper(
|
||
|
|
value=remaining_string, key="response_id", return_original_value=True
|
||
|
|
)
|
||
|
|
|
||
|
|
if decrypted_value is None:
|
||
|
|
return False
|
||
|
|
|
||
|
|
if decrypted_value.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||
|
|
return True
|
||
|
|
return False
|
||
|
|
|
||
|
|
def _decrypt_response_id(
|
||
|
|
self, response_id: str
|
||
|
|
) -> Tuple[str, Optional[str], Optional[str]]:
|
||
|
|
"""
|
||
|
|
Returns:
|
||
|
|
- original_response_id: the original response id
|
||
|
|
- user_id: the user id
|
||
|
|
- team_id: the team id
|
||
|
|
"""
|
||
|
|
split_result = response_id.split("resp_")
|
||
|
|
if len(split_result) < 2:
|
||
|
|
return response_id, None, None
|
||
|
|
|
||
|
|
remaining_string = split_result[1]
|
||
|
|
decrypted_value = decrypt_value_helper(
|
||
|
|
value=remaining_string, key="response_id", return_original_value=True
|
||
|
|
)
|
||
|
|
|
||
|
|
if decrypted_value is None:
|
||
|
|
return response_id, None, None
|
||
|
|
|
||
|
|
if decrypted_value.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||
|
|
# Expected format: "litellm_proxy:responses_api:response_id:{response_id};user_id:{user_id}"
|
||
|
|
parts = decrypted_value.split(";")
|
||
|
|
|
||
|
|
if len(parts) >= 2:
|
||
|
|
# Extract response_id from "litellm_proxy:responses_api:response_id:{response_id}"
|
||
|
|
response_id_part = parts[0]
|
||
|
|
original_response_id = response_id_part.split("response_id:")[-1]
|
||
|
|
|
||
|
|
# Extract user_id from "user_id:{user_id}"
|
||
|
|
user_id_part = parts[1]
|
||
|
|
user_id = user_id_part.split("user_id:")[-1]
|
||
|
|
|
||
|
|
# Extract team_id from "team_id:{team_id}"
|
||
|
|
team_id_part = parts[2]
|
||
|
|
team_id = team_id_part.split("team_id:")[-1]
|
||
|
|
|
||
|
|
return original_response_id, user_id, team_id
|
||
|
|
else:
|
||
|
|
# Fallback if format is unexpected
|
||
|
|
return response_id, None, None
|
||
|
|
return response_id, None, None
|
||
|
|
|
||
|
|
def _get_signing_key(self) -> Optional[str]:
|
||
|
|
"""Get the signing key for encryption/decryption."""
|
||
|
|
import os
|
||
|
|
|
||
|
|
from litellm.proxy.proxy_server import master_key
|
||
|
|
|
||
|
|
salt_key = os.getenv("LITELLM_SALT_KEY", None)
|
||
|
|
if salt_key is None:
|
||
|
|
salt_key = master_key
|
||
|
|
return salt_key
|
||
|
|
|
||
|
|
def _encrypt_response_id(
|
||
|
|
self,
|
||
|
|
response: BaseLiteLLMOpenAIResponseObject,
|
||
|
|
user_api_key_dict: "UserAPIKeyAuth",
|
||
|
|
request_cache: Optional[dict[str, str]] = None,
|
||
|
|
) -> BaseLiteLLMOpenAIResponseObject:
|
||
|
|
# encrypt the response id using the symmetric key
|
||
|
|
# encrypt the response id, and encode the user id and response id in base64
|
||
|
|
|
||
|
|
# Check if signing key is available
|
||
|
|
signing_key = self._get_signing_key()
|
||
|
|
if signing_key is None:
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"Response ID encryption is enabled but no signing key is configured. "
|
||
|
|
"Please set LITELLM_SALT_KEY environment variable or configure a master_key. "
|
||
|
|
"Skipping response ID encryption. "
|
||
|
|
"See: https://docs.litellm.ai/docs/proxy/prod#5-set-litellm-salt-key"
|
||
|
|
)
|
||
|
|
return response
|
||
|
|
|
||
|
|
response_id = getattr(response, "id", None)
|
||
|
|
response_obj = getattr(response, "response", None)
|
||
|
|
|
||
|
|
if (
|
||
|
|
response_id
|
||
|
|
and isinstance(response_id, str)
|
||
|
|
and response_id.startswith("resp_")
|
||
|
|
):
|
||
|
|
# Check request-scoped cache first (for streaming consistency)
|
||
|
|
if request_cache is not None and response_id in request_cache:
|
||
|
|
setattr(response, "id", request_cache[response_id])
|
||
|
|
else:
|
||
|
|
encrypted_response_id = SpecialEnums.LITELLM_MANAGED_RESPONSE_API_RESPONSE_ID_COMPLETE_STR.value.format(
|
||
|
|
response_id,
|
||
|
|
user_api_key_dict.user_id or "",
|
||
|
|
user_api_key_dict.team_id or "",
|
||
|
|
)
|
||
|
|
|
||
|
|
encoded_user_id_and_response_id = encrypt_value_helper(
|
||
|
|
value=encrypted_response_id
|
||
|
|
)
|
||
|
|
encrypted_id = f"resp_{encoded_user_id_and_response_id}"
|
||
|
|
if request_cache is not None:
|
||
|
|
request_cache[response_id] = encrypted_id
|
||
|
|
setattr(response, "id", encrypted_id)
|
||
|
|
|
||
|
|
elif response_obj and isinstance(response_obj, ResponsesAPIResponse):
|
||
|
|
# Check request-scoped cache first (for streaming consistency)
|
||
|
|
if request_cache is not None and response_obj.id in request_cache:
|
||
|
|
setattr(response_obj, "id", request_cache[response_obj.id])
|
||
|
|
else:
|
||
|
|
encrypted_response_id = SpecialEnums.LITELLM_MANAGED_RESPONSE_API_RESPONSE_ID_COMPLETE_STR.value.format(
|
||
|
|
response_obj.id,
|
||
|
|
user_api_key_dict.user_id or "",
|
||
|
|
user_api_key_dict.team_id or "",
|
||
|
|
)
|
||
|
|
encoded_user_id_and_response_id = encrypt_value_helper(
|
||
|
|
value=encrypted_response_id
|
||
|
|
)
|
||
|
|
encrypted_id = f"resp_{encoded_user_id_and_response_id}"
|
||
|
|
if request_cache is not None:
|
||
|
|
request_cache[response_obj.id] = encrypted_id
|
||
|
|
setattr(response_obj, "id", encrypted_id)
|
||
|
|
setattr(response, "response", response_obj)
|
||
|
|
return response
|
||
|
|
|
||
|
|
async def async_post_call_success_hook(
|
||
|
|
self,
|
||
|
|
data: dict,
|
||
|
|
user_api_key_dict: "UserAPIKeyAuth",
|
||
|
|
response: LLMResponseTypes,
|
||
|
|
) -> Any:
|
||
|
|
"""
|
||
|
|
Queue response IDs for batch processing instead of writing directly to DB.
|
||
|
|
|
||
|
|
This method adds response IDs to an in-memory queue, which are then
|
||
|
|
batch-processed by the DBSpendUpdateWriter during regular database update cycles.
|
||
|
|
"""
|
||
|
|
from litellm.proxy.proxy_server import general_settings
|
||
|
|
|
||
|
|
if general_settings.get("disable_responses_id_security", False):
|
||
|
|
return response
|
||
|
|
if isinstance(response, ResponsesAPIResponse):
|
||
|
|
response = cast(
|
||
|
|
ResponsesAPIResponse,
|
||
|
|
self._encrypt_response_id(
|
||
|
|
response, user_api_key_dict, request_cache=None
|
||
|
|
),
|
||
|
|
)
|
||
|
|
return response
|
||
|
|
|
||
|
|
async def async_post_call_streaming_iterator_hook( # type: ignore
|
||
|
|
self, user_api_key_dict: "UserAPIKeyAuth", response: Any, request_data: dict
|
||
|
|
) -> AsyncGenerator[BaseLiteLLMOpenAIResponseObject, None]:
|
||
|
|
from litellm.proxy.proxy_server import general_settings
|
||
|
|
|
||
|
|
# Create a request-scoped cache for consistent encryption across streaming chunks.
|
||
|
|
request_encryption_cache: dict[str, str] = {}
|
||
|
|
|
||
|
|
async for chunk in response:
|
||
|
|
if (
|
||
|
|
isinstance(chunk, BaseLiteLLMOpenAIResponseObject)
|
||
|
|
and user_api_key_dict.request_route
|
||
|
|
== "/v1/responses" # only encrypt the response id for the responses api
|
||
|
|
and not general_settings.get("disable_responses_id_security", False)
|
||
|
|
):
|
||
|
|
chunk = self._encrypt_response_id(
|
||
|
|
chunk, user_api_key_dict, request_encryption_cache
|
||
|
|
)
|
||
|
|
yield chunk
|