Skip to content

Commit

Permalink
remove SelfAttention test and warning filter
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Nov 7, 2023
1 parent cbf7bea commit 7508594
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 28 deletions.
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ filterwarnings = [
"ignore:`flax.traverse_util.Traversal` will be deprecated.*:DeprecationWarning",
# Deprecated legacy checkpoint - just want to keep the tests running for a while
"ignore:Flax Checkpointing will soon be deprecated in favor of Orbax.*:DeprecationWarning",
# DeprecationWarning: SelfAttention will be deprecated soon.
"ignore:.*SelfAttention will be deprecated soon.*:DeprecationWarning",
# DeprecationWarning: The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead.
"ignore:.*The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead.*:DeprecationWarning",
# DeprecationWarning: the function signature of MultiHeadDotProductAttention's `__call__` method has changed
Expand Down
30 changes: 4 additions & 26 deletions tests/linen/linen_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class AttentionTest(parameterized.TestCase):
def test_multihead_self_attention(self):
rng = random.key(0)
x = jnp.ones((4, 6, 5))
sa_module = nn.SelfAttention(
sa_module = nn.MultiHeadDotProductAttention(
num_heads=8,
qkv_features=16,
kernel_init=initializers.ones,
Expand All @@ -47,7 +47,7 @@ def test_multihead_self_attention(self):
def test_dtype_infer(self):
rng = random.key(0)
x = jnp.ones((4, 6, 5), jnp.complex64)
sa_module = nn.SelfAttention(
sa_module = nn.MultiHeadDotProductAttention(
num_heads=8,
qkv_features=16,
kernel_init=initializers.ones,
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_decoding(self, spatial_shape, attn_dims):
inputs = random.normal(
key1, (bs,) + spatial_shape + (num_heads * num_features,)
)
module = nn.SelfAttention(
module = nn.MultiHeadDotProductAttention(
num_heads=num_heads,
qkv_features=num_heads * num_features,
precision=lax.Precision.HIGHEST,
Expand All @@ -198,7 +198,7 @@ def test_decoding(self, spatial_shape, attn_dims):
initial_vars = decode_module.init(key2, inputs)
state, params = pop(initial_vars, 'params')
causal_mask = nn.attention.make_causal_mask(jnp.ones((bs,) + spatial_shape))
y_ref = jax.jit(lambda x, y: module.apply(initial_vars, x, y))(
y_ref = jax.jit(lambda x, y: module.apply(initial_vars, x, mask=y))(
inputs, causal_mask
)

Expand Down Expand Up @@ -263,28 +263,6 @@ def get_receptive_field_1d(pos):
'autoregressive self-attention.'
)

def test_multihead_self_attention_equality(self):
rng = random.key(0)
q = jnp.ones((4, 2, 3, 5))
module_kwargs = {
'num_heads': 8,
'qkv_features': 16,
'kernel_init': initializers.ones,
'bias_init': initializers.zeros,
'deterministic': False,
}
sa_module0 = nn.MultiHeadDotProductAttention(**module_kwargs)
sa_module1 = nn.SelfAttention(**module_kwargs)
y0, v0 = sa_module0.init_with_output(rng, q)
with self.assertWarnsRegex(
DeprecationWarning, 'SelfAttention will be deprecated soon.'
):
y1, v1 = sa_module1.init_with_output(rng, q)
self.assertTrue((y0 == y1).all())
self.assertTrue(
jax.tree_util.tree_all(jax.tree_map(lambda x, y: (x == y).all(), v0, v1))
)

def test_multihead_kv_args(self):
key1, key2 = random.split(random.key(0), 2)
query = random.uniform(key1, (3, 5))
Expand Down

0 comments on commit 7508594

Please sign in to comment.