-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Flash Attention v2 to Ops #1970
Conversation
It seems like the performance significantly decreased. Not sure why 🤔 May be worth comparing the TTIR and TTGIR of the tutorial vs what the ops generate |
I think the main change in the forward was the switch back from the MODES to the CAUSAL in 1 pass. I think converting the MODE==3 to FA v2 might require a little work. I'll take a look at the ttir/ttgir as well as the MODEs |
Eh, I think this PR forgets to update the block size and |
Yep, you are correct, I adjusted blocks and stages, locally on my A6000 forward perf is much improved. |
Updated the new forward pass numbers for FA based on CI. |
I also dropped the do_scaled as it is no longer needed (no scaling done to the do in v2). --------- Co-authored-by: Philippe Tillet <phil@openai.com>
I also dropped the do_scaled as it is no longer needed (no scaling done to the do in v2).