-
Notifications
You must be signed in to change notification settings - Fork 207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RetNet: relative position #49
Comments
Update: I see now that Still, this does not seem equivalent to what was described in the paper. The same Then, I suppose that # Euler identity
e ** (i * x) = cos(x) + i * sin(x) We can view def theta_shift(x, sin, cos):
return (x * cos) + (rotate_every_two(x) * sin) which effectively increments To my current understanding, this is correct when applied to def complex_conjugate(x):
# Very similar to `rotate_every_two` from earlier
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((x1, -x2), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
class MultiScaleRetention(nn.Module):
# ...
def forward(
self,
x,
rel_pos,
chunkwise_recurrent=False,
incremental_state=None
):
# ...
qr = theta_shift(q, sin, cos)
kr = complex_conjugate(theta_shift(k, sin, cos)) Finally, it seems that retention does not view |
|
Just realized that I can put LaTeX into these comments. Definitely would have made my original question cleaner. I see that
In that case, shouldn't the expression be But I guess it's not important. Whether the frequency is positive/negative, it's still a waveform with the same magnitude of frequency. And that probably doesn't affect anything noticeably. @sunyt32 Thanks for the response! 🙏 |
@fkodom I think this may be due to: |
@bin123apple You are right under the In a nutshell, if you treat |
I believe there is a difference in relative position implemented here, and what is described in the paper. The issue I see is in theta_shift and rotate_every_two
You can see here that
theta_shift
is applied toq
andk
, which have input shape(bsz, self.num_heads, tgt_len, self.key_dim)
(after transpose).Why does
rotate_every_two
shuffle elements along thekey_dim
axis? This is not what was described in the paper (Equations 3, 4)Relative position embedding should depend only on the sequence position (
m
,n
) andtheta
parameters. For that reason, I wonder ifrotate_every_two
is a bug?The text was updated successfully, but these errors were encountered: