Skip to content

Commit

Permalink
Merge pull request #470 from speechmatics/bugfix-doubleprecision
Browse files Browse the repository at this point in the history
Fix compilation issues when using KALDI_DOUBLEPRECISION=1
  • Loading branch information
danpovey committed Jan 28, 2016
2 parents e418aa5 + f587473 commit 75a8334
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 21 deletions.
12 changes: 6 additions & 6 deletions src/feat/signal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void ElementwiseProductOfFft(const Vector<BaseFloat> &a, Vector<BaseFloat> *b) {
void ConvolveSignals(const Vector<BaseFloat> &filter, Vector<BaseFloat> *signal) {
int32 signal_length = signal->Dim();
int32 filter_length = filter.Dim();
Vector<float> signal_padded(signal_length + filter_length - 1);
Vector<BaseFloat> signal_padded(signal_length + filter_length - 1);
signal_padded.SetZero();
for (int32 i = 0; i < signal_length; i++) {
for (int32 j = 0; j < filter_length; j++) {
Expand All @@ -54,11 +54,11 @@ void FFTbasedConvolveSignals(const Vector<BaseFloat> &filter, Vector<BaseFloat>

SplitRadixRealFft<BaseFloat> srfft(fft_length);

Vector<float> filter_padded(fft_length);
Vector<BaseFloat> filter_padded(fft_length);
filter_padded.Range(0, filter_length).CopyFromVec(filter);
srfft.Compute(filter_padded.Data(), true);

Vector<float> signal_padded(fft_length);
Vector<BaseFloat> signal_padded(fft_length);
signal_padded.Range(0, signal_length).CopyFromVec(*signal);
srfft.Compute(signal_padded.Data(), true);

Expand All @@ -83,13 +83,13 @@ void FFTbasedBlockConvolveSignals(const Vector<BaseFloat> &filter, Vector<BaseFl
KALDI_VLOG(1) << "Block size is " << block_length;
SplitRadixRealFft<BaseFloat> srfft(fft_length);

Vector<float> filter_padded(fft_length);
Vector<BaseFloat> filter_padded(fft_length);
filter_padded.Range(0, filter_length).CopyFromVec(filter);
srfft.Compute(filter_padded.Data(), true);

Vector<float> temp_pad(filter_length - 1);
Vector<BaseFloat> temp_pad(filter_length - 1);
temp_pad.SetZero();
Vector<float> signal_block_padded(fft_length);
Vector<BaseFloat> signal_block_padded(fft_length);

for (int32 po = 0; po < signal_length; po += block_length) {
// get a block of the signal
Expand Down
8 changes: 4 additions & 4 deletions src/lm/kaldi-rnnlm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ KaldiRnnlmWrapper::KaldiRnnlmWrapper(

BaseFloat KaldiRnnlmWrapper::GetLogProb(
int32 word, const std::vector<int32> &wseq,
const std::vector<BaseFloat> &context_in,
std::vector<BaseFloat> *context_out) {
const std::vector<float> &context_in,
std::vector<float> *context_out) {

std::vector<std::string> wseq_symbols(wseq.size());
for (int32 i = 0; i < wseq_symbols.size(); ++i) {
Expand All @@ -79,7 +79,7 @@ RnnlmDeterministicFst::RnnlmDeterministicFst(int32 max_ngram_order,

// Uses empty history for <s>.
std::vector<Label> bos;
std::vector<BaseFloat> bos_context(rnnlm->GetHiddenLayerSize(), 1.0f);
std::vector<float> bos_context(rnnlm->GetHiddenLayerSize(), 1.0);
state_to_wseq_.push_back(bos);
state_to_context_.push_back(bos_context);
wseq_to_state_[bos] = 0;
Expand All @@ -101,7 +101,7 @@ bool RnnlmDeterministicFst::GetArc(StateId s, Label ilabel, fst::StdArc *oarc) {
KALDI_ASSERT(static_cast<size_t>(s) < state_to_wseq_.size());

std::vector<Label> wseq = state_to_wseq_[s];
std::vector<BaseFloat> new_context(rnnlm_->GetHiddenLayerSize());
std::vector<float> new_context(rnnlm_->GetHiddenLayerSize());
BaseFloat logprob = rnnlm_->GetLogProb(ilabel, wseq,
state_to_context_[s], &new_context);

Expand Down
6 changes: 3 additions & 3 deletions src/lm/kaldi-rnnlm.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class KaldiRnnlmWrapper {
int32 GetEos() const { return eos_; }

BaseFloat GetLogProb(int32 word, const std::vector<int32> &wseq,
const std::vector<BaseFloat> &context_in,
std::vector<BaseFloat> *context_out);
const std::vector<float> &context_in,
std::vector<float> *context_out);

private:
rnnlm::CRnnLM rnnlm_;
Expand Down Expand Up @@ -96,7 +96,7 @@ class RnnlmDeterministicFst

KaldiRnnlmWrapper *rnnlm_;
int32 max_ngram_order_;
std::vector<std::vector<BaseFloat> > state_to_context_;
std::vector<std::vector<float> > state_to_context_;
};

} // namespace kaldi
Expand Down
12 changes: 6 additions & 6 deletions src/nnet2/nnet-component.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3933,8 +3933,8 @@ void Convolutional1dComponent::Propagate(const ChunkInfo &in_info,
}

// apply all filters
AddMatMatBatched(1.0f, tgt_batch, patch_batch, kNoTrans, filter_params_batch,
kTrans, 1.0f);
AddMatMatBatched<BaseFloat>(1.0, tgt_batch, patch_batch, kNoTrans, filter_params_batch,
kTrans, 1.0);

// release memory
delete filter_params_elem;
Expand Down Expand Up @@ -4060,8 +4060,8 @@ void Convolutional1dComponent::Backprop(const ChunkInfo &in_info,
p * num_filters, num_filters)));
filter_params_batch.push_back(filter_params_elem);
}
AddMatMatBatched(1.0f, patch_deriv_batch, out_deriv_batch, kNoTrans,
filter_params_batch, kNoTrans, 0.0f);
AddMatMatBatched<BaseFloat>(1.0, patch_deriv_batch, out_deriv_batch, kNoTrans,
filter_params_batch, kNoTrans, 0.0);

// release memory
delete filter_params_elem;
Expand Down Expand Up @@ -4275,8 +4275,8 @@ void Convolutional1dComponent::Update(const CuMatrixBase<BaseFloat> &in_value,
p * filter_dim, filter_dim)));
}

AddMatMatBatched(1.0f, filters_grad_batch, diff_patch_batch, kTrans, patch_batch,
kNoTrans, 1.0f);
AddMatMatBatched<BaseFloat>(1.0, filters_grad_batch, diff_patch_batch, kTrans, patch_batch,
kNoTrans, 1.0);

// add the row blocks together to filters_grad
filters_grad.AddMatBlocks(1.0, filters_grad_blocks_batch);
Expand Down
2 changes: 1 addition & 1 deletion src/nnet2/nnet-precondition-online-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ void UnitTestPreconditionDirectionsOnline() {
AssertEqual(trace1, trace2 * gamma2 * gamma2, 1.0e-02);

AssertEqual(Mcopy1, Mcopy2);
AssertEqual(row_prod1, row_prod2, 1.0e-02f);
AssertEqual<BaseFloat>(row_prod1, row_prod2, 1.0e-02);
AssertEqual(gamma1, gamma2, 1.0e-02);

// make sure positive definite
Expand Down
2 changes: 1 addition & 1 deletion src/nnet3/natural-gradient-online-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ void UnitTestPreconditionDirectionsOnline() {
AssertEqual(trace1, trace2 * gamma2 * gamma2, 1.0e-02);

AssertEqual(Mcopy1, Mcopy2);
AssertEqual(row_prod1, row_prod2, 1.0e-02f);
AssertEqual<BaseFloat>(row_prod1, row_prod2, 1.0e-02);
AssertEqual(gamma1, gamma2, 1.0e-02);

// make sure positive definite
Expand Down

0 comments on commit 75a8334

Please sign in to comment.