Skip to content

Commit

Permalink
Image feedback (#608)
Browse files Browse the repository at this point in the history
* Allow the use of generated images as inputs

Change-Id: I0956fb78272a8a8af2c5219d80a26dec944040a8

# Conflicts:
#	google/generativeai/vision_models/_vision_models.py

* types + formatting

Change-Id: I0cac4ba1de764d3c02c5eab7556d8324aeda1f93

* add files

Change-Id: Ie7f91cef171c1f813b52ff1b2a4daedf7ea19edd

* Fix 3.9

Change-Id: If9ff9ebc0b2bf16b91e741d862a9e2808c7a738a

* Fix 3.9

Change-Id: Iee02352ca21fa66da9b097d4dfa9454b67609e79

* fix pytype

Change-Id: Ic5c250f3f3ded2374abfbdbee6d62ea4cfb0f799

* fix pytype

Change-Id: I431c66e45e7582218b5de7a90eeeee01b80df664

* typo

Change-Id: I1bb15e1363c652f9c0b4a60dad834fce65a4f0a1

* reapply commits lost in merge

Change-Id: I7bfebdeaa217d93ed5d11aca31cf0b20afd38c02

* Update google/generativeai/client.py

* Remove GCS reference

Change-Id: I5c1b8cbccee0e13d8aca70582a76e0c089e040ed

* black .

Change-Id: I2c24f8798cb8103d35474e7e6d2e4fc3100825aa
  • Loading branch information
MarkDaoust authored Nov 1, 2024
1 parent aae0caf commit c72aa6b
Show file tree
Hide file tree
Showing 7 changed files with 443 additions and 437 deletions.
95 changes: 85 additions & 10 deletions google/generativeai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import os
import contextlib
import inspect
import collections
import dataclasses
import pathlib
from typing import Any, cast
from collections.abc import Sequence
from collections.abc import Sequence, Mapping
import httplib2
from io import IOBase

Expand All @@ -23,6 +24,11 @@
import googleapiclient.http
import googleapiclient.discovery

from google.protobuf import struct_pb2

from proto.marshal.collections import maps
from proto.marshal.collections import repeated

try:
from google.generativeai import version

Expand Down Expand Up @@ -130,6 +136,70 @@ async def create_file(self, *args, **kwargs):
)


# This is to get around https://github.com/googleapis/proto-plus-python/issues/488
def to_value(value) -> struct_pb2.Value:
"""Return a protobuf Value object representing this value."""
if isinstance(value, struct_pb2.Value):
return value
if value is None:
return struct_pb2.Value(null_value=0)
if isinstance(value, bool):
return struct_pb2.Value(bool_value=value)
if isinstance(value, (int, float)):
return struct_pb2.Value(number_value=float(value))
if isinstance(value, str):
return struct_pb2.Value(string_value=value)
if isinstance(value, collections.abc.Sequence):
return struct_pb2.Value(list_value=to_list_value(value))
if isinstance(value, collections.abc.Mapping):
return struct_pb2.Value(struct_value=to_mapping_value(value))
raise ValueError("Unable to coerce value: %r" % value)


def to_list_value(value) -> struct_pb2.ListValue:
# We got a proto, or else something we sent originally.
# Preserve the instance we have.
if isinstance(value, struct_pb2.ListValue):
return value
if isinstance(value, repeated.RepeatedComposite):
return struct_pb2.ListValue(values=[v for v in value.pb])

# We got a list (or something list-like); convert it.
return struct_pb2.ListValue(values=[to_value(v) for v in value])


def to_mapping_value(value) -> struct_pb2.Struct:
# We got a proto, or else something we sent originally.
# Preserve the instance we have.
if isinstance(value, struct_pb2.Struct):
return value
if isinstance(value, maps.MapComposite):
return struct_pb2.Struct(
fields={k: v for k, v in value.pb.items()},
)

# We got a dict (or something dict-like); convert it.
return struct_pb2.Struct(fields={k: to_value(v) for k, v in value.items()})


class PredictionServiceClient(glm.PredictionServiceClient):
def predict(self, model=None, instances=None, parameters=None):
pr = protos.PredictRequest.pb()
request = pr(
model=model, instances=[to_value(i) for i in instances], parameters=to_value(parameters)
)
return super().predict(request)


class PredictionServiceAsyncClient(glm.PredictionServiceAsyncClient):
async def predict(self, model=None, instances=None, parameters=None):
pr = protos.PredictRequest.pb()
request = pr(
model=model, instances=[to_value(i) for i in instances], parameters=to_value(parameters)
)
return await super().predict(request)


