-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC]: Multi-Step Scheduling #6854
Comments
The result seems pretty impressive! |
+1 on prioritizing this. Really great result! |
Wow! These results look great. Happy to help out as we can |
@SolitaryThinker thanks for writing out a clear proposal, results look great! I have some followup questions:
|
@megha95 thanks for the great questions!
|
I think for this one, it is also easy to add different policies. for example, we can do early return if number of eos tokens > X or something like that. But we will need more benchmark for these cases |
@SolitaryThinker : how's the compatibility with PP+TP going? |
@SolitaryThinker Thanks for the great work! I have a few comments:
|
@zhisbug PP is partly working now (debugging some hanging issues), going to continue focus on getting PP+TP working as a priority |
|
I guess this is incompatible with guided decoding (#5423), correct? Since guided decoding needs to see output tokens on every decode step. |
@jon-chuang It should be possible to make it compatible. Currently each step's pythonized output is available to the output_processor (detokentization) not immediately after the decode step, but after the next step's decode - as we perform the pythonization after launching the next step in order to keep GPU as busy as possible. There are two things that would be needed to make it compatible with guided decode:
|
I think that #5423 mentions async It will be good to coordinate efforts with logit_processor API changes. But I think for the time being you should not wait for these features to land, and simply throw an incompatibility error until such a feature lands. |
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you! |
This issue has been automatically closed due to inactivity. Please feel free to reopen if you feel it is still relevant. Thank you! |
Motivation.
TLDR; There is high CPU overhead associated with each decode batch due to the processing and generation of input/output. Multi-step decoding will be able to amortize all these overheads over n-steps at a time.
Result is that GPU is often idle, waiting for CPU operations (5-13ms of GPU bubble)
Multi-step is when multiple decode passes are performed before performing a GPU-CPU sync in order to invoke vLLM scheduler and process sampled tokens. Currently the GPU->CPU memory transfer for sampled tokens is also synchronous with each decode step causing bubbles on the GPU. With multi-step, this memory transfer can happen in a separate CUDA stream and is essentially free as the CPU runs ahead of GPU.
See below for the source of performance improvement.
Torch Profiles
Baseline 8B on 1xH100
Multi-Step-8 8B on 1xH100
Benchmarks
MS = multi-step
MS-8 = 8-multi-steps before calling vLLM scheduler and process_output
Proposed Change.
Extend
ExecuteModelRequest
(input toWorkers
) andRequestOutput/SamplerOutput
to include metadata for the multi-step state and modify existingModelRunner
to properly handle multi-step state.AsyncLLMEngine/LLMEngine
will need to be modified to be aware of multi-step in order to call into the VLLM scheduler after n-steps instead of on every decode. The existing PP scheduling will not be changed.High level Algorithm:
Details:
Multi-step states that need to be track for each (micro)batch:
sampled_token_ids
- to keep track of sampled tokens still on GPUsampler_output_ready_event
- CUDA event to make sure we only pythonize if the GPU sampling is finishedCore changes to Engine:
Core changes to
ModelRunner
:advance_step
Prototype:
The current prototype is based on speculative decode’s
T1DraftModelRunner’
s logic. There are numerous additions for PP/TP support. For the prototype we created a non-spec decode MultiStepModelRunner underworkers/
. The goal is that we will generalize this to the existing ModelRunner (removing the need for a new file) before merging.Reasoning: PP+multi-step
TLDR: Since the current multi-step look is inside ModelRunner/Worker, PP scheduling in Executor will cause bubbles between each step and not interleave the steps of Batch 1 (VE1) with Batch 2 (VE2)
Feedback Period.
No response
CC List.
@zhisbug @Yard1 @WoosukKwon @rkooo567 @zhuohan123 @simon-mo @comaniac @megha95 @richardliaw
Any Other Things.
Much thanks to @Yard1 for extensive help with design and implementation!
Sync with @megha for ongoing work to make the output_processor async. She proposed to move sampler out of model runner.
The text was updated successfully, but these errors were encountered: