Skip to content

Commit

Permalink
add xpos from new microsoft paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 22, 2022
1 parent 517ee2c commit 4b2b86d
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 35 deletions.
45 changes: 22 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,40 +71,31 @@ q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)
```

## Learned Rotations
## Length Extrapolatable Rotary Embeddings

For injecting learned rotations into a network. Experiments pending

Update: doesn't seem to do anything -_-, will keep trying...
In <a href="https://arxiv.org/abs/2212.10554v1">this paper</a>, they were able to fix length extrapolation issue with rotary embeddings by giving it a decay similar to ALiBi. They named this technique XPos, and you can use it by setting `use_xpos = True` on initialization

```python
import torch
from torch import nn
from rotary_embedding_torch import apply_learned_rotations

x = torch.randn(1, 1024, 512)

# you can only rotate in (dim // 2) values
# ex. for 512, you can only rotate in 256 values

# say you have two sets of learned rotations of 128 values each

rots1 = nn.Linear(512, 128)(x)
rots2 = nn.Linear(512, 128)(x)
from rotary_embedding_torch import RotaryEmbedding

# you rotate in 256 (128 x 2) at first
# instantiate the positional embedding in your transformer and pass to all your attention layers

x = apply_learned_rotations(rots1, x, start_index = 0)
rotary_emb = RotaryEmbedding(
dim = 32,
use_xpos = True # set this to True to make rotary embeddings extrapolate better to sequence lengths greater than the one used at training time
)

# then you start at index 256 and rotate in the last (128 x 2)
# mock queries and keys - dimensions should end with (seq_len, feature dimension), and any number of preceding dimensions (batch, heads, etc)

x = apply_learned_rotations(rots2, x, start_index = 256)
q = torch.randn(1, 8, 1024, 64) # queries - (batch, heads, seq len, dimension of head)
k = torch.randn(1, 8, 1024, 64) # keys

# you could also concat the rotations together and pass it in all at once
# apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention)

rots = torch.cat((rots1, rots2), dim = -1)
# instead of using `rotate_queries_or_keys`, you will use `rotate_queries_and_keys`, the rest is taken care of

x = apply_learned_rotations(rots, x)
q, k = rotary_emb.rotate_queries_and_keys(q, k)
```

## Citations
Expand All @@ -119,3 +110,11 @@ x = apply_learned_rotations(rots, x)
primaryClass = {cs.CL}
}
```

```bibtex
@inproceedings{Sun2022ALT,
title = {A Length-Extrapolatable Transformer},
author = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
year = {2022}
}
```
59 changes: 48 additions & 11 deletions rotary_embedding_torch/rotary_embedding_torch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from inspect import isfunction
from math import pi, log

import torch
Expand Down Expand Up @@ -37,13 +36,13 @@ def rotate_half(x):
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d r -> ... (d r)')

def apply_rotary_emb(freqs, t, start_index = 0):
def apply_rotary_emb(freqs, t, start_index = 0, scale = 1.):
freqs = freqs.to(t)
rot_dim = freqs.shape[-1]
end_index = start_index + rot_dim
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
return torch.cat((t_left, t, t_right), dim = -1)

# learned rotation helpers
Expand All @@ -67,7 +66,9 @@ def __init__(
theta = 10000,
max_freq = 10,
num_freqs = 1,
learned_freq = False
learned_freq = False,
use_xpos = False,
xpos_scale_base = 512,
):
super().__init__()
if exists(custom_freqs):
Expand All @@ -82,23 +83,59 @@ def __init__(
raise ValueError(f'unknown modality {freqs_for}')

self.cache = dict()
self.cache_scale = dict()
self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)

if learned_freq:
self.freqs = nn.Parameter(freqs)
else:
self.register_buffer('freqs', freqs)
self.use_xpos = use_xpos
if not use_xpos:
self.register_buffer('scale', None)
return

scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.scale_base = xpos_scale_base
self.register_buffer('scale', scale)

def rotate_queries_or_keys(self, t, seq_dim = -2):
device = t.device
seq_len = t.shape[seq_dim]
assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
device, seq_len = t.device, t.shape[seq_dim]
freqs = self.forward(lambda: torch.arange(seq_len, device = device), cache_key = seq_len)
return apply_rotary_emb(freqs, t)

def rotate_queries_and_keys(self, q, k, seq_dim = -2):
assert self.use_xpos
device, seq_len = q.device, q.shape[seq_dim]
seq = torch.arange(seq_len, device = device)
freqs = self.forward(lambda: seq, cache_key = seq_len)
scale = self.get_scale(lambda: seq, cache_key = seq_len)
rotated_q = apply_rotary_emb(freqs, q, scale = scale)
rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1)
return rotated_q, rotated_k

def get_scale(self, t, cache_key = None):
assert self.use_xpos

if exists(cache_key) and cache_key in self.cache:
return self.cache[cache_key]

if callable(t):
t = t()

scale = 1.
if self.use_xpos:
power = t - (len(t) // 2) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)

if exists(cache_key):
self.cache[cache_key] = freqs

return scale

def forward(self, t, cache_key = None):
if exists(cache_key) and cache_key in self.cache:
return self.cache[cache_key]

if isfunction(t):
if callable(t):
t = t()

freqs = self.freqs
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'rotary-embedding-torch',
packages = find_packages(),
version = '0.1.5',
version = '0.2.0',
license='MIT',
description = 'Rotary Embedding - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 4b2b86d

Please sign in to comment.