Skip to content

Commit

Permalink
address #15
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 14, 2023
1 parent 02a5d07 commit f3dc708
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 26 deletions.
75 changes: 50 additions & 25 deletions rotary_embedding_torch/rotary_embedding_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,16 @@ def __init__(
elif freqs_for == 'constant':
freqs = torch.ones(num_freqs).float()

self.cache = dict()
self.cache_scale = dict()
self.tmp_store('cached_freqs', None)
self.tmp_store('cached_scales', None)

self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)

self.learned_freq = learned_freq

# dummy for device

self.register_buffer('dummy', torch.tensor(0), persistent = False)
self.tmp_store('dummy', torch.tensor(0))

# default sequence dimension

Expand All @@ -116,17 +117,20 @@ def __init__(

self.use_xpos = use_xpos
if not use_xpos:
self.register_buffer('scale', None)
self.tmp_store('scale', None)
return

scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.scale_base = xpos_scale_base
self.register_buffer('scale', scale)
self.tmp_store('scale', scale)

@property
def device(self):
return self.dummy.device

def tmp_store(self, key, value):
self.register_buffer(key, value, persistent = False)

def get_seq_pos(self, seq_len, device, dtype, offset = 0):
return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor

Expand All @@ -141,7 +145,7 @@ def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, freq_seq_len = N
assert freq_seq_len >= seq_len
seq_len = freq_seq_len

freqs = self.forward(lambda: self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset), cache_key = f'freqs:{seq_len}|offset:{offset}')
freqs = self.forward(self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset), seq_len = seq_len, offset = offset)

if seq_dim == -3:
freqs = rearrange(freqs, 'n d -> n 1 d')
Expand All @@ -168,8 +172,9 @@ def rotate_queries_and_keys(self, q, k, seq_dim = None):
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]

seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)
freqs = self.forward(lambda: seq, cache_key = f'freqs:{seq_len}')
scale = self.get_scale(lambda: seq, cache_key = f'scale:{seq_len}').to(dtype)

freqs = self.forward(seq, seq_len = seq_len)
scale = self.get_scale(seq, seq_len = seq_len).to(dtype)

if seq_dim == -3:
freqs = rearrange(freqs, 'n d -> n 1 d')
Expand All @@ -183,23 +188,32 @@ def rotate_queries_and_keys(self, q, k, seq_dim = None):

return rotated_q, rotated_k

def get_scale(self, t, cache_key = None):
@beartype
def get_scale(
self,
t: Tensor,
seq_len: Optional[int] = None,
offset = 0
):
assert self.use_xpos

if exists(cache_key) and cache_key in self.cache:
return self.cache[cache_key]
should_cache = exists(seq_len)

if callable(t):
t = t()
if (
should_cache and \
exists(self.cached_scales) and \
(seq_len + offset) <= self.cached_scales.shape[0]
):
return self.cached_scales[offset:(offset + seq_len)]

scale = 1.
if self.use_xpos:
power = (t - len(t) // 2) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)

if exists(cache_key):
self.cache[cache_key] = scale
if should_cache:
self.tmp_store('cached_scales', scale)

return scale

Expand All @@ -213,7 +227,7 @@ def get_axial_freqs(self, *dims):
else:
pos = torch.arange(dim, device = self.device)

freqs = self.forward(pos, cache_key = dim)
freqs = self.forward(pos, seq_len = dim)

all_axis = [None] * len(dims)
all_axis[ind] = Colon
Expand All @@ -225,21 +239,32 @@ def get_axial_freqs(self, *dims):
return torch.cat(all_freqs, dim = -1)

@autocast(enabled = False)
def forward(self, t, cache_key = None):
should_cache = not self.learned_freq and exists(cache_key)

if should_cache and cache_key in self.cache:
return self.cache[cache_key]

if callable(t):
t = t()
@beartype
def forward(
self,
t: Tensor,
seq_len: Optional[int] = None,
offset = 0
):
should_cache = (
not self.learned_freq and \
exists(seq_len) and \
self.freqs_for != 'pixel'
)

if (
should_cache and \
exists(self.cached_freqs) and \
(offset + seq_len) <= self.cached_freqs.shape[0]
):
return self.cached_freqs[offset:(offset + seq_len)]

freqs = self.freqs

freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)

if should_cache:
self.cache[cache_key] = freqs
self.tmp_store('cached_freqs', freqs)

return freqs
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'rotary-embedding-torch',
packages = find_packages(),
version = '0.4.0',
version = '0.5.0',
license='MIT',
description = 'Rotary Embedding - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit f3dc708

Please sign in to comment.