Skip to content

Commit

Permalink
Merge pull request #141 from SeanNaren/fix-inference
Browse files Browse the repository at this point in the history
Refactor of testing/prediction, added progress
  • Loading branch information
Sean Naren authored Aug 24, 2017
2 parents 3f10ec0 + 683eabe commit fc8efd8
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 41 deletions.
25 changes: 13 additions & 12 deletions decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,8 @@

import Levenshtein as Lev
import torch
from enum import Enum
from six.moves import xrange

try:
from pytorch_ctc import CTCBeamDecoder as CTCBD
from pytorch_ctc import Scorer, KenLMScorer
except ImportError:
print("warn: pytorch_ctc unavailable. Only greedy decoding is supported.")


class Decoder(object):
"""
Expand Down Expand Up @@ -134,17 +127,25 @@ def decode(self, probs, sizes=None):


class BeamCTCDecoder(Decoder):
def __init__(self, labels, scorer, beam_width=20, top_paths=1, blank_index=0, space_index=28):
def __init__(self, labels, beam_width=20, top_paths=1, blank_index=0, space_index=28, lm_path=None, trie_path=None,
lm_alpha=None, lm_beta1=None, lm_beta2=None):
super(BeamCTCDecoder, self).__init__(labels, blank_index=blank_index, space_index=space_index)
self._beam_width = beam_width
self._top_n = top_paths

try:
import pytorch_ctc
from pytorch_ctc import CTCBeamDecoder, Scorer, KenLMScorer
except ImportError:
raise ImportError("BeamCTCDecoder requires pytorch_ctc package.")

self._decoder = CTCBD(scorer, labels, top_paths=top_paths, beam_width=beam_width,
blank_index=blank_index, space_index=space_index, merge_repeated=False)
if lm_path is not None:
scorer = KenLMScorer(labels, lm_path, trie_path)
scorer.set_lm_weight(lm_alpha)
scorer.set_word_weight(lm_beta1)
scorer.set_valid_word_weight(lm_beta2)
else:
scorer = Scorer()
self._decoder = CTCBeamDecoder(scorer, labels, top_paths=top_paths, beam_width=beam_width,
blank_index=blank_index, space_index=space_index, merge_repeated=False)

def decode(self, probs, sizes=None):
sizes = sizes.cpu() if sizes is not None else None
Expand Down
26 changes: 12 additions & 14 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import sys
import time

import torch
from decoder import GreedyDecoder, BeamCTCDecoder

from torch.autograd import Variable

from data.data_loader import SpectrogramParser
from decoder import GreedyDecoder, BeamCTCDecoder, Scorer, KenLMScorer
from model import DeepSpeech

parser = argparse.ArgumentParser(description='DeepSpeech prediction')
Expand All @@ -18,8 +18,10 @@
parser.add_argument('--decoder', default="greedy", choices=["greedy", "beam"], type=str, help="Decoder to use")
beam_args = parser.add_argument_group("Beam Decode Options", "Configurations options for the CTC Beam Search decoder")
beam_args.add_argument('--beam_width', default=10, type=int, help='Beam width to use')
beam_args.add_argument('--lm_path', default=None, type=str, help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)')
beam_args.add_argument('--trie_path', default=None, type=str, help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)')
beam_args.add_argument('--lm_path', default=None, type=str,
help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)')
beam_args.add_argument('--trie_path', default=None, type=str,
help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)')
beam_args.add_argument('--lm_alpha', default=0.8, type=float, help='Language model weight')
beam_args.add_argument('--lm_beta1', default=1, type=float, help='Language model word bonus (all words)')
beam_args.add_argument('--lm_beta2', default=1, type=float, help='Language model word bonus (IV words)')
Expand All @@ -33,15 +35,10 @@
audio_conf = DeepSpeech.get_audio_conf(model)

if args.decoder == "beam":
scorer = None
if args.lm_path is not None:
scorer = KenLMScorer(labels, args.lm_path, args.trie_path)
scorer.set_lm_weight(args.lm_alpha)
scorer.set_word_weight(args.lm_beta1)
scorer.set_valid_word_weight(args.lm_beta2)
else:
scorer = Scorer()
decoder = BeamCTCDecoder(labels, scorer, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '), blank_index=labels.index('_'))
decoder = BeamCTCDecoder(labels, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '),
blank_index=labels.index('_'), lm_path=args.lm_path,
trie_path=args.trie_path, lm_alpha=args.lm_alpha, lm_beta1=args.lm_beta1,
lm_beta2=args.lm_beta2)
else:
decoder = GreedyDecoder(labels, space_index=labels.index(' '), blank_index=labels.index('_'))

Expand All @@ -56,4 +53,5 @@
t1 = time.time()

print(decoded_output[0])
print("Decoded {0:.2f} seconds of audio in {1:.2f} seconds".format(spect.size(3)*audio_conf['window_stride'], t1-t0), file=sys.stderr)
print("Decoded {0:.2f} seconds of audio in {1:.2f} seconds".format(spect.size(3) * audio_conf['window_stride'],
t1 - t0), file=sys.stderr)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ torch
visdom
wget
librosa
tqdm
27 changes: 12 additions & 15 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import argparse
import json

import torch
from torch.autograd import Variable
from tqdm import tqdm

from decoder import GreedyDecoder, BeamCTCDecoder

from data.data_loader import SpectrogramDataset, AudioDataLoader
from decoder import GreedyDecoder, BeamCTCDecoder, Scorer, KenLMScorer
from model import DeepSpeech

parser = argparse.ArgumentParser(description='DeepSpeech prediction')
Expand All @@ -19,8 +19,10 @@
parser.add_argument('--decoder', default="greedy", choices=["greedy", "beam"], type=str, help="Decoder to use")
beam_args = parser.add_argument_group("Beam Decode Options", "Configurations options for the CTC Beam Search decoder")
beam_args.add_argument('--beam_width', default=10, type=int, help='Beam width to use')
beam_args.add_argument('--lm_path', default=None, type=str, help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)')
beam_args.add_argument('--trie_path', default=None, type=str, help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)')
beam_args.add_argument('--lm_path', default=None, type=str,
help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)')
beam_args.add_argument('--trie_path', default=None, type=str,
help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)')
beam_args.add_argument('--lm_alpha', default=0.8, type=float, help='Language model weight')
beam_args.add_argument('--lm_beta1', default=1, type=float, help='Language model word bonus (all words)')
beam_args.add_argument('--lm_beta2', default=1, type=float, help='Language model word bonus (IV words)')
Expand All @@ -34,15 +36,10 @@
audio_conf = DeepSpeech.get_audio_conf(model)

if args.decoder == "beam":
scorer = None
if args.lm_path is not None:
scorer = KenLMScorer(labels, args.lm_path, args.trie_path)
scorer.set_lm_weight(args.lm_alpha)
scorer.set_word_weight(args.lm_beta1)
scorer.set_valid_word_weight(args.lm_beta2)
else:
scorer = Scorer()
decoder = BeamCTCDecoder(labels, scorer, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '), blank_index=labels.index('_'))
decoder = BeamCTCDecoder(labels, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '),
blank_index=labels.index('_'), lm_path=args.lm_path,
trie_path=args.trie_path, lm_alpha=args.lm_alpha, lm_beta1=args.lm_beta1,
lm_beta2=args.lm_beta2)
else:
decoder = GreedyDecoder(labels, space_index=labels.index(' '), blank_index=labels.index('_'))

Expand All @@ -51,7 +48,7 @@
test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size,
num_workers=args.num_workers)
total_cer, total_wer = 0, 0
for i, (data) in enumerate(test_loader):
for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader)):
inputs, targets, input_percentages, target_sizes = data

inputs = Variable(inputs, volatile=True)
Expand Down

0 comments on commit fc8efd8

Please sign in to comment.