Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flash Attention v2 to Ops #1970

Merged
merged 6 commits into from
Jul 23, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 55 additions & 94 deletions python/triton/ops/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
@jit
def _fwd_kernel(
Q, K, V, sm_scale,
L, M,
L,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
Expand All @@ -26,7 +26,7 @@ def _fwd_kernel(
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
MODE: tl.constexpr,
IS_CAUSAL: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
Expand Down Expand Up @@ -55,37 +55,13 @@ def _fwd_kernel(
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# causal check on every loop iteration can be expensive
# and peeling the last iteration of the loop does not work well with ptxas
# so we have a mode to do the causal check in a separate kernel entirely
if MODE == 0: # entire non-causal attention
lo, hi = 0, N_CTX
if MODE == 1: # entire causal attention
lo, hi = 0, (start_m + 1) * BLOCK_M
if MODE == 2: # off band-diagonal
lo, hi = 0, start_m * BLOCK_M
if MODE == 3: # on band-diagonal
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
m_i = tl.load(m_ptrs)
l_i = tl.load(l_ptrs)
acc += tl.load(O_block_ptr).to(tl.float32)
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
# credits to: Adam P. Goucher (https://github.com/apgoucher):
# scale sm_scale by 1/log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
Expand All @@ -94,70 +70,61 @@ def _fwd_kernel(
# load q: it will stay in SRAM throughout
q = tl.load(Q_block_ptr)
q = (q * qk_scale).to(K.dtype.element_ty)
# advance block pointers to first iteration of the loop
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
# loop over k, v and update accumulator
lo = 0
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
# -- load k, v --
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, allow_tf32=True)
if MODE == 1 or MODE == 3:
if IS_CAUSAL:
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.math.exp2(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
qk += tl.dot(q, k, allow_tf32=True)
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp2(m_i - m_i_new)
beta = tl.math.exp2(m_ij - m_i_new)
l_i *= alpha
l_i_new = l_i + beta * l_ij
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_block_ptr)
p = p.to(V.dtype.element_ty)
acc += tl.dot(p, v, allow_tf32=True)
# update m_i and l_i
l_i = l_i_new
p = tl.math.exp2(qk - m_i_new[:, None])
# -- scale and update acc --
acc_scale = l_i * 0 + alpha # workaround some compiler bug
acc *= acc_scale[:, None]
acc += tl.dot(p.to(V.dtype.element_ty), v, allow_tf32=True)
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
m_i = m_i_new
# update pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
# write back l and m
acc = acc / l_i[:, None]
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_i)
tl.store(m_ptrs, m_i)
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
tl.store(O_block_ptr, acc.to(K.dtype.element_ty))


@jit
def _bwd_preprocess(
Out, DO, L,
NewDO, Delta,
Out, DO,
Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
denom = tl.load(L + off_m).to(tl.float32)
# compute
do = do / denom[:, None]
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)


Expand All @@ -166,7 +133,7 @@ def _bwd_kernel_one_col_block(
Q, K, V, sm_scale, qk_scale,
Out, DO,
DQ, DK, DV,
L, M,
L,
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
Expand All @@ -176,14 +143,14 @@ def _bwd_kernel_one_col_block(
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
MODE: tl.constexpr,
CAUSAL: tl.constexpr,
):
if SEQUENCE_PARALLEL:
DQ += stride_dqa.to(tl.int64) * start_n
if MODE == 0:
lo = 0
else:
if CAUSAL:
lo = start_n * BLOCK_M
else:
lo = 0
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
Expand All @@ -197,7 +164,7 @@ def _bwd_kernel_one_col_block(
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
l_ptrs = L + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
Expand All @@ -211,14 +178,14 @@ def _bwd_kernel_one_col_block(
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
if MODE == 1:
if CAUSAL:
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
else:
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
qk *= qk_scale
m = tl.load(m_ptrs + offs_m_curr)
p = tl.math.exp2(qk - m[:, None])
l_i = tl.load(l_ptrs + offs_m_curr)
p = tl.math.exp2(qk - l_i[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do, allow_tf32=True)
Expand Down Expand Up @@ -257,7 +224,7 @@ def _bwd_kernel(
Q, K, V, sm_scale,
Out, DO,
DQ, DK, DV,
L, M,
L,
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
Expand All @@ -266,7 +233,7 @@ def _bwd_kernel(
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
MODE: tl.constexpr,
CAUSAL: tl.constexpr,
# fmt: on
):
qk_scale = sm_scale * 1.44269504
Expand All @@ -288,7 +255,7 @@ def _bwd_kernel(
_bwd_kernel_one_col_block(
Q, K, V, sm_scale, qk_scale, Out, DO,
DQ, DK, DV,
L, M,
L,
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
Expand All @@ -298,14 +265,14 @@ def _bwd_kernel(
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
MODE=MODE,
CAUSAL=CAUSAL,
)
else:
start_n = tl.program_id(1)
_bwd_kernel_one_col_block(
Q, K, V, sm_scale, qk_scale, Out, DO,
DQ, DK, DV,
L, M,
L,
D,
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
Expand All @@ -315,7 +282,7 @@ def _bwd_kernel(
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
MODE=MODE,
CAUSAL=CAUSAL,
)


Expand All @@ -335,7 +302,6 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False):
o = torch.empty_like(q)
grid = (cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
if causal:
modes = [1] if q.shape[2] <= 2048 else [2, 3]
Expand All @@ -344,19 +310,19 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False):
for mode in modes:
_fwd_kernel[grid](
q, k, v, sm_scale,
L, m,
L,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=128, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk,
MODE=mode,
IS_CAUSAL=causal,
num_warps=num_warps,
num_stages=2)

ctx.save_for_backward(q, k, v, o, L, m)
ctx.save_for_backward(q, k, v, o, L)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
Expand All @@ -367,7 +333,7 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False):
@staticmethod
def backward(ctx, do):
BLOCK = 128
q, k, v, o, l, m = ctx.saved_tensors
q, k, v, o, L = ctx.saved_tensors
sequence_parallel = ctx.sequence_parallel
seq_len_kv = k.shape[2]
do = do.contiguous()
Expand All @@ -379,22 +345,17 @@ def backward(ctx, do):
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
if ctx.causal:
mode = 1
else:
mode = 0
delta = torch.empty_like(L)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
o, do,
delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)](
q, k, v, ctx.sm_scale,
o, do_scaled,
o, do,
dq, dk, dv,
l, m,
L,
delta,
o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
Expand All @@ -403,7 +364,7 @@ def backward(ctx, do):
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
SEQUENCE_PARALLEL=sequence_parallel,
MODE=mode,
CAUSAL=ctx.causal,
num_warps=8,
num_stages=1,
)
Expand Down
8 changes: 3 additions & 5 deletions python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _fwd_kernel(
@triton.jit
def _bwd_preprocess(
Out, DO,
NewDO, Delta,
Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
Expand All @@ -134,7 +134,6 @@ def _bwd_preprocess(
# compute
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)


Expand Down Expand Up @@ -277,16 +276,15 @@ def backward(ctx, do):
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(L)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do,
do_scaled, delta,
delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
o, do,
dq, dk, dv,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
Expand Down
Loading