Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DeltaNet] Adds beta as a vector option #42

Merged
merged 6 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,4 +265,4 @@ def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs):
if verbose:
print(f"{desc} max memory: {mem}GB")
torch.cuda.empty_cache()
return mem
return mem
4 changes: 2 additions & 2 deletions benchmarks/ops/benchmark_delta_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from benchmark import benchmark_combined, benchmark_forward

from fla.ops.delta_rule import (chunk_linear_attn_delta_rule,
fused_recurrent_linear_attn_delta_rule)
fused_recurrent_delta_rule)
from fla.ops.retention import fused_chunk_retention


Expand Down Expand Up @@ -71,7 +71,7 @@ def time_fwd_bwd(func, *args, **kwargs):
v2 = torch.randn(B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype)

f_b = time_fwd_bwd(
fused_recurrent_linear_attn_delta_rule, q, k, v, verbose=False
fused_recurrent_delta_rule, q, k, v, verbose=False
)
time_f_b[config, "delta_recurrent"] = f_b

Expand Down
4 changes: 2 additions & 2 deletions fla/layers/delta_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
from fla.ops.delta_rule import (chunk_delta_rule, fused_chunk_delta_rule,
fused_recurrent_linear_attn_delta_rule)
fused_recurrent_delta_rule)

if TYPE_CHECKING:
from fla.models.utils import Cache
Expand Down Expand Up @@ -206,7 +206,7 @@ def forward(
beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2])
state = past_key_values[self.layer_idx][-1] if use_cache else None
if mode == 'fused_recurrent':
o, recurrent_state = fused_recurrent_linear_attn_delta_rule(q, k, v, beta, state, output_final_state=use_cache)
o, recurrent_state = fused_recurrent_delta_rule(q, k, v, beta, state, output_final_state=use_cache)
elif mode == 'fused_chunk':
assert self.chunk_size in [16, 32, 64]
o, recurrent_state = fused_chunk_delta_rule(q, k, v, beta, self.chunk_size, state, output_final_state=use_cache)
Expand Down
3 changes: 1 addition & 2 deletions fla/ops/abc/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import triton
import triton.language as tl

from fla.ops.utils import (logcumsumexp_fwd_kernel, softmax_bwd_kernel,
softmax_fwd_kernel)
from fla.ops.utils import (logcumsumexp_fwd_kernel, softmax_bwd_kernel, softmax_fwd_kernel)
from fla.utils import contiguous


Expand Down
6 changes: 3 additions & 3 deletions fla/ops/delta_rule/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-

from .chunk_fuse import fused_chunk_delta_rule
from .recurrent_fuse import fused_recurrent_linear_attn_delta_rule
from .chunk import chunk_delta_rule
from .chunk_fuse import fused_chunk_delta_rule
from .recurrent_fuse import fused_recurrent_delta_rule

__all__ = [
'fused_chunk_delta_rule',
'fused_recurrent_linear_attn_delta_rule',
'fused_recurrent_delta_rule',
'chunk_delta_rule'
]
7 changes: 6 additions & 1 deletion fla/ops/delta_rule/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@ def delta_rule_recurrence(q, k, v, beta):
o = torch.zeros_like(v)
S = torch.zeros(b, h, d_k, d_v).to(v)
q = q * (d_k ** -0.5)

if beta.ndim < v.ndim:
beta = beta[..., None]

for i in range(l):
_k = k[:, :, i]
_q = q[:, :, i]
_v = v[:, :, i].clone()
beta_i = beta[:, :, i]
_v = _v - (S.clone() * _k[..., None]).sum(-2)
_v = _v * beta_i[..., None]
_v = _v * beta_i
S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)

return o


Expand Down
Loading
Loading