Skip to content

Commit

Permalink
Merge pull request #13 from david-ryan-snyder/xvector
Browse files Browse the repository at this point in the history
xvector: nnet3-xvector-merge-egs, etc
  • Loading branch information
danpovey committed Feb 22, 2016
2 parents c3c99b0 + 72dfb91 commit 05f9c36
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 8 deletions.
1 change: 1 addition & 0 deletions egs/wsj/s5/steps/nnet3/xvector/get_egs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ echo $num_train_frames >$dir/info/num_frames
num_train_archives=$[($num_train_frames*$num_repeats)/$frames_per_iter + 1]
echo "$0: producing $num_train_archives archives for training"
echo $num_train_archives > $dir/info/num_archives
echo $num_diagnostic_archives > $dir/info/num_diagnostic_archives


if [ $nj -gt $num_train_archives ]; then
Expand Down
14 changes: 7 additions & 7 deletions egs/wsj/s5/steps/nnet3/xvector/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,12 @@ while [ $x -lt $num_iters ]; do

# Set off jobs doing some diagnostics, in the background.
# Use the egs dir from the previous iteration for the diagnostics
$cmd JOB=1:$num_diagnostic_archives $dir/log/compute_prob_valid.$x.log \
nnet3-xvector-compute-prob "$dir/$x.raw - |" \
"ark:nnet3-merge-egs ark:$egs_dir/valid_diagnostic.JOB.egs ark:- |" &
$cmd JOB=1:$num_diagnostic_archives $dir/log/compute_prob_train.$x.log \
nnet3-xvector-compute-prob "nnet3-am-copy --raw=true $dir/$x.raw - |" \
"ark:nnet3-merge-egs ark:$egs_dir/train_diagnostic.JOB.egs ark:- |" &
$cmd JOB=1:$num_diagnostic_archives $dir/log/compute_prob_valid.$x.JOB.log \
nnet3-xvector-compute-prob $dir/$x.raw \
"ark:nnet3-xvector-merge-egs ark:$egs_dir/valid_diagnostic_egs.JOB.ark ark:- |" &
$cmd JOB=1:$num_diagnostic_archives $dir/log/compute_prob_train.$x.JOB.log \
nnet3-xvector-compute-prob $dir/$x.raw \
"ark:nnet3-xvector-merge-egs ark:$egs_dir/train_diagnostic_egs.JOB.ark ark:- |" &

if [ $x -gt 0 ]; then
$cmd $dir/log/progress.$x.log \
Expand Down Expand Up @@ -176,7 +176,7 @@ while [ $x -lt $num_iters ]; do
nnet3-xvector-train $parallel_train_opts --print-interval=10 \
--max-param-change=$max_param_change \
$dir/$x.raw \
"ark:nnet3-copy-egs ark:$egs_dir/egs.$archive.ark ark:- | nnet3-shuffle-egs --buffer-size=$shuffle_buffer_size --srand=$x ark:- ark:-| nnet3-merge-egs --minibatch-size=$minibatch_size --measure-output-frames=false --discard-partial-minibatches=true ark:- ark:- |" \
"ark:nnet3-copy-egs ark:$egs_dir/egs.$archive.ark ark:- | nnet3-shuffle-egs --buffer-size=$shuffle_buffer_size --srand=$x ark:- ark:-| nnet3-xvector-merge-egs --minibatch-size=$minibatch_size --discard-partial-minibatches=true ark:- ark:- |" \
$dir/$[$x+1].$n.raw || touch $dir/.error &
done
wait
Expand Down
3 changes: 2 additions & 1 deletion src/xvectorbin/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ LDFLAGS += $(CUDA_LDFLAGS)
LDLIBS += $(CUDA_LDLIBS)

BINFILES = nnet3-xvector-get-egs nnet3-xvector-compute-prob \
nnet3-xvector-show-progress nnet3-xvector-train
nnet3-xvector-show-progress nnet3-xvector-train \
nnet3-xvector-merge-egs

OBJFILES =

Expand Down
108 changes: 108 additions & 0 deletions src/xvectorbin/nnet3-xvector-merge-egs.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// xvectorbin/nnet3-xvector-merge-egs.cc

// Copyright 2016 David Snyder
// 2012-2015 Johns Hopkins University (author: Daniel Povey)
// 2014 Vimal Manohar

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "nnet3/nnet-example.h"
#include "nnet3/nnet-example-utils.h"

int main(int argc, char *argv[]) {
try {
using namespace kaldi;
using namespace kaldi::nnet3;
typedef kaldi::int32 int32;

const char *usage =
"This copies nnet examples for xvector training from input to\n"
"output but while doing so it merges many NnetExample objects\n"
"into one, forming a minibatch consisting of single NnetExample.\n"
"Unlike nnet3-merge-egs, this binary does not expect the examples\n"
"to have any output.\n"
"\n"
"Usage: nnet3-xvector-merge-egs [options] <egs-rspecifier> "
"<egs-wspecifier>\n"
"e.g.\n"
"nnet3-xvector-merge-egs --minibatch-size=512 ark:1.egs ark:- "
"| nnet3-xvector-train ... \n"
"See also nnet3-copy-egs and nnet3-merge-egs\n";

bool compress = false;
int32 minibatch_size = 512;
bool discard_partial_minibatches = false;

ParseOptions po(usage);
po.Register("minibatch-size", &minibatch_size, "Target size of "
"minibatches when merging.");
po.Register("compress", &compress, "If true, compress the output examples "
"(not recommended unless you are writing to disk)");
po.Register("discard-partial-minibatches", &discard_partial_minibatches,
"discard any partial minibatches of 'uneven' size that may be "
"encountered at the end.");

po.Read(argc, argv);

if (po.NumArgs() != 2) {
po.PrintUsage();
exit(1);
}

std::string examples_rspecifier = po.GetArg(1),
examples_wspecifier = po.GetArg(2);

SequentialNnetExampleReader example_reader(examples_rspecifier);
NnetExampleWriter example_writer(examples_wspecifier);

std::vector<NnetExample> examples;
examples.reserve(minibatch_size);

int32 num_read = 0, num_written = 0;
while (!example_reader.Done()) {
const NnetExample &cur_eg = example_reader.Value();
examples.resize(examples.size() + 1);
examples.back() = cur_eg;
bool minibatch_ready =
static_cast<int32>(examples.size()) >= minibatch_size;

// Do Next() now, so we can test example_reader.Done() below .
example_reader.Next();
num_read++;

if (minibatch_ready || (!discard_partial_minibatches &&
(example_reader.Done() && !examples.empty()))) {
NnetExample merged_eg;
MergeExamples(examples, compress, &merged_eg);
std::ostringstream ostr;
ostr << "merged-" << num_written;
num_written++;
std::string output_key = ostr.str();
example_writer.Write(output_key, merged_eg);
examples.clear();
}
}
KALDI_LOG << "Merged " << num_read << " egs to " << num_written << '.';
return (num_written != 0 ? 0 : 1);
} catch(const std::exception &e) {
std::cerr << e.what() << '\n';
return -1;
}
}


0 comments on commit 05f9c36

Please sign in to comment.