Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update amp custom_fwd, custom_bwd usage for torch 2.4.0 compatibility #54

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions fla/ops/abc/recurrent_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous


@triton.jit
Expand Down Expand Up @@ -284,7 +283,7 @@ class FusedRecurrentGatedABCFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(
ctx,
q: torch.Tensor,
Expand Down Expand Up @@ -374,7 +373,7 @@ def forward(

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, dht=None):
q, k, v, s, g, qv, hk0, hv0, ok = ctx.saved_tensors
B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/based/chunk_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous

# on-the-fly computation without materializing hidden statets into HBMs

Expand Down Expand Up @@ -305,7 +304,7 @@ class FusedChunkBasedFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, scale=1):
B, H, T, K, V = *k.shape, v.shape[-1]

Expand Down Expand Up @@ -338,7 +337,7 @@ def forward(ctx, q, k, v, scale=1):

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, dz):
q, k, v = ctx.saved_tensors
B, H, T, K, V = *k.shape, v.shape[-1]
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/based/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous

# Based: An Educational and Effective Sequence Mixer
# https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based
Expand Down Expand Up @@ -314,7 +313,7 @@ class ParallelBasedFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, scale):
BTL, BTS = 128, 32
assert BTL % BTS == 0
Expand Down Expand Up @@ -349,7 +348,7 @@ def forward(ctx, q, k, v, scale):

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, dz):
q, k, v = ctx.saved_tensors
scale = ctx.scale
Expand Down
6 changes: 3 additions & 3 deletions fla/ops/delta_rule/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.ops.delta_rule.wy_fast import (bwd_prepare_wy_repr,
fwd_prepare_wy_repr, fwd_recompute_w_u)
from fla.ops.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd


@triton.autotune(
Expand Down Expand Up @@ -491,7 +491,7 @@ class ChunkDeltaRuleFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=1):
# obtain WY representation. u is actually the new v.
w, u, A = fwd_prepare_wy_repr(k, v, beta, BT)
Expand All @@ -512,7 +512,7 @@ def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoin

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, d_ht=None):
q, k, v, beta, A, h, v_new, initial_state = ctx.saved_tensors
BT = ctx.BT
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/delta_rule/chunk_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.ops.delta_rule.utils import bwd_prepare_wy_repr, fwd_prepare_wy_repr
from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous


# on-the-fly computation without materializing hidden statets into HBMs
Expand Down Expand Up @@ -327,7 +326,7 @@ class FusedChunkDeltaRuleFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=0):
# lvl=1 will recompute ``fwd_prepare_wy_repr`` for saving memory.
assert checkpoint_level in [0, 1]
Expand All @@ -345,7 +344,7 @@ def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoin

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, d_final_state=None):
q, k_origin, v, v_new, v_new2, d, beta, initial_state = ctx.saved_tensors
chunk_size = ctx.chunk_size
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/delta_rule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import triton
import triton.language as tl
from einops import rearrange
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.ops.delta_rule.wy_fast import prepare_wy_repr as prepare_wy_repr2
from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous


# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
Expand Down Expand Up @@ -191,7 +190,7 @@ def bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size):

class WYRepresentationPrepration(torch.autograd.Function):
@contiguous
@custom_fwd
@autocast_custom_fwd
@staticmethod
def forward(ctx, k, v, beta, chunk_size):
o_cumdecay, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size)
Expand All @@ -200,7 +199,7 @@ def forward(ctx, k, v, beta, chunk_size):
return o_cumdecay, v_new

@contiguous
@custom_bwd
@autocast_custom_bwd
@staticmethod
def backward(ctx, do, do2):
k, v, beta, o_cumdecay, v_new = ctx.saved_tensors
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/delta_rule/wy_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import triton
import triton.language as tl
from einops import rearrange
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous


# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
Expand Down Expand Up @@ -288,7 +287,7 @@ class WYRepresentationPrepration(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, k, v, beta, chunk_size=64):
ctx.BT = chunk_size
w, u, A = fwd_prepare_wy_repr(k, v, beta, ctx.BT)
Expand All @@ -297,7 +296,7 @@ def forward(ctx, k, v, beta, chunk_size=64):

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, dw, du):
k, v, beta, A = ctx.saved_tensors
BT = ctx.BT
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/gla/chunk_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
import triton.language as tl
from einops import rearrange
from packaging import version
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.ops.gla.chunk_util import (bwd_decay_global_cumsum, fwd_decay_cumsum,
prepare_qg_kg)
from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous


@triton.jit
Expand Down Expand Up @@ -304,7 +303,7 @@ class FusedChunkGLAFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):
ctx.g_dtype = g.dtype
g_original = g
Expand Down Expand Up @@ -396,7 +395,7 @@ def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, dht=None):
q, k, v, g_origin, A, initial_state = ctx.saved_tensors
B, H, T, K, V = *k.shape, v.shape[-1]
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/gla/recurrent_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous

# on-the-fly computation without materializing hidden statets into HBMs

Expand Down Expand Up @@ -223,7 +222,7 @@ class FusedRecurrentGLAFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, gk, gv, scale=None, initial_state=None, output_final_state=False, reverse=False):
B, H, T, K, V = *q.shape, v.shape[-1]
# default scale
Expand Down Expand Up @@ -270,7 +269,7 @@ def forward(ctx, q, k, v, gk, gv, scale=None, initial_state=None, output_final_s

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, dht=None):
q, k, v, gk, gv, initial_state, o = ctx.saved_tensors
batch_size, n_heads, seq_len, K = q.shape
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/linear_attn/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.ops.linear_attn.utils import normalize_output
from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous


@triton.jit
Expand Down Expand Up @@ -238,7 +237,7 @@ class ChunkLinearAttentionFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, scale, initial_state, output_final_state):
B, H, T, K, V = *q.shape, v.shape[-1]
BT = 64
Expand Down Expand Up @@ -282,7 +281,7 @@ def forward(ctx, q, k, v, scale, initial_state, output_final_state):

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, dht=None):
q, k, v, h = ctx.saved_tensors

Expand Down
7 changes: 3 additions & 4 deletions fla/ops/linear_attn/chunk_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
import triton
import triton.language as tl
from packaging import version
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.ops.linear_attn.utils import normalize_output
from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous


@triton.jit
Expand Down Expand Up @@ -208,7 +207,7 @@ class FusedChunkLinearAttentionFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, scale, initial_state, output_final_state):
B, H, T, K, V = *k.shape, v.shape[-1]
BT = 64
Expand Down Expand Up @@ -255,7 +254,7 @@ def forward(ctx, q, k, v, scale, initial_state, output_final_state):

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, dht=None):
q, k, v, initial_state = ctx.saved_tensors
B, H, T, K, V = *k.shape, v.shape[-1]
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/rebased/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous

# Rebased: Linear Transformers with Learnable Kernel Functions are Better In-Context Models
# https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py
Expand Down Expand Up @@ -339,7 +338,7 @@ class ParallelBasedFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, scale):
BTL, BTS = 128, 32
assert BTL % BTS == 0
Expand Down Expand Up @@ -374,7 +373,7 @@ def forward(ctx, q, k, v, scale):

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, dz):
q, k, v = ctx.saved_tensors
scale = ctx.scale
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/retention/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous


@triton.autotune(
Expand Down Expand Up @@ -375,7 +374,7 @@ class ChunkRetentionFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, initial_state, output_final_state, scale, checkpoint_level):
BT = 64
h, final_state = chunk_fwd_h_fn(k, v, BT, initial_state, output_final_state)
Expand All @@ -388,7 +387,7 @@ def forward(ctx, q, k, v, initial_state, output_final_state, scale, checkpoint_l

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, d_ht=None):
BT, scale = ctx.BT, ctx.scale
q, k, v, h, initial_state = ctx.saved_tensors
Expand Down
Loading
Loading