From 0c75eaa6c1d738eb5ca3931fb4640e3458726200 Mon Sep 17 00:00:00 2001 From: Hossein Hadian Date: Mon, 3 Jul 2017 01:40:17 -0400 Subject: [PATCH] Add test template for AttentionForward/Backward --- src/nnet3/attention-test.cc | 29 ++++++++++++++++++++++++++++- src/nnet3/attention.cc | 2 +- src/nnet3/attention.h | 2 +- 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/nnet3/attention-test.cc b/src/nnet3/attention-test.cc index 62fc19c4cee..8acd09e46b4 100644 --- a/src/nnet3/attention-test.cc +++ b/src/nnet3/attention-test.cc @@ -1,6 +1,6 @@ // nnet3/attention-test.cc -// Copyright 2017 Johns Hopkins University (author: Hossein Hadian) +// Copyright 2017 Hossein Hadian // See ../../COPYING for clarification regarding multiple authors // @@ -116,8 +116,35 @@ void UnitTestAttentionDotProductAndAddScales() { AssertEqual(B, B2); } +void TestAttentionForwardBackward() { + BaseFloat key_scale = 0.5 * RandInt(1, 3); + int32 output_num_rows = RandInt(1, 50), + value_dim = RandInt(1, 10), key_dim = RandInt(1, 10), + row_shift = RandInt(1, 5), context_dim = RandInt(2, 5), + num_extra_rows = (context_dim - 1) * row_shift, + input_num_rows = output_num_rows + num_extra_rows, + query_dim = key_dim + context_dim; + CuMatrix keys(input_num_rows, key_dim), + queries(output_num_rows, query_dim), + values(input_num_rows, value_dim), + C(output_num_rows, context_dim), + output(output_num_rows, value_dim + context_dim); + + AttentionForward(key_scale, keys, queries, values, &C, &output); + + + CuMatrix keys_deriv(input_num_rows, key_dim), + queries_deriv(output_num_rows, query_dim), + values_deriv(input_num_rows, value_dim), + output_deriv(output_num_rows, value_dim + context_dim); + + AttentionBackward(key_scale, keys, queries, values, C, + output_deriv, &keys_deriv, &queries_deriv, &values_deriv); +} + void UnitTestAttention() { UnitTestAttentionDotProductAndAddScales(); + TestAttentionForwardBackward(); } diff --git a/src/nnet3/attention.cc b/src/nnet3/attention.cc index a00193c241e..bac26c05a5d 100644 --- a/src/nnet3/attention.cc +++ b/src/nnet3/attention.cc @@ -1,7 +1,7 @@ // nnet3/attention.cc // Copyright 2017 Johns Hopkins University (author: Daniel Povey) -// Johns Hopkins University (author: Hossein Hadian) +// Hossein Hadian // See ../../COPYING for clarification regarding multiple authors // diff --git a/src/nnet3/attention.h b/src/nnet3/attention.h index d5a2577f96e..0993b78fc86 100644 --- a/src/nnet3/attention.h +++ b/src/nnet3/attention.h @@ -1,7 +1,7 @@ // nnet3/attention.h // Copyright 2017 Johns Hopkins University (author: Daniel Povey) -// Johns Hopkins University (author: Hossein Hadian) +// Hossein Hadian // See ../../COPYING for clarification regarding multiple authors //