chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,158 @@
|
||||
# OpenAI Text Completion Guardrail Translation Handler
|
||||
|
||||
Handler for processing OpenAI's text completion endpoint (`/v1/completions`) with guardrails.
|
||||
|
||||
## Overview
|
||||
|
||||
This handler processes text completion requests by:
|
||||
1. Extracting the text prompt(s) from the request
|
||||
2. Applying guardrails to the prompt text(s)
|
||||
3. Updating the request with the guardrailed prompt(s)
|
||||
4. Applying guardrails to the completion output text
|
||||
|
||||
## Data Format
|
||||
|
||||
### Input Format
|
||||
|
||||
**Single Prompt:**
|
||||
```json
|
||||
{
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "Say this is a test",
|
||||
"max_tokens": 7,
|
||||
"temperature": 0
|
||||
}
|
||||
```
|
||||
|
||||
**Multiple Prompts (Batch):**
|
||||
```json
|
||||
{
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": [
|
||||
"Tell me a joke",
|
||||
"Write a poem"
|
||||
],
|
||||
"max_tokens": 50
|
||||
}
|
||||
```
|
||||
|
||||
### Output Format
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
|
||||
"object": "text_completion",
|
||||
"created": 1589478378,
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"choices": [
|
||||
{
|
||||
"text": "\n\nThis is indeed a test",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"finish_reason": "length"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 12
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
The handler is automatically discovered and applied when guardrails are used with the text completion endpoint.
|
||||
|
||||
### Example: Using Guardrails with Text Completion
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "Say this is a test",
|
||||
"guardrails": ["content_moderation"],
|
||||
"max_tokens": 7
|
||||
}'
|
||||
```
|
||||
|
||||
The guardrail will be applied to both:
|
||||
- **Input**: The prompt text before sending to the LLM
|
||||
- **Output**: The completion text in the response
|
||||
|
||||
### Example: PII Masking in Prompts and Completions
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "My name is John Doe and my email is john@example.com",
|
||||
"guardrails": ["mask_pii"],
|
||||
"metadata": {
|
||||
"guardrails": ["mask_pii"]
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### Example: Batch Prompts with Guardrails
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": [
|
||||
"Tell me about AI",
|
||||
"What is machine learning?"
|
||||
],
|
||||
"guardrails": ["content_filter"],
|
||||
"max_tokens": 100
|
||||
}'
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Input Processing
|
||||
|
||||
- **Field**: `prompt` (string or list of strings)
|
||||
- **Processing**:
|
||||
- String prompts: Apply guardrail directly
|
||||
- List prompts: Apply guardrail to each string in the list
|
||||
- **Result**: Updated prompt(s) in request
|
||||
|
||||
### Output Processing
|
||||
|
||||
- **Field**: `choices[*].text` (string)
|
||||
- **Processing**: Applies guardrail to each completion text
|
||||
- **Result**: Updated completion texts in response
|
||||
|
||||
### Supported Prompt Types
|
||||
|
||||
1. **String**: Single prompt as a string
|
||||
2. **List of Strings**: Multiple prompts for batch completion
|
||||
3. **List of Lists**: Token-based prompts (passed through unchanged)
|
||||
|
||||
## Extension
|
||||
|
||||
Override these methods to customize behavior:
|
||||
|
||||
- `process_input_messages()`: Customize how prompts are processed
|
||||
- `process_output_response()`: Customize how completion texts are processed
|
||||
|
||||
## Supported Call Types
|
||||
|
||||
- `CallTypes.text_completion` - Synchronous text completion
|
||||
- `CallTypes.atext_completion` - Asynchronous text completion
|
||||
|
||||
## Notes
|
||||
|
||||
- The handler processes both input prompts and output completion texts
|
||||
- List prompts are processed individually (each string in the list)
|
||||
- Non-string prompt items (e.g., token lists) are passed through unchanged
|
||||
- Both sync and async call types use the same handler
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
"""OpenAI Text Completion handler for Unified Guardrails."""
|
||||
|
||||
from litellm.llms.openai.completion.guardrail_translation.handler import (
|
||||
OpenAITextCompletionHandler,
|
||||
)
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
guardrail_translation_mappings = {
|
||||
CallTypes.text_completion: OpenAITextCompletionHandler,
|
||||
CallTypes.atext_completion: OpenAITextCompletionHandler,
|
||||
}
|
||||
|
||||
__all__ = ["guardrail_translation_mappings", "OpenAITextCompletionHandler"]
|
||||
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
OpenAI Text Completion Handler for Unified Guardrails
|
||||
|
||||
This module provides guardrail translation support for OpenAI's text completion endpoint.
|
||||
The handler processes the 'prompt' parameter for guardrails.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.types.utils import TextCompletionResponse
|
||||
|
||||
|
||||
class OpenAITextCompletionHandler(BaseTranslation):
|
||||
"""
|
||||
Handler for processing OpenAI text completion requests with guardrails.
|
||||
|
||||
This class provides methods to:
|
||||
1. Process input prompt (pre-call hook)
|
||||
2. Process output response (post-call hook)
|
||||
|
||||
The handler specifically processes the 'prompt' parameter which can be:
|
||||
- A single string
|
||||
- A list of strings (for batch completions)
|
||||
"""
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input prompt by applying guardrails to text content.
|
||||
|
||||
Args:
|
||||
data: Request data dictionary containing 'prompt' parameter
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
|
||||
Returns:
|
||||
Modified data with guardrails applied to prompt
|
||||
"""
|
||||
prompt = data.get("prompt")
|
||||
if prompt is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Text Completion: No prompt found in request data"
|
||||
)
|
||||
return data
|
||||
|
||||
if isinstance(prompt, str):
|
||||
# Single string prompt
|
||||
inputs = GenericGuardrailAPIInputs(texts=[prompt])
|
||||
# Include model information if available
|
||||
model = data.get("model")
|
||||
if model:
|
||||
inputs["model"] = model
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
data["prompt"] = guardrailed_texts[0] if guardrailed_texts else prompt
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Text Completion: Applied guardrail to string prompt. "
|
||||
"Original length: %d, New length: %d",
|
||||
len(prompt),
|
||||
len(data["prompt"]),
|
||||
)
|
||||
|
||||
elif isinstance(prompt, list):
|
||||
# List of string prompts (batch completion)
|
||||
texts_to_check = []
|
||||
text_indices = [] # Track which prompts are strings
|
||||
|
||||
for idx, p in enumerate(prompt):
|
||||
if isinstance(p, str):
|
||||
texts_to_check.append(p)
|
||||
text_indices.append(idx)
|
||||
|
||||
if texts_to_check:
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
# Include model information if available
|
||||
model = data.get("model")
|
||||
if model:
|
||||
inputs["model"] = model
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
|
||||
# Replace guardrailed texts back
|
||||
for guardrail_idx, prompt_idx in enumerate(text_indices):
|
||||
if guardrail_idx < len(guardrailed_texts):
|
||||
data["prompt"][prompt_idx] = guardrailed_texts[guardrail_idx]
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Text Completion: Applied guardrail to prompt[%d]. "
|
||||
"Original length: %d, New length: %d",
|
||||
prompt_idx,
|
||||
len(texts_to_check[guardrail_idx]),
|
||||
len(guardrailed_texts[guardrail_idx]),
|
||||
)
|
||||
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"OpenAI Text Completion: Unexpected prompt type: %s. Expected string or list.",
|
||||
type(prompt),
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: "TextCompletionResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output response by applying guardrails to completion text.
|
||||
|
||||
Args:
|
||||
response: Text completion response object
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
user_api_key_dict: User API key metadata to pass to guardrails
|
||||
|
||||
Returns:
|
||||
Modified response with guardrails applied to completion text
|
||||
"""
|
||||
if not hasattr(response, "choices") or not response.choices:
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Text Completion: No choices in response to process"
|
||||
)
|
||||
return response
|
||||
|
||||
# Collect all texts to check
|
||||
texts_to_check = []
|
||||
choice_indices = []
|
||||
|
||||
for idx, choice in enumerate(response.choices):
|
||||
if hasattr(choice, "text") and isinstance(choice.text, str):
|
||||
texts_to_check.append(choice.text)
|
||||
choice_indices.append(idx)
|
||||
|
||||
# Apply guardrails in batch
|
||||
if texts_to_check:
|
||||
# Create a request_data dict with response info and user API key metadata
|
||||
request_data: dict = {"response": response}
|
||||
|
||||
# Add user API key metadata with prefixed keys
|
||||
user_metadata = self.transform_user_api_key_dict_to_metadata(
|
||||
user_api_key_dict
|
||||
)
|
||||
if user_metadata:
|
||||
request_data["litellm_metadata"] = user_metadata
|
||||
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
# Include model information from the response if available
|
||||
if hasattr(response, "model") and response.model:
|
||||
inputs["model"] = response.model
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
|
||||
# Apply guardrailed texts back to choices
|
||||
for guardrail_idx, choice_idx in enumerate(choice_indices):
|
||||
if guardrail_idx < len(guardrailed_texts):
|
||||
original_text = response.choices[choice_idx].text
|
||||
response.choices[choice_idx].text = guardrailed_texts[guardrail_idx]
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Text Completion: Applied guardrail to choice[%d] text. "
|
||||
"Original length: %d, New length: %d",
|
||||
choice_idx,
|
||||
len(original_text),
|
||||
len(guardrailed_texts[guardrail_idx]),
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,318 @@
|
||||
import json
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage
|
||||
from litellm.types.utils import LlmProviders, ModelResponse, TextCompletionResponse
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
from ..common_utils import BaseOpenAILLM, OpenAIError
|
||||
from .transformation import OpenAITextCompletionConfig
|
||||
|
||||
|
||||
class OpenAITextCompletion(BaseLLM):
|
||||
openai_text_completion_global_config = OpenAITextCompletionConfig()
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def validate_environment(self, api_key):
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model_response: ModelResponse,
|
||||
api_key: str,
|
||||
model: str,
|
||||
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
|
||||
timeout: float,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
api_base: Optional[str] = None,
|
||||
acompletion: bool = False,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
client=None,
|
||||
organization: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
try:
|
||||
if headers is None:
|
||||
headers = self.validate_environment(api_key=api_key)
|
||||
if model is None or messages is None:
|
||||
raise OpenAIError(status_code=422, message="Missing model or messages")
|
||||
|
||||
# don't send max retries to the api, if set
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_text_completion_config(
|
||||
model=model,
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
|
||||
data = provider_config.transform_text_completion_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
headers=headers,
|
||||
)
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
if acompletion is True:
|
||||
if optional_params.get("stream", False):
|
||||
return self.async_streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
|
||||
elif optional_params.get("stream", False):
|
||||
return self.streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries, # type: ignore
|
||||
client=client,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
if client is None:
|
||||
openai_client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
http_client=litellm.client_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
raw_response = openai_client.completions.with_raw_response.create(**data) # type: ignore
|
||||
response = raw_response.parse()
|
||||
response_json = response.model_dump()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
api_key=api_key,
|
||||
original_response=response_json,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return TextCompletionResponse(**response_json)
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
|
||||
async def acompletion(
|
||||
self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
api_key: str,
|
||||
model: str,
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
organization: Optional[str] = None,
|
||||
client=None,
|
||||
):
|
||||
try:
|
||||
if client is None:
|
||||
openai_aclient = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
http_client=BaseOpenAILLM._get_async_http_client(),
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_aclient = client
|
||||
|
||||
raw_response = await openai_aclient.completions.with_raw_response.create(
|
||||
**data
|
||||
)
|
||||
response = raw_response.parse()
|
||||
response_json = response.model_dump()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
## RESPONSE OBJECT
|
||||
response_obj = TextCompletionResponse(**response_json)
|
||||
response_obj._hidden_params.original_response = json.dumps(response_json)
|
||||
return response_obj
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
|
||||
def streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
api_key: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
timeout: float,
|
||||
api_base: Optional[str] = None,
|
||||
max_retries=None,
|
||||
client=None,
|
||||
organization=None,
|
||||
):
|
||||
if client is None:
|
||||
openai_client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
http_client=litellm.client_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
try:
|
||||
raw_response = openai_client.completions.with_raw_response.create(**data)
|
||||
response = raw_response.parse()
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="text-completion-openai",
|
||||
logging_obj=logging_obj,
|
||||
stream_options=data.get("stream_options", None),
|
||||
)
|
||||
|
||||
try:
|
||||
for chunk in streamwrapper:
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
api_key: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
organization=None,
|
||||
):
|
||||
if client is None:
|
||||
openai_client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
http_client=litellm.aclient_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
raw_response = await openai_client.completions.with_raw_response.create(**data)
|
||||
response = raw_response.parse()
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="text-completion-openai",
|
||||
logging_obj=logging_obj,
|
||||
stream_options=data.get("stream_options", None),
|
||||
)
|
||||
|
||||
try:
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Support for gpt model family
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from litellm.llms.base_llm.completion.transformation import BaseTextCompletionConfig
|
||||
from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage
|
||||
from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse
|
||||
|
||||
from ..chat.gpt_transformation import OpenAIGPTConfig
|
||||
from .utils import _transform_prompt
|
||||
|
||||
|
||||
class OpenAITextCompletionConfig(BaseTextCompletionConfig, OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/api-reference/completions/create
|
||||
|
||||
The class `OpenAITextCompletionConfig` provides configuration for the OpenAI's text completion API interface. Below are the parameters:
|
||||
|
||||
- `best_of` (integer or null): This optional parameter generates server-side completions and returns the one with the highest log probability per token.
|
||||
|
||||
- `echo` (boolean or null): This optional parameter will echo back the prompt in addition to the completion.
|
||||
|
||||
- `frequency_penalty` (number or null): Defaults to 0. It is a numbers from -2.0 to 2.0, where positive values decrease the model's likelihood to repeat the same line.
|
||||
|
||||
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
||||
|
||||
- `logprobs` (integer or null): This optional parameter includes the log probabilities on the most likely tokens as well as the chosen tokens.
|
||||
|
||||
- `max_tokens` (integer or null): This optional parameter sets the maximum number of tokens to generate in the completion.
|
||||
|
||||
- `n` (integer or null): This optional parameter sets how many completions to generate for each prompt.
|
||||
|
||||
- `presence_penalty` (number or null): Defaults to 0 and can be between -2.0 and 2.0. Positive values increase the model's likelihood to talk about new topics.
|
||||
|
||||
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
||||
|
||||
- `suffix` (string or null): Defines the suffix that comes after a completion of inserted text.
|
||||
|
||||
- `temperature` (number or null): This optional parameter defines the sampling temperature to use.
|
||||
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
||||
"""
|
||||
|
||||
best_of: Optional[int] = None
|
||||
echo: Optional[bool] = None
|
||||
frequency_penalty: Optional[int] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
logprobs: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
suffix: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
suffix: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def convert_to_chat_model_response_object(
|
||||
self,
|
||||
response_object: Optional[TextCompletionResponse] = None,
|
||||
model_response_object: Optional[ModelResponse] = None,
|
||||
):
|
||||
try:
|
||||
## RESPONSE OBJECT
|
||||
if response_object is None or model_response_object is None:
|
||||
raise ValueError("Error in response object format")
|
||||
choice_list: List[Choices] = []
|
||||
for idx, choice in enumerate(response_object["choices"]):
|
||||
message = Message(
|
||||
content=choice["text"],
|
||||
role="assistant",
|
||||
)
|
||||
choice = Choices(
|
||||
finish_reason=choice["finish_reason"],
|
||||
index=idx,
|
||||
message=message,
|
||||
logprobs=choice.get("logprobs", None),
|
||||
)
|
||||
choice_list.append(choice)
|
||||
model_response_object.choices = choice_list # type: ignore
|
||||
|
||||
if "usage" in response_object:
|
||||
setattr(model_response_object, "usage", response_object["usage"])
|
||||
|
||||
if "id" in response_object:
|
||||
model_response_object.id = response_object["id"]
|
||||
|
||||
if "model" in response_object:
|
||||
model_response_object.model = response_object["model"]
|
||||
|
||||
model_response_object._hidden_params[
|
||||
"original_response"
|
||||
] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
|
||||
return model_response_object
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return [
|
||||
"functions",
|
||||
"function_call",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"n",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"response_format",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"max_retries",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def transform_text_completion_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
prompt = _transform_prompt(messages)
|
||||
return {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
from typing import List, Union, cast
|
||||
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
AllPromptValues,
|
||||
OpenAITextCompletionUserMessage,
|
||||
)
|
||||
|
||||
|
||||
def is_tokens_or_list_of_tokens(value: List):
|
||||
# Check if it's a list of integers (tokens)
|
||||
if isinstance(value, list) and all(isinstance(item, int) for item in value):
|
||||
return True
|
||||
# Check if it's a list of lists of integers (list of tokens)
|
||||
if isinstance(value, list) and all(
|
||||
isinstance(item, list) and all(isinstance(i, int) for i in item)
|
||||
for item in value
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _transform_prompt(
|
||||
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
|
||||
) -> AllPromptValues:
|
||||
if len(messages) == 1: # base case
|
||||
message_content = messages[0].get("content")
|
||||
if (
|
||||
message_content
|
||||
and isinstance(message_content, list)
|
||||
and is_tokens_or_list_of_tokens(message_content)
|
||||
):
|
||||
openai_prompt: AllPromptValues = cast(AllPromptValues, message_content)
|
||||
else:
|
||||
openai_prompt = ""
|
||||
content = convert_content_list_to_str(cast(AllMessageValues, messages[0]))
|
||||
openai_prompt += content
|
||||
else:
|
||||
prompt_str_list: List[str] = []
|
||||
for m in messages:
|
||||
try: # expect list of int/list of list of int to be a 1 message array only.
|
||||
content = convert_content_list_to_str(cast(AllMessageValues, m))
|
||||
prompt_str_list.append(content)
|
||||
except Exception as e:
|
||||
raise e
|
||||
openai_prompt = prompt_str_list
|
||||
return openai_prompt
|
||||
Reference in New Issue
Block a user