diff --git a/test/test_pallas.py b/test/test_pallas.py index b6a6c613fe6..f2db81c7d65 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -224,6 +224,30 @@ def attention(q, k, v): self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu())) jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + @unittest.mock.patch.dict(os.environ, {"XLA_TPU_LAYOUT": "0"}) + def test_flash_attention_wrapper_causal(self): + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + from torch_xla.experimental.custom_kernel import flash_attention + + def attention(q, k, v): + attn_weight = q @ k.transpose(-2, -1) + attn_weight = nn.functional.softmax(attn_weight, dim=-1) + attn_output = attn_weight @ v + return attn_output + + q = torch.randn(3, 2, 128, 4).to("xla") + k = torch.randn(3, 2, 128, 4).to("xla") + v = torch.randn(3, 2, 128, 4).to("xla") + + # The causal mask is turned on by default in the wrapper. + # It masks out the top right triangle of the attention matrix, therefore it speeds up the compute but also changes the output. + o = flash_attention(q, k, v, causal=True) + expected_o = attention(q, k, v) + self.assertFalse(torch.allclose(o.cpu(), expected_o.cpu())) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 135b9815b95..c8086374963 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -126,12 +126,13 @@ def wrapped_kernel(kernel: Callable, # This is a simplified wrapper on top of https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139 -# where we only takes q, k, v, and segment_ids as input and set causal and block_sizes for the users. +# where we only takes q, k, v, segment_ids and causal as input and set block_sizes for the users. def flash_attention( q, # [batch_size, num_heads, q_seq_len, d_model] k, # [batch_size, num_heads, kv_seq_len, d_model] v, # [batch_size, num_heads, kv_seq_len, d_model] segment_ids=None, # q of [batch_size, q_seq_len] and kv of [batch_size, kv_seq_len] + causal=False, ): # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. @@ -140,7 +141,7 @@ def flash_attention( import jax.numpy as jnp import jax.experimental.pallas.ops.tpu.flash_attention as tpu_flash_attention - # TODO: Support segment_ids and causal. + # TODO: Support segment_ids. flash_attention_kernel = make_kernel_from_pallas( tpu_flash_attention.flash_attention, lambda q, k, v: (q.shape, q.dtype)) @@ -150,7 +151,7 @@ def flash_attention( q, k, v, - static_argnames=["block_sizes"], + static_argnames=["block_sizes", "causal"], block_sizes=tpu_flash_attention.BlockSizes( block_q=min(512, q.shape[2]), block_k_major=min(512, k.shape[2]), @@ -163,4 +164,5 @@ def flash_attention( block_q_dq=min(1024, q.shape[2]), block_k_dq=min(256, k.shape[2]), block_k_major_dq=min(512, k.shape[2]), - )) + ), + causal=causal)