diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index 74221409cc7a..2f6005f79da0 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -155,27 +155,27 @@ def test_elementwise(N, dtype_str): flash_attention_data = { "a100": { - (4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.433, - (4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.392, - (4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.106, + (4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.532, + (4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.471, + (4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.150, (4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.204, (4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.202, (4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.089, - (4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.242, - (4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.220, - (4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.069, + (4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.298, + (4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.263, + (4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.095, (4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.136, (4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135, (4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.052, - (4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.432, - (4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.392, - (4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.107, + (4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.525, + (4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.471, + (4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.150, (4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.265, (4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.257, (4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.128, - (4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.251, - (4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.220, - (4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.069, + (4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.297, + (4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.263, + (4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.095, (4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159, (4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.138, (4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.076, diff --git a/python/triton/ops/flash_attention.py b/python/triton/ops/flash_attention.py index 2c498f869656..0ea329d1a7e6 100644 --- a/python/triton/ops/flash_attention.py +++ b/python/triton/ops/flash_attention.py @@ -17,7 +17,7 @@ @jit def _fwd_kernel( Q, K, V, sm_scale, - L, M, + L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, @@ -26,7 +26,7 @@ def _fwd_kernel( Z, H, N_CTX, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - MODE: tl.constexpr, + IS_CAUSAL: tl.constexpr, ): start_m = tl.program_id(0) off_hz = tl.program_id(1) @@ -55,14 +55,6 @@ def _fwd_kernel( block_shape=(BLOCK_N, BLOCK_DMODEL), order=(1, 0) ) - O_block_ptr = tl.make_block_ptr( - base=Out + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) @@ -70,22 +62,6 @@ def _fwd_kernel( m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # causal check on every loop iteration can be expensive - # and peeling the last iteration of the loop does not work well with ptxas - # so we have a mode to do the causal check in a separate kernel entirely - if MODE == 0: # entire non-causal attention - lo, hi = 0, N_CTX - if MODE == 1: # entire causal attention - lo, hi = 0, (start_m + 1) * BLOCK_M - if MODE == 2: # off band-diagonal - lo, hi = 0, start_m * BLOCK_M - if MODE == 3: # on band-diagonal - l_ptrs = L + off_hz * N_CTX + offs_m - m_ptrs = M + off_hz * N_CTX + offs_m - m_i = tl.load(m_ptrs) - l_i = tl.load(l_ptrs) - acc += tl.load(O_block_ptr).to(tl.float32) - lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M # credits to: Adam P. Goucher (https://github.com/apgoucher): # scale sm_scale by 1/log_2(e) and use # 2^x instead of exp in the loop because CSE and LICM @@ -94,57 +70,51 @@ def _fwd_kernel( # load q: it will stay in SRAM throughout q = tl.load(Q_block_ptr) q = (q * qk_scale).to(K.dtype.element_ty) - # advance block pointers to first iteration of the loop - K_block_ptr = tl.advance(K_block_ptr, (0, lo)) - V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) - # loop over k, v and update accumulator + lo = 0 + hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX for start_n in range(lo, hi, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- + # -- load k, v -- k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # -- compute qk --- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k, allow_tf32=True) - if MODE == 1 or MODE == 3: + if IS_CAUSAL: qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.math.exp2(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) + qk += tl.dot(q, k, allow_tf32=True) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) alpha = tl.math.exp2(m_i - m_i_new) - beta = tl.math.exp2(m_ij - m_i_new) - l_i *= alpha - l_i_new = l_i + beta * l_ij - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(V_block_ptr) - p = p.to(V.dtype.element_ty) - acc += tl.dot(p, v, allow_tf32=True) - # update m_i and l_i - l_i = l_i_new + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc *= acc_scale[:, None] + acc += tl.dot(p.to(V.dtype.element_ty), v, allow_tf32=True) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) m_i = m_i_new # update pointers K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) # write back l and m + acc = acc / l_i[:, None] l_ptrs = L + off_hz * N_CTX + offs_m - m_ptrs = M + off_hz * N_CTX + offs_m - tl.store(l_ptrs, l_i) - tl.store(m_ptrs, m_i) + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) tl.store(O_block_ptr, acc.to(K.dtype.element_ty)) @jit def _bwd_preprocess( - Out, DO, L, - NewDO, Delta, + Out, DO, + Delta, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, ): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) @@ -152,12 +122,9 @@ def _bwd_preprocess( # load o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) - denom = tl.load(L + off_m).to(tl.float32) # compute - do = do / denom[:, None] delta = tl.sum(o * do, axis=1) # write-back - tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) tl.store(Delta + off_m, delta) @@ -166,7 +133,7 @@ def _bwd_kernel_one_col_block( Q, K, V, sm_scale, qk_scale, Out, DO, DQ, DK, DV, - L, M, + L, D, stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, @@ -176,14 +143,14 @@ def _bwd_kernel_one_col_block( BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, - MODE: tl.constexpr, + CAUSAL: tl.constexpr, ): if SEQUENCE_PARALLEL: DQ += stride_dqa.to(tl.int64) * start_n - if MODE == 0: - lo = 0 - else: + if CAUSAL: lo = start_n * BLOCK_M + else: + lo = 0 # initialize row/col offsets offs_qm = lo + tl.arange(0, BLOCK_M) offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) @@ -197,7 +164,7 @@ def _bwd_kernel_one_col_block( dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) # pointer to row-wise quantities in value-like data D_ptrs = D + off_hz * N_CTX - m_ptrs = M + off_hz * N_CTX + l_ptrs = L + off_hz * N_CTX # initialize dv amd dk dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) @@ -211,14 +178,14 @@ def _bwd_kernel_one_col_block( q = tl.load(q_ptrs) # recompute p = softmax(qk, dim=-1).T # NOTE: `do` is pre-divided by `l`; no normalization here - if MODE == 1: + if CAUSAL: qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf")) else: qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, tl.trans(k)) qk *= qk_scale - m = tl.load(m_ptrs + offs_m_curr) - p = tl.math.exp2(qk - m[:, None]) + l_i = tl.load(l_ptrs + offs_m_curr) + p = tl.math.exp2(qk - l_i[:, None]) # compute dv do = tl.load(do_ptrs) dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do, allow_tf32=True) @@ -257,7 +224,7 @@ def _bwd_kernel( Q, K, V, sm_scale, Out, DO, DQ, DK, DV, - L, M, + L, D, stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, @@ -266,7 +233,7 @@ def _bwd_kernel( BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, - MODE: tl.constexpr, + CAUSAL: tl.constexpr, # fmt: on ): qk_scale = sm_scale * 1.44269504 @@ -288,7 +255,7 @@ def _bwd_kernel( _bwd_kernel_one_col_block( Q, K, V, sm_scale, qk_scale, Out, DO, DQ, DK, DV, - L, M, + L, D, stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, @@ -298,14 +265,14 @@ def _bwd_kernel( BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK_N, SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, - MODE=MODE, + CAUSAL=CAUSAL, ) else: start_n = tl.program_id(1) _bwd_kernel_one_col_block( Q, K, V, sm_scale, qk_scale, Out, DO, DQ, DK, DV, - L, M, + L, D, stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, @@ -315,7 +282,7 @@ def _bwd_kernel( BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK_N, SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, - MODE=MODE, + CAUSAL=CAUSAL, ) @@ -327,36 +294,31 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): capability = torch.cuda.get_device_capability() if capability[0] < 8: raise RuntimeError("Flash attention currently only supported for compute capability >= 80") - BLOCK = 128 + BLOCK_M = 128 + BLOCK_N = 64 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128} o = torch.empty_like(q) - grid = (cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1) + grid = (cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 - if causal: - modes = [1] if q.shape[2] <= 2048 else [2, 3] - else: - modes = [0] - for mode in modes: - _fwd_kernel[grid]( - q, k, v, sm_scale, - L, m, - o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], q.shape[2], - BLOCK_M=128, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk, - MODE=mode, - num_warps=num_warps, - num_stages=2) + _fwd_kernel[grid]( + q, k, v, sm_scale, + L, + o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, + IS_CAUSAL=causal, + num_warps=num_warps, + num_stages=4) - ctx.save_for_backward(q, k, v, o, L, m) + ctx.save_for_backward(q, k, v, o, L) ctx.grid = grid ctx.sm_scale = sm_scale ctx.BLOCK_DMODEL = Lk @@ -367,7 +329,7 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): @staticmethod def backward(ctx, do): BLOCK = 128 - q, k, v, o, l, m = ctx.saved_tensors + q, k, v, o, L = ctx.saved_tensors sequence_parallel = ctx.sequence_parallel seq_len_kv = k.shape[2] do = do.contiguous() @@ -379,22 +341,17 @@ def backward(ctx, do): dq = torch.zeros_like(q, dtype=torch.float32) dk = torch.empty_like(k) dv = torch.empty_like(v) - do_scaled = torch.empty_like(do) - delta = torch.empty_like(l) - if ctx.causal: - mode = 1 - else: - mode = 0 + delta = torch.empty_like(L) _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( - o, do, l, - do_scaled, delta, + o, do, + delta, BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) _bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)]( q, k, v, ctx.sm_scale, - o, do_scaled, + o, do, dq, dk, dv, - l, m, + L, delta, o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), @@ -403,7 +360,7 @@ def backward(ctx, do): BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=ctx.BLOCK_DMODEL, SEQUENCE_PARALLEL=sequence_parallel, - MODE=mode, + CAUSAL=ctx.causal, num_warps=8, num_stages=1, ) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 7098bf3c679c..7abae03a2641 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -123,7 +123,7 @@ def _fwd_kernel( @triton.jit def _bwd_preprocess( Out, DO, - NewDO, Delta, + Delta, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, ): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) @@ -134,7 +134,6 @@ def _bwd_preprocess( # compute delta = tl.sum(o * do, axis=1) # write-back - tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) tl.store(Delta + off_m, delta) @@ -277,16 +276,15 @@ def backward(ctx, do): dq = torch.zeros_like(q, dtype=torch.float32) dk = torch.empty_like(k) dv = torch.empty_like(v) - do_scaled = torch.empty_like(do) delta = torch.empty_like(L) _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( o, do, - do_scaled, delta, + delta, BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) _bwd_kernel[(ctx.grid[1],)]( q, k, v, ctx.sm_scale, - o, do_scaled, + o, do, dq, dk, dv, L, delta, q.stride(0), q.stride(1), q.stride(2), q.stride(3),