Skip to content

Commit

Permalink
tentative fix
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Jun 4, 2022
1 parent 5ccbcd9 commit 4652427
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
10 changes: 8 additions & 2 deletions tests/test_rotary_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,21 @@ def test_helper_methods():


@pytest.mark.parametrize("device", DEVICES)
def test_rotary_embeddings(device):
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_rotary_embeddings(device, dtype):
rotary = RotaryEmbedding(EMB).to(device)

# Generate dummy inputs
q = torch.ones((BATCH, HEADS, SEQ, EMB), device=device) # uniform on purpose
q = torch.ones(
(BATCH, HEADS, SEQ, EMB), device=device, dtype=dtype
) # uniform on purpose
k = q.clone()

q_rot, k_rot = rotary(q, k)

assert q_rot.dtype == q.dtype
assert k_rot.dtype == k.dtype

# Check that the sequences now encode relative position information
att = torch.einsum("bhne,bhme->bhnm", q, k)
att_rot = torch.einsum("bhne,bhme->bhnm", q_rot, k_rot)
Expand Down
12 changes: 7 additions & 5 deletions xformers/components/positional_embedding/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@ def _update_cos_sin_tables(self, x, seq_dimension=1):

# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
if (
seq_len != self._seq_len_cached
or self._cos_cached.device != x.device
or self._cos_cached.dtype != x.dtype
):
self._seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(
self.inv_freq
)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
t = torch.arange(x.shape[seq_dimension], device=x.device, dtype=x.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

self._cos_cached = emb.cos()[None, None, :, :]
Expand Down

0 comments on commit 4652427

Please sign in to comment.