Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add image embedding support to TFHub MLTransforms #31564

Merged
merged 5 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor
from apache_beam.ml.inference.tensorflow_inference import default_tensor_inference_fn
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _ImageEmbeddingHandler
from apache_beam.ml.transforms.base import _TextEmbeddingHandler

__all__ = ['TensorflowHubTextEmbeddings']
__all__ = ['TensorflowHubTextEmbeddings', 'TensorflowHubImageEmbeddings']


# TODO: https://github.com/apache/beam/issues/30288
Expand Down Expand Up @@ -132,3 +133,42 @@ def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
model_handler=_TextEmbeddingHandler(self),
inference_args=self.inference_args,
))


class TensorflowHubImageEmbeddings(EmbeddingsManager):
def __init__(self, columns: List[str], hub_url: str, **kwargs):
"""
Embedding config for tensorflow hub models. This config can be used with
MLTransform to embed image data. Models are loaded using the RunInference
PTransform with the help of a ModelHandler.

Args:
columns: The columns containing the images to be embedded.
hub_url: The url of the tensorflow hub model.
min_batch_size: The minimum batch size to be used for inference.
max_batch_size: The maximum batch size to be used for inference.
large_model: Whether to share the model across processes.
"""
super().__init__(columns=columns, **kwargs)
self.model_uri = hub_url

def get_model_handler(self) -> ModelHandler:
# override the default inference function
return _TensorflowHubModelHandler(
model_uri=self.model_uri,
preprocessing_url=None,
min_batch_size=self.min_batch_size,
max_batch_size=self.max_batch_size,
large_model=self.large_model,
)

def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
"""
Returns a RunInference object that is used to run inference on the text
input using _ImageEmbeddingHandler.
"""
return (
RunInference(
model_handler=_ImageEmbeddingHandler(self),
inference_args=self.inference_args,
))
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
import unittest
import uuid

import numpy as np

import apache_beam as beam
from apache_beam.ml.transforms.base import MLTransform

hub_url = 'https://tfhub.dev/google/nnlm-en-dim128/2'
hub_img_url = 'https://www.kaggle.com/models/google/resnet-v2/TensorFlow2/101-feature-vector/2' # pylint: disable=line-too-long
test_query_column = 'test_query'
test_query = 'This is a test query'

Expand All @@ -32,6 +35,7 @@
from apache_beam.ml.transforms.embeddings.tensorflow_hub import TensorflowHubTextEmbeddings
except ImportError:
TensorflowHubTextEmbeddings = None # type: ignore
tf = None

# pylint: disable=ungrouped-imports
try:
Expand All @@ -40,6 +44,14 @@
except ImportError:
tft = None

# pylint: disable=ungrouped-imports
try:
from apache_beam.ml.transforms.embeddings.tensorflow_hub import TensorflowHubImageEmbeddings
from PIL import Image
except ImportError:
TensorflowHubImageEmbeddings = None # type: ignore
Image = None


@unittest.skipIf(
TensorflowHubTextEmbeddings is None, 'Tensorflow is not installed.')
Expand Down Expand Up @@ -161,6 +173,54 @@ def test_with_int_data_types(self):
embedding_config))


@unittest.skipIf(
TensorflowHubImageEmbeddings is None, 'Tensorflow is not installed.')
class TFHubImageEmbeddingsTest(unittest.TestCase):
def setUp(self) -> None:
self.artifact_location = tempfile.mkdtemp()

def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)

def generateRandomImage(self, size: int):
imarray = np.random.rand(size, size, 3) * 255
return imarray / 255.0

@unittest.skipIf(Image is None, 'Pillow is not installed.')
def test_sentence_transformer_image_embeddings(self):
embedding_config = TensorflowHubImageEmbeddings(
hub_url=hub_img_url, columns=[test_query_column])
img = self.generateRandomImage(224)
with beam.Pipeline() as pipeline:
result_pcoll = (
pipeline
| "CreateData" >> beam.Create([{
test_query_column: img
}])
| "MLTransform" >> MLTransform(
write_artifact_location=self.artifact_location).with_transform(
embedding_config))

def assert_element(element):
assert len(element[test_query_column]) == 2048

_ = (result_pcoll | beam.Map(assert_element))

def test_with_str_data_types(self):
embedding_config = TensorflowHubImageEmbeddings(
hub_url=hub_img_url, columns=[test_query_column])
with self.assertRaises(TypeError):
with beam.Pipeline() as pipeline:
_ = (
pipeline
| "CreateData" >> beam.Create([{
test_query_column: "img.jpg"
}])
| "MLTransform" >> MLTransform(
write_artifact_location=self.artifact_location).with_transform(
embedding_config))


@unittest.skipIf(
TensorflowHubTextEmbeddings is None, 'Tensorflow is not installed.')
class TFHubEmbeddingsGCSArtifactLocationTest(TFHubEmbeddingsTest):
Expand Down
Loading