diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index 01fe6fca5..15b30405f 100644 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -265,4 +265,4 @@ def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs): if verbose: print(f"{desc} max memory: {mem}GB") torch.cuda.empty_cache() - return mem \ No newline at end of file + return mem diff --git a/benchmarks/ops/benchmark_delta_rule.py b/benchmarks/ops/benchmark_delta_rule.py index 4c749628d..8d2bf7f38 100644 --- a/benchmarks/ops/benchmark_delta_rule.py +++ b/benchmarks/ops/benchmark_delta_rule.py @@ -5,7 +5,7 @@ from benchmark import benchmark_combined, benchmark_forward from fla.ops.delta_rule import (chunk_linear_attn_delta_rule, - fused_recurrent_linear_attn_delta_rule) + fused_recurrent_delta_rule) from fla.ops.retention import fused_chunk_retention @@ -71,7 +71,7 @@ def time_fwd_bwd(func, *args, **kwargs): v2 = torch.randn(B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype) f_b = time_fwd_bwd( - fused_recurrent_linear_attn_delta_rule, q, k, v, verbose=False + fused_recurrent_delta_rule, q, k, v, verbose=False ) time_f_b[config, "delta_recurrent"] = f_b diff --git a/fla/layers/delta_net.py b/fla/layers/delta_net.py index c32427d87..4d0a2b371 100644 --- a/fla/layers/delta_net.py +++ b/fla/layers/delta_net.py @@ -14,7 +14,7 @@ from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution from fla.ops.delta_rule import (chunk_delta_rule, fused_chunk_delta_rule, - fused_recurrent_linear_attn_delta_rule) + fused_recurrent_delta_rule) if TYPE_CHECKING: from fla.models.utils import Cache @@ -206,7 +206,7 @@ def forward( beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2]) state = past_key_values[self.layer_idx][-1] if use_cache else None if mode == 'fused_recurrent': - o, recurrent_state = fused_recurrent_linear_attn_delta_rule(q, k, v, beta, state, output_final_state=use_cache) + o, recurrent_state = fused_recurrent_delta_rule(q, k, v, beta, state, output_final_state=use_cache) elif mode == 'fused_chunk': assert self.chunk_size in [16, 32, 64] o, recurrent_state = fused_chunk_delta_rule(q, k, v, beta, self.chunk_size, state, output_final_state=use_cache) diff --git a/fla/ops/abc/chunk.py b/fla/ops/abc/chunk.py index 599317e7d..1ebe6e08f 100644 --- a/fla/ops/abc/chunk.py +++ b/fla/ops/abc/chunk.py @@ -8,8 +8,7 @@ import triton import triton.language as tl -from fla.ops.utils import (logcumsumexp_fwd_kernel, softmax_bwd_kernel, - softmax_fwd_kernel) +from fla.ops.utils import (logcumsumexp_fwd_kernel, softmax_bwd_kernel, softmax_fwd_kernel) from fla.utils import contiguous diff --git a/fla/ops/delta_rule/__init__.py b/fla/ops/delta_rule/__init__.py index b0848b3e9..03cffaea2 100644 --- a/fla/ops/delta_rule/__init__.py +++ b/fla/ops/delta_rule/__init__.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- -from .chunk_fuse import fused_chunk_delta_rule -from .recurrent_fuse import fused_recurrent_linear_attn_delta_rule from .chunk import chunk_delta_rule +from .chunk_fuse import fused_chunk_delta_rule +from .recurrent_fuse import fused_recurrent_delta_rule __all__ = [ 'fused_chunk_delta_rule', - 'fused_recurrent_linear_attn_delta_rule', + 'fused_recurrent_delta_rule', 'chunk_delta_rule' ] diff --git a/fla/ops/delta_rule/naive.py b/fla/ops/delta_rule/naive.py index 45ca247cb..1e4c628f0 100644 --- a/fla/ops/delta_rule/naive.py +++ b/fla/ops/delta_rule/naive.py @@ -10,15 +10,20 @@ def delta_rule_recurrence(q, k, v, beta): o = torch.zeros_like(v) S = torch.zeros(b, h, d_k, d_v).to(v) q = q * (d_k ** -0.5) + + if beta.ndim < v.ndim: + beta = beta[..., None] + for i in range(l): _k = k[:, :, i] _q = q[:, :, i] _v = v[:, :, i].clone() beta_i = beta[:, :, i] _v = _v - (S.clone() * _k[..., None]).sum(-2) - _v = _v * beta_i[..., None] + _v = _v * beta_i S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2) o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o diff --git a/fla/ops/delta_rule/recurrent_fuse.py b/fla/ops/delta_rule/recurrent_fuse.py index 6bd242649..dc1f032f0 100644 --- a/fla/ops/delta_rule/recurrent_fuse.py +++ b/fla/ops/delta_rule/recurrent_fuse.py @@ -15,33 +15,26 @@ @triton.jit def fused_recurrent_fwd_kernel( # B: batch_size, H: n_heads, T: seq_len, D: d_head - q, # query [B, H, L, D_head_K] - k, # key [B, H, L, D_head_V] - v, # value [B, H, L, D_head_V]. + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V]. beta, # beta [B, H, L] - o, # output [B, H, L, D_head_V] - initial_state, - final_state, # final hidden state [B, H, D_head_K, D_head_V] - - - s_qk_h, # stride size: L * D_head_K - s_qk_t, # stride size: D_head_K - s_qk_d, # stride size: 1 - - s_vo_h, # stride size: L * D_head_V - s_vo_t, # stride size: D_head_V - s_vo_d, # stride size: 1 - + o, # output [B, H, L, V] + h0, + ht, # final hidden state [B, H, K, V] + s_qk_h, # stride size: L * K + s_vo_h, # stride size: L * V + scale, # K ** -0.5 B, # batch size H, # n_heads T, # seq_len - scale, # D_head_K ** -0.5 + K: tl.constexpr, # K + V: tl.constexpr, # V BK: tl.constexpr, # BLOCK SIZE along the K dimension BV: tl.constexpr, # BLOCK SIZE along the V dimension - DK: tl.constexpr, # D_head_K - DV: tl.constexpr, # D_head_V USE_INITIAL_STATE: tl.constexpr, # whether to use initial state STORE_FINAL_STATE: tl.constexpr, # whether to store final state + IS_HEADWISE_BETA: tl.constexpr, # whether beta is headwise vector or scalar ): # indices @@ -50,47 +43,49 @@ def fused_recurrent_fwd_kernel( p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) - p_beta = beta + i_bh * T + if IS_HEADWISE_BETA: + p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + else: + p_beta = beta + i_bh * T p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) - mask_bk = (i_k * BK + tl.arange(0, BK)) < DK - mask_bv = (i_v * BV + tl.arange(0, BV)) < DV + mask_bk = (i_k * BK + tl.arange(0, BK)) < K + mask_bv = (i_v * BV + tl.arange(0, BV)) < V mask_kv = mask_bk[None, :] & mask_bv[:, None] h = tl.zeros([BV, BK], dtype=tl.float32) if USE_INITIAL_STATE: - p_init_s = initial_state + i_bh * DK * DV + \ - (i_k * BK + tl.arange(0, BK)[None, :]) * \ - DV + (i_v * BV + tl.arange(0, BV)[:, None]) - h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) for _ in range(0, T): - _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) - _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) - _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale - _v_minus = tl.sum(h * _k[None, :], axis=1) - _v -= _v_minus - _beta = tl.load(p_beta).to(tl.float32) + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + _v_minus = tl.sum(h * b_k[None, :], axis=1) + b_v -= _v_minus + if IS_HEADWISE_BETA: + b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) # in-place overwrite - tl.store(p_v, _v.to(p_v.dtype.element_ty), mask=mask_bv) - _v *= _beta - h += _k[None, :] * _v[:, None] - _o = h * _q[None, :] + tl.store(p_v, b_v.to(p_v.dtype.element_ty), mask=mask_bv) + b_v *= b_beta + h += b_k[None, :] * b_v[:, None] + _o = h * b_q[None, :] _o = tl.sum(_o, axis=1) tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv) - p_q += DK - p_k += DK - p_o += DV - p_v += DV - p_beta += 1 + p_q += K + p_k += K + p_o += V + p_v += V + p_beta += V if IS_HEADWISE_BETA else 1 if STORE_FINAL_STATE: - p_final_s = final_state + i_bh * DK * DV + \ - (i_k * BK + tl.arange(0, BK)[None, :]) * \ - DV + (i_v * BV + tl.arange(0, BV)[:, None]) - tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv) + p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, h.to(p_ht.dtype.element_ty), mask=mask_kv) # Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 @@ -98,82 +93,91 @@ def fused_recurrent_fwd_kernel( def fused_recurrent_bwd_kernel( # B: batch_size, H: n_heads, T: seq_len, D: d_head # NV: number of split in the V dimension. NK: number of split in the K dimension - q, # query [B, H, L, D_head_K] - k, # key [B, H, L, D_head_V] - v, # value [B, H, L, D_head_V] - beta, # beta [B, H, L] + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + beta, # beta [B, H, L, (V)] + + do, # gradient of output [B, H, L, V] + dq, # gradient of query [NV, B, H, L, K] + dk, # gradient of key [NV, B, H, L, K] + dv, # gradient of value [NK, B, H, L, V] + dbeta, # gradient of beta [NV, (NK), B, H, L] - do, # gradient of output [B, H, L, D_head_V] - dq, # gradient of query [NV, B, H, L, D_head_K] - dk, # gradient of key [NV, B, H, L, D_head_K] - dv, # gradient of value [NK, B, H, L, D_head_V] - dbeta, # gradient of beta [B, H, L] + # initial hidden state initialization [B, H, K, V] + h0, - # initial hidden state initialization [B, H, D_head_K, D_head_V] - initial_state, + s_qk_h, # stride size: L * K - s_qk_h, # stride size: L * D_head_K - s_qk_t, # stride size: D_head_K - s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * V - s_vo_h, # stride size: L * D_head_V - s_vo_t, # stride size: D_head_V - s_vo_d, # stride size: 1 + NK, # NK block size + scale, # K ** -0.5 B, # batch_size H, # n_heads T, # seq_len - scale, # D_head_K ** -0.5 + K: tl.constexpr, # K + V: tl.constexpr, # V BK: tl.constexpr, # BLOCK SIZE along the K dimension BV: tl.constexpr, # BLOCK SIZE along the V dimension - DK: tl.constexpr, # D_head_K - DV: tl.constexpr, # D_head_V USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + IS_HEADWISE_BETA: tl.constexpr, # whether beta is headwise vector or scalar ): i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - mask_bk = i_k * BK + tl.arange(0, BK) < DK - mask_bv = i_v * BV + tl.arange(0, BV) < DV - - p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK - p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK - p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV - p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV - p_beta = beta + i_bh * T + T - 1 - p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1 - - p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \ - BK + tl.arange(0, BK) + (T - 1) * DK - p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \ - BV + tl.arange(0, BV) + (T - 1) * DV + mask_bk = i_k * BK + tl.arange(0, BK) < K + mask_bv = i_v * BV + tl.arange(0, BV) < V + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + if IS_HEADWISE_BETA: + p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + else: + p_beta = beta + i_bh * T + T - 1 + + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + if IS_HEADWISE_BETA: + p_dbeta = dbeta + (i_bh + i_k * B * H + i_v * B * H * NK) * s_vo_h + tl.arange(0, BV) + (T - 1) * V + else: + p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1 d_h = tl.zeros([BK, BV], dtype=tl.float32) for _ in range(T): - _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) - _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale - _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) - _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) - _beta = tl.load(p_beta).to(tl.float32) - d_h += _q[:, None] * _do[None, :] - d_k = tl.sum(d_h * _v[None, :] * _beta, axis=1) - d_v = tl.sum(d_h * _k[:, None], axis=0) - - d_beta = tl.sum(d_v * _v) - d_v = d_v * _beta + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + if IS_HEADWISE_BETA: + b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + d_h += b_q[:, None] * b_do[None, :] + d_k = tl.sum(d_h * (b_v * b_beta)[None, :], axis=1) + d_v = tl.sum(d_h * b_k[:, None], axis=0) + + d_beta = d_v * b_v if IS_HEADWISE_BETA else tl.sum(d_v * b_v) + d_v = d_v * b_beta tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) - tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty)) + if IS_HEADWISE_BETA: + tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty), mask=mask_bv) + else: + tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty)) - d_h -= _k[:, None] * d_v[None, :] + d_h -= b_k[:, None] * d_v[None, :] - p_do -= DV - p_q -= DK - p_k -= DK - p_v -= DV - p_dk -= DK - p_dv -= DV - p_dbeta -= 1 - p_beta -= 1 + p_do -= V + p_q -= K + p_k -= K + p_v -= V + p_dk -= K + p_dv -= V + p_dbeta -= V if IS_HEADWISE_BETA else 1 + p_beta -= V if IS_HEADWISE_BETA else 1 tl.debug_barrier() @@ -182,28 +186,32 @@ def fused_recurrent_bwd_kernel( p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) - p_beta = beta + i_bh * T + if IS_HEADWISE_BETA: + p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + else: + p_beta = beta + i_bh * T p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) - p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + DV - p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + DK + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + V + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + K if USE_INITIAL_STATE: mask_kv = mask_bk[:, None] & mask_bv[None, :] - p_init_s = initial_state + i_bh * DK * DV + \ - (i_k * BK + tl.arange(0, BK)[:, None]) * \ - DV + (i_v * BV + tl.arange(0, BV)[None, :]) - h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) for i in range(0, T): - _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) - _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) - _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) - _beta = tl.load(p_beta).to(tl.float32) - _v *= _beta - - h += _k[:, None] * _v[None, :] - _d_q = h * _do[None, :] + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + if IS_HEADWISE_BETA: + b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + + h += b_k[:, None] * b_v[None, :] + _d_q = h * b_do[None, :] d_q = tl.sum(_d_q, axis=1) * scale tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) @@ -213,100 +221,110 @@ def fused_recurrent_bwd_kernel( d_k -= tl.sum(d_v[None, :] * h, axis=1) tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) - p_k += DK - p_do += DV - p_v += DV - p_dk += DK - p_dv += DV - p_dq += DK - p_beta += 1 + p_k += K + p_do += V + p_v += V + p_dk += K + p_dv += V + p_dq += K + p_beta += V if IS_HEADWISE_BETA else 1 class FusedRecurrentFunction(torch.autograd.Function): @staticmethod @contiguous - def forward(ctx, q, k, v, beta, initial_state=None, output_final_state=False): - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] + def forward(ctx, q, k, v, beta, scale=None, initial_state=None, output_final_state=False): + B, H, T, K, V = *q.shape, v.shape[-1] - scale = d_head_qk ** -0.5 - BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 8) - NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) num_stages = 1 num_warps = 1 assert NK == 1, "NK > 1 is not supported yet" - o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + o = q.new_empty(NK, B, H, T, V) if output_final_state: - final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v) + final_state = q.new_empty(B, H, K, V) else: final_state = None - grid = (NV, NK, batch_size * n_heads) + grid = (NV, NK, B * H) fused_recurrent_fwd_kernel[grid]( q, k, v, beta, o, initial_state, final_state, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + q.stride(1), + v.stride(1), + scale, + B=B, H=H, T=T, K=K, V=V, + BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + IS_HEADWISE_BETA=beta.ndim == v.ndim, num_warps=num_warps, num_stages=num_stages, - USE_INITIAL_STATE=initial_state is not None, - STORE_FINAL_STATE=final_state is not None ) o = o.sum(0) ctx.save_for_backward(q, k, v, beta, initial_state) + ctx.scale = scale return o, final_state @staticmethod @contiguous - def backward(ctx, do, d_final_state=None): + def backward(ctx, do, dht=None): q, k, v, beta, initial_state = ctx.saved_tensors - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - scale = d_head_qk ** -0.5 - BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32) - NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + B, H, T, K, V = *q.shape, v.shape[-1] + scale = ctx.scale + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) assert NK == 1, "NK > 1 is not supported yet" num_stages = 1 num_warps = 2 - dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) - dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) - dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) - grid = (NV, NK, batch_size * n_heads) - dbeta = q.new_empty(NV, batch_size, n_heads, seq_len) + beta_vector = beta.ndim == v.ndim + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + if beta_vector: + dbeta = q.new_empty(NV, NK, B, H, T, V) + else: + dbeta = q.new_empty(NV, B, H, T) + grid = (NV, NK, B * H) fused_recurrent_bwd_kernel[grid]( q, k, v, beta, do, dq, dk, dv, dbeta, initial_state, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + q.stride(1), + v.stride(1), + NK, scale, + B=B, H=H, T=T, K=K, V=V, + BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + IS_HEADWISE_BETA=beta_vector, num_warps=num_warps, - num_stages=num_stages, - USE_INITIAL_STATE=initial_state is not None + num_stages=num_stages ) dq = dq.sum(0) dk = dk.sum(0) dv = dv.sum(0) - dbeta = dbeta.sum(0) - return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None + dbeta = dbeta.sum((0, 1)) if beta_vector else dbeta.sum(0) + return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None, None -def fused_recurrent_linear_attn_delta_rule( +def fused_recurrent_delta_rule( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, beta: torch.Tensor = None, + scale: float = -1, initial_state: torch.Tensor = None, output_final_state: bool = False, - normalize: bool = False + normalize: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: + if scale == -1: + scale = q.shape[-1] ** -0.5 if initial_state is not None: initial_state = initial_state.detach() if beta is None: beta = torch.ones_like(q[..., 0]) - o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, initial_state, output_final_state) + o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, scale, initial_state, output_final_state) return o, final_state diff --git a/fla/ops/rwkv6/recurrent_fuse.py b/fla/ops/rwkv6/recurrent_fuse.py index af251526f..4aa1055d9 100644 --- a/fla/ops/rwkv6/recurrent_fuse.py +++ b/fla/ops/rwkv6/recurrent_fuse.py @@ -347,7 +347,7 @@ def fused_recurrent_rwkv6( v: torch.Tensor, w: torch.Tensor, u: torch.Tensor, - scale: int = -1, + scale: float = -1, initial_state: torch.Tensor = None, output_final_state: bool = False, causal: bool = True diff --git a/tests/ops/test_delta.py b/tests/ops/test_delta.py index 2fbc5e548..0f3df69c3 100644 --- a/tests/ops/test_delta.py +++ b/tests/ops/test_delta.py @@ -1,21 +1,56 @@ -from fla.ops.delta_rule import fused_chunk_delta_rule, chunk_delta_rule, fused_recurrent_linear_attn_delta_rule +# -*- coding: utf-8 -*- + +import pytest import torch -import time - -if __name__ == "__main__": - torch.set_default_dtype(torch.bfloat16) - seq_len = 2048 - b = 8 - h = 8 - d = 256 - k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, d), p=2, dim=-1) - v = torch.randn(b, h, seq_len, d) - q = torch.randn(b, h, seq_len, d) - beta = torch.rand(b, h, seq_len).sigmoid() - require_grad = True + +from fla.ops.delta_rule import (chunk_delta_rule, fused_chunk_delta_rule, + fused_recurrent_delta_rule) +from fla.ops.delta_rule.naive import delta_rule_recurrence + + +@pytest.mark.parametrize("B", [8]) +@pytest.mark.parametrize("H", [4]) +@pytest.mark.parametrize("T", [1024]) +@pytest.mark.parametrize("D", [128]) +@pytest.mark.parametrize("dtype", [torch.float]) +def test_beta_scalar_vector_equivalence(B: int, H: int, T: int, D: int, dtype: torch.dtype): + q = torch.randn(B, H, T, D, dtype=dtype) + k = torch.nn.functional.normalize(torch.randn(B, H, T, D, dtype=dtype), p=2, dim=-1) + v = torch.randn(B, H, T, D, dtype=dtype) + beta = torch.rand(B, H, T, dtype=dtype).sigmoid() q, k, v, beta = map(lambda x: x.cuda().requires_grad_(True), (q, k, v, beta)) + do = torch.rand_like(v) + o = delta_rule_recurrence(q.clone(), k.clone(), v.clone(), beta.clone()) + o.backward(do, retain_graph=True) + q_grad, k_grad, v_grad, beta_grad = q.grad, k.grad, v.grad, beta.grad + q.grad = k.grad = v.grad = beta.grad = None + + o2, _ = fused_recurrent_delta_rule(q.clone(), k.clone(), v.clone(), beta.clone()) + o2.backward(do, retain_graph=True) + q_grad2, k_grad2, v_grad2, beta_grad2 = q.grad, k.grad, v.grad, beta.grad + q.grad = k.grad = v.grad = beta.grad = None + + assert o.allclose(o2, 0, 1e-3), f"Diff: {torch.abs(o - o2).max()}" + assert q_grad.allclose(q_grad2, 0, 1e-3), f"Diff: {torch.abs(q_grad - q_grad2).max()}" + assert k_grad.allclose(k_grad2, 0, 1e-3), f"Diff: {torch.abs(k_grad - k_grad2).max()}" + assert v_grad.allclose(v_grad2, 0, 1e-3), f"Diff: {torch.abs(v_grad - v_grad2).max()}" + assert beta_grad.allclose(beta_grad2, 0, 1e-3), f"Diff: {torch.abs(beta_grad - beta_grad2).max()}" + + +@pytest.mark.parametrize("B", [4]) +@pytest.mark.parametrize("H", [8]) +@pytest.mark.parametrize("T", [512]) +@pytest.mark.parametrize("D", [64]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_chunk_fused_equivalence(B: int, H: int, T: int, D: int, dtype: torch.dtype): + q = torch.randn(B, H, T, D, dtype=dtype) + k = torch.nn.functional.normalize(torch.randn(B, H, T, D, dtype=dtype), p=2, dim=-1) + v = torch.randn(B, H, T, D, dtype=dtype) + beta = torch.rand(B, H, T, dtype=dtype).sigmoid() + q, k, v, beta = map(lambda x: x.cuda().requires_grad_(True), (q, k, v, beta)) do = torch.rand_like(v) + o2, _ = fused_chunk_delta_rule(q.clone(), k.clone(), v.clone(), beta.clone(), 16) o2.backward(do, retain_graph=True) q_grad2, k_grad2, v_grad2, beta_grad2 = q.grad, k.grad, v.grad, beta.grad @@ -25,36 +60,9 @@ o.backward(do, retain_graph=True) q_grad, k_grad, v_grad, beta_grad = q.grad, k.grad, v.grad, beta.grad q.grad = k.grad = v.grad = beta.grad = None - print((o- o2).abs().max()) - # assert (o- o2).abs().max() < 1e-5 - print((q_grad - q_grad2).abs().max()) - print((k_grad - k_grad2).abs().max()) - print((v_grad - v_grad2).abs().max()) - print((beta_grad - beta_grad2).abs().max()) - # breakpoint() - - print("Start warmup") - for _ in range(30): - o2, _ = fused_chunk_delta_rule(q, k, v, beta, 16) - o2.backward(do, retain_graph=True) - o, _ = chunk_delta_rule(q, k, v, beta, 32) - o.backward(do, retain_graph=True) - torch.cuda.synchronize() - print("Warmup Done") - - start = time.time() - for _ in range(100): - o2, _ = fused_chunk_delta_rule(q, k, v, beta, 16) - o2.backward(do, retain_graph=True) - torch.cuda.synchronize() - print(time.time() - start) - - start = time.time() - for _ in range(100): - o2, _ = chunk_delta_rule(q, k, v, beta, 32) - o2.backward(do, retain_graph=True) - torch.cuda.synchronize() - print(time.time() - start) - - + assert o.allclose(o2, 0, 1e-2), f"Diff: {torch.abs(o - o2).max()}" + assert q_grad.allclose(q_grad2, 0, 1e-2), f"Diff: {torch.abs(q_grad - q_grad2).max()}" + assert k_grad.allclose(k_grad2, 0, 1e-2), f"Diff: {torch.abs(k_grad - k_grad2).max()}" + assert v_grad.allclose(v_grad2, 0, 1e-2), f"Diff: {torch.abs(v_grad - v_grad2).max()}" + assert beta_grad.allclose(beta_grad2, 0, 1e-2), f"Diff: {torch.abs(beta_grad - beta_grad2).max()}" diff --git a/tests/test_padding.py b/tests/test_padding.py index 864823238..25c54f95f 100644 --- a/tests/test_padding.py +++ b/tests/test_padding.py @@ -2,9 +2,7 @@ # I made some modifications to the original code to make it work with the current version of the library. import torch -from transformers import AutoTokenizer, AutoModelForCausalLM, PretrainedConfig -import fla.models - +from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained('fla-hub/gla-340M-15B').to('cuda').to(torch.float32) tokenizer = AutoTokenizer.from_pretrained('fla-hub/gla-340M-15B')