Skip to content

Commit

Permalink
can slice the rotary pos emb being passed into apply_rotary_emb at th…
Browse files Browse the repository at this point in the history
…e right sequence dimension if it is passed in
  • Loading branch information
lucidrains committed Sep 29, 2024
1 parent e2224b5 commit 4da3d07
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
22 changes: 18 additions & 4 deletions rotary_embedding_torch/rotary_embedding_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from math import pi, log

import torch
from torch.nn import Module, ModuleList
from torch.amp import autocast
from torch.nn import Module, ModuleList
from torch import nn, einsum, broadcast_tensors, Tensor

from einops import rearrange, repeat
Expand All @@ -24,6 +24,12 @@ def broadcat(tensors, dim = -1):
broadcasted_tensors = broadcast_tensors(*tensors)
return torch.cat(broadcasted_tensors, dim = dim)

def slice_at_dim(t, dim_slice: slice, *, dim):
dim += (t.ndim if dim < 0 else 0)
colons = [slice(None)] * t.ndim
colons[dim] = dim_slice
return t[tuple(colons)]

# rotary embedding helper functions

def rotate_half(x):
Expand All @@ -33,12 +39,20 @@ def rotate_half(x):
return rearrange(x, '... d r -> ... (d r)')

@autocast('cuda', enabled = False)
def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2):
def apply_rotary_emb(
freqs,
t,
start_index = 0,
scale = 1.,
seq_dim = -2,
freqs_seq_dim = None
):
dtype = t.dtype

if t.ndim == 3:
if t.ndim == 3 or exists(freqs_seq_dim):
freqs_seq_dim = default(freqs_seq_dim, 0)
seq_len = t.shape[seq_dim]
freqs = freqs[-seq_len:]
freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim)

rot_dim = freqs.shape[-1]
end_index = start_index + rot_dim
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'rotary-embedding-torch',
packages = find_packages(),
version = '0.8.3',
version = '0.8.4',
license='MIT',
description = 'Rotary Embedding - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit 4da3d07

Please sign in to comment.