diff --git a/src/rnnlm/rnnlm-component-itf.h b/src/rnnlm/rnnlm-component-itf.h index f9cdd8b26b7..54c47d8b6f7 100644 --- a/src/rnnlm/rnnlm-component-itf.h +++ b/src/rnnlm/rnnlm-component-itf.h @@ -241,6 +241,9 @@ class LmOutputComponent { virtual ~LmOutputComponent() { } + virtual BaseFloat ComputeLogprobOfWordGivenHistory(const CuVectorBase &hidden, + int32 word_index) = 0; + virtual void Propagate(const CuMatrixBase &in, const vector &indexes, // objf is computed on the chosen indexes CuMatrixBase *out) const = 0; diff --git a/src/rnnlm/rnnlm-component.cc b/src/rnnlm/rnnlm-component.cc index 5e49a500d22..022d173171a 100644 --- a/src/rnnlm/rnnlm-component.cc +++ b/src/rnnlm/rnnlm-component.cc @@ -166,9 +166,16 @@ void NaturalGradientAffineImportanceSamplingComponent::Propagate(const CuMatrixB CuSubMatrix bias_params(params_.ColRange(params_.NumCols() - 1, 1)); CuSubMatrix linear_params(params_.ColRange(0, params_.NumCols() - 1)); out->Row(0).CopyColFromMat(bias_params, 0); - out->CopyRowsFromVec(out->Row(0)); + if (out->NumRows() > 1) + out->RowRange(1, out->NumRows() - 1).CopyRowsFromVec(out->Row(0)); out->AddMatMat(1.0, in, kNoTrans, linear_params, kTrans, 1.0); if (normalize) { + // TODO(Hxu) + // CuMatrix test_norm(*out); + // test_norm.ApplyExp(); + ComputeSamplingNonlinearity(*out, out); + KALDI_LOG << "average normalization term is " << exp(out->Sum() / out->NumRows() - 1); + out->ApplyLog(); out->ApplyLogSoftMaxPerRow(*out); } } @@ -861,13 +868,14 @@ void AffineImportanceSamplingComponent::Propagate(const CuMatrixBase out->RowRange(1, out->NumRows() - 1).CopyRowsFromVec(out->Row(0)); out->AddMatMat(1.0, in, kNoTrans, linear_params, kTrans, 1.0); if (normalize) { + // TODO(Hxu) + // CuMatrix test_norm(*out); + // test_norm.ApplyExp(); + ComputeSamplingNonlinearity(*out, out); + KALDI_LOG << "average normalization term is " << exp(out->Sum() / out->NumRows() - 1); + out->ApplyLog(); out->ApplyLogSoftMaxPerRow(*out); } - // TODO(Hxu) - CuMatrix test_norm(*out); -// test_norm.ApplyExp(); - ComputeSamplingNonlinearity(*out, &test_norm); - KALDI_LOG << "average normalization term is " << exp(test_norm.Sum() / test_norm.NumRows() - 1); } BaseFloat AffineImportanceSamplingComponent::ComputeLogprobOfWordGivenHistory( diff --git a/src/rnnlm/rnnlm-component.h b/src/rnnlm/rnnlm-component.h index 239e0593124..cf75e7c3f5d 100644 --- a/src/rnnlm/rnnlm-component.h +++ b/src/rnnlm/rnnlm-component.h @@ -229,7 +229,7 @@ class AffineImportanceSamplingComponent: public LmOutputComponent { bool normalize, CuMatrixBase *out) const; - BaseFloat ComputeLogprobOfWordGivenHistory(const CuVectorBase &hidden, + virtual BaseFloat ComputeLogprobOfWordGivenHistory(const CuVectorBase &hidden, int32 word_index); virtual void Backprop(const vector &indexes,