Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Hi, this PR implements MLA decode algorithm, I would love to hear your thoughts on this design and implementation. ### The mystery Mat Absorb algorithm In the DeepSeekV2 paper, there was no specific formulas for how to do param matrixes absorption, but it just vaguely said that > Fortunately, due to the associative law of matrix multiplication, we can absorb 𝑊𝑈𝐾 into 𝑊𝑈𝑄, and 𝑊𝑈𝑉 into 𝑊𝑂 I know there were also some discussion on this, but I still can't find convinced answer. Here is my conclusion on this topic, Mat Absorb is only suitable for decode, do not use Mat Absorb for prefill, which means MLA decode and prefill should have different computation graph and different set of params, and Mat Absorb should merge param matrixes offline, materialize the merged param matrixes. You can find the two sets of Mat Absorb are two einsum ops in [test_mla_decode_kernel.py](https://github.com/flashinfer-ai/flashinfer/blob/cd74b457981d165a0afc5317e1a65a495e32b3c9/tests/test_mla_decode_kernel.py#L209). ``` # Now we merge W_UQ and W_UK (absorb W_UK into W_UQ) # q~q_lora_rank n~num_heads d~qk_nope_head_dim l~kv_lora_rank self.W_UQ_UK = torch.einsum("q n d, l n d -> q n l", W_UQ, W_UK).flatten(start_dim=1) # [1536, 65536] # Merge W_UV and W_O (absorb W_UV into W_O) # l~kv_lora_rank n~num_heads d~v_head_dim h~hidden_size self.W_UV_O = torch.einsum("l n d, h n d -> n l h", W_UV, W_O).flatten(start_dim=0, end_dim=1) # [65536, 5120] ``` I'm going to state my reason below. First let me depict the original MLA algorithm in computation graph (The final o_proj is omitted). It can be regarded as a 128 heads / (128+64) dim MHA algorithm. <img width="891" alt="image" src="https://github.com/user-attachments/assets/d2410816-0898-4a4e-afcf-86ad78044237"> And after Mat Absorb, MLA become a special 128 heads / (512+64) dim MQA algorithm, please note that the compressed_kv is used as both K and V directly without any projection. The detailed Mat Absorb algorithm can be found in [test_mla_decode_kernel.py](https://github.com/flashinfer-ai/flashinfer/blob/cd74b457981d165a0afc5317e1a65a495e32b3c9/tests/test_mla_decode_kernel.py#L206), in which `DeepseekV2AttentionVanilla` is the original DeepSeekV2 MLA inference implementation copied from huggingface and modified slightly, we take `DeepseekV2AttentionVanilla` as a reference to verify the correctness our Mat Absorb implementation. The `DeepseekV2AttentionMatAbsorbDecode` is our Mat Absorb implementation, it has two versions of inference function(`run_proof_of_concept`), one is implemented purely by torch, which can help you to make it clear how the Mat Absorb version of MLA inference works, and the other uses our new flashinfer MLA decode kernel, you can also take it as an usage example. <img width="980" alt="image" src="https://github.com/user-attachments/assets/46e03c8d-666c-49f9-869b-af862602050c"> Now let's do some calculation to see the if Mat Absorb version is performant (for the sake of convenience, we call the original MLA algo as Vanilla version) . ``` # We calculate the number of float ops needed by the part of MLA computation graph, # the input tensors are c_Q and cached k_pe and compressed_kv, the output tensor is the output hidden states. # We omitted the calculation from input hidden states to c_Q and cached k_pe and compressed_kv, # because it's the same for both vanilla version and mat-absorb version. def num_float_ops_vanilla(q_len, kv_len): return ( q_len*1536*(128*192) + # from c_Q to q_pe and q_nope, corresponding to q_b_proj kv_len * 512 * (128*(128+128)) + # from compressed_kv to k_nop and value_states, corresponding to kv_b_proj 128 * (q_len*64*kv_len + q_len*128*kv_len + q_len*kv_len*128) + # 128 heads MHA q_len * (128*128) * 5120 ) # from MHA output to output hidden states, corresponding to o_proj def mem_footprint_vanilla(q_len, kv_len): return ( q_len*1536 + 1536*(128*192) + kv_len*512 + 512*(128*(128+128)) + 128 * ((q_len*64 + 64*kv_len) + (q_len*128 + 128*kv_len)) + q_len * (128*128) + (128*128) * 5120 ) def num_float_ops_mat_absorb(q_len, kv_len): return ( q_len*1536*(128*64) + # from c_Q to q_pe, corresponding to W_QR q_len*1536*(128*512) + # from c_Q to q_nope, corresponding to W_UQUK 128 * (q_len*64*kv_len + q_len*512*kv_len + q_len*kv_len*512) + # 128 heads MQA q_len * (128*512) * 5120 ) # from MHA output to output hidden states, corresponding to W_UV_O def mem_footprint_mat_absorb(q_len, kv_len): return ( q_len*1536 + 1536*(128*64) + 1536*(128*512) + 128 * (q_len*64 + q_len*512) + 1*(64*kv_len + 512*kv_len) + q_len * (128*512) + (128*512) * 5120 ) # from MHA output to output hidden states, corresponding to W_UV_O kv_len = 10000 print(f"prefill: num_float_ops mat_absorb vs vanilla ratio ~ {num_float_ops_mat_absorb(kv_len, kv_len) / num_float_ops_vanilla(kv_len, kv_len)}" print(f"prefill: mem_footprint mat_absorb vs vanilla ratio ~ {mem_footprint_mat_absorb(kv_len, kv_len) / mem_footprint_vanilla(kv_len, kv_len)}") print(f"decode: num_float_ops mat_absorb vs vanilla ratio ~ {num_float_ops_mat_absorb(1, kv_len) / num_float_ops_vanilla(1, kv_len)}") print(f"decode: mem_footprint mat_absorb vs vanilla ratio ~ {mem_footprint_mat_absorb(1, kv_len) / mem_footprint_vanilla(1, kv_len)}") ``` The output is: ``` prefill: num_float_ops mat_absorb vs vanilla ratio ~ 3.3602009088734754 prefill: mem_footprint mat_absorb vs vanilla ratio ~ 2.2874373717252205 decode: num_float_ops mat_absorb vs vanilla ratio ~ 0.010941137164898957 decode: mem_footprint mat_absorb vs vanilla ratio ~ 1.167867978048944 ``` So we can conclude from the result above, for decode case Mat Absorb version only use about 1% computation compared to Vanilla version, and the memory footprint is at the same level with Vanilla version, but for prefill case, both computation and memory footprint are much higher than Vanilla version, so there is no reason to use Mat Absorb for prefill, but it's worth a try for decode. ### The kernel implementation design The new MLA decode kernel actually follows the same design concept as the current decode kernel, also reuse much of existing code base, we add some helper functions, such as `dim3_offset` for better code readability. The scheduling policy is also the same as the current one, we split task by kv-len dimension, and because the num_kv_heads is 1 now, we can't split num_kv_heads dimension across blocks now. There is one problem that the 128heads / (512+64)dim Q data is too large to fit into one SM's register file or even smem, which means we can't use only one SM/block to process one Q data, we have to tile the num_qo_heads dimension into gridDim.y, which can cause kv-cache data movement from gmem to smem multiple times, though this is inevitable. ### Further improvement - Tensor-core version implementation, since current MLA models (DeepSeek-V2-Lite, DeepSeek-V2, MiniCPM3) all have large num_qo_heads, which is large enough to feed data into mma fragment, but in my opinion maybe this can have limited performance improvement, because consider the above analysis, the bottle neck is IO bandwidth not the computation intensity. - Load more Q head data per thread and per block. The more we load Q head data, the less block number is needed, the less kv data movement from gmem to smem is needed. We can add more `q_nope_vec` per thread, also we can use smem to store more `q_nope_vec`. I would love to hear inputs from others. BTW, the new variable and function naming may be not follow current convention, I'm willing to change according to your advice. --------- Co-authored-by: tsu-bin <tsubin@gmail.com>
- Loading branch information