Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/llms/chatgpt/chat/transformation.py

80 lines
2.5 KiB
Python
Raw Normal View History

from typing import Any, List, Optional, Tuple
from litellm.exceptions import AuthenticationError
from litellm.llms.openai.openai import OpenAIConfig
from litellm.types.llms.openai import AllMessageValues
from ..authenticator import Authenticator
from ..common_utils import (
GetAccessTokenError,
ensure_chatgpt_session_id,
get_chatgpt_default_headers,
)
from .streaming_utils import ChatGPTToolCallNormalizer
class ChatGPTConfig(OpenAIConfig):
def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
custom_llm_provider: str = "openai",
) -> None:
super().__init__()
self.authenticator = Authenticator()
def _get_openai_compatible_provider_info(
self,
model: str,
api_base: Optional[str],
api_key: Optional[str],
custom_llm_provider: str,
) -> Tuple[Optional[str], Optional[str], str]:
dynamic_api_base = self.authenticator.get_api_base()
try:
dynamic_api_key = self.authenticator.get_access_token()
except GetAccessTokenError as e:
raise AuthenticationError(
model=model,
llm_provider=custom_llm_provider,
message=str(e),
)
return dynamic_api_base, dynamic_api_key, custom_llm_provider
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
validated_headers = super().validate_environment(
headers, model, messages, optional_params, litellm_params, api_key, api_base
)
account_id = self.authenticator.get_account_id()
session_id = ensure_chatgpt_session_id(litellm_params)
default_headers = get_chatgpt_default_headers(
api_key or "", account_id, session_id
)
return {**default_headers, **validated_headers}
def post_stream_processing(self, stream: Any) -> Any:
return ChatGPTToolCallNormalizer(stream)
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
optional_params = super().map_openai_params(
non_default_params, optional_params, model, drop_params
)
optional_params.setdefault("stream", False)
return optional_params