Skip to content

Commit

Permalink
Implement Hugging Face Image Embedding MLTransform (#31536)
Browse files Browse the repository at this point in the history
* Implement Hugging Face Image Embedding MLTransform

* correct imports

* Simplify to original sentence transformer class
  • Loading branch information
jrmccluskey committed Jun 7, 2024
1 parent 9edd413 commit fe54c21
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
19 changes: 16 additions & 3 deletions sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _ImageEmbeddingHandler
from apache_beam.ml.transforms.base import _TextEmbeddingHandler

try:
Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(
model_name: str,
columns: List[str],
max_seq_length: Optional[int] = None,
image_model: bool = False,
**kwargs):
"""
Embedding config for sentence-transformers. This config can be used with
Expand All @@ -122,16 +124,21 @@ def __init__(
Args:
model_name: Name of the model to use. The model should be hosted on
HuggingFace Hub or compatible with sentence_transformers.
HuggingFace Hub or compatible with sentence_transformers. For image
embedding models, see
https://www.sbert.net/docs/sentence_transformer/pretrained_models.html#image-text-models # pylint: disable=line-too-long
for a list of available sentence_transformers models.
columns: List of columns to be embedded.
max_seq_length: Max sequence length to use for the model if applicable.
image_model: Whether the model is generating image embeddings.
min_batch_size: The minimum batch size to be used for inference.
max_batch_size: The maximum batch size to be used for inference.
large_model: Whether to share the model across processes.
"""
super().__init__(columns, **kwargs)
self.model_name = model_name
self.max_seq_length = max_seq_length
self.image_model = image_model

def get_model_handler(self):
return _SentenceTransformerModelHandler(
Expand All @@ -144,8 +151,14 @@ def get_model_handler(self):
large_model=self.large_model)

def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
# wrap the model handler in a _TextEmbeddingHandler since
# the SentenceTransformerEmbeddings works on text input data.
# wrap the model handler in an appropriate embedding handler to provide
# some type checking.
if self.image_model:
return (
RunInference(
model_handler=_ImageEmbeddingHandler(self),
inference_args=self.inference_args,
))
return (
RunInference(
model_handler=_TextEmbeddingHandler(self),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
try:
from apache_beam.ml.transforms.embeddings.huggingface import SentenceTransformerEmbeddings
from apache_beam.ml.transforms.embeddings.huggingface import InferenceAPIEmbeddings
from PIL import Image
import torch
except ImportError:
SentenceTransformerEmbeddings = None # type: ignore
Expand All @@ -46,10 +47,17 @@
except ImportError:
tft = None

# pylint: disable=ungrouped-imports
try:
from PIL import Image
except ImportError:
Image = None

_HF_TOKEN = os.environ.get('HF_INFERENCE_TOKEN')
test_query = "This is a test"
test_query_column = "feature_1"
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
IMAGE_MODEL_NAME = "clip-ViT-B-32"
_parameterized_inputs = [
([{
test_query_column: 'That is a happy person'
Expand Down Expand Up @@ -85,7 +93,7 @@
@unittest.skipIf(
SentenceTransformerEmbeddings is None,
'sentence-transformers is not installed.')
class SentenceTrasformerEmbeddingsTest(unittest.TestCase):
class SentenceTransformerEmbeddingsTest(unittest.TestCase):
def setUp(self) -> None:
self.artifact_location = tempfile.mkdtemp(prefix='sentence_transformers_')
# this bucket has TTL and will be deleted periodically
Expand Down Expand Up @@ -277,6 +285,48 @@ def test_mltransform_to_ptransform_with_sentence_transformer(self):
self.assertEqual(
ptransform_list[i]._model_handler._underlying.model_name, model_name)

def generateRandomImage(self, size: int):
imarray = np.random.rand(size, size, 3) * 255
return Image.fromarray(imarray.astype('uint8')).convert('RGBA')

@unittest.skipIf(Image is None, 'Pillow is not installed.')
def test_sentence_transformer_image_embeddings(self):
embedding_config = SentenceTransformerEmbeddings(
model_name=IMAGE_MODEL_NAME,
columns=[test_query_column],
image_model=True)
img = self.generateRandomImage(256)
with beam.Pipeline() as pipeline:
result_pcoll = (
pipeline
| "CreateData" >> beam.Create([{
test_query_column: img
}])
| "MLTransform" >> MLTransform(
write_artifact_location=self.artifact_location).with_transform(
embedding_config))

def assert_element(element):
assert len(element[test_query_column]) == 512

_ = (result_pcoll | beam.Map(assert_element))

def test_sentence_transformer_images_with_str_data_types(self):
embedding_config = SentenceTransformerEmbeddings(
model_name=IMAGE_MODEL_NAME,
columns=[test_query_column],
image_model=True)
with self.assertRaises(TypeError):
with beam.Pipeline() as pipeline:
_ = (
pipeline
| "CreateData" >> beam.Create([{
test_query_column: "image.jpg"
}])
| "MLTransform" >> MLTransform(
write_artifact_location=self.artifact_location).with_transform(
embedding_config))


@unittest.skipIf(_HF_TOKEN is None, 'HF_TOKEN environment variable not set.')
class HuggingfaceInferenceAPITest(unittest.TestCase):
Expand Down

0 comments on commit fe54c21

Please sign in to comment.