From 8c4ab9b4706db5e87d2be2ac1f2b2f1471873ef9 Mon Sep 17 00:00:00 2001 From: Lingvo Maintenance Date: Mon, 18 Nov 2024 17:25:59 -0800 Subject: [PATCH] Call `_UpdatePaddingWithPackedInputMask` when both `packed_input` is set and the input bundle has both `source_segment_id` and `query_segment_id`. Raise an error when the two IDs are supplied inconsistently, emit a warning if they are inconsistent with the packed_input setting. PiperOrigin-RevId: 697811741 --- lingvo/core/attention.py | 113 +++++++++++++++++++++++++-------------- 1 file changed, 74 insertions(+), 39 deletions(-) diff --git a/lingvo/core/attention.py b/lingvo/core/attention.py index 638edcdfc..8b810f089 100644 --- a/lingvo/core/attention.py +++ b/lingvo/core/attention.py @@ -552,21 +552,33 @@ def AttenProbs(inputs: py_utils.NestedMap) -> tf.Tensor: if source_padding is not None: # source_padding is [source_length, multiplier, source_batch] now - if p.packed_input: - assert hasattr(inputs, 'source_segment_id') - assert hasattr(inputs, 'query_segment_id') + has_source_segment_id = hasattr(inputs, 'source_segment_id') + has_query_segment_id = hasattr(inputs, 'query_segment_id') + if has_source_segment_id != has_query_segment_id: + raise ValueError( + 'source_segment_id and query_segment_id must be supplied at the' + ' same time, but source_segment_id is %s and query_segment_id' + ' is %s.' + % ( + 'present' if has_source_segment_id else 'unavailable', + 'present' if has_query_segment_id else 'unavailable', + ) + ) + elif p.packed_input and has_source_segment_id: source_padding = self._UpdatePaddingWithPackedInputMask( source_padding, inputs.source_segment_id, inputs.query_segment_id ) - else: - if hasattr(inputs, 'source_segment_id'): - tf.logging.warning( - 'packed_input is False but source_segment_id is passed.' - ) - if hasattr(inputs, 'query_segment_id'): - tf.logging.warning( - 'packed_input is False but query_segment_id is passed.' - ) + elif p.packed_input or has_source_segment_id: + tf.logging.warning( + 'packed_input is %s but both source_segment_id and' + ' query_segment_id are %s in the inputs map. ID tensors should be' + ' supplied if and only if packed_input is set. Continuing without' + ' updating the padding tensor.' + % ( + p.packed_input, + 'present' if has_source_segment_id else 'unavailable', + ) + ) source_padding = tf.transpose(source_padding, [1, 2, 0]) # [multiplier, source_batch, source_length] @@ -728,9 +740,19 @@ def AttenProbs(inputs: py_utils.NestedMap) -> tf.Tensor: source_padding = per_step_source_padding if source_padding is not None: - if p.packed_input: - assert hasattr(inputs, 'source_segment_id') - assert hasattr(inputs, 'query_segment_id') + has_source_segment_id = hasattr(inputs, 'source_segment_id') + has_query_segment_id = hasattr(inputs, 'query_segment_id') + if has_source_segment_id != has_query_segment_id: + raise ValueError( + 'source_segment_id and query_segment_id must be supplied at the' + ' same time, but source_segment_id is %s and query_segment_id' + ' is %s.' + % ( + 'present' if has_source_segment_id else 'unavailable', + 'present' if has_query_segment_id else 'unavailable', + ) + ) + elif p.packed_input and has_source_segment_id: source_padding = tf.expand_dims(source_padding, 1) source_padding = self._UpdatePaddingWithPackedInputMask( source_padding, @@ -738,15 +760,17 @@ def AttenProbs(inputs: py_utils.NestedMap) -> tf.Tensor: inputs.query_segment_id, ) source_padding = tf.squeeze(source_padding, 1) - else: - if hasattr(inputs, 'source_segment_id'): - tf.logging.warning( - 'packed_input is False but source_segment_id is passed.' - ) - if hasattr(inputs, 'query_segment_id'): - tf.logging.warning( - 'packed_input is False but query_segment_id is passed.' - ) + elif p.packed_input or has_source_segment_id: + tf.logging.warning( + 'packed_input is %s but both source_segment_id and' + ' query_segment_id are %s in the inputs map. ID tensors should' + ' be supplied if and only if packed_input is set. Continuing' + ' without updating the padding tensor.' + % ( + p.packed_input, + 'present' if has_source_segment_id else 'unavailable', + ) + ) # [b, sl] source_padding = tf.transpose(source_padding) logits = tf.transpose(logits) @@ -1101,24 +1125,35 @@ def AttenProbs( source_padding = per_step_source_padding if source_padding is not None: - if p.packed_input: - assert hasattr(inputs, 'source_segment_id') - assert hasattr(inputs, 'query_segment_id') - source_padding = tf.transpose(source_padding, [1, 2, 0]) + source_padding = tf.transpose(source_padding, [1, 2, 0]) + has_source_segment_id = hasattr(inputs, 'source_segment_id') + has_query_segment_id = hasattr(inputs, 'query_segment_id') + if has_source_segment_id != has_query_segment_id: + raise ValueError( + 'source_segment_id and query_segment_id must be supplied at the' + ' same time, but source_segment_id is %s and query_segment_id' + ' is %s.' + % ( + 'present' if has_source_segment_id else 'unavailable', + 'present' if has_query_segment_id else 'unavailable', + ) + ) + elif p.packed_input and has_source_segment_id: source_padding = self._UpdatePaddingWithPackedInputMask( source_padding, inputs.source_segment_id, inputs.query_segment_id ) - source_padding = tf.transpose(source_padding, [1, 2, 0]) - else: - if hasattr(inputs, 'source_segment_id'): - tf.logging.warning( - 'packed_input is False but source_segment_id is passed.' - ) - if hasattr(inputs, 'query_segment_id'): - tf.logging.warning( - 'packed_input is False but query_segment_id is passed.' - ) - source_padding = tf.transpose(source_padding, [2, 0, 1]) + elif p.packed_input or has_source_segment_id: + tf.logging.warning( + 'packed_input is %s but both source_segment_id and' + ' query_segment_id are %s in the inputs map. ID tensors should be' + ' supplied if and only if packed_input is set. Continuing without' + ' updating the padding tensor.' + % ( + p.packed_input, + 'present' if has_source_segment_id else 'unavailable', + ) + ) + source_padding = tf.transpose(source_padding, [1, 2, 0]) # => [n, source_batch, time] logits = tf.transpose(logits, [2, 0, 1])