Skip to content

Commit

Permalink
Deduplicate Base Embedding Handler Code (#31534)
Browse files Browse the repository at this point in the history
* Deduplicate Base Embedding Handler Code

* linting
  • Loading branch information
jrmccluskey authored Jun 7, 2024
1 parent e5c2f6c commit 9a921d5
Showing 1 changed file with 43 additions and 88 deletions.
131 changes: 43 additions & 88 deletions sdks/python/apache_beam/ml/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,22 +587,20 @@ def load_transforms_from_artifact_location(artifact_location):
return _transform_attribute_manager.load_attributes(artifact_location)


class _TextEmbeddingHandler(ModelHandler):
class _EmbeddingHandler(ModelHandler):
"""
A ModelHandler intended to be work on list[dict[str, str]] inputs.
A ModelHandler intended to be work on list[dict[str, Any]] inputs.
The inputs to the model handler are expected to be a list of dicts.
For example, if the original mode is used with RunInference to take a
PCollection[E] to a PCollection[P], this ModelHandler would take a
PCollection[Dict[str, E]] to a PCollection[Dict[str, P]].
_TextEmbeddingHandler will accept an EmbeddingsManager instance, which
_EmbeddingHandler will accept an EmbeddingsManager instance, which
contains the details of the model to be loaded and the inference_fn to be
used. The purpose of _TextEmbeddingHandler is to generate embeddings for
text inputs using the EmbeddingsManager instance.
If the input is not a text column, a RuntimeError will be raised.
used. The purpose of _EmbeddingHandler is to generate embeddings for
general inputs using the EmbeddingsManager instance.
This is an internal class and offers no backwards compatibility guarantees.
Expand All @@ -619,12 +617,9 @@ def load_model(self):
return model

def _validate_column_data(self, batch):
if not isinstance(batch[0], (str, bytes)):
raise TypeError(
'Embeddings can only be generated on Dict[str, str].'
f'Got Dict[str, {type(batch[0])}] instead.')
pass

def _validate_batch(self, batch: Sequence[Dict[str, List[str]]]):
def _validate_batch(self, batch: Sequence[Dict[str, Any]]):
if not batch or not isinstance(batch[0], dict):
raise TypeError(
'Expected data to be dicts, got '
Expand Down Expand Up @@ -676,8 +671,7 @@ def run_inference(

def get_metrics_namespace(self) -> str:
return (
self._underlying.get_metrics_namespace() or
'BeamML_TextEmbeddingHandler')
self._underlying.get_metrics_namespace() or 'BeamML_EmbeddingHandler')

def batch_elements_kwargs(self) -> Mapping[str, Any]:
batch_sizes_map = {}
Expand All @@ -694,7 +688,41 @@ def validate_inference_args(self, _):
pass


class _ImageEmbeddingHandler(ModelHandler):
class _TextEmbeddingHandler(_EmbeddingHandler):
"""
A ModelHandler intended to be work on list[dict[str, str]] inputs.
The inputs to the model handler are expected to be a list of dicts.
For example, if the original mode is used with RunInference to take a
PCollection[E] to a PCollection[P], this ModelHandler would take a
PCollection[Dict[str, E]] to a PCollection[Dict[str, P]].
_TextEmbeddingHandler will accept an EmbeddingsManager instance, which
contains the details of the model to be loaded and the inference_fn to be
used. The purpose of _TextEmbeddingHandler is to generate embeddings for
text inputs using the EmbeddingsManager instance.
If the input is not a text column, a RuntimeError will be raised.
This is an internal class and offers no backwards compatibility guarantees.
Args:
embeddings_manager: An EmbeddingsManager instance.
"""
def _validate_column_data(self, batch):
if not isinstance(batch[0], (str, bytes)):
raise TypeError(
'Embeddings can only be generated on Dict[str, str].'
f'Got Dict[str, {type(batch[0])}] instead.')

def get_metrics_namespace(self) -> str:
return (
self._underlying.get_metrics_namespace() or
'BeamML_TextEmbeddingHandler')


class _ImageEmbeddingHandler(_EmbeddingHandler):
"""
A ModelHandler intended to be work on list[dict[str, Image]] inputs.
Expand All @@ -717,15 +745,6 @@ class _ImageEmbeddingHandler(ModelHandler):
Args:
embeddings_manager: An EmbeddingsManager instance.
"""
def __init__(self, embeddings_manager: EmbeddingsManager):
self.embedding_config = embeddings_manager
self._underlying = self.embedding_config.get_model_handler()
self.columns = self.embedding_config.get_columns_to_apply()

def load_model(self):
model = self._underlying.load_model()
return model

def _validate_column_data(self, batch):
# Don't want to require framework-specific imports
# here, so just catch columns of primatives for now.
Expand All @@ -734,71 +753,7 @@ def _validate_column_data(self, batch):
'Embeddings can only be generated on Dict[str, Image].'
f'Got Dict[str, {type(batch[0])}] instead.')

def _validate_batch(self, batch: Sequence[Dict[str, List[Any]]]):
if not batch or not isinstance(batch[0], dict):
raise TypeError(
'Expected data to be dicts, got '
f'{type(batch[0])} instead.')

def _process_batch(
self,
dict_batch: Dict[str, List[Any]],
model: ModelT,
inference_args: Optional[Dict[str, Any]]) -> Dict[str, List[Any]]:
result: Dict[str, List[Any]] = collections.defaultdict(list)
input_keys = dict_batch.keys()
missing_columns_in_data = set(self.columns) - set(input_keys)
if missing_columns_in_data:
raise RuntimeError(
f'Data does not contain the following columns '
f': {missing_columns_in_data}.')
for key, batch in dict_batch.items():
if key in self.columns:
self._validate_column_data(batch)
prediction = self._underlying.run_inference(
batch, model, inference_args)
if isinstance(prediction, np.ndarray):
prediction = prediction.tolist()
result[key] = prediction # type: ignore[assignment]
else:
result[key] = prediction # type: ignore[assignment]
else:
result[key] = batch
return result

def run_inference(
self,
batch: Sequence[Dict[str, List[str]]],
model: ModelT,
inference_args: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Union[List[float], List[str]]]]:
"""
Runs inference on a batch of text inputs. The inputs are expected to be
a list of dicts. Each dict should have the same keys, and the shape
should be of the same size for a single key across the batch.
"""
self._validate_batch(batch)
dict_batch = _convert_list_of_dicts_to_dict_of_lists(list_of_dicts=batch)
transformed_batch = self._process_batch(dict_batch, model, inference_args)
return _convert_dict_of_lists_to_lists_of_dict(
dict_of_lists=transformed_batch,
)

def get_metrics_namespace(self) -> str:
return (
self._underlying.get_metrics_namespace() or
'BeamML_ImageEmbeddingHandler')

def batch_elements_kwargs(self) -> Mapping[str, Any]:
batch_sizes_map = {}
if self.embedding_config.max_batch_size:
batch_sizes_map['max_batch_size'] = self.embedding_config.max_batch_size
if self.embedding_config.min_batch_size:
batch_sizes_map['min_batch_size'] = self.embedding_config.min_batch_size
return (self._underlying.batch_elements_kwargs() or batch_sizes_map)

def __repr__(self):
return self._underlying.__repr__()

def validate_inference_args(self, _):
pass

0 comments on commit 9a921d5

Please sign in to comment.