Skip to content

Commit

Permalink
[RWKV6] Update naive impls
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed Aug 31, 2024
1 parent 739ef15 commit 9ec376b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 81 deletions.
68 changes: 6 additions & 62 deletions fla/ops/rwkv6/chunk_naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,14 @@
import torch
from einops import rearrange

from fla.ops.rwkv6.chunk import chunk_rwkv6
from fla.ops.rwkv6.recurrent_fuse import fused_recurrent_rwkv6


def naive_chunk_rwkv6(
q,
k,
v,
w,
u,
chunk_size=32,
initial_state=None,
output_final_state=True,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
chunk_size: int = 32
):
assert q.shape[-2] % chunk_size == 0
orig_dtype = q.dtype
Expand Down Expand Up @@ -46,54 +41,3 @@ def naive_chunk_rwkv6(
o_intra[:, :, :, i] = intra_inter_o + intra_intra_o
o = o_inter + o_intra
return rearrange(o, 'b h n c d -> b h (n c) d').to(orig_dtype)


if __name__ == "__main__":
B = 4
H = 4
L = 4096
D = 100
dtype = torch.bfloat16
require_grad = True
q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(require_grad)
k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(require_grad)
v = torch.randn(B, H, L, 2*D).cuda().to(dtype).requires_grad_(require_grad)
w = torch.nn.functional.logsigmoid(torch.randn(B, H, L, D)).cuda().to(dtype).requires_grad_(require_grad)
u = (torch.randn(H, D).cuda().to(dtype)).requires_grad_(require_grad)
h = (torch.randn(B, H, D, 2*D).cuda().to(dtype)).requires_grad_(require_grad)
do = torch.rand_like(v).cuda()
o2, _ = chunk_rwkv6(q, k, v, w.clone(), u)
o, _ = fused_recurrent_rwkv6(q, k, v, w, u, scale=1.0)
o.backward(do)
dq, q.grad = q.grad.clone(), None
dk, k.grad = k.grad.clone(), None
dv, v.grad = v.grad.clone(), None
dw, w.grad = w.grad.clone(), None
du, u.grad = u.grad.clone(), None
dh, h.grad = h.grad.clone(), None

o2.backward(do)

def rmsre(pred, target, eps=1e-8):
return torch.sqrt(torch.mean(torch.square((pred - target) / (target.abs() + eps))))

def print_diff(name, grad1, grad2):
abs_diff = (grad1 - grad2).abs()
max_diff = abs_diff.max().item()
rmsre_value = rmsre(grad1, grad2).item()
print(f"{name}: Max Abs Diff = {max_diff:.6f}, RMSRE = {rmsre_value:.6f}")

print(f"o: {(o - o2).abs().max().item():.6f}")
print_diff("q", q.grad, dq)
print_diff("k", k.grad, dk)
print_diff("v", v.grad, dv)
print_diff("w", w.grad, dw)
print_diff("u", u.grad, du)
print_diff("h", h.grad, dh)

all_grads1 = torch.cat([q.grad.flatten(), k.grad.flatten(), v.grad.flatten(),
w.grad.flatten(), u.grad.flatten(), h.grad.flatten()])
all_grads2 = torch.cat([dq.flatten(), dk.flatten(), dv.flatten(),
dw.flatten(), du.flatten(), dh.flatten()])
overall_rmsre = rmsre(all_grads1, all_grads2).item()
print(f"\nOverall RMSRE: {overall_rmsre:.6f}")
32 changes: 13 additions & 19 deletions fla/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def softmax_bwd_kernel(
tl.store(p_ds, b_ds.to(p_ds.dtype.element_ty), boundary_check=(0, 1))



@triton.autotune(
configs=[
triton.Config({'BT': 16}, num_warps=2),
Expand Down Expand Up @@ -222,7 +221,6 @@ def chunk_global_cumsum_vector_kernel(
b_z += tl.sum(b_s, 0)



@triton.autotune(
configs=[
triton.Config({'BT': 16}, num_warps=2),
Expand Down Expand Up @@ -251,6 +249,7 @@ def chunk_global_reversed_cumsum_scalar_kernel(
b_o = b_s - tl.cumsum(b_s, axis=0) + b_z[None]
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))


@triton.autotune(
configs=[
triton.Config({'BT': 16}, num_warps=2),
Expand Down Expand Up @@ -280,7 +279,6 @@ def chunk_global_cumsum_scalar_kernel(
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))



@triton.autotune(
configs=[
triton.Config({'BS': 16}, num_warps=2),
Expand Down Expand Up @@ -317,6 +315,7 @@ def chunk_local_cumsum_vector_kernel(
b_o = tl.dot(m_s, b_s, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))


@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
Expand All @@ -342,7 +341,6 @@ def chunk_local_cumsum_scalar_kernel(
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))



def chunk_local_cumsum_vector(g, BT):
B, H, T, S = g.shape
NT = triton.cdiv(T, BT)
Expand All @@ -365,11 +363,11 @@ def chunk_local_cumsum_scalar(g, BT):
g_org, g = g, torch.empty_like(g, dtype=torch.float)
grid = (NT, B * H)
chunk_local_cumsum_scalar_kernel[grid](
g_org, g,
g_org, g,
T=T, BT=BT
)
return g


@contiguous
def chunk_local_cumsum(g, BT):
Expand All @@ -378,8 +376,8 @@ def chunk_local_cumsum(g, BT):
elif len(g.shape) == 4:
return chunk_local_cumsum_vector(g, BT)
else:
raise ValueError(f"Unsupported shape {g.shape}. Should be either (batch size, num head, seq len, dim) or (Batch size, num head, seq len)")

raise ValueError(f"Unsupported shape {
g.shape}. Should be either (batch size, num_heads, seq_len, dim) or (batch_size, num_heads, seq_len)")


@contiguous
Expand All @@ -400,7 +398,6 @@ def chunk_global_reversed_cumsum_vector(
return z



@contiguous
def chunk_global_reversed_cumsum_scalar(
s: torch.Tensor,
Expand All @@ -411,13 +408,12 @@ def chunk_global_reversed_cumsum_scalar(
grid = (B * H,)
z = torch.empty_like(s, dtype=dtype)
chunk_global_reversed_cumsum_scalar_kernel[grid](
s, z,
s, z,
T=T
)
return z



@contiguous
def chunk_global_cumsum_vector(
s: torch.Tensor,
Expand All @@ -436,7 +432,6 @@ def chunk_global_cumsum_vector(
return z



@contiguous
def chunk_global_cumsum_scalar(
s: torch.Tensor,
Expand All @@ -447,20 +442,21 @@ def chunk_global_cumsum_scalar(
grid = (B * H,)
z = torch.empty_like(s, dtype=dtype)
chunk_global_cumsum_scalar_kernel[grid](
s, z,
s, z,
T=T
)
return z


@contiguous
def chunk_global_cumsum(s, dtype=None):
if len(s.shape) == 3:
return chunk_global_cumsum_scalar(s, dtype)
elif len(s.shape) == 4:
return chunk_global_cumsum_vector(s, dtype)
else:
raise ValueError(f"Unsupported shape {s.shape}. Should be either (batch size, num head, seq len) or (Batch size, num head, seq len, dim)")

raise ValueError(f"Unsupported shape {s.shape}. "
f"Should be either [batch size, num_heads, seq_len] or [batch_size, num_heads, seq_len, dim]")


@contiguous
Expand All @@ -470,7 +466,5 @@ def chunk_global_reversed_cumsum(s, dtype=None):
elif len(s.shape) == 4:
return chunk_global_reversed_cumsum_vector(s, dtype)
else:
raise ValueError(f"Unsupported shape {s.shape}. Should be either (batch size, num head, seq len) or (Batch size, num head, seq len, dim)")



raise ValueError(f"Unsupported shape {s.shape}. "
f"Should be either [batch size, num_heads, seq_len] or [batch_size, num_heads, seq_len, dim]")

0 comments on commit 9ec376b

Please sign in to comment.