Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Spec Decode] Add multi-proposer support for variable and flexible speculative decoding #7947

Open
wants to merge 46 commits into
base: main
Choose a base branch
from

Conversation

ShangmingCai
Copy link
Contributor

This PR plans to add multiple-proposer support for speculative decoding as mentioned in (#6300).

With this feature, varying scheduling policies could be applied:

  1. Variable proposer according to workload: use a faster proposer (like '[ngram]') when the workload is lightweight, and use a slower proposer with higher accuracy (e.g., draft-model-based proposer) when the workload is heavy.
  • The insight behind this:
    • For lightweight workloads, we can achieve lower TBT (time between tokens) and great goodput since a faster proposer (with acceptable acceptance rate) like '[ngram]' has ignorable proposal latency compared to the target model step latency (average_time_per_proposal_tok_ms < scoring_time_ms * 1%).
    • For heavy workloads, switch to the slower proposer to improve proposal quality to reduce computation waste. Under this situation, the overall goodput is way more important than per request latency. The draft acceptance rate of the proposer should be as high as possible to guarantee that the token throughput turns into goodput.
    • Under extremely heavy workloads, disable spec decode temporarily (disable_by_batch_size or disable when pending queue > 0).
    • This feature can be combined with SmartSpec ([RFC]: Automate Speculative Decoding #4565) in the future, which dynamically determines the speculation length.
  1. Fine-grained flexibility via SpecDecodeParams. The engine maintainer can detect and choose a suitable proposer to handle specific requests.
  • The insight behind this:
    • Each proposer might have advantages over certain types of requests. For example, RAG-based proposers or '[ngram]' are suitable for RAG requests. If the engine is fixed to use a RAG-based proposer for all requests, it may encounter performance degradation when the generation requires some creativity.
    • The draft-model-based proposers and MLP-based proposers are more general and stable. They can be the backup when RAG is not hit.
    • We can disable spec decode for specific requests by setting speculation length to zero or setting proposer to "disable". ([Feature]: Is it possible to control whether to use speculative decoding when making a request? #6993) ([Feature]: API control over speculative decoding and prefix caching #7569)
    • This strategy reduces the average completion time of each request for our inner application scenario under a lightweight workload. However, if multiple proposers are required in the same batch, this could cause some troubles for fine-grained processing. (I haven't figured out this part yet.)
      • Divide and conquer: If we use different proposers for different sequences in the same batch, all proposers will need to wait for the slowest proposer to finish on each batch for further scoring. It means those proposers with lower acceptance rates but faster speed, like Ngram, will be dragged down by the slowest proposer for each step when there might remain more steps for them to complete.
      • Proposal quality first: Use SpecDecodeWorkerMetrics to select the proposer with the best draft_acceptance_rate dynamically.
      • Proposal latency first: Always choose the proposer with the lowest proposal latency for the whole batch.
  • Other thought:
    • SpecDecodeParams is similar to LoRARequest, we need it if we want each request to have such flexibility when using spec decode. Maybe num_speculative_tokens can be moved from SequenceGroupMetadata to SpecDecodeParams so that we can provide speculation length scheduling for each request in the future.
    • If we want engine-level scheduling, not per-request-level, I think it would be better not to add SpecDecodeParams to cause more metadata overhead. I am not sure which scheduling granularity is best for spec decode yet, it could be related to the use case whether we use one backend for different apps.

Since NGram is a lightweight implementation that can be set by default without too many prerequisites, this PR utilizes it to implement a multi-proposer demo for now. More flexible choices will be added in the future.

Changed Code:

  • Add SpecDecodeParams
  • Change NGramWorker detection logic from ngram_prompt_lookup_max > 0 to model name
  • Add MultiProposerWorker and support NGram proposer as a backup to pair up with another slower but more accurate proposer.

This PR is still working in progress. The test has not been added yet.

@cadedaniel Can you tell me your opinion when you have time to check on this? I think the design details should be determined after discussion.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@ShangmingCai
Copy link
Contributor Author

@cadedaniel Do you have time to look at this PR and give me some advice? Also, I am investigating how to integrate LoRA and multi-step scheduling with spec decode. At this point, vllm supports many optimization options, but they cannot work with each other at the same time, this is frustrating to some degree when it comes to production development.

@cadedaniel
Copy link
Collaborator

I am afk for a few weeks unfortunately. cc @LiuXiaoxuanPKU @sroy745 @njhill vLLM spec decode experts

@LiuXiaoxuanPKU
Copy link
Collaborator

LiuXiaoxuanPKU commented Sep 10, 2024

Hi, thanks for the contribution. Will take a look by the end of today.

@LiuXiaoxuanPKU LiuXiaoxuanPKU self-assigned this Sep 10, 2024
@LiuXiaoxuanPKU
Copy link
Collaborator

LiuXiaoxuanPKU commented Sep 12, 2024

Thanks for the great PR description, it's well motivated. After scanning through the PR, I have some questions:

  1. Does the current design support 'multiple proposers of the same type'? Is it a legit requirement? For example, use two draft models (maybe different sizes) as the proposers. The current way of deciding the number of gpu/cpu blocks seems to have the assumption of a single draft model.
  2. It seems the current design mainly focuses on combining ngram and draft model based method. I'm curious about the compatibility of the current design of using medusa/eagle as one of the proposers.

For the multi-step scheduling and lora, I feel we can disable those two features when multiple proposers are used because vllm might use async scheduling in the next 1-2 month.

Also, I'm curious about if you have any workloads/numbers that demonstrate the benefits of the multi proposer method?

@ShangmingCai
Copy link
Contributor Author

Thank you very much for your time.

  1. Does the current design support 'multiple proposers of the same type'? Is it a legit requirement? For example, use two draft models (maybe different sizes) as the proposers. The current way of deciding the number of gpu/cpu blocks seems to have the assumption of a single draft model.

The original motivation is to support different types of proposers for various situations. However, I believe it is possible that some users may train multiple draft models of the same type for different applications to meet the data distribution. So the answer to the second question is yes, it is a legitimate but niche choice.

As for the first question, I haven't refactored the create_worker logic of SpecDecodeWorker yet since I think it would be better not to touch it before we discuss it thoroughly. Do we only need two proposers, or do we want more? For the first situation, we can do it by adding a parameter called backup_speculative_model or second_speculative_model (If Ngram is compulsory, then we don't need to add another one since we have ngram_prompt_lookup_max). For the second situation, do we input a list while configuring or consider proposers like LoRA and manage them by add_proposer and remove_proposer? It is not decided yet, but we can support it if we need it.

  1. It seems the current design mainly focuses on combining ngram and draft model based method. I'm curious about the compatibility of the current design of using medusa/eagle as one of the proposers.

Sure. I have tested all additional-weight-required implementations (including typical draft model, MLP, Medusa, Eagle) with Ngram. They are all compatible. You can consider this proposed class, MultiProposerWorker, as a wrapper or a dispatcher. However, I haven't tested the combination without Ngram yet.

Why Ngram? I think a lot about what would be most practical for users. We both know that training a draft model is not trivial. For most users, Ngram can be set by default without too many prerequisites, as we can see in (#5805).

Production Features

  • N-gram prompt lookup spec decode on by default

Therefore, I implemented this demo of supporting Ngram as a backup to pair up with another slower but more accurate proposer as the first step.
If the community and users think this feature is a great feature and want to exploit it, then we can make it more complete step by step.

Also, I'm curious about if you have any workloads/numbers that demonstrate the benefits of the multi proposer method?

There is currently no public dataset that can support our experiments. So we collected some real requests and conducted several experiments in an internal RAG Chatbot application. Consider we process each request one by one, if RAG hits, we use Ngram, if not, we use a 0.5B draft model, the average latency of all requests with MultiProposerWorker reduces by ~16% compared to pure NgramWorker, and ~40% compared to pure MultiStepWorker.

However, if requests with different proposers are coming in batches, things could be complicated as I mentioned in the PR statement due to continuous batching. Since we only have a target model instance and need to perform the batch together currently, 'Divide and conquer' might have a problem. So I think it would be better to discuss with the community before I turn this demo into a proper PR.

@ShangmingCai
Copy link
Contributor Author

Hello @LiuXiaoxuanPKU . Any news since our last conversation?

I think maybe there are 2 ways to support dynamic speculative decoding with various proposers during runtime.

The first one is that we handle the engine-level or request-level scheduling gracefully as we discussed before. The only tricky thing to note is that when we switch from a proposer that does not require KVCache to a proposer that does, we need to rerun the prefill phase of this new spec model for those running requests.

The other one is that we merely provide some online API to make speculative decoding switchable like LoRA and leave it to the engine maintainer.
Example:

curl -X POST http://localhost:8000/v1/switch_spec_method \
-H "Content-Type: application/json" \
-d '{
    "speculative_model": "[ngram]"  # or "None" / "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4" / "primary" / "backup"
}'

So we can control whether to use spec decode and which proposer to use without shutting down the engine with 'CTRL+C'. The clean switching can happen when all previous requests have been completed.

I know you have been busy removing Batch Expansion. Great work. Let me know what you think of this PR when you are available.

Copy link

mergify bot commented Nov 26, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ShangmingCai.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 26, 2024
Signed-off-by: ShangmingCai <csmthu@gmail.com>
@mergify mergify bot removed the needs-rebase label Dec 18, 2024
Signed-off-by: ShangmingCai <csmthu@gmail.com>
Signed-off-by: ShangmingCai <csmthu@gmail.com>
Signed-off-by: ShangmingCai <csmthu@gmail.com>
Signed-off-by: ShangmingCai <csmthu@gmail.com>
Signed-off-by: ShangmingCai <csmthu@gmail.com>
Signed-off-by: ShangmingCai <csmthu@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants