Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix Cache Aware Scheduling [1/n] #10128

Merged
merged 4 commits into from
Nov 23, 2024

Conversation

rickyyx
Copy link
Contributor

@rickyyx rickyyx commented Nov 7, 2024

TL;DR

Background

With current impl in main, there are at least 2 places where scheduling is not optimal:

  1. When deciding how many new tokens should be used for scheduling budget, it doesn't take into account already computed tokens.
  2. When deciding if a sequence could be allocated for prefill, it doesn't take into account the already computed blocks.

This would result in under-utilization of KV cache, and un optimal scheduling decision for a batch.

For more details, see the #7883

High Level Approach

This PR addresses the issue by solving 1 only:

  1. Make scheduling prefix cache aware: When deciding how many tokens to schedule, the already cached tokens are now also taken into account, and the scheduling budget now includes cached tokens.
  2. The ComputedBlocksTracker will track the block hashes for a sequence and tell scheduler/block manager how many tokens have been cached given a sequence.

On a high level, there are below major changes.

  1. Scheduler when deciding how many tokens to batch, it also takes into account new tokens that are already cached. So the scheduling budget only includes new and uncached tokens.
  2. The scheduler gets the cached token info by asking the blockmanager which has the computed blocks tracker tracks the information for each sequence.
  3. For each sequence, a block hash is computed twice in the current PR (once at the computed block tracker, once in the allocator). A future PR could improve this.

Throughput Benchmark

Example command

    python benchmarks/benchmark_prefix_caching.py \
        --model meta-llama/Meta-Llama-3-8B-Instruct \
        --enable-prefix-caching \
        --use-v2-block-manager \
        --num-prompts 300 \
        --input-length-range 1000:1000 \
        --output-len 10 \
        --prefix-len $prefix_len \
        --max-num-batched-tokens 2048 \
        --enable-chunked-prefill
image

More details in this doc

Serving Benchmark

# Server
vllm serve meta-llama/Meta-Llama-3-8B-Instruct --disable-log-requests --enable-prefix-caching --enable-chunked-prefill

Serving Results on QPS=10

  • The improvement on TTFT is significant.
  • There's slight improvement on request throughput
  • There's slight degrade on ITL (likely due to more tokens from prefill in a batch)
python3 benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random \
--random-input-len 200 --random-output-len 30 --random-prefix-len 600 --seed 0 \
--request-rate 10 --num-prompts 500
============ Serving Benchmark Result ============
                                          PR           Main
Successful requests:                     500           500       
Benchmark duration (s):                  52.01         55.42     
Total input tokens:                      400000        400000    
Total generated tokens:                  14831         14809     
Request throughput (req/s):              9.61          9.02      
Output token throughput (tok/s):         285.16        267.23    
Total Token throughput (tok/s):          7976.00       7485.27   
---------------Time to First Token----------------
Mean TTFT (ms):                          238.19        2193.42   
Median TTFT (ms):                        209.70        2510.56   
P99 TTFT (ms):                           632.90        3922.97   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          81.33         60.74     
Median TPOT (ms):                        75.03         61.54     
P99 TPOT (ms):                           135.80        64.79     
---------------Inter-token Latency----------------
Mean ITL (ms):                           82.12         61.21     
Median ITL (ms):                         62.17         62.90     
P99 ITL (ms):                           285.48        116.98    
==================================================

Serving Results on QPS=15

With a higher request rate (QPS=15), the improvement on TTFT and request rate will be more significant (25%)

python3 benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input-len 200 --random-output-len 50 --random-prefix-len 600 --seed 0 \
--request-rate 15 \
--num-prompts 500
============ Serving Benchmark Result ============
                                          PR           Main
Successful requests:                     500           500       
Benchmark duration (s):                  46.94         58.66     
Total input tokens:                      400000        400000    
Total generated tokens:                  24685         24668     
Request throughput (req/s):              10.65         8.52      
Output token throughput (tok/s):         525.92        420.52    
Total Token throughput (tok/s):          9048.08       7239.31   
---------------Time to First Token----------------
Mean TTFT (ms):                          4797.81       11201.85  
Median TTFT (ms):                        4566.95       11076.55  
P99 TTFT (ms):                           10789.80      22585.37  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          160.69        64.68     
Median TPOT (ms):                        170.73        65.03     
P99 TPOT (ms):                           191.34        68.68     
---------------Inter-token Latency----------------
Mean ITL (ms):                           161.50        65.20     
Median ITL (ms):                         154.61        66.75     
P99 ITL (ms):                           847.72        124.48    
==================================================

