From 12f119ef18dd3d36973b885fa58a244f788c71a9 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Tue, 25 Apr 2017 14:21:32 -0400 Subject: [PATCH] add interface for lattice-rescoring --- src/rnnlm/rnnlm-component-itf.h | 6 +++-- src/rnnlm/rnnlm-component.cc | 14 +++++++----- src/rnnlm/rnnlm-component.h | 3 +++ src/rnnlm/rnnlm-training.cc | 40 +++++++++++++++++---------------- 4 files changed, 36 insertions(+), 27 deletions(-) diff --git a/src/rnnlm/rnnlm-component-itf.h b/src/rnnlm/rnnlm-component-itf.h index 5c6f463bf60..5603831792c 100644 --- a/src/rnnlm/rnnlm-component-itf.h +++ b/src/rnnlm/rnnlm-component-itf.h @@ -121,7 +121,8 @@ class LmInputComponent { virtual LmInputComponent* Copy() const = 0; - LmInputComponent() {} + LmInputComponent(): learning_rate_(0.001), learning_rate_factor_(1.0), + is_gradient_(false), max_change_(0.0) {} virtual void Add(BaseFloat alpha, const LmInputComponent &other) = 0; virtual void Scale(BaseFloat scale) = 0; @@ -208,7 +209,8 @@ class LmOutputComponent { BaseFloat LearningRate() const { return learning_rate_; } BaseFloat MaxChange() const { return max_change_; } - LmOutputComponent() {} + LmOutputComponent(): learning_rate_(0.001), learning_rate_factor_(1.0), + is_gradient_(false), max_change_(0.0) {} virtual void Add(BaseFloat alpha, const LmOutputComponent &other) = 0; virtual void Scale(BaseFloat scale) = 0; diff --git a/src/rnnlm/rnnlm-component.cc b/src/rnnlm/rnnlm-component.cc index d356923ee9a..48eb3e73a7d 100644 --- a/src/rnnlm/rnnlm-component.cc +++ b/src/rnnlm/rnnlm-component.cc @@ -90,7 +90,6 @@ void NaturalGradientAffineImportanceSamplingComponent::Init(std::string matrix_f void NaturalGradientAffineImportanceSamplingComponent::InitFromConfig(ConfigLine *cfl) { bool ok = true; - this->learning_rate_factor_ = 1.0; std::string matrix_filename; std::string unigram_filename; int32 input_dim = -1, output_dim = -1; @@ -603,7 +602,6 @@ void LmNaturalGradientLinearComponent::Init( int32 input_dim = mat.NumCols() - 1, output_dim = mat.NumRows(); linear_params_.Resize(output_dim, input_dim); linear_params_.CopyFromMat(mat.Range(0, output_dim, 0, input_dim)); - is_gradient_ = false; // not configurable; there's no reason you'd want this update_count_ = 0.0; active_scaling_count_ = 0.0; max_change_scale_stats_ = 0.0; @@ -633,7 +631,6 @@ void LmNaturalGradientLinearComponent::Init( << "to activate the per-component max change mechanism."; KALDI_ASSERT(max_change_per_sample >= 0.0); max_change_per_sample_ = max_change_per_sample; - is_gradient_ = false; // not configurable; there's no reason you'd want this update_count_ = 0.0; active_scaling_count_ = 0.0; max_change_scale_stats_ = 0.0; @@ -819,7 +816,6 @@ void AffineImportanceSamplingComponent::InitFromConfig(ConfigLine *cfl) { params_.ColRange(params_.NumCols() - 1, 1).AddMat(1.0, g, kTrans); } } - this->learning_rate_factor_ = 1.0; // TODO(hxu) quick fix if (cfl->HasUnusedValues()) KALDI_ERR << "Could not process these elements in initializer: " << cfl->UnusedValues(); @@ -868,6 +864,14 @@ void AffineImportanceSamplingComponent::Propagate(const CuMatrixBase } } +BaseFloat AffineImportanceSamplingComponent::ComputeLogprobOfWordGivenHistory( + const CuVectorBase &hidden, + int32 word_index) { + CuSubVector param = params_.Row(word_index); + BaseFloat ans = VecVec(hidden, param); + return ans; +} + void AffineImportanceSamplingComponent::Backprop( const CuMatrixBase &in_value, const CuMatrixBase &out_value, // out_value @@ -1122,8 +1126,6 @@ void LmLinearComponent::InitFromConfig(ConfigLine *cfl) { if (!ok) KALDI_ERR << "Bad initializer " << cfl->WholeLine(); - is_gradient_ = false; // not configurable; there's no reason you'd want this - max_change_ = 1.0; } void LmLinearComponent::Propagate(const SparseMatrix &sp, diff --git a/src/rnnlm/rnnlm-component.h b/src/rnnlm/rnnlm-component.h index 7f57b73f9a8..239e0593124 100644 --- a/src/rnnlm/rnnlm-component.h +++ b/src/rnnlm/rnnlm-component.h @@ -229,6 +229,9 @@ class AffineImportanceSamplingComponent: public LmOutputComponent { bool normalize, CuMatrixBase *out) const; + BaseFloat ComputeLogprobOfWordGivenHistory(const CuVectorBase &hidden, + int32 word_index); + virtual void Backprop(const vector &indexes, const CuMatrixBase &in_value, const CuMatrixBase &, // out_value diff --git a/src/rnnlm/rnnlm-training.cc b/src/rnnlm/rnnlm-training.cc index 350d41298b5..7e452a05660 100644 --- a/src/rnnlm/rnnlm-training.cc +++ b/src/rnnlm/rnnlm-training.cc @@ -101,19 +101,14 @@ void LmNnetSamplingTrainer::ProcessEgInputs(const NnetExample& eg, const LmInputComponent& a, const SparseMatrix **old_input, CuMatrix *new_input) { - for (size_t i = 0; i < eg.io.size(); i++) { - const NnetIo &io = eg.io[i]; +// for (size_t i = 0; i < eg.io.size(); i++) { + const NnetIo &io = eg.io[0]; - if (io.name == "input") { - KALDI_ASSERT(old_input != NULL && new_input != NULL); - new_input->Resize(io.features.NumRows(), - a.OutputDim(), - kSetZero); + KALDI_ASSERT (io.name == "input"); + new_input->Resize(io.features.NumRows(), a.OutputDim(), kSetZero); - *old_input = &io.features.GetSparseMatrix(); - a.Propagate(io.features.GetSparseMatrix(), new_input); - } - } + *old_input = &io.features.GetSparseMatrix(); + a.Propagate(io.features.GetSparseMatrix(), new_input); } void LmNnetSamplingTrainer::Train(const NnetExample &eg) { @@ -493,21 +488,27 @@ void LmNnetSamplingTrainer::ComputeObjfAndDerivSample( std::vector outputs; // outputs[i] is the correct word for row i // will be initialized with SparsMatrixToVector // the size will be post.NumRows() = t * minibatch_size + // words for the same *sentence* would be grouped together + // not the same time-step + std::set outputs_set; SparseMatrixToVector(post, &outputs); - std::vector selected_probs(num_samples * t); // selected_probs[i * t + j] is the prob of + std::vector selected_probs; // selected_probs[i * t + j] is the prob of // selecting samples[j][i] + // words for the same time step t would be grouped together TODO(hxu) int minibatch_size = (*old_output)->NumRows() / t; KALDI_ASSERT(outputs.size() == t * minibatch_size); CuMatrix out((*old_output)->NumRows(), num_samples, kSetZero); + // words for the same sentence would be grouped together if (num_samples == output_projection.OutputDim()) { output_projection.Propagate(**old_output, &out); } else { + selected_probs.resize(num_samples * t); // need to parallelize this loop for (int i = 0; i < t; i++) { CuSubMatrix this_in((**old_output).Data() + i * (**old_output).Stride(), minibatch_size, (**old_output).NumCols(), (**old_output).Stride() * t); @@ -517,7 +518,8 @@ void LmNnetSamplingTrainer::ComputeObjfAndDerivSample( vector indexes(num_samples); for (int j = 0; j < num_samples; j++) { indexes[j] = samples[i][j].first; - selected_probs[j + t * i] = samples[i][j].second; + selected_probs[j * t + i] = samples[i][j].second; +// selected_probs[j + t * i] = samples[i][j].second; } output_projection.Propagate(this_in, indexes, &this_out); } @@ -528,13 +530,14 @@ void LmNnetSamplingTrainer::ComputeObjfAndDerivSample( ComputeSamplingNonlinearity(out, &f_out); *tot_weight = post.NumRows(); - vector correct_indexes(out.NumRows(), -1); + vector correct_indexes; //(out.NumRows(), -1); + SparseMatrix supervision_cpu; + // grouped same as output (words in a sentence group together) // if (num_samples == output_projection.OutputDim()) { - for (int j = 0; j < outputs.size(); j++) { - correct_indexes[j] = outputs[j]; - } + VectorToSparseMatrix(outputs, out.NumCols(), &supervision_cpu); } else { + correct_indexes.resize(out.NumRows(), -1); // TODO(hxu) not tested it yet for (int j = 0; j < t; j++) { unordered_map word2pos; @@ -550,10 +553,9 @@ void LmNnetSamplingTrainer::ComputeObjfAndDerivSample( // KALDI_ASSERT(outputs[j + i * t] == samples[j][correct_indexes[j + i * t]].first); } } + VectorToSparseMatrix(correct_indexes, out.NumCols(), &supervision_cpu); } - SparseMatrix supervision_cpu; - VectorToSparseMatrix(correct_indexes, out.NumCols(), &supervision_cpu); CuSparseMatrix supervision_gpu(supervision_cpu); *tot_objf = TraceMatSmat(out, supervision_gpu, kTrans); // first part of the objf // (the objf regarding the positive reward for getting correct labels)