Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output inconsistent for autoregressive performer #82

Open
GanjinZero opened this issue Jan 24, 2022 · 2 comments
Open

Output inconsistent for autoregressive performer #82

GanjinZero opened this issue Jan 24, 2022 · 2 comments

Comments

@GanjinZero
Copy link

I want to apply autoregressive performer for decoding.

import torch
from performer_pytorch import Performer
from torch import nn

attn = Performer(
    dim = 128,
    heads = 4,
    depth = 2,
    local_attn_heads = 2,
    dim_head = 32,
    causal = True,
    use_scalenorm=True
).cuda()
attn.fix_projection_matrices_()
attn.eval()

x = torch.cat([torch.ones((1,50,128)), torch.ones((1,50,128)) * 0.5, torch.ones((1,400,128)) * (-0.1), torch.ones((1,400,128)) * 0.1], dim=1).cuda()
y = attn(x)

x0 = x[0,0:100].unsqueeze(0)
y0 = attn(x0)

print((y[0][0:100]-y0[0][0:100]).norm())

The output is tensor(0.0003, device='cuda:0', grad_fn=).

If I turn off the use_scalenorm, the output is tensor(0.0085, device='cuda:0', grad_fn=).
This shows the inconsistent output for autoregressive performer.

@GanjinZero
Copy link
Author

I guess the reason is from k = create_kernel(k, is_query = False) from FastAttention.forward, in the softmax_kernel operation, it has line data_dash = ratio * (torch.exp(data_dash - diag_data - torch.amax(data_dash, dim=(-1, -2), keepdim=True)) + eps).
torch.amax(data_dash, dim=(-1, -2), keepdim=True)) contains information for later time hidden states, and this information is passed to previous time hidden states.

@lucidrains
Copy link
Owner

@GanjinZero oh shoot, yea, those maxes are for numerical stability, but i think they should be detached fc8b784 can you let me know if this resolves the issue on your end?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants