From 259e1f13f620afc3440ea7dfa58d3315a92a99ef Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Wed, 10 Jan 2024 09:45:45 -0800 Subject: [PATCH] [#3572](https://github.com/google/flax/pull/3572) was rolled back because of internal test breakages. This PR re-adds the attention changes again and fixes internal tests PiperOrigin-RevId: 597276997 --- docs/api_reference/flax.linen/layers.rst | 11 ++- flax/linen/__init__.py | 1 + flax/linen/attention.py | 86 +++++++++++++++++++++- tests/linen/linen_attention_test.py | 90 ++++++++++++++++++++++++ 4 files changed, 184 insertions(+), 4 deletions(-) diff --git a/docs/api_reference/flax.linen/layers.rst b/docs/api_reference/flax.linen/layers.rst index b60f880621..19cbf07fa9 100644 --- a/docs/api_reference/flax.linen/layers.rst +++ b/docs/api_reference/flax.linen/layers.rst @@ -84,11 +84,15 @@ Attention .. flax_module:: :module: flax.linen - :class: SelfAttention + :class: MultiHeadDotProductAttention .. flax_module:: :module: flax.linen - :class: MultiHeadDotProductAttention + :class: MultiHeadAttention + +.. flax_module:: + :module: flax.linen + :class: SelfAttention .. autofunction:: dot_product_attention_weights .. autofunction:: dot_product_attention @@ -147,8 +151,9 @@ Recurrent WeightNorm Sequential Dropout - SelfAttention MultiHeadDotProductAttention + MultiHeadAttention + SelfAttention RNNCellBase LSTMCell OptimizedLSTMCell diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 4b19940e79..661b5821da 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -61,6 +61,7 @@ tanh as tanh, ) from .attention import ( + MultiHeadAttention as MultiHeadAttention, MultiHeadDotProductAttention as MultiHeadDotProductAttention, SelfAttention as SelfAttention, combine_masks as combine_masks, diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 050d3d5e4c..6d1dbb8966 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -572,6 +572,89 @@ def __call__( return out +class MultiHeadAttention(MultiHeadDotProductAttention): + """Multi-head dot-product attention. + Alias for ``MultiHeadDotProductAttention``. + + **NOTE**: ``MultiHeadAttention`` is a wrapper of ``MultiHeadDotProductAttention``, + and so their implementations are identical. However ``MultiHeadAttention`` layers + will, by default, be named ``MultiHeadAttention_{index}``, whereas ``MultiHeadDotProductAttention`` + will be named ``MultiHeadDotProductAttention_{index}``. Therefore, this could affect + checkpointing, param collection names and RNG threading (since the layer name is + used when generating new RNG's) within the module. + + Example usage:: + + >>> import flax.linen as nn + >>> import jax + + >>> layer = nn.MultiHeadAttention(num_heads=8, qkv_features=16) + >>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6) + >>> shape = (4, 3, 2, 5) + >>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape) + >>> variables = layer.init(jax.random.key(0), q) + + >>> # different inputs for inputs_q, inputs_k and inputs_v + >>> out = layer.apply(variables, q, k, v) + >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k) + >>> out = layer.apply(variables, q, k) + >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q) + >>> out = layer.apply(variables, q) + + >>> attention_kwargs = dict( + ... num_heads=8, + ... qkv_features=16, + ... kernel_init=nn.initializers.ones, + ... bias_init=nn.initializers.zeros, + ... dropout_rate=0.5, + ... deterministic=False, + ... ) + >>> class Module(nn.Module): + ... attention_kwargs: dict + ... + ... @nn.compact + ... def __call__(self, x, dropout_rng=None): + ... out1 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) + ... out2 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) + ... return out1, out2 + >>> module = Module(attention_kwargs) + >>> variables = module.init({'params': key1, 'dropout': key2}, q) + + >>> # out1 and out2 are different. + >>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3}) + >>> # out3 and out4 are different. + >>> # out1 and out3 are different. out2 and out4 are different. + >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4}) + >>> # out1 and out2 are the same. + >>> out1, out2 = module.apply(variables, q, dropout_rng=key5) + >>> # out1 and out2 are the same as out3 and out4. + >>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply` + >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5) + + Attributes: + num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + dtype: the dtype of the computation (default: infer from inputs and params) + param_dtype: the dtype passed to parameter initializers (default: float32) + qkv_features: dimension of the key, query, and value. + out_features: dimension of the last projection + broadcast_dropout: bool: use a broadcasted dropout along batch dims. + dropout_rate: dropout rate + deterministic: if false, the attention weight is masked randomly using + dropout, whereas if true, the attention weights are deterministic. + precision: numerical precision of the computation see ``jax.lax.Precision`` + for details. + kernel_init: initializer for the kernel of the Dense layers. + bias_init: initializer for the bias of the Dense layers. + use_bias: bool: whether pointwise QKVO dense transforms use bias. + attention_fn: dot_product_attention or compatible function. Accepts query, + key, value, and returns output of shape ``[bs, dim1, dim2, ..., dimN,, + num_heads, value_channels]`` + decode: whether to prepare and use an autoregressive cache. + normalize_qk: should QK normalization be applied (arxiv.org/abs/2302.05442). + """ + + class SelfAttention(MultiHeadDotProductAttention): """Self-attention special case of multi-head dot-product attention. This layer is deprecated in favor of ``MultiHeadDotProductAttention``. @@ -590,6 +673,7 @@ def __call__( # type: ignore mask: Optional[Array] = None, deterministic: Optional[bool] = None, dropout_rng: Optional[PRNGKey] = None, + sow_weights: bool = False, ): """Applies multi-head dot product self-attention on the input data. @@ -615,7 +699,7 @@ def __call__( # type: ignore DeprecationWarning, ) return super().__call__( - inputs_q, mask=mask, deterministic=deterministic, dropout_rng=dropout_rng + inputs_q, mask=mask, deterministic=deterministic, dropout_rng=dropout_rng, sow_weights=sow_weights ) diff --git a/tests/linen/linen_attention_test.py b/tests/linen/linen_attention_test.py index 5c843ced33..9b8fa2a81d 100644 --- a/tests/linen/linen_attention_test.py +++ b/tests/linen/linen_attention_test.py @@ -420,5 +420,95 @@ def test_autoregressive_decode_with_x64(self): assert y1.shape == (1, 1, 4) assert y2.shape == (1, 1, 4) + def test_attention_alias_equivalence(self): + key1, key2 = random.split(random.key(0), 2) + query = random.uniform(key1, (3, 5)) + key_value = random.uniform(key1, (9, 5)) + attention_kwargs = dict( + num_heads=8, + qkv_features=16, + kernel_init=initializers.lecun_normal(), + bias_init=initializers.uniform(), + deterministic=False, + ) + module1 = nn.MultiHeadDotProductAttention(**attention_kwargs) + module2 = nn.MultiHeadAttention(**attention_kwargs) + out1, v1 = module1.init_with_output(key2, query, key_value) + out2, v2 = module2.init_with_output(key2, query, key_value, key_value) + self.assertTrue((out1 == out2).all()) + self.assertTrue( + jax.tree_util.tree_all(jax.tree_map(lambda x, y: (x == y).all(), v1, v2)) + ) + + def test_attention_alias_submodule(self): + key1, key2 = random.split(random.key(0), 2) + query = random.uniform(key1, (3, 5)) + key_value = random.uniform(key1, (9, 5)) + attention_kwargs = dict( + num_heads=8, + qkv_features=16, + kernel_init=initializers.lecun_normal(), + bias_init=initializers.uniform(), + deterministic=False, + ) + + class Foo1(nn.Module): + attention_kwargs: dict + + @nn.compact + def __call__(self, query, key): + return nn.MultiHeadDotProductAttention(**self.attention_kwargs)( + query, key + ) + + class Foo2(nn.Module): + attention_kwargs: dict + + @nn.compact + def __call__(self, query, key, value): + return nn.MultiHeadAttention(**self.attention_kwargs)(query, key, value) + + module1 = Foo1(attention_kwargs) + module2 = Foo2(attention_kwargs) + out1, v1 = module1.init_with_output(key2, query, key_value) + out2, v2 = module2.init_with_output(key2, query, key_value, key_value) + + # test different output and variables if layer names are different + self.assertTrue((out1 != out2).all()) + v2['params']['MultiHeadDotProductAttention_0'] = v2['params'][ + 'MultiHeadAttention_0' + ] + del v2['params']['MultiHeadAttention_0'] + self.assertTrue( + jax.tree_util.tree_all(jax.tree_map(lambda x, y: (x != y).all(), v1, v2)) + ) + + # test same output if variables are the same + v2 = jax.tree_map(lambda x: x, v1) + v2['params']['MultiHeadAttention_0'] = v2['params'][ + 'MultiHeadDotProductAttention_0' + ] + del v2['params']['MultiHeadDotProductAttention_0'] + out2 = module2.apply(v2, query, key_value, key_value) + self.assertTrue((out1 == out2).all()) + + # test same output and variables if names are the same + class Foo2(nn.Module): + attention_kwargs: dict + + @nn.compact + def __call__(self, query, key, value): + return nn.MultiHeadAttention( + **self.attention_kwargs, name='MultiHeadDotProductAttention_0' + )(query, key, value) + + module2 = Foo2(attention_kwargs) + out2, v2 = module2.init_with_output(key2, query, key_value, key_value) + self.assertTrue((out1 == out2).all()) + self.assertTrue( + jax.tree_util.tree_all(jax.tree_map(lambda x, y: (x == y).all(), v1, v2)) + ) + + if __name__ == '__main__': absltest.main()