Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash Attention 2 #795

Closed
patrickvonplaten opened this issue Jul 17, 2023 · 25 comments
Closed

Flash Attention 2 #795

patrickvonplaten opened this issue Jul 17, 2023 · 25 comments
Assignees

Comments

@patrickvonplaten
Copy link

🚀 Feature

Adding Flash Attention 2

Motivation

Flash Attention 2 has just been added to the original repo: https://github.com/Dao-AILab/flash-attention . It clams to be almost twice as fast as Flash Attention 1 which is a huge speed-up. How can we best add it to xformers?

Pitch

Flash Attention 2 is very fast at pretty much no extra cost

Alternatives

N/A

Additional context

Many libraries depend on xformers to run flash attention. It would be great to add it here.

@patrickvonplaten patrickvonplaten changed the title Adding Flash Attention 2 Jul 17, 2023
@Skylion007
Copy link
Contributor

First step would be updating the flash_attention submodule. All the heuristics would probably need to be changed to to prefer it until we get implementations in CUTLASS and triton.

@danthe3rd
Copy link
Contributor

Hey,
Thanks for opening this issue! I want to run some benchmarks and testing first, and hopefully we can have it in xFormers this week.

@Skylion007
Copy link
Contributor

Clarification, FlashAttention2 actually uses CUTLASS.

@danthe3rd
Copy link
Contributor

I have an initial prototype working - but hitting some nans in Flash-Attention. I've opened an issue Dao-AILab/flash-attention#334
We get great speedups across the board, I'll share some benchmarks soon :)

@danthe3rd
Copy link
Contributor

Here are some benchmarks for the FW pass on A100:
https://pastebin.com/YEApkXBM

@lucidrains
Copy link

lucidrains commented Jul 18, 2023

can we expect this to be upstreamed to pytorch 2.0's scaled_dot_product_attention ? or should we open a separate issue

@danthe3rd
Copy link
Contributor

Hi @lucidrains ,
I would imagine this will be done at some point in the future, but it would be best to ask with an issue in the pytorch repo directly. Cc @drisspg

@lucidrains
Copy link

ok sounds good, will do!

@Boreaso
Copy link

Boreaso commented Jul 19, 2023

Hi @lucidrains , I would imagine this will be done at some point in the future, but it would be best to ask with an issue in the pytorch repo directly. Cc @drisspg

Hi @danthe3rd v100 GPUs is not currently supported in FlashAttention2. Any plans to support it in xformers?

@danthe3rd
Copy link
Contributor

Hi @danthe3rd v100 GPUs is not currently supported in FlashAttention2. Any plans to support it in xformers?

V100 GPUs are already supported in xformers (we have our own reimplementation of Flash-Attention)

@danthe3rd
Copy link
Contributor

As we update to Flashv2, Flash won't be available to Windows users, as Flashv2 is not available on windows (Dao-AILab/flash-attention#345). We will fallback to our own reimplementation for windows users.

@lucidrains
Copy link

lucidrains commented Jul 19, 2023

@danthe3rd you'll probably update your in-house kernel to flash attention 2 though? which kernel is being used to train llama?

@danthe3rd
Copy link
Contributor

I don't plan to update the in-house kernel any more. There are a few things from Flashv2 which are already in there, but further work would be needed to get the full performance. Also some changes won't work well within the available CUTLASS v2 abstractions we are using.
CUTLASS plans to add support for windows within months tho...

@lucidrains
Copy link

@danthe3rd ahh ok, so Tri's implementation will be the best available, for the right hardware

thanks for clarifying!

@bhack
Copy link

bhack commented Jul 20, 2023

triton-lang/triton#1970

@danthe3rd
Copy link
Contributor

We also plan to update the Triton version in xFormers at a later stage, but for now we focus on the CUDA one from @tridao as it provides the best performance

@bhack
Copy link

bhack commented Jul 20, 2023

Just in the case you are interested there is a parallel activity also in the official Pytorch repo:
pytorch/pytorch#105474

@danthe3rd danthe3rd self-assigned this Jul 20, 2023
@danthe3rd
Copy link
Contributor

danthe3rd commented Jul 20, 2023

We just merged an initial support for Flash-Attention v2 in xformers
cfea89f

