Skip to content

Commit

Permalink
[DeltaNet] change default mode to chunk; optimize WY backward pass …
Browse files Browse the repository at this point in the history
…as a matmul form
  • Loading branch information
sustcsonglin committed Aug 22, 2024
1 parent 37cf2e2 commit 56e6121
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 45 deletions.
4 changes: 2 additions & 2 deletions fla/layers/delta_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def __init__(
expand_k: float = 1.0,
expand_v: float = 1.0,
num_heads: int = 4,
mode: str = 'fused_chunk',
chunk_size: int = 16,
mode: str = 'chunk',
chunk_size: int = 64,
use_beta: bool = True,
use_gate: bool = False,
use_output_norm: bool = True,
Expand Down
21 changes: 7 additions & 14 deletions fla/ops/delta_rule/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,14 @@
fwd_prepare_wy_repr, fwd_recompute_w_u)
from fla.ops.utils import contiguous

# from fla.ops.delta_rule.utils import bwd_prepare_wy_repr


@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
triton.Config({}, num_warps=16)
],
key=["BT", "BK", "BV"],
)
Expand Down Expand Up @@ -87,8 +84,7 @@ def fwd_prepare_dv(q, k, do, BT):
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
triton.Config({}, num_warps=16)
],
key=["BT", "BK", "BV"],
)
Expand Down Expand Up @@ -166,8 +162,7 @@ def chunk_delta_rule_fwd_kernel_h(
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
triton.Config({}, num_warps=16)
],
key=["BT", "BK", "BV"],
)
Expand Down Expand Up @@ -230,8 +225,7 @@ def chunk_linear_attn_fwd_kernel_o(
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
triton.Config({}, num_warps=16)
],
key=["BT", "BK", "BV"],
)
Expand Down Expand Up @@ -308,8 +302,7 @@ def chunk_delta_rule_bwd_kernel_dhu(
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
triton.Config({}, num_warps=16)
],
key=["BT", "BK", "BV"],
)
Expand Down Expand Up @@ -496,9 +489,9 @@ def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT):

class ChunkDeltaRuleFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@staticmethod
def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=1):
# obtain WY representation. u is actually the new v.
w, u, A = fwd_prepare_wy_repr(k, v, beta, BT)
Expand All @@ -517,9 +510,9 @@ def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoin
ctx.BT = BT
return o.to(q.dtype), final_state

@staticmethod
@contiguous
@custom_bwd
@staticmethod
def backward(ctx, do, d_ht=None):
q, k, v, beta, A, h, v_new, initial_state = ctx.saved_tensors
BT = ctx.BT
Expand Down
40 changes: 11 additions & 29 deletions fla/ops/delta_rule/wy_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
triton.Config({}, num_warps=16)
],
key=["BT", "BK", "BV"],
)
Expand Down Expand Up @@ -93,8 +92,7 @@ def fwd_prepare_wy_repr_kernel(
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
triton.Config({}, num_warps=16)
],
key=["BT", "BK", "BV"],
)
Expand Down Expand Up @@ -150,8 +148,7 @@ def fwd_recompute_w_u_kernel(
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
triton.Config({}, num_warps=16)
],
key=["BT", "BK", "BV"],
)
Expand Down Expand Up @@ -196,38 +193,24 @@ def bwd_prepare_wy_repr_kernel(
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))

tl.debug_barrier()
b_A2 = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
b_dw = tl.load(p_dw, boundary_check=(0, 1))
b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False)
b_A2 += tl.dot(b_k_beta, tl.trans(b_k), allow_tf32=False)
b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False)
b_dk = b_dk_beta * b_beta[:, None]
b_dbeta += tl.sum(b_dk_beta * b_k, 1)
# store
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))

b_A -= (tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :])
b_A2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_A2, 0)

b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)
tl.debug_barrier()

for i in range(BT-1, 0, -1):
mask = tl.arange(0, BT) == i
b_da = tl.sum(tl.where(mask[:, None], b_dA, 0), 0)
b_a = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
b_da2 = b_da + tl.sum(b_da[None, :] * b_A, 1)
b_dA = tl.where(mask[:, None], b_da2, b_dA)
b_dA += b_da[None, :] * b_a[:, None]

b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False)
b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False)
b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)
tl.debug_barrier()

for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
Expand Down Expand Up @@ -303,18 +286,18 @@ def bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT):

class WYRepresentationPrepration(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@staticmethod
def forward(ctx, k, v, beta, chunk_size):
def forward(ctx, k, v, beta, chunk_size=64):
ctx.BT = chunk_size
w, u, A = fwd_prepare_wy_repr(k, v, beta, ctx.BT)
ctx.save_for_backward(k, v, beta, A)
return w, u

@staticmethod
@contiguous
@custom_bwd
@staticmethod
def backward(ctx, dw, du):
k, v, beta, A = ctx.saved_tensors
BT = ctx.BT
Expand Down Expand Up @@ -357,7 +340,7 @@ def naive(k, v, beta, chunk_size):


if __name__ == "__main__":
torch.set_default_dtype(torch.float32)
torch.set_default_dtype(torch.bfloat16)
seq_len = 1024
b = 4
h = 4
Expand All @@ -378,8 +361,7 @@ def naive(k, v, beta, chunk_size):

k_grad2, v_grad2, beta_grad2 = k.grad, v.grad, beta.grad
k.grad = v.grad = beta.grad = None

o3, o4 = prepare_wy_repr(k.clone(), v.clone(), beta.clone())
o3, o4 = prepare_wy_repr(k.clone(), v.clone(), beta.clone(), 64)
print((o1-o3).abs().max())
print((o2-o4).abs().max())

Expand Down

0 comments on commit 56e6121

Please sign in to comment.