Skip to content

Commit

Permalink
Simplify to original sentence transformer class
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmccluskey committed Jun 7, 2024
1 parent e6d6719 commit 40fa9df
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 66 deletions.
57 changes: 15 additions & 42 deletions sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
model_name: str,
columns: List[str],
max_seq_length: Optional[int] = None,
image_model: bool = False,
**kwargs):
"""
Embedding config for sentence-transformers. This config can be used with
Expand All @@ -123,16 +124,21 @@ def __init__(
Args:
model_name: Name of the model to use. The model should be hosted on
HuggingFace Hub or compatible with sentence_transformers.
HuggingFace Hub or compatible with sentence_transformers. For image
embedding models, see
https://www.sbert.net/docs/sentence_transformer/pretrained_models.html#image-text-models # pylint: disable=line-too-long
for a list of available sentence_transformers models.
columns: List of columns to be embedded.
max_seq_length: Max sequence length to use for the model if applicable.
image_model: Whether the model is generating image embeddings.
min_batch_size: The minimum batch size to be used for inference.
max_batch_size: The maximum batch size to be used for inference.
large_model: Whether to share the model across processes.
"""
super().__init__(columns, **kwargs)
self.model_name = model_name
self.max_seq_length = max_seq_length
self.image_model = image_model

def get_model_handler(self):
return _SentenceTransformerModelHandler(
Expand All @@ -145,54 +151,21 @@ def get_model_handler(self):
large_model=self.large_model)

def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
# wrap the model handler in a _TextEmbeddingHandler since
# the SentenceTransformerEmbeddings works on text input data.
# wrap the model handler in an appropriate embedding handler to provide
# some type checking.
if self.image_model:
return (
RunInference(
model_handler=_ImageEmbeddingHandler(self),
inference_args=self.inference_args,
))
return (
RunInference(
model_handler=_TextEmbeddingHandler(self),
inference_args=self.inference_args,
))


class SentenceTransformerImageEmbeddings(EmbeddingsManager):
def __init__(self, model_name: str, columns: List[str], **kwargs):
"""
Embedding config for sentence-transformers. This config can be used with
MLTransform to embed image data. Models are loaded using the RunInference
PTransform with the help of ModelHandler.
Args:
model_name: Name of the model to use. The model should be hosted on
HuggingFace Hub or compatible with sentence_transformers. See
https://www.sbert.net/docs/sentence_transformer/pretrained_models.html#image-text-models # pylint: disable=line-too-long
for a list of sentence_transformers models.
columns: List of columns to be embedded.
min_batch_size: The minimum batch size to be used for inference.
max_batch_size: The maximum batch size to be used for inference.
large_model: Whether to share the model across processes.
"""
super().__init__(columns, **kwargs)
self.model_name = model_name

def get_model_handler(self):
return _SentenceTransformerModelHandler(
model_class=SentenceTransformer,
model_name=self.model_name,
load_model_args=self.load_model_args,
min_batch_size=self.min_batch_size,
max_batch_size=self.max_batch_size,
large_model=self.large_model)

def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
# wrap the model handler in a _TextEmbeddingHandler since
# the SentenceTransformerEmbeddings works on text input data.
return (
RunInference(
model_handler=_ImageEmbeddingHandler(self),
inference_args=self.inference_args,
))


class _InferenceAPIHandler(ModelHandler):
def __init__(self, config: 'InferenceAPIEmbeddings'):
super().__init__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
# pylint: disable=ungrouped-imports
try:
from apache_beam.ml.transforms.embeddings.huggingface import SentenceTransformerEmbeddings
from apache_beam.ml.transforms.embeddings.huggingface import SentenceTransformerImageEmbeddings
from apache_beam.ml.transforms.embeddings.huggingface import InferenceAPIEmbeddings
from PIL import Image
import torch
Expand All @@ -58,6 +57,7 @@
test_query = "This is a test"
test_query_column = "feature_1"
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
IMAGE_MODEL_NAME = "clip-ViT-B-32"
_parameterized_inputs = [
([{
test_query_column: 'That is a happy person'
Expand Down Expand Up @@ -93,7 +93,7 @@
@unittest.skipIf(
SentenceTransformerEmbeddings is None,
'sentence-transformers is not installed.')
class SentenceTrasformerEmbeddingsTest(unittest.TestCase):
class SentenceTransformerEmbeddingsTest(unittest.TestCase):
def setUp(self) -> None:
self.artifact_location = tempfile.mkdtemp(prefix='sentence_transformers_')
# this bucket has TTL and will be deleted periodically
Expand Down Expand Up @@ -285,31 +285,16 @@ def test_mltransform_to_ptransform_with_sentence_transformer(self):
self.assertEqual(
ptransform_list[i]._model_handler._underlying.model_name, model_name)


@pytest.mark.no_xdist
@unittest.skipIf(
SentenceTransformerEmbeddings is None,
'sentence-transformers is not installed.')
@unittest.skipIf(Image is None, 'Pillow is not installed.')
class SentenceTransformerImageEmbeddingsTest(unittest.TestCase):
def setUp(self) -> None:
self.artifact_location = tempfile.mkdtemp(prefix='sentence_transformers_')
# this bucket has TTL and will be deleted periodically
self.gcs_artifact_location = os.path.join(
'gs://temp-storage-for-perf-tests/sentence_transformers',
uuid.uuid4().hex)
self.model_name = "clip-ViT-B-32"

def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)

def generateRandomImage(self, size: int):
imarray = np.random.rand(size, size, 3) * 255
return Image.fromarray(imarray.astype('uint8')).convert('RGBA')

@unittest.skipIf(Image is None, 'Pillow is not installed.')
def test_sentence_transformer_image_embeddings(self):
embedding_config = SentenceTransformerImageEmbeddings(
model_name=self.model_name, columns=[test_query_column])
embedding_config = SentenceTransformerEmbeddings(
model_name=IMAGE_MODEL_NAME,
columns=[test_query_column],
image_model=True)
img = self.generateRandomImage(256)
with beam.Pipeline() as pipeline:
result_pcoll = (
Expand All @@ -327,8 +312,10 @@ def assert_element(element):
_ = (result_pcoll | beam.Map(assert_element))

def test_sentence_transformer_images_with_str_data_types(self):
embedding_config = SentenceTransformerImageEmbeddings(
model_name=self.model_name, columns=[test_query_column])
embedding_config = SentenceTransformerEmbeddings(
model_name=IMAGE_MODEL_NAME,
columns=[test_query_column],
image_model=True)
with self.assertRaises(TypeError):
with beam.Pipeline() as pipeline:
_ = (
Expand Down

0 comments on commit 40fa9df

Please sign in to comment.