From 461746ae372823dda312398fe96e72fa4374e293 Mon Sep 17 00:00:00 2001 From: Ilya Platonov Date: Mon, 11 Apr 2016 16:35:08 -0700 Subject: [PATCH] Preparing code to move into kaldi repository. Fixing stuff and renaming into am_nnet. --- egs/apiai_decode/s5/README | 37 ---------------- src/nnet3/Makefile | 2 +- ...le.cc => online-nnet3-decodable-simple.cc} | 42 +++++++++---------- ...able.h => online-nnet3-decodable-simple.h} | 18 ++++---- src/online2/online-nnet3-decoding.cc | 4 +- src/online2/online-nnet3-decoding.h | 6 +-- .../online2-wav-nnet3-latgen-faster.cc | 6 +-- 7 files changed, 39 insertions(+), 76 deletions(-) delete mode 100644 egs/apiai_decode/s5/README rename src/nnet3/{online-nnet3-decodable.cc => online-nnet3-decodable-simple.cc} (84%) rename src/nnet3/{online-nnet3-decodable.h => online-nnet3-decodable-simple.h} (91%) diff --git a/egs/apiai_decode/s5/README b/egs/apiai_decode/s5/README deleted file mode 100644 index 919ddaa30b2..00000000000 --- a/egs/apiai_decode/s5/README +++ /dev/null @@ -1,37 +0,0 @@ -Example scripts on how to use a pre-trained chain enlgish model and kaldi base code to recognize any number of wav files. - -IMPORTANT: wav files must be in 16kHz, 16 bit little-endian format. - -Model: -English pretrained model were released by Api.ai under Creative Commons Attribution-ShareAlike 4.0 International Public License. -- Acustic data is mostly mobile recorded data -- Language model is based on Assistant.ai logs and good for understanding short commands, like "Wake me up at 7 am" -For more details, visit https://github.com/api-ai/api-ai-english-asr-model - -Usage: -- Ensure kaldi is compiled and this scripts are inside kaldi/egs// directory -- Run ./download-model.sh - to download pretrained chain model -- Run ./recognize-wav.sh test1.wav test2.wav to do recognition -- See output for recognition results - -Using steps/nnet3/decode.sh script: -You can use kaldi steps/nnet3/decode.sh, which will decode data and calculate Word Error Rate (WER) for it. -Steps: -- Run recognize-wav.sh test1.wav test2.wav, it will make data dir, calculate mfcc features for it and do decoding, you need only first two steps out of it -- If you want WER then edit data/test-corpus/text and replace NO_TRANSCRIPTION with expected text transcription for every wav file -- steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 --cmd run.pl --nj 1 exp/api.ai-model/ data/test-corpus/ exp/api.ai-model/decode/ -- See exp/api.ai-model/decode/wer* files for WER and exp/api.ai-model/decode/log/ files for decoding output - -Online Decoder: -At the moment kaldi does not support online decoding for nnet3 models, but decoders can be found here https://github.com/api-ai/kaldi/ . -See http://kaldi.sourceforge.net/online_decoding.html for more information about kaldi online decoding. -Steps: - - Run ./local/create-corpus.sh data/test-corpus/ test1.wav test2.wav (or just run recognize-wav.sh) to create corpus - - If you want WER then edit data/test-corpus/text and replace NO_TRANSCRIPTION with expected text transcription for every wav file - - Make config file exp/api.ai-model/online.conf with following content -==CONTENT START== - --feature-type=mfcc - --mfcc-config=exp/api.ai-model/mfcc.conf -==CONTENT END== - - Run steps/online/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 --cmd run.pl --nj 1 exp/api.ai-model/ data/test-corpus/ exp/api.ai-model/decode/ - - See exp/api.ai-model/decode/wer* files for WER and exp/api.ai-model/decode/log/ files for decoding output diff --git a/src/nnet3/Makefile b/src/nnet3/Makefile index 658a74c6c5e..e3aa89b943a 100644 --- a/src/nnet3/Makefile +++ b/src/nnet3/Makefile @@ -28,7 +28,7 @@ OBJFILES = nnet-common.o nnet-compile.o nnet-component-itf.o \ discriminative-supervision.o nnet-discriminative-example.o \ nnet-discriminative-diagnostics.o \ discriminative-training.o nnet-discriminative-training.o \ - online-nnet3-decodable.o + online-nnet3-decodable-simple.o LIBNAME = kaldi-nnet3 diff --git a/src/nnet3/online-nnet3-decodable.cc b/src/nnet3/online-nnet3-decodable-simple.cc similarity index 84% rename from src/nnet3/online-nnet3-decodable.cc rename to src/nnet3/online-nnet3-decodable-simple.cc index e5a23002737..3090aa487c8 100644 --- a/src/nnet3/online-nnet3-decodable.cc +++ b/src/nnet3/online-nnet3-decodable-simple.cc @@ -18,37 +18,37 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. -#include "nnet3/online-nnet3-decodable.h" +#include #include "nnet3/nnet-utils.h" namespace kaldi { namespace nnet3 { -DecodableNnet3Online::DecodableNnet3Online( - const AmNnetSimple &nnet, +DecodableNnet3SimpleOnline::DecodableNnet3SimpleOnline( + const AmNnetSimple &am_nnet, const TransitionModel &trans_model, const DecodableNnet3OnlineOptions &opts, OnlineFeatureInterface *input_feats): - compiler_(nnet.GetNnet(), opts.optimize_config), + compiler_(am_nnet.GetNnet(), opts.optimize_config), features_(input_feats), - nnet_(nnet), + am_nnet_(am_nnet), trans_model_(trans_model), opts_(opts), feat_dim_(input_feats->Dim()), - num_pdfs_(nnet.GetNnet().OutputDim("output")), + num_pdfs_(am_nnet.GetNnet().OutputDim("output")), begin_frame_(-1) { KALDI_ASSERT(opts_.max_nnet_batch_size > 0); - log_priors_ = nnet_.Priors(); + log_priors_ = am_nnet_.Priors(); KALDI_ASSERT((log_priors_.Dim() == 0 || log_priors_.Dim() == trans_model_.NumPdfs()) && "Priors in neural network must match with transition model (if exist)."); - ComputeSimpleNnetContext(nnet_.GetNnet(), &left_context_, &right_context_); + ComputeSimpleNnetContext(am_nnet_.GetNnet(), &left_context_, &right_context_); log_priors_.ApplyLog(); } -BaseFloat DecodableNnet3Online::LogLikelihood(int32 frame, int32 index) { +BaseFloat DecodableNnet3SimpleOnline::LogLikelihood(int32 frame, int32 index) { ComputeForFrame(frame); int32 pdf_id = trans_model_.TransitionIdToPdf(index); KALDI_ASSERT(frame >= begin_frame_ && @@ -57,12 +57,12 @@ BaseFloat DecodableNnet3Online::LogLikelihood(int32 frame, int32 index) { } -bool DecodableNnet3Online::IsLastFrame(int32 frame) const { +bool DecodableNnet3SimpleOnline::IsLastFrame(int32 frame) const { KALDI_ASSERT(false && "Method is not imlemented"); return false; } -int32 DecodableNnet3Online::NumFramesReady() const { +int32 DecodableNnet3SimpleOnline::NumFramesReady() const { int32 features_ready = features_->NumFramesReady(); if (features_ready == 0) return 0; @@ -70,18 +70,18 @@ int32 DecodableNnet3Online::NumFramesReady() const { if (opts_.pad_input) { // normal case... we'll pad with duplicates of first + last frame to get the // required left and right context. - if (input_finished) return subsampling(features_ready); - else return std::max(0, subsampling(features_ready - right_context_)); + if (input_finished) return NumSubsampledFrames(features_ready); + else return std::max(0, NumSubsampledFrames(features_ready - right_context_)); } else { - return std::max(0, subsampling(features_ready - right_context_ - left_context_)); + return std::max(0, NumSubsampledFrames(features_ready - right_context_ - left_context_)); } } -int32 DecodableNnet3Online::subsampling(int32 num_frames) const { +int32 DecodableNnet3SimpleOnline::NumSubsampledFrames(int32 num_frames) const { return (num_frames) / opts_.frame_subsampling_factor; } -void DecodableNnet3Online::ComputeForFrame(int32 subsampled_frame) { +void DecodableNnet3SimpleOnline::ComputeForFrame(int32 subsampled_frame) { int32 features_ready = features_->NumFramesReady(); bool input_finished = features_->IsLastFrame(features_ready - 1); KALDI_ASSERT(subsampled_frame >= 0); @@ -118,13 +118,13 @@ void DecodableNnet3Online::ComputeForFrame(int32 subsampled_frame) { features_->GetFrame(t_modified, &row); } - int32 num_subsampled_frames = subsampling(input_frame_end - input_frame_begin - + int32 num_subsampled_frames = NumSubsampledFrames(input_frame_end - input_frame_begin - left_context_ - right_context_); // I'm not checking if the input feature vector is ok. // It should be done, but I'm not sure if it is the best place. // Maybe a new "nnet3 feature pipeline"? - int32 mfcc_dim = nnet_.GetNnet().InputDim("input"); - int32 ivector_dim = nnet_.GetNnet().InputDim("ivector"); + int32 mfcc_dim = am_nnet_.GetNnet().InputDim("input"); + int32 ivector_dim = am_nnet_.GetNnet().InputDim("ivector"); // MFCCs in the left chunk SubMatrix mfcc_mat = features.ColRange(0,mfcc_dim); @@ -143,7 +143,7 @@ void DecodableNnet3Online::ComputeForFrame(int32 subsampled_frame) { begin_frame_ = subsampled_frame; } -void DecodableNnet3Online::DoNnetComputation( +void DecodableNnet3SimpleOnline::DoNnetComputation( int32 input_t_start, const MatrixBase &input_feats, const VectorBase &ivector, @@ -182,7 +182,7 @@ void DecodableNnet3Online::DoNnetComputation( const NnetComputation *computation = compiler_.Compile(request); Nnet *nnet_to_update = NULL; // we're not doing any update. NnetComputer computer(opts_.compute_config, *computation, - nnet_.GetNnet(), nnet_to_update); + am_nnet_.GetNnet(), nnet_to_update); CuMatrix input_feats_cu(input_feats); computer.AcceptInput("input", &input_feats_cu); diff --git a/src/nnet3/online-nnet3-decodable.h b/src/nnet3/online-nnet3-decodable-simple.h similarity index 91% rename from src/nnet3/online-nnet3-decodable.h rename to src/nnet3/online-nnet3-decodable-simple.h index 09a84752644..b5b85cdc412 100644 --- a/src/nnet3/online-nnet3-decodable.h +++ b/src/nnet3/online-nnet3-decodable-simple.h @@ -1,4 +1,4 @@ -// nnet3/online-nnet3-decodable.h +// nnet3/online-nnet3-decodable-simple.h // Copyright 2014 Johns Hopkins Universithy (author: Daniel Povey) // 2016 Api.ai (Author: Ilya Platonov) @@ -62,9 +62,9 @@ struct DecodableNnet3OnlineOptions { "frames (this will rarely make a difference)"); opts->Register("frame-subsampling-factor", &frame_subsampling_factor, - "Required if the frame-rate of the output (e.g. in 'chain' " - "models) is less than the frame-rate of the original " - "alignment."); + "Required if the frame-rate of the output (e.g. in 'chain' " + "models) is less than the frame-rate of the original " + "alignment."); // register the optimization options with the prefix "optimization". ParseOptions optimization_opts("optimization", opts); @@ -84,9 +84,9 @@ struct DecodableNnet3OnlineOptions { feature input from a matrix. */ -class DecodableNnet3Online: public DecodableInterface { +class DecodableNnet3SimpleOnline: public DecodableInterface { public: - DecodableNnet3Online(const AmNnetSimple &nnet, + DecodableNnet3SimpleOnline(const AmNnetSimple &am_nnet, const TransitionModel &trans_model, const DecodableNnet3OnlineOptions &opts, OnlineFeatureInterface *input_feats); @@ -108,7 +108,7 @@ class DecodableNnet3Online: public DecodableInterface { /// them (and possibly for some succeeding frames) void ComputeForFrame(int32 frame); // corrects number of frames by frame_subsampling_factor; - int32 subsampling(int32) const; + int32 NumSubsampledFrames(int32) const; void DoNnetComputation( int32 input_t_start, @@ -120,7 +120,7 @@ class DecodableNnet3Online: public DecodableInterface { CachingOptimizingCompiler compiler_; OnlineFeatureInterface *features_; - const AmNnetSimple &nnet_; + const AmNnetSimple &am_nnet_; const TransitionModel &trans_model_; DecodableNnet3OnlineOptions opts_; CuVector log_priors_; // log-priors taken from the model. @@ -143,7 +143,7 @@ class DecodableNnet3Online: public DecodableInterface { // opts_.max_nnet_batch_size. Matrix scaled_loglikes_; - KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableNnet3Online); + KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableNnet3SimpleOnline); }; } // namespace nnet3 diff --git a/src/online2/online-nnet3-decoding.cc b/src/online2/online-nnet3-decoding.cc index 97980a721d1..5583c4931e9 100644 --- a/src/online2/online-nnet3-decoding.cc +++ b/src/online2/online-nnet3-decoding.cc @@ -27,13 +27,13 @@ namespace kaldi { SingleUtteranceNnet3Decoder::SingleUtteranceNnet3Decoder( const OnlineNnet3DecodingConfig &config, const TransitionModel &tmodel, - const nnet3::AmNnetSimple &model, + const nnet3::AmNnetSimple &am_model, const fst::Fst &fst, OnlineNnet2FeaturePipeline *feature_pipeline): config_(config), feature_pipeline_(feature_pipeline), tmodel_(tmodel), - decodable_(model, tmodel, config.decodable_opts, feature_pipeline), + decodable_(am_model, tmodel, config.decodable_opts, feature_pipeline), decoder_(fst, config.decoder_opts) { decoder_.InitDecoding(); } diff --git a/src/online2/online-nnet3-decoding.h b/src/online2/online-nnet3-decoding.h index f9b440522c2..edfb9ef5f20 100644 --- a/src/online2/online-nnet3-decoding.h +++ b/src/online2/online-nnet3-decoding.h @@ -26,10 +26,10 @@ #include #include +#include "../nnet3/online-nnet3-decodable-simple.h" #include "matrix/matrix-lib.h" #include "util/common-utils.h" #include "base/kaldi-error.h" -#include "nnet3/online-nnet3-decodable.h" #include "online2/online-nnet2-feature-pipeline.h" #include "online2/online-endpoint.h" #include "decoder/lattice-faster-online-decoder.h" @@ -71,7 +71,7 @@ class SingleUtteranceNnet3Decoder { // class, it's owned externally. SingleUtteranceNnet3Decoder(const OnlineNnet3DecodingConfig &config, const TransitionModel &tmodel, - const nnet3::AmNnetSimple &model, + const nnet3::AmNnetSimple &am_model, const fst::Fst &fst, OnlineNnet2FeaturePipeline *feature_pipeline); @@ -116,7 +116,7 @@ class SingleUtteranceNnet3Decoder { const TransitionModel &tmodel_; - nnet3::DecodableNnet3Online decodable_; + nnet3::DecodableNnet3SimpleOnline decodable_; LatticeFasterOnlineDecoder decoder_; diff --git a/src/online2bin/online2-wav-nnet3-latgen-faster.cc b/src/online2bin/online2-wav-nnet3-latgen-faster.cc index 634a679e964..68ab093dadc 100644 --- a/src/online2bin/online2-wav-nnet3-latgen-faster.cc +++ b/src/online2bin/online2-wav-nnet3-latgen-faster.cc @@ -151,12 +151,12 @@ int main(int argc, char *argv[]) { } TransitionModel trans_model; - nnet3::AmNnetSimple nnet; + nnet3::AmNnetSimple am_nnet; { bool binary; Input ki(nnet3_rxfilename, &binary); trans_model.Read(ki.Stream(), binary); - nnet.Read(ki.Stream(), binary); + am_nnet.Read(ki.Stream(), binary); } fst::Fst *decode_fst = ReadFstKaldi(fst_rxfilename); @@ -203,7 +203,7 @@ int main(int argc, char *argv[]) { SingleUtteranceNnet3Decoder decoder(nnet3_decoding_config, trans_model, - nnet, + am_nnet, *decode_fst, &feature_pipeline); OnlineTimer decoding_timer(utt);