Skip to content

Commit

Permalink
refactor: replace deprecated torch.cuda.amp.autocast;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Sep 24, 2024
1 parent d8ade79 commit df69ff2
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pypots/nn/modules/reformer/local_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from einops import rearrange
from einops import repeat, pack, unpack
from torch import nn, einsum
from torch.cuda.amp import autocast
from torch.amp import autocast

TOKEN_SELF_ATTN_VALUE = -5e4

Expand All @@ -28,7 +28,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@autocast(enabled=False)
@autocast("cuda", enabled=False)
def apply_rotary_pos_emb(q, k, freqs, scale=1):
q_len = q.shape[-2]
q_freqs = freqs[..., -q_len:, :]
Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(self, dim, scale_base=None, use_xpos=False):
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.register_buffer("scale", scale, persistent=False)

@autocast(enabled=False)
@autocast("cuda", enabled=False)
def forward(self, x):
seq_len, device = x.shape[-2], x.device

Expand Down

0 comments on commit df69ff2

Please sign in to comment.