Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Asynchronous Output Processor #7049

Merged
merged 21 commits into from
Aug 27, 2024

Conversation

megha95
Copy link
Contributor

@megha95 megha95 commented Aug 1, 2024

This PR is an implementation for RFC #6913 .

Following changes were made:

  1. Scheduler states' updates is decoupled from output processor. Earlier both _process_model_outputs and scheduler.schedule() were updating deque and free-ing memory of finished sequences.
  2. Adds _advance_to_next_step in LLMEngine that does small but necessary steps in critical path: eg: appending new token id and updating prefill sequences' statuses to decode.
  3. Adds a callback function to execute_model so that process_model_outputs is triggered right before sampler inside ModelRunner. Since GPU always runs ahead of CPU, this allows for complete overlap with cuda graph forward pass in each decoding step

Results
Throughput improves by 11% for Llama3.1-8B FP8 on 1xH100. Results are even better at high batch sizes. With my own benchmarking setup, I see 1.3x increase in throughput at high RPS.

There's a follow-up PR # that combines async output processor with multi-step and shows promising results. 30% improvement in TPOT.
Without async output processor

CUDA_VISIBLE_DEVICES=0 python3 benchmark_throughput.py --model /mnt/workdisk/chengli/models/meta-llama/Meta-Llama-3.1-8B-Instruct --backend vllm --input-len 128 --output-len 256 --num-prompts 1000 --tensor-parallel 1 --quantization fp8 --disable-output-proc-callback

Throughput: 31.98 requests/s, 12279.04 tokens/s

With async output processor

CUDA_VISIBLE_DEVICES=0 python3 benchmark_throughput.py --model /mnt/workdisk/chengli/models/meta-llama/Meta-Llama-3.1-8B-Instruct --backend vllm --input-len 128 --output-len 256 --num-prompts 1000 --tensor-parallel 1 --quantization fp8 

Throughput: 35.64 requests/s, 13687.30 tokens/s

To do:

  • Update stopping criteria and scheduling of sequences for when stop token id is generated in previous step
  • add support for AsyncLLM engine
  • No callback codepath for spec dec, pipeline parallelism and beam search
  • Run tests
  • Add tests
  • Add benchmark numbers

Copy link

github-actions bot commented Aug 1, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@megha95 megha95 force-pushed the output-proc-callback branch from ea57d35 to 6fcce3f Compare August 5, 2024 23:26
@megha95 megha95 marked this pull request as ready for review August 8, 2024 06:22
@megha95 megha95 marked this pull request as draft August 8, 2024 06:24
@megha95 megha95 marked this pull request as ready for review August 8, 2024 21:47
Copy link
Collaborator

@alexm-neuralmagic alexm-neuralmagic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@megha95 This is a good progress! Left some comments.

