-
Notifications
You must be signed in to change notification settings - Fork 588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize MLA/GQA/MQA Triton decoding #1138
Conversation
Tested on A100-80G:
Llama-3-8B
Reproduce:
|
Nice work! TLDR: Reuse from L2 to block. Is it right? @ispobock |
ref #905 (comment) After a brief look, the throughput has roughly doubled compared to the previous MLA version, great work! cc @merrymercy @Ying1123 @hnyls2002 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall LGTM @ispobock
Currently, all CIs have passed, including when Llama3 disables FlashInfer, it will go through some logic. The benchmark and eval of this PR also meet expectations. The verification of DeepSeek V2 on A100 TP8 and H100 TP8 can be done later, and try to continue analyzing whether there is room for optimization with nsys
and ncu
. After yesterday's simple discussion, it is mainly the quick implementation by @ispobock, also thanks a lot for the implementation reference by @grimoire InternLM/lmdeploy#1649 and discussion comments from @lzhangzz
@MARD1NO and @yzh119 , if you are interested, welcome to help review and give some optimization suggestions. Thanks.
@Xu-Chen @lxww302 I noticed that you have used the implementation of SGLang's DeepSeek V2 TP8 MLA before. Could you help verify the performance of the new version, for example, on devices you have like A100 TP8, A800 TP8, H100 TP8, etc.? Thanks very mauch! git clone -b decode_gqa_opt https://github.com/ispobock/sglang.git
cd sglang
pip install --upgrade pip
pip install -e "python[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V2 --port 30000 --trust-remote-code --disable-radix-cache --enable-mla --tp=8
python3 -m sglang.bench_serving --backend sglang |
I have 8Xh100s, I executed your command
|
Thanks! Is it H100 SXM or NVL? @81549361 |
May you collect the env info with |
Not sure if this could be helpful or not, but I ran llmperf for both main branch and incoming branch. Overall this PR seems to make things much faster:
llmperf command usedpython token_benchmark_ray.py \
--model "${MODEL}" \
--mean-input-tokens 1500 \
--stddev-input-tokens 150 \
--mean-output-tokens 245 \
--stddev-output-tokens 20 \
--max-num-completed-requests "64" \
--timeout 7200 \
--num-concurrent-requests "8" \
--llm-api openai \
--additional-sampling-params '{}' main branch{
"version": "2023-08-31",
"mean_input_tokens": 1500,
"stddev_input_tokens": 150,
"mean_output_tokens": 245,
"stddev_output_tokens": 20,
"num_concurrent_requests": 8,
"results_inter_token_latency_s_quantiles_p25": 0.03990331099470551,
"results_inter_token_latency_s_quantiles_p50": 0.057948063652443406,
"results_inter_token_latency_s_quantiles_p75": 0.08040066503004678,
"results_inter_token_latency_s_quantiles_p90": 0.08383243498141633,
"results_inter_token_latency_s_quantiles_p95": 0.08516111126646178,
"results_inter_token_latency_s_quantiles_p99": 0.10164050496592587,
"results_inter_token_latency_s_mean": 0.06027883582796916,
"results_inter_token_latency_s_min": 0.03675615620323733,
"results_inter_token_latency_s_max": 0.1020314351556132,
"results_inter_token_latency_s_stddev": 0.0211621866217624,
"results_ttft_s_quantiles_p25": 0.4133454477414489,
"results_ttft_s_quantiles_p50": 1.016814228380099,
"results_ttft_s_quantiles_p75": 11.284791270736605,
"results_ttft_s_quantiles_p90": 11.749069100199268,
"results_ttft_s_quantiles_p95": 11.803535583987832,
"results_ttft_s_quantiles_p99": 11.955875016311182,
"results_ttft_s_mean": 5.338054827436281,
"results_ttft_s_min": 0.2691499590873718,
"results_ttft_s_max": 12.148427874781191,
"results_ttft_s_stddev": 5.495650480946165,
"results_end_to_end_latency_s_quantiles_p25": 11.498506030999124,
"results_end_to_end_latency_s_quantiles_p50": 15.51382327103056,
"results_end_to_end_latency_s_quantiles_p75": 22.9230548851192,
"results_end_to_end_latency_s_quantiles_p90": 23.657817971240732,
"results_end_to_end_latency_s_quantiles_p95": 23.97725157707464,
"results_end_to_end_latency_s_quantiles_p99": 24.61372328522615,
"results_end_to_end_latency_s_mean": 16.84320118615142,
"results_end_to_end_latency_s_min": 3.5896931253373623,
"results_end_to_end_latency_s_max": 25.067169249989092,
"results_end_to_end_latency_s_stddev": 6.076063540076458,
"results_request_output_throughput_token_per_s_quantiles_p25": 12.432897921487776,
"results_request_output_throughput_token_per_s_quantiles_p50": 17.950591526918625,
"results_request_output_throughput_token_per_s_quantiles_p75": 25.023589881617227,
"results_request_output_throughput_token_per_s_quantiles_p90": 25.61754857375858,
"results_request_output_throughput_token_per_s_quantiles_p95": 26.080372795146523,
"results_request_output_throughput_token_per_s_quantiles_p99": 27.12744569799552,
"results_request_output_throughput_token_per_s_mean": 18.7890127702506,
"results_request_output_throughput_token_per_s_min": 9.773737854436295,
"results_request_output_throughput_token_per_s_max": 27.204481327432568,
"results_request_output_throughput_token_per_s_stddev": 6.462698432888159,
"results_number_input_tokens_quantiles_p25": 1419.75,
"results_number_input_tokens_quantiles_p50": 1513.5,
"results_number_input_tokens_quantiles_p75": 1585.25,
"results_number_input_tokens_quantiles_p90": 1726.1000000000001,
"results_number_input_tokens_quantiles_p95": 1812.2499999999998,
"results_number_input_tokens_quantiles_p99": 1942.5299999999997,
"results_number_input_tokens_mean": 1515.53125,
"results_number_input_tokens_min": "1125",
"results_number_input_tokens_max": "1986",
"results_number_input_tokens_stddev": 157.1251617922921,
"results_number_output_tokens_quantiles_p25": 271.25,
"results_number_output_tokens_quantiles_p50": 287.0,
"results_number_output_tokens_quantiles_p75": 304.5,
"results_number_output_tokens_quantiles_p90": 318.0,
"results_number_output_tokens_quantiles_p95": 326.4,
"results_number_output_tokens_quantiles_p99": 340.37,
"results_number_output_tokens_mean": 280.546875,
"results_number_output_tokens_min": "78",
"results_number_output_tokens_max": "341",
"results_number_output_tokens_stddev": 43.62427229119711,
"results_num_requests_started": 64,
"results_error_rate": 0.0,
"results_number_errors": 0,
"results_error_code_frequency": "{}",
"results_mean_output_throughput_token_per_s": 122.91809365087381,
"results_num_completed_requests": 64,
"results_num_completed_requests_per_min": 26.288247263678944,
"timestamp": 1723922364
} incoming branch{
"version": "2023-08-31",
"mean_input_tokens": 1500,
"stddev_input_tokens": 150,
"mean_output_tokens": 245,
"stddev_output_tokens": 20,
"num_concurrent_requests": 8,
"results_inter_token_latency_s_quantiles_p25": 0.04048058146969138,
"results_inter_token_latency_s_quantiles_p50": 0.04134249718749723,
"results_inter_token_latency_s_quantiles_p75": 0.042773683461634744,
"results_inter_token_latency_s_quantiles_p90": 0.04477736409998821,
"results_inter_token_latency_s_quantiles_p95": 0.04621570852103804,
"results_inter_token_latency_s_quantiles_p99": 0.04943066709057319,
"results_inter_token_latency_s_mean": 0.04202164194913325,
"results_inter_token_latency_s_min": 0.03828613981456747,
"results_inter_token_latency_s_max": 0.05096760665209523,
"results_inter_token_latency_s_stddev": 0.0023344492257422154,
"results_ttft_s_quantiles_p25": 0.3779949996387586,
"results_ttft_s_quantiles_p50": 0.403224729700014,
"results_ttft_s_quantiles_p75": 0.44007199979387224,
"results_ttft_s_quantiles_p90": 0.4766438877210021,
"results_ttft_s_quantiles_p95": 0.4872294148663059,
"results_ttft_s_quantiles_p99": 0.49447528753429654,
"results_ttft_s_mean": 0.4035295032663271,
"results_ttft_s_min": 0.2787872082553804,
"results_ttft_s_max": 0.49528229096904397,
"results_ttft_s_stddev": 0.05853017613187361,
"results_end_to_end_latency_s_quantiles_p25": 10.952284958562814,
"results_end_to_end_latency_s_quantiles_p50": 11.724067542003468,
"results_end_to_end_latency_s_quantiles_p75": 12.392438833485357,
"results_end_to_end_latency_s_quantiles_p90": 12.949160708626732,
"results_end_to_end_latency_s_quantiles_p95": 13.369823349895887,
"results_end_to_end_latency_s_quantiles_p99": 13.602660472076385,
"results_end_to_end_latency_s_mean": 11.063488117179077,
"results_end_to_end_latency_s_min": 2.310943207703531,
"results_end_to_end_latency_s_max": 13.658869832754135,
"results_end_to_end_latency_s_stddev": 2.5735290879206163,
"results_request_output_throughput_token_per_s_quantiles_p25": 23.376963498120137,
"results_request_output_throughput_token_per_s_quantiles_p50": 24.13135072660546,
"results_request_output_throughput_token_per_s_quantiles_p75": 24.70095651189223,
"results_request_output_throughput_token_per_s_quantiles_p90": 25.105406335351436,
"results_request_output_throughput_token_per_s_quantiles_p95": 25.318698051259776,
"results_request_output_throughput_token_per_s_quantiles_p99": 26.00064578019821,
"results_request_output_throughput_token_per_s_mean": 23.819321580789712,
"results_request_output_throughput_token_per_s_min": 19.61920693264775,
"results_request_output_throughput_token_per_s_max": 26.11816971864744,
"results_request_output_throughput_token_per_s_stddev": 1.3040854008387603,
"results_number_input_tokens_quantiles_p25": 1419.75,
"results_number_input_tokens_quantiles_p50": 1513.5,
"results_number_input_tokens_quantiles_p75": 1585.25,
"results_number_input_tokens_quantiles_p90": 1726.1000000000001,
"results_number_input_tokens_quantiles_p95": 1812.2499999999998,
"results_number_input_tokens_quantiles_p99": 1942.5299999999997,
"results_number_input_tokens_mean": 1515.53125,
"results_number_input_tokens_min": "1125",
"results_number_input_tokens_max": "1986",
"results_number_input_tokens_stddev": 157.1251617922921,
"results_number_output_tokens_quantiles_p25": 265.75,
"results_number_output_tokens_quantiles_p50": 285.0,
"results_number_output_tokens_quantiles_p75": 296.25,
"results_number_output_tokens_quantiles_p90": 317.0,
"results_number_output_tokens_quantiles_p95": 322.0,
"results_number_output_tokens_quantiles_p99": 338.84999999999997,
"results_number_output_tokens_mean": 265.484375,
"results_number_output_tokens_min": "47",
"results_number_output_tokens_max": "342",
"results_number_output_tokens_stddev": 66.06466101119273,
"results_num_requests_started": 64,
"results_error_rate": 0.0,
"results_number_errors": 0,
"results_error_code_frequency": "{}",
"results_mean_output_throughput_token_per_s": 162.73324599263228,
"results_num_completed_requests": 64,
"results_num_completed_requests_per_min": 36.77803923322394,
"timestamp": 1723922279
} |
|
What is your startup command? |
@81549361 Startup command I used for both are the same:
|
@81549361 Did you add |
Awesome! Will test DeepSeek-V2-Chat on 8*A800 next week. Tested on A800-80G: DeepSeek-V2-Lite Main branch ( DeepSeek-V2-Lite-Chat on 1 * A800-80G )
This PR ( DeepSeek-V2-Lite-Chat on 1 * A800-80G )
|
Tested DeepSeek-V2-Chat-0628 on 8*A800 serve python3 -m sglang.launch_server \
--model-path /data/model-cache/deepseek-ai/DeepSeek-V2-Chat-0628 \
--served-model-name deepseek-chat \
--tp 8 \
--enable-mla \
--disable-radix-cache \
--mem-fraction-static 0.87 \
--schedule-conservativeness 0.1 \
--chunked-prefill-size 32768 \
--max-prefill-tokens 163840 \
--trust-remote-code \
--host 0.0.0.0 \
--port 50521 test python3 -m sglang.bench_serving \
--backend sglang \
--dataset-name sharegpt \
--dataset-path /data/model-cache/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json \
--model /data/model-cache/deepseek-ai/DeepSeek-V2-Chat-0628 \
--port 50521 result
Should I use base model? |
@halexan You don’t need to set this
|
Tested DeepSeek-V2-Chat-0628 on 8*A800
test
This PR ( DeepSeek-V2-Chat-0628 on 8 * A800-80G )
|
Does your 8*A800 has nvlink? |
Yes |
H100 SXM TP8 with DeepSeek V2 current PR
Compared to the main branch, it has improved by about 35%. main branch
I plan to merge this PR first, and the compatibility support for fp8 will be completed in another PR. @ispobock @merrymercy @Ying1123 @hnyls2002 |
To further improve performance, both W8A8 (FP8) and FP8 KV Cache are necessary and should be supported for DeepSeek V2. |
Furthermore, should pay attention to the MLA implementation of FlashInfer ( flashinfer-ai/flashinfer#237) |
@jon-chuang When do you expect to complete the support for MLA in FlashInfer? May you synchronize the approximate time? Thanks. |
@ispobock - do you mind telling a bit more about how you spotted this issue or this optimization? |
@microwish Yeah, we did the profiling first and found the decoding kernel took most of the time. And then we checked the kernel with ncu and get some directions for optimizing the memory access. |
Motivation
Optimize memory access for MLA/GQA/MQA decoding.
Modification
One block handle
BLOCK_H
q heads with shared k/v head. Inspired by InternLM/lmdeploy#1649.