Replies: 5 comments 8 replies
-
Re blockwise fill, we can probably adapt some of the utility kernels code for the memset kernel, probably as a more general op. But yeah, I think having this as a kernel could work just fine. [off topic] One potential thought is whether we can extend the accelerator gemm setup to cover |
Beta Was this translation helpful? Give feedback.
-
One thing I'd want to keep an eye on is opportunities to reuse buffers |
Beta Was this translation helpful? Give feedback.
-
I've read through and updated with a more realistic example lowering. I hope this is clearer. |
Beta Was this translation helpful? Give feedback.
-
GEMM-SOFTMAX-GEMM is only part of the flash attention optimization, do we have plan to tackle the other algorithms especially those can help on memory efficiency? |
Beta Was this translation helpful? Give feedback.
-
Added a section for updates/comparison against FA2. @zhanglx13 @sjw36 @sunway513 @krzysz00 |
Beta Was this translation helpful? Give feedback.
-
This design discussion is about how to implement gemm-softmax-gemm fusion in rocMLIR.
This is an algorithm presented in the paper known as "FlashAttention" : https://arxiv.org/pdf/2205.14135.pdf
Problem definition
The attention subgraph found in most transformer models could be expressed as follows :
O = ATTENTION (Q, K, V)
Let$$Q,K,V \in R^{N \times d}$$
Where as in SOFTMAX is :
Note : SOFTMAX is conducted in the row axis.
So to do this in a single fused kernel, we need to figure out a way to tile the whole computation.
Solution : FlashAttention
The paper describes a solution that rougly translates to following pseudo code :
DISCLAIMER : could be error prone as I decoded this in a day; so a review is more than welcome here.
Restrictions
O, l, m
are read and written by every row block in a row of blocks.Short-term solution
CK more or less uses the same algorithm in doing the compute.
Basically, we introduce rock.gridwise_attention op that lowers to somewhat psuedo IR as follows :
Updates from Flash Attention 2
Paper: https://tridao.me/publications/flash2/flash2.pdf)
There are 3 main changes in the flash attention 2
Algorithm change
In v1,
NOTE: softmax normalize is elementwise sub of rowmax followed by a division by rowsum
In v2
So as a summary, in v2 number of ops for correction is reduced.
Parallelism change
"These ideas of swapping the order of the loop (outer loop over row blocks and inner loop over column blocks, instead of the other way round in the original FlashAttention paper), as well as parallelizing over the sequence length dimension were first suggested and implemented by Phil Tillet in the Triton [17] implementation." from the paper
Wave partitioning change
I think this is how CK does it anyway. i.e. splitting Q across warps such that QK^t rows are kept within a warp for iteration / reductions.
Features required
Long-term ideas for generality
This is basically a mix of two main features :
Fusion of softmax to a preceding gemm
Fusion of gemm to preceding gemm (that maybe fused with softmax/element-wise ops)
Short term solution: high-level task breakdown.
New OP - blockwise_reduction
New OP - blockwise_fill
New OP - gridwise.attention op and its lowering (explained as an example above)
[Optional] IMPROVE - blockwise_* LDS->Reg ops.
NOTE : Add a member function to obtain a virtual memref type with annotation of per-thread sub-view of the output.
Beta Was this translation helpful? Give feedback.
All reactions