Skip to content

Commit

Permalink
Merge pull request kaldi-asr#41 from hhadian/attention
Browse files Browse the repository at this point in the history
Add test template for AttentionForward/Backward
  • Loading branch information
danpovey authored Jul 3, 2017
2 parents ecffd43 + 0c75eaa commit f29f8a6
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
29 changes: 28 additions & 1 deletion src/nnet3/attention-test.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// nnet3/attention-test.cc

// Copyright 2017 Johns Hopkins University (author: Hossein Hadian)
// Copyright 2017 Hossein Hadian

// See ../../COPYING for clarification regarding multiple authors
//
Expand Down Expand Up @@ -116,8 +116,35 @@ void UnitTestAttentionDotProductAndAddScales() {
AssertEqual(B, B2);
}

void TestAttentionForwardBackward() {
BaseFloat key_scale = 0.5 * RandInt(1, 3);
int32 output_num_rows = RandInt(1, 50),
value_dim = RandInt(1, 10), key_dim = RandInt(1, 10),
row_shift = RandInt(1, 5), context_dim = RandInt(2, 5),
num_extra_rows = (context_dim - 1) * row_shift,
input_num_rows = output_num_rows + num_extra_rows,
query_dim = key_dim + context_dim;
CuMatrix<BaseFloat> keys(input_num_rows, key_dim),
queries(output_num_rows, query_dim),
values(input_num_rows, value_dim),
C(output_num_rows, context_dim),
output(output_num_rows, value_dim + context_dim);

AttentionForward(key_scale, keys, queries, values, &C, &output);


CuMatrix<BaseFloat> keys_deriv(input_num_rows, key_dim),
queries_deriv(output_num_rows, query_dim),
values_deriv(input_num_rows, value_dim),
output_deriv(output_num_rows, value_dim + context_dim);

AttentionBackward(key_scale, keys, queries, values, C,
output_deriv, &keys_deriv, &queries_deriv, &values_deriv);
}

void UnitTestAttention() {
UnitTestAttentionDotProductAndAddScales();
TestAttentionForwardBackward();
}


Expand Down
2 changes: 1 addition & 1 deletion src/nnet3/attention.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// nnet3/attention.cc

// Copyright 2017 Johns Hopkins University (author: Daniel Povey)
// Johns Hopkins University (author: Hossein Hadian)
// Hossein Hadian

// See ../../COPYING for clarification regarding multiple authors
//
Expand Down
2 changes: 1 addition & 1 deletion src/nnet3/attention.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// nnet3/attention.h

// Copyright 2017 Johns Hopkins University (author: Daniel Povey)
// Johns Hopkins University (author: Hossein Hadian)
// Hossein Hadian

// See ../../COPYING for clarification regarding multiple authors
//
Expand Down

0 comments on commit f29f8a6

Please sign in to comment.