Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RetNet: relative position #49

Closed
fkodom opened this issue Aug 2, 2023 · 5 comments
Closed

RetNet: relative position #49

fkodom opened this issue Aug 2, 2023 · 5 comments
Assignees

Comments

@fkodom
Copy link

fkodom commented Aug 2, 2023

I believe there is a difference in relative position implemented here, and what is described in the paper. The issue I see is in theta_shift and rotate_every_two

def rotate_every_two(x):
    x1 = x[:, :, :, ::2]
    x2 = x[:, :, :, 1::2]
    x = torch.stack((-x2, x1), dim=-1)
    return x.flatten(-2)  # in einsum notation: rearrange(x, '... d j -> ... (d j)')\

# ...

def theta_shift(x, sin, cos):
    return (x * cos) + (rotate_every_two(x) * sin)

You can see here that theta_shift is applied to q and k, which have input shape (bsz, self.num_heads, tgt_len, self.key_dim) (after transpose).

class MultiScaleRetention(nn.Module):
    # ...
    def forward(
        self,
        x,
        rel_pos,
        chunkwise_recurrent=False,
        incremental_state=None
    ):
        bsz, tgt_len, _ = x.size()
        (sin, cos), inner_mask = rel_pos

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        g = self.g_proj(x)

        k *= self.scaling
        q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
        k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)

        qr = theta_shift(q, sin, cos)
        kr = theta_shift(k, sin, cos)

Why does rotate_every_two shuffle elements along the key_dim axis? This is not what was described in the paper (Equations 3, 4)

Screen Shot 2023-08-02 at 12 25 37 PM

Relative position embedding should depend only on the sequence position (m, n) and theta parameters. For that reason, I wonder if rotate_every_two is a bug?

@fkodom
Copy link
Author

fkodom commented Aug 2, 2023

Update: I see now that rotate_every_two is effectively multiplying by i. (If we view embedding vector of length d as a complex-valued vector of length d // 2, where odd-numbered indices correspond to the imaginary components.)

Still, this does not seem equivalent to what was described in the paper. The same theta_shift operation is applied to both q and k, whereas the paper only performs conjugation on k.

Then, I suppose that theta_shift is the Euler identity:

# Euler identity
e ** (i * x) = cos(x) + i * sin(x)

We can view theta_shift as multiplying complex-valued q with the complex exponential e ** (i * theta)

def theta_shift(x, sin, cos):
    return (x * cos) + (rotate_every_two(x) * sin)

which effectively increments n -> n + 1 in the exponential e ** (i * n * x).

To my current understanding, this is correct when applied to q, and almost correct when applied to the conjugate of k. We should take the complex conjugate of k after applying theta_shift.

def complex_conjugate(x):
    # Very similar to `rotate_every_two` from earlier
    x1 = x[:, :, :, ::2]
    x2 = x[:, :, :, 1::2]
    x = torch.stack((x1, -x2), dim=-1)
    return x.flatten(-2)  # in einsum notation: rearrange(x, '... d j -> ... (d j)')\

class MultiScaleRetention(nn.Module):
    # ...
    def forward(
        self,
        x,
        rel_pos,
        chunkwise_recurrent=False,
        incremental_state=None
    ):
        # ...
        qr = theta_shift(q, sin, cos)
        kr = complex_conjugate(theta_shift(k, sin, cos))

Finally, it seems that retention does not view q or k as complex-valued vectors -- just regular, real-valued embeddings. That explains why methods like MultiScaleRetention.parallel_forward don't account for complex values. (TBH, I'm still a little unclear on why that is, but at least it makes the code match the math.)

@shumingma shumingma assigned shumingma and donglixp and unassigned shumingma Aug 3, 2023
@sunyt32
Copy link
Contributor

sunyt32 commented Aug 3, 2023

$\mathbb{R}^d$ and $\mathbb{C}^{d/2}$ are Isomorphisms, we use real-valued for simplicity. complex_conjugate is not necessary, where $(ke^{-im\theta})^*=k^*e^{im\theta}$.

@fkodom
Copy link
Author

fkodom commented Aug 8, 2023

Just realized that I can put LaTeX into these comments. Definitely would have made my original question cleaner.

I see that $\mathbb{C}^{d/2}$ and $\mathbb{R}^d$ are easily interchangeable. Sounds like isomorphism is the technically correct term. 😅

complex_conjugate is not necessary, where $(ke^{-im\theta})^*=k^*e^{im\theta}$

In that case, shouldn't the expression be $\left(K_m e^{-im\theta}\right)^{\dagger} = K_m^* e^{im\theta}$ ? Just doesn't feel correct to treat $K_m$ as both a real- and complex-valued Tensor within the same expression.

But I guess it's not important. Whether the frequency is positive/negative, it's still a waveform with the same magnitude of frequency. And that probably doesn't affect anything noticeably.

@sunyt32 Thanks for the response! 🙏

@fkodom fkodom closed this as completed Aug 8, 2023
@bin123apple
Copy link

@fkodom I think this may be due to: $(e^{im\theta})^T = e^{-im\theta}$, if we start from a simple $2 \times 2$ example: $(e^{im\theta})^T$ = $([[a, -b],[b, a]])^T$ = $([[a, b],[-b, a]])$ = $e^{-im\theta}$. For $d \times d$ situation, based on some relative position embedding papers such as Roformer, This conclusion should also hold.
So: $o(n) = \sum Q_n(\gamma e^{i\theta})^{n-m}{K_m}^Tv_m $
$= \sum (Q_n\gamma^n e^{in\theta})(\gamma^{-m}e^{-im\theta}{K_m}^T)v_m$
$= \sum (Q_n\gamma^n e^{in\theta})(\gamma^{-m}(e^{im\theta})^T{K_m}^T)v_m $
$= \sum (\gamma^n Q_n e^{in\theta})(\gamma^{-m}(K_m e^{im\theta})^T)v_m $
I think this conclusion is correct and corresponds to the code. But it is obviously not the same as the Eq. (3) in the paper.
Then I checked the Eq. (3) again and I think maybe the final form of Eq. (3) should be $\sum (\gamma^n Q_n e^{in\theta})(\gamma^{-m}(K_m)^Te^{-im\theta})v_m$ instead of $\sum (\gamma^n Q_n e^{in\theta})(\gamma^{-m}(K_me^{-im\theta}))^Tv_m$ (Because $(e^{-im\theta}(K_m)^T)$ obviously satisfies the commutative law)??
It would be great if the author @sunyt32 can help to explain whether my understanding is correct or point out where I'm wrong. And by the way, this work is outstanding!

@sunyt32
Copy link
Contributor

sunyt32 commented Oct 7, 2023

@bin123apple You are right under the $2\times 2$ real number view, which is the same as the implementation of Roformer. Besides, for a complex view, there is also an implementation in LLaMA, where RoPE is added by transforming q, k into complex.

In a nutshell, if you treat $Q, K$ as real matrixs, then you can follow Roformer. If you treat them as complex matrixs, you can follow LLaMA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants