Skip to content

Commit

Permalink
Preparing code to move into kaldi repository.
Browse files Browse the repository at this point in the history
Fixing stuff and renaming into am_nnet.
  • Loading branch information
Ilya Platonov committed Apr 12, 2016
1 parent 75f8b02 commit 461746a
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 76 deletions.
37 changes: 0 additions & 37 deletions egs/apiai_decode/s5/README

This file was deleted.

2 changes: 1 addition & 1 deletion src/nnet3/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ OBJFILES = nnet-common.o nnet-compile.o nnet-component-itf.o \
discriminative-supervision.o nnet-discriminative-example.o \
nnet-discriminative-diagnostics.o \
discriminative-training.o nnet-discriminative-training.o \
online-nnet3-decodable.o
online-nnet3-decodable-simple.o


LIBNAME = kaldi-nnet3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,37 @@
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "nnet3/online-nnet3-decodable.h"
#include <nnet3/online-nnet3-decodable-simple.h>
#include "nnet3/nnet-utils.h"

namespace kaldi {
namespace nnet3 {

DecodableNnet3Online::DecodableNnet3Online(
const AmNnetSimple &nnet,
DecodableNnet3SimpleOnline::DecodableNnet3SimpleOnline(
const AmNnetSimple &am_nnet,
const TransitionModel &trans_model,
const DecodableNnet3OnlineOptions &opts,
OnlineFeatureInterface *input_feats):
compiler_(nnet.GetNnet(), opts.optimize_config),
compiler_(am_nnet.GetNnet(), opts.optimize_config),
features_(input_feats),
nnet_(nnet),
am_nnet_(am_nnet),
trans_model_(trans_model),
opts_(opts),
feat_dim_(input_feats->Dim()),
num_pdfs_(nnet.GetNnet().OutputDim("output")),
num_pdfs_(am_nnet.GetNnet().OutputDim("output")),
begin_frame_(-1) {
KALDI_ASSERT(opts_.max_nnet_batch_size > 0);
log_priors_ = nnet_.Priors();
log_priors_ = am_nnet_.Priors();
KALDI_ASSERT((log_priors_.Dim() == 0 || log_priors_.Dim() == trans_model_.NumPdfs()) &&
"Priors in neural network must match with transition model (if exist).");

ComputeSimpleNnetContext(nnet_.GetNnet(), &left_context_, &right_context_);
ComputeSimpleNnetContext(am_nnet_.GetNnet(), &left_context_, &right_context_);
log_priors_.ApplyLog();
}



BaseFloat DecodableNnet3Online::LogLikelihood(int32 frame, int32 index) {
BaseFloat DecodableNnet3SimpleOnline::LogLikelihood(int32 frame, int32 index) {
ComputeForFrame(frame);
int32 pdf_id = trans_model_.TransitionIdToPdf(index);
KALDI_ASSERT(frame >= begin_frame_ &&
Expand All @@ -57,31 +57,31 @@ BaseFloat DecodableNnet3Online::LogLikelihood(int32 frame, int32 index) {
}


bool DecodableNnet3Online::IsLastFrame(int32 frame) const {
bool DecodableNnet3SimpleOnline::IsLastFrame(int32 frame) const {
KALDI_ASSERT(false && "Method is not imlemented");
return false;
}

int32 DecodableNnet3Online::NumFramesReady() const {
int32 DecodableNnet3SimpleOnline::NumFramesReady() const {
int32 features_ready = features_->NumFramesReady();
if (features_ready == 0)
return 0;
bool input_finished = features_->IsLastFrame(features_ready - 1);
if (opts_.pad_input) {
// normal case... we'll pad with duplicates of first + last frame to get the
// required left and right context.
if (input_finished) return subsampling(features_ready);
else return std::max<int32>(0, subsampling(features_ready - right_context_));
if (input_finished) return NumSubsampledFrames(features_ready);
else return std::max<int32>(0, NumSubsampledFrames(features_ready - right_context_));
} else {
return std::max<int32>(0, subsampling(features_ready - right_context_ - left_context_));
return std::max<int32>(0, NumSubsampledFrames(features_ready - right_context_ - left_context_));
}
}

int32 DecodableNnet3Online::subsampling(int32 num_frames) const {
int32 DecodableNnet3SimpleOnline::NumSubsampledFrames(int32 num_frames) const {
return (num_frames) / opts_.frame_subsampling_factor;
}

void DecodableNnet3Online::ComputeForFrame(int32 subsampled_frame) {
void DecodableNnet3SimpleOnline::ComputeForFrame(int32 subsampled_frame) {
int32 features_ready = features_->NumFramesReady();
bool input_finished = features_->IsLastFrame(features_ready - 1);
KALDI_ASSERT(subsampled_frame >= 0);
Expand Down Expand Up @@ -118,13 +118,13 @@ void DecodableNnet3Online::ComputeForFrame(int32 subsampled_frame) {
features_->GetFrame(t_modified, &row);
}

int32 num_subsampled_frames = subsampling(input_frame_end - input_frame_begin -
int32 num_subsampled_frames = NumSubsampledFrames(input_frame_end - input_frame_begin -
left_context_ - right_context_);
// I'm not checking if the input feature vector is ok.
// It should be done, but I'm not sure if it is the best place.
// Maybe a new "nnet3 feature pipeline"?
int32 mfcc_dim = nnet_.GetNnet().InputDim("input");
int32 ivector_dim = nnet_.GetNnet().InputDim("ivector");
int32 mfcc_dim = am_nnet_.GetNnet().InputDim("input");
int32 ivector_dim = am_nnet_.GetNnet().InputDim("ivector");
// MFCCs in the left chunk
SubMatrix<BaseFloat> mfcc_mat = features.ColRange(0,mfcc_dim);

Expand All @@ -143,7 +143,7 @@ void DecodableNnet3Online::ComputeForFrame(int32 subsampled_frame) {
begin_frame_ = subsampled_frame;
}

void DecodableNnet3Online::DoNnetComputation(
void DecodableNnet3SimpleOnline::DoNnetComputation(
int32 input_t_start,
const MatrixBase<BaseFloat> &input_feats,
const VectorBase<BaseFloat> &ivector,
Expand Down Expand Up @@ -182,7 +182,7 @@ void DecodableNnet3Online::DoNnetComputation(
const NnetComputation *computation = compiler_.Compile(request);
Nnet *nnet_to_update = NULL; // we're not doing any update.
NnetComputer computer(opts_.compute_config, *computation,
nnet_.GetNnet(), nnet_to_update);
am_nnet_.GetNnet(), nnet_to_update);

CuMatrix<BaseFloat> input_feats_cu(input_feats);
computer.AcceptInput("input", &input_feats_cu);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// nnet3/online-nnet3-decodable.h
// nnet3/online-nnet3-decodable-simple.h

// Copyright 2014 Johns Hopkins Universithy (author: Daniel Povey)
// 2016 Api.ai (Author: Ilya Platonov)
Expand Down Expand Up @@ -62,9 +62,9 @@ struct DecodableNnet3OnlineOptions {
"frames (this will rarely make a difference)");

opts->Register("frame-subsampling-factor", &frame_subsampling_factor,
"Required if the frame-rate of the output (e.g. in 'chain' "
"models) is less than the frame-rate of the original "
"alignment.");
"Required if the frame-rate of the output (e.g. in 'chain' "
"models) is less than the frame-rate of the original "
"alignment.");

// register the optimization options with the prefix "optimization".
ParseOptions optimization_opts("optimization", opts);
Expand All @@ -84,9 +84,9 @@ struct DecodableNnet3OnlineOptions {
feature input from a matrix.
*/

class DecodableNnet3Online: public DecodableInterface {
class DecodableNnet3SimpleOnline: public DecodableInterface {
public:
DecodableNnet3Online(const AmNnetSimple &nnet,
DecodableNnet3SimpleOnline(const AmNnetSimple &am_nnet,
const TransitionModel &trans_model,
const DecodableNnet3OnlineOptions &opts,
OnlineFeatureInterface *input_feats);
Expand All @@ -108,7 +108,7 @@ class DecodableNnet3Online: public DecodableInterface {
/// them (and possibly for some succeeding frames)
void ComputeForFrame(int32 frame);
// corrects number of frames by frame_subsampling_factor;
int32 subsampling(int32) const;
int32 NumSubsampledFrames(int32) const;

void DoNnetComputation(
int32 input_t_start,
Expand All @@ -120,7 +120,7 @@ class DecodableNnet3Online: public DecodableInterface {
CachingOptimizingCompiler compiler_;

OnlineFeatureInterface *features_;
const AmNnetSimple &nnet_;
const AmNnetSimple &am_nnet_;
const TransitionModel &trans_model_;
DecodableNnet3OnlineOptions opts_;
CuVector<BaseFloat> log_priors_; // log-priors taken from the model.
Expand All @@ -143,7 +143,7 @@ class DecodableNnet3Online: public DecodableInterface {
// opts_.max_nnet_batch_size.
Matrix<BaseFloat> scaled_loglikes_;

KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableNnet3Online);
KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableNnet3SimpleOnline);
};

} // namespace nnet3
Expand Down
4 changes: 2 additions & 2 deletions src/online2/online-nnet3-decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ namespace kaldi {
SingleUtteranceNnet3Decoder::SingleUtteranceNnet3Decoder(
const OnlineNnet3DecodingConfig &config,
const TransitionModel &tmodel,
const nnet3::AmNnetSimple &model,
const nnet3::AmNnetSimple &am_model,
const fst::Fst<fst::StdArc> &fst,
OnlineNnet2FeaturePipeline *feature_pipeline):
config_(config),
feature_pipeline_(feature_pipeline),
tmodel_(tmodel),
decodable_(model, tmodel, config.decodable_opts, feature_pipeline),
decodable_(am_model, tmodel, config.decodable_opts, feature_pipeline),
decoder_(fst, config.decoder_opts) {
decoder_.InitDecoding();
}
Expand Down
6 changes: 3 additions & 3 deletions src/online2/online-nnet3-decoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
#include <vector>
#include <deque>

#include "../nnet3/online-nnet3-decodable-simple.h"
#include "matrix/matrix-lib.h"
#include "util/common-utils.h"
#include "base/kaldi-error.h"
#include "nnet3/online-nnet3-decodable.h"
#include "online2/online-nnet2-feature-pipeline.h"
#include "online2/online-endpoint.h"
#include "decoder/lattice-faster-online-decoder.h"
Expand Down Expand Up @@ -71,7 +71,7 @@ class SingleUtteranceNnet3Decoder {
// class, it's owned externally.
SingleUtteranceNnet3Decoder(const OnlineNnet3DecodingConfig &config,
const TransitionModel &tmodel,
const nnet3::AmNnetSimple &model,
const nnet3::AmNnetSimple &am_model,
const fst::Fst<fst::StdArc> &fst,
OnlineNnet2FeaturePipeline *feature_pipeline);

Expand Down Expand Up @@ -116,7 +116,7 @@ class SingleUtteranceNnet3Decoder {

const TransitionModel &tmodel_;

nnet3::DecodableNnet3Online decodable_;
nnet3::DecodableNnet3SimpleOnline decodable_;

LatticeFasterOnlineDecoder decoder_;

Expand Down
6 changes: 3 additions & 3 deletions src/online2bin/online2-wav-nnet3-latgen-faster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,12 @@ int main(int argc, char *argv[]) {
}

TransitionModel trans_model;
nnet3::AmNnetSimple nnet;
nnet3::AmNnetSimple am_nnet;
{
bool binary;
Input ki(nnet3_rxfilename, &binary);
trans_model.Read(ki.Stream(), binary);
nnet.Read(ki.Stream(), binary);
am_nnet.Read(ki.Stream(), binary);
}

fst::Fst<fst::StdArc> *decode_fst = ReadFstKaldi(fst_rxfilename);
Expand Down Expand Up @@ -203,7 +203,7 @@ int main(int argc, char *argv[]) {

SingleUtteranceNnet3Decoder decoder(nnet3_decoding_config,
trans_model,
nnet,
am_nnet,
*decode_fst,
&feature_pipeline);
OnlineTimer decoding_timer(utt);
Expand Down

0 comments on commit 461746a

Please sign in to comment.