Skip to content

Commit

Permalink
forward row masking working
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Aug 10, 2023
1 parent d702810 commit 7908525
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 48 deletions.
42 changes: 39 additions & 3 deletions test/test_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("bias_choice", [BiasMode.rel_pos, BiasMode.none, BiasMode.alibi])
@pytest.mark.parametrize("sm_scale", [None, 1])
def test_op(Z, H, N_CTX, D_HEAD, causal, bias_choice, sm_scale, dtype=torch.float16):
def test_flash_all(Z, H, N_CTX, D_HEAD, causal, bias_choice, sm_scale, dtype=torch.float16):
torch.manual_seed(20)
q = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
Expand Down Expand Up @@ -62,8 +62,8 @@ def test_op(Z, H, N_CTX, D_HEAD, causal, bias_choice, sm_scale, dtype=torch.floa
if N_CTX > BLOCK_M and causal:
# Since the kernel will not iterate over all seq_len_kv when causal
# We will only check the minimum rectangular block
attn_bias = attn_bias[:,:,:,:BLOCK_M]
mask = mask[:,:,:,:BLOCK_M]
attn_bias = attn_bias[:, :, :, :BLOCK_M]
mask = mask[:, :, :, :BLOCK_M]
torch.testing.assert_close(attn_bias, mask, atol=4e-2, rtol=0)

# compare
Expand All @@ -80,5 +80,41 @@ def test_op(Z, H, N_CTX, D_HEAD, causal, bias_choice, sm_scale, dtype=torch.floa
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)


def test_flash_masked_block(dtype=torch.float16):
torch.manual_seed(20)
Z, H, N_CTX, D_HEAD = (6, 8, 256, 16)
q = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
k = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
v = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)

sm_scale = 1 / (D_HEAD**0.5)

temp_mask = torch.ones((Z, H, N_CTX, N_CTX)).tril_(-1).bool()
ref_mask = torch.zeros_like(temp_mask, dtype=torch.float32)
ref_mask.masked_fill_(temp_mask, float("-inf"))
ref_mask = ref_mask.to(q.device).to(q.dtype)
dout = torch.randn_like(q)
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False):
ref_out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, scale=sm_scale, is_causal=False, attn_mask=ref_mask
)
tri_out, mask = attention(q, k, v, False, sm_scale, BiasMode.inverse_causal, True) # type: ignore

torch.testing.assert_close(ref_out, tri_out, atol=5.8e-2, rtol=0)
torch.testing.assert_close(ref_mask, mask.half(), atol=4e-2, rtol=0)


if __name__ == "__main__":
pytest.main([__file__])
1 change: 1 addition & 0 deletions transformer_nuggets/flash/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from transformer_nuggets.flash.flash_attention import *
from transformer_nuggets.flash.masks import *
66 changes: 21 additions & 45 deletions transformer_nuggets/flash/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,61 +13,30 @@

import torch
import enum

def build_causal_mask(seq_len_q, seq_len_kv):
temp_mask = (
torch.ones((seq_len_q, seq_len_kv))
.tril_()
.bool()
)
mask = torch.zeros_like(temp_mask, dtype=torch.float32)
mask.masked_fill_(temp_mask.logical_not(), float("-inf"))
return mask

def build_alibi_mask(n_queries, n_keys, n_heads, scale=None, causal=True):
if scale is None:
assert n_heads%8 == 0
m_0 = 2.0 ** (-8.0 / n_heads)
slopes = torch.pow(m_0, torch.arange(1, 1 + n_heads))[:, None, None]
base = -1 * (torch.arange(n_queries)[:, None] - torch.arange(n_keys)[None, :])
if scale is not None:
alibi_base = base * scale
else:
alibi_base = base * slopes
alibi_base = alibi_base.expand(n_heads, n_queries, n_keys)
if causal:
causal_mask = build_causal_mask(n_queries, n_keys)
causal_mask = causal_mask.expand(n_heads, n_queries, n_keys)
full_mask = alibi_base + causal_mask
else:
full_mask = alibi_base
return full_mask


@triton.jit
def rel_attention_triton(cur, m, n, head_num, num_heads):
bias = n - m
cur = cur + bias
return cur

