From 7d3d242c2bd3efede7e719ee88f4c4c6045b75e5 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 11 Jun 2024 11:21:17 -0400 Subject: [PATCH 1/5] Add image embedding support to TFHub MLTransforms --- .../transforms/embeddings/tensorflow_hub.py | 42 ++++++++++++- .../embeddings/tensorflow_hub_test.py | 61 +++++++++++++++++++ 2 files changed, 102 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py index 9e4480788257..61a6e93d4a22 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py @@ -29,9 +29,10 @@ from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor from apache_beam.ml.inference.tensorflow_inference import default_tensor_inference_fn from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _ImageEmbeddingHandler from apache_beam.ml.transforms.base import _TextEmbeddingHandler -__all__ = ['TensorflowHubTextEmbeddings'] +__all__ = ['TensorflowHubTextEmbeddings', 'TensorflowHubImageEmbeddings'] # TODO: https://github.com/apache/beam/issues/30288 @@ -132,3 +133,42 @@ def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: model_handler=_TextEmbeddingHandler(self), inference_args=self.inference_args, )) + + +class TensorflowHubImageEmbeddings(EmbeddingsManager): + def __init__(self, columns: List[str], hub_url: str, **kwargs): + """ + Embedding config for tensorflow hub models. This config can be used with + MLTransform to embed image data. Models are loaded using the RunInference + PTransform with the help of a ModelHandler. + + Args: + columns: The columns containing the images to be embedded. + hub_url: The url of the tensorflow hub model. + min_batch_size: The minimum batch size to be used for inference. + max_batch_size: The maximum batch size to be used for inference. + large_model: Whether to share the model across processes. + """ + super().__init__(columns=columns, **kwargs) + self.model_uri = hub_url + + def get_model_handler(self) -> ModelHandler: + # override the default inference function + return _TensorflowHubModelHandler( + model_uri=self.model_uri, + preprocessing_url=None, + min_batch_size=self.min_batch_size, + max_batch_size=self.max_batch_size, + large_model=self.large_model, + ) + + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: + """ + Returns a RunInference object that is used to run inference on the text + input using _TextEmbeddingHandler. + """ + return ( + RunInference( + model_handler=_ImageEmbeddingHandler(self), + inference_args=self.inference_args, + )) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py index b08ca8e2d8ea..ea52cf04964d 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py @@ -20,18 +20,23 @@ import unittest import uuid +import numpy as np + import apache_beam as beam from apache_beam.ml.transforms.base import MLTransform hub_url = 'https://tfhub.dev/google/nnlm-en-dim128/2' +hub_img_url = 'https://www.kaggle.com/models/google/resnet-v2/TensorFlow2/101-feature-vector/2' # pylint: disable=line-too-long test_query_column = 'test_query' test_query = 'This is a test query' # pylint: disable=ungrouped-imports try: from apache_beam.ml.transforms.embeddings.tensorflow_hub import TensorflowHubTextEmbeddings + import tensorflow as tf except ImportError: TensorflowHubTextEmbeddings = None # type: ignore + tf = None # type: ignore # pylint: disable=ungrouped-imports try: @@ -40,6 +45,14 @@ except ImportError: tft = None +# pylint: disable=ungrouped-imports +try: + from apache_beam.ml.transforms.embeddings.tensorflow_hub import TensorflowHubImageEmbeddings + from PIL import Image +except: + TensorflowHubImageEmbeddings = None # type: ignore + Image = None # type: ignore + @unittest.skipIf( TensorflowHubTextEmbeddings is None, 'Tensorflow is not installed.') @@ -161,6 +174,54 @@ def test_with_int_data_types(self): embedding_config)) +@unittest.skipIf( + TensorflowHubImageEmbeddings is None, 'Tensorflow is not installed.') +class TFHubImageEmbeddingsTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self) -> None: + shutil.rmtree(self.artifact_location) + + def generateRandomImage(self, size: int): + imarray = np.random.rand(size, size, 3) * 255 + return imarray / 255.0 + + @unittest.skipIf(Image is None, 'Pillow is not installed.') + def test_sentence_transformer_image_embeddings(self): + embedding_config = TensorflowHubImageEmbeddings( + hub_url=hub_img_url, columns=[test_query_column]) + img = self.generateRandomImage(224) + with beam.Pipeline() as pipeline: + result_pcoll = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: img + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + def assert_element(element): + assert len(element[test_query_column]) == 2048 + + _ = (result_pcoll | beam.Map(assert_element)) + + def test_with_str_data_types(self): + embedding_config = TensorflowHubImageEmbeddings( + hub_url=hub_img_url, columns=[test_query_column]) + with self.assertRaises(TypeError): + with beam.Pipeline() as pipeline: + _ = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: "img.jpg" + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + @unittest.skipIf( TensorflowHubTextEmbeddings is None, 'Tensorflow is not installed.') class TFHubEmbeddingsGCSArtifactLocationTest(TFHubEmbeddingsTest): From d58104d1671530b96565b5562ec9d4b4ab4f50a5 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 11 Jun 2024 11:44:27 -0400 Subject: [PATCH 2/5] linting --- .../ml/transforms/embeddings/tensorflow_hub_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py index ea52cf04964d..3cc77c911034 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py @@ -33,7 +33,6 @@ # pylint: disable=ungrouped-imports try: from apache_beam.ml.transforms.embeddings.tensorflow_hub import TensorflowHubTextEmbeddings - import tensorflow as tf except ImportError: TensorflowHubTextEmbeddings = None # type: ignore tf = None # type: ignore @@ -49,7 +48,7 @@ try: from apache_beam.ml.transforms.embeddings.tensorflow_hub import TensorflowHubImageEmbeddings from PIL import Image -except: +except ImportError: TensorflowHubImageEmbeddings = None # type: ignore Image = None # type: ignore From e54fb0ddfa48df80adb775b11129ddab0d12bf70 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 11 Jun 2024 12:11:13 -0400 Subject: [PATCH 3/5] more linting --- .../ml/transforms/embeddings/tensorflow_hub_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py index 3cc77c911034..72bd072e4f43 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py @@ -35,7 +35,7 @@ from apache_beam.ml.transforms.embeddings.tensorflow_hub import TensorflowHubTextEmbeddings except ImportError: TensorflowHubTextEmbeddings = None # type: ignore - tf = None # type: ignore + tf = None # pylint: disable=ungrouped-imports try: @@ -50,7 +50,7 @@ from PIL import Image except ImportError: TensorflowHubImageEmbeddings = None # type: ignore - Image = None # type: ignore + Image = None @unittest.skipIf( From 4d138fdcd225f355257658b94049351735ecf32a Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 11 Jun 2024 12:25:54 -0400 Subject: [PATCH 4/5] formatting --- .../apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py index 72bd072e4f43..24bca5155fa7 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py @@ -50,7 +50,7 @@ from PIL import Image except ImportError: TensorflowHubImageEmbeddings = None # type: ignore - Image = None + Image = None @unittest.skipIf( From 196e3bc840dddb3041430e7f8964054da201869a Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Wed, 12 Jun 2024 10:11:21 -0400 Subject: [PATCH 5/5] typo --- .../apache_beam/ml/transforms/embeddings/tensorflow_hub.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py index 61a6e93d4a22..f78ddf3ff04a 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py @@ -165,7 +165,7 @@ def get_model_handler(self) -> ModelHandler: def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: """ Returns a RunInference object that is used to run inference on the text - input using _TextEmbeddingHandler. + input using _ImageEmbeddingHandler. """ return ( RunInference(