chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,42 @@
|
||||
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
|
||||
from litellm.llms.vertex_ai.common_utils import (
|
||||
VertexAIModelRoute,
|
||||
get_vertex_ai_model_route,
|
||||
)
|
||||
|
||||
from .cost_calculator import cost_calculator
|
||||
from .vertex_gemini_transformation import VertexAIGeminiImageEditConfig
|
||||
from .vertex_imagen_transformation import VertexAIImagenImageEditConfig
|
||||
|
||||
__all__ = [
|
||||
"VertexAIGeminiImageEditConfig",
|
||||
"VertexAIImagenImageEditConfig",
|
||||
"get_vertex_ai_image_edit_config",
|
||||
"cost_calculator",
|
||||
]
|
||||
|
||||
|
||||
def get_vertex_ai_image_edit_config(model: str) -> BaseImageEditConfig:
|
||||
"""
|
||||
Get the appropriate image edit config for a Vertex AI model.
|
||||
|
||||
Routes to the correct transformation class based on the model type:
|
||||
- Gemini models use generateContent API (VertexAIGeminiImageEditConfig)
|
||||
- Imagen models use predict API (VertexAIImagenImageEditConfig)
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., "gemini-2.5-flash", "imagegeneration@006")
|
||||
|
||||
Returns:
|
||||
BaseImageEditConfig: The appropriate configuration class
|
||||
"""
|
||||
# Determine the model route
|
||||
model_route = get_vertex_ai_model_route(model)
|
||||
|
||||
if model_route == VertexAIModelRoute.GEMINI:
|
||||
# Gemini models use generateContent API
|
||||
return VertexAIGeminiImageEditConfig()
|
||||
else:
|
||||
# Default to Imagen for other models (imagegeneration, etc.)
|
||||
# This includes NON_GEMINI models like imagegeneration@006
|
||||
return VertexAIImagenImageEditConfig()
|
||||
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Vertex AI Image Edit Cost Calculator
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
def cost_calculator(
|
||||
model: str,
|
||||
image_response: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Vertex AI image edit cost calculator.
|
||||
|
||||
Mirrors image generation pricing: charge per returned image based on
|
||||
model metadata (`output_cost_per_image`).
|
||||
"""
|
||||
model_info = litellm.get_model_info(
|
||||
model=model,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
output_cost_per_image: float = model_info.get("output_cost_per_image") or 0.0
|
||||
|
||||
if not isinstance(image_response, ImageResponse):
|
||||
raise ValueError(
|
||||
f"image_response must be of type ImageResponse got type={type(image_response)}"
|
||||
)
|
||||
|
||||
num_images = len(image_response.data or [])
|
||||
return output_cost_per_image * num_images
|
||||
@@ -0,0 +1,298 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from io import BufferedReader, BytesIO
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
from httpx._types import RequestFiles
|
||||
|
||||
import litellm
|
||||
from litellm.images.utils import ImageEditRequestUtils
|
||||
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
|
||||
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.images.main import ImageEditOptionalRequestParams
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import FileTypes, ImageObject, ImageResponse, OpenAIImage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class VertexAIGeminiImageEditConfig(BaseImageEditConfig, VertexLLM):
|
||||
"""
|
||||
Vertex AI Gemini Image Edit Configuration
|
||||
|
||||
Uses generateContent API for Gemini models on Vertex AI
|
||||
"""
|
||||
|
||||
SUPPORTED_PARAMS: List[str] = ["size"]
|
||||
|
||||
def __init__(self) -> None:
|
||||
BaseImageEditConfig.__init__(self)
|
||||
VertexLLM.__init__(self)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return list(self.SUPPORTED_PARAMS)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
image_edit_optional_params: ImageEditOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict[str, Any]:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
filtered_params = {
|
||||
key: value
|
||||
for key, value in image_edit_optional_params.items()
|
||||
if key in supported_params
|
||||
}
|
||||
|
||||
mapped_params: Dict[str, Any] = {}
|
||||
|
||||
if "size" in filtered_params:
|
||||
mapped_params["aspectRatio"] = self._map_size_to_aspect_ratio(
|
||||
filtered_params["size"] # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
return mapped_params
|
||||
|
||||
def _resolve_vertex_project(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_project", None)
|
||||
or os.environ.get("VERTEXAI_PROJECT")
|
||||
or getattr(litellm, "vertex_project", None)
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
|
||||
def _resolve_vertex_location(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_location", None)
|
||||
or os.environ.get("VERTEXAI_LOCATION")
|
||||
or os.environ.get("VERTEX_LOCATION")
|
||||
or getattr(litellm, "vertex_location", None)
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
or get_secret_str("VERTEX_LOCATION")
|
||||
)
|
||||
|
||||
def _resolve_vertex_credentials(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_credentials", None)
|
||||
or os.environ.get("VERTEXAI_CREDENTIALS")
|
||||
or getattr(litellm, "vertex_credentials", None)
|
||||
or os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
|
||||
or get_secret_str("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
headers = headers or {}
|
||||
litellm_params = litellm_params or {}
|
||||
|
||||
# If a custom api_base is provided, skip credential validation
|
||||
# This allows users to use proxies or mock endpoints without needing Vertex AI credentials
|
||||
_api_base = litellm_params.get("api_base") or api_base
|
||||
if _api_base is not None:
|
||||
return headers
|
||||
|
||||
# First check litellm_params (where vertex_ai_project/vertex_ai_credentials are passed)
|
||||
# then fall back to environment variables and other sources
|
||||
vertex_project = (
|
||||
self.safe_get_vertex_ai_project(litellm_params)
|
||||
or self._resolve_vertex_project()
|
||||
)
|
||||
vertex_credentials = (
|
||||
self.safe_get_vertex_ai_credentials(litellm_params)
|
||||
or self._resolve_vertex_credentials()
|
||||
)
|
||||
access_token, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
return self.set_headers(access_token, headers)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Vertex AI Gemini generateContent API
|
||||
"""
|
||||
# Use the model name as provided, handling vertex_ai prefix
|
||||
model_name = model
|
||||
if model.startswith("vertex_ai/"):
|
||||
model_name = model.replace("vertex_ai/", "")
|
||||
|
||||
# If a custom api_base is provided, use it directly
|
||||
# This allows users to use proxies or mock endpoints
|
||||
if api_base:
|
||||
return api_base.rstrip("/")
|
||||
|
||||
# First check litellm_params (where vertex_ai_project/vertex_ai_location are passed)
|
||||
# then fall back to environment variables and other sources
|
||||
vertex_project = (
|
||||
self.safe_get_vertex_ai_project(litellm_params)
|
||||
or self._resolve_vertex_project()
|
||||
)
|
||||
vertex_location = (
|
||||
self.safe_get_vertex_ai_location(litellm_params)
|
||||
or self._resolve_vertex_location()
|
||||
)
|
||||
|
||||
if not vertex_project or not vertex_location:
|
||||
raise ValueError(
|
||||
"vertex_project and vertex_location are required for Vertex AI"
|
||||
)
|
||||
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
|
||||
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model_name}:generateContent"
|
||||
|
||||
def transform_image_edit_request( # type: ignore[override]
|
||||
self,
|
||||
model: str,
|
||||
prompt: Optional[str],
|
||||
image: Optional[FileTypes],
|
||||
image_edit_optional_request_params: Dict[str, Any],
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[Dict[str, Any], Optional[RequestFiles]]:
|
||||
inline_parts = self._prepare_inline_image_parts(image) if image else []
|
||||
if not inline_parts:
|
||||
raise ValueError("Vertex AI Gemini image edit requires at least one image.")
|
||||
|
||||
# Build parts list with image and prompt (if provided)
|
||||
parts = inline_parts.copy()
|
||||
if prompt is not None and prompt != "":
|
||||
parts.append({"text": prompt})
|
||||
|
||||
# Correct format for Vertex AI Gemini image editing
|
||||
contents = {"role": "USER", "parts": parts}
|
||||
|
||||
request_body: Dict[str, Any] = {"contents": contents}
|
||||
|
||||
# Generation config with proper structure for image editing
|
||||
generation_config: Dict[str, Any] = {"response_modalities": ["IMAGE"]}
|
||||
|
||||
# Add image-specific configuration
|
||||
image_config: Dict[str, Any] = {}
|
||||
if "aspectRatio" in image_edit_optional_request_params:
|
||||
image_config["aspect_ratio"] = image_edit_optional_request_params[
|
||||
"aspectRatio"
|
||||
]
|
||||
|
||||
if image_config:
|
||||
generation_config["image_config"] = image_config
|
||||
|
||||
request_body["generationConfig"] = generation_config
|
||||
|
||||
payload: Any = json.dumps(request_body)
|
||||
empty_files = cast(RequestFiles, [])
|
||||
return cast(
|
||||
Tuple[Dict[str, Any], Optional[RequestFiles]], (payload, empty_files)
|
||||
)
|
||||
|
||||
def transform_image_edit_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: Any,
|
||||
) -> ImageResponse:
|
||||
model_response = ImageResponse()
|
||||
try:
|
||||
response_json = raw_response.json()
|
||||
except Exception as exc:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image edit response: {exc}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
candidates = response_json.get("candidates", [])
|
||||
data_list: List[ImageObject] = []
|
||||
|
||||
for candidate in candidates:
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
for part in parts:
|
||||
inline_data = part.get("inlineData")
|
||||
if inline_data and inline_data.get("data"):
|
||||
data_list.append(
|
||||
ImageObject(
|
||||
b64_json=inline_data["data"],
|
||||
url=None,
|
||||
)
|
||||
)
|
||||
|
||||
model_response.data = cast(List[OpenAIImage], data_list)
|
||||
return model_response
|
||||
|
||||
def _map_size_to_aspect_ratio(self, size: str) -> str:
|
||||
"""Map OpenAI size format to Gemini aspect ratio format"""
|
||||
aspect_ratio_map = {
|
||||
"1024x1024": "1:1",
|
||||
"1792x1024": "16:9",
|
||||
"1024x1792": "9:16",
|
||||
"1280x896": "4:3",
|
||||
"896x1280": "3:4",
|
||||
}
|
||||
return aspect_ratio_map.get(size, "1:1")
|
||||
|
||||
def _prepare_inline_image_parts(
|
||||
self, image: Union[FileTypes, List[FileTypes]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
images: List[FileTypes]
|
||||
if isinstance(image, list):
|
||||
images = image
|
||||
else:
|
||||
images = [image]
|
||||
|
||||
inline_parts: List[Dict[str, Any]] = []
|
||||
for img in images:
|
||||
if img is None:
|
||||
continue
|
||||
|
||||
mime_type = ImageEditRequestUtils.get_image_content_type(img)
|
||||
image_bytes = self._read_all_bytes(img)
|
||||
inline_parts.append(
|
||||
{
|
||||
"inlineData": {
|
||||
"mimeType": mime_type,
|
||||
"data": base64.b64encode(image_bytes).decode("utf-8"),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return inline_parts
|
||||
|
||||
def _read_all_bytes(self, image: FileTypes) -> bytes:
|
||||
if isinstance(image, bytes):
|
||||
return image
|
||||
if isinstance(image, BytesIO):
|
||||
current_pos = image.tell()
|
||||
image.seek(0)
|
||||
data = image.read()
|
||||
image.seek(current_pos)
|
||||
return data
|
||||
if isinstance(image, BufferedReader):
|
||||
current_pos = image.tell()
|
||||
image.seek(0)
|
||||
data = image.read()
|
||||
image.seek(current_pos)
|
||||
return data
|
||||
raise ValueError("Unsupported image type for Vertex AI Gemini image edit.")
|
||||
@@ -0,0 +1,365 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from io import BufferedRandom, BufferedReader, BytesIO
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
from httpx._types import RequestFiles
|
||||
|
||||
import litellm
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
|
||||
from litellm.llms.vertex_ai.common_utils import get_vertex_base_url
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.images.main import ImageEditOptionalRequestParams
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import FileTypes, ImageObject, ImageResponse, OpenAIImage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class VertexAIImagenImageEditConfig(BaseImageEditConfig, VertexLLM):
|
||||
"""
|
||||
Vertex AI Imagen Image Edit Configuration
|
||||
|
||||
Uses predict API for Imagen models on Vertex AI
|
||||
"""
|
||||
|
||||
SUPPORTED_PARAMS: List[str] = ["n", "size", "mask"]
|
||||
|
||||
def __init__(self) -> None:
|
||||
BaseImageEditConfig.__init__(self)
|
||||
VertexLLM.__init__(self)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return list(self.SUPPORTED_PARAMS)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
image_edit_optional_params: ImageEditOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict[str, Any]:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
filtered_params = {
|
||||
key: value
|
||||
for key, value in image_edit_optional_params.items()
|
||||
if key in supported_params
|
||||
}
|
||||
|
||||
mapped_params: Dict[str, Any] = {}
|
||||
|
||||
# Map OpenAI parameters to Imagen format
|
||||
if "n" in filtered_params:
|
||||
mapped_params["sampleCount"] = filtered_params["n"]
|
||||
|
||||
if "size" in filtered_params:
|
||||
mapped_params["aspectRatio"] = self._map_size_to_aspect_ratio(
|
||||
filtered_params["size"] # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if "mask" in filtered_params:
|
||||
mapped_params["mask"] = filtered_params["mask"]
|
||||
|
||||
return mapped_params
|
||||
|
||||
def _resolve_vertex_project(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_project", None)
|
||||
or os.environ.get("VERTEXAI_PROJECT")
|
||||
or getattr(litellm, "vertex_project", None)
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
|
||||
def _resolve_vertex_location(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_location", None)
|
||||
or os.environ.get("VERTEXAI_LOCATION")
|
||||
or os.environ.get("VERTEX_LOCATION")
|
||||
or getattr(litellm, "vertex_location", None)
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
or get_secret_str("VERTEX_LOCATION")
|
||||
)
|
||||
|
||||
def _resolve_vertex_credentials(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_credentials", None)
|
||||
or os.environ.get("VERTEXAI_CREDENTIALS")
|
||||
or getattr(litellm, "vertex_credentials", None)
|
||||
or os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
|
||||
or get_secret_str("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
headers = headers or {}
|
||||
vertex_project = self._resolve_vertex_project()
|
||||
vertex_credentials = self._resolve_vertex_credentials()
|
||||
access_token, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
return self.set_headers(access_token, headers)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Vertex AI Imagen predict API
|
||||
"""
|
||||
vertex_project = self._resolve_vertex_project()
|
||||
vertex_location = self._resolve_vertex_location()
|
||||
|
||||
if not vertex_project or not vertex_location:
|
||||
raise ValueError(
|
||||
"vertex_project and vertex_location are required for Vertex AI"
|
||||
)
|
||||
|
||||
# Use the model name as provided, handling vertex_ai prefix
|
||||
model_name = model
|
||||
if model.startswith("vertex_ai/"):
|
||||
model_name = model.replace("vertex_ai/", "")
|
||||
|
||||
if api_base:
|
||||
base_url = api_base.rstrip("/")
|
||||
else:
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
|
||||
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model_name}:predict"
|
||||
|
||||
def transform_image_edit_request( # type: ignore[override]
|
||||
self,
|
||||
model: str,
|
||||
prompt: Optional[str],
|
||||
image: Optional[FileTypes],
|
||||
image_edit_optional_request_params: Dict[str, Any],
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[Dict[str, Any], Optional[RequestFiles]]:
|
||||
# Prepare reference images in the correct Imagen format
|
||||
if image is None:
|
||||
raise ValueError(
|
||||
"Vertex AI Imagen image edit requires at least one reference image."
|
||||
)
|
||||
reference_images = self._prepare_reference_images(
|
||||
image, image_edit_optional_request_params
|
||||
)
|
||||
if not reference_images:
|
||||
raise ValueError(
|
||||
"Vertex AI Imagen image edit requires at least one reference image."
|
||||
)
|
||||
|
||||
if prompt is None:
|
||||
raise ValueError("Vertex AI Imagen image edit requires a prompt.")
|
||||
|
||||
# Correct Imagen instances format
|
||||
instances = [{"prompt": prompt, "referenceImages": reference_images}]
|
||||
|
||||
# Extract OpenAI parameters and set sensible defaults for Vertex AI-specific parameters
|
||||
sample_count = image_edit_optional_request_params.get("sampleCount", 1)
|
||||
# Use sensible defaults for Vertex AI-specific parameters (not exposed to users)
|
||||
edit_mode = "EDIT_MODE_INPAINT_INSERTION" # Default edit mode
|
||||
base_steps = 50 # Default number of steps
|
||||
|
||||
# Imagen parameters with correct structure
|
||||
parameters = {
|
||||
"sampleCount": sample_count,
|
||||
"editMode": edit_mode,
|
||||
"editConfig": {"baseSteps": base_steps},
|
||||
}
|
||||
|
||||
# Set default values for Vertex AI-specific parameters (not configurable by users via OpenAI API)
|
||||
parameters["guidanceScale"] = 7.5 # Default guidance scale
|
||||
parameters["seed"] = None # Let Vertex AI choose random seed
|
||||
|
||||
request_body: Dict[str, Any] = {
|
||||
"instances": instances,
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
payload: Any = json.dumps(request_body)
|
||||
empty_files = cast(RequestFiles, [])
|
||||
return cast(
|
||||
Tuple[Dict[str, Any], Optional[RequestFiles]], (payload, empty_files)
|
||||
)
|
||||
|
||||
def transform_image_edit_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: Any,
|
||||
) -> ImageResponse:
|
||||
model_response = ImageResponse()
|
||||
try:
|
||||
response_json = raw_response.json()
|
||||
except Exception as exc:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image edit response: {exc}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
predictions = response_json.get("predictions", [])
|
||||
data_list: List[ImageObject] = []
|
||||
|
||||
for prediction in predictions:
|
||||
# Imagen returns images as bytesBase64Encoded
|
||||
if "bytesBase64Encoded" in prediction:
|
||||
data_list.append(
|
||||
ImageObject(
|
||||
b64_json=prediction["bytesBase64Encoded"],
|
||||
url=None,
|
||||
)
|
||||
)
|
||||
|
||||
model_response.data = cast(List[OpenAIImage], data_list)
|
||||
return model_response
|
||||
|
||||
def _map_size_to_aspect_ratio(self, size: str) -> str:
|
||||
"""Map OpenAI size format to Imagen aspect ratio format"""
|
||||
aspect_ratio_map = {
|
||||
"1024x1024": "1:1",
|
||||
"1792x1024": "16:9",
|
||||
"1024x1792": "9:16",
|
||||
"1280x896": "4:3",
|
||||
"896x1280": "3:4",
|
||||
}
|
||||
return aspect_ratio_map.get(size, "1:1")
|
||||
|
||||
def _prepare_reference_images(
|
||||
self,
|
||||
image: Union[FileTypes, List[FileTypes]],
|
||||
image_edit_optional_request_params: Dict[str, Any],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Prepare reference images in the correct Imagen API format
|
||||
"""
|
||||
images: List[FileTypes]
|
||||
if isinstance(image, list):
|
||||
images = image
|
||||
else:
|
||||
images = [image]
|
||||
|
||||
reference_images: List[Dict[str, Any]] = []
|
||||
|
||||
for idx, img in enumerate(images):
|
||||
if img is None:
|
||||
continue
|
||||
|
||||
image_bytes = self._read_all_bytes(img)
|
||||
base64_data = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
# Create reference image structure
|
||||
reference_image = {
|
||||
"referenceType": "REFERENCE_TYPE_RAW",
|
||||
"referenceId": idx + 1,
|
||||
"referenceImage": {"bytesBase64Encoded": base64_data},
|
||||
}
|
||||
|
||||
reference_images.append(reference_image)
|
||||
|
||||
# Handle mask image if provided (for inpainting)
|
||||
mask_image = image_edit_optional_request_params.get("mask")
|
||||
if mask_image is not None:
|
||||
mask_bytes = self._read_all_bytes(mask_image)
|
||||
mask_base64 = base64.b64encode(mask_bytes).decode("utf-8")
|
||||
|
||||
mask_reference = {
|
||||
"referenceType": "REFERENCE_TYPE_MASK",
|
||||
"referenceId": len(reference_images) + 1,
|
||||
"referenceImage": {"bytesBase64Encoded": mask_base64},
|
||||
"maskImageConfig": {
|
||||
"maskMode": "MASK_MODE_USER_PROVIDED",
|
||||
"dilation": 0.03, # Default dilation value (not configurable via OpenAI API)
|
||||
},
|
||||
}
|
||||
reference_images.append(mask_reference)
|
||||
|
||||
return reference_images
|
||||
|
||||
def _read_all_bytes(
|
||||
self, image: Any, depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
|
||||
) -> bytes:
|
||||
if depth > max_depth:
|
||||
raise ValueError(
|
||||
f"Max recursion depth {max_depth} reached while reading image bytes for Vertex AI Imagen image edit."
|
||||
)
|
||||
|
||||
if isinstance(image, (list, tuple)):
|
||||
for item in image:
|
||||
if item is not None:
|
||||
return self._read_all_bytes(
|
||||
item, depth=depth + 1, max_depth=max_depth
|
||||
)
|
||||
raise ValueError("Unsupported image type for Vertex AI Imagen image edit.")
|
||||
|
||||
if isinstance(image, dict):
|
||||
for key in ("data", "bytes", "content"):
|
||||
if key in image and image[key] is not None:
|
||||
value = image[key]
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return base64.b64decode(value)
|
||||
except Exception:
|
||||
continue
|
||||
return self._read_all_bytes(
|
||||
value, depth=depth + 1, max_depth=max_depth
|
||||
)
|
||||
if "path" in image:
|
||||
return self._read_all_bytes(
|
||||
image["path"], depth=depth + 1, max_depth=max_depth
|
||||
)
|
||||
|
||||
if isinstance(image, bytes):
|
||||
return image
|
||||
if isinstance(image, bytearray):
|
||||
return bytes(image)
|
||||
if isinstance(image, BytesIO):
|
||||
current_pos = image.tell()
|
||||
image.seek(0)
|
||||
data = image.read()
|
||||
image.seek(current_pos)
|
||||
return data
|
||||
if isinstance(image, (BufferedReader, BufferedRandom)):
|
||||
stream_pos: Optional[int] = None
|
||||
try:
|
||||
stream_pos = image.tell()
|
||||
except Exception:
|
||||
stream_pos = None
|
||||
if stream_pos is not None:
|
||||
image.seek(0)
|
||||
data = image.read()
|
||||
if stream_pos is not None:
|
||||
image.seek(stream_pos)
|
||||
return data
|
||||
if isinstance(image, (str, Path)):
|
||||
path_obj = Path(image)
|
||||
if not path_obj.exists():
|
||||
raise ValueError(
|
||||
f"Mask/image path does not exist for Vertex AI Imagen image edit: {path_obj}"
|
||||
)
|
||||
return path_obj.read_bytes()
|
||||
if hasattr(image, "read"):
|
||||
data = image.read()
|
||||
if isinstance(data, str):
|
||||
data = data.encode("utf-8")
|
||||
return data
|
||||
raise ValueError(
|
||||
f"Unsupported image type for Vertex AI Imagen image edit. Got type={type(image)}"
|
||||
)
|
||||
Reference in New Issue
Block a user