Wheels will be available shortly, but in the meantime you can build it from source.

Summary:

  • Update third-party/flash-attention to the new repo/version
  • It's not available on Windows (where we will fallback on our home-made CUTLASS kernel, which is mostly as fast as Flash v1) [Flashv2] Windows support Dao-AILab/flash-attention#345
  • Limited to A100+ (whereas the previous Flashv1 worked with Sm75 as well)
  • Currently the BW pass only works when seqlen % 128 == 0 [flashv2] NaNs in bw pass for some inputs Dao-AILab/flash-attention#334. We will update the 3rd-party module once Tri has a fix
  • NOTE: Currently the Bw pass is not deterministic. Fix will come later from Tri
  • TESTING: Because we only support a subset of the test cases, I made sure when generating random shapes we filter-out the non-compatible ones. This allows to make sure we always have 20 random shapes that are tested (see shape_not_supported_reasons)
  • DISPATCH: Now we always dispatch to Flash with priority 1
  • SUPPORT: This adds support for head dimensions up to 256 (although performance is much worse after 160)
  • TRITON: This disables the fMHA triton implementation - I'm preparing an upgrade (cc @dianaml0 ). That's because it's imported directly from the Flash-Attention repo that we embed. Anyway we need to have an implementation that supports the post-mlir rewrite of triton

@ekagra-ranjan
Copy link

Here are some benchmarks for the FW pass on A100:
https://pastebin.com/YEApkXBM

@danthe3rd Can you please share the schema needed to compare different rows in the table you shared?

  1. Like what is the the different numbers here: f16 1-16384-16-80, p=0.0, BiasT=NoneType ?
  2. Which column represents flash v1: is it the cutlassF ?
  3. Do you have any ideas as to why BiasT=Tensor takes 50% more time than BiasT=NoneType for cutlassF?

@WindowsXp-Beta
Copy link

@ekagra-ranjan

  1. You can take a look at the benchmark code here. The output format is {dtype} {B}-{M}-{H}-{K}, where B, M, H, and K stand for batch size, sequence length, number of heads, and head dimension, respectively.
  2. cutlassF is xFormer's reimplementation of flash-attention using cutlass.
  3. Great question. I'm also trying to figure it out. You can refer to flashv2's blog and paper for more information. In fact, xFormers has already applied several optimizations in it such as parallel on Q's seqlen in FW and K, V's seqlen in BW. However, there are some differences. I believe their warp partition policy plays an important role. As it can reduce shared memory IO and warp synchronization significantly. Additionally, if you examine their code, you'll notice that they use three kernels for BW while xFormers fuses these operations in one kernel. Moreover, xFormers uses cutlass 2.x while flash-v2 uses cutlass 3.x.

Since I'm also new to cutlass and CUDA programming, perhaps @danthe3rd can provide us with more insights?

@danthe3rd
Copy link
Contributor

Do you have any ideas as to why BiasT=Tensor takes 50% more time than BiasT=NoneType for cutlassF?

This is mainly because of 2 reasons;
(1) fundamental reason: flash is fast because it avoids memory IOs (writing and reading the N^2 attention matrix multiple times). When you have a bias, you also need to read an N^2 tensor, which will take time.
(2) implementation: we have not really focused our efforts on the custom attention bias setting. It's mainly used for prototyping, and the correct solution should be to fuse whatever attention bias you want to use directly in the kernel, to avoid memory IO

However, there are some differences. I believe their warp partition policy plays an important role.

Yes I believe this is one of the big things that explains some of the performance gap. I haven't investigated more at this point to understand more precisely where the gap is.

@WindowsXp-Beta
Copy link

Thanks for explaining. I just realized that I misunderstood the third question. I thought it was about the gap between cutlassF and flash v2 lol.

@ekagra-ranjan
Copy link

ekagra-ranjan commented Jul 27, 2023

Thank you @WindowsXp-Beta and @danthe3rd for your replies! This is helpful!

@tmm1
Copy link
Contributor

tmm1 commented Aug 3, 2023

these should be resolved with #816

@tmm1 tmm1 mentioned this issue Aug 3, 2023
10 tasks
@killawhale2
Copy link

Shouldn't this issue be re-opened until #816 is merged? cc. @danthe3rd

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants