Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Spec Decode] Ngram Spec Decode #12193

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

LiuXiaoxuanPKU
Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU commented Jan 19, 2025

This PR tries to add ngram spec decode to V1. Design doc: here.
Major changes:

  1. Since we only implement the ngram spec decode, we did not add another scheduler for running the drafting method. We always check if we need to do ngram lookup before calling the scheduler.
  2. Add a new field _spec_token_ids in Request to track speculated tokens.
  3. Changes to model_runner:
    3.1 Change the _prepare_input to also return the logits of speculated tokens.
    3.2 Change the _prepare_input to add speculated tokens as input tokens.
    3.3 Change the execute_model to generate multiple tokens per call. Concretely, it will add more than one tokens to input_batch and req_state.
  4. We only perform spec decode for requests in the running queue.
  5. We only support greedy decoding for now.

What is missing

  • Change scheduling to only propose tokens for decoding requests.
  • Stop checking for spec decode, where mutiple tokens are generated in a single step.
  • For the ngram lookup logic, currently I just append dummy tokens directly instead of performing the lookup. We can move v0's lookup logic here.
  • Check the correctness of this PR with chunked prefill. <-- We only perform spec decode in the decoding phase.
  • More end to end tests & Style.

Tasks out of the scope of this PR

  1. Optimize the performance of ngram lookup.
  2. Support non-greedy decoding.
  3. Add other spec decode methods.

[Update]
I will move the following two features into following PRs:

  1. Guarantee the correctness of prefix caching + spec decode, because it will involve changing the behavior of kv cache manager @comaniac.
  2. Change the scheduling policy to guarantee that at least one token is scheduled for each request. Separate this because it will touch the scheduling code and needs more careful thought/test.

Minor: There is a minimal example/test in tests/v1/e2e/test_basic_specdecode.py. You can check it for the current use and check correctness with pytest -s tests/v1/e2e/test_basic_specdecode.py.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link

mergify bot commented Jan 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LiuXiaoxuanPKU.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 19, 2025
@mergify mergify bot removed the needs-rebase label Jan 20, 2025
@NickLucche
Copy link
Contributor

Surely late here, but why is a speculative decoding-aware scheduler needed? Wouldn't it be possible to just assume multi-token generation per-step as default?

@comaniac
Copy link
Collaborator

comaniac commented Jan 24, 2025

Surely late here, but why is a speculative decoding-aware scheduler needed? Wouldn't it be possible to just assume multi-token generation per-step as default?

Because the scheduler has to know how many kv-cache slots are needed for each request. We use lookahead slots in v0 that always allocates k lookahead slots for each request when spec decode is enabled. However, it's inefficient when we don't have k spec tokens for every request. This may happen in the following examples:

  1. N-gram won't propose any tokens if failed to find a match.
  2. The draft model generates EOS.
  3. Insufficient kv-cache slots.
  4. In dynamic speculative decoding, we control "k" based on the current traffic.

So in this design for v1, we first get the spec tokens, and let the target model scheduler allocate the exact number of slots accordingly.

vllm/v1/core/scheduler.py Outdated Show resolved Hide resolved
self._spec_token_ids = []

@property
def spec_token_ids(self) -> List[int]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should return ConstantList to be read-only. See output_token_ids and all_token_ids as references.

vllm/v1/request.py Show resolved Hide resolved
vllm/v1/core/scheduler.py Show resolved Hide resolved
vllm/v1/core/scheduler.py Show resolved Hide resolved
Comment on lines +167 to +173
# When calculating new full blocks, we exclude speculative tokens.
# We only cache blocks where token_ids are valid. KV cache of
# speculative tokens will be valid once these tokens are accepted
# (tracked by num_computed_tokens).
num_cached_tokens = request.num_computed_tokens + num_tokens - len(
request.spec_token_ids)
num_full_blocks_after_append = num_cached_tokens // self.block_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized that the current _cache_full_blocks may not support caching the same block twice. It may have issues in an edge case (block size 4, k=3):

Step 1: [0, 1, 2, 3] + [4, S0, S1, S2]
- The first block is already cached in the last step.
- Assuming all spec tokens are rejected.

Step 2: [0, 1, 2, 3] + [4, 5, S0, S1] + [S2]
- `5` is the bonus token.
- num_cached_tokens = 5 + 1 - 3 = 3
- num_full_blocks_after_append = 3 // 4 = 0
- So you attempt to cache the first block again.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#12415 should fix this.

logger = init_logger(__name__)


class RejectionSampler(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use FlashInfer kernel for reject sampling?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add a TODO, saying we can replace it with flashinfer in the future.

vllm/v1/spec_decode/ngram_proposer.py Show resolved Hide resolved
vllm/v1/worker/gpu_input_batch.py Show resolved Hide resolved
vllm/v1/worker/gpu_model_runner.py Show resolved Hide resolved
@NickLucche
Copy link
Contributor

I see, thanks a lot for elaborating @comaniac!

@@ -621,6 +663,8 @@ class SchedulerOutput:
num_scheduled_tokens: Dict[str, int]
total_num_scheduled_tokens: int
scheduled_encoder_inputs: Dict[str, List[int]]
use_spec_decode: bool
scheduled_spec_decode_tokens: Dict[str, List[int]]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a more general approach:

  1. no specific to spec decode: add a function to scheduler so it can schedule n new tokens (with token ids)
  2. specific to spec decode: rejection sampling, some tokens might be rejected, some ways of rewinding

Copy link
Collaborator

@sroy745 sroy745 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR !

vllm/v1/spec_decode/ngram_proposer.py Outdated Show resolved Hide resolved
vllm/v1/engine/core.py Show resolved Hide resolved
vllm/v1/engine/core.py Outdated Show resolved Hide resolved
vllm/v1/sample/rejection_sampler.py Outdated Show resolved Hide resolved
@@ -361,11 +363,15 @@ def make_sampling_metadata(
# TODO - Replace this with incremental update to output token
# statistics.
output_token_ids.append(req_id_output_token_ids[req_id])
if rejection_sampling:
assert req_id_to_spec_token_ids is not None
spec_token_ids.append(req_id_to_spec_token_ids[req_id])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that there are no speculations for a req_id? If so do we need to do req_id_to_spec_token_ids.get(req_id, []) ?

request.append_output_token_ids(token_id)
num_new_tokens = 1
if request.num_computed_tokens >= request.num_tokens:
request.clear_spec_tokens()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if we should clear the spec tokens always? What happens if the target model does not accept any of the spec tokens? Is that case being handled here?

vllm/v1/core/scheduler.py Show resolved Hide resolved
vllm/v1/core/scheduler.py Show resolved Hide resolved
tests/v1/core/test_stop_checking.py Show resolved Hide resolved
tests/v1/core/test_stop_checking.py Show resolved Hide resolved
Copy link

mergify bot commented Jan 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LiuXiaoxuanPKU.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 27, 2025
@@ -0,0 +1,47 @@
import torch
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use torch tensor or numpy to implement the first version of rejection sampling

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants