Skip to content

Commit

Permalink
Fix FA tutorial (triton-lang#485)
Browse files Browse the repository at this point in the history
- Check correctness for fp8 inputs only when torch supports it
- Only run benchmark in fp16
  • Loading branch information
zhanglx13 authored Jan 25, 2024
1 parent 3c6010d commit c631824
Showing 1 changed file with 39 additions and 33 deletions.
72 changes: 39 additions & 33 deletions python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
import triton
import triton.language as tl

torch_dtype:tl.constexpr = torch.float16
TORCH_HAS_FP8 = False
TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2')
TORCH_HAS_FP8E5FNUZ = hasattr(torch, 'float8_e5m2fnuz')
if TORCH_HAS_FP8E5:
torch_dtype:tl.constexpr = torch.float8_e5m2
TORCH_HAS_FP8 = True
if TORCH_HAS_FP8E5FNUZ:
torch_dtype:tl.constexpr = torch.float8_e5m2fnuz
TORCH_HAS_FP8 = True
# Pick the fp8 data type

# AMD E4M3B8
# Note: When picking this f8 data type, scaling is required when using f8
# for the second gemm
#TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz')

# AMD E5M2B16
TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz')


@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q,
Expand Down Expand Up @@ -555,7 +555,7 @@ def forward(ctx, q, k, v, causal, sm_scale):
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
o = torch.empty_like(q, dtype=v.dtype)
if torch.version.hip is None:
BLOCK_M = 128
BLOCK_N = 64 if Lk <= 64 else 32
Expand Down Expand Up @@ -642,26 +642,33 @@ def backward(ctx, do):

attention = _attention.apply


@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
[(4, 48, 1024, 64),
(4, 48, 2048, 64),
(4, 48, 4096, 64),
(4, 48, 1024, 128),
(4, 48, 2048, 128),
(4, 48, 4096, 128),
#(4, 48, 8192, 64),
#(4, 48, 16384, 64)
])
name_to_torch_types = {
'fp16': torch.float16,
}

if TORCH_HAS_FP8E5B16:
name_to_torch_types['fp8'] = torch.float8_e5m2fnuz

@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, dtype',
[ (*shape, dtype)
for shape in [(4, 48, 1024, 64),
(4, 48, 2048, 64),
(4, 48, 4096, 64),
(4, 48, 1024, 128),
(4, 48, 2048, 128),
(4, 48, 4096, 128)]
for dtype in ['fp16', 'fp8']])
@pytest.mark.parametrize('causal', [False, True])
def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype):
if dtype == 'fp8' and not TORCH_HAS_FP8E5B16:
pytest.skip("fp8 not supported")
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
if TORCH_HAS_FP8:
q = q.to(torch_dtype)
k = k.to(torch_dtype)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_()

q = q.to(name_to_torch_types[dtype])
k = k.to(name_to_torch_types[dtype])
sm_scale = 0.5
dout = torch.randn_like(q, dtype=torch.float16)
# reference implementation
Expand All @@ -674,7 +681,9 @@ def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
# triton implementation
tri_out = attention(q, k, v, causal, sm_scale)
# compare
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2)
atol = 1.4e-1 if dtype == 'fp8' else 1e-2
rtol = 1e-2 if dtype == 'fp8' else 0
torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=rtol)


@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
Expand Down Expand Up @@ -775,9 +784,6 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
if mode == "fwd" and TORCH_HAS_FP8:
q = q.to(torch_dtype)
k = k.to(torch_dtype)
sm_scale = D_HEAD ** -0.5
fn = lambda: attention(q, k, v, causal, sm_scale)
if mode == 'bwd':
Expand Down

0 comments on commit c631824

Please sign in to comment.