chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1,42 @@
from litellm.llms.base_llm.image_generation.transformation import (
BaseImageGenerationConfig,
)
from litellm.llms.vertex_ai.common_utils import (
VertexAIModelRoute,
get_vertex_ai_model_route,
)
from .vertex_gemini_transformation import VertexAIGeminiImageGenerationConfig
from .vertex_imagen_transformation import VertexAIImagenImageGenerationConfig
__all__ = [
"VertexAIGeminiImageGenerationConfig",
"VertexAIImagenImageGenerationConfig",
"get_vertex_ai_image_generation_config",
]
def get_vertex_ai_image_generation_config(model: str) -> BaseImageGenerationConfig:
"""
Get the appropriate image generation config for a Vertex AI model.
Routes to the correct transformation class based on the model type:
- Gemini image generation models use generateContent API (VertexAIGeminiImageGenerationConfig)
- Imagen models use predict API (VertexAIImagenImageGenerationConfig)
Args:
model: The model name (e.g., "gemini-2.5-flash-image", "imagegeneration@006")
Returns:
BaseImageGenerationConfig: The appropriate configuration class
"""
# Determine the model route
model_route = get_vertex_ai_model_route(model)
if model_route == VertexAIModelRoute.GEMINI:
# Gemini models use generateContent API
return VertexAIGeminiImageGenerationConfig()
else:
# Default to Imagen for other models (imagegeneration, etc.)
# This includes NON_GEMINI models like imagegeneration@006
return VertexAIImagenImageGenerationConfig()

View File

@@ -0,0 +1,36 @@
"""
Vertex AI Image Generation Cost Calculator
"""
import litellm
from litellm.litellm_core_utils.llm_cost_calc.utils import (
calculate_image_response_cost_from_usage,
)
from litellm.types.utils import ImageResponse
def cost_calculator(
model: str,
image_response: ImageResponse,
) -> float:
"""
Vertex AI Image Generation Cost Calculator
"""
_model_info = litellm.get_model_info(
model=model,
custom_llm_provider="vertex_ai",
)
token_based_cost = calculate_image_response_cost_from_usage(
model=model,
image_response=image_response,
custom_llm_provider="vertex_ai",
)
if token_based_cost is not None:
return token_based_cost
output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
num_images: int = 0
if image_response.data:
num_images = len(image_response.data)
return output_cost_per_image * num_images

View File

