Skip to content

Commit

Permalink
make check_watermark a stand alone function
Browse files Browse the repository at this point in the history
Change-Id: I2d72620359dcc70fe8e720a14f78d83f75a42d90
  • Loading branch information
MarkDaoust committed Oct 17, 2024
1 parent 63e0501 commit 17b5d56
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 36 deletions.
1 change: 1 addition & 0 deletions google/generativeai/vision_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Classes for working with vision models."""

from google.generativeai.vision_models._vision_models import (
check_watermark,
Image,
GeneratedImage,
ImageGenerationModel,
Expand Down
82 changes: 46 additions & 36 deletions google/generativeai/vision_models/_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,50 @@ def to_mapping_value(value) -> struct_pb2.Struct:
ImageLikeType = Union["Image", pathlib.Path, content_types.ImageType]


def check_watermark(
img: ImageLikeType, model_id: str = "models/image-verification-001"
) -> "CheckWatermarkResult":
"""Checks if an image has a Google-AI watermark.
Args:
img: can be a `pathlib.Path` or a `PIL.Image.Image`, `IPythin.display.Image`, or `google.generativeai.Image`.
model_id: Which version of the image-verification model to send the image to.
Returns:
"""
if isinstance(img, Image):
pass
elif isinstance(img, pathlib.Path):
img = Image.load_from_file(img)
elif IPython_display is not None and isinstance(img, IPython_display.Image):
img = Image(image_bytes=img.data)
elif PIL_Image is not None and isinstance(img, PIL_Image.Image):
blob = content_types._pil_to_blob(img)
img = Image(image_bytes=blob.data)
elif isinstance(img, protos.Blob):
img = Image(image_bytes=img.data)
else:
raise TypeError(
f"Not implemented: Could not convert a {type(img)} into `Image`\n {img=}"
)

prediction_client = client.get_default_prediction_client()
if not model_id.startswith("models/"):
model_id = f"models/{model_id}"

instance = {"image": {"bytesBase64Encoded": base64.b64encode(img._loaded_bytes).decode()}}
parameters = {"watermarkVerification": True}

# This is to get around https://github.com/googleapis/proto-plus-python/issues/488
pr = protos.PredictRequest.pb()
request = pr(model=model_id, instances=[to_value(instance)], parameters=to_value(parameters))

response = prediction_client.predict(request)

return CheckWatermarkResult(response.predictions)


class Image:
"""Image."""

Expand Down Expand Up @@ -209,41 +253,7 @@ def _as_base64_string(self) -> str:
def _repr_png_(self):
return self._pil_image._repr_png_() # type:ignore

def check_watermark(self: ImageLikeType, model_id: str = "models/image-verification-001"):
img = None
if isinstance(self, Image):
img = self
elif isinstance(self, pathlib.Path):
img = Image.load_from_file(self)
elif IPython_display is not None and isinstance(self, IPython_display.Image):
img = Image(image_bytes=self.data)
elif PIL_Image is not None and isinstance(self, PIL_Image.Image):
blob = content_types._pil_to_blob(self)
img = Image(image_bytes=blob.data)
elif isinstance(self, protos.Blob):
img = Image(image_bytes=self.data)
else:
raise TypeError(
f"Not implemented: Could not convert a {type(img)} into `Image`\n {img=}"
)

prediction_client = client.get_default_prediction_client()
if not model_id.startswith("models/"):
model_id = f"models/{model_id}"

# Note: Only a single prompt is supported by the service.
instance = {"image": {"bytesBase64Encoded": base64.b64encode(img._loaded_bytes).decode()}}
parameters = {"watermarkVerification": True}

# This is to get around https://github.com/googleapis/proto-plus-python/issues/488
pr = protos.PredictRequest.pb()
request = pr(
model=model_id, instances=[to_value(instance)], parameters=to_value(parameters)
)

response = prediction_client.predict(request)

return CheckWatermarkResult(response.predictions)
check_watermark = check_watermark


class CheckWatermarkResult:
Expand Down Expand Up @@ -539,7 +549,7 @@ def generation_parameters(self):
return self._generation_parameters

@staticmethod
def load_from_file(location: str) -> "GeneratedImage":
def load_from_file(location: os.PathLike) -> "GeneratedImage":
"""Loads image from file.
Args:
Expand Down

0 comments on commit 17b5d56

Please sign in to comment.