chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,158 @@
# OpenAI Text Completion Guardrail Translation Handler
Handler for processing OpenAI's text completion endpoint (`/v1/completions`) with guardrails.
## Overview
This handler processes text completion requests by:
1. Extracting the text prompt(s) from the request
2. Applying guardrails to the prompt text(s)
3. Updating the request with the guardrailed prompt(s)
4. Applying guardrails to the completion output text
## Data Format
### Input Format
**Single Prompt:**
```json
{
"model": "gpt-3.5-turbo-instruct",
"prompt": "Say this is a test",
"max_tokens": 7,
"temperature": 0
}
```
**Multiple Prompts (Batch):**
```json
{
"model": "gpt-3.5-turbo-instruct",
"prompt": [
"Tell me a joke",
"Write a poem"
],
"max_tokens": 50
}
```
### Output Format
```json
{
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
"object": "text_completion",
"created": 1589478378,
"model": "gpt-3.5-turbo-instruct",
"choices": [
{
"text": "\n\nThis is indeed a test",
"index": 0,
"logprobs": null,
"finish_reason": "length"
}
],
"usage": {
"prompt_tokens": 5,
"completion_tokens": 7,
"total_tokens": 12
}
}
```
## Usage
The handler is automatically discovered and applied when guardrails are used with the text completion endpoint.
### Example: Using Guardrails with Text Completion
```bash
curl -X POST 'http://localhost:4000/v1/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer your-api-key' \
-d '{
"model": "gpt-3.5-turbo-instruct",
"prompt": "Say this is a test",
"guardrails": ["content_moderation"],
"max_tokens": 7
}'
```
The guardrail will be applied to both:
- **Input**: The prompt text before sending to the LLM
- **Output**: The completion text in the response
### Example: PII Masking in Prompts and Completions
```bash
curl -X POST 'http://localhost:4000/v1/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer your-api-key' \
-d '{
"model": "gpt-3.5-turbo-instruct",
"prompt": "My name is John Doe and my email is john@example.com",
"guardrails": ["mask_pii"],
"metadata": {
"guardrails": ["mask_pii"]
}
}'
```
### Example: Batch Prompts with Guardrails
```bash
curl -X POST 'http://localhost:4000/v1/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer your-api-key' \
-d '{
"model": "gpt-3.5-turbo-instruct",
"prompt": [
"Tell me about AI",
"What is machine learning?"
],
"guardrails": ["content_filter"],
"max_tokens": 100
}'
```
## Implementation Details
### Input Processing
- **Field**: `prompt` (string or list of strings)
- **Processing**:
- String prompts: Apply guardrail directly
- List prompts: Apply guardrail to each string in the list
- **Result**: Updated prompt(s) in request
### Output Processing
- **Field**: `choices[*].text` (string)
- **Processing**: Applies guardrail to each completion text
- **Result**: Updated completion texts in response
### Supported Prompt Types
1. **String**: Single prompt as a string
2. **List of Strings**: Multiple prompts for batch completion
3. **List of Lists**: Token-based prompts (passed through unchanged)
## Extension
Override these methods to customize behavior:
- `process_input_messages()`: Customize how prompts are processed
- `process_output_response()`: Customize how completion texts are processed
## Supported Call Types
- `CallTypes.text_completion` - Synchronous text completion
- `CallTypes.atext_completion` - Asynchronous text completion
## Notes
- The handler processes both input prompts and output completion texts
- List prompts are processed individually (each string in the list)
- Non-string prompt items (e.g., token lists) are passed through unchanged
- Both sync and async call types use the same handler

View File

@@ -0,0 +1,13 @@
"""OpenAI Text Completion handler for Unified Guardrails."""
from litellm.llms.openai.completion.guardrail_translation.handler import (
OpenAITextCompletionHandler,
)
from litellm.types.utils import CallTypes
guardrail_translation_mappings = {
CallTypes.text_completion: OpenAITextCompletionHandler,
CallTypes.atext_completion: OpenAITextCompletionHandler,
}
__all__ = ["guardrail_translation_mappings", "OpenAITextCompletionHandler"]

View File

