From 7025ec67a1635e79b0570b2ed017508feb351b20 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 22 Aug 2017 19:01:15 +0100 Subject: [PATCH 1/2] Simplify decoder structure --- decoder.py | 25 +++++++++++++------------ predict.py | 26 ++++++++++++-------------- test.py | 23 +++++++++-------------- 3 files changed, 34 insertions(+), 40 deletions(-) diff --git a/decoder.py b/decoder.py index 6bf67746..890407ca 100644 --- a/decoder.py +++ b/decoder.py @@ -17,15 +17,8 @@ import Levenshtein as Lev import torch -from enum import Enum from six.moves import xrange -try: - from pytorch_ctc import CTCBeamDecoder as CTCBD - from pytorch_ctc import Scorer, KenLMScorer -except ImportError: - print("warn: pytorch_ctc unavailable. Only greedy decoding is supported.") - class Decoder(object): """ @@ -134,17 +127,25 @@ def decode(self, probs, sizes=None): class BeamCTCDecoder(Decoder): - def __init__(self, labels, scorer, beam_width=20, top_paths=1, blank_index=0, space_index=28): + def __init__(self, labels, beam_width=20, top_paths=1, blank_index=0, space_index=28, lm_path=None, trie_path=None, + lm_alpha=None, lm_beta1=None, lm_beta2=None): super(BeamCTCDecoder, self).__init__(labels, blank_index=blank_index, space_index=space_index) self._beam_width = beam_width self._top_n = top_paths + try: - import pytorch_ctc + from pytorch_ctc import CTCBeamDecoder, Scorer, KenLMScorer except ImportError: raise ImportError("BeamCTCDecoder requires pytorch_ctc package.") - - self._decoder = CTCBD(scorer, labels, top_paths=top_paths, beam_width=beam_width, - blank_index=blank_index, space_index=space_index, merge_repeated=False) + if lm_path is not None: + scorer = KenLMScorer(labels, lm_path, trie_path) + scorer.set_lm_weight(lm_alpha) + scorer.set_word_weight(lm_beta1) + scorer.set_valid_word_weight(lm_beta2) + else: + scorer = Scorer() + self._decoder = CTCBeamDecoder(scorer, labels, top_paths=top_paths, beam_width=beam_width, + blank_index=blank_index, space_index=space_index, merge_repeated=False) def decode(self, probs, sizes=None): sizes = sizes.cpu() if sizes is not None else None diff --git a/predict.py b/predict.py index 56c23760..d8e0602b 100644 --- a/predict.py +++ b/predict.py @@ -2,11 +2,11 @@ import sys import time -import torch +from decoder import GreedyDecoder, BeamCTCDecoder + from torch.autograd import Variable from data.data_loader import SpectrogramParser -from decoder import GreedyDecoder, BeamCTCDecoder, Scorer, KenLMScorer from model import DeepSpeech parser = argparse.ArgumentParser(description='DeepSpeech prediction') @@ -18,8 +18,10 @@ parser.add_argument('--decoder', default="greedy", choices=["greedy", "beam"], type=str, help="Decoder to use") beam_args = parser.add_argument_group("Beam Decode Options", "Configurations options for the CTC Beam Search decoder") beam_args.add_argument('--beam_width', default=10, type=int, help='Beam width to use') -beam_args.add_argument('--lm_path', default=None, type=str, help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)') -beam_args.add_argument('--trie_path', default=None, type=str, help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)') +beam_args.add_argument('--lm_path', default=None, type=str, + help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)') +beam_args.add_argument('--trie_path', default=None, type=str, + help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)') beam_args.add_argument('--lm_alpha', default=0.8, type=float, help='Language model weight') beam_args.add_argument('--lm_beta1', default=1, type=float, help='Language model word bonus (all words)') beam_args.add_argument('--lm_beta2', default=1, type=float, help='Language model word bonus (IV words)') @@ -33,15 +35,10 @@ audio_conf = DeepSpeech.get_audio_conf(model) if args.decoder == "beam": - scorer = None - if args.lm_path is not None: - scorer = KenLMScorer(labels, args.lm_path, args.trie_path) - scorer.set_lm_weight(args.lm_alpha) - scorer.set_word_weight(args.lm_beta1) - scorer.set_valid_word_weight(args.lm_beta2) - else: - scorer = Scorer() - decoder = BeamCTCDecoder(labels, scorer, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '), blank_index=labels.index('_')) + decoder = BeamCTCDecoder(labels, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '), + blank_index=labels.index('_'), lm_path=args.lm_path, + trie_path=args.trie_path, lm_alpha=args.lm_alpha, lm_beta1=args.lm_beta1, + lm_beta2=args.lm_beta2) else: decoder = GreedyDecoder(labels, space_index=labels.index(' '), blank_index=labels.index('_')) @@ -56,4 +53,5 @@ t1 = time.time() print(decoded_output[0]) - print("Decoded {0:.2f} seconds of audio in {1:.2f} seconds".format(spect.size(3)*audio_conf['window_stride'], t1-t0), file=sys.stderr) + print("Decoded {0:.2f} seconds of audio in {1:.2f} seconds".format(spect.size(3) * audio_conf['window_stride'], + t1 - t0), file=sys.stderr) diff --git a/test.py b/test.py index ec52f01f..e9673fac 100644 --- a/test.py +++ b/test.py @@ -1,11 +1,9 @@ import argparse -import json -import torch from torch.autograd import Variable +from decoder import GreedyDecoder, BeamCTCDecoder from data.data_loader import SpectrogramDataset, AudioDataLoader -from decoder import GreedyDecoder, BeamCTCDecoder, Scorer, KenLMScorer from model import DeepSpeech parser = argparse.ArgumentParser(description='DeepSpeech prediction') @@ -19,8 +17,10 @@ parser.add_argument('--decoder', default="greedy", choices=["greedy", "beam"], type=str, help="Decoder to use") beam_args = parser.add_argument_group("Beam Decode Options", "Configurations options for the CTC Beam Search decoder") beam_args.add_argument('--beam_width', default=10, type=int, help='Beam width to use') -beam_args.add_argument('--lm_path', default=None, type=str, help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)') -beam_args.add_argument('--trie_path', default=None, type=str, help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)') +beam_args.add_argument('--lm_path', default=None, type=str, + help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)') +beam_args.add_argument('--trie_path', default=None, type=str, + help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)') beam_args.add_argument('--lm_alpha', default=0.8, type=float, help='Language model weight') beam_args.add_argument('--lm_beta1', default=1, type=float, help='Language model word bonus (all words)') beam_args.add_argument('--lm_beta2', default=1, type=float, help='Language model word bonus (IV words)') @@ -34,15 +34,10 @@ audio_conf = DeepSpeech.get_audio_conf(model) if args.decoder == "beam": - scorer = None - if args.lm_path is not None: - scorer = KenLMScorer(labels, args.lm_path, args.trie_path) - scorer.set_lm_weight(args.lm_alpha) - scorer.set_word_weight(args.lm_beta1) - scorer.set_valid_word_weight(args.lm_beta2) - else: - scorer = Scorer() - decoder = BeamCTCDecoder(labels, scorer, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '), blank_index=labels.index('_')) + decoder = BeamCTCDecoder(labels, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '), + blank_index=labels.index('_'), lm_path=args.lm_path, + trie_path=args.trie_path, lm_alpha=args.lm_alpha, lm_beta1=args.lm_beta1, + lm_beta2=args.lm_beta2) else: decoder = GreedyDecoder(labels, space_index=labels.index(' '), blank_index=labels.index('_')) From 683eabe8e8c868d09e465eac7c31d12e548e98d8 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 24 Aug 2017 09:50:52 +0100 Subject: [PATCH 2/2] Added progress bar support --- requirements.txt | 1 + test.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index bed6925e..4699cfa4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ torch visdom wget librosa +tqdm \ No newline at end of file diff --git a/test.py b/test.py index e9673fac..9a736751 100644 --- a/test.py +++ b/test.py @@ -1,6 +1,8 @@ import argparse from torch.autograd import Variable +from tqdm import tqdm + from decoder import GreedyDecoder, BeamCTCDecoder from data.data_loader import SpectrogramDataset, AudioDataLoader @@ -46,7 +48,7 @@ test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) total_cer, total_wer = 0, 0 - for i, (data) in enumerate(test_loader): + for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader)): inputs, targets, input_percentages, target_sizes = data inputs = Variable(inputs, volatile=True)