Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flashv2/BW] nan in some configurations #443

Closed
danthe3rd opened this issue Aug 11, 2023 · 7 comments
Closed

[flashv2/BW] nan in some configurations #443

danthe3rd opened this issue Aug 11, 2023 · 7 comments

Comments

@danthe3rd
Copy link
Contributor

danthe3rd commented Aug 11, 2023

Hi, after upgrading Flashv2 to 2.0.4 (in facebookresearch/xformers#816), we still have some test failures in xformers. Here is a simple repro code:

Repro code:

# Tested with
# flash-attn 2.0.4 (v2.0.4 - d30f2e1cd50185c98ed88c0684b4a603f15bee37)
# torch==2.0.0
# NVIDIA A100-SXM4-80GB
# cuda 11.8
import torch
import flash_attn


q_cuseqlen = torch.tensor([0, 76, 110, 256], device='cuda', dtype=torch.int32)
k_cuseqlen = torch.tensor([0, 1, 2, 3], device='cuda', dtype=torch.int32)
Mq = 256
Mk = 3
H = 1
K = 32

torch.manual_seed(0)
q = torch.randn([Mq, H, K], dtype=torch.float16, device="cuda") * 3
k, v = [torch.randn([Mk, H, K], dtype=torch.float16, device="cuda") * 3 for _ in range(2)]
q.requires_grad_(True)
k.requires_grad_(True)
v.requires_grad_(True)

grad = torch.full_like(q, 1.0)

out = flash_attn.flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=True)
out.backward(grad)

print("flash_attn", flash_attn.__version__)
print("Q gradient:", "NaNs!" if q.grad.isnan().any().item() else "OK")
print("K gradient:", "NaNs!" if k.grad.isnan().any().item() else "OK")
print("V gradient:", "NaNs!" if v.grad.isnan().any().item() else "OK")

Output

flash_attn 2.0.4
Q gradient: NaNs!
K gradient: OK
V gradient: OK
@tmm1
Copy link
Contributor

tmm1 commented Aug 11, 2023

I replicated the results above on a 3090 as well. The result is also the same changing float16 to bfloat16.

@tridao
Copy link
Contributor

tridao commented Aug 11, 2023

Thanks for the bug report.
Something I don't quite understand: the cuseqlen says there are 2 sequences, one from indices 0 to 46 and one from indices 46 to 256. However, the K & V passed in only has length 2, so it doesn't agree with what cuseqlen is describing.

When I changed K & V to have length 256, the gradients are ok.

Do you mean to pass in a different cuseqlen for K & V?
When I pass in a different cuseqlen_k = torch.tensor([ 0, 1, 2], device=device, dtype=torch.int32) the gradients are also ok.

@danthe3rd
Copy link
Contributor Author

Woops my bad indeed. Let me close this and reopen once I figure out my issue

@danthe3rd
Copy link
Contributor Author

danthe3rd commented Aug 16, 2023

Reopening - I fixed the repro script.
The issue only happens with causal=True (although in this case with 1 key, it's equivalent to setting causal=False)

@danthe3rd danthe3rd reopened this Aug 16, 2023
@tridao
Copy link
Contributor

tridao commented Aug 16, 2023

I can reproduce the bug now, thank you @danthe3rd! I'm investigating.

@tridao
Copy link
Contributor

tridao commented Aug 16, 2023

I've (hopefully) fixed this in v2.0.8. CI is building all the CUDA wheels now. Thanks for the bug report again!

@danthe3rd
Copy link
Contributor Author

Confirming that all xformers tests pass now on A100 :)
Thanks a lot for the prompt fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants