diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py index 10b0e5edc1..49b5578142 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py @@ -81,18 +81,17 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): - + # hyps is a list, every element is decode result of a sentence. hyps = hubert_model.ctc_greedy_search(batch) texts = batch["supervisions"]["text"] - assert len(hyps) == len(texts) + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] this_batch = [] - - for hyp_text, ref_text in zip(hyps, texts): + assert len(hyps) == len(texts) + for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts): ref_words = ref_text.split() hyp_words = hyp_text.split() - this_batch.append((ref_words, hyp_words)) - + this_batch.append((cut_id, ref_words, hyp_words)) results["ctc_greedy_search"].extend(this_batch) num_cuts += len(texts)