diff --git a/pyproject.toml b/pyproject.toml index 3ac7851091..3ee045469c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,8 +122,6 @@ filterwarnings = [ "ignore:`flax.traverse_util.Traversal` will be deprecated.*:DeprecationWarning", # Deprecated legacy checkpoint - just want to keep the tests running for a while "ignore:Flax Checkpointing will soon be deprecated in favor of Orbax.*:DeprecationWarning", - # DeprecationWarning: SelfAttention will be deprecated soon. - "ignore:.*SelfAttention will be deprecated soon.*:DeprecationWarning", # DeprecationWarning: The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead. "ignore:.*The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead.*:DeprecationWarning", # DeprecationWarning: the function signature of MultiHeadDotProductAttention's `__call__` method has changed diff --git a/tests/linen/linen_attention_test.py b/tests/linen/linen_attention_test.py index eececeb02d..cff474fa31 100644 --- a/tests/linen/linen_attention_test.py +++ b/tests/linen/linen_attention_test.py @@ -33,7 +33,7 @@ class AttentionTest(parameterized.TestCase): def test_multihead_self_attention(self): rng = random.key(0) x = jnp.ones((4, 6, 5)) - sa_module = nn.SelfAttention( + sa_module = nn.MultiHeadDotProductAttention( num_heads=8, qkv_features=16, kernel_init=initializers.ones, @@ -47,7 +47,7 @@ def test_multihead_self_attention(self): def test_dtype_infer(self): rng = random.key(0) x = jnp.ones((4, 6, 5), jnp.complex64) - sa_module = nn.SelfAttention( + sa_module = nn.MultiHeadDotProductAttention( num_heads=8, qkv_features=16, kernel_init=initializers.ones, @@ -186,7 +186,7 @@ def test_decoding(self, spatial_shape, attn_dims): inputs = random.normal( key1, (bs,) + spatial_shape + (num_heads * num_features,) ) - module = nn.SelfAttention( + module = nn.MultiHeadDotProductAttention( num_heads=num_heads, qkv_features=num_heads * num_features, precision=lax.Precision.HIGHEST, @@ -198,7 +198,7 @@ def test_decoding(self, spatial_shape, attn_dims): initial_vars = decode_module.init(key2, inputs) state, params = pop(initial_vars, 'params') causal_mask = nn.attention.make_causal_mask(jnp.ones((bs,) + spatial_shape)) - y_ref = jax.jit(lambda x, y: module.apply(initial_vars, x, y))( + y_ref = jax.jit(lambda x, y: module.apply(initial_vars, x, mask=y))( inputs, causal_mask ) @@ -263,28 +263,6 @@ def get_receptive_field_1d(pos): 'autoregressive self-attention.' ) - def test_multihead_self_attention_equality(self): - rng = random.key(0) - q = jnp.ones((4, 2, 3, 5)) - module_kwargs = { - 'num_heads': 8, - 'qkv_features': 16, - 'kernel_init': initializers.ones, - 'bias_init': initializers.zeros, - 'deterministic': False, - } - sa_module0 = nn.MultiHeadDotProductAttention(**module_kwargs) - sa_module1 = nn.SelfAttention(**module_kwargs) - y0, v0 = sa_module0.init_with_output(rng, q) - with self.assertWarnsRegex( - DeprecationWarning, 'SelfAttention will be deprecated soon.' - ): - y1, v1 = sa_module1.init_with_output(rng, q) - self.assertTrue((y0 == y1).all()) - self.assertTrue( - jax.tree_util.tree_all(jax.tree_map(lambda x, y: (x == y).all(), v0, v1)) - ) - def test_multihead_kv_args(self): key1, key2 = random.split(random.key(0), 2) query = random.uniform(key1, (3, 5))