Skip to content

Commit

Permalink
[#3572](#3572) was rolled back because of internal test breakages. Th…
Browse files Browse the repository at this point in the history
…is PR re-adds the attention changes again and fixes internal tests

PiperOrigin-RevId: 597276997
  • Loading branch information
chiamp authored and Flax Authors committed Jan 11, 2024
1 parent ed44f52 commit 259e1f1
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 4 deletions.
11 changes: 8 additions & 3 deletions docs/api_reference/flax.linen/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,15 @@ Attention

.. flax_module::
:module: flax.linen
:class: SelfAttention
:class: MultiHeadDotProductAttention

.. flax_module::
:module: flax.linen
:class: MultiHeadDotProductAttention
:class: MultiHeadAttention

.. flax_module::
:module: flax.linen
:class: SelfAttention

.. autofunction:: dot_product_attention_weights
.. autofunction:: dot_product_attention
Expand Down Expand Up @@ -147,8 +151,9 @@ Recurrent
WeightNorm
Sequential
Dropout
SelfAttention
MultiHeadDotProductAttention
MultiHeadAttention
SelfAttention
RNNCellBase
LSTMCell
OptimizedLSTMCell
Expand Down
1 change: 1 addition & 0 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
tanh as tanh,
)
from .attention import (
MultiHeadAttention as MultiHeadAttention,
MultiHeadDotProductAttention as MultiHeadDotProductAttention,
SelfAttention as SelfAttention,
combine_masks as combine_masks,
Expand Down
86 changes: 85 additions & 1 deletion flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,89 @@ def __call__(
return out


class MultiHeadAttention(MultiHeadDotProductAttention):
"""Multi-head dot-product attention.
Alias for ``MultiHeadDotProductAttention``.
**NOTE**: ``MultiHeadAttention`` is a wrapper of ``MultiHeadDotProductAttention``,
and so their implementations are identical. However ``MultiHeadAttention`` layers
will, by default, be named ``MultiHeadAttention_{index}``, whereas ``MultiHeadDotProductAttention``
will be named ``MultiHeadDotProductAttention_{index}``. Therefore, this could affect
checkpointing, param collection names and RNG threading (since the layer name is
used when generating new RNG's) within the module.
Example usage::
>>> import flax.linen as nn
>>> import jax
>>> layer = nn.MultiHeadAttention(num_heads=8, qkv_features=16)
>>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6)
>>> shape = (4, 3, 2, 5)
>>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape)
>>> variables = layer.init(jax.random.key(0), q)
>>> # different inputs for inputs_q, inputs_k and inputs_v
>>> out = layer.apply(variables, q, k, v)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k)
>>> out = layer.apply(variables, q, k)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q)
>>> out = layer.apply(variables, q)
>>> attention_kwargs = dict(
... num_heads=8,
... qkv_features=16,
... kernel_init=nn.initializers.ones,
... bias_init=nn.initializers.zeros,
... dropout_rate=0.5,
... deterministic=False,
... )
>>> class Module(nn.Module):
... attention_kwargs: dict
...
... @nn.compact
... def __call__(self, x, dropout_rng=None):
... out1 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
... out2 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
... return out1, out2
>>> module = Module(attention_kwargs)
>>> variables = module.init({'params': key1, 'dropout': key2}, q)
>>> # out1 and out2 are different.
>>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3})
>>> # out3 and out4 are different.
>>> # out1 and out3 are different. out2 and out4 are different.
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4})
>>> # out1 and out2 are the same.
>>> out1, out2 = module.apply(variables, q, dropout_rng=key5)
>>> # out1 and out2 are the same as out3 and out4.
>>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply`
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
Attributes:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
dtype: the dtype of the computation (default: infer from inputs and params)
param_dtype: the dtype passed to parameter initializers (default: float32)
qkv_features: dimension of the key, query, and value.
out_features: dimension of the last projection
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rate: dropout rate
deterministic: if false, the attention weight is masked randomly using
dropout, whereas if true, the attention weights are deterministic.
precision: numerical precision of the computation see ``jax.lax.Precision``
for details.
kernel_init: initializer for the kernel of the Dense layers.
bias_init: initializer for the bias of the Dense layers.
use_bias: bool: whether pointwise QKVO dense transforms use bias.
attention_fn: dot_product_attention or compatible function. Accepts query,
key, value, and returns output of shape ``[bs, dim1, dim2, ..., dimN,,
num_heads, value_channels]``
decode: whether to prepare and use an autoregressive cache.
normalize_qk: should QK normalization be applied (arxiv.org/abs/2302.05442).
"""


class SelfAttention(MultiHeadDotProductAttention):
"""Self-attention special case of multi-head dot-product attention.
This layer is deprecated in favor of ``MultiHeadDotProductAttention``.
Expand All @@ -590,6 +673,7 @@ def __call__( # type: ignore
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
dropout_rng: Optional[PRNGKey] = None,
sow_weights: bool = False,
):
"""Applies multi-head dot product self-attention on the input data.
Expand All @@ -615,7 +699,7 @@ def __call__( # type: ignore
DeprecationWarning,
)
return super().__call__(
inputs_q, mask=mask, deterministic=deterministic, dropout_rng=dropout_rng
inputs_q, mask=mask, deterministic=deterministic, dropout_rng=dropout_rng, sow_weights=sow_weights
)


Expand Down
90 changes: 90 additions & 0 deletions tests/linen/linen_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,5 +420,95 @@ def test_autoregressive_decode_with_x64(self):
assert y1.shape == (1, 1, 4)
assert y2.shape == (1, 1, 4)

def test_attention_alias_equivalence(self):
key1, key2 = random.split(random.key(0), 2)
query = random.uniform(key1, (3, 5))
key_value = random.uniform(key1, (9, 5))
attention_kwargs = dict(
num_heads=8,
qkv_features=16,
kernel_init=initializers.lecun_normal(),
bias_init=initializers.uniform(),
deterministic=False,
)
module1 = nn.MultiHeadDotProductAttention(**attention_kwargs)
module2 = nn.MultiHeadAttention(**attention_kwargs)
out1, v1 = module1.init_with_output(key2, query, key_value)
out2, v2 = module2.init_with_output(key2, query, key_value, key_value)
self.assertTrue((out1 == out2).all())
self.assertTrue(
jax.tree_util.tree_all(jax.tree_map(lambda x, y: (x == y).all(), v1, v2))
)

def test_attention_alias_submodule(self):
key1, key2 = random.split(random.key(0), 2)
query = random.uniform(key1, (3, 5))
key_value = random.uniform(key1, (9, 5))
attention_kwargs = dict(
num_heads=8,
qkv_features=16,
kernel_init=initializers.lecun_normal(),
bias_init=initializers.uniform(),
deterministic=False,
)

class Foo1(nn.Module):
attention_kwargs: dict

@nn.compact
def __call__(self, query, key):
return nn.MultiHeadDotProductAttention(**self.attention_kwargs)(
query, key
)

class Foo2(nn.Module):
attention_kwargs: dict

@nn.compact
def __call__(self, query, key, value):
return nn.MultiHeadAttention(**self.attention_kwargs)(query, key, value)

module1 = Foo1(attention_kwargs)
module2 = Foo2(attention_kwargs)
out1, v1 = module1.init_with_output(key2, query, key_value)
out2, v2 = module2.init_with_output(key2, query, key_value, key_value)

# test different output and variables if layer names are different
self.assertTrue((out1 != out2).all())
v2['params']['MultiHeadDotProductAttention_0'] = v2['params'][
'MultiHeadAttention_0'
]
del v2['params']['MultiHeadAttention_0']
self.assertTrue(
jax.tree_util.tree_all(jax.tree_map(lambda x, y: (x != y).all(), v1, v2))
)

# test same output if variables are the same
v2 = jax.tree_map(lambda x: x, v1)
v2['params']['MultiHeadAttention_0'] = v2['params'][
'MultiHeadDotProductAttention_0'
]
del v2['params']['MultiHeadDotProductAttention_0']
out2 = module2.apply(v2, query, key_value, key_value)
self.assertTrue((out1 == out2).all())

# test same output and variables if names are the same
class Foo2(nn.Module):
attention_kwargs: dict

@nn.compact
def __call__(self, query, key, value):
return nn.MultiHeadAttention(
**self.attention_kwargs, name='MultiHeadDotProductAttention_0'
)(query, key, value)

module2 = Foo2(attention_kwargs)
out2, v2 = module2.init_with_output(key2, query, key_value, key_value)
self.assertTrue((out1 == out2).all())
self.assertTrue(
jax.tree_util.tree_all(jax.tree_map(lambda x, y: (x == y).all(), v1, v2))
)


if __name__ == '__main__':
absltest.main()

0 comments on commit 259e1f1

Please sign in to comment.