diff --git a/google/generativeai/client.py b/google/generativeai/client.py index a75643f1a..f53007be3 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -3,10 +3,11 @@ import os import contextlib import inspect +import collections import dataclasses import pathlib from typing import Any, cast -from collections.abc import Sequence +from collections.abc import Sequence, Mapping import httplib2 from io import IOBase @@ -23,6 +24,11 @@ import googleapiclient.http import googleapiclient.discovery +from google.protobuf import struct_pb2 + +from proto.marshal.collections import maps +from proto.marshal.collections import repeated + try: from google.generativeai import version @@ -130,6 +136,70 @@ async def create_file(self, *args, **kwargs): ) +# This is to get around https://github.com/googleapis/proto-plus-python/issues/488 +def to_value(value) -> struct_pb2.Value: + """Return a protobuf Value object representing this value.""" + if isinstance(value, struct_pb2.Value): + return value + if value is None: + return struct_pb2.Value(null_value=0) + if isinstance(value, bool): + return struct_pb2.Value(bool_value=value) + if isinstance(value, (int, float)): + return struct_pb2.Value(number_value=float(value)) + if isinstance(value, str): + return struct_pb2.Value(string_value=value) + if isinstance(value, collections.abc.Sequence): + return struct_pb2.Value(list_value=to_list_value(value)) + if isinstance(value, collections.abc.Mapping): + return struct_pb2.Value(struct_value=to_mapping_value(value)) + raise ValueError("Unable to coerce value: %r" % value) + + +def to_list_value(value) -> struct_pb2.ListValue: + # We got a proto, or else something we sent originally. + # Preserve the instance we have. + if isinstance(value, struct_pb2.ListValue): + return value + if isinstance(value, repeated.RepeatedComposite): + return struct_pb2.ListValue(values=[v for v in value.pb]) + + # We got a list (or something list-like); convert it. + return struct_pb2.ListValue(values=[to_value(v) for v in value]) + + +def to_mapping_value(value) -> struct_pb2.Struct: + # We got a proto, or else something we sent originally. + # Preserve the instance we have. + if isinstance(value, struct_pb2.Struct): + return value + if isinstance(value, maps.MapComposite): + return struct_pb2.Struct( + fields={k: v for k, v in value.pb.items()}, + ) + + # We got a dict (or something dict-like); convert it. + return struct_pb2.Struct(fields={k: to_value(v) for k, v in value.items()}) + + +class PredictionServiceClient(glm.PredictionServiceClient): + def predict(self, model=None, instances=None, parameters=None): + pr = protos.PredictRequest.pb() + request = pr( + model=model, instances=[to_value(i) for i in instances], parameters=to_value(parameters) + ) + return super().predict(request) + + +class PredictionServiceAsyncClient(glm.PredictionServiceAsyncClient): + async def predict(self, model=None, instances=None, parameters=None): + pr = protos.PredictRequest.pb() + request = pr( + model=model, instances=[to_value(i) for i in instances], parameters=to_value(parameters) + ) + return await super().predict(request) + + @dataclasses.dataclass class _ClientManager: client_config: dict[str, Any] = dataclasses.field(default_factory=dict) @@ -220,15 +290,20 @@ def configure( self.clients = {} def make_client(self, name): - if name == "file": - cls = FileServiceClient - elif name == "file_async": - cls = FileServiceAsyncClient - elif name.endswith("_async"): - name = name.split("_")[0] - cls = getattr(glm, name.title() + "ServiceAsyncClient") - else: - cls = getattr(glm, name.title() + "ServiceClient") + local_clients = { + "file": FileServiceClient, + "file_async": FileServiceAsyncClient, + "prediction": PredictionServiceClient, + "prediction_async": PredictionServiceAsyncClient, + } + cls = local_clients.get(name, None) + + if cls is None: + if name.endswith("_async"): + name = name.split("_")[0] + cls = getattr(glm, name.title() + "ServiceAsyncClient") + else: + cls = getattr(glm, name.title() + "ServiceClient") # Attempt to configure using defaults. if not self.client_config: diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index 23241a536..3eeababbb 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -16,45 +16,16 @@ from __future__ import annotations from collections.abc import Iterable, Mapping, Sequence -import io import inspect -import mimetypes -import pathlib -import typing from typing import Any, Callable, Union from typing_extensions import TypedDict import pydantic from google.generativeai.types import file_types +from google.generativeai.types.image_types import _image_types from google.generativeai import protos -if typing.TYPE_CHECKING: - import PIL.Image - import PIL.ImageFile - import IPython.display - - IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image) - ImageType = PIL.Image.Image | IPython.display.Image -else: - IMAGE_TYPES = () - try: - import PIL.Image - import PIL.ImageFile - - IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,) - except ImportError: - PIL = None - - try: - import IPython.display - - IMAGE_TYPES = IMAGE_TYPES + (IPython.display.Image,) - except ImportError: - IPython = None - - ImageType = Union["PIL.Image.Image", "IPython.display.Image"] - __all__ = [ "BlobDict", @@ -97,62 +68,6 @@ def to_mode(x: ModeOptions) -> Mode: return _MODE[x] -def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob: - # If the image is a local file, return a file-based blob without any modification. - # Otherwise, return a lossless WebP blob (same quality with optimized size). - def file_blob(image: PIL.Image.Image) -> protos.Blob | None: - if not isinstance(image, PIL.ImageFile.ImageFile) or image.filename is None: - return None - filename = str(image.filename) - if not pathlib.Path(filename).is_file(): - return None - - mime_type = image.get_format_mimetype() - image_bytes = pathlib.Path(filename).read_bytes() - - return protos.Blob(mime_type=mime_type, data=image_bytes) - - def webp_blob(image: PIL.Image.Image) -> protos.Blob: - # Reference: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp - image_io = io.BytesIO() - image.save(image_io, format="webp", lossless=True) - image_io.seek(0) - - mime_type = "image/webp" - image_bytes = image_io.read() - - return protos.Blob(mime_type=mime_type, data=image_bytes) - - return file_blob(image) or webp_blob(image) - - -def image_to_blob(image: ImageType) -> protos.Blob: - if PIL is not None: - if isinstance(image, PIL.Image.Image): - return _pil_to_blob(image) - - if IPython is not None: - if isinstance(image, IPython.display.Image): - name = image.filename - if name is None: - raise ValueError( - "Conversion failed. The `IPython.display.Image` can only be converted if " - "it is constructed from a local file. Please ensure you are using the format: Image(filename='...')." - ) - mime_type, _ = mimetypes.guess_type(name) - if mime_type is None: - mime_type = "image/unknown" - - return protos.Blob(mime_type=mime_type, data=image.data) - - raise TypeError( - "Image conversion failed. The input was expected to be of type `Image` " - "(either `PIL.Image.Image` or `IPython.display.Image`).\n" - f"However, received an object of type: {type(image)}.\n" - f"Object Value: {image}" - ) - - class BlobDict(TypedDict): mime_type: str data: bytes @@ -189,12 +104,7 @@ def is_blob_dict(d): return "mime_type" in d and "data" in d -if typing.TYPE_CHECKING: - BlobType = Union[ - protos.Blob, BlobDict, PIL.Image.Image, IPython.display.Image - ] # Any for the images -else: - BlobType = Union[protos.Blob, BlobDict, Any] +BlobType = Union[protos.Blob, BlobDict, _image_types.ImageType] # Any for the images def to_blob(blob: BlobType) -> protos.Blob: @@ -203,8 +113,8 @@ def to_blob(blob: BlobType) -> protos.Blob: if isinstance(blob, protos.Blob): return blob - elif isinstance(blob, IMAGE_TYPES): - return image_to_blob(blob) + elif isinstance(blob, _image_types.IMAGE_TYPES): + return _image_types.image_to_blob(blob) else: if isinstance(blob, Mapping): raise KeyError( diff --git a/google/generativeai/types/image_types/__init__.py b/google/generativeai/types/image_types/__init__.py new file mode 100644 index 000000000..6e9d0a3fe --- /dev/null +++ b/google/generativeai/types/image_types/__init__.py @@ -0,0 +1 @@ +from google.generativeai.types.image_types._image_types import * diff --git a/google/generativeai/types/image_types/_image_types.py b/google/generativeai/types/image_types/_image_types.py new file mode 100644 index 000000000..ddfea057f --- /dev/null +++ b/google/generativeai/types/image_types/_image_types.py @@ -0,0 +1,332 @@ +import base64 +import io +import json +import mimetypes +import os +import pathlib +import typing +from typing import Any, Dict, Optional, Union + +from google.generativeai import protos +from google.generativeai import client + +# pylint: disable=g-import-not-at-top +if typing.TYPE_CHECKING: + import PIL.Image + import PIL.ImageFile + import IPython.display + + IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image) + ImageType = PIL.Image.Image | IPython.display.Image +else: + IMAGE_TYPES = () + try: + import PIL.Image + import PIL.ImageFile + + IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,) + except ImportError: + PIL = None + + try: + import IPython.display + + IMAGE_TYPES = IMAGE_TYPES + (IPython.display.Image,) + except ImportError: + IPython = None + + ImageType = Union["Image", "PIL.Image.Image", "IPython.display.Image"] +# pylint: enable=g-import-not-at-top + +__all__ = ["Image", "GeneratedImage", "check_watermark", "CheckWatermarkResult", "ImageType"] + + +def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob: + # If the image is a local file, return a file-based blob without any modification. + # Otherwise, return a lossless WebP blob (same quality with optimized size). + def file_blob(image: PIL.Image.Image) -> Union[protos.Blob, None]: + if not isinstance(image, PIL.ImageFile.ImageFile) or image.filename is None: + return None + filename = str(image.filename) + if not pathlib.Path(filename).is_file(): + return None + + mime_type = image.get_format_mimetype() + image_bytes = pathlib.Path(filename).read_bytes() + + return protos.Blob(mime_type=mime_type, data=image_bytes) + + def webp_blob(image: PIL.Image.Image) -> protos.Blob: + # Reference: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp + image_io = io.BytesIO() + image.save(image_io, format="webp", lossless=True) + image_io.seek(0) + + mime_type = "image/webp" + image_bytes = image_io.read() + + return protos.Blob(mime_type=mime_type, data=image_bytes) + + return file_blob(image) or webp_blob(image) + + +def image_to_blob(image: ImageType) -> protos.Blob: + if PIL is not None: + if isinstance(image, PIL.Image.Image): + return _pil_to_blob(image) + + if IPython is not None: + if isinstance(image, IPython.display.Image): + name = image.filename + if name is None: + raise ValueError( + "Conversion failed. The `IPython.display.Image` can only be converted if " + "it is constructed from a local file. Please ensure you are using the format: Image(filename='...')." + ) + mime_type, _ = mimetypes.guess_type(name) + if mime_type is None: + mime_type = "image/unknown" + + return protos.Blob(mime_type=mime_type, data=image.data) + + if isinstance(image, Image): + return protos.Blob(mime_type=image._mime_type, data=image._image_bytes) + + raise TypeError( + "Image conversion failed. The input was expected to be of type `Image` " + "(either `PIL.Image.Image` or `IPython.display.Image`).\n" + f"However, received an object of type: {type(image)}.\n" + f"Object Value: {image}" + ) + + +class CheckWatermarkResult: + def __init__(self, predictions): + self._predictions = predictions + + @property + def decision(self): + return self._predictions[0]["decision"] + + def __str__(self): + return f"CheckWatermarkResult([{{'decision': {self.decision!r}}}])" + + def __bool__(self): + decision = self.decision + if decision == "ACCEPT": + return True + elif decision == "REJECT": + return False + else: + raise ValueError(f"Unrecognized result: {decision}") + + +def check_watermark( + img: Union[pathlib.Path, ImageType], model_id: str = "models/image-verification-001" +) -> "CheckWatermarkResult": + """Checks if an image has a Google-AI watermark. + + Args: + img: can be a `pathlib.Path` or a `PIL.Image.Image`, `IPython.display.Image`, or `google.generativeai.Image`. + model_id: Which version of the image-verification model to send the image to. + + Returns: + + """ + if isinstance(img, Image): + pass + elif isinstance(img, pathlib.Path): + img = Image.load_from_file(img) + elif IPython.display is not None and isinstance(img, IPython.display.Image): + img = Image(image_bytes=img.data) + elif PIL.Image is not None and isinstance(img, PIL.Image.Image): + blob = _pil_to_blob(img) + img = Image(image_bytes=blob.data) + elif isinstance(img, protos.Blob): + img = Image(image_bytes=img.data) + else: + raise TypeError( + f"Not implemented: Could not convert a {type(img)} into `Image`\n {img=}" + ) + + prediction_client = client.get_default_prediction_client() + if not model_id.startswith("models/"): + model_id = f"models/{model_id}" + + instance = {"image": {"bytesBase64Encoded": base64.b64encode(img._loaded_bytes).decode()}} + parameters = {"watermarkVerification": True} + + response = prediction_client.predict( + model=model_id, instances=[instance], parameters=parameters + ) + + return CheckWatermarkResult(response.predictions) + + +class Image: + """Image.""" + + __module__ = "vertexai.vision_models" + + _loaded_bytes: Optional[bytes] = None + _loaded_image: Optional["PIL.Image.Image"] = None + + def __init__( + self, + image_bytes: Optional[bytes], + ): + """Creates an `Image` object. + + Args: + image_bytes: Image file bytes. Image can be in PNG or JPEG format. + """ + self._image_bytes = image_bytes + + @staticmethod + def load_from_file(location: os.PathLike) -> "Image": + """Loads image from local file. + + Args: + location: Local path from where to load + the image. + + Returns: + Loaded image as an `Image` object. + """ + # Load image from local path + image_bytes = pathlib.Path(location).read_bytes() + image = Image(image_bytes=image_bytes) + return image + + @property + def _image_bytes(self) -> bytes: + return self._loaded_bytes + + @_image_bytes.setter + def _image_bytes(self, value: bytes): + self._loaded_bytes = value + + @property + def _pil_image(self) -> "PIL.Image.Image": # type: ignore + if self._loaded_image is None: + if not PIL: + raise RuntimeError( + "The PIL module is not available. Please install the Pillow package." + ) + self._loaded_image = PIL.Image.open(io.BytesIO(self._image_bytes)) + return self._loaded_image + + @property + def _size(self): + return self._pil_image.size + + @property + def _mime_type(self) -> str: + """Returns the MIME type of the image.""" + import PIL + + return PIL.Image.MIME.get(self._pil_image.format, "image/jpeg") + + def show(self): + """Shows the image. + + This method only works when in a notebook environment. + """ + if PIL and IPython: + IPython.display.display(self._pil_image) + + def save(self, location: str): + """Saves image to a file. + + Args: + location: Local path where to save the image. + """ + pathlib.Path(location).write_bytes(self._image_bytes) + + def _as_base64_string(self) -> str: + """Encodes image using the base64 encoding. + + Returns: + Base64 encoding of the image as a string. + """ + # ! b64encode returns `bytes` object, not `str`. + # We need to convert `bytes` to `str`, otherwise we get service error: + # "received initial metadata size exceeds limit" + return base64.b64encode(self._image_bytes).decode("ascii") + + def _repr_png_(self): + return self._pil_image._repr_png_() # type:ignore + + check_watermark = check_watermark + + +_EXIF_USER_COMMENT_TAG_IDX = 0x9286 +_IMAGE_GENERATION_PARAMETERS_EXIF_KEY = ( + "google.cloud.vertexai.image_generation.image_generation_parameters" +) + + +class GeneratedImage(Image): + """Generated image.""" + + __module__ = "google.generativeai" + + def __init__( + self, + image_bytes: Optional[bytes], + generation_parameters: Dict[str, Any], + ): + """Creates a `GeneratedImage` object. + + Args: + image_bytes: Image file bytes. Image can be in PNG or JPEG format. + generation_parameters: Image generation parameter values. + """ + super().__init__(image_bytes=image_bytes) + self._generation_parameters = generation_parameters + + @property + def generation_parameters(self): + """Image generation parameters as a dictionary.""" + return self._generation_parameters + + @staticmethod + def load_from_file(location: os.PathLike) -> "GeneratedImage": + """Loads image from file. + + Args: + location: Local path from where to load the image. + + Returns: + Loaded image as a `GeneratedImage` object. + """ + base_image = Image.load_from_file(location=location) + exif = base_image._pil_image.getexif() # pylint: disable=protected-access + exif_comment_dict = json.loads(exif[_EXIF_USER_COMMENT_TAG_IDX]) + generation_parameters = exif_comment_dict[_IMAGE_GENERATION_PARAMETERS_EXIF_KEY] + return GeneratedImage( + image_bytes=base_image._image_bytes, # pylint: disable=protected-access + generation_parameters=generation_parameters, + ) + + def save(self, location: str, include_generation_parameters: bool = True): + """Saves image to a file. + + Args: + location: Local path where to save the image. + include_generation_parameters: Whether to include the image + generation parameters in the image's EXIF metadata. + """ + if include_generation_parameters: + if not self._generation_parameters: + raise ValueError("Image does not have generation parameters.") + if not PIL: + raise ValueError("The PIL module is required for saving generation parameters.") + + exif = self._pil_image.getexif() + exif[_EXIF_USER_COMMENT_TAG_IDX] = json.dumps( + {_IMAGE_GENERATION_PARAMETERS_EXIF_KEY: self._generation_parameters} + ) + self._pil_image.save(location, exif=exif) + else: + super().save(location=location) diff --git a/google/generativeai/vision_models/__init__.py b/google/generativeai/vision_models/__init__.py index 2a4a27e32..e1b62d39b 100644 --- a/google/generativeai/vision_models/__init__.py +++ b/google/generativeai/vision_models/__init__.py @@ -14,15 +14,15 @@ # """Classes for working with vision models.""" +from google.generativeai.types.image_types import check_watermark, Image, GeneratedImage + from google.generativeai.vision_models._vision_models import ( - check_watermark, - Image, - GeneratedImage, ImageGenerationModel, ImageGenerationResponse, ) __all__ = [ + "check_watermark", "Image", "GeneratedImage", "ImageGenerationModel", diff --git a/google/generativeai/vision_models/_vision_models.py b/google/generativeai/vision_models/_vision_models.py index 52ec689a9..f89ab86e6 100644 --- a/google/generativeai/vision_models/_vision_models.py +++ b/google/generativeai/vision_models/_vision_models.py @@ -16,88 +16,12 @@ """Classes for working with vision models.""" import base64 -import collections import dataclasses -import io -import json -import os -import pathlib import typing -from typing import Any, Dict, List, Literal, Optional, Union +from typing import List, Literal, Optional from google.generativeai import client -from google.generativeai import protos -from google.generativeai.types import content_types - -from google.protobuf import struct_pb2 - -from proto.marshal.collections import maps -from proto.marshal.collections import repeated - - -# pylint: disable=g-import-not-at-top\ -if typing.TYPE_CHECKING: - from IPython import display as IPython_display -else: - try: - from IPython import display as IPython_display - except ImportError: - IPython_display = None - -if typing.TYPE_CHECKING: - import PIL.Image as PIL_Image -else: - try: - from PIL import Image as PIL_Image - except ImportError: - PIL_Image = None - - -# This is to get around https://github.com/googleapis/proto-plus-python/issues/488 -def to_value(value) -> struct_pb2.Value: - """Return a protobuf Value object representing this value.""" - if isinstance(value, struct_pb2.Value): - return value - if value is None: - return struct_pb2.Value(null_value=0) - if isinstance(value, bool): - return struct_pb2.Value(bool_value=value) - if isinstance(value, (int, float)): - return struct_pb2.Value(number_value=float(value)) - if isinstance(value, str): - return struct_pb2.Value(string_value=value) - if isinstance(value, collections.abc.Sequence): - return struct_pb2.Value(list_value=to_list_value(value)) - if isinstance(value, collections.abc.Mapping): - return struct_pb2.Value(struct_value=to_mapping_value(value)) - raise ValueError("Unable to coerce value: %r" % value) - - -def to_list_value(value) -> struct_pb2.ListValue: - # We got a proto, or else something we sent originally. - # Preserve the instance we have. - if isinstance(value, struct_pb2.ListValue): - return value - if isinstance(value, repeated.RepeatedComposite): - return struct_pb2.ListValue(values=[v for v in value.pb]) - - # We got a list (or something list-like); convert it. - return struct_pb2.ListValue(values=[to_value(v) for v in value]) - - -def to_mapping_value(value) -> struct_pb2.Struct: - # We got a proto, or else something we sent originally. - # Preserve the instance we have. - if isinstance(value, struct_pb2.Struct): - return value - if isinstance(value, maps.MapComposite): - return struct_pb2.Struct( - fields={k: v for k, v in value.pb.items()}, - ) - - # We got a dict (or something dict-like); convert it. - return struct_pb2.Struct(fields={k: to_value(v) for k, v in value.items()}) - +from google.generativeai.types import image_types AspectRatio = Literal["1:1", "9:16", "16:9", "4:3", "3:4"] ASPECT_RATIOS = AspectRatio.__args__ # type: ignore @@ -111,171 +35,6 @@ def to_mapping_value(value) -> struct_pb2.Struct: PersonGeneration = Literal["dont_allow", "allow_adult"] PERSON_GENERATIONS = PersonGeneration.__args__ # type: ignore -ImageLikeType = Union["Image", pathlib.Path, content_types.ImageType] - - -def check_watermark( - img: ImageLikeType, model_id: str = "models/image-verification-001" -) -> "CheckWatermarkResult": - """Checks if an image has a Google-AI watermark. - - Args: - img: can be a `pathlib.Path` or a `PIL.Image.Image`, `IPython.display.Image`, or `google.generativeai.Image`. - model_id: Which version of the image-verification model to send the image to. - - Returns: - - """ - if isinstance(img, Image): - pass - elif isinstance(img, pathlib.Path): - img = Image.load_from_file(img) - elif IPython_display is not None and isinstance(img, IPython_display.Image): - img = Image(image_bytes=img.data) - elif PIL_Image is not None and isinstance(img, PIL_Image.Image): - blob = content_types._pil_to_blob(img) - img = Image(image_bytes=blob.data) - elif isinstance(img, protos.Blob): - img = Image(image_bytes=img.data) - else: - raise TypeError( - f"Not implemented: Could not convert a {type(img)} into `Image`\n {img=}" - ) - - prediction_client = client.get_default_prediction_client() - if not model_id.startswith("models/"): - model_id = f"models/{model_id}" - - instance = {"image": {"bytesBase64Encoded": base64.b64encode(img._loaded_bytes).decode()}} - parameters = {"watermarkVerification": True} - - # This is to get around https://github.com/googleapis/proto-plus-python/issues/488 - pr = protos.PredictRequest.pb() - request = pr(model=model_id, instances=[to_value(instance)], parameters=to_value(parameters)) - - response = prediction_client.predict(request) - - return CheckWatermarkResult(response.predictions) - - -class Image: - """Image.""" - - __module__ = "vertexai.vision_models" - - _loaded_bytes: Optional[bytes] = None - _loaded_image: Optional["PIL_Image.Image"] = None - - def __init__( - self, - image_bytes: Optional[bytes], - ): - """Creates an `Image` object. - - Args: - image_bytes: Image file bytes. Image can be in PNG or JPEG format. - """ - self._image_bytes = image_bytes - - @staticmethod - def load_from_file(location: os.PathLike) -> "Image": - """Loads image from local file or Google Cloud Storage. - - Args: - location: Local path or Google Cloud Storage uri from where to load - the image. - - Returns: - Loaded image as an `Image` object. - """ - # Load image from local path - image_bytes = pathlib.Path(location).read_bytes() - image = Image(image_bytes=image_bytes) - return image - - @property - def _image_bytes(self) -> bytes: - return self._loaded_bytes - - @_image_bytes.setter - def _image_bytes(self, value: bytes): - self._loaded_bytes = value - - @property - def _pil_image(self) -> "PIL_Image.Image": # type: ignore - if self._loaded_image is None: - if not PIL_Image: - raise RuntimeError( - "The PIL module is not available. Please install the Pillow package." - ) - self._loaded_image = PIL_Image.open(io.BytesIO(self._image_bytes)) - return self._loaded_image - - @property - def _size(self): - return self._pil_image.size - - @property - def _mime_type(self) -> str: - """Returns the MIME type of the image.""" - if PIL_Image: - return PIL_Image.MIME.get(self._pil_image.format, "image/jpeg") - # Fall back to jpeg - return "image/jpeg" - - def show(self): - """Shows the image. - - This method only works when in a notebook environment. - """ - if PIL_Image and IPython_display: - IPython_display.display(self._pil_image) - - def save(self, location: str): - """Saves image to a file. - - Args: - location: Local path where to save the image. - """ - pathlib.Path(location).write_bytes(self._image_bytes) - - def _as_base64_string(self) -> str: - """Encodes image using the base64 encoding. - - Returns: - Base64 encoding of the image as a string. - """ - # ! b64encode returns `bytes` object, not `str`. - # We need to convert `bytes` to `str`, otherwise we get service error: - # "received initial metadata size exceeds limit" - return base64.b64encode(self._image_bytes).decode("ascii") - - def _repr_png_(self): - return self._pil_image._repr_png_() # type:ignore - - check_watermark = check_watermark - - -class CheckWatermarkResult: - def __init__(self, predictions): - self._predictions = predictions - - @property - def decision(self): - return self._predictions[0]["decision"] - - def __str__(self): - return f"CheckWatermarkResult([{{'decision': {self.decision!r}}}])" - - def __bool__(self): - decision = self.decision - if decision == "ACCEPT": - return True - elif decision == "REJECT": - return False - else: - raise ValueError(f"Unrecognized result: {decision}") - class ImageGenerationModel: """Generates images from text prompt. @@ -417,20 +176,16 @@ def _generate_images( parameters["personGeneration"] = person_generation shared_generation_parameters["person_generation"] = person_generation - # This is to get around https://github.com/googleapis/proto-plus-python/issues/488 - pr = protos.PredictRequest.pb() - request = pr( - model=self.model_name, instances=[to_value(instance)], parameters=to_value(parameters) + response = self._client.predict( + model=self.model_name, instances=[instance], parameters=parameters ) - response = self._client.predict(request) - - generated_images: List["GeneratedImage"] = [] + generated_images: List[image_types.GeneratedImage] = [] for idx, prediction in enumerate(response.predictions): generation_parameters = dict(shared_generation_parameters) generation_parameters["index_of_image_in_batch"] = idx encoded_bytes = prediction.get("bytesBase64Encoded") - generated_image = GeneratedImage( + generated_image = image_types.GeneratedImage( image_bytes=base64.b64decode(encoded_bytes) if encoded_bytes else None, generation_parameters=generation_parameters, ) @@ -507,84 +262,12 @@ class ImageGenerationResponse: __module__ = "vertexai.preview.vision_models" - images: List["GeneratedImage"] + images: List[image_types.GeneratedImage] - def __iter__(self) -> typing.Iterator["GeneratedImage"]: + def __iter__(self) -> typing.Iterator[image_types.GeneratedImage]: """Iterates through the generated images.""" yield from self.images - def __getitem__(self, idx: int) -> "GeneratedImage": + def __getitem__(self, idx: int) -> image_types.GeneratedImage: """Gets the generated image by index.""" return self.images[idx] - - -_EXIF_USER_COMMENT_TAG_IDX = 0x9286 -_IMAGE_GENERATION_PARAMETERS_EXIF_KEY = ( - "google.cloud.vertexai.image_generation.image_generation_parameters" -) - - -class GeneratedImage(Image): - """Generated image.""" - - __module__ = "google.generativeai" - - def __init__( - self, - image_bytes: Optional[bytes], - generation_parameters: Dict[str, Any], - ): - """Creates a `GeneratedImage` object. - - Args: - image_bytes: Image file bytes. Image can be in PNG or JPEG format. - generation_parameters: Image generation parameter values. - """ - super().__init__(image_bytes=image_bytes) - self._generation_parameters = generation_parameters - - @property - def generation_parameters(self): - """Image generation parameters as a dictionary.""" - return self._generation_parameters - - @staticmethod - def load_from_file(location: os.PathLike) -> "GeneratedImage": - """Loads image from file. - - Args: - location: Local path from where to load the image. - - Returns: - Loaded image as a `GeneratedImage` object. - """ - base_image = Image.load_from_file(location=location) - exif = base_image._pil_image.getexif() # pylint: disable=protected-access - exif_comment_dict = json.loads(exif[_EXIF_USER_COMMENT_TAG_IDX]) - generation_parameters = exif_comment_dict[_IMAGE_GENERATION_PARAMETERS_EXIF_KEY] - return GeneratedImage( - image_bytes=base_image._image_bytes, # pylint: disable=protected-access - generation_parameters=generation_parameters, - ) - - def save(self, location: str, include_generation_parameters: bool = True): - """Saves image to a file. - - Args: - location: Local path where to save the image. - include_generation_parameters: Whether to include the image - generation parameters in the image's EXIF metadata. - """ - if include_generation_parameters: - if not self._generation_parameters: - raise ValueError("Image does not have generation parameters.") - if not PIL_Image: - raise ValueError("The PIL module is required for saving generation parameters.") - - exif = self._pil_image.getexif() - exif[_EXIF_USER_COMMENT_TAG_IDX] = json.dumps( - {_IMAGE_GENERATION_PARAMETERS_EXIF_KEY: self._generation_parameters} - ) - self._pil_image.save(location, exif=exif) - else: - super().save(location=location) diff --git a/tests/test_content.py b/tests/test_content.py index 2031e40ae..8bec14a9c 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -22,6 +22,8 @@ from absl.testing import parameterized from google.generativeai import protos from google.generativeai.types import content_types +from google.generativeai.types import image_types +from google.generativeai.types.image_types import _image_types import IPython.display import PIL.Image @@ -90,7 +92,7 @@ class UnitTests(parameterized.TestCase): ["P", PIL.Image.fromarray(np.zeros([6, 6, 3], dtype=np.uint8)).convert("P")], ) def test_numpy_to_blob(self, image): - blob = content_types.image_to_blob(image) + blob = _image_types.image_to_blob(image) self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/webp") self.assertStartsWith(blob.data, b"RIFF \x00\x00\x00WEBPVP8L") @@ -98,9 +100,10 @@ def test_numpy_to_blob(self, image): @parameterized.named_parameters( ["PIL", PIL.Image.open(TEST_PNG_PATH)], ["IPython", IPython.display.Image(filename=TEST_PNG_PATH)], + ["image_types.Image", image_types.Image.load_from_file(TEST_PNG_PATH)], ) def test_png_to_blob(self, image): - blob = content_types.image_to_blob(image) + blob = _image_types.image_to_blob(image) self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @@ -108,9 +111,10 @@ def test_png_to_blob(self, image): @parameterized.named_parameters( ["PIL", PIL.Image.open(TEST_JPG_PATH)], ["IPython", IPython.display.Image(filename=TEST_JPG_PATH)], + ["image_types.Image", image_types.Image.load_from_file(TEST_JPG_PATH)], ) def test_jpg_to_blob(self, image): - blob = content_types.image_to_blob(image) + blob = _image_types.image_to_blob(image) self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/jpeg") self.assertStartsWith(blob.data, b"\xff\xd8\xff\xe0\x00\x10JFIF") @@ -118,9 +122,10 @@ def test_jpg_to_blob(self, image): @parameterized.named_parameters( ["PIL", PIL.Image.open(TEST_GIF_PATH)], ["IPython", IPython.display.Image(filename=TEST_GIF_PATH)], + ["image_types.Image", image_types.Image.load_from_file(TEST_GIF_PATH)], ) def test_gif_to_blob(self, image): - blob = content_types.image_to_blob(image) + blob = _image_types.image_to_blob(image) self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/gif") self.assertStartsWith(blob.data, b"GIF87a")