@triton.jit
def alibi_attention_triton(cur, m, n, head_num, num_heads):
# 0 Indexing
alibi_scale = tl.math.exp2(-((head_num + 1) * 8.0 / num_heads))
bias = n - m
cur = cur + (alibi_scale * bias)
return cur
from transformer_nuggets.flash.masks import (
alibi_attention_triton, rel_attention_triton, inverse_causal_mask_triton
)

class BiasMode(enum.Enum):
none = 0
rel_pos = 1
alibi = 2
inverse_causal = 3

@triton.jit
def max_fn(x, y):
return tl.math.max(x, y)

@triton.jit
def masked_row(rows):
""" rows is BLOCK_M slice of the QK score
Returns:
BLOCK_M vector of boolean values indicating whether this
Query x Key position is fully masked
"""
return rows == float("-inf")

@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
Expand Down Expand Up @@ -151,6 +120,9 @@ def _fwd_kernel(
qk = rel_attention_triton(qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz%H, H)
elif BIAS_CHOICE == BiasMode.alibi:
qk = alibi_attention_triton(qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz%H, H)
elif BIAS_CHOICE == BiasMode.inverse_causal:
# This should only be used for debugging
qk = inverse_causal_mask_triton(qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz%H, H)
if DEBUG_MASK and BIAS_CHOICE != BiasMode.none:
mask = qk - tl.dot(q,k)
if IS_CAUSAL:
Expand All @@ -160,12 +132,16 @@ def _fwd_kernel(
if IS_CAUSAL:
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
row_max = tl.max(qk, 1)
masked_out_rows= masked_row(row_max)
m_i_new = tl.maximum(m_i, row_max)
# TODO FIX ME
# alpha = tl.math.exp2(m_i - m_i_new)
# p = tl.math.exp2(qk - m_i_new[:, None])
alpha = tl.math.exp(m_i - m_i_new)
alpha = tl.where(masked_out_rows, 0, alpha)
p = tl.math.exp(qk - m_i_new[:, None])
p = tl.where(masked_out_rows[:, None], 0, p)
# -- scale and update acc --
acc_scale = l_i * 0 + alpha # workaround some compiler bug
acc *= acc_scale[:, None]
Expand Down
58 changes: 58 additions & 0 deletions transformer_nuggets/flash/masks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
import triton
import triton.language as tl


def build_causal_mask(seq_len_q, seq_len_kv):
temp_mask = torch.ones((seq_len_q, seq_len_kv)).tril_().bool()
mask = torch.zeros_like(temp_mask, dtype=torch.float32)
mask.masked_fill_(temp_mask.logical_not(), float("-inf"))
return mask


def build_alibi_mask(n_queries, n_keys, n_heads, scale=None, causal=True):
if scale is None:
assert n_heads % 8 == 0
m_0 = 2.0 ** (-8.0 / n_heads)
slopes = torch.pow(m_0, torch.arange(1, 1 + n_heads))[:, None, None]
base = -1 * (torch.arange(n_queries)[:, None] - torch.arange(n_keys)[None, :])
if scale is not None:
alibi_base = base * scale
else:
alibi_base = base * slopes
alibi_base = alibi_base.expand(n_heads, n_queries, n_keys)
if causal:
causal_mask = build_causal_mask(n_queries, n_keys)
causal_mask = causal_mask.expand(n_heads, n_queries, n_keys)
full_mask = alibi_base + causal_mask
else:
full_mask = alibi_base
return full_mask


@triton.jit
def rel_attention_triton(cur, m, n, head_num, num_heads):
bias = n - m
cur = cur + bias
return cur


@triton.jit
def alibi_attention_triton(cur, m, n, head_num, num_heads):
# 0 Indexing
alibi_scale = tl.math.exp2(-((head_num + 1) * 8.0 / num_heads))
bias = n - m
cur = cur + (alibi_scale * bias)
return cur


@triton.jit
def causal_mask_triton(cur, m, n, head_num, num_heads):
cur = tl.where(m >= n, cur, float("-inf"))
return cur


@triton.jit
def inverse_causal_mask_triton(cur, m, n, head_num, num_heads):
cur = tl.where(m > n, float("-inf"), cur)
return cur

0 comments on commit 7908525

Please sign in to comment.