Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/router_utils/prompt_caching_cache.py

260 lines
9.7 KiB
Python
Raw Normal View History

"""
Wrapper around router cache. Meant to store model id when prompt caching supported prompt is called.
"""
import hashlib
import json
from typing import TYPE_CHECKING, Any, List, Optional, Union, cast
from typing_extensions import TypedDict
from litellm.caching.caching import DualCache
from litellm.caching.in_memory_cache import InMemoryCache
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.router import Router
litellm_router = Router
Span = Union[_Span, Any]
else:
Span = Any
litellm_router = Any
class PromptCachingCacheValue(TypedDict):
model_id: str
class PromptCachingCache:
def __init__(self, cache: DualCache):
self.cache = cache
self.in_memory_cache = InMemoryCache()
@staticmethod
def serialize_object(obj: Any) -> Any:
"""Helper function to serialize Pydantic objects, dictionaries, or fallback to string."""
if hasattr(obj, "dict"):
# If the object is a Pydantic model, use its `dict()` method
return obj.dict()
elif isinstance(obj, dict):
# If the object is a dictionary, serialize it with sorted keys
return json.dumps(
obj, sort_keys=True, separators=(",", ":")
) # Standardize serialization
elif isinstance(obj, list):
# Serialize lists by ensuring each element is handled properly
return [PromptCachingCache.serialize_object(item) for item in obj]
elif isinstance(obj, (int, float, bool)):
return obj # Keep primitive types as-is
return str(obj)
@staticmethod
def extract_cacheable_prefix(
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
"""
Extract the cacheable prefix from messages.
The cacheable prefix is everything UP TO AND INCLUDING the LAST content block
(across all messages) that has cache_control. This includes ALL blocks before
the last cacheable block (even if they don't have cache_control).
Args:
messages: List of messages to extract cacheable prefix from
Returns:
List of messages containing only the cacheable prefix
"""
if not messages:
return messages
# Find the last content block (across all messages) that has cache_control
last_cacheable_message_idx = None
last_cacheable_content_idx = None
for msg_idx, message in enumerate(messages):
content = message.get("content")
# Check for cache_control at message level (when content is a string)
# This handles the case where cache_control is a sibling of string content:
# {"role": "user", "content": "...", "cache_control": {"type": "ephemeral"}}
message_level_cache_control = message.get("cache_control")
if (
message_level_cache_control is not None
and isinstance(message_level_cache_control, dict)
and message_level_cache_control.get("type") == "ephemeral"
):
last_cacheable_message_idx = msg_idx
# Set to None to indicate the entire message content is cacheable
# (not a specific content block index within a list)
last_cacheable_content_idx = None
# Also check for cache_control within content blocks (when content is a list)
if not isinstance(content, list):
continue
for content_idx, content_block in enumerate(content):
if isinstance(content_block, dict):
cache_control = content_block.get("cache_control")
if (
cache_control is not None
and isinstance(cache_control, dict)
and cache_control.get("type") == "ephemeral"
):
last_cacheable_message_idx = msg_idx
last_cacheable_content_idx = content_idx
# If no cacheable block found, return empty list (no cacheable prefix)
if last_cacheable_message_idx is None:
return []
# Build the cacheable prefix: all messages up to and including the last cacheable message
cacheable_prefix = []
for msg_idx, message in enumerate(messages):
if msg_idx < last_cacheable_message_idx:
# Include entire message (comes before last cacheable block)
cacheable_prefix.append(message)
elif msg_idx == last_cacheable_message_idx:
# Include message but only up to and including the last cacheable content block
content = message.get("content")
if isinstance(content, list) and last_cacheable_content_idx is not None:
# Create a copy of the message with only cacheable content blocks
message_copy = cast(
AllMessageValues,
{
**message,
"content": content[: last_cacheable_content_idx + 1],
},
)
cacheable_prefix.append(message_copy)
else:
# Content is not a list or cacheable content idx is None, include full message
cacheable_prefix.append(message)
else:
# Message comes after last cacheable block, don't include
break
return cacheable_prefix
@staticmethod
def get_prompt_caching_cache_key(
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[str]:
if messages is None and tools is None:
return None
# Extract cacheable prefix from messages (only include up to last cache_control block)
cacheable_messages = None
if messages is not None:
cacheable_messages = PromptCachingCache.extract_cacheable_prefix(messages)
# If no cacheable prefix found, return None (can't cache)
if not cacheable_messages:
return None
# Use serialize_object for consistent and stable serialization
data_to_hash = {}
if cacheable_messages is not None:
serialized_messages = PromptCachingCache.serialize_object(
cacheable_messages
)
data_to_hash["messages"] = serialized_messages
if tools is not None:
serialized_tools = PromptCachingCache.serialize_object(tools)
data_to_hash["tools"] = serialized_tools
# Combine serialized data into a single string
data_to_hash_str = json.dumps(
data_to_hash,
sort_keys=True,
separators=(",", ":"),
)
# Create a hash of the serialized data for a stable cache key
hashed_data = hashlib.sha256(data_to_hash_str.encode()).hexdigest()
return f"deployment:{hashed_data}:prompt_caching"
def add_model_id(
self,
model_id: str,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> None:
if messages is None and tools is None:
return None
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
# If no cacheable prefix found, don't cache (can't generate cache key)
if cache_key is None:
return None
self.cache.set_cache(
cache_key, PromptCachingCacheValue(model_id=model_id), ttl=300
)
return None
async def async_add_model_id(
self,
model_id: str,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> None:
if messages is None and tools is None:
return None
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
# If no cacheable prefix found, don't cache (can't generate cache key)
if cache_key is None:
return None
await self.cache.async_set_cache(
cache_key,
PromptCachingCacheValue(model_id=model_id),
ttl=300, # store for 5 minutes
)
return None
async def async_get_model_id(
self,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[PromptCachingCacheValue]:
"""
Get model ID from cache using the cacheable prefix.
The cache key is based on the cacheable prefix (everything up to and including
the last cache_control block), so requests with the same cacheable prefix but
different user messages will have the same cache key.
"""
if messages is None and tools is None:
return None
# Generate cache key using cacheable prefix
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
if cache_key is None:
return None
# Perform cache lookup
cache_result = await self.cache.async_get_cache(key=cache_key)
return cache_result
def get_model_id(
self,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[PromptCachingCacheValue]:
if messages is None and tools is None:
return None
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
# If no cacheable prefix found, return None (can't cache)
if cache_key is None:
return None
return self.cache.get_cache(cache_key)