Skip to content

Commit

Permalink
Merge pull request #42 from hypnopump/recurrent_beta_vector
Browse files Browse the repository at this point in the history
[DeltaNet] Adds beta as a vector option
  • Loading branch information
yzhangcs authored Aug 9, 2024
2 parents a068379 + 8ab462d commit c0dde22
Show file tree
Hide file tree
Showing 10 changed files with 248 additions and 220 deletions.
2 changes: 1 addition & 1 deletion benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,4 +265,4 @@ def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs):
if verbose:
print(f"{desc} max memory: {mem}GB")
torch.cuda.empty_cache()
return mem
return mem
4 changes: 2 additions & 2 deletions benchmarks/ops/benchmark_delta_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from benchmark import benchmark_combined, benchmark_forward

from fla.ops.delta_rule import (chunk_linear_attn_delta_rule,
fused_recurrent_linear_attn_delta_rule)
fused_recurrent_delta_rule)
from fla.ops.retention import fused_chunk_retention


Expand Down Expand Up @@ -71,7 +71,7 @@ def time_fwd_bwd(func, *args, **kwargs):
v2 = torch.randn(B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype)

f_b = time_fwd_bwd(
fused_recurrent_linear_attn_delta_rule, q, k, v, verbose=False
fused_recurrent_delta_rule, q, k, v, verbose=False
)
time_f_b[config, "delta_recurrent"] = f_b

Expand Down
4 changes: 2 additions & 2 deletions fla/layers/delta_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
from fla.ops.delta_rule import (chunk_delta_rule, fused_chunk_delta_rule,
fused_recurrent_linear_attn_delta_rule)
fused_recurrent_delta_rule)

if TYPE_CHECKING:
from fla.models.utils import Cache
Expand Down Expand Up @@ -206,7 +206,7 @@ def forward(
beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2])
state = past_key_values[self.layer_idx][-1] if use_cache else None
if mode == 'fused_recurrent':
o, recurrent_state = fused_recurrent_linear_attn_delta_rule(q, k, v, beta, state, output_final_state=use_cache)
o, recurrent_state = fused_recurrent_delta_rule(q, k, v, beta, state, output_final_state=use_cache)
elif mode == 'fused_chunk':
assert self.chunk_size in [16, 32, 64]
o, recurrent_state = fused_chunk_delta_rule(q, k, v, beta, self.chunk_size, state, output_final_state=use_cache)
Expand Down
3 changes: 1 addition & 2 deletions fla/ops/abc/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import triton
import triton.language as tl

from fla.ops.utils import (logcumsumexp_fwd_kernel, softmax_bwd_kernel,
softmax_fwd_kernel)
from fla.ops.utils import (logcumsumexp_fwd_kernel, softmax_bwd_kernel, softmax_fwd_kernel)
from fla.utils import contiguous


Expand Down
6 changes: 3 additions & 3 deletions fla/ops/delta_rule/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-

from .chunk_fuse import fused_chunk_delta_rule
from .recurrent_fuse import fused_recurrent_linear_attn_delta_rule
from .chunk import chunk_delta_rule
from .chunk_fuse import fused_chunk_delta_rule
from .recurrent_fuse import fused_recurrent_delta_rule

__all__ = [
'fused_chunk_delta_rule',
'fused_recurrent_linear_attn_delta_rule',
'fused_recurrent_delta_rule',
'chunk_delta_rule'
]
7 changes: 6 additions & 1 deletion fla/ops/delta_rule/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@ def delta_rule_recurrence(q, k, v, beta):
o = torch.zeros_like(v)
S = torch.zeros(b, h, d_k, d_v).to(v)
q = q * (d_k ** -0.5)

if beta.ndim < v.ndim:
beta = beta[..., None]

for i in range(l):
_k = k[:, :, i]
_q = q[:, :, i]
_v = v[:, :, i].clone()
beta_i = beta[:, :, i]
_v = _v - (S.clone() * _k[..., None]).sum(-2)
_v = _v * beta_i[..., None]
_v = _v * beta_i
S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)

return o


Expand Down
Loading

0 comments on commit c0dde22

Please sign in to comment.