Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance]: Speculative Performance almost same or lower #5239

Closed
tolry418 opened this issue Jun 4, 2024 · 7 comments
Closed

[Performance]: Speculative Performance almost same or lower #5239

tolry418 opened this issue Jun 4, 2024 · 7 comments
Labels
performance Performance-related issues stale

Comments

@tolry418
Copy link

tolry418 commented Jun 4, 2024

Proposal to improve performance

@LiuXiaoxuanPKU Good to see you again. Thank you for your work.
I guess your working group releases SD a little by little.
I'm wondering about current SD version.
I had experiment result that using Speculative Decoding way is almost same performance or lower than normal(Only using Target Model) even low query per second. Is that reason for SD in progress?
I attached result bellow.
Could you tell me your thought about the result?

Report of performance regression

No response

Misc discussion on performance

Case. 300 prompt examples (Average Input 158), Set max output 100 Tokens .
Target Model "Llama-2-70B-chat" , Draft Model "TinyLlama 1.1B-chat-GPTQ "
Attached result as bellow.
question

@tolry418 tolry418 added the performance Performance-related issues label Jun 4, 2024
@LiuXiaoxuanPKU
Copy link
Collaborator

Hi @tolry418, thanks for raising the issue! Yeah, the performance is expected.

Please check out this issue for all potential tasks for improving the speculative decoding performance here. Basially all P0 issues are important for the performance.

From what I see for 70B model, my impression:
(1) 2GPUs is not enough. SD requires enough compute resources, i will say at least 4GPU with low request rate (0.5, 1, 1.5). (2) What's your proposed length? If it's <3, then the bonus token is important for the performance. Currently, bonus token is not enabled.
(3) What's your TP strategy for the draft model? I feel TP1 for the draft model will improve the performance.

But we do test prompt lookup decoding and see performance improvement on some workloads. Please check the results here.

@tolry418
Copy link
Author

tolry418 commented Jun 4, 2024

@LiuXiaoxuanPKU Thanks for your reply!
Above 70b experiment , I set like bellow.

python -m vllm.entrypoints.api_server --port 8000 \
--model="meta-llama/Llama-2-70b-chat-hf" --speculative-model="TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"\
--use-v2-block-manager --num-speculative-tokens 5 --tensor-parallel-size=2 --enforce-eager \
--load-format='safetensors' --enable-prefix-caching 

I set proposed length 5.
Could you tell me how can i set TP only for draft model? Doesn't '--tensor-parallel-size' apply both model? Can i set it separately?
About bonus token, current version already set bonus token set 1. doesn't it?
I checked bonus token added on proposed token. Did i go wrong?

And i got one more experiment result. To check GPU resource effect.
To tell the truth, I have only 2GPUs. So I down the target model size as llama 13B. The Draft model set as TheBloke/Llama-2-13B-chat-GPTQ like bellow.

python -m vllm.entrypoints.api_server --port 8000 \
--model="meta-llama/Llama-2-13b-chat-hf" --speculative-model="TheBloke/Llama-2-13B-chat-GPTQ"\
--use-v2-block-manager --num-speculative-tokens 5 --tensor-parallel-size=1 --enforce-eager \
--load-format='safetensors' --enable-prefix-caching 

Here is the result . SD still worse than only model.
QUESTION4
I'll do check effect of prompt lookup decoding following your suggestion.

@Dbxwz
Copy link

Dbxwz commented Jun 5, 2024

@tolry418 With the increase in BatchSize, the throughput of Speculative Decoding will be worse than that of the baseline (which uses only the target model).
SD trades computing power for time, uses more GPU resources to generate more tokens. However, SD wastes more GPU resources when the GPU hits a computational bottleneck. This is because SD requires the target model to compute draft tokens, which are sometimes rejected.
To be honest, the acceptance rate of draft tokens is about 60% (for eagle/lookahead, SD's rate is even lower). For instance, while the baseline can generate 100 tokens, SD might only generate 60 tokens when both compute 100 tokens.
So SD is more suitable for small throughput scenarios, where GPU has more computing power.

@ShangmingCai
Copy link
Contributor

ShangmingCai commented Jun 5, 2024

@LiuXiaoxuanPKU Thanks for your reply! Above 70b experiment , I set like bellow.

python -m vllm.entrypoints.api_server --port 8000 \
--model="meta-llama/Llama-2-70b-chat-hf" --speculative-model="TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"\
--use-v2-block-manager --num-speculative-tokens 5 --tensor-parallel-size=2 --enforce-eager \
--load-format='safetensors' --enable-prefix-caching 

I set proposed length 5. Could you tell me how can i set TP only for draft model? Doesn't '--tensor-parallel-size' apply both model? Can i set it separately? About bonus token, current version already set bonus token set 1. doesn't it? I checked bonus token added on proposed token. Did i go wrong?

And i got one more experiment result. To check GPU resource effect. To tell the truth, I have only 2GPUs. So I down the target model size as llama 13B. The Draft model set as TheBloke/Llama-2-13B-chat-GPTQ like bellow.

python -m vllm.entrypoints.api_server --port 8000 \
--model="meta-llama/Llama-2-13b-chat-hf" --speculative-model="TheBloke/Llama-2-13B-chat-GPTQ"\
--use-v2-block-manager --num-speculative-tokens 5 --tensor-parallel-size=1 --enforce-eager \
--load-format='safetensors' --enable-prefix-caching 

Here is the result . SD still worse than only model. QUESTION4 I'll do check effect of prompt lookup decoding following your suggestion.

I believe that the focus of speculative decoding is on the latency per prompt rather than the overall system throughput.

I did a similar experiment. I think the reason the current Speculative Decoding impl is slow is that the proposal process of the draft model is still time-consuming sequential autoregressive decoding. Even if the draft model size is 100 times smaller, its inference time overhead cannot be reduced by 100 times or even to 1/10 of the target model. Therefore, the proposal time overhead is the hidden bottleneck if you use a 0.5B to 7B model as a draft model. The larger the k, the greater the proposal time.

Therefore, I suppose implementations that do not require a draft model, such as Medusa/Eagle, will have greater potential in terms of performance gains.

@stefanobranco
Copy link

You can set --speculative_draft_tensor_parallel_size 1 on the newest release, that should help.

Do you know your acceptance rate? In my experience choice of draft model makes a pretty big impact on performance, and I think 5 speculative tokens with a quantized tinyLlama model might just lead to a lot of reject tokens, which then just means the larger model still has to do most of the work, with added overhead. Maybe reducing the amount of speculated tokens might be beneficial to overall performance as well?

Copy link

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

@github-actions github-actions bot added the stale label Oct 26, 2024
Copy link

This issue has been automatically closed due to inactivity. Please feel free to reopen if you feel it is still relevant. Thank you!

@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Nov 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues stale
Projects
None yet
5 participants