-
Notifications
You must be signed in to change notification settings - Fork 162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support MLA decode #551
Conversation
The append(extend) kernel also need to be considered w/ absorption. Because we read kv cache in the append kernel. The kv cache should be directly the compressed_kv. |
Maybe we can set a threshold value on the append token number based on some experience value to decide use w/ or w/o Mat Absorb. You can tweak the numbers of the above equations to see how the q_len, kv_len values affect the ratio values. |
Hi @tsu-bin thank you for contributing MLA, this is a great feature to have. Do you have any performance number yet? |
hi @yzh119 I just added the benchmark code. Here is the result from my workspace (The cards were not fully vacant, so the result may vibrate a little). It proved what I said above, the bottle neck is IO bandwidth. The current scheduling design tiles
|
hi @tsu-bin is it possible to get your email or wechat? Recently I am also very interested in the MLA kernel and I think we can work on it together. You can find my email in my profile. |
hi @jason-huang03, I'm very glad to discuss with you, already sent mail to you. Hope we can work out the most optimized solution for MLA kernels. |
Hi @tsu-bin
I think that make sense, I can take over the work of implementing it with tensor cores (I'm refactoring the codebase with cutlass and I think it will be easier after that), but we can merge the cuda-cores implementation first. |
Hi @yzh119 That's great, it's still a challenge to write mma code manually, are you planing to use CUTE to refactor current prefill implementation? |
Yes I'm working on that.
Sounds great! |
…orkEstimationDispatchedMLA
…now 512x128 is supported
…ck to accommodate MiniCPM3-4B who has 40 num_qo_heads
a20aa81
to
d8f10ee
Compare
hi @yzh119 rebase is done, please note that there are still some recent features, such as 'improve plan performance by using non-blocking memcpy #547', that still need to be applied to new code from this PR. BTW, when you are refactoring prefill code, maybe you can make some room for code reuse to ease the upcoming MLA prefill implementation. |
Sure, we are trying to unify all attention variants into data structures like this: |
tests/test_mla_decode_kernel.py
Outdated
output_mat_absorbed_use_torch = mla_mat_absorb.run_proof_of_concept(hidden_states.squeeze(1), compressed_kv_normed_cache, k_pe_cache, use_flashinfer_kernel=False) | ||
output_mat_absorbed_use_flashinfer = mla_mat_absorb.run_proof_of_concept(hidden_states.squeeze(1), compressed_kv_normed_cache, k_pe_cache, use_flashinfer_kernel=True) | ||
|
||
cos_sim_use_torch = F.cosine_similarity(output_vanilla.reshape(-1), output_mat_absorbed_use_torch.reshape(-1), dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why use cosine similarity here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually during my development of decode kernel, I tried hard to align the kernel implementation with the correct computation semantic, I found this metric is a good indicator to guide me if I was in the right direction to fix the discrepancy, the value changed from 0.1 -> 0.95 -> 0.9999.
But now I just tried the MSE as the metric, it value is rather large than I expected, it seems that the cosine similarity is a relaxed standard.
So there are still some discrepancy, I will look into this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @yzh119 I just updated the test case, added both MSE and WMAPE metrics. Below is the output from one run, it seems that the f32 to f16 conversion can cause some precision loss. I still can't find wrong implementation in decode kernel.
cos_use_torch_f32 = 1.0
wmape_use_torch_f32 = 9.26729843382549e-06
mse_use_torch_f32=0.0008153514936566353
cos_use_torch_f16 = 0.9997764825820923
wmape_use_torch_f16 = 0.016850485358918303
mse_use_torch_f16 = 3793.904296875
cos_use_flashinfer = 0.9999209642410278
wmape_use_flashinfer = 0.012483024544464835
mse_use_flashinfer = 1346.939453125
The MSE value is rather large, because the elements in the output tensor are at the level of thousands, you can try it by yourself.
Do you think WMAPE (corresponds to comparison of tensor's magnitude) and Cosine Similarity (corresponds to comparison of tensor's angle) together can be enough proof for algorithm correctness?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for confirmation, I think cosine similarity is okay in this case.
Let's merge this PR first and investigate the numerical issue later.
9677527
to
8f160ed
Compare
8f160ed
to
89edcda
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great contribution, thank you @tsu-bin !
Look forward to your CUTE refactor of prefill kernel, then we can continue to work on MLA prefill kernel. Hope soon the complete implementation can be ready for production. |
@zhyncs Hi,are there plans to port this to sglang in the future? |
Hi, this PR implements MLA decode algorithm, I would love to hear your thoughts on this design and implementation.
The mystery Mat Absorb algorithm
In the DeepSeekV2 paper, there was no specific formulas for how to do param matrixes absorption, but it just vaguely said that
I know there were also some discussion on this, but I still can't find convinced answer.
Here is my conclusion on this topic, Mat Absorb is only suitable for decode, do not use Mat Absorb for prefill, which means MLA decode and prefill should have different computation graph and different set of params, and Mat Absorb should merge param matrixes offline, materialize the merged param matrixes. You can find the two sets of Mat Absorb are two einsum ops in test_mla_decode_kernel.py.
I'm going to state my reason below. First let me depict the original MLA algorithm in computation graph (The final o_proj is omitted). It can be regarded as a 128 heads / (128+64) dim MHA algorithm.
And after Mat Absorb, MLA become a special 128 heads / (512+64) dim MQA algorithm, please note that the compressed_kv is used as both K and V directly without any projection. The detailed Mat Absorb algorithm can be found in test_mla_decode_kernel.py, in which
DeepseekV2AttentionVanilla
is the original DeepSeekV2 MLA inference implementation copied from huggingface and modified slightly, we takeDeepseekV2AttentionVanilla
as a reference to verify the correctness our Mat Absorb implementation. TheDeepseekV2AttentionMatAbsorbDecode
is our Mat Absorb implementation, it has two versions of inference function(run_proof_of_concept
), one is implemented purely by torch, which can help you to make it clear how the Mat Absorb version of MLA inference works, and the other uses our new flashinfer MLA decode kernel, you can also take it as an usage example.Now let's do some calculation to see the if Mat Absorb version is performant (for the sake of convenience, we call the original MLA algo as Vanilla version) .
The output is:
So we can conclude from the result above, for decode case Mat Absorb version only use about 1% computation compared to Vanilla version, and the memory footprint is at the same level with Vanilla version, but for prefill case, both computation and memory footprint are much higher than Vanilla version, so there is no reason to use Mat Absorb for prefill, but it's worth a try for decode.
The kernel implementation design
The new MLA decode kernel actually follows the same design concept as the current decode kernel, also reuse much of existing code base, we add some helper functions, such as
dim3_offset
for better code readability.The scheduling policy is also the same as the current one, we split task by kv-len dimension, and because the num_kv_heads is 1 now, we can't split num_kv_heads dimension across blocks now. There is one problem that the 128heads / (512+64)dim Q data is too large to fit into one SM's register file or even smem, which means we can't use only one SM/block to process one Q data, we have to tile the num_qo_heads dimension into gridDim.y, which can cause kv-cache data movement from gmem to smem multiple times, though this is inevitable.
Further improvement
q_nope_vec
per thread, also we can use smem to store moreq_nope_vec
. I would love to hear inputs from others.BTW, the new variable and function naming may be not follow current convention, I'm willing to change according to your advice.