-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flashv2] NaNs in bw pass for some inputs #334
Comments
Thanks for the bug report, I'll take a look today. |
Hi, I meet the same problem today. it seems like only q will get NaN gradients |
Same here when using flash attnetion2 in training ViT. |
Yes this is an issue when the sequence length is not divisible by 128. I figured out the problem and I'll try to push the fix soon (traveling right now). |
v2.0.3 should hopefully fix this problem. Can you guys try the latest version (v2.0.3)? |
Tried and it works now. |
Fixed by a4f148b and can be closed. |
Thanks for pushing Flash-Attention v2! The speedups are really huge and this will make so many workloads much faster!
I would like to update to Flashv2 in xformers, however we have a few tests failing due to
nans
in the BW pass, and I have managed to isolate a minimum repro here. Is it possible to have a look?Thanks a lot!
Repro code
Output
The text was updated successfully, but these errors were encountered: