chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1,198 @@
"""
Transformation logic for context caching.
Why separate file? Make it easy to see how transformation works
"""
import re
from typing import List, Optional, Tuple, Literal
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.vertex_ai import CachedContentRequestBody
from litellm.utils import is_cached_message
from ..common_utils import get_supports_system_message
from ..gemini.transformation import (
_gemini_convert_messages_with_history,
_transform_system_message,
)
def get_first_continuous_block_idx(
filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message)
) -> int:
"""
Find the array index that ends the first continuous sequence of message blocks.
Args:
filtered_messages: List of tuples containing (index, message) pairs
Returns:
int: The array index where the first continuous sequence ends
"""
if not filtered_messages:
return -1
if len(filtered_messages) == 1:
return 0
current_value = filtered_messages[0][0]
# Search forward through the array indices
for i in range(1, len(filtered_messages)):
if filtered_messages[i][0] != current_value + 1:
return i - 1
current_value = filtered_messages[i][0]
# If we made it through the whole list, return the last index
return len(filtered_messages) - 1
def extract_ttl_from_cached_messages(messages: List[AllMessageValues]) -> Optional[str]:
"""
Extract TTL from cached messages. Returns the first valid TTL found.
Args:
messages: List of messages to extract TTL from
Returns:
Optional[str]: TTL string in format "3600s" or None if not found/invalid
"""
for message in messages:
if not is_cached_message(message):
continue
content = message.get("content")
if not content or isinstance(content, str):
continue
for content_item in content:
# Type check to ensure content_item is a dictionary before calling .get()
if not isinstance(content_item, dict):
continue
cache_control = content_item.get("cache_control")
if not cache_control or not isinstance(cache_control, dict):
continue
if cache_control.get("type") != "ephemeral":
continue
ttl = cache_control.get("ttl")
if ttl and _is_valid_ttl_format(ttl):
return str(ttl)
return None
def _is_valid_ttl_format(ttl: str) -> bool:
"""
Validate TTL format. Should be a string ending with 's' for seconds.
Examples: "3600s", "7200s", "1.5s"
Args:
ttl: TTL string to validate
Returns:
bool: True if valid format, False otherwise
"""
if not isinstance(ttl, str):
return False
# TTL should end with 's' and contain a valid number before it
pattern = r"^([0-9]*\.?[0-9]+)s$"
match = re.match(pattern, ttl)
if not match:
return False
try:
# Ensure the numeric part is valid and positive
numeric_part = float(match.group(1))
return numeric_part > 0
except ValueError:
return False
def separate_cached_messages(
messages: List[AllMessageValues],
) -> Tuple[List[AllMessageValues], List[AllMessageValues]]:
"""
Returns separated cached and non-cached messages.
Args:
messages: List of messages to be separated.
Returns:
Tuple containing:
- cached_messages: List of cached messages.
- non_cached_messages: List of non-cached messages.
"""
cached_messages: List[AllMessageValues] = []
non_cached_messages: List[AllMessageValues] = []
# Extract cached messages and their indices
filtered_messages: List[Tuple[int, AllMessageValues]] = []
for idx, message in enumerate(messages):
if is_cached_message(message=message):
filtered_messages.append((idx, message))
# Validate only one block of continuous cached messages
last_continuous_block_idx = get_first_continuous_block_idx(filtered_messages)
# Separate messages based on the block of cached messages
if filtered_messages and last_continuous_block_idx is not None:
first_cached_idx = filtered_messages[0][0]
last_cached_idx = filtered_messages[last_continuous_block_idx][0]
cached_messages = messages[first_cached_idx : last_cached_idx + 1]
non_cached_messages = (
messages[:first_cached_idx] + messages[last_cached_idx + 1 :]
)
else:
non_cached_messages = messages
return cached_messages, non_cached_messages
def transform_openai_messages_to_gemini_context_caching(
model: str,
messages: List[AllMessageValues],
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
cache_key: str,
vertex_project: Optional[str],
vertex_location: Optional[str],
) -> CachedContentRequestBody:
# Extract TTL from cached messages BEFORE system message transformation
ttl = extract_ttl_from_cached_messages(messages)
supports_system_message = get_supports_system_message(
model=model, custom_llm_provider=custom_llm_provider
)
transformed_system_messages, new_messages = _transform_system_message(
supports_system_message=supports_system_message, messages=messages
)
transformed_messages = _gemini_convert_messages_with_history(
messages=new_messages, model=model
)
model_name = "models/{}".format(model)
if custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta":
model_name = f"projects/{vertex_project}/locations/{vertex_location}/publishers/google/{model_name}"
data = CachedContentRequestBody(
contents=transformed_messages,
model=model_name,
displayName=cache_key,
)
# Add TTL if present and valid
if ttl:
data["ttl"] = ttl
if transformed_system_messages is not None:
data["system_instruction"] = transformed_system_messages
return data

View File

@@ -0,0 +1,578 @@
from typing import List, Literal, Optional, Tuple, Union
import httpx
import litellm
from litellm.caching.caching import Cache, LiteLLMCacheType
from litellm.constants import MINIMUM_PROMPT_CACHE_TOKEN_COUNT
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm._logging import verbose_logger
from litellm.llms.openai.openai import AllMessageValues
from litellm.utils import is_prompt_caching_valid_prompt
from litellm.types.llms.vertex_ai import (
CachedContentListAllResponseBody,
VertexAICachedContentResponseObject,
)
from ..common_utils import VertexAIError
from ..vertex_llm_base import VertexBase
from .transformation import (
separate_cached_messages,
transform_openai_messages_to_gemini_context_caching,
)
local_cache_obj = Cache(
type=LiteLLMCacheType.LOCAL
) # only used for calling 'get_cache_key' function
MAX_PAGINATION_PAGES = 100 # Reasonable upper bound for pagination
class ContextCachingEndpoints(VertexBase):
"""
Covers context caching endpoints for Vertex AI + Google AI Studio
v0: covers Google AI Studio
"""
def __init__(self) -> None:
pass
def _get_token_and_url_context_caching(
self,
gemini_api_key: Optional[str],
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
api_base: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_auth_header: Optional[str],
) -> Tuple[Optional[str], str]:
"""
Internal function. Returns the token and url for the call.
Handles logic if it's google ai studio vs. vertex ai.
Returns
token, url
"""
if custom_llm_provider == "gemini":
auth_header = None
endpoint = "cachedContents"
url = "https://generativelanguage.googleapis.com/v1beta/{}?key={}".format(
endpoint, gemini_api_key
)
elif custom_llm_provider == "vertex_ai":
auth_header = vertex_auth_header
endpoint = "cachedContents"
if vertex_location == "global":
url = f"https://aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/{endpoint}"
else:
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/{endpoint}"
else:
auth_header = vertex_auth_header
endpoint = "cachedContents"
if vertex_location == "global":
url = f"https://aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/{endpoint}"
else:
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/{endpoint}"
return self._check_custom_proxy(
api_base=api_base,
custom_llm_provider=custom_llm_provider,
gemini_api_key=gemini_api_key,
endpoint=endpoint,
stream=None,
auth_header=auth_header,
url=url,
model=None,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_api_version="v1beta1"
if custom_llm_provider == "vertex_ai_beta"
else "v1",
)
def check_cache(
self,
cache_key: str,
client: HTTPHandler,
headers: dict,
api_key: str,
api_base: Optional[str],
logging_obj: Logging,
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_auth_header: Optional[str],
) -> Optional[str]:
"""
Checks if content already cached.
Currently, checks cache list, for cache key == displayName, since Google doesn't let us set the name of the cache (their API docs are out of sync with actual implementation).
Returns
- cached_content_name - str - cached content name stored on google. (if found.)
OR
- None
"""
_, base_url = self._get_token_and_url_context_caching(
gemini_api_key=api_key,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_auth_header=vertex_auth_header,
)
page_token: Optional[str] = None
# Iterate through all pages
for _ in range(MAX_PAGINATION_PAGES):
# Build URL with pagination token if present
if page_token:
separator = "&" if "?" in base_url else "?"
url = f"{base_url}{separator}pageToken={page_token}"
else:
url = base_url
try:
## LOGGING
logging_obj.pre_call(
input="",
api_key="",
additional_args={
"complete_input_dict": {},
"api_base": url,
"headers": headers,
},
)
resp = client.get(url=url, headers=headers)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
if e.response.status_code == 403:
return None
raise VertexAIError(
status_code=e.response.status_code, message=e.response.text
)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
raw_response = resp.json()
logging_obj.post_call(original_response=raw_response)
if "cachedContents" not in raw_response:
return None
all_cached_items = CachedContentListAllResponseBody(**raw_response)
if "cachedContents" not in all_cached_items:
return None
# Check current page for matching cache_key
for cached_item in all_cached_items["cachedContents"]:
display_name = cached_item.get("displayName")
if display_name is not None and display_name == cache_key:
return cached_item.get("name")
# Check if there are more pages
page_token = all_cached_items.get("nextPageToken")
if not page_token:
# No more pages, cache not found
break
return None
async def async_check_cache(
self,
cache_key: str,
client: AsyncHTTPHandler,
headers: dict,
api_key: str,
api_base: Optional[str],
logging_obj: Logging,
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_auth_header: Optional[str],
) -> Optional[str]:
"""
Checks if content already cached.
Currently, checks cache list, for cache key == displayName, since Google doesn't let us set the name of the cache (their API docs are out of sync with actual implementation).
Returns
- cached_content_name - str - cached content name stored on google. (if found.)
OR
- None
"""
_, base_url = self._get_token_and_url_context_caching(
gemini_api_key=api_key,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_auth_header=vertex_auth_header,
)
page_token: Optional[str] = None
# Iterate through all pages
for _ in range(MAX_PAGINATION_PAGES):
# Build URL with pagination token if present
if page_token:
separator = "&" if "?" in base_url else "?"
url = f"{base_url}{separator}pageToken={page_token}"
else:
url = base_url
try:
## LOGGING
logging_obj.pre_call(
input="",
api_key="",
additional_args={
"complete_input_dict": {},
"api_base": url,
"headers": headers,
},
)
resp = await client.get(url=url, headers=headers)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
if e.response.status_code == 403:
return None
raise VertexAIError(
status_code=e.response.status_code, message=e.response.text
)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
raw_response = resp.json()
logging_obj.post_call(original_response=raw_response)
if "cachedContents" not in raw_response:
return None
all_cached_items = CachedContentListAllResponseBody(**raw_response)
if "cachedContents" not in all_cached_items:
return None
# Check current page for matching cache_key
for cached_item in all_cached_items["cachedContents"]:
display_name = cached_item.get("displayName")
if display_name is not None and display_name == cache_key:
return cached_item.get("name")
# Check if there are more pages
page_token = all_cached_items.get("nextPageToken")
if not page_token:
# No more pages, cache not found
break
return None
def check_and_create_cache(
self,
messages: List[AllMessageValues], # receives openai format messages
optional_params: dict, # cache the tools if present, in case cache content exists in messages
api_key: str,
api_base: Optional[str],
model: str,
client: Optional[HTTPHandler],
timeout: Optional[Union[float, httpx.Timeout]],
logging_obj: Logging,
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_auth_header: Optional[str],
extra_headers: Optional[dict] = None,
cached_content: Optional[str] = None,
) -> Tuple[List[AllMessageValues], dict, Optional[str]]:
"""
Receives
- messages: List of dict - messages in the openai format
Returns
- messages - List[dict] - filtered list of messages in the openai format.
- cached_content - str - the cache content id, to be passed in the gemini request body
Follows - https://ai.google.dev/api/caching#request-body
"""
if cached_content is not None:
return messages, optional_params, cached_content
cached_messages, non_cached_messages = separate_cached_messages(
messages=messages
)
if len(cached_messages) == 0:
return messages, optional_params, None
# Gemini requires a minimum of 1024 tokens for context caching.
# Skip caching if the cached content is too small to avoid API errors.
if not is_prompt_caching_valid_prompt(
model=model,
messages=cached_messages,
custom_llm_provider=custom_llm_provider,
):
verbose_logger.debug(
"Vertex AI context caching: cached content is below minimum token "
"count (%d). Skipping context caching.",
MINIMUM_PROMPT_CACHE_TOKEN_COUNT,
)
return messages, optional_params, None
tools = optional_params.pop("tools", None)
## AUTHORIZATION ##
token, url = self._get_token_and_url_context_caching(
gemini_api_key=api_key,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_auth_header=vertex_auth_header,
)
headers = {
"Content-Type": "application/json",
}
if token is not None:
headers["Authorization"] = f"Bearer {token}"
if extra_headers is not None:
headers.update(extra_headers)
if client is None or not isinstance(client, HTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = HTTPHandler(**_params) # type: ignore
else:
client = client
## CHECK IF CACHED ALREADY
generated_cache_key = local_cache_obj.get_cache_key(
messages=cached_messages, tools=tools, model=model
)
google_cache_name = self.check_cache(
cache_key=generated_cache_key,
client=client,
headers=headers,
api_key=api_key,
api_base=api_base,
logging_obj=logging_obj,
custom_llm_provider=custom_llm_provider,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_auth_header=vertex_auth_header,
)
if google_cache_name:
return non_cached_messages, optional_params, google_cache_name
## TRANSFORM REQUEST
cached_content_request_body = (
transform_openai_messages_to_gemini_context_caching(
model=model,
messages=cached_messages,
cache_key=generated_cache_key,
custom_llm_provider=custom_llm_provider,
vertex_project=vertex_project,
vertex_location=vertex_location,
)
)
cached_content_request_body["tools"] = tools
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": cached_content_request_body,
"api_base": url,
"headers": headers,
},
)
try:
response = client.post(
url=url, headers=headers, json=cached_content_request_body # type: ignore
)
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
raw_response_cached = response.json()
cached_content_response_obj = VertexAICachedContentResponseObject(
name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
)
return (
non_cached_messages,
optional_params,
cached_content_response_obj["name"],
)
async def async_check_and_create_cache(
self,
messages: List[AllMessageValues], # receives openai format messages
optional_params: dict, # cache the tools if present, in case cache content exists in messages
api_key: str,
api_base: Optional[str],
model: str,
client: Optional[AsyncHTTPHandler],
timeout: Optional[Union[float, httpx.Timeout]],
logging_obj: Logging,
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_auth_header: Optional[str],
extra_headers: Optional[dict] = None,
cached_content: Optional[str] = None,
) -> Tuple[List[AllMessageValues], dict, Optional[str]]:
"""
Receives
- messages: List of dict - messages in the openai format
Returns
- messages - List[dict] - filtered list of messages in the openai format.
- cached_content - str - the cache content id, to be passed in the gemini request body
Follows - https://ai.google.dev/api/caching#request-body
"""
if cached_content is not None:
return messages, optional_params, cached_content
cached_messages, non_cached_messages = separate_cached_messages(
messages=messages
)
if len(cached_messages) == 0:
return messages, optional_params, None
# Gemini requires a minimum of 1024 tokens for context caching.
# Skip caching if the cached content is too small to avoid API errors.
if not is_prompt_caching_valid_prompt(
model=model,
messages=cached_messages,
custom_llm_provider=custom_llm_provider,
):
verbose_logger.debug(
"Vertex AI context caching: cached content is below minimum token "
"count (%d). Skipping context caching.",
MINIMUM_PROMPT_CACHE_TOKEN_COUNT,
)
return messages, optional_params, None
tools = optional_params.pop("tools", None)
## AUTHORIZATION ##
token, url = self._get_token_and_url_context_caching(
gemini_api_key=api_key,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_auth_header=vertex_auth_header,
)
headers = {
"Content-Type": "application/json",
}
if token is not None:
headers["Authorization"] = f"Bearer {token}"
if extra_headers is not None:
headers.update(extra_headers)
if client is None or not isinstance(client, AsyncHTTPHandler):
client = get_async_httpx_client(
params={"timeout": timeout}, llm_provider=litellm.LlmProviders.VERTEX_AI
)
else:
client = client
## CHECK IF CACHED ALREADY
generated_cache_key = local_cache_obj.get_cache_key(
messages=cached_messages, tools=tools, model=model
)
google_cache_name = await self.async_check_cache(
cache_key=generated_cache_key,
client=client,
headers=headers,
api_key=api_key,
api_base=api_base,
logging_obj=logging_obj,
custom_llm_provider=custom_llm_provider,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_auth_header=vertex_auth_header,
)
if google_cache_name:
return non_cached_messages, optional_params, google_cache_name
## TRANSFORM REQUEST
cached_content_request_body = (
transform_openai_messages_to_gemini_context_caching(
model=model,
messages=cached_messages,
cache_key=generated_cache_key,
custom_llm_provider=custom_llm_provider,
vertex_project=vertex_project,
vertex_location=vertex_location,
)
)
cached_content_request_body["tools"] = tools
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": cached_content_request_body,
"api_base": url,
"headers": headers,
},
)
try:
response = await client.post(
url=url, headers=headers, json=cached_content_request_body # type: ignore
)
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
raw_response_cached = response.json()
cached_content_response_obj = VertexAICachedContentResponseObject(
name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
)
return (
non_cached_messages,
optional_params,
cached_content_response_obj["name"],
)
def get_cache(self):
pass
async def async_get_cache(self):
pass