@@ -0,0 +1,194 @@
"""
OpenAI Text Completion Handler for Unified Guardrails
This module provides guardrail translation support for OpenAI's text completion endpoint.
The handler processes the 'prompt' parameter for guardrails.
"""
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
from litellm.types.utils import GenericGuardrailAPIInputs
if TYPE_CHECKING:
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.types.utils import TextCompletionResponse
class OpenAITextCompletionHandler(BaseTranslation):
"""
Handler for processing OpenAI text completion requests with guardrails.
This class provides methods to:
1. Process input prompt (pre-call hook)
2. Process output response (post-call hook)
The handler specifically processes the 'prompt' parameter which can be:
- A single string
- A list of strings (for batch completions)
"""
async def process_input_messages(
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process input prompt by applying guardrails to text content.
Args:
data: Request data dictionary containing 'prompt' parameter
guardrail_to_apply: The guardrail instance to apply
Returns:
Modified data with guardrails applied to prompt
"""
prompt = data.get("prompt")
if prompt is None:
verbose_proxy_logger.debug(
"OpenAI Text Completion: No prompt found in request data"
)
return data
if isinstance(prompt, str):
# Single string prompt
inputs = GenericGuardrailAPIInputs(texts=[prompt])
# Include model information if available
model = data.get("model")
if model:
inputs["model"] = model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=data,
input_type="request",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
data["prompt"] = guardrailed_texts[0] if guardrailed_texts else prompt
verbose_proxy_logger.debug(
"OpenAI Text Completion: Applied guardrail to string prompt. "
"Original length: %d, New length: %d",
len(prompt),
len(data["prompt"]),
)
elif isinstance(prompt, list):
# List of string prompts (batch completion)
texts_to_check = []
text_indices = [] # Track which prompts are strings
for idx, p in enumerate(prompt):
if isinstance(p, str):
texts_to_check.append(p)
text_indices.append(idx)
if texts_to_check:
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
# Include model information if available
model = data.get("model")
if model:
inputs["model"] = model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=data,
input_type="request",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
# Replace guardrailed texts back
for guardrail_idx, prompt_idx in enumerate(text_indices):
if guardrail_idx < len(guardrailed_texts):
data["prompt"][prompt_idx] = guardrailed_texts[guardrail_idx]
verbose_proxy_logger.debug(
"OpenAI Text Completion: Applied guardrail to prompt[%d]. "
"Original length: %d, New length: %d",
prompt_idx,
len(texts_to_check[guardrail_idx]),
len(guardrailed_texts[guardrail_idx]),
)
else:
verbose_proxy_logger.warning(
"OpenAI Text Completion: Unexpected prompt type: %s. Expected string or list.",
type(prompt),
)
return data
async def process_output_response(
self,
response: "TextCompletionResponse",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
) -> Any:
"""
Process output response by applying guardrails to completion text.
Args:
response: Text completion response object
guardrail_to_apply: The guardrail instance to apply
litellm_logging_obj: Optional logging object
user_api_key_dict: User API key metadata to pass to guardrails
Returns:
Modified response with guardrails applied to completion text
"""
if not hasattr(response, "choices") or not response.choices:
verbose_proxy_logger.debug(
"OpenAI Text Completion: No choices in response to process"
)
return response
# Collect all texts to check
texts_to_check = []
choice_indices = []
for idx, choice in enumerate(response.choices):
if hasattr(choice, "text") and isinstance(choice.text, str):
texts_to_check.append(choice.text)
choice_indices.append(idx)
# Apply guardrails in batch
if texts_to_check:
# Create a request_data dict with response info and user API key metadata
request_data: dict = {"response": response}
# Add user API key metadata with prefixed keys
user_metadata = self.transform_user_api_key_dict_to_metadata(
user_api_key_dict
)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
# Include model information from the response if available
if hasattr(response, "model") and response.model:
inputs["model"] = response.model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="response",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
# Apply guardrailed texts back to choices
for guardrail_idx, choice_idx in enumerate(choice_indices):
if guardrail_idx < len(guardrailed_texts):
original_text = response.choices[choice_idx].text
response.choices[choice_idx].text = guardrailed_texts[guardrail_idx]
verbose_proxy_logger.debug(
"OpenAI Text Completion: Applied guardrail to choice[%d] text. "
"Original length: %d, New length: %d",
choice_idx,
len(original_text),
len(guardrailed_texts[guardrail_idx]),
)
return response

View File

@@ -0,0 +1,318 @@
import json
from typing import Callable, List, Optional, Union
from openai import AsyncOpenAI, OpenAI
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.base import BaseLLM
from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage
from litellm.types.utils import LlmProviders, ModelResponse, TextCompletionResponse
from litellm.utils import ProviderConfigManager
from ..common_utils import BaseOpenAILLM, OpenAIError
from .transformation import OpenAITextCompletionConfig
class OpenAITextCompletion(BaseLLM):
openai_text_completion_global_config = OpenAITextCompletionConfig()
def __init__(self) -> None:
super().__init__()
def validate_environment(self, api_key):
headers = {
"content-type": "application/json",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def completion(
self,
model_response: ModelResponse,
api_key: str,
model: str,
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
timeout: float,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
print_verbose: Optional[Callable] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
litellm_params=None,
logger_fn=None,
client=None,
organization: Optional[str] = None,
headers: Optional[dict] = None,
):
try:
if headers is None:
headers = self.validate_environment(api_key=api_key)
if model is None or messages is None:
raise OpenAIError(status_code=422, message="Missing model or messages")
# don't send max retries to the api, if set
provider_config = ProviderConfigManager.get_provider_text_completion_config(
model=model,
provider=LlmProviders(custom_llm_provider),
)
data = provider_config.transform_text_completion_request(
model=model,
messages=messages,
optional_params=optional_params,
headers=headers,
)
max_retries = data.pop("max_retries", 2)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"headers": headers,
"api_base": api_base,
"complete_input_dict": data,
},
)
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
data=data,
headers=headers,
model_response=model_response,
model=model,
timeout=timeout,
max_retries=max_retries,
client=client,
organization=organization,
)
else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
elif optional_params.get("stream", False):
return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
data=data,
headers=headers,
model_response=model_response,
model=model,
timeout=timeout,
max_retries=max_retries, # type: ignore
client=client,
organization=organization,
)
else:
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries, # type: ignore
organization=organization,
)
else:
openai_client = client
raw_response = openai_client.completions.with_raw_response.create(**data) # type: ignore
response = raw_response.parse()
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
api_key=api_key,
original_response=response_json,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return TextCompletionResponse(**response_json)
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
async def acompletion(
self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
api_key: str,
model: str,
timeout: float,
max_retries: int,
organization: Optional[str] = None,
client=None,
):
try:
if client is None:
openai_aclient = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=BaseOpenAILLM._get_async_http_client(),
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_aclient = client
raw_response = await openai_aclient.completions.with_raw_response.create(
**data
)
response = raw_response.parse()
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
response_obj = TextCompletionResponse(**response_json)
response_obj._hidden_params.original_response = json.dumps(response_json)
return response_obj
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
def streaming(
self,
logging_obj,
api_key: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float,
api_base: Optional[str] = None,
max_retries=None,
client=None,
organization=None,
):
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries, # type: ignore
organization=organization,
)
else:
openai_client = client
try:
raw_response = openai_client.completions.with_raw_response.create(**data)
response = raw_response.parse()
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
)
try:
for chunk in streamwrapper:
yield chunk
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
async def async_streaming(
self,
logging_obj,
api_key: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float,
max_retries: int,
api_base: Optional[str] = None,
client=None,
organization=None,
):
if client is None:
openai_client = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_client = client
raw_response = await openai_client.completions.with_raw_response.create(**data)
response = raw_response.parse()
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
)
try:
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)

View File

@@ -0,0 +1,158 @@
"""
Support for gpt model family
"""
from typing import List, Optional, Union
from litellm.llms.base_llm.completion.transformation import BaseTextCompletionConfig
from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage
from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse
from ..chat.gpt_transformation import OpenAIGPTConfig
from .utils import _transform_prompt
class OpenAITextCompletionConfig(BaseTextCompletionConfig, OpenAIGPTConfig):
"""
Reference: https://platform.openai.com/docs/api-reference/completions/create
The class `OpenAITextCompletionConfig` provides configuration for the OpenAI's text completion API interface. Below are the parameters:
- `best_of` (integer or null): This optional parameter generates server-side completions and returns the one with the highest log probability per token.
- `echo` (boolean or null): This optional parameter will echo back the prompt in addition to the completion.
- `frequency_penalty` (number or null): Defaults to 0. It is a numbers from -2.0 to 2.0, where positive values decrease the model's likelihood to repeat the same line.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `logprobs` (integer or null): This optional parameter includes the log probabilities on the most likely tokens as well as the chosen tokens.
- `max_tokens` (integer or null): This optional parameter sets the maximum number of tokens to generate in the completion.
- `n` (integer or null): This optional parameter sets how many completions to generate for each prompt.
- `presence_penalty` (number or null): Defaults to 0 and can be between -2.0 and 2.0. Positive values increase the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `suffix` (string or null): Defines the suffix that comes after a completion of inserted text.
- `temperature` (number or null): This optional parameter defines the sampling temperature to use.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
best_of: Optional[int] = None
echo: Optional[bool] = None
frequency_penalty: Optional[int] = None
logit_bias: Optional[dict] = None
logprobs: Optional[int] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
suffix: Optional[str] = None
def __init__(
self,
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[int] = None,
logit_bias: Optional[dict] = None,
logprobs: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
suffix: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def convert_to_chat_model_response_object(
self,
response_object: Optional[TextCompletionResponse] = None,
model_response_object: Optional[ModelResponse] = None,
):
try:
## RESPONSE OBJECT
if response_object is None or model_response_object is None:
raise ValueError("Error in response object format")
choice_list: List[Choices] = []
for idx, choice in enumerate(response_object["choices"]):
message = Message(
content=choice["text"],
role="assistant",
)
choice = Choices(
finish_reason=choice["finish_reason"],
index=idx,
message=message,
logprobs=choice.get("logprobs", None),
)
choice_list.append(choice)
model_response_object.choices = choice_list # type: ignore
if "usage" in response_object:
setattr(model_response_object, "usage", response_object["usage"])
if "id" in response_object:
model_response_object.id = response_object["id"]
if "model" in response_object:
model_response_object.model = response_object["model"]
model_response_object._hidden_params[
"original_response"
] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
return model_response_object
except Exception as e:
raise e
def get_supported_openai_params(self, model: str) -> List:
return [
"functions",
"function_call",
"temperature",
"top_p",
"n",
"stream",
"stream_options",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
"logprobs",
"top_logprobs",
"extra_headers",
]
def transform_text_completion_request(
self,
model: str,
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
optional_params: dict,
headers: dict,
) -> dict:
prompt = _transform_prompt(messages)
return {
"model": model,
"prompt": prompt,
**optional_params,
}

View File

@@ -0,0 +1,50 @@
from typing import List, Union, cast
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.types.llms.openai import (
AllMessageValues,
AllPromptValues,
OpenAITextCompletionUserMessage,
)
def is_tokens_or_list_of_tokens(value: List):
# Check if it's a list of integers (tokens)
if isinstance(value, list) and all(isinstance(item, int) for item in value):
return True
# Check if it's a list of lists of integers (list of tokens)
if isinstance(value, list) and all(
isinstance(item, list) and all(isinstance(i, int) for i in item)
for item in value
):
return True
return False
def _transform_prompt(
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
) -> AllPromptValues:
if len(messages) == 1: # base case
message_content = messages[0].get("content")
if (
message_content
and isinstance(message_content, list)
and is_tokens_or_list_of_tokens(message_content)
):
openai_prompt: AllPromptValues = cast(AllPromptValues, message_content)
else:
openai_prompt = ""
content = convert_content_list_to_str(cast(AllMessageValues, messages[0]))
openai_prompt += content
else:
prompt_str_list: List[str] = []
for m in messages:
try: # expect list of int/list of list of int to be a 1 message array only.
content = convert_content_list_to_str(cast(AllMessageValues, m))
prompt_str_list.append(content)
except Exception as e:
raise e
openai_prompt = prompt_str_list
return openai_prompt