Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flashv2] NaNs in bw pass for some inputs #334

Closed
danthe3rd opened this issue Jul 18, 2023 · 7 comments
Closed

[flashv2] NaNs in bw pass for some inputs #334

danthe3rd opened this issue Jul 18, 2023 · 7 comments

Comments

@danthe3rd
Copy link
Contributor

danthe3rd commented Jul 18, 2023

Thanks for pushing Flash-Attention v2! The speedups are really huge and this will make so many workloads much faster!
I would like to update to Flashv2 in xformers, however we have a few tests failing due to nans in the BW pass, and I have managed to isolate a minimum repro here. Is it possible to have a look?
Thanks a lot!

Repro code

# Tested with
# flash-attn-2.0.0.post1 (d1a3b52f17b914c93bf740654387b566a7330687)
# torch==2.0.0
# NVIDIA A100-SXM4-80GB
# cuda 11.8
import torch
import flash_attn

torch.manual_seed(352)
q = torch.randn([1, 1, 1, 16], dtype=torch.float16, device="cuda") * 3
k, v = [torch.randn([1, 1, 1, 16], dtype=torch.float16, device="cuda") * 3 for _ in range(2)]
q.requires_grad_(True)
k.requires_grad_(True)
v.requires_grad_(True)

grad = torch.full_like(q, 1.0)

out = flash_attn.flash_attn_func(q, k, v)
out.backward(grad)

print("Q gradient:", "NaNs!" if q.grad.isnan().any().item() else "OK")
print("K gradient:", "NaNs!" if k.grad.isnan().any().item() else "OK")
print("V gradient:", "NaNs!" if v.grad.isnan().any().item() else "OK")
print("Q: ", q.tolist())
print("K: ", k.tolist())
print("V: ", v.tolist())

Output

Q gradient: NaNs!
K gradient: OK
V gradient: OK
Q:  [[[[-2.705078125, 2.4375, 4.84375, 1.771484375, 1.064453125, 6.9296875, -3.859375, -4.0625, 2.37890625, -4.2578125, -1.05078125, 0.8740234375, 0.45166015625, 0.9140625, 9.0546875, -0.66259765625]]]]
K:  [[[[6.296875, 3.10546875, -1.3349609375, -0.03778076171875, 2.873046875, 0.081787109375, 1.984375, -0.85986328125, 0.028167724609375, 6.8984375, 4.2421875, -2.48828125, 0.2978515625, 3.55078125, -2.533203125, -3.24609375]]]]
V:  [[[[3.48046875, -3.859375, 2.8203125, -1.25, -2.10546875, 3.3984375, 1.28515625, -4.29296875, -1.08984375, 1.294921875, 1.2265625, 0.81005859375, -1.076171875, 0.45361328125, 1.958984375, 0.1907958984375]]]]
@tridao
Copy link
Contributor

tridao commented Jul 18, 2023

Thanks for the bug report, I'll take a look today.

@zihaozou
Copy link

Hi, I meet the same problem today. it seems like only q will get NaN gradients

@simonJJJ
Copy link
Contributor

Same here when using flash attnetion2 in training ViT.

@tridao
Copy link
Contributor

tridao commented Jul 25, 2023

Yes this is an issue when the sequence length is not divisible by 128. I figured out the problem and I'll try to push the fix soon (traveling right now).

@tridao
Copy link
Contributor

tridao commented Aug 1, 2023

v2.0.3 should hopefully fix this problem. Can you guys try the latest version (v2.0.3)?

@WindowsXp-Beta
Copy link

Tried and it works now.

@tmm1
Copy link
Contributor

tmm1 commented Aug 3, 2023

Fixed by a4f148b and can be closed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants