From ed5254e386cbe91666bf60bde025d8206480accb Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Tue, 19 Mar 2024 01:28:02 +0000 Subject: [PATCH] tmp tmp introduce flash_attention Add test case Fix the test Fix linters --- test/test_pallas.py | 22 ++++++++++ torch_xla/experimental/custom_kernel.py | 57 ++++++++++++++++++++++--- 2 files changed, 74 insertions(+), 5 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index 95ead568cfc..b6a6c613fe6 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -202,6 +202,28 @@ def attention(q, k, v): expected_o = attention(q, k, v) self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu())) + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + @unittest.mock.patch.dict(os.environ, {"XLA_TPU_LAYOUT": "0"}) + def test_flash_attention_wrapper(self): + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + from torch_xla.experimental.custom_kernel import flash_attention + + def attention(q, k, v): + attn_weight = q @ k.transpose(-2, -1) + attn_weight = nn.functional.softmax(attn_weight, dim=-1) + attn_output = attn_weight @ v + return attn_output + + q = torch.randn(3, 2, 128, 4).to("xla") + k = torch.randn(3, 2, 128, 4).to("xla") + v = torch.randn(3, 2, 128, 4).to("xla") + + o = flash_attention(q, k, v) + expected_o = attention(q, k, v) + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu())) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index b6b7a304b5b..135b9815b95 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -72,7 +72,6 @@ def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable): import jax import jax.numpy as jnp import jax._src.pallas.mosaic.pallas_call_registration - from jax.experimental import pallas as pl def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype: if dtype == torch.float32: @@ -97,19 +96,26 @@ def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype: raise ValueError(f"Unsupported dtype: {dtype}") # TODO: Maybe we can cache the payload for the same input. - def wrapped_kernel(kernel: Callable, output_shape_dtype_fn: Callable, *args): + def wrapped_kernel(kernel: Callable, + output_shape_dtype_fn: Callable, + *args, + static_argnames=[], + **kwargs) -> Callable: jax_args = [] for i, arg in enumerate(args): if torch.is_tensor(arg): - # ShapedArray doesn't have any storage and thus is very suitable for generating the payload. - jax_meta_tensor = jax.core.ShapedArray( + # ShapeDtypeStruct doesn't have any storage and thus is very suitable for generating the payload. + jax_meta_tensor = jax.ShapeDtypeStruct( arg.shape, convert_torch_dtype_to_jax(arg.dtype)) jax_args.append(jax_meta_tensor) else: # TODO: We can support more types here. assert False, f"Unsupported argument type: {type(arg)}" - ir = jax.jit(kernel).lower(*jax_args).compiler_ir() + # Here we ignore the kwargs for execution as most of the time, the kwargs is only used in traced code. + ir = jax.jit( + kernel, static_argnames=static_argnames).lower(*jax_args, + **kwargs).compiler_ir() payload = _extract_backend_config(ir) output_shape, output_dtype = output_shape_dtype_fn(*args) output = torch.empty(output_shape, dtype=output_dtype).to(xm.xla_device()) @@ -117,3 +123,44 @@ def wrapped_kernel(kernel: Callable, output_shape_dtype_fn: Callable, *args): return output return functools.partial(wrapped_kernel, kernel, output_shape_dtype_fn) + + +# This is a simplified wrapper on top of https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139 +# where we only takes q, k, v, and segment_ids as input and set causal and block_sizes for the users. +def flash_attention( + q, # [batch_size, num_heads, q_seq_len, d_model] + k, # [batch_size, num_heads, kv_seq_len, d_model] + v, # [batch_size, num_heads, kv_seq_len, d_model] + segment_ids=None, # q of [batch_size, q_seq_len] and kv of [batch_size, kv_seq_len] +): + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + import jax + import jax.numpy as jnp + import jax.experimental.pallas.ops.tpu.flash_attention as tpu_flash_attention + + # TODO: Support segment_ids and causal. + flash_attention_kernel = make_kernel_from_pallas( + tpu_flash_attention.flash_attention, lambda q, k, v: (q.shape, q.dtype)) + + # The block_sizes configuration is copied from https://github.com/google/maxtext/blob/0fee320451738166c8e596dc63a57a4673671576/MaxText/layers/attentions.py#L215-L240 + # It yields much better performance than the default block_sizes. + return flash_attention_kernel( + q, + k, + v, + static_argnames=["block_sizes"], + block_sizes=tpu_flash_attention.BlockSizes( + block_q=min(512, q.shape[2]), + block_k_major=min(512, k.shape[2]), + block_k=min(512, k.shape[2]), + block_b=min(2, q.shape[0]), + block_q_major_dkv=min(512, q.shape[2]), + block_k_major_dkv=min(512, k.shape[2]), + block_q_dkv=min(512, q.shape[2]), + block_k_dkv=min(512, k.shape[2]), + block_q_dq=min(1024, q.shape[2]), + block_k_dq=min(256, k.shape[2]), + block_k_major_dq=min(512, k.shape[2]), + ))