Skip to content

Commit

Permalink
rebase stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Mar 26, 2024
1 parent 740e924 commit bec2df0
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 33 deletions.
8 changes: 4 additions & 4 deletions test/test_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def test_flash_all(Z, H, N_CTX, D_HEAD, causal, bias_choice, sm_scale, dtype=tor
# Check attn_bias equivalence
if bias_choice != BiasMode.none:
BLOCK_M = 128
BLOCK_N = 64
mask = mask.half()
if N_CTX > BLOCK_M and causal:
# Since the kernel will not iterate over all seq_len_kv when causal
Expand Down Expand Up @@ -107,7 +106,7 @@ def test_flash_masked_block(dtype=torch.float16):
ref_mask.masked_fill_(temp_mask, float("-inf"))
ref_mask = ref_mask.to(q.device).to(q.dtype)
dout = torch.randn_like(q)
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False):
with sdpa_kernel(SDPBackend.MATH):
ref_out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, scale=sm_scale, is_causal=False, attn_mask=ref_mask
)
Expand All @@ -116,7 +115,7 @@ def test_flash_masked_block(dtype=torch.float16):
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None

tri_out, mask = attention(q, k, v, False, sm_scale, BiasMode.inverse_causal, True) # type: ignore

tri_out.half()
Expand All @@ -128,10 +127,11 @@ def test_flash_masked_block(dtype=torch.float16):
atol = 2e-2 * 6
torch.testing.assert_close(ref_out, tri_out, atol=5.8e-2, rtol=0)
torch.testing.assert_close(ref_mask, mask.half(), atol=4e-2, rtol=0)
breakpoint()

torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)


if __name__ == "__main__":
pytest.main([__file__])
49 changes: 29 additions & 20 deletions transformer_nuggets/flash/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,34 @@
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
"""

import enum

import torch

import triton
import triton.language as tl

import torch
import enum
from transformer_nuggets.flash.masks import (
alibi_attention_triton, rel_attention_triton, inverse_causal_mask_triton
alibi_attention_triton,
BiasMode,
inverse_causal_mask_triton,
rel_attention_triton,
)

class BiasMode(enum.Enum):
none = 0
rel_pos = 1
alibi = 2
inverse_causal = 3

@triton.jit
def max_fn(x, y):
return tl.math.max(x, y)


@triton.jit
def masked_row(rows):
""" rows is BLOCK_M slice of the QK score
"""rows is BLOCK_M slice of the QK score
Returns:
BLOCK_M vector of boolean values indicating whether this
Query x Key position is fully masked
"""
return rows == float("-inf")


@triton.jit
def _fwd_kernel(
Q,
Expand Down Expand Up @@ -138,12 +133,18 @@ def _fwd_kernel(
qk += tl.dot(q, k)
# ~~~~~~~~~~~~~~~~~~~ This is all mask stuff ~~~~~~~~~~~~~~~~~~~
if BIAS_CHOICE == 1:
qk = rel_attention_triton(qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz%H, H)
qk = rel_attention_triton(
qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz % H, H
)
elif BIAS_CHOICE == 2:
qk = alibi_attention_triton(qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz%H, H)
qk = alibi_attention_triton(
qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz % H, H
)
elif BIAS_CHOICE == 3:
# This should only be used for debugging
qk = inverse_causal_mask_triton(qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz%H, H)
qk = inverse_causal_mask_triton(
qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz % H, H
)
if DEBUG_MASK and BIAS_CHOICE != BiasMode.none:
mask = qk - tl.dot(q, k)
if IS_CAUSAL:
Expand Down Expand Up @@ -304,16 +305,22 @@ def _bwd_kernel(
qk *= qk_scale
# ~~~~~~~~~~~~~~~~~~~ This is all mask stuff ~~~~~~~~~~~~~~~~~~~
if BIAS_CHOICE == 1:
qk = rel_attention_triton(qk, offs_m_curr[:, None], (offs_n[None, :]), off_hz%H, H)
qk = rel_attention_triton(
qk, offs_m_curr[:, None], (offs_n[None, :]), off_hz % H, H
)
elif BIAS_CHOICE == 2:
qk = alibi_attention_triton(qk, offs_m_curr[:, None], (offs_n[None, :]), off_hz%H, H)
qk = alibi_attention_triton(
qk, offs_m_curr[:, None], (offs_n[None, :]), off_hz % H, H
)
elif BIAS_CHOICE == 3:
# This should only be used for debugging
qk = inverse_causal_mask_triton(qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz%H, H)
qk = inverse_causal_mask_triton(
qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz % H, H
)
# ~~~~~~~~~~~~~~~~~~~ This is the end of mask stuff ~~~~~~~~~~~~~~~~~~~
l_i = tl.load(l_ptrs + offs_m_curr)
row_max = tl.max(qk, 1)
masked_out_rows= masked_row(row_max)
masked_out_rows = masked_row(row_max)
# TODO fix me
# p = tl.math.exp2(qk - l_i[:, None])
p = tl.math.exp(qk - l_i[:, None])
Expand Down Expand Up @@ -356,7 +363,9 @@ def forward(ctx, q, k, v, causal, sm_scale, bias_choice: BiasMode, debug_mask=Fa
BLOCK_M = 128
BLOCK_N = 64
grid = (triton.cdiv(seq_len_qv, BLOCK_M), batch_size * num_heads, 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
L = torch.empty(
(q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
)

scratch_space = None
if debug_mask:
Expand Down
43 changes: 34 additions & 9 deletions transformer_nuggets/flash/masks.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,57 @@
import enum

import torch

import triton
import triton.language as tl


class BiasMode(enum.Enum):
none = 0
rel_pos = 1
alibi = 2
inverse_causal = 3


def build_causal_mask(seq_len_q, seq_len_kv):
temp_mask = torch.ones((seq_len_q, seq_len_kv)).tril_().bool()
mask = torch.zeros_like(temp_mask, dtype=torch.float32)
mask.masked_fill_(temp_mask.logical_not(), float("-inf"))
return mask


def build_alibi_mask(n_queries, n_keys, n_heads, scale=None, causal=True):
if scale is None:
def build_rel_mask(
n_queries: int,
n_keys: int,
n_heads: int,
mode: BiasMode,
causal=True,
):
"""Builds torch equivalent mask
Args:
n_queries: Number of queries.
n_keys: Number of keys.
n_heads: Number of attention heads.
mode: Bias mode for the attention mask.
causal: Whether to include causal mask. Defaults to True.
Returns:
torch.Tensor: The alibi attention mask.
"""
if mode == BiasMode.alibi:
assert n_heads % 8 == 0
m_0 = 2.0 ** (-8.0 / n_heads)
slopes = torch.pow(m_0, torch.arange(1, 1 + n_heads))[:, None, None]
base = -1 * (torch.arange(n_queries)[:, None] - torch.arange(n_keys)[None, :])
if scale is not None:
alibi_base = base * scale
else:
alibi_base = base * slopes
alibi_base = alibi_base.expand(n_heads, n_queries, n_keys)
mask = base
mask = mask * slopes if mode == BiasMode.alibi else mask
mask = mask.expand(n_heads, n_queries, n_keys)
if causal:
causal_mask = build_causal_mask(n_queries, n_keys)
causal_mask = causal_mask.expand(n_heads, n_queries, n_keys)
full_mask = alibi_base + causal_mask
full_mask = mask + causal_mask
else:
full_mask = alibi_base
full_mask = mask
return full_mask


Expand Down

0 comments on commit bec2df0

Please sign in to comment.