From a8048d8c742187753461efabaaae3615cc3e9878 Mon Sep 17 00:00:00 2001 From: andrewshinsony Date: Tue, 7 Apr 2020 22:52:47 +0900 Subject: [PATCH] change attention api --- python/src/nnabla/functions.py | 61 ++++++++++++---------- python/src/nnabla/parametric_functions.py | 62 +++++++++++++---------- python/test/test_parametric_functions.py | 57 ++++++++++++--------- 3 files changed, 101 insertions(+), 79 deletions(-) diff --git a/python/src/nnabla/functions.py b/python/src/nnabla/functions.py index 3e8c63a38..eecbbbe3b 100644 --- a/python/src/nnabla/functions.py +++ b/python/src/nnabla/functions.py @@ -1111,7 +1111,7 @@ def multi_head_attention(query, key, value, num_heads, q_weight, k_weight, v_wei Computes multi-headed attention with query, key, and value. We use the following notations to describe the inputs and outputs below. - :math:`L_T`: target sequence length, :math:`L_S`: source sequence length, :math:`B`: batch size, :math:`E`: embedding dimension, :math`H`: number of attention heads. + :math:`L_T`: target sequence length, :math:`L_S`: source sequence length, :math:`B`: batch size, :math:`D`: input dimension, :math:`E`: embedding dimension, :math`H`: number of attention heads. References: @@ -1120,37 +1120,44 @@ def multi_head_attention(query, key, value, num_heads, q_weight, k_weight, v_wei Args: - query (~nnabla.Variable): Input N-D array with shape :math:`(L_T, B, E)`. - key (~nnabla.Variable): Input N-D array with shape :math:`(L_S, B, E_k)`. - value (~nnabla.Variable): Input N-D array with shape :math:`(L_S, B, E_v)`. + query (~nnabla.Variable): Input N-D array with shape :math:`(L_T, B, D_q)`. + key (~nnabla.Variable): Input N-D array with shape :math:`(L_S, B, D_k)`. + value (~nnabla.Variable): Input N-D array with shape :math:`(L_S, B, D_v)`. num_heads (int): Number of attention heads. Note that embedding dimensoin E must be divisible by the number of heads. Default is 12 which is conventional. - q_weight (~nnabla.Variable): Input N-D array with shape :math:`(E E)`. - k_weight (~nnabla.Variable): Input N-D array with shape :math:`(E_k, E)`. - v_weight (~nnabla.Variable): Input N-D array with shape :math:`(E_v, E)`. - out_weight (~nnabla.Variable): Input N-D array with shape :math:`(E, E)`. + q_weight (~nnabla.Variable): Input N-D array with shape :math:`(D_q, E)`. + k_weight (~nnabla.Variable): Input N-D array with shape :math:`(D_k, E)`. + v_weight (~nnabla.Variable): Input N-D array with shape :math:`(D_v, E_v)`. + out_weight (~nnabla.Variable): Input N-D array with shape :math:`(D_v, E_{out})`. q_bias (~nnabla.Variable, optional): Input N-D array with shape :math:`(E, )`. k_bias (~nnabla.Variable, optional): Input N-D array with shape :math:`(E, )`. - v_bias (~nnabla.Variable, optional): Input N-D array with shape :math:`(E, )`. - out_bias (~nnabla.Variable, optional): Input N-D array with shape :math:`(E, )`. + v_bias (~nnabla.Variable, optional): Input N-D array with shape :math:`(E_v, )`. + out_bias (~nnabla.Variable, optional): Input N-D array with shape :math:`(E_{out}, )`. attn_bias_k (~nnabla.Variable, optional): Input N-D array with shape :math:`(E, )`. - attn_bias_v (~nnabla.Variable, optional): Input N-D array with shape :math:`(E, )`. + attn_bias_v (~nnabla.Variable, optional): Input N-D array with shape :math:`(E_v, )`. dropout (float, optional): Dropout ratio applied to parameters. Default is 0. additive_mask (~nnabla.Variable, optional): Input N-D array with shape :math:`(L_T, L_S)`. Values will be added to the attention layer to prevent attention to certain positions. key_padding_mask (~nnabla.Variable, optional): Input N-D array with shape :math:`(B, L_S)`. Specified padding elements will be ignored by the attention layer. Values must be either 1 or 0. Returns: - ~nnabla.Variable: Output :math:`y` with shape :math:`(L_T, B, E)` + ~nnabla.Variable: Output :math:`y` with shape :math:`(L_T, B, E_{out})` ~nnabla.Variable: Output :math:`h_n` with shape :math:`(B, L_T, L_S)` ''' from . import functions as F - tgt_len, batch_size, embed_dim = query.shape - src_len, batch_size, kdim = key.shape - vdim = value.shape[2] + tgt_len, batch_size, _ = query.shape + src_len, batch_size, _ = key.shape + q_embed_dim = q_weight.shape[1] + k_embed_dim = k_weight.shape[1] + v_embed_dim = v_weight.shape[1] + out_dim = out_weight.shape[1] assert src_len == value.shape[0] - head_dim = embed_dim // num_heads - assert head_dim * num_heads == embed_dim + head_dim = q_embed_dim // num_heads + head_vdim = v_embed_dim // num_heads + assert q_embed_dim == k_embed_dim, "embedding dimensions must be the same for query and key." + assert head_dim * num_heads == q_embed_dim, "embedding dimension must be divisibile by num_heads %d" % num_heads + assert head_vdim * \ + num_heads == v_embed_dim, "v_embed_dim must be divisibile by num_heads %d." % num_heads if key_padding_mask is not None: assert key_padding_mask.shape[0] == batch_size @@ -1158,16 +1165,16 @@ def multi_head_attention(query, key, value, num_heads, q_weight, k_weight, v_wei # query:(L_T, B, E) --> q:(L_T, B, E) q = F.affine(query, q_weight, q_bias, base_axis=2) - # key:(L_S, B, E_k) --> k:(L_S, B, E) + # key:(L_S, B, D_k) --> k:(L_S, B, E_k) k = F.affine(key, k_weight, k_bias, base_axis=2) - # value:(L_S, B, E_v) --> v:(L_S, B, E) + # value:(L_S, B, D_v) --> v:(L_S, B, E_v) v = F.affine(value, v_weight, v_bias, base_axis=2) q *= float(head_dim) ** -0.5 if attn_bias_k is not None: - attn_bias_k = F.reshape(attn_bias_k, (1, 1, embed_dim)) - attn_bias_v = F.reshape(attn_bias_v, (1, 1, embed_dim)) + attn_bias_k = F.reshape(attn_bias_k, (1, 1, k_embed_dim)) + attn_bias_v = F.reshape(attn_bias_v, (1, 1, v_embed_dim)) src_len += 1 assert attn_bias_k is not None attn_bias_k = F.broadcast( @@ -1184,11 +1191,11 @@ def multi_head_attention(query, key, value, num_heads, q_weight, k_weight, v_wei key_padding_mask = F.pad(key_padding_mask, (0, 1)) q = F.transpose( - F.reshape(q, (tgt_len, batch_size * num_heads, head_dim)), (1, 0, 2)) # q:(B*H, L_T, dim_head) + F.reshape(q, (tgt_len, batch_size * num_heads, head_dim)), (1, 0, 2)) # q:(B*H, L_T, head_dim) k = F.transpose( - F.reshape(k, (-1, batch_size * num_heads, head_dim)), (1, 0, 2)) # k:(B*H, L_S, dim_head) + F.reshape(k, (-1, batch_size * num_heads, head_dim)), (1, 0, 2)) # k:(B*H, L_S, head_dim) v = F.transpose( - F.reshape(v, (-1, batch_size * num_heads, head_dim)), (1, 0, 2)) # v:(B*H, L_S, dim_head) + F.reshape(v, (-1, batch_size * num_heads, head_vdim)), (1, 0, 2)) # v:(B*H, L_S, head_vdim) # attn_output_weights: (B*H, L_T, L_S) attn_output_weights = F.batch_matmul(q, k, transpose_b=True) @@ -1218,12 +1225,12 @@ def multi_head_attention(query, key, value, num_heads, q_weight, k_weight, v_wei attn_output_weights = F.dropout( attn_output_weights, p=dropout) - # (B*H, L_T, L_S) x (B*H, L_S, dim_head) --> (B*H, L_T, dim_head) + # (B*H, L_T, L_S) x (B*H, L_S, head_vdim) --> (B*H, L_T, head_vdim) attn_output = F.batch_matmul(attn_output_weights, v) assert list(attn_output.shape) == [ - batch_size * num_heads, tgt_len, head_dim] + batch_size * num_heads, tgt_len, head_vdim] attn_output = F.reshape(F.transpose( - attn_output, (1, 0, 2)), (tgt_len, batch_size, embed_dim)) # attn_output: (L_T, B, E) + attn_output, (1, 0, 2)), (tgt_len, batch_size, v_embed_dim)) # attn_output: (L_T, B, E_v) attn_output = F.affine(attn_output, out_weight, out_bias, base_axis=2) diff --git a/python/src/nnabla/parametric_functions.py b/python/src/nnabla/parametric_functions.py index 41bd4474b..18283fb59 100644 --- a/python/src/nnabla/parametric_functions.py +++ b/python/src/nnabla/parametric_functions.py @@ -3681,12 +3681,12 @@ def weight_normalization(v, dim=0, eps=1e-12, fix_parameters=False): ('attn_bias_k', 'attnetion bias for k', '(E, 1)', True), ('attn_bias_v', 'attnetion bias for v', '(E, 1)', True), ]) -def multi_head_attention(query, key, value, num_heads=12, dropout=0.0, rng=None, with_bias=True, add_attn_bias=False, additive_mask=None, key_padding_mask=None, fix_parameters=False, param_init=None): +def multi_head_attention(query, key, value, num_heads=12, dropout=0.0, k_embed_dim=None, v_embed_dim=None, out_dim=None, rng=None, with_bias=True, add_attn_bias=False, additive_mask=None, key_padding_mask=None, fix_parameters=False, param_init=None): '''MultiHeadAttention. Computes multi-headed attention with query, key, and value. We use the following notations to describe the inputs and outputs below. - :math:`L_T`: target sequence length, :math:`L_S`: source sequence length, :math:`B`: batch size, :math:`E`: embedding dimension. + :math:`L_T`: target sequence length, :math:`L_S`: source sequence length, :math:`B`: batch size, :math:`D`: input dimension, :math:`E`: embedding dimension. References: @@ -3698,19 +3698,22 @@ def multi_head_attention(query, key, value, num_heads=12, dropout=0.0, rng=None, .. code-block:: python - q = nn.Variable((tgt_len, batch_size, embed_dim)) - k = nn.Variable((src_len, batch_size, kdim)) - v = nn.Variable((src_len, batch_size, vdim)) + q = nn.Variable((tgt_len, batch_size, q_input_dim)) + k = nn.Variable((src_len, batch_size, k_input_dim)) + v = nn.Variable((src_len, batch_size, v_input_dim)) out, w = PF.multi_head_attention(q, k, v) out.forward() Args: - query (~nnabla.Variable): Input N-D array with shape :math:`(L_T, B, E)`. - key (~nnabla.Variable): Input N-D array with shape :math:`(L_S, B, E_k)`. - value (~nnabla.Variable): Input N-D array with shape :math:`(L_S, B, E_v)`. + query (~nnabla.Variable): Input N-D array with shape :math:`(L_T, B, D_q)`. + key (~nnabla.Variable): Input N-D array with shape :math:`(L_S, B, D_k)`. + value (~nnabla.Variable): Input N-D array with shape :math:`(L_S, B, D_v)`. num_heads (int, optional): Number of attention heads. Note that embedding dimensoin E must be divisible by the number of heads. Default is 12 which is conventional. dropout (float, optional): Dropout ratio applied to parameters. Default is 0. + k_embed_dim (int, optional): Embedding dimension for key. If specified, embedding dimensions for both query and key are set as that value. Otherwise, k_embed_dim is set as the same alue as embedding dimension for query. + v_embed_dim (int, optional): Embedding dimension for value. If not specified, it is defaulted as the same value as embedding dimension for query. + out_dim (int, optional): Embedding dimension for output weight. If not spefied, it is defaulted as the same value as embedding dimension for value. rng (numpy.random.RandomState, optional): Random generator for Initializer. Default is None. with_bias (bool, optional): Specify whether to include the bias parameters. Default is True. add_attn_bias (bool, optional): Specify whether to add attention bias parameters for key and value. Default is False. @@ -3728,32 +3731,37 @@ def multi_head_attention(query, key, value, num_heads=12, dropout=0.0, rng=None, ~nnabla.Variable: Output :math:`h_n` with shape :math:`(B, L_T, L_S)` ''' - embed_dim = query.shape[2] - kdim = key.shape[2] - vdim = value.shape[2] + if k_embed_dim is None: + q_embed_dim = k_embed_dim = query.shape[2] + else: + q_embed_dim = k_embed_dim + if v_embed_dim is None: + v_embed_dim = value.shape[2] + if out_dim == None: + out_dim = v_embed_dim if param_init is None: param_init = {} q_weight = param_init.get('q_weight', UniformInitializer( - calc_uniform_lim_glorot(embed_dim, embed_dim), rng)) + calc_uniform_lim_glorot(query.shape[2], q_embed_dim), rng)) k_weight = param_init.get('k_weight', UniformInitializer( - calc_uniform_lim_glorot(kdim, embed_dim), rng)) + calc_uniform_lim_glorot(key.shape[2], k_embed_dim), rng)) v_weight = param_init.get('v_weight', UniformInitializer( - calc_uniform_lim_glorot(vdim, embed_dim), rng)) + calc_uniform_lim_glorot(value.shape[2], v_embed_dim), rng)) qw = get_parameter_or_create( - "q_weight", (embed_dim, embed_dim), q_weight, True, not fix_parameters) + "q_weight", (query.shape[2], q_embed_dim), q_weight, True, not fix_parameters) kw = get_parameter_or_create( - "k_weight", (kdim, embed_dim), k_weight, True, not fix_parameters) + "k_weight", (key.shape[2], k_embed_dim), k_weight, True, not fix_parameters) vw = get_parameter_or_create( - "v_weight", (vdim, embed_dim), v_weight, True, not fix_parameters) + "v_weight", (value.shape[2], v_embed_dim), v_weight, True, not fix_parameters) out_weight = param_init.get('out_weight', UniformInitializer( - calc_uniform_lim_glorot(embed_dim, embed_dim), rng)) + calc_uniform_lim_glorot(v_embed_dim, out_dim), rng)) ow = get_parameter_or_create("out_weight", ( - embed_dim, embed_dim), out_weight, True, not fix_parameters) + v_embed_dim, out_dim), out_weight, True, not fix_parameters) qb = kb = vb = ob = None if with_bias: @@ -3763,25 +3771,25 @@ def multi_head_attention(query, key, value, num_heads=12, dropout=0.0, rng=None, out_bias = param_init.get('out_bias', ConstantInitializer()) qb = get_parameter_or_create( - "q_bias", (embed_dim, ), q_bias, True, not fix_parameters) + "q_bias", (q_embed_dim, ), q_bias, True, not fix_parameters) kb = get_parameter_or_create( - "k_bias", (embed_dim, ), k_bias, True, not fix_parameters) + "k_bias", (k_embed_dim, ), k_bias, True, not fix_parameters) vb = get_parameter_or_create( - "v_bias", (embed_dim, ), v_bias, True, not fix_parameters) + "v_bias", (v_embed_dim, ), v_bias, True, not fix_parameters) ob = get_parameter_or_create( - "out_bias", (embed_dim, ), out_bias, True, not fix_parameters) + "out_bias", (out_dim, ), out_bias, True, not fix_parameters) abk = abv = None if add_attn_bias: attn_bias_k = param_init.get('attn_bias_k', UniformInitializer( - calc_uniform_lim_glorot(embed_dim, 1), rng)) + calc_uniform_lim_glorot(k_embed_dim, 1), rng)) attn_bias_v = param_init.get('attn_bias_v', UniformInitializer( - calc_uniform_lim_glorot(embed_dim, 1), rng)) + calc_uniform_lim_glorot(v_embed_dim, 1), rng)) abk = get_parameter_or_create( - "attn_bias_k", (1, 1, embed_dim), attn_bias_k, True, not fix_parameters) + "attn_bias_k", (1, 1, k_embed_dim), attn_bias_k, True, not fix_parameters) abv = get_parameter_or_create( - "attn_bias_v", (1, 1, embed_dim), attn_bias_v, True, not fix_parameters) + "attn_bias_v", (1, 1, v_embed_dim), attn_bias_v, True, not fix_parameters) return F.multi_head_attention(query, key, value, num_heads, qw, kw, vw, ow, qb, kb, vb, ob, abk, abv, dropout, additive_mask=additive_mask, key_padding_mask=key_padding_mask) diff --git a/python/test/test_parametric_functions.py b/python/test/test_parametric_functions.py index 2c36dde99..7e4dfbaaa 100644 --- a/python/test/test_parametric_functions.py +++ b/python/test/test_parametric_functions.py @@ -1356,20 +1356,20 @@ def test_pf_min_max_quantized_convolution_execution(g_rng, inshape, outmaps, @pytest.mark.parametrize("ctx, func_name", ctxs) @pytest.mark.parametrize("src_len, tgt_len, batch_size", [ (2, 3, 2)]) -@pytest.mark.parametrize("embed_dim, num_heads, dropout, kdim, vdim", [ - (12, 6, 0.0, 12, 12), - (12, 12, 0.0, 10, 10)]) +@pytest.mark.parametrize("q_input_dim, k_input_dim, v_input_dim, k_embed_dim, v_embed_dim, out_dim, num_heads, dropout", [ + (16, 16, 16, 12, 12, 12, 6, 0.0), + (16, 15, 14, 12, 24, 24, 12, 0.0)]) @pytest.mark.parametrize("with_bias", [True, False]) @pytest.mark.parametrize("fix_parameters", [True, False]) @pytest.mark.parametrize("add_attn_bias", [True, False]) @pytest.mark.parametrize("rng", [None, True]) @pytest.mark.parametrize('param_init', [None, True]) -def test_pf_multi_head_attention_execution(g_rng, src_len, tgt_len, batch_size, embed_dim, num_heads, dropout, rng, with_bias, add_attn_bias, kdim, vdim, fix_parameters, param_init, ctx, func_name): +def test_pf_multi_head_attention_execution(g_rng, src_len, tgt_len, batch_size, q_input_dim, k_input_dim, v_input_dim, k_embed_dim, v_embed_dim, out_dim, num_heads, dropout, rng, with_bias, add_attn_bias, fix_parameters, param_init, ctx, func_name): - q_shape = (embed_dim, embed_dim) - k_shape = (kdim, embed_dim) - v_shape = (vdim, embed_dim) - o_shape = (embed_dim, embed_dim) + q_shape = (q_input_dim, k_embed_dim) + k_shape = (k_input_dim, k_embed_dim) + v_shape = (v_input_dim, v_embed_dim) + o_shape = (v_embed_dim, out_dim) q_weight = process_param_init(I.NormalInitializer(), q_shape, g_rng) k_weight = process_param_init(I.NormalInitializer(), k_shape, g_rng) @@ -1383,11 +1383,14 @@ def test_pf_multi_head_attention_execution(g_rng, src_len, tgt_len, batch_size, out_weight=out_weight) if with_bias: - b_shape = (embed_dim, ) - q_bias = process_param_init(I.ConstantInitializer(), b_shape, g_rng) - k_bias = process_param_init(I.ConstantInitializer(), b_shape, g_rng) - v_bias = process_param_init(I.ConstantInitializer(), b_shape, g_rng) - out_bias = process_param_init(I.ConstantInitializer(), b_shape, g_rng) + qb_shape = (k_embed_dim, ) + kb_shape = (k_embed_dim, ) + vb_shape = (v_embed_dim, ) + ob_shape = (out_dim, ) + q_bias = process_param_init(I.ConstantInitializer(), qb_shape, g_rng) + k_bias = process_param_init(I.ConstantInitializer(), kb_shape, g_rng) + v_bias = process_param_init(I.ConstantInitializer(), vb_shape, g_rng) + out_bias = process_param_init(I.ConstantInitializer(), ob_shape, g_rng) param_init['q_bias'] = q_bias param_init['k_bias'] = k_bias @@ -1395,11 +1398,12 @@ def test_pf_multi_head_attention_execution(g_rng, src_len, tgt_len, batch_size, param_init['out_bias'] = out_bias if add_attn_bias: - kv_shape = (1, 1, embed_dim) + attnk_shape = (1, 1, k_embed_dim) + attnv_shape = (1, 1, v_embed_dim) attn_bias_k = process_param_init( - I.NormalInitializer(), kv_shape, g_rng) + I.NormalInitializer(), attnk_shape, g_rng) attn_bias_v = process_param_init( - I.NormalInitializer(), kv_shape, g_rng) + I.NormalInitializer(), attnv_shape, g_rng) param_init['attn_bias_k'] = attn_bias_k param_init['attn_bias_v'] = attn_bias_v @@ -1409,6 +1413,9 @@ def test_pf_multi_head_attention_execution(g_rng, src_len, tgt_len, batch_size, kw = {} insert_if_not_none(kw, 'num_heads', num_heads) insert_if_not_default(kw, 'dropout', dropout, 0.0) + insert_if_not_none(kw, 'k_embed_dim', k_embed_dim) + insert_if_not_none(kw, 'v_embed_dim', v_embed_dim) + insert_if_not_none(kw, 'out_dim', out_dim) insert_if_not_none(kw, 'rng', rng) insert_if_not_default(kw, 'with_bias', with_bias, True) insert_if_not_default(kw, 'add_attn_bias', add_attn_bias, False) @@ -1416,11 +1423,11 @@ def test_pf_multi_head_attention_execution(g_rng, src_len, tgt_len, batch_size, insert_if_not_none(kw, 'param_init', param_init) q = nn.Variable.from_numpy_array( - g_rng.randn(tgt_len, batch_size, embed_dim).astype(np.float32), need_grad=True) + g_rng.randn(tgt_len, batch_size, q_input_dim).astype(np.float32), need_grad=True) k = nn.Variable.from_numpy_array( - g_rng.randn(src_len, batch_size, kdim).astype(np.float32), need_grad=True) + g_rng.randn(src_len, batch_size, k_input_dim).astype(np.float32), need_grad=True) v = nn.Variable.from_numpy_array( - g_rng.randn(src_len, batch_size, vdim).astype(np.float32), need_grad=True) + g_rng.randn(src_len, batch_size, v_input_dim).astype(np.float32), need_grad=True) # Check execution y, w = PF.multi_head_attention(q, k, v, **kw) @@ -1463,18 +1470,18 @@ def test_pf_multi_head_attention_execution(g_rng, src_len, tgt_len, batch_size, assert ow.need_grad if with_bias: - assert qb.shape == b_shape - assert kb.shape == b_shape - assert vb.shape == b_shape - assert ob.shape == b_shape + assert qb.shape == qb_shape + assert kb.shape == kb_shape + assert vb.shape == vb_shape + assert ob.shape == ob_shape assert qb.need_grad assert kb.need_grad assert vb.need_grad assert ob.need_grad if add_attn_bias: - assert abk.shape == kv_shape - assert abv.shape == kv_shape + assert abk.shape == attnk_shape + assert abv.shape == attnv_shape assert abk.need_grad assert abv.need_grad