Serving Results on No Prefix Shared

python3 benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input-len 800 --random-output-len 50 --random-prefix-len 0 --seed 0 --request-rate 10 --num-prompts 500
  • There's no overhead when workload has no shared prefil.
============ Serving Benchmark Result ============
                                          PR           Main         
Successful requests:                     500           500          
Benchmark duration (s):                  127.43        127.65       
Total input tokens:                      400000        400000       
Total generated tokens:                  24350         24350        
Request throughput (req/s):              3.92          3.92         
Output token throughput (tok/s):         191.08        190.76       
Total Token throughput (tok/s):          3330.05       3324.43      
---------------Time to First Token----------------
Mean TTFT (ms):                          36577.91      36610.46     
Median TTFT (ms):                        36444.46      36417.98     
P99 TTFT (ms):                           73755.41      73895.10     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          144.37        144.74       
Median TPOT (ms):                        143.66        143.89       
P99 TPOT (ms):                           302.23        302.97       
---------------Inter-token Latency----------------
Mean ITL (ms):                           142.62        142.98       
Median ITL (ms):                         143.56        143.77       
P99 ITL (ms):                           286.82        287.27       
==================================================

Copy link

github-actions bot commented Nov 7, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@rickyyx rickyyx changed the title [WIP] Prefix Cache Aware Scheduling [1/n] Prefix Cache Aware Scheduling [1/n] Nov 12, 2024

return num_uncached_new_tokens_seq, num_cached_new_tokens_seq

def _chunk_new_tokens_to_schedule(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mainly refactor from the main's logic.

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Leave some commits.
Also cc @zhuohan123 @alexm-neuralmagic

vllm/core/block/prefix_caching_block.py Show resolved Hide resolved
vllm/core/block/prefix_caching_block.py Outdated Show resolved Hide resolved
vllm/core/block/prefix_caching_block.py Outdated Show resolved Hide resolved
vllm/core/block/prefix_caching_block.py Outdated Show resolved Hide resolved
vllm/core/block/prefix_caching_block.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Show resolved Hide resolved
vllm/core/scheduler.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Outdated Show resolved Hide resolved
@rickyyx rickyyx requested a review from KuntaiDu as a code owner November 13, 2024 21:26
Copy link

mergify bot commented Nov 13, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @rickyyx.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@rickyyx
Copy link
Contributor Author

rickyyx commented Nov 13, 2024

Updates

  • Added integration test for prefix caching with fully (and partially) cached prompts.

Copy link

mergify bot commented Nov 13, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @rickyyx.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@rickyyx
Copy link
Contributor Author

rickyyx commented Nov 13, 2024

lol - fml, i hate DCO.

@mergify mergify bot removed the needs-rebase label Nov 13, 2024
Signed-off-by: rickyx <rickyx@anyscale.com>
@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 14, 2024
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. The remaining issue per offline discussion is to make sure each scheduled sequence has at least one token in the budget.

cc @zhuohan123 for another review.

Signed-off-by: rickyx <rickyx@anyscale.com>
@rickyyx
Copy link
Contributor Author

rickyyx commented Nov 15, 2024

Updates

  • Added test to make sure each scheduled sequence contributes at least one token in the budget (a fully cached prompt should not be scheduled with 0 uncached token).

Signed-off-by: rickyx <rickyx@anyscale.com>
Signed-off-by: rickyx <rickyx@anyscale.com>
@youkaichao youkaichao merged commit 4634a89 into vllm-project:main Nov 23, 2024
46 of 50 checks passed
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 28, 2024
Signed-off-by: rickyx <rickyx@anyscale.com>
Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com>
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Signed-off-by: rickyx <rickyx@anyscale.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation frontend ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants