Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker #5348

Merged
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
5650b95
Merge pull request #1 from vllm-project/main
sroy745 May 29, 2024
8f36146
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
9e75057
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
bbf1484
Integrate Typical Acceptance Sampler into spec decode worker
sroy745 Jun 7, 2024
db2c679
Merge branch 'vllm-project:main' into main
sroy745 Jun 7, 2024
3495673
Fixing tests
sroy745 Jun 9, 2024
26c7c57
adding missing commit
sroy745 Jun 10, 2024
090f0bf
reverting changes to conftest
sroy745 Jun 10, 2024
733cc6e
reverting changes to conftest
sroy745 Jun 10, 2024
19ca0c9
Merge branch 'main' into spec_decode_integrate_accpetance_sampler
sroy745 Jun 10, 2024
acf8d2c
Dummy commit
sroy745 Jun 10, 2024
2d2b02b
Merge branch 'spec_decode_integrate_accpetance_sampler' of https://gi…
sroy745 Jun 10, 2024
2010b35
Revert unnecessary commits
sroy745 Jun 10, 2024
8d7512c
Merge branch 'vllm-project:main' into main
sroy745 Jun 10, 2024
7fa64b6
Merge remote-tracking branch 'origin/main' into spec_decode_integrate…
sroy745 Jun 10, 2024
dea6fbd
Pass only one sampler which can either be the RejectionSampler of the…
sroy745 Jun 10, 2024
c3383db
Fix test scripture
sroy745 Jun 10, 2024
b15abba
Fix tests
sroy745 Jun 11, 2024
6ca731c
Fix tests
sroy745 Jun 11, 2024
483c671
Pass only 1 verification_sampler which can either be rejectionSampler…
sroy745 Jun 11, 2024
2c6d06c
Update metrics.py to take the base sampler class
sroy745 Jun 11, 2024
027b485
Fix tests and comments
sroy745 Jun 11, 2024
ded92ac
Fix test fixture and default values of args
sroy745 Jun 11, 2024
738871e
Small misc fixes
sroy745 Jun 11, 2024
50e8771
Fix spec_decode/test_metrics.py
sroy745 Jun 11, 2024
101611e
Merge branch 'main' into spec_decode_integrate_accpetance_sampler
sroy745 Jun 11, 2024
5e6638b
Merge branch 'main' into spec_decode_integrate_accpetance_sampler
sroy745 Jun 25, 2024
cc760a0
Make rejection_sampler.py and typical_acceptance_sampler.py implement…
sroy745 Jun 25, 2024
360ce0b
Raise exception instead of returning None for invalid sampler name
sroy745 Jun 25, 2024
6572ba4
Adding log about type of sampler
sroy745 Jun 25, 2024
be85f07
Misc comment fixes
sroy745 Jun 26, 2024
6dc9efe
Misc fixes
sroy745 Jun 26, 2024
512fad9
Misc fixes
sroy745 Jun 26, 2024
b1d510c
Misc fixes
sroy745 Jun 26, 2024
f4b9e4d
Misc fixes
sroy745 Jun 26, 2024
0ea9408
Documentation
sroy745 Jun 26, 2024
5772d04
Fix comments
sroy745 Jun 26, 2024
b7254e7
Fix arg name
sroy745 Jun 26, 2024
ef93081
Fixing a test
sroy745 Jun 26, 2024
0165842
Fix comment
sroy745 Jun 26, 2024
510974b
Fix formatting
sroy745 Jun 26, 2024
396fa54
Fixing tests and lint failures
sroy745 Jun 26, 2024
f8cc895
Removing e2e test for TypicalAcceptanceSampler from test_ngram_correc…
sroy745 Jun 27, 2024
439117d
Fix a comment
sroy745 Jun 27, 2024
75f034f
Dummy commit
sroy745 Jun 27, 2024
a0f5ade
Merge pull request #2 from vllm-project/main
sroy745 Jun 27, 2024
3082255
Fix format error
sroy745 Jun 27, 2024
4e7f51a
Merge pull request #3 from vllm-project/main
sroy745 Jun 28, 2024
d26c624
Dummy fix
sroy745 Jun 29, 2024
98d5f92
Merge branch 'main' into spec_decode_integrate_accpetance_sampler
sroy745 Jun 29, 2024
f186844
Update test_multistep_correctness.py
sroy745 Jun 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions tests/samplers/test_typical_acceptance_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
break
return draft_token_ids

def get_acceptance_sampler(
posterior_threshold: float = 0.03,
posterior_alpha: float = 0.9,
disable_bonus_tokens: bool = False,
strict_mode: bool = False,
) -> TypicalAcceptanceSampler:
"""
Initializes and returns a TypicalAcceptanceSampler.
"""
return TypicalAcceptanceSampler(
posterior_threshold, posterior_alpha, disable_bonus_tokens, strict_mode)



@pytest.mark.parametrize("k", list(range(1, 6)))
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
Expand All @@ -64,7 +77,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
different combinations of k, vocab_size, batch_size and num devices.
"""
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler()
typical_acceptance_sampler = get_acceptance_sampler()
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
Expand All @@ -76,7 +89,9 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
size=(batch_size, k),
dtype=torch.int64)
# Verify that sampling succeeds for all cases.
typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids)
typical_acceptance_sampler(target_probs, bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)


@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
Expand All @@ -94,7 +109,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
Expand Down Expand Up @@ -126,7 +141,8 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,

with pytest.raises(AssertionError):
typical_acceptance_sampler(target_probs, bonus_token_ids,
draft_token_ids)
draft_probs=None,
draft_token_ids=draft_token_ids)


@pytest.mark.parametrize("seed", list(range(10)))
Expand All @@ -151,7 +167,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
Expand All @@ -165,7 +181,8 @@ def test_uniform_target_distribution_accepts_all_tokens(
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
draft_probs=None,
draft_token_ids=draft_token_ids)
# We are using a uniform target probability distribution.
# For a uniform distribution the entropy is very high and it
# should lead to all draft tokens being accepted. Verify that.
Expand Down Expand Up @@ -203,7 +220,7 @@ def test_temperature_zero_target_distribution(seed: int,
vocab_size = 30_000
torch.set_default_device(device)

typical_acceptance_sampler = TypicalAcceptanceSampler(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Simulate temperature 0 probability distribution for target probabilities
Expand All @@ -226,7 +243,8 @@ def test_temperature_zero_target_distribution(seed: int,
# Verify the same.
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, -1] == -1)
Expand Down Expand Up @@ -261,7 +279,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
batch_size = 4
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# For sequences 0 and 2 set the distribution to a temperature
Expand All @@ -279,7 +297,8 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
draft_probs=None,
draft_token_ids=draft_token_ids)
# verify the shape of output_token_ids
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
Expand Down Expand Up @@ -326,7 +345,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Create a temperature zero target probability distribution and ensure
Expand All @@ -341,7 +360,8 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
Expand All @@ -359,7 +379,8 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
Expand All @@ -384,7 +405,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Simulate temperature 0 probability distribution for target
Expand All @@ -404,7 +425,8 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 1:-1] == -1)
Expand All @@ -420,7 +442,8 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
typical_acceptance_sampler.init_gpu_tensors(rank=0)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
Expand Down Expand Up @@ -451,7 +474,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
Expand Down
53 changes: 52 additions & 1 deletion tests/spec_decode/e2e/test_multistep_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality. This gives us good coverage of temp=0.

At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
highest probability in the target distribution are accepted. Therefore, we can
expect greedy equality for the TypicalAcceptanceSampler at temp=0.

For temp>0, we rely on unit tests on the rejection sampler to verify that the
output distribution is the same with spec decode vs. no spec decode (this would
be prohibitively expensive to run with a real model).
be prohibitively expensive to run with a real model). Similary, for the
TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
test cases.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For testing strategy:

I am concerned that we are adding many E2E tests that don't provide a lot of signal over what already exists. The tradeoff of more tests is that we can accidentally explode CI time. This is because we rely on E2E tests for spec decode correctness and any small regression in model loading or vLLM initialization time can hurt us bad.

So, what I suggest:

  • E2E tests over the interaction between spec decode worker and typical acceptance
    • make sure it can handle different BS
    • make sure it can handle different K
  • E2E test over one other proposer method
    • just need one to make sure typical acceptance works beyond draft model
  • We don't need tests around preemption, disabling/skipping speculation, different block size, since all of these are no different for typical acceptance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a single test in test_multistep_correctness.py to cover different batch size and speculation_length values with TypicalAcceptanceSampler. Added a similar test to test_ngram_correctness.py


NOTE: Speculative decoding's distribution equality requires that the measured
distributions of the target model and proposal model be deterministic given the
Expand Down Expand Up @@ -611,3 +617,48 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
batch_size,
max_output_len=output_len,
force_output_len=True)

@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",

# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
}
# Try a range of common k, as well as large speculation.
for k in [1, 2, 63]
])
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_typical_acceptance_sampling(
baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
"""Verify that speculative decoding produces exact equality to without spec
decode with many TypicalAcceptanceSampler as the draft token acceptance
sampling method.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
46 changes: 46 additions & 0 deletions tests/spec_decode/e2e/test_ngram_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,49 @@ def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)

@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",

# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 3,
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("batch_size", [1, 32])
def test_ngram_typical_acceptance_sampling(
baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k, batch_size and
using TypicalAcceptanceSampler as the draft token acceptance method.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
12 changes: 6 additions & 6 deletions tests/spec_decode/test_dynamic_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,33 @@
import pytest
import torch

from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.spec_decode.top1_proposer import Top1Proposer

from .utils import create_batch, mock_worker

from .test_utils import mock_spec_decode_sampler

@pytest.mark.parametrize('queue_size', [4])
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('k', [1])
@pytest.mark.parametrize("mock_spec_decode_sampler",
["rejection_sampler", "typical_acceptance_sampler"], indirect=True)
@torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
def test_disable_spec_tokens(
queue_size: int, batch_size: int, k: int, mock_spec_decode_sampler):
"""Verify that speculative tokens are disabled when the batch size
exceeds the threshold.
"""
disable_by_batch_size = 3

draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker,
rejection_sampler=rejection_sampler,
spec_decode_sampler=mock_spec_decode_sampler,
metrics_collector=metrics_collector,
disable_by_batch_size=disable_by_batch_size)

Expand Down
Loading
Loading