Skip to content

Commit

Permalink
✍️ fix transducer beam search for longer sequence
Browse files Browse the repository at this point in the history
  • Loading branch information
nglehuy committed Jan 5, 2021
1 parent 4842078 commit 3d5c7f3
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions tensorflow_asr/models/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def initialize_beam(dynamic=False):
B = BeamHypothesis(
score=B.score.write(0, 0.0),
indices=B.indices.write(0, self.text_featurizer.blank),
prediction=B.prediction.write(0, tf.ones([total], dtype=tf.int32) * self.text_featurizer.blank),
prediction=B.prediction.write(0, tf.ones([total * 2], dtype=tf.int32) * self.text_featurizer.blank),
states=B.states.write(0, self.predict_net.get_initial_state())
)

Expand Down Expand Up @@ -673,10 +673,7 @@ def false_fn():

b_score, b_indices, b_prediction, b_states, \
a_score, a_indices, a_prediction, a_states, A_i = tf.cond(
tf.equal(pred, self.text_featurizer.blank),
true_fn=true_fn,
false_fn=false_fn
)
tf.equal(pred, self.text_featurizer.blank), true_fn=true_fn, false_fn=false_fn)

B = BeamHypothesis(score=b_score, indices=b_indices, prediction=b_prediction, states=b_states)
A = BeamHypothesis(score=a_score, indices=a_indices, prediction=a_prediction, states=a_states)
Expand Down

0 comments on commit 3d5c7f3

Please sign in to comment.