From ecffd4384e4a0526d401a7f5a05a3a61f7039051 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 3 Jul 2017 00:53:11 -0400 Subject: [PATCH] [src] Fix various bugs that came up while testing attention component --- src/nnet3/attention.cc | 10 +++++----- src/nnet3/nnet-attention-component.cc | 6 +++--- src/nnet3/nnet-attention-component.h | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/nnet3/attention.cc b/src/nnet3/attention.cc index 55b957c65d8..a00193c241e 100644 --- a/src/nnet3/attention.cc +++ b/src/nnet3/attention.cc @@ -108,9 +108,9 @@ void AttentionForward(BaseFloat key_scale, context_dim = queries.NumCols() - key_dim, value_dim = values.NumCols(); KALDI_ASSERT(num_input_rows > 0 && key_dim > 0 && - num_output_rows > num_input_rows && + num_input_rows > num_output_rows && context_dim > 0 && - (num_output_rows - num_input_rows) % (context_dim - 1) == 0 && + (num_input_rows - num_output_rows) % (context_dim - 1) == 0 && values.NumRows() == num_input_rows); KALDI_ASSERT(c->NumRows() == num_output_rows && c->NumCols() == context_dim); @@ -169,10 +169,10 @@ void AttentionBackward(BaseFloat key_scale, context_dim = queries.NumCols() - key_dim, value_dim = values.NumCols(); KALDI_ASSERT(num_input_rows > 0 && key_dim > 0 && - num_output_rows > num_input_rows && + num_input_rows > num_output_rows && context_dim > 0 && - (num_output_rows - num_input_rows) % (context_dim - 1) == 0 && - values.NumRows() == num_output_rows); + (num_input_rows - num_output_rows) % (context_dim - 1) == 0 && + values.NumRows() == num_input_rows); KALDI_ASSERT(SameDim(keys, *keys_deriv) && SameDim(queries, *queries_deriv) && SameDim(values, *values_deriv)); diff --git a/src/nnet3/nnet-attention-component.cc b/src/nnet3/nnet-attention-component.cc index 3d7fb649603..fc12671b9a0 100644 --- a/src/nnet3/nnet-attention-component.cc +++ b/src/nnet3/nnet-attention-component.cc @@ -151,7 +151,7 @@ RestrictedAttentionComponent::Propagate(const ComponentPrecomputedIndexes *index h * context_dim_, context_dim_), out_part(*out, 0, out->NumRows(), h * output_dim_per_head, output_dim_per_head); - PropagateOneHead(indexes->io, in, &c_part, &out_part); + PropagateOneHead(indexes->io, in_part, &c_part, &out_part); } return static_cast(memo); } @@ -162,7 +162,7 @@ void RestrictedAttentionComponent::PropagateOneHead( CuMatrixBase *c, CuMatrixBase *out) const { int32 query_dim = key_dim_ + context_dim_, - full_value_dim = key_dim_ + (output_context_ ? context_dim_ : 0); + full_value_dim = value_dim_ + (output_context_ ? context_dim_ : 0); KALDI_ASSERT(in.NumRows() == io.num_images * io.num_t_in && out->NumRows() == io.num_images * io.num_t_out && out->NumCols() == full_value_dim && @@ -311,7 +311,7 @@ void RestrictedAttentionComponent::BackpropOneHead( CuMatrixBase *in_deriv) const { // the easiest way to understand this is to compare it with PropagateOneHead(). int32 query_dim = key_dim_ + context_dim_, - full_value_dim = key_dim_ + (output_context_ ? context_dim_ : 0); + full_value_dim = value_dim_ + (output_context_ ? context_dim_ : 0); KALDI_ASSERT(in_value.NumRows() == io.num_images * io.num_t_in && out_deriv.NumRows() == io.num_images * io.num_t_out && out_deriv.NumCols() == full_value_dim && diff --git a/src/nnet3/nnet-attention-component.h b/src/nnet3/nnet-attention-component.h index 9434e2c0cf1..4ff3ac47979 100644 --- a/src/nnet3/nnet-attention-component.h +++ b/src/nnet3/nnet-attention-component.h @@ -151,7 +151,7 @@ class RestrictedAttentionComponent: public Component { virtual Component* Copy() const { return new RestrictedAttentionComponent(*this); } - + virtual void DeleteMemo(void *memo) const { delete static_cast(memo); } // Some functions that are only to be reimplemented for GeneralComponents.