Skip to content

Commit

Permalink
Add suffix for vocab files in tft transforms (#29720)
Browse files Browse the repository at this point in the history
* Add suffix for vocab files

* add test

* default params to None similar to TFT

* fix lint

* Update tft.py
  • Loading branch information
AnandInguva committed Dec 12, 2023
1 parent 90e79ae commit 276aa02
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 11 deletions.
28 changes: 18 additions & 10 deletions sdks/python/apache_beam/ml/transforms/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def __init__(
num_oov_buckets: Any lookup of an out-of-vocabulary token will return a
bucket ID based on its hash if `num_oov_buckets` is greater than zero.
Otherwise it is assigned the `default_value`.
vocab_filename: The file name for the vocabulary file. If not provided,
the default name would be `compute_and_apply_vocab'
vocab_filename: The file name for the vocabulary file. The vocab file
will be suffixed with the column name.
NOTE in order to make your pipelines resilient to implementation
details please set `vocab_filename` when you are using
the vocab_filename on a downstream component.
Expand All @@ -183,8 +183,7 @@ def __init__(
self._top_k = top_k
self._frequency_threshold = frequency_threshold
self._num_oov_buckets = num_oov_buckets
self._vocab_filename = vocab_filename if vocab_filename else (
'compute_and_apply_vocab')
self._vocab_filename = vocab_filename
self._name = name
self.split_string_by_delimiter = split_string_by_delimiter

Expand All @@ -196,14 +195,17 @@ def apply_transform(
data = self._split_string_with_delimiter(
data, self.split_string_by_delimiter)

vocab_filename = self._vocab_filename
if vocab_filename:
vocab_filename = vocab_filename + f'_{output_column_name}'
return {
output_column_name: tft.compute_and_apply_vocabulary(
x=data,
default_value=self._default_value,
top_k=self._top_k,
frequency_threshold=self._frequency_threshold,
num_oov_buckets=self._num_oov_buckets,
vocab_filename=self._vocab_filename,
vocab_filename=vocab_filename,
name=self._name)
}

Expand Down Expand Up @@ -535,7 +537,7 @@ def __init__(
ngram_range: Tuple[int, int] = (1, 1),
ngrams_separator: Optional[str] = None,
compute_word_count: bool = False,
key_vocab_filename: str = 'key_vocab_mapping',
key_vocab_filename: Optional[str] = None,
name: Optional[str] = None,
):
"""
Expand All @@ -558,7 +560,9 @@ def __init__(
compute_word_count: A boolean that specifies whether to compute
the unique word count over the entire dataset. Defaults to False.
key_vocab_filename: The file name for the key vocabulary file when
compute_word_count is True.
compute_word_count is True. If empty, a file name
will be chosen based on the current scope. If provided, the vocab
file will be suffixed with the column name.
name: A name for the operation (optional).
Note that original order of the input may not be preserved.
Expand All @@ -585,10 +589,14 @@ def apply_transform(self, data: tf.SparseTensor, output_col_name: str):
data, self.split_string_by_delimiter)
output = tft.bag_of_words(
data, self.ngram_range, self.ngrams_separator, self.name)
# word counts are written to the key_vocab_filename
self.compute_word_count_fn(data, self.key_vocab_filename)
# word counts are written to the file only if compute_word_count is True
key_vocab_filename = self.key_vocab_filename
if key_vocab_filename:
key_vocab_filename = key_vocab_filename + f'_{output_col_name}'
self.compute_word_count_fn(data, key_vocab_filename)
return {output_col_name: output}


def count_unqiue_words(data: tf.SparseTensor, output_vocab_name: str) -> None:
def count_unqiue_words(
data: tf.SparseTensor, output_vocab_name: Optional[str]) -> None:
tft.count_per_key(data, key_vocabulary_filename=output_vocab_name)
81 changes: 80 additions & 1 deletion sdks/python/apache_beam/ml/transforms/tft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,85 @@ def test_string_split_with_multiple_delimiters(self):
]
assert_that(result, equal_to(expected_result, equals_fn=np.array_equal))

def test_multiple_columns_with_default_vocab_name(self):
data = [{
'x': ['I', 'like', 'pie'], 'y': ['Apach', 'Beam', 'is', 'awesome']
},
{
'x': ['yum', 'yum', 'pie'],
'y': ['Beam', 'is', 'a', 'unified', 'model']
}]
with beam.Pipeline() as p:
result = (
p
| "Create" >> beam.Create(data)
| "MLTransform" >> base.MLTransform(
write_artifact_location=self.artifact_location).with_transform(
tft.ComputeAndApplyVocabulary(columns=['x', 'y'])))

expected_data_x = [np.array([3, 2, 1]), np.array([0, 0, 1])]

expected_data_y = [np.array([6, 1, 0, 4]), np.array([1, 0, 5, 2, 3])]

actual_data_x = (result | beam.Map(lambda x: x.x))
actual_data_y = (result | beam.Map(lambda x: x.y))

assert_that(
actual_data_x,
equal_to(expected_data_x, equals_fn=np.array_equal),
label='x')
assert_that(
actual_data_y,
equal_to(expected_data_y, equals_fn=np.array_equal),
label='y')
files = os.listdir(self.artifact_location)
files.remove(base._ATTRIBUTE_FILE_NAME)
assert len(files) == 1
tft_vocab_assets = os.listdir(
os.path.join(
self.artifact_location, files[0], 'transform_fn', 'assets'))
assert len(tft_vocab_assets) == 2

def test_multiple_columns_with_vocab_name(self):
data = [{
'x': ['I', 'like', 'pie'], 'y': ['Apach', 'Beam', 'is', 'awesome']
},
{
'x': ['yum', 'yum', 'pie'],
'y': ['Beam', 'is', 'a', 'unified', 'model']
}]
with beam.Pipeline() as p:
result = (
p
| "Create" >> beam.Create(data)
| "MLTransform" >> base.MLTransform(
write_artifact_location=self.artifact_location).with_transform(
tft.ComputeAndApplyVocabulary(
columns=['x', 'y'], vocab_filename='my_vocab')))

expected_data_x = [np.array([3, 2, 1]), np.array([0, 0, 1])]

expected_data_y = [np.array([6, 1, 0, 4]), np.array([1, 0, 5, 2, 3])]

actual_data_x = (result | beam.Map(lambda x: x.x))
actual_data_y = (result | beam.Map(lambda x: x.y))

assert_that(
actual_data_x,
equal_to(expected_data_x, equals_fn=np.array_equal),
label='x')
assert_that(
actual_data_y,
equal_to(expected_data_y, equals_fn=np.array_equal),
label='y')
files = os.listdir(self.artifact_location)
files.remove(base._ATTRIBUTE_FILE_NAME)
assert len(files) == 1
tft_vocab_assets = os.listdir(
os.path.join(
self.artifact_location, files[0], 'transform_fn', 'assets'))
assert len(tft_vocab_assets) == 2


class TFIDIFTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down Expand Up @@ -717,7 +796,7 @@ def validate_count_per_key(key_vocab_filename):
self.artifact_location,
files[0],
'transform_fn/assets',
key_vocab_filename)
key_vocab_filename + '_x')
with open(key_vocab_location, 'r') as f:
key_vocab_list = [line.strip() for line in f]
return key_vocab_list
Expand Down

0 comments on commit 276aa02

Please sign in to comment.