Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport] Make a FlashAttention Wrapper #6827

Merged
merged 1 commit into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,28 @@ def attention(q, k, v):
expected_o = attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@unittest.mock.patch.dict(os.environ, {"XLA_TPU_LAYOUT": "0"})
def test_flash_attention_wrapper(self):
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
from torch_xla.experimental.custom_kernel import flash_attention

def attention(q, k, v):
attn_weight = q @ k.transpose(-2, -1)
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output

q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")

o = flash_attention(q, k, v)
expected_o = attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
57 changes: 52 additions & 5 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable):
import jax
import jax.numpy as jnp
import jax._src.pallas.mosaic.pallas_call_registration
from jax.experimental import pallas as pl

def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
if dtype == torch.float32:
Expand All @@ -97,23 +96,71 @@ def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
raise ValueError(f"Unsupported dtype: {dtype}")

# TODO: Maybe we can cache the payload for the same input.
def wrapped_kernel(kernel: Callable, output_shape_dtype_fn: Callable, *args):
def wrapped_kernel(kernel: Callable,
output_shape_dtype_fn: Callable,
*args,
static_argnames=[],
**kwargs) -> Callable:
jax_args = []
for i, arg in enumerate(args):
if torch.is_tensor(arg):
# ShapedArray doesn't have any storage and thus is very suitable for generating the payload.
jax_meta_tensor = jax.core.ShapedArray(
# ShapeDtypeStruct doesn't have any storage and thus is very suitable for generating the payload.
jax_meta_tensor = jax.ShapeDtypeStruct(
arg.shape, convert_torch_dtype_to_jax(arg.dtype))
jax_args.append(jax_meta_tensor)
else:
# TODO: We can support more types here.
assert False, f"Unsupported argument type: {type(arg)}"

ir = jax.jit(kernel).lower(*jax_args).compiler_ir()
# Here we ignore the kwargs for execution as most of the time, the kwargs is only used in traced code.
ir = jax.jit(
kernel, static_argnames=static_argnames).lower(*jax_args,
**kwargs).compiler_ir()
payload = _extract_backend_config(ir)
output_shape, output_dtype = output_shape_dtype_fn(*args)
output = torch.empty(output_shape, dtype=output_dtype).to(xm.xla_device())
torch_xla._XLAC._xla_tpu_custom_call_(output, args, payload)
return output

return functools.partial(wrapped_kernel, kernel, output_shape_dtype_fn)


# This is a simplified wrapper on top of https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139
# where we only takes q, k, v, and segment_ids as input and set causal and block_sizes for the users.
def flash_attention(
q, # [batch_size, num_heads, q_seq_len, d_model]
k, # [batch_size, num_heads, kv_seq_len, d_model]
v, # [batch_size, num_heads, kv_seq_len, d_model]
segment_ids=None, # q of [batch_size, q_seq_len] and kv of [batch_size, kv_seq_len]
):
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
import jax
import jax.numpy as jnp
import jax.experimental.pallas.ops.tpu.flash_attention as tpu_flash_attention

# TODO: Support segment_ids and causal.
flash_attention_kernel = make_kernel_from_pallas(
tpu_flash_attention.flash_attention, lambda q, k, v: (q.shape, q.dtype))

# The block_sizes configuration is copied from https://github.com/google/maxtext/blob/0fee320451738166c8e596dc63a57a4673671576/MaxText/layers/attentions.py#L215-L240
# It yields much better performance than the default block_sizes.
return flash_attention_kernel(
q,
k,
v,
static_argnames=["block_sizes"],
block_sizes=tpu_flash_attention.BlockSizes(
block_q=min(512, q.shape[2]),
block_k_major=min(512, k.shape[2]),
block_k=min(512, k.shape[2]),
block_b=min(2, q.shape[0]),
block_q_major_dkv=min(512, q.shape[2]),
block_k_major_dkv=min(512, k.shape[2]),
block_q_dkv=min(512, q.shape[2]),
block_k_dkv=min(512, k.shape[2]),
block_q_dq=min(1024, q.shape[2]),
block_k_dq=min(256, k.shape[2]),
block_k_major_dq=min(512, k.shape[2]),
))
Loading