vllm/config.py Outdated
@@ -129,6 +129,7 @@ def __init__(
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
multimodal_config: Optional["MultiModalConfig"] = None,
use_output_proc_callback: Optional[bool] = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @WoosukKwon proposal to make the async callback default is a good direction, so we can see the benefits already from this PR. You can enable the callback by default and disable it only for the cases where it does not work currently (beam search, etc...). Follow up PRs can add the missing features.

vllm/engine/async_llm_engine.py Outdated Show resolved Hide resolved
required if the worker is to perform async forward pass to next step.
"""
for seq_group_metadata, sequence_group_outputs in zip(
seq_group_metadata_list, output):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed, most likely you need to do seq.is_finished() check here before doing append_token(..)

Copy link
Contributor Author

@megha95 megha95 Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated now. Thanks for pointing this out.

vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/output_processor/single_step.py Outdated Show resolved Hide resolved
@megha95 megha95 force-pushed the output-proc-callback branch from 64b8b0b to e153330 Compare August 12, 2024 20:26
Copy link
Collaborator

@alexm-neuralmagic alexm-neuralmagic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a second pass, left comments

vllm/config.py Outdated Show resolved Hide resolved
vllm/config.py Outdated Show resolved Hide resolved
vllm/engine/arg_utils.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Show resolved Hide resolved
@WoosukKwon WoosukKwon self-requested a review August 13, 2024 17:36
@megha95 megha95 changed the title [DO NOT MERGE] Output Processor Callback Output Processor Callback Aug 15, 2024
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 15, 2024
@megha95 megha95 force-pushed the output-proc-callback branch from 8979669 to f8bcd7e Compare August 19, 2024 18:03
vllm/config.py Outdated Show resolved Hide resolved
vllm/config.py Outdated
Comment on lines 313 to 314
# TO DO: assert no pipeline parallelism
# TO DO: assert no spec decoding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the issues in supporting these two features?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supporting spec decoding will need more thought. Mostly because, execute_model inside spec dec's step is quite different from non-spec dec codepath and we need to figure out the right place to call the callback. Pipeline parallelism should be easier, and can be done in follow-up PR.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supporting spec decoding will need more thought. Mostly because, execute_model inside spec dec's step is quite different from non-spec dec codepath and we need to figure out the right place to call the callback. Pipeline parallelism should be easier, and can be done in follow-up PR.

Hi, we have better performance with Asynchronous Output Processor, do we have any schedule to support async in spec dec's path? thanks

vllm/config.py Outdated Show resolved Hide resolved
vllm/worker/model_runner.py Outdated Show resolved Hide resolved
benchmarks/benchmark_throughput.py Outdated Show resolved Hide resolved
vllm/executor/executor_base.py Outdated Show resolved Hide resolved
vllm/engine/output_processor/multi_step.py Outdated Show resolved Hide resolved
vllm/engine/output_processor/multi_step.py Show resolved Hide resolved
vllm/engine/output_processor/multi_step.py Outdated Show resolved Hide resolved
vllm/engine/arg_utils.py Outdated Show resolved Hide resolved
@alexm-neuralmagic
Copy link
Collaborator

@comaniac would be good to get your quick feedback on this PR

vllm/config.py Outdated
self.use_output_proc_callback = False

if speculative_config:
self.use_output_proc_callback = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need a warning?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

vllm/config.py Outdated Show resolved Hide resolved
# For async output processing, we need to swap cache buffers between
# iterations. I.e. since the output processing is lagged one step,
# we cannot reuse the cached objects immediately when the schedule()
# is called again, but only when schedule() is called the second time.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the second time is guaranteed to be available?

scheduler_outputs = self._schedule()
now = time.time()

if not self.cache_config.enable_prefix_caching:
common_computed_block_nums = []

# TODO: Combine multi-step and async postprocessor
allow_output_proc_callback: bool = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just combining this logic into the _allow_output_proc_callback function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be possible, will see if I can do it.

Comment on lines 388 to 391
# Avoid deque alloc
self.tmp_queue: Deque[SequenceGroup] = deque()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you say more about this? Or use a better name for this queue. It's hard to catch its functionality in the rest of its use cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a deque acting as a temporary buffer to store the new seqs in running queue. It's only used inside free_finished_seq_groups, we save some time by avoiding re-allocation during each step.

vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
Comment on lines +1554 to +1686
# Skip double logging when using async output proc
if finished_before and idx in finished_before:
actual_num_batched_tokens -= 1
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't connect the comment and the code. Can you elaborate it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actual_num_batched_tokens -= 1 is required to avoid double-counting. With async output proc, number of batched tokens is already counted since its lagged by a step. If we have a token that was finished before, we won't count it again. This does not happen with non-async logging.

vllm/engine/output_processor/single_step.py Outdated Show resolved Hide resolved
vllm/worker/model_runner.py Outdated Show resolved Hide resolved
@alexm-neuralmagic
Copy link
Collaborator

@comaniac @WoosukKwon thanks for the detailed reviews, I think we addressed all of the comments, going over the final CI tests with @megha95 and some issue Woosuk found with beam search.

@megha95 megha95 changed the title Output Processor Callback Asynchronous Output Processor Aug 23, 2024
throughput_results.json Outdated Show resolved Hide resolved
@megha95 megha95 requested a review from WoosukKwon August 24, 2024 00:47
@megha95 megha95 changed the title Asynchronous Output Processor [Core] Asynchronous Output Processor Aug 24, 2024
vllm/core/scheduler.py Show resolved Hide resolved
vllm/config.py Outdated Show resolved Hide resolved
vllm/config.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@megha95 @alexm-neuralmagic Huge thanks for the PR! Also, thanks a lot for the kind intro and explanation of this PR.

I'm happy with merging this PR once the remaining super-minor issues are fixed and the PR passes the CI again.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @megha95 @alexm-neuralmagic for the great work!

Do you have a sense for how much of the speedup comes from removal of the scheduling from the critical path vs removal of detokenization?

vllm/engine/llm_engine.py Show resolved Hide resolved
vllm/executor/multiproc_gpu_executor.py Outdated Show resolved Hide resolved
vllm/engine/async_llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/async_llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Outdated Show resolved Hide resolved
@megha95
Copy link
Contributor Author

megha95 commented Aug 27, 2024

@njhill Thank you for the review. To answer your question:

Do you have a sense for how much of the speedup comes from removal of the scheduling from the critical path vs removal of detokenization?

We haven't removed scheduling from critical path. Most of the speedup in this PR comes from removing detokenization, free_finished_seq_groups and creation of RequestOutputs that is streamed back to user.

@WoosukKwon
Copy link
Collaborator

@megha95 @alexm-neuralmagic Thanks again for your great work!
@njhill Thanks for taking a look at the PR!

@WoosukKwon WoosukKwon merged commit 2eedede into vllm-project:main Aug 27, 2024
46 of 48 checks passed
@DarkLight1337
Copy link
Member

DarkLight1337 commented Sep 2, 2024

It seems that this PR is causing the LoRA tests to fail on main.

Before this PR on main: https://buildkite.com/vllm/ci-aws/builds/7622
In this PR: https://buildkite.com/vllm/ci-aws/builds/7645

Can those who worked on this PR look into the issue further?

@youkaichao
Copy link
Member

cc @megha95 @WoosukKwon

@megha95
Copy link
Contributor Author

megha95 commented Sep 3, 2024

@youkaichao @DarkLight1337 Having a look.

triple-Mu pushed a commit to triple-Mu/vllm_official that referenced this pull request Sep 4, 2024
Co-authored-by: Alexander Matveev <alexm@neuralmagic.com>
@rkooo567
Copy link
Collaborator

@megha95 do you know what it takes to support this feature for spmd architecture? (referring to

"Async output processing can not be enabled with ray spmd")
)

Jeffwan pushed a commit to aibrix/vllm that referenced this pull request Sep 19, 2024
Co-authored-by: Alexander Matveev <alexm@neuralmagic.com>
michalkuligowski pushed a commit to HabanaAI/vllm-fork that referenced this pull request Sep 27, 2024
FILL IN THE PR DESCRIPTION HERE

This PR refer to [vllm-project#7049](vllm-project#7049)
to implement Asynchronous Output Processor on HPU. It is open by
default, to disable it, please pass the `--disable_async_output_proc`
flag.

From my local test on latest habana_main branch(commit
29fb5ed), the throughput improves from
3847 TPS to 4011 TPS.

**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE
DESCRIPTION ABOVE**

---

<details>
<!-- inside this <details> section, markdown rendering does not work, so
we use raw html here. -->
<summary><b> PR Checklist (Click to Expand) </b></summary>

<p>Thank you for your contribution to vLLM! Before submitting the pull
request, please ensure the PR meets the following criteria. This helps
vLLM maintain the code quality and improve the efficiency of the review
process.</p>

<h3>PR Title and Classification</h3>
<p>Only specific types of PRs will be reviewed. The PR title is prefixed
appropriately to indicate the type of change. Please use one of the
following:</p>
<ul>
    <li><code>[Bugfix]</code> for bug fixes.</li>
<li><code>[CI/Build]</code> for build or continuous integration
improvements.</li>
<li><code>[Doc]</code> for documentation fixes and improvements.</li>
<li><code>[Model]</code> for adding a new model or improving an existing
model. Model name should appear in the title.</li>
<li><code>[Frontend]</code> For changes on the vLLM frontend (e.g.,
OpenAI API server, <code>LLM</code> class, etc.) </li>
<li><code>[Kernel]</code> for changes affecting CUDA kernels or other
compute kernels.</li>
<li><code>[Core]</code> for changes in the core vLLM logic (e.g.,
<code>LLMEngine</code>, <code>AsyncLLMEngine</code>,
<code>Scheduler</code>, etc.)</li>
<li><code>[Hardware][Vendor]</code> for hardware-specific changes.
Vendor name should appear in the prefix (e.g.,
<code>[Hardware][AMD]</code>).</li>
<li><code>[Misc]</code> for PRs that do not fit the above categories.
Please use this sparingly.</li>
</ul>
<p><strong>Note:</strong> If the PR spans more than one category, please
include all relevant prefixes.</p>

<h3>Code Quality</h3>

<p>The PR need to meet the following code quality standards:</p>

<ul>
<li>We adhere to <a
href="https://google.github.io/styleguide/pyguide.html">Google Python
style guide</a> and <a
href="https://google.github.io/styleguide/cppguide.html">Google C++
style guide</a>.</li>
<li>Pass all linter checks. Please use <a
href="https://github.com/vllm-project/vllm/blob/main/format.sh"><code>format.sh</code></a>
to format your code.</li>
<li>The code need to be well-documented to ensure future contributors
can easily understand the code.</li>
<li>Include sufficient tests to ensure the project to stay correct and
robust. This includes both unit tests and integration tests.</li>
<li>Please add documentation to <code>docs/source/</code> if the PR
modifies the user-facing behaviors of vLLM. It helps vLLM user
understand and utilize the new features or changes.</li>
</ul>

<h3>Adding or changing kernels</h3>
<p>Each custom kernel needs a schema and one or more implementations to
be registered with PyTorch.</p>
<ul>
<li>Make sure custom ops are registered following PyTorch guidelines: <a
href="https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#cpp-custom-ops-tutorial">Custom
C++ and CUDA Operators</a> and <a
href="https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU">The
Custom Operators Manual</a></li>
<li>Custom operations that return <code>Tensors</code> require
meta-functions. Meta-functions should be implemented and registered in
python so that dynamic dims can be handled automatically. See above
documents for a description of meta-functions.</li>
<li>Use <a
href="https://pytorch.org/docs/stable/library.html#torch.library.opcheck"><code>torch.libary.opcheck()</code></a>
to test the function registration and meta-function for any registered
ops. See <code>tests/kernels</code> for examples.</li>
<li>When changing the C++ signature of an existing op, the schema must
be updated to reflect the changes.</li>
<li>If a new custom type is needed, see the following document: <a
href="https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA">Custom
Class Support in PT2</a>.
</ul>

<h3>Notes for Large Changes</h3>
<p>Please keep the changes as concise as possible. For major
architectural changes (>500 LOC excluding kernel/data/config/test), we
would expect a GitHub issue (RFC) discussing the technical design and
justification. Otherwise, we will tag it with <code>rfc-required</code>
and might not go through the PR.</p>

<h3>What to Expect for the Reviews</h3>

<p>The goal of the vLLM team is to be a <i>transparent reviewing
machine</i>. We would like to make the review process transparent and
efficient and make sure no contributor feel confused or frustrated.
However, the vLLM team is small, so we need to prioritize some PRs over
others. Here is what you can expect from the review process: </p>

<ul>
<li> After the PR is submitted, the PR will be assigned to a reviewer.
Every reviewer will pick up the PRs based on their expertise and
availability.</li>
<li> After the PR is assigned, the reviewer will provide status update
every 2-3 days. If the PR is not reviewed within 7 days, please feel
free to ping the reviewer or the vLLM team.</li>
<li> After the review, the reviewer will put an <code>
action-required</code> label on the PR if there are changes required.
The contributor should address the comments and ping the reviewer to
re-review the PR.</li>
<li> Please respond to all comments within a reasonable time frame. If a
comment isn't clear or you disagree with a suggestion, feel free to ask
for clarification or discuss the suggestion.
 </li>
</ul>

<h3>Thank You</h3>

<p> Finally, thank you for taking the time to read these guidelines and
for your interest in contributing to vLLM. Your contributions make vLLM
a great tool for everyone! </p>


</details>
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Co-authored-by: Alexander Matveev <alexm@neuralmagic.com>
Signed-off-by: Alvant <alvasian@yandex.ru>
@mergify mergify bot added the frontend label Nov 5, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Co-authored-by: Alexander Matveev <alexm@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants