diff --git a/rotary_embedding_torch/rotary_embedding_torch.py b/rotary_embedding_torch/rotary_embedding_torch.py index 14cbb60..c89cf80 100644 --- a/rotary_embedding_torch/rotary_embedding_torch.py +++ b/rotary_embedding_torch/rotary_embedding_torch.py @@ -105,8 +105,8 @@ def rotate_queries_and_keys(self, q, k, seq_dim = -2): assert self.use_xpos device, seq_len = q.device, q.shape[seq_dim] seq = torch.arange(seq_len, device = device) - freqs = self.forward(lambda: seq, cache_key = seq_len) - scale = self.get_scale(lambda: seq, cache_key = seq_len) + freqs = self.forward(lambda: seq, cache_key = f'freqs:{seq_len}') + scale = self.get_scale(lambda: seq, cache_key = f'scale:{seq_len}') rotated_q = apply_rotary_emb(freqs, q, scale = scale) rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1) return rotated_q, rotated_k @@ -127,7 +127,7 @@ def get_scale(self, t, cache_key = None): scale = torch.cat((scale, scale), dim = -1) if exists(cache_key): - self.cache[cache_key] = freqs + self.cache[cache_key] = scale return scale diff --git a/setup.py b/setup.py index f074122..e7f2f03 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'rotary-embedding-torch', packages = find_packages(), - version = '0.2.1', + version = '0.2.2', license='MIT', description = 'Rotary Embedding - Pytorch', long_description_content_type = 'text/markdown',