Skip to content

Commit

Permalink
Implement Vertex AI Image Embedding MLTransform (#31514)
Browse files Browse the repository at this point in the history
* Implement Vertex AI Image Embeddings support

* add unit test

* Adjust arg validation

* yapf

* linting

* address comments

* add image embeddings to __all__
  • Loading branch information
jrmccluskey authored Jun 6, 2024
1 parent 8c6e1a4 commit 3cdbd3a
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 1 deletion.
96 changes: 95 additions & 1 deletion sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _ImageEmbeddingHandler
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
from vertexai.language_models import TextEmbeddingInput
from vertexai.language_models import TextEmbeddingModel
from vertexai.vision_models import Image
from vertexai.vision_models import MultiModalEmbeddingModel

__all__ = ["VertexAITextEmbeddings"]
__all__ = ["VertexAITextEmbeddings", "VertexAIImageEmbeddings"]

DEFAULT_TASK_TYPE = "RETRIEVAL_DOCUMENT"
# TODO: https://github.com/apache/beam/issues/29356
Expand Down Expand Up @@ -157,3 +160,94 @@ def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
return RunInference(
model_handler=_TextEmbeddingHandler(self),
inference_args=self.inference_args)


class _VertexAIImageEmbeddingHandler(ModelHandler):
def __init__(
self,
model_name: str,
dimension: Optional[int] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[Credentials] = None,
):
vertexai.init(project=project, location=location, credentials=credentials)
self.model_name = model_name
self.dimension = dimension

def run_inference(
self,
batch: Sequence[Image],
model: MultiModalEmbeddingModel,
inference_args: Optional[Dict[str, Any]] = None,
) -> Iterable:
embeddings = []
# Maximum request size for muli-model embedding models is 1.
for img in batch:
embedding_response = model.get_embeddings(
image=img, dimension=self.dimension)
embeddings.append(embedding_response.image_embedding)
return embeddings

def load_model(self):
model = MultiModalEmbeddingModel.from_pretrained(self.model_name)
return model

def __repr__(self):
# ModelHandler is internal to the user and is not exposed.
# Hence we need to override the __repr__ method to expose
# the name of the class.
return 'VertexAIImageEmbeddings'


class VertexAIImageEmbeddings(EmbeddingsManager):
def __init__(
self,
model_name: str,
columns: List[str],
dimension: Optional[int],
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[Credentials] = None,
**kwargs):
"""
Embedding Config for Vertex AI Image Embedding models following
https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-multimodal-embeddings # pylint: disable=line-too-long
Image Embeddings are generated for a batch of images using the Vertex AI API.
Embeddings are returned in a list for each image in the batch. This
transform makes remote calls to the Vertex AI service and may incur costs
for use.
Args:
model_name: The name of the Vertex AI Multi-Modal Embedding model.
columns: The columns containing the text to be embedded.
dimension: The length of the embedding vector to generate. Must be one of
128, 256, 512, or 1408. If not set, Vertex AI's default value is 1408.
project: The default GCP project for API calls.
location: The default location for API calls.
credentials: Custom credentials for API calls.
Defaults to environment credentials.
"""
self.model_name = model_name
self.project = project
self.location = location
self.credentials = credentials
if dimension is not None and dimension not in (128, 256, 512, 1408):
raise ValueError(
"dimension argument must be one of 128, 256, 512, or 1408")
self.dimension = dimension
super().__init__(columns=columns, **kwargs)

def get_model_handler(self) -> ModelHandler:
return _VertexAIImageEmbeddingHandler(
model_name=self.model_name,
dimension=self.dimension,
project=self.project,
location=self.location,
credentials=self.credentials,
)

def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
return RunInference(
model_handler=_ImageEmbeddingHandler(self),
inference_args=self.inference_args)
41 changes: 41 additions & 0 deletions sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@

try:
from apache_beam.ml.transforms.embeddings.vertex_ai import VertexAITextEmbeddings
from apache_beam.ml.transforms.embeddings.vertex_ai import VertexAIImageEmbeddings
from vertexai.vision_models import Image
except ImportError:
VertexAITextEmbeddings = None # type: ignore
VertexAIImageEmbeddings = None # type: ignore

# pylint: disable=ungrouped-imports
try:
Expand Down Expand Up @@ -245,5 +248,43 @@ def test_mltransform_to_ptransform_with_vertex(self):
ptransform_list[i]._model_handler._underlying.model_name, model_name)


@unittest.skipIf(
VertexAIImageEmbeddings is None, 'Vertex AI Python SDK is not installed.')
class VertexAIImageEmbeddingsTest(unittest.TestCase):
def setUp(self) -> None:
self.artifact_location = tempfile.mkdtemp(prefix='_vertex_ai_image_test')
self.gcs_artifact_location = os.path.join(
'gs://temp-storage-for-perf-tests/vertex_ai_image', uuid.uuid4().hex)
self.model_name = "multimodalembedding"
self.image_path = "gs://apache-beam-ml/testing/inputs/vertex_images/sunflowers/1008566138_6927679c8a.jpg" # pylint: disable=line-too-long

def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)

def test_vertex_ai_image_embedding(self):
embedding_config = VertexAIImageEmbeddings(
model_name=self.model_name, columns=[test_query_column], dimension=128)
with beam.Pipeline() as pipeline:
transformed_pcoll = (
pipeline | "CreateData" >> beam.Create([{
test_query_column: Image(gcs_uri=self.image_path)
}])
| "MLTransform" >> MLTransform(
write_artifact_location=self.artifact_location).with_transform(
embedding_config))

def assert_element(element):
assert len(element[test_query_column]) == 128

_ = (transformed_pcoll | beam.Map(assert_element))

def test_improper_dimension(self):
with self.assertRaises(ValueError):
_ = VertexAIImageEmbeddings(
model_name=self.model_name,
columns=[test_query_column],
dimension=127)


if __name__ == '__main__':
unittest.main()

0 comments on commit 3cdbd3a

Please sign in to comment.