Skip to content

Commit

Permalink
Don't supply query_segment_id when packed_input is disabled.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696646145
  • Loading branch information
lingvo-bot authored and copybara-github committed Nov 14, 2024
1 parent cfd911c commit 41c50c7
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions lingvo/core/rnn_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,14 @@ def GeneratePackedInputResetMask(segment_id, is_reverse=False):
class IdentitySeqLayer(base_layer.BaseLayer):
"""A no-op sequence layer."""

def __init__(self, params):
super().__init__(params)

def zero_state(self, theta, batch_size):
del theta
del batch_size
return py_utils.NestedMap()

def FPropFullSequence(self, theta, inputs, paddings):
del theta
del paddings
return inputs


Expand Down Expand Up @@ -868,6 +869,7 @@ def zero_state(self,
Returns:
state0 - A `.NestedMap` containing initial states of RNN and attention.
"""
del atten_state_dim

p = self.params
atten = self.atten
Expand All @@ -894,6 +896,7 @@ def zero_state(self,
return state0

def reset_atten_state(self, theta, state, inputs):
del theta
state.atten = inputs.reset_mask * state.atten
if isinstance(state.atten_state, py_utils.NestedMap):
if 'inner' not in state.atten_state:
Expand Down Expand Up @@ -986,14 +989,21 @@ def CellFn(theta, state0, inputs):
py_utils.NestedMap(
act=act, padding=inputs.padding, reset_mask=inputs.reset_mask))

query_segment_id = (
tf.cast(tf.squeeze(inputs.segment_id, 1), py_utils.FPropDtype(p))
if p.packed_input
else None
)

state1.atten, state1.atten_probs, state1.atten_state = (
self.atten.ComputeContextVectorWithSource(
theta.atten,
theta.packed_src,
self.cell.GetOutput(state1.rnn),
state0_mod.atten_state,
query_segment_id=tf.cast(
tf.squeeze(inputs.segment_id, 1), py_utils.FPropDtype(p))))
query_segment_id=query_segment_id,
)
)
return state1, py_utils.NestedMap()

if p.packed_input:
Expand Down

0 comments on commit 41c50c7

Please sign in to comment.