From bcb707576c6a80eaf850aa80e8643d3497ec2bc4 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Mon, 6 Jun 2022 08:44:43 -0700 Subject: [PATCH] [fix] Rotary embeddings respecting input types (#326) * tentative fix * fixing arange not knowing half on cpu --- CHANGELOG.md | 1 + tests/test_rotary_embeddings.py | 13 +++++++++++-- .../components/positional_embedding/rotary.py | 16 ++++++++++------ 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 31bc0ecdb..425c8c21b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## TBD ### Fixed - Removed dupliacated biases in the FusedMLP layers [#317] +- Rotary embeddings respecting input types [#326] ### Added - Four blocksparsity layouts from DeepSpeed [#320] diff --git a/tests/test_rotary_embeddings.py b/tests/test_rotary_embeddings.py index d22292eab..ad6a477dc 100644 --- a/tests/test_rotary_embeddings.py +++ b/tests/test_rotary_embeddings.py @@ -52,16 +52,25 @@ def test_helper_methods(): @pytest.mark.parametrize("device", DEVICES) -def test_rotary_embeddings(device): +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_rotary_embeddings(device, dtype): rotary = RotaryEmbedding(EMB).to(device) # Generate dummy inputs - q = torch.ones((BATCH, HEADS, SEQ, EMB), device=device) # uniform on purpose + q = torch.ones( + (BATCH, HEADS, SEQ, EMB), device=device, dtype=dtype + ) # uniform on purpose k = q.clone() q_rot, k_rot = rotary(q, k) + assert q_rot.dtype == q.dtype + assert k_rot.dtype == k.dtype + # Check that the sequences now encode relative position information + q, k = q.float(), k.float() + q_rot, k_rot = q_rot.float(), k_rot.float() + att = torch.einsum("bhne,bhme->bhnm", q, k) att_rot = torch.einsum("bhne,bhme->bhnm", q_rot, k_rot) diff --git a/xformers/components/positional_embedding/rotary.py b/xformers/components/positional_embedding/rotary.py index 94bf5736f..551089b3b 100644 --- a/xformers/components/positional_embedding/rotary.py +++ b/xformers/components/positional_embedding/rotary.py @@ -61,16 +61,20 @@ def _update_cos_sin_tables(self, x, seq_dimension=1): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) - if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: + if ( + seq_len != self._seq_len_cached + or self._cos_cached.device != x.device + or self._cos_cached.dtype != x.dtype + ): self._seq_len_cached = seq_len - t = torch.arange(x.shape[seq_dimension], device=x.device).type_as( - self.inv_freq + t = torch.arange( + x.shape[seq_dimension], device=x.device, dtype=torch.float32 ) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self._cos_cached = emb.cos()[None, None, :, :] - self._sin_cached = emb.sin()[None, None, :, :] + self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype) + self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype) return self._cos_cached, self._sin_cached