Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flash Attention v2 to Ops #1970

Merged
merged 6 commits into from
Jul 23, 2023

Conversation

IzzyPutterman
Copy link
Contributor

I also dropped the do_scaled as it is no longer needed (no scaling done to the do in v2).

@ptillet
Copy link
Collaborator

ptillet commented Jul 20, 2023

It seems like the performance significantly decreased. Not sure why 🤔 May be worth comparing the TTIR and TTGIR of the tutorial vs what the ops generate

@IzzyPutterman
Copy link
Contributor Author

I think the main change in the forward was the switch back from the MODES to the CAUSAL in 1 pass. I think converting the MODE==3 to FA v2 might require a little work.

I'll take a look at the ttir/ttgir as well as the MODEs

@ptillet
Copy link
Collaborator

ptillet commented Jul 20, 2023

Eh, I think this PR forgets to update the block size and num_stages for the fwd pass.

@IzzyPutterman
Copy link
Contributor Author

Yep, you are correct, I adjusted blocks and stages, locally on my A6000 forward perf is much improved.

@IzzyPutterman
Copy link
Contributor Author

Updated the new forward pass numbers for FA based on CI.

@ptillet ptillet merged commit de6f053 into triton-lang:main Jul 23, 2023
5 of 6 checks passed
pingzhuu pushed a commit to siliconflow/triton that referenced this pull request Apr 2, 2024
I also dropped the do_scaled as it is no longer needed (no scaling done
to the do in v2).

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants