Skip to content

Commit

Permalink
Call _UpdatePaddingWithPackedInputMask when both packed_input is …
Browse files Browse the repository at this point in the history
…set and the input bundle has both `source_segment_id` and `query_segment_id`.

Raise an error when the two IDs are supplied inconsistently, emit a warning if they are inconsistent with the packed_input setting.

PiperOrigin-RevId: 697811741
  • Loading branch information
lingvo-bot authored and copybara-github committed Nov 19, 2024
1 parent 41c50c7 commit 8c4ab9b
Showing 1 changed file with 74 additions and 39 deletions.
113 changes: 74 additions & 39 deletions lingvo/core/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,21 +552,33 @@ def AttenProbs(inputs: py_utils.NestedMap) -> tf.Tensor:

if source_padding is not None:
# source_padding is [source_length, multiplier, source_batch] now
if p.packed_input:
assert hasattr(inputs, 'source_segment_id')
assert hasattr(inputs, 'query_segment_id')
has_source_segment_id = hasattr(inputs, 'source_segment_id')
has_query_segment_id = hasattr(inputs, 'query_segment_id')
if has_source_segment_id != has_query_segment_id:
raise ValueError(
'source_segment_id and query_segment_id must be supplied at the'
' same time, but source_segment_id is %s and query_segment_id'
' is %s.'
% (
'present' if has_source_segment_id else 'unavailable',
'present' if has_query_segment_id else 'unavailable',
)
)
elif p.packed_input and has_source_segment_id:
source_padding = self._UpdatePaddingWithPackedInputMask(
source_padding, inputs.source_segment_id, inputs.query_segment_id
)
else:
if hasattr(inputs, 'source_segment_id'):
tf.logging.warning(
'packed_input is False but source_segment_id is passed.'
)
if hasattr(inputs, 'query_segment_id'):
tf.logging.warning(
'packed_input is False but query_segment_id is passed.'
)
elif p.packed_input or has_source_segment_id:
tf.logging.warning(
'packed_input is %s but both source_segment_id and'
' query_segment_id are %s in the inputs map. ID tensors should be'
' supplied if and only if packed_input is set. Continuing without'
' updating the padding tensor.'
% (
p.packed_input,
'present' if has_source_segment_id else 'unavailable',
)
)
source_padding = tf.transpose(source_padding, [1, 2, 0])

# [multiplier, source_batch, source_length]
Expand Down Expand Up @@ -728,25 +740,37 @@ def AttenProbs(inputs: py_utils.NestedMap) -> tf.Tensor:
source_padding = per_step_source_padding

if source_padding is not None:
if p.packed_input:
assert hasattr(inputs, 'source_segment_id')
assert hasattr(inputs, 'query_segment_id')
has_source_segment_id = hasattr(inputs, 'source_segment_id')
has_query_segment_id = hasattr(inputs, 'query_segment_id')
if has_source_segment_id != has_query_segment_id:
raise ValueError(
'source_segment_id and query_segment_id must be supplied at the'
' same time, but source_segment_id is %s and query_segment_id'
' is %s.'
% (
'present' if has_source_segment_id else 'unavailable',
'present' if has_query_segment_id else 'unavailable',
)
)
elif p.packed_input and has_source_segment_id:
source_padding = tf.expand_dims(source_padding, 1)
source_padding = self._UpdatePaddingWithPackedInputMask(
source_padding,
inputs.source_segment_id,
inputs.query_segment_id,
)
source_padding = tf.squeeze(source_padding, 1)
else:
if hasattr(inputs, 'source_segment_id'):
tf.logging.warning(
'packed_input is False but source_segment_id is passed.'
)
if hasattr(inputs, 'query_segment_id'):
tf.logging.warning(
'packed_input is False but query_segment_id is passed.'
)
elif p.packed_input or has_source_segment_id:
tf.logging.warning(
'packed_input is %s but both source_segment_id and'
' query_segment_id are %s in the inputs map. ID tensors should'
' be supplied if and only if packed_input is set. Continuing'
' without updating the padding tensor.'
% (
p.packed_input,
'present' if has_source_segment_id else 'unavailable',
)
)
# [b, sl]
source_padding = tf.transpose(source_padding)
logits = tf.transpose(logits)
Expand Down Expand Up @@ -1101,24 +1125,35 @@ def AttenProbs(
source_padding = per_step_source_padding

if source_padding is not None:
if p.packed_input:
assert hasattr(inputs, 'source_segment_id')
assert hasattr(inputs, 'query_segment_id')
source_padding = tf.transpose(source_padding, [1, 2, 0])
source_padding = tf.transpose(source_padding, [1, 2, 0])
has_source_segment_id = hasattr(inputs, 'source_segment_id')
has_query_segment_id = hasattr(inputs, 'query_segment_id')
if has_source_segment_id != has_query_segment_id:
raise ValueError(
'source_segment_id and query_segment_id must be supplied at the'
' same time, but source_segment_id is %s and query_segment_id'
' is %s.'
% (
'present' if has_source_segment_id else 'unavailable',
'present' if has_query_segment_id else 'unavailable',
)
)
elif p.packed_input and has_source_segment_id:
source_padding = self._UpdatePaddingWithPackedInputMask(
source_padding, inputs.source_segment_id, inputs.query_segment_id
)
source_padding = tf.transpose(source_padding, [1, 2, 0])
else:
if hasattr(inputs, 'source_segment_id'):
tf.logging.warning(
'packed_input is False but source_segment_id is passed.'
)
if hasattr(inputs, 'query_segment_id'):
tf.logging.warning(
'packed_input is False but query_segment_id is passed.'
)
source_padding = tf.transpose(source_padding, [2, 0, 1])
elif p.packed_input or has_source_segment_id:
tf.logging.warning(
'packed_input is %s but both source_segment_id and'
' query_segment_id are %s in the inputs map. ID tensors should be'
' supplied if and only if packed_input is set. Continuing without'
' updating the padding tensor.'
% (
p.packed_input,
'present' if has_source_segment_id else 'unavailable',
)
)
source_padding = tf.transpose(source_padding, [1, 2, 0])

# => [n, source_batch, time]
logits = tf.transpose(logits, [2, 0, 1])
Expand Down

0 comments on commit 8c4ab9b

Please sign in to comment.