chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
Translate from OpenAI's `/v1/chat/completions` to Perplexity's `/v1/chat/completions`
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Usage, PromptTokensDetailsWrapper
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.types.llms.openai import ChatCompletionAnnotation
|
||||
from litellm.types.llms.openai import ChatCompletionAnnotationURLCitation
|
||||
|
||||
|
||||
class PerplexityChatConfig(OpenAIGPTConfig):
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return "perplexity"
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
api_base = api_base or get_secret_str("PERPLEXITY_API_BASE") or "https://api.perplexity.ai" # type: ignore
|
||||
dynamic_api_key = (
|
||||
api_key
|
||||
or get_secret_str("PERPLEXITYAI_API_KEY")
|
||||
or get_secret_str("PERPLEXITY_API_KEY")
|
||||
)
|
||||
return api_base, dynamic_api_key
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Perplexity supports a subset of OpenAI params
|
||||
|
||||
Ref: https://docs.perplexity.ai/api-reference/chat-completions
|
||||
|
||||
Eg. Perplexity does not support tools, tool_choice, function_call, functions, etc.
|
||||
"""
|
||||
base_openai_params = [
|
||||
"frequency_penalty",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"presence_penalty",
|
||||
"response_format",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_retries",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
try:
|
||||
if litellm.supports_reasoning(
|
||||
model=model, custom_llm_provider=self.custom_llm_provider
|
||||
):
|
||||
base_openai_params.append("reasoning_effort")
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error checking if model supports reasoning: {e}")
|
||||
|
||||
try:
|
||||
if litellm.supports_web_search(
|
||||
model=model, custom_llm_provider=self.custom_llm_provider
|
||||
):
|
||||
base_openai_params.append("web_search_options")
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error checking if model supports web search: {e}")
|
||||
|
||||
return base_openai_params
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
# Call the parent transform_response first to handle the standard transformation
|
||||
model_response = super().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
# Extract and enhance usage with Perplexity-specific fields
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
self._enhance_usage_with_perplexity_fields(
|
||||
model_response, raw_response_json
|
||||
)
|
||||
self._add_citations_as_annotations(model_response, raw_response_json)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Error extracting Perplexity-specific usage fields: {e}"
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
def _enhance_usage_with_perplexity_fields(
|
||||
self, model_response: ModelResponse, raw_response_json: dict
|
||||
) -> None:
|
||||
"""
|
||||
Extract citation tokens and search queries from Perplexity API response
|
||||
and add them to the usage object using standard LiteLLM fields.
|
||||
"""
|
||||
if not hasattr(model_response, "usage") or model_response.usage is None:
|
||||
# Create a usage object if it doesn't exist (when usage was None)
|
||||
model_response.usage = Usage( # type: ignore[attr-defined]
|
||||
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||
)
|
||||
|
||||
usage = model_response.usage # type: ignore[attr-defined]
|
||||
|
||||
# Extract citation tokens count
|
||||
citations = raw_response_json.get("citations", [])
|
||||
citation_tokens = 0
|
||||
if citations:
|
||||
# Count total characters in citations as a proxy for citation tokens
|
||||
# This is an estimation - in practice, you might want to use proper tokenization
|
||||
total_citation_chars = sum(
|
||||
len(str(citation)) for citation in citations if citation
|
||||
)
|
||||
# Rough estimation: ~4 characters per token (OpenAI's general rule)
|
||||
if total_citation_chars > 0:
|
||||
citation_tokens = max(1, total_citation_chars // 4)
|
||||
|
||||
# Extract search queries count from usage or response metadata
|
||||
# Perplexity might include this in the usage object or as separate metadata
|
||||
perplexity_usage = raw_response_json.get("usage", {})
|
||||
|
||||
# Try to extract search queries from usage field first, then root level
|
||||
num_search_queries = perplexity_usage.get("num_search_queries")
|
||||
if num_search_queries is None:
|
||||
num_search_queries = raw_response_json.get("num_search_queries")
|
||||
if num_search_queries is None:
|
||||
num_search_queries = perplexity_usage.get("search_queries")
|
||||
if num_search_queries is None:
|
||||
num_search_queries = raw_response_json.get("search_queries")
|
||||
|
||||
# Create or update prompt_tokens_details to include web search requests and citation tokens
|
||||
if citation_tokens > 0 or (
|
||||
num_search_queries is not None and num_search_queries > 0
|
||||
):
|
||||
if usage.prompt_tokens_details is None:
|
||||
usage.prompt_tokens_details = PromptTokensDetailsWrapper()
|
||||
|
||||
# Store citation tokens count for cost calculation
|
||||
if citation_tokens > 0:
|
||||
setattr(usage, "citation_tokens", citation_tokens)
|
||||
|
||||
# Store search queries count in the standard web_search_requests field
|
||||
if num_search_queries is not None and num_search_queries > 0:
|
||||
usage.prompt_tokens_details.web_search_requests = num_search_queries
|
||||
|
||||
def _add_citations_as_annotations(
|
||||
self, model_response: ModelResponse, raw_response_json: dict
|
||||
) -> None:
|
||||
"""
|
||||
Extract citations and search_results from Perplexity API response
|
||||
and add them as ChatCompletionAnnotation objects to the message.
|
||||
"""
|
||||
if not model_response.choices:
|
||||
return
|
||||
|
||||
# Get the first choice (assuming single response)
|
||||
choice = model_response.choices[0]
|
||||
if not hasattr(choice, "message") or choice.message is None:
|
||||
return
|
||||
|
||||
message = choice.message
|
||||
annotations = []
|
||||
|
||||
# Extract citations from the response
|
||||
citations = raw_response_json.get("citations", [])
|
||||
search_results = raw_response_json.get("search_results", [])
|
||||
|
||||
# Create a mapping of URLs to search result titles
|
||||
url_to_title = {}
|
||||
for result in search_results:
|
||||
if isinstance(result, dict) and "url" in result and "title" in result:
|
||||
url_to_title[result["url"]] = result["title"]
|
||||
|
||||
# Get the message content to find citation positions
|
||||
content = getattr(message, "content", "")
|
||||
if not content:
|
||||
return
|
||||
|
||||
# Find all citation markers like [1], [2], [3], [4] in the text
|
||||
import re
|
||||
|
||||
citation_pattern = r"\[(\d+)\]"
|
||||
citation_matches = list(re.finditer(citation_pattern, content))
|
||||
|
||||
# Create a mapping of citation numbers to URLs
|
||||
citation_number_to_url = {}
|
||||
for i, citation in enumerate(citations):
|
||||
if isinstance(citation, str):
|
||||
citation_number_to_url[i + 1] = citation # 1-indexed
|
||||
|
||||
# Create annotations for each citation match found in the text
|
||||
for match in citation_matches:
|
||||
citation_number = int(match.group(1))
|
||||
if citation_number in citation_number_to_url:
|
||||
url = citation_number_to_url[citation_number]
|
||||
title = url_to_title.get(url, "")
|
||||
|
||||
# Create the URL citation annotation with actual text positions
|
||||
url_citation: ChatCompletionAnnotationURLCitation = {
|
||||
"url": url,
|
||||
"title": title,
|
||||
"start_index": match.start(),
|
||||
"end_index": match.end(),
|
||||
}
|
||||
|
||||
annotation: ChatCompletionAnnotation = {
|
||||
"type": "url_citation",
|
||||
"url_citation": url_citation,
|
||||
}
|
||||
|
||||
annotations.append(annotation)
|
||||
|
||||
# Add annotations to the message if we have any
|
||||
if annotations:
|
||||
if not hasattr(message, "annotations") or message.annotations is None:
|
||||
message.annotations = []
|
||||
message.annotations.extend(annotations)
|
||||
|
||||
# Also add the raw citations and search_results as attributes for backward compatibility
|
||||
if citations:
|
||||
setattr(model_response, "citations", citations)
|
||||
if search_results:
|
||||
setattr(model_response, "search_results", search_results)
|
||||
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
Helper util for handling perplexity-specific cost calculation
|
||||
- e.g.: citation tokens, search queries
|
||||
"""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
from litellm.types.utils import Usage
|
||||
from litellm.utils import get_model_info
|
||||
|
||||
|
||||
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- usage: LiteLLM Usage block, containing perplexity-specific usage information
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
## USE PRE-CALCULATED COST FROM PERPLEXITY IF AVAILABLE
|
||||
## Perplexity returns accurate cost in usage.cost.total_cost including request fees
|
||||
cost_info = getattr(usage, "cost", None)
|
||||
if cost_info is not None and isinstance(cost_info, dict):
|
||||
total_cost = cost_info.get("total_cost")
|
||||
if total_cost is not None:
|
||||
# Return total cost as completion_cost (prompt_cost=0) since Perplexity
|
||||
# doesn't break down by input/output in their cost object
|
||||
return (0.0, float(total_cost))
|
||||
|
||||
## FALLBACK: Calculate cost manually if Perplexity doesn't provide it
|
||||
## GET MODEL INFO
|
||||
model_info = get_model_info(model=model, custom_llm_provider="perplexity")
|
||||
|
||||
def _safe_float_cast(
|
||||
value: Union[str, int, float, None, object], default: float = 0.0
|
||||
) -> float:
|
||||
"""Safely cast a value to float with proper type handling for mypy."""
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return float(value) # type: ignore
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
## CALCULATE INPUT COST
|
||||
input_cost_per_token = _safe_float_cast(model_info.get("input_cost_per_token"))
|
||||
prompt_cost: float = (usage.prompt_tokens or 0) * input_cost_per_token
|
||||
|
||||
## ADD CITATION TOKENS COST (if present)
|
||||
citation_tokens = getattr(usage, "citation_tokens", 0) or 0
|
||||
citation_cost_value = model_info.get("citation_cost_per_token")
|
||||
if citation_tokens > 0 and citation_cost_value is not None:
|
||||
citation_cost_per_token = _safe_float_cast(citation_cost_value)
|
||||
prompt_cost += citation_tokens * citation_cost_per_token
|
||||
|
||||
## CALCULATE OUTPUT COST
|
||||
output_cost_per_token = _safe_float_cast(model_info.get("output_cost_per_token"))
|
||||
completion_cost: float = (usage.completion_tokens or 0) * output_cost_per_token
|
||||
|
||||
## ADD REASONING TOKENS COST (if present)
|
||||
reasoning_tokens = getattr(usage, "reasoning_tokens", 0) or 0
|
||||
# Also check completion_tokens_details if reasoning_tokens is not directly available
|
||||
if (
|
||||
reasoning_tokens == 0
|
||||
and hasattr(usage, "completion_tokens_details")
|
||||
and usage.completion_tokens_details
|
||||
):
|
||||
reasoning_tokens = (
|
||||
getattr(usage.completion_tokens_details, "reasoning_tokens", 0) or 0
|
||||
)
|
||||
|
||||
reasoning_cost_value = model_info.get("output_cost_per_reasoning_token")
|
||||
if reasoning_tokens > 0 and reasoning_cost_value is not None:
|
||||
reasoning_cost_per_token = _safe_float_cast(reasoning_cost_value)
|
||||
completion_cost += reasoning_tokens * reasoning_cost_per_token
|
||||
|
||||
## ADD SEARCH QUERIES COST (if present)
|
||||
num_search_queries = 0
|
||||
if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
|
||||
num_search_queries = (
|
||||
getattr(usage.prompt_tokens_details, "web_search_requests", 0) or 0
|
||||
)
|
||||
|
||||
# Check both possible keys for search cost (legacy and current)
|
||||
search_cost_value = model_info.get(
|
||||
"search_queries_cost_per_query"
|
||||
) or model_info.get("search_context_cost_per_query")
|
||||
if num_search_queries > 0 and search_cost_value is not None:
|
||||
# Handle both dict and float formats
|
||||
if isinstance(search_cost_value, dict):
|
||||
# Use the "low" size as default - tests expect 0.005 / 1000
|
||||
search_cost_per_query = (
|
||||
_safe_float_cast(search_cost_value.get("search_context_size_low", 0))
|
||||
/ 1000
|
||||
)
|
||||
else:
|
||||
search_cost_per_query = _safe_float_cast(search_cost_value)
|
||||
search_cost = num_search_queries * search_cost_per_query
|
||||
# Add search cost to completion cost (similar to how other providers handle it)
|
||||
completion_cost += search_cost
|
||||
|
||||
return prompt_cost, completion_cost
|
||||
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Perplexity AI Embedding API
|
||||
|
||||
Docs: https://docs.perplexity.ai/api-reference/embeddings-post
|
||||
|
||||
Supports models:
|
||||
- pplx-embed-v1-0.6b (1024 dims, 32 K context)
|
||||
- pplx-embed-v1-4b (2560 dims, 32 K context)
|
||||
|
||||
Perplexity returns embeddings as base64-encoded signed int8 values by default.
|
||||
This module decodes them into float arrays for OpenAI-compatible responses.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import struct
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
|
||||
from litellm.types.utils import EmbeddingResponse, Usage
|
||||
|
||||
|
||||
class PerplexityEmbeddingError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Union[dict, httpx.Headers] = {},
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api.perplexity.ai/v1/embeddings"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
class PerplexityEmbeddingConfig(BaseEmbeddingConfig):
|
||||
"""
|
||||
Reference: https://docs.perplexity.ai/api-reference/embeddings-post
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
if api_base:
|
||||
if not api_base.endswith("/embeddings"):
|
||||
api_base = f"{api_base}/v1/embeddings"
|
||||
return api_base
|
||||
return "https://api.perplexity.ai/v1/embeddings"
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
return [
|
||||
"dimensions",
|
||||
"encoding_format",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "dimensions":
|
||||
optional_params["dimensions"] = v
|
||||
elif k == "encoding_format":
|
||||
optional_params["encoding_format"] = v
|
||||
return optional_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
api_key = get_secret_str("PERPLEXITYAI_API_KEY") or get_secret_str(
|
||||
"PERPLEXITY_API_KEY"
|
||||
)
|
||||
return {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: AllEmbeddingInputValues,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
return {
|
||||
"model": model,
|
||||
"input": input,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _decode_base64_embedding(embedding_value: Any) -> List[float]:
|
||||
"""
|
||||
Decode a Perplexity embedding into a list of floats.
|
||||
|
||||
Perplexity returns base64-encoded signed int8 values by default.
|
||||
If the value is already a list of numbers (e.g. from a mock or
|
||||
future float format), it is returned as-is.
|
||||
"""
|
||||
if isinstance(embedding_value, list):
|
||||
return embedding_value
|
||||
if isinstance(embedding_value, str):
|
||||
raw_bytes = base64.b64decode(embedding_value)
|
||||
count = len(raw_bytes)
|
||||
int8_values = struct.unpack(f"{count}b", raw_bytes)
|
||||
return [float(v) / 127.0 for v in int8_values]
|
||||
return embedding_value
|
||||
|
||||
def transform_embedding_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> EmbeddingResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise PerplexityEmbeddingError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
model_response.model = raw_response_json.get("model", model)
|
||||
model_response.object = raw_response_json.get("object", "list")
|
||||
|
||||
raw_data = raw_response_json.get("data", [])
|
||||
decoded_data: List[Dict[str, Any]] = []
|
||||
for item in raw_data:
|
||||
decoded_item = dict(item)
|
||||
decoded_item["embedding"] = self._decode_base64_embedding(
|
||||
item.get("embedding")
|
||||
)
|
||||
decoded_data.append(decoded_item)
|
||||
model_response.data = decoded_data
|
||||
|
||||
usage_data = raw_response_json.get("usage", {})
|
||||
usage = Usage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0)
|
||||
or usage_data.get("total_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self,
|
||||
error_message: str,
|
||||
status_code: int,
|
||||
headers: Union[dict, httpx.Headers],
|
||||
) -> BaseLLMException:
|
||||
return PerplexityEmbeddingError(
|
||||
message=error_message, status_code=status_code, headers=headers
|
||||
)
|
||||
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Perplexity Agent API (Responses API) module
|
||||
"""
|
||||
|
||||
from .transformation import PerplexityResponsesConfig
|
||||
|
||||
__all__ = ["PerplexityResponsesConfig"]
|
||||
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
Perplexity Responses API — OpenAI-compatible.
|
||||
|
||||
The only provider quirks:
|
||||
- cost returned as dict → handled by ResponseAPIUsage.parse_cost validator
|
||||
- preset models (preset/pro-search) → handled by transform_responses_api_request
|
||||
- HTTP 200 with status:"failed" → raised as exception in transform_response_api_response
|
||||
|
||||
Ref: https://docs.perplexity.ai/api-reference/responses-post
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import ResponseInputParam, ResponsesAPIResponse
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
|
||||
class PerplexityResponsesConfig(OpenAIResponsesAPIConfig):
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""Ref: https://docs.perplexity.ai/api-reference/responses-post"""
|
||||
return [
|
||||
"max_output_tokens",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"tools",
|
||||
"reasoning",
|
||||
"instructions",
|
||||
"models",
|
||||
]
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.PERPLEXITY
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
litellm_params = litellm_params or GenericLiteLLMParams()
|
||||
api_key = (
|
||||
litellm_params.api_key
|
||||
or get_secret_str("PERPLEXITYAI_API_KEY")
|
||||
or get_secret_str("PERPLEXITY_API_KEY")
|
||||
)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], litellm_params: dict) -> str:
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret_str("PERPLEXITY_API_BASE")
|
||||
or "https://api.perplexity.ai"
|
||||
)
|
||||
return f"{api_base.rstrip('/')}/v1/responses"
|
||||
|
||||
def _ensure_message_type(
|
||||
self, input: Union[str, ResponseInputParam]
|
||||
) -> Union[str, ResponseInputParam]:
|
||||
"""Ensure list input items have type='message' (required by Perplexity)."""
|
||||
if isinstance(input, str):
|
||||
return input
|
||||
if isinstance(input, list):
|
||||
result: List[Any] = []
|
||||
for item in input:
|
||||
if isinstance(item, dict) and "type" not in item:
|
||||
new_item = dict(item) # convert to plain dict to avoid TypedDict checking
|
||||
new_item["type"] = "message"
|
||||
result.append(new_item)
|
||||
else:
|
||||
result.append(item)
|
||||
return result
|
||||
return input
|
||||
|
||||
def transform_responses_api_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, ResponseInputParam],
|
||||
response_api_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
"""Handle preset/ model prefix: send as {"preset": name} instead of {"model": name}."""
|
||||
input = self._ensure_message_type(input)
|
||||
if model.startswith("preset/"):
|
||||
input = self._validate_input_param(input)
|
||||
data: Dict = {
|
||||
"preset": model[len("preset/") :],
|
||||
"input": input,
|
||||
}
|
||||
data.update(response_api_optional_request_params)
|
||||
return data
|
||||
return super().transform_responses_api_request(
|
||||
model=model,
|
||||
input=input,
|
||||
response_api_optional_request_params=response_api_optional_request_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def transform_response_api_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
"""Check for Perplexity's status:'failed' on HTTP 200 before delegating to base."""
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raw_response_json = None
|
||||
|
||||
if (
|
||||
isinstance(raw_response_json, dict)
|
||||
and raw_response_json.get("status") == "failed"
|
||||
):
|
||||
error = raw_response_json.get("error", {})
|
||||
raise BaseLLMException(
|
||||
status_code=raw_response.status_code,
|
||||
message=error.get("message", "Unknown Perplexity error"),
|
||||
)
|
||||
|
||||
return super().transform_response_api_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
def supports_native_websocket(self) -> bool:
|
||||
"""Perplexity does not support native WebSocket for Responses API"""
|
||||
return False
|
||||
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Calls Perplexity's /search endpoint to search the web.
|
||||
"""
|
||||
from typing import Dict, List, Optional, TypedDict, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.search.transformation import (
|
||||
BaseSearchConfig,
|
||||
SearchResponse,
|
||||
SearchResult,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
|
||||
class _PerplexitySearchRequestRequired(TypedDict):
|
||||
"""Required fields for Perplexity Search API request."""
|
||||
|
||||
query: Union[str, List[str]] # Required - search query or queries
|
||||
|
||||
|
||||
class PerplexitySearchRequest(_PerplexitySearchRequestRequired, total=False):
|
||||
"""
|
||||
Perplexity Search API request format.
|
||||
Based on: https://docs.perplexity.ai/api-reference/search-post
|
||||
"""
|
||||
|
||||
max_results: int # Optional - maximum number of results (1-20), default 10
|
||||
search_domain_filter: List[str] # Optional - list of domains to filter (max 20)
|
||||
max_tokens_per_page: int # Optional - max tokens per page, default 1024
|
||||
country: str # Optional - country code filter (e.g., 'US', 'GB', 'DE')
|
||||
|
||||
|
||||
class PerplexitySearchConfig(BaseSearchConfig):
|
||||
PERPLEXITY_API_BASE = "https://api.perplexity.ai"
|
||||
|
||||
@staticmethod
|
||||
def ui_friendly_name() -> str:
|
||||
return "Perplexity"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: Dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""
|
||||
Validate environment and return headers.
|
||||
"""
|
||||
api_key = api_key or get_secret_str("PERPLEXITYAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"PERPLEXITYAI_API_KEY is not set. Set `PERPLEXITYAI_API_KEY` environment variable."
|
||||
)
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
optional_params: dict,
|
||||
data: Optional[Union[Dict, List[Dict]]] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Get complete URL for Search endpoint.
|
||||
"""
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret_str("PERPLEXITY_API_BASE")
|
||||
or self.PERPLEXITY_API_BASE
|
||||
)
|
||||
|
||||
# append "/search" to the api base if it's not already there
|
||||
if not api_base.endswith("/search"):
|
||||
api_base = f"{api_base}/search"
|
||||
|
||||
return api_base
|
||||
|
||||
def transform_search_request(
|
||||
self,
|
||||
query: Union[str, List[str]],
|
||||
optional_params: dict,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""
|
||||
Transform Search request to Perplexity API format.
|
||||
|
||||
Note: LiteLLM's native spec is the perplexity search spec.
|
||||
|
||||
There's no transformation needed for the request data.
|
||||
|
||||
https://docs.perplexity.ai/api-reference/search-post
|
||||
|
||||
Args:
|
||||
query: Search query (string or list of strings)
|
||||
optional_params: Optional parameters for the request
|
||||
- max_results: Maximum number of search results (1-20)
|
||||
- search_domain_filter: List of domains to filter (max 20)
|
||||
- max_tokens_per_page: Max tokens per page (default 1024)
|
||||
- country: Country code filter (e.g., 'US', 'GB', 'DE')
|
||||
|
||||
Returns:
|
||||
Dict with typed request data following PerplexitySearchRequest spec
|
||||
"""
|
||||
request_data: PerplexitySearchRequest = {
|
||||
"query": query,
|
||||
}
|
||||
|
||||
# Add optional parameters following Perplexity API spec (only if not None)
|
||||
max_results = optional_params.get("max_results")
|
||||
if max_results is not None:
|
||||
request_data["max_results"] = max_results
|
||||
|
||||
search_domain_filter = optional_params.get("search_domain_filter")
|
||||
if search_domain_filter is not None:
|
||||
request_data["search_domain_filter"] = search_domain_filter
|
||||
|
||||
max_tokens_per_page = optional_params.get("max_tokens_per_page")
|
||||
if max_tokens_per_page is not None:
|
||||
request_data["max_tokens_per_page"] = max_tokens_per_page
|
||||
|
||||
country = optional_params.get("country")
|
||||
if country is not None:
|
||||
request_data["country"] = country
|
||||
|
||||
return dict(request_data)
|
||||
|
||||
def transform_search_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
**kwargs,
|
||||
) -> SearchResponse:
|
||||
"""
|
||||
Transform Perplexity API response to standard SearchResponse format.
|
||||
|
||||
Args:
|
||||
raw_response: Raw httpx response from Perplexity API
|
||||
logging_obj: Logging object for tracking
|
||||
|
||||
Returns:
|
||||
SearchResponse with standardized format
|
||||
"""
|
||||
response_json = raw_response.json()
|
||||
|
||||
# Transform results to SearchResult objects
|
||||
results = []
|
||||
for result in response_json.get("results", []):
|
||||
search_result = SearchResult(
|
||||
title=result.get("title", ""),
|
||||
url=result.get("url", ""),
|
||||
snippet=result.get("snippet", ""),
|
||||
date=result.get("date"),
|
||||
last_updated=result.get("last_updated"),
|
||||
)
|
||||
results.append(search_result)
|
||||
|
||||
return SearchResponse(
|
||||
results=results,
|
||||
object="search",
|
||||
)
|
||||
Reference in New Issue
Block a user