From 67b4fe6186971289de4881000df0c3b6d8c8c8cc Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 5 Oct 2021 09:20:07 -0700 Subject: [PATCH] move freqs to linspace --- README.md | 3 ++- rotary_embedding_torch/rotary_embedding_torch.py | 2 +- setup.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 0f2d98c..3314688 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,8 @@ from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding, broadcat pos_emb = RotaryEmbedding( dim = 32, - freqs_for = 'pixel' + freqs_for = 'pixel', + max_freq = 256 ) # queries and keys for frequencies to be rotated into diff --git a/rotary_embedding_torch/rotary_embedding_torch.py b/rotary_embedding_torch/rotary_embedding_torch.py index 38d0353..55d03c5 100644 --- a/rotary_embedding_torch/rotary_embedding_torch.py +++ b/rotary_embedding_torch/rotary_embedding_torch.py @@ -74,7 +74,7 @@ def __init__( elif freqs_for == 'lang': freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) elif freqs_for == 'pixel': - freqs = torch.logspace(0., log(max_freq / 2) / log(2), dim // 2, base = 2) * pi + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi elif freqs_for == 'constant': freqs = torch.ones(num_freqs).float() else: diff --git a/setup.py b/setup.py index 6e48844..f575d2e 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'rotary-embedding-torch', packages = find_packages(), - version = '0.1.0', + version = '0.1.1', license='MIT', description = 'Rotary Embedding - Pytorch', author = 'Phil Wang',