Skip to content

Commit

Permalink
adding the ComputeLogProb interface to LmOutputComponent
Browse files Browse the repository at this point in the history
  • Loading branch information
hainan-xv committed Apr 28, 2017
1 parent af2a4c2 commit 7de1e33
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
3 changes: 3 additions & 0 deletions src/rnnlm/rnnlm-component-itf.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ class LmOutputComponent {

virtual ~LmOutputComponent() { }

virtual BaseFloat ComputeLogprobOfWordGivenHistory(const CuVectorBase<BaseFloat> &hidden,
int32 word_index) = 0;

virtual void Propagate(const CuMatrixBase<BaseFloat> &in,
const vector<int> &indexes, // objf is computed on the chosen indexes
CuMatrixBase<BaseFloat> *out) const = 0;
Expand Down
20 changes: 14 additions & 6 deletions src/rnnlm/rnnlm-component.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,16 @@ void NaturalGradientAffineImportanceSamplingComponent::Propagate(const CuMatrixB
CuSubMatrix<BaseFloat> bias_params(params_.ColRange(params_.NumCols() - 1, 1));
CuSubMatrix<BaseFloat> linear_params(params_.ColRange(0, params_.NumCols() - 1));
out->Row(0).CopyColFromMat(bias_params, 0);
out->CopyRowsFromVec(out->Row(0));
if (out->NumRows() > 1)
out->RowRange(1, out->NumRows() - 1).CopyRowsFromVec(out->Row(0));
out->AddMatMat(1.0, in, kNoTrans, linear_params, kTrans, 1.0);
if (normalize) {
// TODO(Hxu)
// CuMatrix<BaseFloat> test_norm(*out);
// test_norm.ApplyExp();
ComputeSamplingNonlinearity(*out, out);
KALDI_LOG << "average normalization term is " << exp(out->Sum() / out->NumRows() - 1);
out->ApplyLog();
out->ApplyLogSoftMaxPerRow(*out);
}
}
Expand Down Expand Up @@ -861,13 +868,14 @@ void AffineImportanceSamplingComponent::Propagate(const CuMatrixBase<BaseFloat>
out->RowRange(1, out->NumRows() - 1).CopyRowsFromVec(out->Row(0));
out->AddMatMat(1.0, in, kNoTrans, linear_params, kTrans, 1.0);
if (normalize) {
// TODO(Hxu)
// CuMatrix<BaseFloat> test_norm(*out);
// test_norm.ApplyExp();
ComputeSamplingNonlinearity(*out, out);
KALDI_LOG << "average normalization term is " << exp(out->Sum() / out->NumRows() - 1);
out->ApplyLog();
out->ApplyLogSoftMaxPerRow(*out);
}
// TODO(Hxu)
CuMatrix<BaseFloat> test_norm(*out);
// test_norm.ApplyExp();
ComputeSamplingNonlinearity(*out, &test_norm);
KALDI_LOG << "average normalization term is " << exp(test_norm.Sum() / test_norm.NumRows() - 1);
}

BaseFloat AffineImportanceSamplingComponent::ComputeLogprobOfWordGivenHistory(
Expand Down
2 changes: 1 addition & 1 deletion src/rnnlm/rnnlm-component.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class AffineImportanceSamplingComponent: public LmOutputComponent {
bool normalize,
CuMatrixBase<BaseFloat> *out) const;

BaseFloat ComputeLogprobOfWordGivenHistory(const CuVectorBase<BaseFloat> &hidden,
virtual BaseFloat ComputeLogprobOfWordGivenHistory(const CuVectorBase<BaseFloat> &hidden,
int32 word_index);

virtual void Backprop(const vector<int> &indexes,
Expand Down

0 comments on commit 7de1e33

Please sign in to comment.