diff --git a/CHANGELOG.md b/CHANGELOG.md index 31bc0ecdb3..425c8c21b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## TBD ### Fixed - Removed dupliacated biases in the FusedMLP layers [#317] +- Rotary embeddings respecting input types [#326] ### Added - Four blocksparsity layouts from DeepSpeed [#320] diff --git a/tests/test_rotary_embeddings.py b/tests/test_rotary_embeddings.py index d22292eab9..ad6a477dc3 100644 --- a/tests/test_rotary_embeddings.py +++ b/tests/test_rotary_embeddings.py @@ -52,16 +52,25 @@ def test_helper_methods(): @pytest.mark.parametrize("device", DEVICES) -def test_rotary_embeddings(device): +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_rotary_embeddings(device, dtype): rotary = RotaryEmbedding(EMB).to(device) # Generate dummy inputs - q = torch.ones((BATCH, HEADS, SEQ, EMB), device=device) # uniform on purpose + q = torch.ones( + (BATCH, HEADS, SEQ, EMB), device=device, dtype=dtype + ) # uniform on purpose k = q.clone() q_rot, k_rot = rotary(q, k) + assert q_rot.dtype == q.dtype + assert k_rot.dtype == k.dtype + # Check that the sequences now encode relative position information + q, k = q.float(), k.float() + q_rot, k_rot = q_rot.float(), k_rot.float() + att = torch.einsum("bhne,bhme->bhnm", q, k) att_rot = torch.einsum("bhne,bhme->bhnm", q_rot, k_rot) diff --git a/xformers/components/positional_embedding/rotary.py b/xformers/components/positional_embedding/rotary.py index 94bf5736f3..551089b3b3 100644 --- a/xformers/components/positional_embedding/rotary.py +++ b/xformers/components/positional_embedding/rotary.py @@ -61,16 +61,20 @@ def _update_cos_sin_tables(self, x, seq_dimension=1): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) - if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: + if ( + seq_len != self._seq_len_cached + or self._cos_cached.device != x.device + or self._cos_cached.dtype != x.dtype + ): self._seq_len_cached = seq_len - t = torch.arange(x.shape[seq_dimension], device=x.device).type_as( - self.inv_freq + t = torch.arange( + x.shape[seq_dimension], device=x.device, dtype=torch.float32 ) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self._cos_cached = emb.cos()[None, None, :, :] - self._sin_cached = emb.sin()[None, None, :, :] + self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype) + self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype) return self._cos_cached, self._sin_cached