Skip to content

Commit

Permalink
Merge pull request #449 from SeanNaren/feature/test
Browse files Browse the repository at this point in the history
Evaluation fixes
  • Loading branch information
Sean Naren authored Aug 15, 2019
2 parents e4d0d42 + b9a3e07 commit d5dbadf
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 65 deletions.
65 changes: 39 additions & 26 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,10 @@
parser.add_argument('--output-path', default=None, type=str, help="Where to save raw acoustic output")
parser = add_decoder_args(parser)
parser.add_argument('--save-output', action="store_true", help="Saves output of model from test")
args = parser.parse_args()

if __name__ == '__main__':
torch.set_grad_enabled(False)
device = torch.device("cuda" if args.cuda else "cpu")
model = load_model(device, args.model_path, args.cuda)

if args.decoder == "beam":
from decoder import BeamCTCDecoder

decoder = BeamCTCDecoder(model.labels, lm_path=args.lm_path, alpha=args.alpha, beta=args.beta,
cutoff_top_n=args.cutoff_top_n, cutoff_prob=args.cutoff_prob,
beam_width=args.beam_width, num_processes=args.lm_workers)
elif args.decoder == "greedy":
decoder = GreedyDecoder(model.labels, blank_index=model.labels.index('_'))
else:
decoder = None
target_decoder = GreedyDecoder(model.labels, blank_index=model.labels.index('_'))
test_dataset = SpectrogramDataset(audio_conf=model.audio_conf, manifest_filepath=args.test_manifest,
labels=model.labels, normalize=True)
test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size,
num_workers=args.num_workers)
def evaluate(test_loader, device, model, decoder, target_decoder, save_output=False, verbose=False):
model.eval()
total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0
output_data = []
for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader)):
Expand All @@ -56,7 +38,7 @@

out, output_sizes = model(inputs, input_sizes)

if args.save_output:
if save_output:
# add output to data array, and continue
output_data.append((out.cpu().numpy(), output_sizes.numpy()))

Expand All @@ -69,17 +51,48 @@
total_wer += wer_inst
total_cer += cer_inst
num_tokens += len(reference.split())
num_chars += len(reference)
if args.verbose:
num_chars += len(reference.replace(' ', ''))
if verbose:
print("Ref:", reference.lower())
print("Hyp:", transcript.lower())
print("WER:", float(wer_inst) / len(reference.split()), "CER:", float(cer_inst) / len(reference), "\n")

print("WER:", float(wer_inst) / len(reference.split()),
"CER:", float(cer_inst) / len(reference.replace(' ', '')), "\n")
wer = float(total_wer) / num_tokens
cer = float(total_cer) / num_chars
return wer * 100, cer * 100, output_data


if __name__ == '__main__':
args = parser.parse_args()
torch.set_grad_enabled(False)
device = torch.device("cuda" if args.cuda else "cpu")
model = load_model(device, args.model_path, args.cuda)

if args.decoder == "beam":
from decoder import BeamCTCDecoder

decoder = BeamCTCDecoder(model.labels, lm_path=args.lm_path, alpha=args.alpha, beta=args.beta,
cutoff_top_n=args.cutoff_top_n, cutoff_prob=args.cutoff_prob,
beam_width=args.beam_width, num_processes=args.lm_workers)
elif args.decoder == "greedy":
decoder = GreedyDecoder(model.labels, blank_index=model.labels.index('_'))
else:
decoder = None
target_decoder = GreedyDecoder(model.labels, blank_index=model.labels.index('_'))
test_dataset = SpectrogramDataset(audio_conf=model.audio_conf, manifest_filepath=args.test_manifest,
labels=model.labels, normalize=True)
test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size,
num_workers=args.num_workers)
wer, cer, output_data = evaluate(test_loader=test_loader,
device=device,
model=model,
decoder=decoder,
target_decoder=target_decoder,
save_output=args.save_output,
verbose=args.verbose)

print('Test Summary \t'
'Average WER {wer:.3f}\t'
'Average CER {cer:.3f}\t'.format(wer=wer * 100, cer=cer * 100))
'Average CER {cer:.3f}\t'.format(wer=wer, cer=cer))
if args.save_output:
np.save(args.output_path, output_data)
53 changes: 14 additions & 39 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
import torch.utils.data.distributed
from apex.fp16_utils import FP16_Optimizer
from apex.parallel import DistributedDataParallel
from tqdm import tqdm
from warpctc_pytorch import CTCLoss

from data.data_loader import AudioDataLoader, SpectrogramDataset, BucketingSampler, DistributedBucketingSampler
from decoder import GreedyDecoder
from logger import VisdomLogger, TensorBoardLogger
from model import DeepSpeech, supported_rnns
from test import evaluate
from utils import convert_model_to_half, reduce_tensor, check_loss

parser = argparse.ArgumentParser(description='DeepSpeech training')
Expand Down Expand Up @@ -150,7 +150,7 @@ def update(self, val, n=1):
print("Loading checkpoint model %s" % args.continue_from)
package = torch.load(args.continue_from, map_location=lambda storage, loc: storage)
model = DeepSpeech.load_model_package(package)
labels = model.labels
labels = model.labels
audio_conf = model.audio_conf
if not args.finetune: # Don't want to restart training
optim_state = package['optim_dict']
Expand Down Expand Up @@ -304,44 +304,19 @@ def update(self, val, n=1):
'Average Loss {loss:.3f}\t'.format(epoch + 1, epoch_time=epoch_time, loss=avg_loss))

start_iter = 0 # Reset start iteration for next epoch
total_cer, total_wer = 0, 0
model.eval()
with torch.no_grad():
for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader)):
inputs, targets, input_percentages, target_sizes = data
input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
inputs = inputs.to(device)

# unflatten targets
split_targets = []
offset = 0
for size in target_sizes:
split_targets.append(targets[offset:offset + size])
offset += size

out, output_sizes = model(inputs, input_sizes)

decoded_output, _ = decoder.decode(out, output_sizes)
target_strings = decoder.convert_to_strings(split_targets)
wer, cer = 0, 0
for x in range(len(target_strings)):
transcript, reference = decoded_output[x][0], target_strings[x][0]
wer += decoder.wer(transcript, reference) / float(len(reference.split()))
cer += decoder.cer(transcript, reference) / float(len(reference))
total_cer += cer
total_wer += wer
del out
wer = total_wer / len(test_loader.dataset)
cer = total_cer / len(test_loader.dataset)
wer *= 100
cer *= 100
loss_results[epoch] = avg_loss
wer_results[epoch] = wer
cer_results[epoch] = cer
print('Validation Summary Epoch: [{0}]\t'
'Average WER {wer:.3f}\t'
'Average CER {cer:.3f}\t'.format(
epoch + 1, wer=wer, cer=cer))
wer, cer, output_data = evaluate(test_loader=test_loader,
device=device,
model=model,
decoder=decoder,
target_decoder=decoder)
loss_results[epoch] = avg_loss
wer_results[epoch] = wer
cer_results[epoch] = cer
print('Validation Summary Epoch: [{0}]\t'
'Average WER {wer:.3f}\t'
'Average CER {cer:.3f}\t'.format(
epoch + 1, wer=wer, cer=cer))

values = {
'loss_results': loss_results,
Expand Down

0 comments on commit d5dbadf

Please sign in to comment.