Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add other decoding methods (nbest, nbest oracle, nbest LG) for wenetspeech pruned rnnt2 #482

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
| | Dev | Test-Net | Test-Meeting |
|----------------------|-------|----------|--------------|
| greedy search | 7.80 | 8.75 | 13.49 |
| modified beam search| 7.76 | 8.71 | 13.41 |
| fast beam search | 7.94 | 8.74 | 13.80 |
| modified beam search | 7.76 | 8.71 | 13.41 |

We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing)

Expand Down
48 changes: 46 additions & 2 deletions egs/wenetspeech/ASR/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ When training with the L subset, the WERs are
|------------------------------------|-------|----------|--------------|------------------------------------------|
| greedy search | 7.80 | 8.75 | 13.49 | --epoch 10, --avg 2, --max-duration 100 |
| modified beam search (beam size 4) | 7.76 | 8.71 | 13.41 | --epoch 10, --avg 2, --max-duration 100 |
| fast beam search (set as default) | 7.94 | 8.74 | 13.80 | --epoch 10, --avg 2, --max-duration 1500 |
| fast beam search (1best) | 7.94 | 8.74 | 13.80 | --epoch 10, --avg 2, --max-duration 1500 |
| fast beam search (nbest) | 9.82 | 10.98 | 16.37 | --epoch 10, --avg 2, --max-duration 600 |
| fast beam search (nbest oracle) | 6.88 | 7.18 | 11.77 | --epoch 10, --avg 2, --max-duration 600 |
| fast beam search (nbest LG, ngram_lm_scale=0.35) | 8.83 | 9.88 | 15.47 | --epoch 10, --avg 2, --max-duration 600 |

The training command for reproducing is given below:

Expand Down Expand Up @@ -59,7 +62,7 @@ avg=2
--decoding-method modified_beam_search \
--beam-size 4

## fast beam search
## fast beam search (1best)
./pruned_transducer_stateless2/decode.py \
--epoch $epoch \
--avg $avg \
Expand All @@ -70,6 +73,47 @@ avg=2
--beam 4 \
--max-contexts 4 \
--max-states 8

## fast beam search (nbest)
./pruned_transducer_stateless2/decode.py \
--epoch 10 \
--avg 2 \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5

## fast beam search (nbest oracle WER)
./pruned_transducer_stateless2/decode.py \
--epoch 10 \
--avg 2 \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5

## fast beam search (with LG)
./pruned_transducer_stateless2/decode.py \
--epoch 10 \
--avg 2 \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--ngram-lm-scale 0.35 \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
```

When training with the M subset, the WERs are
Expand Down
1 change: 1 addition & 0 deletions egs/wenetspeech/ASR/local/compile_lg.py
31 changes: 31 additions & 0 deletions egs/wenetspeech/ASR/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,34 @@ if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then
--lang-dir data/lang_char
fi
fi

# If you don't want to use LG for decoding, the following steps are not necessary.
if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
log "Stage 17: Prepare G"
# It will take about 20 minutes.
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
lang_char_dir=data/lang_char
if [ ! -f $lang_char_dir/3-gram.unpruned.arpa ]; then
python ./shared/make_kn_lm.py \
-ngram-order 3 \
-text $lang_char_dir/text_words_segmentation \
-lm $lang_char_dir/3-gram.unpruned.arpa
fi

mkdir -p data/lm
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
# It is used in building LG
python3 -m kaldilm \
--read-symbol-table="$lang_char_dir/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
$lang_char_dir/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
fi
fi

if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
log "Stage 18: Compile LG"
lang_char_dir=data/lang_char
python ./local/compile_lg.py --lang-dir $lang_char_dir
fi
156 changes: 153 additions & 3 deletions egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
--decoding-method modified_beam_search \
--beam-size 4

(3) fast beam search
(3) fast beam search (1best)
./pruned_transducer_stateless2/decode.py \
--epoch 10 \
--avg 2 \
Expand All @@ -48,6 +48,46 @@
--beam 4 \
--max-contexts 4 \
--max-states 8

(4) fast beam search (nbest)
./pruned_transducer_stateless2/decode.py \
--epoch 10 \
--avg 2 \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5

(5) fast beam search (nbest oracle WER)
./pruned_transducer_stateless2/decode.py \
--epoch 10 \
--avg 2 \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5

(6) fast beam search (with LG)
./pruned_transducer_stateless2/decode.py \
--epoch 10 \
--avg 2 \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""


Expand All @@ -63,13 +103,17 @@
from asr_datamodule import WenetSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model

from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
Expand Down Expand Up @@ -151,6 +195,11 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to
specify `--lang-dir`, which should contain `LG.pt`.
""",
)

Expand All @@ -173,6 +222,16 @@ def get_parser():
Used only when --decoding-method is fast_beam_search""",
)

parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.35,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)

parser.add_argument(
"--max-contexts",
type=int,
Expand Down Expand Up @@ -204,13 +263,32 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)

parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)

parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)

return parser


def decode_one_batch(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
Expand Down Expand Up @@ -267,6 +345,50 @@ def decode_one_batch(
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
sentence = "".join([lexicon.word_table[i] for i in hyp])
hyps.append(list(sentence))
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
Expand Down Expand Up @@ -331,6 +453,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Expand Down Expand Up @@ -373,6 +496,7 @@ def decode_dataset(
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
batch=batch,
)
Expand Down Expand Up @@ -454,6 +578,9 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
Expand All @@ -463,6 +590,13 @@ def main():
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if params.decoding_method == "fast_beam_search_nbest_LG":
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
if (
params.decoding_method == "fast_beam_search_nbest"
or params.decoding_method == "fast_beam_search_nbest_oracle"
):
params.suffix += f"-nbest-scale-{params.nbest_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
else:
Expand All @@ -482,6 +616,11 @@ def main():
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1

graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)

logging.info(params)

logging.info("About to create model")
Expand Down Expand Up @@ -513,8 +652,18 @@ def main():
model.eval()
model.device = device

if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lg_filename = params.lang_dir + "/LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = None

Expand Down Expand Up @@ -610,6 +759,7 @@ def main():
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
)
save_results(
Expand Down