Skip to content

Commit

Permalink
add interface for lattice-rescoring
Browse files Browse the repository at this point in the history
  • Loading branch information
hainan-xv committed Apr 25, 2017
1 parent 38ef328 commit 12f119e
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 27 deletions.
6 changes: 4 additions & 2 deletions src/rnnlm/rnnlm-component-itf.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ class LmInputComponent {

virtual LmInputComponent* Copy() const = 0;

LmInputComponent() {}
LmInputComponent(): learning_rate_(0.001), learning_rate_factor_(1.0),
is_gradient_(false), max_change_(0.0) {}

virtual void Add(BaseFloat alpha, const LmInputComponent &other) = 0;
virtual void Scale(BaseFloat scale) = 0;
Expand Down Expand Up @@ -208,7 +209,8 @@ class LmOutputComponent {
BaseFloat LearningRate() const { return learning_rate_; }
BaseFloat MaxChange() const { return max_change_; }

LmOutputComponent() {}
LmOutputComponent(): learning_rate_(0.001), learning_rate_factor_(1.0),
is_gradient_(false), max_change_(0.0) {}

virtual void Add(BaseFloat alpha, const LmOutputComponent &other) = 0;
virtual void Scale(BaseFloat scale) = 0;
Expand Down
14 changes: 8 additions & 6 deletions src/rnnlm/rnnlm-component.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ void NaturalGradientAffineImportanceSamplingComponent::Init(std::string matrix_f

void NaturalGradientAffineImportanceSamplingComponent::InitFromConfig(ConfigLine *cfl) {
bool ok = true;
this->learning_rate_factor_ = 1.0;
std::string matrix_filename;
std::string unigram_filename;
int32 input_dim = -1, output_dim = -1;
Expand Down Expand Up @@ -603,7 +602,6 @@ void LmNaturalGradientLinearComponent::Init(
int32 input_dim = mat.NumCols() - 1, output_dim = mat.NumRows();
linear_params_.Resize(output_dim, input_dim);
linear_params_.CopyFromMat(mat.Range(0, output_dim, 0, input_dim));
is_gradient_ = false; // not configurable; there's no reason you'd want this
update_count_ = 0.0;
active_scaling_count_ = 0.0;
max_change_scale_stats_ = 0.0;
Expand Down Expand Up @@ -633,7 +631,6 @@ void LmNaturalGradientLinearComponent::Init(
<< "to activate the per-component max change mechanism.";
KALDI_ASSERT(max_change_per_sample >= 0.0);
max_change_per_sample_ = max_change_per_sample;
is_gradient_ = false; // not configurable; there's no reason you'd want this
update_count_ = 0.0;
active_scaling_count_ = 0.0;
max_change_scale_stats_ = 0.0;
Expand Down Expand Up @@ -819,7 +816,6 @@ void AffineImportanceSamplingComponent::InitFromConfig(ConfigLine *cfl) {
params_.ColRange(params_.NumCols() - 1, 1).AddMat(1.0, g, kTrans);
}
}
this->learning_rate_factor_ = 1.0; // TODO(hxu) quick fix
if (cfl->HasUnusedValues())
KALDI_ERR << "Could not process these elements in initializer: "
<< cfl->UnusedValues();
Expand Down Expand Up @@ -868,6 +864,14 @@ void AffineImportanceSamplingComponent::Propagate(const CuMatrixBase<BaseFloat>
}
}

BaseFloat AffineImportanceSamplingComponent::ComputeLogprobOfWordGivenHistory(
const CuVectorBase<BaseFloat> &hidden,
int32 word_index) {
CuSubVector<BaseFloat> param = params_.Row(word_index);
BaseFloat ans = VecVec(hidden, param);
return ans;
}

void AffineImportanceSamplingComponent::Backprop(
const CuMatrixBase<BaseFloat> &in_value,
const CuMatrixBase<BaseFloat> &out_value, // out_value
Expand Down Expand Up @@ -1122,8 +1126,6 @@ void LmLinearComponent::InitFromConfig(ConfigLine *cfl) {
if (!ok)
KALDI_ERR << "Bad initializer " << cfl->WholeLine();

is_gradient_ = false; // not configurable; there's no reason you'd want this
max_change_ = 1.0;
}

void LmLinearComponent::Propagate(const SparseMatrix<BaseFloat> &sp,
Expand Down
3 changes: 3 additions & 0 deletions src/rnnlm/rnnlm-component.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ class AffineImportanceSamplingComponent: public LmOutputComponent {
bool normalize,
CuMatrixBase<BaseFloat> *out) const;

BaseFloat ComputeLogprobOfWordGivenHistory(const CuVectorBase<BaseFloat> &hidden,
int32 word_index);

virtual void Backprop(const vector<int> &indexes,
const CuMatrixBase<BaseFloat> &in_value,
const CuMatrixBase<BaseFloat> &, // out_value
Expand Down
40 changes: 21 additions & 19 deletions src/rnnlm/rnnlm-training.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,19 +101,14 @@ void LmNnetSamplingTrainer::ProcessEgInputs(const NnetExample& eg,
const LmInputComponent& a,
const SparseMatrix<BaseFloat> **old_input,
CuMatrix<BaseFloat> *new_input) {
for (size_t i = 0; i < eg.io.size(); i++) {
const NnetIo &io = eg.io[i];
// for (size_t i = 0; i < eg.io.size(); i++) {
const NnetIo &io = eg.io[0];

if (io.name == "input") {
KALDI_ASSERT(old_input != NULL && new_input != NULL);
new_input->Resize(io.features.NumRows(),
a.OutputDim(),
kSetZero);
KALDI_ASSERT (io.name == "input");
new_input->Resize(io.features.NumRows(), a.OutputDim(), kSetZero);

*old_input = &io.features.GetSparseMatrix();
a.Propagate(io.features.GetSparseMatrix(), new_input);
}
}
*old_input = &io.features.GetSparseMatrix();
a.Propagate(io.features.GetSparseMatrix(), new_input);
}

void LmNnetSamplingTrainer::Train(const NnetExample &eg) {
Expand Down Expand Up @@ -493,21 +488,27 @@ void LmNnetSamplingTrainer::ComputeObjfAndDerivSample(
std::vector<int> outputs; // outputs[i] is the correct word for row i
// will be initialized with SparsMatrixToVector
// the size will be post.NumRows() = t * minibatch_size
// words for the same *sentence* would be grouped together
// not the same time-step

std::set<int> outputs_set;

SparseMatrixToVector(post, &outputs);

std::vector<double> selected_probs(num_samples * t); // selected_probs[i * t + j] is the prob of
std::vector<double> selected_probs; // selected_probs[i * t + j] is the prob of
// selecting samples[j][i]
// words for the same time step t would be grouped together TODO(hxu)

int minibatch_size = (*old_output)->NumRows() / t;
KALDI_ASSERT(outputs.size() == t * minibatch_size);

CuMatrix<BaseFloat> out((*old_output)->NumRows(), num_samples, kSetZero);
// words for the same sentence would be grouped together

if (num_samples == output_projection.OutputDim()) {
output_projection.Propagate(**old_output, &out);
} else {
selected_probs.resize(num_samples * t);
// need to parallelize this loop
for (int i = 0; i < t; i++) {
CuSubMatrix<BaseFloat> this_in((**old_output).Data() + i * (**old_output).Stride(), minibatch_size, (**old_output).NumCols(), (**old_output).Stride() * t);
Expand All @@ -517,7 +518,8 @@ void LmNnetSamplingTrainer::ComputeObjfAndDerivSample(
vector<int> indexes(num_samples);
for (int j = 0; j < num_samples; j++) {
indexes[j] = samples[i][j].first;
selected_probs[j + t * i] = samples[i][j].second;
selected_probs[j * t + i] = samples[i][j].second;
// selected_probs[j + t * i] = samples[i][j].second;
}
output_projection.Propagate(this_in, indexes, &this_out);
}
Expand All @@ -528,13 +530,14 @@ void LmNnetSamplingTrainer::ComputeObjfAndDerivSample(
ComputeSamplingNonlinearity(out, &f_out);

*tot_weight = post.NumRows();
vector<int32> correct_indexes(out.NumRows(), -1);
vector<int32> correct_indexes; //(out.NumRows(), -1);
SparseMatrix<BaseFloat> supervision_cpu;
// grouped same as output (words in a sentence group together)
//
if (num_samples == output_projection.OutputDim()) {
for (int j = 0; j < outputs.size(); j++) {
correct_indexes[j] = outputs[j];
}
VectorToSparseMatrix(outputs, out.NumCols(), &supervision_cpu);
} else {
correct_indexes.resize(out.NumRows(), -1);
// TODO(hxu) not tested it yet
for (int j = 0; j < t; j++) {
unordered_map<int32, int32> word2pos;
Expand All @@ -550,10 +553,9 @@ void LmNnetSamplingTrainer::ComputeObjfAndDerivSample(
// KALDI_ASSERT(outputs[j + i * t] == samples[j][correct_indexes[j + i * t]].first);
}
}
VectorToSparseMatrix(correct_indexes, out.NumCols(), &supervision_cpu);
}

SparseMatrix<BaseFloat> supervision_cpu;
VectorToSparseMatrix(correct_indexes, out.NumCols(), &supervision_cpu);
CuSparseMatrix<BaseFloat> supervision_gpu(supervision_cpu);
*tot_objf = TraceMatSmat(out, supervision_gpu, kTrans); // first part of the objf
// (the objf regarding the positive reward for getting correct labels)
Expand Down

0 comments on commit 12f119e

Please sign in to comment.