@dataclasses.dataclass
class _ClientManager:
client_config: dict[str, Any] = dataclasses.field(default_factory=dict)
Expand Down Expand Up @@ -220,15 +290,20 @@ def configure(
self.clients = {}

def make_client(self, name):
if name == "file":
cls = FileServiceClient
elif name == "file_async":
cls = FileServiceAsyncClient
elif name.endswith("_async"):
name = name.split("_")[0]
cls = getattr(glm, name.title() + "ServiceAsyncClient")
else:
cls = getattr(glm, name.title() + "ServiceClient")
local_clients = {
"file": FileServiceClient,
"file_async": FileServiceAsyncClient,
"prediction": PredictionServiceClient,
"prediction_async": PredictionServiceAsyncClient,
}
cls = local_clients.get(name, None)

if cls is None:
if name.endswith("_async"):
name = name.split("_")[0]
cls = getattr(glm, name.title() + "ServiceAsyncClient")
else:
cls = getattr(glm, name.title() + "ServiceClient")

# Attempt to configure using defaults.
if not self.client_config:
Expand Down
98 changes: 4 additions & 94 deletions google/generativeai/types/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,16 @@
from __future__ import annotations

from collections.abc import Iterable, Mapping, Sequence
import io
import inspect
import mimetypes
import pathlib
import typing
from typing import Any, Callable, Union
from typing_extensions import TypedDict

import pydantic

from google.generativeai.types import file_types
from google.generativeai.types.image_types import _image_types
from google.generativeai import protos

if typing.TYPE_CHECKING:
import PIL.Image
import PIL.ImageFile
import IPython.display

IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image)
ImageType = PIL.Image.Image | IPython.display.Image
else:
IMAGE_TYPES = ()
try:
import PIL.Image
import PIL.ImageFile

IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,)
except ImportError:
PIL = None

try:
import IPython.display

IMAGE_TYPES = IMAGE_TYPES + (IPython.display.Image,)
except ImportError:
IPython = None

ImageType = Union["PIL.Image.Image", "IPython.display.Image"]


__all__ = [
"BlobDict",
Expand Down Expand Up @@ -97,62 +68,6 @@ def to_mode(x: ModeOptions) -> Mode:
return _MODE[x]


def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob:
# If the image is a local file, return a file-based blob without any modification.
# Otherwise, return a lossless WebP blob (same quality with optimized size).
def file_blob(image: PIL.Image.Image) -> protos.Blob | None:
if not isinstance(image, PIL.ImageFile.ImageFile) or image.filename is None:
return None
filename = str(image.filename)
if not pathlib.Path(filename).is_file():
return None

mime_type = image.get_format_mimetype()
image_bytes = pathlib.Path(filename).read_bytes()

return protos.Blob(mime_type=mime_type, data=image_bytes)

def webp_blob(image: PIL.Image.Image) -> protos.Blob:
# Reference: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp
image_io = io.BytesIO()
image.save(image_io, format="webp", lossless=True)
image_io.seek(0)

mime_type = "image/webp"
image_bytes = image_io.read()

return protos.Blob(mime_type=mime_type, data=image_bytes)

return file_blob(image) or webp_blob(image)


def image_to_blob(image: ImageType) -> protos.Blob:
if PIL is not None:
if isinstance(image, PIL.Image.Image):
return _pil_to_blob(image)

if IPython is not None:
if isinstance(image, IPython.display.Image):
name = image.filename
if name is None:
raise ValueError(
"Conversion failed. The `IPython.display.Image` can only be converted if "
"it is constructed from a local file. Please ensure you are using the format: Image(filename='...')."
)
mime_type, _ = mimetypes.guess_type(name)
if mime_type is None:
mime_type = "image/unknown"

return protos.Blob(mime_type=mime_type, data=image.data)

raise TypeError(
"Image conversion failed. The input was expected to be of type `Image` "
"(either `PIL.Image.Image` or `IPython.display.Image`).\n"
f"However, received an object of type: {type(image)}.\n"
f"Object Value: {image}"
)


class BlobDict(TypedDict):
mime_type: str
data: bytes
Expand Down Expand Up @@ -189,12 +104,7 @@ def is_blob_dict(d):
return "mime_type" in d and "data" in d


if typing.TYPE_CHECKING:
BlobType = Union[
protos.Blob, BlobDict, PIL.Image.Image, IPython.display.Image
] # Any for the images
else:
BlobType = Union[protos.Blob, BlobDict, Any]
BlobType = Union[protos.Blob, BlobDict, _image_types.ImageType] # Any for the images


def to_blob(blob: BlobType) -> protos.Blob:
Expand All @@ -203,8 +113,8 @@ def to_blob(blob: BlobType) -> protos.Blob:

if isinstance(blob, protos.Blob):
return blob
elif isinstance(blob, IMAGE_TYPES):
return image_to_blob(blob)
elif isinstance(blob, _image_types.IMAGE_TYPES):
return _image_types.image_to_blob(blob)
else:
if isinstance(blob, Mapping):
raise KeyError(
Expand Down
1 change: 1 addition & 0 deletions google/generativeai/types/image_types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from google.generativeai.types.image_types._image_types import *
Loading

0 comments on commit c72aa6b

Please sign in to comment.