@@ -0,0 +1,282 @@
import json
from typing import Any, Dict, List, Optional
import httpx
from openai.types.image import Image
import litellm
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from litellm.types.utils import ImageResponse
class VertexImageGeneration(VertexLLM):
def process_image_generation_response(
self,
json_response: Dict[str, Any],
model_response: ImageResponse,
model: Optional[str] = None,
) -> ImageResponse:
if "predictions" not in json_response:
raise litellm.InternalServerError(
message=f"image generation response does not contain 'predictions', got {json_response}",
llm_provider="vertex_ai",
model=model,
)
predictions = json_response["predictions"]
response_data: List[Image] = []
for prediction in predictions:
bytes_base64_encoded = prediction["bytesBase64Encoded"]
image_object = Image(b64_json=bytes_base64_encoded)
response_data.append(image_object)
model_response.data = response_data
return model_response
def transform_optional_params(self, optional_params: Optional[dict]) -> dict:
"""
Transform the optional params to the format expected by the Vertex AI API.
For example, "aspect_ratio" is transformed to "aspectRatio".
"""
default_params = {
"sampleCount": 1,
}
if optional_params is None:
return default_params
def snake_to_camel(snake_str: str) -> str:
"""Convert snake_case to camelCase"""
components = snake_str.split("_")
return components[0] + "".join(word.capitalize() for word in components[1:])
transformed_params = default_params.copy()
for key, value in optional_params.items():
if "_" in key:
camel_case_key = snake_to_camel(key)
transformed_params[camel_case_key] = value
else:
transformed_params[key] = value
return transformed_params
def image_generation(
self,
prompt: str,
api_base: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
model_response: ImageResponse,
logging_obj: Any,
model: str = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[Any] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
aimg_generation=False,
extra_headers: Optional[dict] = None,
) -> ImageResponse:
if aimg_generation is True:
return self.aimage_generation( # type: ignore
prompt=prompt,
api_base=api_base,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
model=model,
client=client,
optional_params=optional_params,
timeout=timeout,
logging_obj=logging_obj,
model_response=model_response,
)
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
# url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
auth_header: Optional[str] = None
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=None,
auth_header=auth_header,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=False,
custom_llm_provider="vertex_ai",
api_base=api_base,
should_use_v1beta1_features=False,
mode="image_generation",
)
optional_params = optional_params or {
"sampleCount": 1
} # default optional params
# Transform optional params to camelCase format
optional_params = self.transform_optional_params(optional_params)
request_data = {
"instances": [{"prompt": prompt}],
"parameters": optional_params,
}
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": optional_params,
"api_base": api_base,
"headers": headers,
},
)
response = sync_handler.post(
url=api_base,
headers=headers,
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
json_response = response.json()
return self.process_image_generation_response(
json_response, model_response, model
)
async def aimage_generation(
self,
prompt: str,
api_base: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
model_response: ImageResponse,
logging_obj: Any,
model: str = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[AsyncHTTPHandler] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
extra_headers: Optional[dict] = None,
):
response = None
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
self.async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
params={"timeout": timeout},
)
else:
self.async_handler = client # type: ignore
# make POST request to
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
"""
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d {
"instances": [
{
"prompt": "a cat"
}
],
"parameters": {
"sampleCount": 1
}
} \
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
"""
auth_header: Optional[str] = None
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=None,
auth_header=auth_header,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=False,
custom_llm_provider="vertex_ai",
api_base=api_base,
should_use_v1beta1_features=False,
mode="image_generation",
)
# Transform optional params to camelCase format
optional_params = self.transform_optional_params(optional_params)
request_data = {
"instances": [{"prompt": prompt}],
"parameters": optional_params,
}
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": optional_params,
"api_base": api_base,
"headers": headers,
},
)
response = await self.async_handler.post(
url=api_base,
headers=headers,
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
json_response = response.json()
return self.process_image_generation_response(
json_response, model_response, model
)
def is_image_generation_response(self, json_response: Dict[str, Any]) -> bool:
if "predictions" in json_response:
if "bytesBase64Encoded" in json_response["predictions"][0]:
return True
return False

View File

@@ -0,0 +1,327 @@
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import httpx
import litellm
from litellm.llms.base_llm.image_generation.transformation import (
BaseImageGenerationConfig,
)
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
ImageObject,
ImageResponse,
ImageUsage,
ImageUsageInputTokensDetails,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class VertexAIGeminiImageGenerationConfig(BaseImageGenerationConfig, VertexLLM):
"""
Vertex AI Gemini Image Generation Configuration
Uses generateContent API for Gemini image generation models on Vertex AI
Supports models like gemini-2.5-flash-image, gemini-3-pro-image-preview, etc.
"""
def __init__(self) -> None:
BaseImageGenerationConfig.__init__(self)
VertexLLM.__init__(self)
def get_supported_openai_params(self, model: str) -> list:
"""
Gemini image generation supported parameters
Includes native Gemini imageConfig params (aspectRatio, imageSize)
in both camelCase and snake_case variants.
"""
return [
"n",
"size",
"aspectRatio",
"aspect_ratio",
"imageSize",
"image_size",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_params = self.get_supported_openai_params(model)
mapped_params = {}
for k, v in non_default_params.items():
if k not in optional_params.keys():
if k in supported_params:
# Map OpenAI parameters to Gemini format
if k == "n":
mapped_params["candidate_count"] = v
elif k == "size":
# Map OpenAI size format to Gemini aspectRatio
mapped_params["aspectRatio"] = self._map_size_to_aspect_ratio(v)
elif k in ("aspectRatio", "aspect_ratio"):
mapped_params["aspectRatio"] = v
elif k in ("imageSize", "image_size"):
mapped_params["imageSize"] = v
else:
mapped_params[k] = v
return mapped_params
def _map_size_to_aspect_ratio(self, size: str) -> str:
"""
Map OpenAI size format to Gemini aspect ratio format
"""
aspect_ratio_map = {
"1024x1024": "1:1",
"1792x1024": "16:9",
"1024x1792": "9:16",
"1280x896": "4:3",
"896x1280": "3:4",
}
return aspect_ratio_map.get(size, "1:1")
def _resolve_vertex_project(self) -> Optional[str]:
return (
getattr(self, "_vertex_project", None)
or os.environ.get("VERTEXAI_PROJECT")
or getattr(litellm, "vertex_project", None)
or get_secret_str("VERTEXAI_PROJECT")
)
def _resolve_vertex_location(self) -> Optional[str]:
return (
getattr(self, "_vertex_location", None)
or os.environ.get("VERTEXAI_LOCATION")
or os.environ.get("VERTEX_LOCATION")
or getattr(litellm, "vertex_location", None)
or get_secret_str("VERTEXAI_LOCATION")
or get_secret_str("VERTEX_LOCATION")
)
def _resolve_vertex_credentials(self) -> Optional[str]:
return (
getattr(self, "_vertex_credentials", None)
or os.environ.get("VERTEXAI_CREDENTIALS")
or getattr(litellm, "vertex_credentials", None)
or os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
or get_secret_str("VERTEXAI_CREDENTIALS")
)
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for Vertex AI Gemini generateContent API
"""
# Use the model name as provided, handling vertex_ai prefix
model_name = model
if model.startswith("vertex_ai/"):
model_name = model.replace("vertex_ai/", "")
# If a custom api_base is provided, use it directly
# This allows users to use proxies or mock endpoints
if api_base:
return api_base.rstrip("/")
# First check litellm_params (where vertex_ai_project/vertex_ai_location are passed)
# then fall back to environment variables and other sources
vertex_project = (
self.safe_get_vertex_ai_project(litellm_params)
or self._resolve_vertex_project()
)
vertex_location = (
self.safe_get_vertex_ai_location(litellm_params)
or self._resolve_vertex_location()
)
if not vertex_project or not vertex_location:
raise ValueError(
"vertex_project and vertex_location are required for Vertex AI"
)
base_url = get_vertex_base_url(vertex_location)
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model_name}:generateContent"
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
headers = headers or {}
# If a custom api_base is provided, skip credential validation
# This allows users to use proxies or mock endpoints without needing Vertex AI credentials
_api_base = litellm_params.get("api_base") or api_base
if _api_base is not None:
return headers
# First check litellm_params (where vertex_ai_project/vertex_ai_credentials are passed)
# then fall back to environment variables and other sources
vertex_project = (
self.safe_get_vertex_ai_project(litellm_params)
or self._resolve_vertex_project()
)
vertex_credentials = (
self.safe_get_vertex_ai_credentials(litellm_params)
or self._resolve_vertex_credentials()
)
access_token, _ = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
return self.set_headers(access_token, headers)
def transform_image_generation_request(
self,
model: str,
prompt: str,
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform the image generation request to Gemini format
Uses generateContent API with responseModalities: ["IMAGE"]
"""
# Prepare messages with the prompt
contents = [{"role": "user", "parts": [{"text": prompt}]}]
# Prepare generation config
generation_config: Dict[str, Any] = {"responseModalities": ["IMAGE"]}
# Handle image-specific config parameters
image_config: Dict[str, Any] = {}
# Map aspectRatio
if "aspectRatio" in optional_params:
image_config["aspectRatio"] = optional_params["aspectRatio"]
elif "aspect_ratio" in optional_params:
image_config["aspectRatio"] = optional_params["aspect_ratio"]
# Map imageSize (for Gemini 3 Pro)
if "imageSize" in optional_params:
image_config["imageSize"] = optional_params["imageSize"]
elif "image_size" in optional_params:
image_config["imageSize"] = optional_params["image_size"]
if image_config:
generation_config["imageConfig"] = image_config
# Handle candidate_count (n parameter)
if "candidate_count" in optional_params:
generation_config["candidateCount"] = optional_params["candidate_count"]
elif "n" in optional_params:
generation_config["candidateCount"] = optional_params["n"]
request_body: Dict[str, Any] = {
"contents": contents,
"generationConfig": generation_config,
}
return request_body
def _transform_image_usage(self, usage: dict) -> ImageUsage:
input_tokens_details = ImageUsageInputTokensDetails(
image_tokens=0,
text_tokens=0,
)
tokens_details = usage.get("promptTokensDetails", [])
for details in tokens_details:
if isinstance(details, dict) and (modality := details.get("modality")):
token_count = details.get("tokenCount", 0)
if modality == "TEXT":
input_tokens_details.text_tokens += token_count
elif modality == "IMAGE":
input_tokens_details.image_tokens += token_count
return ImageUsage(
input_tokens=usage.get("promptTokenCount", 0),
input_tokens_details=input_tokens_details,
output_tokens=usage.get("candidatesTokenCount", 0),
total_tokens=usage.get("totalTokenCount", 0),
)
def transform_image_generation_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ImageResponse:
"""
Transform Gemini image generation response to litellm ImageResponse format
"""
try:
response_data = raw_response.json()
except Exception as e:
raise self.get_error_class(
error_message=f"Error transforming image generation response: {e}",
status_code=raw_response.status_code,
headers=raw_response.headers,
)
if not model_response.data:
model_response.data = []
# Gemini image generation models return in candidates format
candidates = response_data.get("candidates", [])
for candidate in candidates:
content = candidate.get("content", {})
parts = content.get("parts", [])
for part in parts:
# Look for inlineData with image
if "inlineData" in part:
inline_data = part["inlineData"]
if "data" in inline_data:
thought_sig = part.get("thoughtSignature")
model_response.data.append(
ImageObject(
b64_json=inline_data["data"],
url=None,
provider_specific_fields={
"thought_signature": thought_sig
}
if thought_sig
else None,
)
)
if usage_metadata := response_data.get("usageMetadata", None):
model_response.usage = self._transform_image_usage(usage_metadata)
return model_response

View File

@@ -0,0 +1,256 @@
import os
from typing import TYPE_CHECKING, Any, List, Optional
import httpx
import litellm
from litellm.llms.base_llm.image_generation.transformation import (
BaseImageGenerationConfig,
)
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,
OpenAIImageGenerationOptionalParams,
)
from litellm.types.utils import ImageObject, ImageResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class VertexAIImagenImageGenerationConfig(BaseImageGenerationConfig, VertexLLM):
"""
Vertex AI Imagen Image Generation Configuration
Uses predict API for Imagen models on Vertex AI
Supports models like imagegeneration@006
"""
def __init__(self) -> None:
BaseImageGenerationConfig.__init__(self)
VertexLLM.__init__(self)
def get_supported_openai_params(
self, model: str
) -> List[OpenAIImageGenerationOptionalParams]:
"""
Imagen API supported parameters
"""
return ["n", "size"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_params = self.get_supported_openai_params(model)
mapped_params = {}
for k, v in non_default_params.items():
if k not in optional_params.keys():
if k in supported_params:
# Map OpenAI parameters to Imagen format
if k == "n":
mapped_params["sampleCount"] = v
elif k == "size":
# Map OpenAI size format to Imagen aspectRatio
mapped_params["aspectRatio"] = self._map_size_to_aspect_ratio(v)
else:
mapped_params[k] = v
return mapped_params
def _map_size_to_aspect_ratio(self, size: str) -> str:
"""
Map OpenAI size format to Imagen aspect ratio format
"""
aspect_ratio_map = {
"1024x1024": "1:1",
"1792x1024": "16:9",
"1024x1792": "9:16",
"1280x896": "4:3",
"896x1280": "3:4",
}
return aspect_ratio_map.get(size, "1:1")
def _resolve_vertex_project(self) -> Optional[str]:
return (
getattr(self, "_vertex_project", None)
or os.environ.get("VERTEXAI_PROJECT")
or getattr(litellm, "vertex_project", None)
or get_secret_str("VERTEXAI_PROJECT")
)
def _resolve_vertex_location(self) -> Optional[str]:
return (
getattr(self, "_vertex_location", None)
or os.environ.get("VERTEXAI_LOCATION")
or os.environ.get("VERTEX_LOCATION")
or getattr(litellm, "vertex_location", None)
or get_secret_str("VERTEXAI_LOCATION")
or get_secret_str("VERTEX_LOCATION")
)
def _resolve_vertex_credentials(self) -> Optional[str]:
return (
getattr(self, "_vertex_credentials", None)
or os.environ.get("VERTEXAI_CREDENTIALS")
or getattr(litellm, "vertex_credentials", None)
or os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
or get_secret_str("VERTEXAI_CREDENTIALS")
)
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for Vertex AI Imagen predict API
"""
# Use the model name as provided, handling vertex_ai prefix
model_name = model
if model.startswith("vertex_ai/"):
model_name = model.replace("vertex_ai/", "")
# If a custom api_base is provided, use it directly
# This allows users to use proxies or mock endpoints
if api_base:
return api_base.rstrip("/")
# First check litellm_params (where vertex_ai_project/vertex_ai_location are passed)
# then fall back to environment variables and other sources
vertex_project = (
self.safe_get_vertex_ai_project(litellm_params)
or self._resolve_vertex_project()
)
vertex_location = (
self.safe_get_vertex_ai_location(litellm_params)
or self._resolve_vertex_location()
)
if not vertex_project or not vertex_location:
raise ValueError(
"vertex_project and vertex_location are required for Vertex AI"
)
base_url = get_vertex_base_url(vertex_location)
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model_name}:predict"
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
headers = headers or {}
# If a custom api_base is provided, skip credential validation
# This allows users to use proxies or mock endpoints without needing Vertex AI credentials
_api_base = litellm_params.get("api_base") or api_base
if _api_base is not None:
return headers
# First check litellm_params (where vertex_ai_project/vertex_ai_credentials are passed)
# then fall back to environment variables and other sources
vertex_project = (
self.safe_get_vertex_ai_project(litellm_params)
or self._resolve_vertex_project()
)
vertex_credentials = (
self.safe_get_vertex_ai_credentials(litellm_params)
or self._resolve_vertex_credentials()
)
access_token, _ = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
return self.set_headers(access_token, headers)
def transform_image_generation_request(
self,
model: str,
prompt: str,
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform the image generation request to Imagen format
Uses predict API with instances and parameters
"""
# Default parameters
default_params = {
"sampleCount": 1,
}
# Merge with optional params
parameters = {**default_params, **optional_params}
request_body = {
"instances": [{"prompt": prompt}],
"parameters": parameters,
}
return request_body
def transform_image_generation_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ImageResponse:
"""
Transform Imagen image generation response to litellm ImageResponse format
"""
try:
response_data = raw_response.json()
except Exception as e:
raise self.get_error_class(
error_message=f"Error transforming image generation response: {e}",
status_code=raw_response.status_code,
headers=raw_response.headers,
)
if not model_response.data:
model_response.data = []
# Imagen format - predictions with generated images
predictions = response_data.get("predictions", [])
for prediction in predictions:
# Imagen returns images as bytesBase64Encoded
if "bytesBase64Encoded" in prediction:
model_response.data.append(
ImageObject(
b64_json=prediction["bytesBase64Encoded"],
url=None,
)
)
return model_response