Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
tmp

introduce flash_attention

Add test case

Fix the test

Fix linters
  • Loading branch information
alanwaketan committed Mar 27, 2024
1 parent a805505 commit ed5254e
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 5 deletions.
22 changes: 22 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,28 @@ def attention(q, k, v):
expected_o = attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@unittest.mock.patch.dict(os.environ, {"XLA_TPU_LAYOUT": "0"})
def test_flash_attention_wrapper(self):
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
from torch_xla.experimental.custom_kernel import flash_attention

def attention(q, k, v):
attn_weight = q @ k.transpose(-2, -1)
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output

q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")

o = flash_attention(q, k, v)
expected_o = attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
57 changes: 52 additions & 5 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable):
import jax
import jax.numpy as jnp
import jax._src.pallas.mosaic.pallas_call_registration
from jax.experimental import pallas as pl

def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
if dtype == torch.float32:
Expand All @@ -97,23 +96,71 @@ def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
raise ValueError(f"Unsupported dtype: {dtype}")

# TODO: Maybe we can cache the payload for the same input.
def wrapped_kernel(kernel: Callable, output_shape_dtype_fn: Callable, *args):
def wrapped_kernel(kernel: Callable,
output_shape_dtype_fn: Callable,
*args,
static_argnames=[],
**kwargs) -> Callable:
jax_args = []
for i, arg in enumerate(args):
if torch.is_tensor(arg):
# ShapedArray doesn't have any storage and thus is very suitable for generating the payload.
jax_meta_tensor = jax.core.ShapedArray(
# ShapeDtypeStruct doesn't have any storage and thus is very suitable for generating the payload.
jax_meta_tensor = jax.ShapeDtypeStruct(
arg.shape, convert_torch_dtype_to_jax(arg.dtype))
jax_args.append(jax_meta_tensor)
else:
# TODO: We can support more types here.
assert False, f"Unsupported argument type: {type(arg)}"

ir = jax.jit(kernel).lower(*jax_args).compiler_ir()
# Here we ignore the kwargs for execution as most of the time, the kwargs is only used in traced code.
ir = jax.jit(
kernel, static_argnames=static_argnames).lower(*jax_args,
**kwargs).compiler_ir()
payload = _extract_backend_config(ir)
output_shape, output_dtype = output_shape_dtype_fn(*args)
output = torch.empty(output_shape, dtype=output_dtype).to(xm.xla_device())
torch_xla._XLAC._xla_tpu_custom_call_(output, args, payload)
return output

return functools.partial(wrapped_kernel, kernel, output_shape_dtype_fn)


# This is a simplified wrapper on top of https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139
# where we only takes q, k, v, and segment_ids as input and set causal and block_sizes for the users.
def flash_attention(
q, # [batch_size, num_heads, q_seq_len, d_model]
k, # [batch_size, num_heads, kv_seq_len, d_model]
v, # [batch_size, num_heads, kv_seq_len, d_model]
segment_ids=None, # q of [batch_size, q_seq_len] and kv of [batch_size, kv_seq_len]
):
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
import jax
import jax.numpy as jnp
import jax.experimental.pallas.ops.tpu.flash_attention as tpu_flash_attention

# TODO: Support segment_ids and causal.
flash_attention_kernel = make_kernel_from_pallas(
tpu_flash_attention.flash_attention, lambda q, k, v: (q.shape, q.dtype))

# The block_sizes configuration is copied from https://github.com/google/maxtext/blob/0fee320451738166c8e596dc63a57a4673671576/MaxText/layers/attentions.py#L215-L240
# It yields much better performance than the default block_sizes.
return flash_attention_kernel(
q,
k,
v,
static_argnames=["block_sizes"],
block_sizes=tpu_flash_attention.BlockSizes(
block_q=min(512, q.shape[2]),
block_k_major=min(512, k.shape[2]),
block_k=min(512, k.shape[2]),
block_b=min(2, q.shape[0]),
block_q_major_dkv=min(512, q.shape[2]),
block_k_major_dkv=min(512, k.shape[2]),
block_q_dkv=min(512, q.shape[2]),
block_k_dkv=min(512, k.shape[2]),
block_q_dq=min(1024, q.shape[2]),
block_k_dq=min(256, k.shape[2]),
block_k_major_dq=min(512, k.shape[2]),
))

0 comments on commit ed5254e

Please sign in to comment.