chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,119 @@
|
||||
from typing import Dict, List, Mapping, Optional, Union
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.constants import PASS_THROUGH_HEADER_PREFIX
|
||||
|
||||
|
||||
class BasePassthroughUtils:
|
||||
@staticmethod
|
||||
def get_merged_query_parameters(
|
||||
existing_url: httpx.URL,
|
||||
request_query_params: Mapping[str, Union[str, list]],
|
||||
default_query_params: Optional[Dict[str, Union[str, list]]] = None,
|
||||
) -> Dict[str, Union[str, List[str]]]:
|
||||
# Get the existing query params from the target URL
|
||||
existing_query_string = existing_url.query.decode("utf-8")
|
||||
existing_query_params = parse_qs(existing_query_string)
|
||||
|
||||
# parse_qs returns a dict where each value is a list, so let's flatten it
|
||||
updated_existing_query_params = {
|
||||
k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items()
|
||||
}
|
||||
|
||||
# Start with default query params (lowest priority)
|
||||
merged_params = {}
|
||||
if default_query_params:
|
||||
merged_params.update(default_query_params)
|
||||
|
||||
# Override with existing URL query params (medium priority)
|
||||
merged_params.update(updated_existing_query_params)
|
||||
|
||||
# Override with request query params (highest priority - client can override anything)
|
||||
merged_params.update(request_query_params)
|
||||
|
||||
return merged_params
|
||||
|
||||
@staticmethod
|
||||
def forward_headers_from_request(
|
||||
request_headers: dict,
|
||||
headers: dict,
|
||||
forward_headers: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Helper to forward headers from original request.
|
||||
|
||||
Also handles 'x-pass-' prefixed headers which are always forwarded
|
||||
with the prefix stripped, regardless of forward_headers setting.
|
||||
e.g., 'x-pass-anthropic-beta: value' becomes 'anthropic-beta: value'
|
||||
"""
|
||||
if forward_headers is True:
|
||||
# Header We Should NOT forward
|
||||
request_headers.pop("content-length", None)
|
||||
request_headers.pop("host", None)
|
||||
|
||||
# Combine request headers with custom headers
|
||||
headers = {**request_headers, **headers}
|
||||
|
||||
# Always process x-pass- prefixed headers (strip prefix and forward)
|
||||
for header_name, header_value in request_headers.items():
|
||||
if header_name.lower().startswith(PASS_THROUGH_HEADER_PREFIX):
|
||||
# Strip the 'x-pass-' prefix to get the actual header name
|
||||
actual_header_name = header_name[len(PASS_THROUGH_HEADER_PREFIX) :]
|
||||
headers[actual_header_name] = header_value
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
class CommonUtils:
|
||||
@staticmethod
|
||||
def encode_bedrock_runtime_modelid_arn(endpoint: str) -> str:
|
||||
"""
|
||||
Encodes any "/" found in the modelId of an AWS Bedrock Runtime Endpoint when arns are passed in.
|
||||
- modelID value can be an ARN which contains slashes that SHOULD NOT be treated as path separators.
|
||||
e.g endpoint: /model/<modelId>/invoke
|
||||
<modelId> containing arns with slashes need to be encoded from
|
||||
arn:aws:bedrock:ap-southeast-1:123456789012:application-inference-profile/abdefg12334 =>
|
||||
arn:aws:bedrock:ap-southeast-1:123456789012:application-inference-profile%2Fabdefg12334
|
||||
so that it is treated as one part of the path.
|
||||
Otherwise, the encoded endpoint will return 500 error when passed to Bedrock endpoint.
|
||||
|
||||
See the apis in https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Operations_Amazon_Bedrock_Runtime.html
|
||||
for more details on the regex patterns of modelId which we use in the regex logic below.
|
||||
|
||||
Args:
|
||||
endpoint (str): The original endpoint string which may contain ARNs that contain slashes.
|
||||
|
||||
Returns:
|
||||
str: The endpoint with properly encoded ARN slashes
|
||||
"""
|
||||
import re
|
||||
|
||||
# Early exit: if no ARN detected, return unchanged
|
||||
if "arn:aws:" not in endpoint:
|
||||
return endpoint
|
||||
|
||||
# Handle all patterns in one go - more efficient and cleaner
|
||||
patterns = [
|
||||
# Custom model with 2 slashes (order matters - do this first)
|
||||
(r"(custom-model)/([a-z0-9.-]+)/([a-z0-9]+)", r"\1%2F\2%2F\3"),
|
||||
# All other resource types with 1 slash
|
||||
(r"(:application-inference-profile)/", r"\1%2F"),
|
||||
(r"(:inference-profile)/", r"\1%2F"),
|
||||
(r"(:foundation-model)/", r"\1%2F"),
|
||||
(r"(:imported-model)/", r"\1%2F"),
|
||||
(r"(:provisioned-model)/", r"\1%2F"),
|
||||
(r"(:prompt)/", r"\1%2F"),
|
||||
(r"(:endpoint)/", r"\1%2F"),
|
||||
(r"(:prompt-router)/", r"\1%2F"),
|
||||
(r"(:default-prompt-router)/", r"\1%2F"),
|
||||
]
|
||||
|
||||
for pattern, replacement in patterns:
|
||||
# Check if pattern exists before applying regex (early exit optimization)
|
||||
if re.search(pattern, endpoint):
|
||||
endpoint = re.sub(pattern, replacement, endpoint)
|
||||
break # Exit after first match since each ARN has only one resource type
|
||||
|
||||
return endpoint
|
||||
Reference in New Issue
Block a user