Skip to content

Commit

Permalink
[src] Fix various bugs that came up while testing attention component
Browse files Browse the repository at this point in the history
  • Loading branch information
danpovey committed Jul 3, 2017
1 parent 54aa27e commit ecffd43
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
10 changes: 5 additions & 5 deletions src/nnet3/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ void AttentionForward(BaseFloat key_scale,
context_dim = queries.NumCols() - key_dim,
value_dim = values.NumCols();
KALDI_ASSERT(num_input_rows > 0 && key_dim > 0 &&
num_output_rows > num_input_rows &&
num_input_rows > num_output_rows &&
context_dim > 0 &&
(num_output_rows - num_input_rows) % (context_dim - 1) == 0 &&
(num_input_rows - num_output_rows) % (context_dim - 1) == 0 &&
values.NumRows() == num_input_rows);
KALDI_ASSERT(c->NumRows() == num_output_rows &&
c->NumCols() == context_dim);
Expand Down Expand Up @@ -169,10 +169,10 @@ void AttentionBackward(BaseFloat key_scale,
context_dim = queries.NumCols() - key_dim,
value_dim = values.NumCols();
KALDI_ASSERT(num_input_rows > 0 && key_dim > 0 &&
num_output_rows > num_input_rows &&
num_input_rows > num_output_rows &&
context_dim > 0 &&
(num_output_rows - num_input_rows) % (context_dim - 1) == 0 &&
values.NumRows() == num_output_rows);
(num_input_rows - num_output_rows) % (context_dim - 1) == 0 &&
values.NumRows() == num_input_rows);
KALDI_ASSERT(SameDim(keys, *keys_deriv) &&
SameDim(queries, *queries_deriv) &&
SameDim(values, *values_deriv));
Expand Down
6 changes: 3 additions & 3 deletions src/nnet3/nnet-attention-component.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ RestrictedAttentionComponent::Propagate(const ComponentPrecomputedIndexes *index
h * context_dim_, context_dim_),
out_part(*out, 0, out->NumRows(),
h * output_dim_per_head, output_dim_per_head);
PropagateOneHead(indexes->io, in, &c_part, &out_part);
PropagateOneHead(indexes->io, in_part, &c_part, &out_part);
}
return static_cast<void*>(memo);
}
Expand All @@ -162,7 +162,7 @@ void RestrictedAttentionComponent::PropagateOneHead(
CuMatrixBase<BaseFloat> *c,
CuMatrixBase<BaseFloat> *out) const {
int32 query_dim = key_dim_ + context_dim_,
full_value_dim = key_dim_ + (output_context_ ? context_dim_ : 0);
full_value_dim = value_dim_ + (output_context_ ? context_dim_ : 0);
KALDI_ASSERT(in.NumRows() == io.num_images * io.num_t_in &&
out->NumRows() == io.num_images * io.num_t_out &&
out->NumCols() == full_value_dim &&
Expand Down Expand Up @@ -311,7 +311,7 @@ void RestrictedAttentionComponent::BackpropOneHead(
CuMatrixBase<BaseFloat> *in_deriv) const {
// the easiest way to understand this is to compare it with PropagateOneHead().
int32 query_dim = key_dim_ + context_dim_,
full_value_dim = key_dim_ + (output_context_ ? context_dim_ : 0);
full_value_dim = value_dim_ + (output_context_ ? context_dim_ : 0);
KALDI_ASSERT(in_value.NumRows() == io.num_images * io.num_t_in &&
out_deriv.NumRows() == io.num_images * io.num_t_out &&
out_deriv.NumCols() == full_value_dim &&
Expand Down
2 changes: 1 addition & 1 deletion src/nnet3/nnet-attention-component.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class RestrictedAttentionComponent: public Component {
virtual Component* Copy() const {
return new RestrictedAttentionComponent(*this);
}

virtual void DeleteMemo(void *memo) const { delete static_cast<Memo*>(memo); }

// Some functions that are only to be reimplemented for GeneralComponents.

Expand Down

0 comments on commit ecffd43

Please sign in to comment.