diff --git a/egs/wsj/s5/steps/nnet3/xvector/get_egs.sh b/egs/wsj/s5/steps/nnet3/xvector/get_egs.sh index 2ab81395d47..7c74fff6090 100755 --- a/egs/wsj/s5/steps/nnet3/xvector/get_egs.sh +++ b/egs/wsj/s5/steps/nnet3/xvector/get_egs.sh @@ -161,6 +161,7 @@ echo $num_train_frames >$dir/info/num_frames num_train_archives=$[($num_train_frames*$num_repeats)/$frames_per_iter + 1] echo "$0: producing $num_train_archives archives for training" echo $num_train_archives > $dir/info/num_archives +echo $num_diagnostic_archives > $dir/info/num_diagnostic_archives if [ $nj -gt $num_train_archives ]; then diff --git a/egs/wsj/s5/steps/nnet3/xvector/train.sh b/egs/wsj/s5/steps/nnet3/xvector/train.sh index b66c95b3c39..c57d66f7019 100755 --- a/egs/wsj/s5/steps/nnet3/xvector/train.sh +++ b/egs/wsj/s5/steps/nnet3/xvector/train.sh @@ -132,12 +132,12 @@ while [ $x -lt $num_iters ]; do # Set off jobs doing some diagnostics, in the background. # Use the egs dir from the previous iteration for the diagnostics - $cmd JOB=1:$num_diagnostic_archives $dir/log/compute_prob_valid.$x.log \ - nnet3-xvector-compute-prob "$dir/$x.raw - |" \ - "ark:nnet3-merge-egs ark:$egs_dir/valid_diagnostic.JOB.egs ark:- |" & - $cmd JOB=1:$num_diagnostic_archives $dir/log/compute_prob_train.$x.log \ - nnet3-xvector-compute-prob "nnet3-am-copy --raw=true $dir/$x.raw - |" \ - "ark:nnet3-merge-egs ark:$egs_dir/train_diagnostic.JOB.egs ark:- |" & + $cmd JOB=1:$num_diagnostic_archives $dir/log/compute_prob_valid.$x.JOB.log \ + nnet3-xvector-compute-prob $dir/$x.raw \ + "ark:nnet3-xvector-merge-egs ark:$egs_dir/valid_diagnostic_egs.JOB.ark ark:- |" & + $cmd JOB=1:$num_diagnostic_archives $dir/log/compute_prob_train.$x.JOB.log \ + nnet3-xvector-compute-prob $dir/$x.raw \ + "ark:nnet3-xvector-merge-egs ark:$egs_dir/train_diagnostic_egs.JOB.ark ark:- |" & if [ $x -gt 0 ]; then $cmd $dir/log/progress.$x.log \ @@ -176,7 +176,7 @@ while [ $x -lt $num_iters ]; do nnet3-xvector-train $parallel_train_opts --print-interval=10 \ --max-param-change=$max_param_change \ $dir/$x.raw \ - "ark:nnet3-copy-egs ark:$egs_dir/egs.$archive.ark ark:- | nnet3-shuffle-egs --buffer-size=$shuffle_buffer_size --srand=$x ark:- ark:-| nnet3-merge-egs --minibatch-size=$minibatch_size --measure-output-frames=false --discard-partial-minibatches=true ark:- ark:- |" \ + "ark:nnet3-copy-egs ark:$egs_dir/egs.$archive.ark ark:- | nnet3-shuffle-egs --buffer-size=$shuffle_buffer_size --srand=$x ark:- ark:-| nnet3-xvector-merge-egs --minibatch-size=$minibatch_size --discard-partial-minibatches=true ark:- ark:- |" \ $dir/$[$x+1].$n.raw || touch $dir/.error & done wait diff --git a/src/xvectorbin/Makefile b/src/xvectorbin/Makefile index 1dc1bee6e0a..e0703ab8422 100644 --- a/src/xvectorbin/Makefile +++ b/src/xvectorbin/Makefile @@ -7,7 +7,8 @@ LDFLAGS += $(CUDA_LDFLAGS) LDLIBS += $(CUDA_LDLIBS) BINFILES = nnet3-xvector-get-egs nnet3-xvector-compute-prob \ - nnet3-xvector-show-progress nnet3-xvector-train + nnet3-xvector-show-progress nnet3-xvector-train \ + nnet3-xvector-merge-egs OBJFILES = diff --git a/src/xvectorbin/nnet3-xvector-merge-egs.cc b/src/xvectorbin/nnet3-xvector-merge-egs.cc new file mode 100644 index 00000000000..28dc9d2ee18 --- /dev/null +++ b/src/xvectorbin/nnet3-xvector-merge-egs.cc @@ -0,0 +1,108 @@ +// xvectorbin/nnet3-xvector-merge-egs.cc + +// Copyright 2016 David Snyder +// 2012-2015 Johns Hopkins University (author: Daniel Povey) +// 2014 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "nnet3/nnet-example.h" +#include "nnet3/nnet-example-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace kaldi::nnet3; + typedef kaldi::int32 int32; + + const char *usage = + "This copies nnet examples for xvector training from input to\n" + "output but while doing so it merges many NnetExample objects\n" + "into one, forming a minibatch consisting of single NnetExample.\n" + "Unlike nnet3-merge-egs, this binary does not expect the examples\n" + "to have any output.\n" + "\n" + "Usage: nnet3-xvector-merge-egs [options] " + "\n" + "e.g.\n" + "nnet3-xvector-merge-egs --minibatch-size=512 ark:1.egs ark:- " + "| nnet3-xvector-train ... \n" + "See also nnet3-copy-egs and nnet3-merge-egs\n"; + + bool compress = false; + int32 minibatch_size = 512; + bool discard_partial_minibatches = false; + + ParseOptions po(usage); + po.Register("minibatch-size", &minibatch_size, "Target size of " + "minibatches when merging."); + po.Register("compress", &compress, "If true, compress the output examples " + "(not recommended unless you are writing to disk)"); + po.Register("discard-partial-minibatches", &discard_partial_minibatches, + "discard any partial minibatches of 'uneven' size that may be " + "encountered at the end."); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string examples_rspecifier = po.GetArg(1), + examples_wspecifier = po.GetArg(2); + + SequentialNnetExampleReader example_reader(examples_rspecifier); + NnetExampleWriter example_writer(examples_wspecifier); + + std::vector examples; + examples.reserve(minibatch_size); + + int32 num_read = 0, num_written = 0; + while (!example_reader.Done()) { + const NnetExample &cur_eg = example_reader.Value(); + examples.resize(examples.size() + 1); + examples.back() = cur_eg; + bool minibatch_ready = + static_cast(examples.size()) >= minibatch_size; + + // Do Next() now, so we can test example_reader.Done() below . + example_reader.Next(); + num_read++; + + if (minibatch_ready || (!discard_partial_minibatches && + (example_reader.Done() && !examples.empty()))) { + NnetExample merged_eg; + MergeExamples(examples, compress, &merged_eg); + std::ostringstream ostr; + ostr << "merged-" << num_written; + num_written++; + std::string output_key = ostr.str(); + example_writer.Write(output_key, merged_eg); + examples.clear(); + } + } + KALDI_LOG << "Merged " << num_read << " egs to " << num_written << '.'; + return (num_written != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what() << '\n'; + return -1; + } +} + +