Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change multi_head_attention API #616

Merged
merged 1 commit into from
Apr 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 34 additions & 27 deletions python/src/nnabla/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,7 +1111,7 @@ def multi_head_attention(query, key, value, num_heads, q_weight, k_weight, v_wei

Computes multi-headed attention with query, key, and value.
We use the following notations to describe the inputs and outputs below.
:math:`L_T`: target sequence length, :math:`L_S`: source sequence length, :math:`B`: batch size, :math:`E`: embedding dimension, :math`H`: number of attention heads.
:math:`L_T`: target sequence length, :math:`L_S`: source sequence length, :math:`B`: batch size, :math:`D`: input dimension, :math:`E`: embedding dimension, :math`H`: number of attention heads.

References:

Expand All @@ -1120,54 +1120,61 @@ def multi_head_attention(query, key, value, num_heads, q_weight, k_weight, v_wei
<https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf>

Args:
query (~nnabla.Variable): Input N-D array with shape :math:`(L_T, B, E)`.
key (~nnabla.Variable): Input N-D array with shape :math:`(L_S, B, E_k)`.
value (~nnabla.Variable): Input N-D array with shape :math:`(L_S, B, E_v)`.
query (~nnabla.Variable): Input N-D array with shape :math:`(L_T, B, D_q)`.
key (~nnabla.Variable): Input N-D array with shape :math:`(L_S, B, D_k)`.
value (~nnabla.Variable): Input N-D array with shape :math:`(L_S, B, D_v)`.
num_heads (int): Number of attention heads. Note that embedding dimensoin E must be divisible by the number of heads. Default is 12 which is conventional.
q_weight (~nnabla.Variable): Input N-D array with shape :math:`(E E)`.
k_weight (~nnabla.Variable): Input N-D array with shape :math:`(E_k, E)`.
v_weight (~nnabla.Variable): Input N-D array with shape :math:`(E_v, E)`.
out_weight (~nnabla.Variable): Input N-D array with shape :math:`(E, E)`.
q_weight (~nnabla.Variable): Input N-D array with shape :math:`(D_q, E)`.
k_weight (~nnabla.Variable): Input N-D array with shape :math:`(D_k, E)`.
v_weight (~nnabla.Variable): Input N-D array with shape :math:`(D_v, E_v)`.
out_weight (~nnabla.Variable): Input N-D array with shape :math:`(D_v, E_{out})`.
q_bias (~nnabla.Variable, optional): Input N-D array with shape :math:`(E, )`.
k_bias (~nnabla.Variable, optional): Input N-D array with shape :math:`(E, )`.
v_bias (~nnabla.Variable, optional): Input N-D array with shape :math:`(E, )`.
out_bias (~nnabla.Variable, optional): Input N-D array with shape :math:`(E, )`.
v_bias (~nnabla.Variable, optional): Input N-D array with shape :math:`(E_v, )`.
out_bias (~nnabla.Variable, optional): Input N-D array with shape :math:`(E_{out}, )`.
attn_bias_k (~nnabla.Variable, optional): Input N-D array with shape :math:`(E, )`.
attn_bias_v (~nnabla.Variable, optional): Input N-D array with shape :math:`(E, )`.
attn_bias_v (~nnabla.Variable, optional): Input N-D array with shape :math:`(E_v, )`.
dropout (float, optional): Dropout ratio applied to parameters. Default is 0.
additive_mask (~nnabla.Variable, optional): Input N-D array with shape :math:`(L_T, L_S)`. Values will be added to the attention layer to prevent attention to certain positions.
key_padding_mask (~nnabla.Variable, optional): Input N-D array with shape :math:`(B, L_S)`. Specified padding elements will be ignored by the attention layer. Values must be either 1 or 0.

Returns:
~nnabla.Variable: Output :math:`y` with shape :math:`(L_T, B, E)`
~nnabla.Variable: Output :math:`y` with shape :math:`(L_T, B, E_{out})`
~nnabla.Variable: Output :math:`h_n` with shape :math:`(B, L_T, L_S)`
'''

from . import functions as F

tgt_len, batch_size, embed_dim = query.shape
src_len, batch_size, kdim = key.shape
vdim = value.shape[2]
tgt_len, batch_size, _ = query.shape
src_len, batch_size, _ = key.shape
q_embed_dim = q_weight.shape[1]
k_embed_dim = k_weight.shape[1]
v_embed_dim = v_weight.shape[1]
out_dim = out_weight.shape[1]
assert src_len == value.shape[0]
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim
head_dim = q_embed_dim // num_heads
head_vdim = v_embed_dim // num_heads
assert q_embed_dim == k_embed_dim, "embedding dimensions must be the same for query and key."
assert head_dim * num_heads == q_embed_dim, "embedding dimension must be divisibile by num_heads %d" % num_heads
assert head_vdim * \
num_heads == v_embed_dim, "v_embed_dim must be divisibile by num_heads %d." % num_heads

if key_padding_mask is not None:
assert key_padding_mask.shape[0] == batch_size
assert key_padding_mask.shape[1] == src_len

# query:(L_T, B, E) --> q:(L_T, B, E)
q = F.affine(query, q_weight, q_bias, base_axis=2)
# key:(L_S, B, E_k) --> k:(L_S, B, E)
# key:(L_S, B, D_k) --> k:(L_S, B, E_k)
k = F.affine(key, k_weight, k_bias, base_axis=2)
# value:(L_S, B, E_v) --> v:(L_S, B, E)
# value:(L_S, B, D_v) --> v:(L_S, B, E_v)
v = F.affine(value, v_weight, v_bias, base_axis=2)

q *= float(head_dim) ** -0.5

if attn_bias_k is not None:
attn_bias_k = F.reshape(attn_bias_k, (1, 1, embed_dim))
attn_bias_v = F.reshape(attn_bias_v, (1, 1, embed_dim))
attn_bias_k = F.reshape(attn_bias_k, (1, 1, k_embed_dim))
attn_bias_v = F.reshape(attn_bias_v, (1, 1, v_embed_dim))
src_len += 1
assert attn_bias_k is not None
attn_bias_k = F.broadcast(
Expand All @@ -1184,11 +1191,11 @@ def multi_head_attention(query, key, value, num_heads, q_weight, k_weight, v_wei
key_padding_mask = F.pad(key_padding_mask, (0, 1))

q = F.transpose(
F.reshape(q, (tgt_len, batch_size * num_heads, head_dim)), (1, 0, 2)) # q:(B*H, L_T, dim_head)
F.reshape(q, (tgt_len, batch_size * num_heads, head_dim)), (1, 0, 2)) # q:(B*H, L_T, head_dim)
k = F.transpose(
F.reshape(k, (-1, batch_size * num_heads, head_dim)), (1, 0, 2)) # k:(B*H, L_S, dim_head)
F.reshape(k, (-1, batch_size * num_heads, head_dim)), (1, 0, 2)) # k:(B*H, L_S, head_dim)
v = F.transpose(
F.reshape(v, (-1, batch_size * num_heads, head_dim)), (1, 0, 2)) # v:(B*H, L_S, dim_head)
F.reshape(v, (-1, batch_size * num_heads, head_vdim)), (1, 0, 2)) # v:(B*H, L_S, head_vdim)

# attn_output_weights: (B*H, L_T, L_S)
attn_output_weights = F.batch_matmul(q, k, transpose_b=True)
Expand Down Expand Up @@ -1218,12 +1225,12 @@ def multi_head_attention(query, key, value, num_heads, q_weight, k_weight, v_wei
attn_output_weights = F.dropout(
attn_output_weights, p=dropout)

# (B*H, L_T, L_S) x (B*H, L_S, dim_head) --> (B*H, L_T, dim_head)
# (B*H, L_T, L_S) x (B*H, L_S, head_vdim) --> (B*H, L_T, head_vdim)
attn_output = F.batch_matmul(attn_output_weights, v)
assert list(attn_output.shape) == [
batch_size * num_heads, tgt_len, head_dim]
batch_size * num_heads, tgt_len, head_vdim]
attn_output = F.reshape(F.transpose(
attn_output, (1, 0, 2)), (tgt_len, batch_size, embed_dim)) # attn_output: (L_T, B, E)
attn_output, (1, 0, 2)), (tgt_len, batch_size, v_embed_dim)) # attn_output: (L_T, B, E_v)

attn_output = F.affine(attn_output, out_weight, out_bias, base_axis=2)

Expand Down
62 changes: 35 additions & 27 deletions python/src/nnabla/parametric_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3681,12 +3681,12 @@ def weight_normalization(v, dim=0, eps=1e-12, fix_parameters=False):
('attn_bias_k', 'attnetion bias for k', '(E, 1)', True),
('attn_bias_v', 'attnetion bias for v', '(E, 1)', True),
])
def multi_head_attention(query, key, value, num_heads=12, dropout=0.0, rng=None, with_bias=True, add_attn_bias=False, additive_mask=None, key_padding_mask=None, fix_parameters=False, param_init=None):
def multi_head_attention(query, key, value, num_heads=12, dropout=0.0, k_embed_dim=None, v_embed_dim=None, out_dim=None, rng=None, with_bias=True, add_attn_bias=False, additive_mask=None, key_padding_mask=None, fix_parameters=False, param_init=None):
'''MultiHeadAttention.

Computes multi-headed attention with query, key, and value.
We use the following notations to describe the inputs and outputs below.
:math:`L_T`: target sequence length, :math:`L_S`: source sequence length, :math:`B`: batch size, :math:`E`: embedding dimension.
:math:`L_T`: target sequence length, :math:`L_S`: source sequence length, :math:`B`: batch size, :math:`D`: input dimension, :math:`E`: embedding dimension.

References:

Expand All @@ -3698,19 +3698,22 @@ def multi_head_attention(query, key, value, num_heads=12, dropout=0.0, rng=None,

.. code-block:: python

q = nn.Variable((tgt_len, batch_size, embed_dim))
k = nn.Variable((src_len, batch_size, kdim))
v = nn.Variable((src_len, batch_size, vdim))
q = nn.Variable((tgt_len, batch_size, q_input_dim))
k = nn.Variable((src_len, batch_size, k_input_dim))
v = nn.Variable((src_len, batch_size, v_input_dim))

out, w = PF.multi_head_attention(q, k, v)
out.forward()

Args:
query (~nnabla.Variable): Input N-D array with shape :math:`(L_T, B, E)`.
key (~nnabla.Variable): Input N-D array with shape :math:`(L_S, B, E_k)`.
value (~nnabla.Variable): Input N-D array with shape :math:`(L_S, B, E_v)`.
query (~nnabla.Variable): Input N-D array with shape :math:`(L_T, B, D_q)`.
key (~nnabla.Variable): Input N-D array with shape :math:`(L_S, B, D_k)`.
value (~nnabla.Variable): Input N-D array with shape :math:`(L_S, B, D_v)`.
num_heads (int, optional): Number of attention heads. Note that embedding dimensoin E must be divisible by the number of heads. Default is 12 which is conventional.
dropout (float, optional): Dropout ratio applied to parameters. Default is 0.
k_embed_dim (int, optional): Embedding dimension for key. If specified, embedding dimensions for both query and key are set as that value. Otherwise, k_embed_dim is set as the same alue as embedding dimension for query.
v_embed_dim (int, optional): Embedding dimension for value. If not specified, it is defaulted as the same value as embedding dimension for query.
out_dim (int, optional): Embedding dimension for output weight. If not spefied, it is defaulted as the same value as embedding dimension for value.
rng (numpy.random.RandomState, optional): Random generator for Initializer. Default is None.
with_bias (bool, optional): Specify whether to include the bias parameters. Default is True.
add_attn_bias (bool, optional): Specify whether to add attention bias parameters for key and value. Default is False.
Expand All @@ -3728,32 +3731,37 @@ def multi_head_attention(query, key, value, num_heads=12, dropout=0.0, rng=None,
~nnabla.Variable: Output :math:`h_n` with shape :math:`(B, L_T, L_S)`
'''

embed_dim = query.shape[2]
kdim = key.shape[2]
vdim = value.shape[2]
if k_embed_dim is None:
q_embed_dim = k_embed_dim = query.shape[2]
else:
q_embed_dim = k_embed_dim
if v_embed_dim is None:
v_embed_dim = value.shape[2]
if out_dim == None:
out_dim = v_embed_dim

if param_init is None:
param_init = {}

q_weight = param_init.get('q_weight', UniformInitializer(
calc_uniform_lim_glorot(embed_dim, embed_dim), rng))
calc_uniform_lim_glorot(query.shape[2], q_embed_dim), rng))
k_weight = param_init.get('k_weight', UniformInitializer(
calc_uniform_lim_glorot(kdim, embed_dim), rng))
calc_uniform_lim_glorot(key.shape[2], k_embed_dim), rng))
v_weight = param_init.get('v_weight', UniformInitializer(
calc_uniform_lim_glorot(vdim, embed_dim), rng))
calc_uniform_lim_glorot(value.shape[2], v_embed_dim), rng))

qw = get_parameter_or_create(
"q_weight", (embed_dim, embed_dim), q_weight, True, not fix_parameters)
"q_weight", (query.shape[2], q_embed_dim), q_weight, True, not fix_parameters)
kw = get_parameter_or_create(
"k_weight", (kdim, embed_dim), k_weight, True, not fix_parameters)
"k_weight", (key.shape[2], k_embed_dim), k_weight, True, not fix_parameters)
vw = get_parameter_or_create(
"v_weight", (vdim, embed_dim), v_weight, True, not fix_parameters)
"v_weight", (value.shape[2], v_embed_dim), v_weight, True, not fix_parameters)

out_weight = param_init.get('out_weight', UniformInitializer(
calc_uniform_lim_glorot(embed_dim, embed_dim), rng))
calc_uniform_lim_glorot(v_embed_dim, out_dim), rng))

ow = get_parameter_or_create("out_weight", (
embed_dim, embed_dim), out_weight, True, not fix_parameters)
v_embed_dim, out_dim), out_weight, True, not fix_parameters)

qb = kb = vb = ob = None
if with_bias:
Expand All @@ -3763,25 +3771,25 @@ def multi_head_attention(query, key, value, num_heads=12, dropout=0.0, rng=None,
out_bias = param_init.get('out_bias', ConstantInitializer())

qb = get_parameter_or_create(
"q_bias", (embed_dim, ), q_bias, True, not fix_parameters)
"q_bias", (q_embed_dim, ), q_bias, True, not fix_parameters)
kb = get_parameter_or_create(
"k_bias", (embed_dim, ), k_bias, True, not fix_parameters)
"k_bias", (k_embed_dim, ), k_bias, True, not fix_parameters)
vb = get_parameter_or_create(
"v_bias", (embed_dim, ), v_bias, True, not fix_parameters)
"v_bias", (v_embed_dim, ), v_bias, True, not fix_parameters)
ob = get_parameter_or_create(
"out_bias", (embed_dim, ), out_bias, True, not fix_parameters)
"out_bias", (out_dim, ), out_bias, True, not fix_parameters)

abk = abv = None
if add_attn_bias:
attn_bias_k = param_init.get('attn_bias_k', UniformInitializer(
calc_uniform_lim_glorot(embed_dim, 1), rng))
calc_uniform_lim_glorot(k_embed_dim, 1), rng))
attn_bias_v = param_init.get('attn_bias_v', UniformInitializer(
calc_uniform_lim_glorot(embed_dim, 1), rng))
calc_uniform_lim_glorot(v_embed_dim, 1), rng))

abk = get_parameter_or_create(
"attn_bias_k", (1, 1, embed_dim), attn_bias_k, True, not fix_parameters)
"attn_bias_k", (1, 1, k_embed_dim), attn_bias_k, True, not fix_parameters)
abv = get_parameter_or_create(
"attn_bias_v", (1, 1, embed_dim), attn_bias_v, True, not fix_parameters)
"attn_bias_v", (1, 1, v_embed_dim), attn_bias_v, True, not fix_parameters)

return F.multi_head_attention(query, key, value, num_heads, qw, kw, vw, ow, qb, kb, vb, ob, abk, abv, dropout, additive_mask=additive_mask, key_padding_mask=key_padding_mask)

Expand Down
Loading