Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative Decoding] Test refactor #8317

Merged
merged 14 commits into from
Sep 11, 2024
472 changes: 173 additions & 299 deletions tests/spec_decode/e2e/conftest.py

Large diffs are not rendered by default.

97 changes: 49 additions & 48 deletions tests/spec_decode/e2e/test_eagle_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pytest

from .conftest import run_greedy_equality_correctness_test
from .conftest import run_equality_correctness_test

# main model
MAIN_MODEL = "JackFram/llama-68m"
Expand Down Expand Up @@ -53,7 +53,7 @@
"dtype": PRECISION,

# Main model
"model": MAIN_MODEL,
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
Expand All @@ -68,15 +68,16 @@
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_eagle_e2e_greedy_correctness(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality with different batch size."""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):

run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)


@pytest.mark.parametrize(
Expand All @@ -94,7 +95,7 @@ def test_eagle_e2e_greedy_correctness(baseline_llm_generator,
"dtype": PRECISION,

# Main model
"model": MAIN_MODEL,
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
Expand All @@ -109,17 +110,16 @@ def test_eagle_e2e_greedy_correctness(baseline_llm_generator,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality with cuda graph enabled and different
def test_eagle_e2e_greedy_correctness_cuda_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)


@pytest.mark.parametrize(
Expand All @@ -140,7 +140,7 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
"dtype": PRECISION,

# Main model
"model": MAIN_MODEL,
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
Expand All @@ -158,18 +158,17 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
def test_eagle_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)


@pytest.mark.parametrize(
Expand All @@ -185,7 +184,7 @@ def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
"dtype": PRECISION,

# Main model
"model": MAIN_MODEL,
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
Expand All @@ -207,16 +206,17 @@ def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
32,
])
@pytest.mark.parametrize("seed", [1])
def test_eagle_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
def test_eagle_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that eagle speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)


@pytest.mark.parametrize(
Expand All @@ -232,7 +232,7 @@ def test_eagle_different_k(baseline_llm_generator, test_llm_generator,
"dtype": PRECISION,

# Main model
"model": MAIN_MODEL,
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
Expand All @@ -250,17 +250,18 @@ def test_eagle_different_k(baseline_llm_generator, test_llm_generator,
32,
])
@pytest.mark.parametrize("seed", [1])
def test_eagle_disable_queue(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that eagle speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)


if __name__ == "__main__":
Expand Down
52 changes: 32 additions & 20 deletions tests/spec_decode/e2e/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import pytest

from .conftest import run_greedy_equality_correctness_test
from .conftest import run_equality_correctness_test

MAIN_MODEL = "JackFram/llama-68m"


@pytest.mark.parametrize(
Expand All @@ -15,7 +17,7 @@

# Verify equality when cuda graphs allowed.
"enforce_eager": False,
"model": "JackFram/llama-68m",
"model_name": "JackFram/llama-68m",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
Expand All @@ -31,23 +33,27 @@
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [32])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
batch_size, output_len):
def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int):
"""Verify spec decode equality when cuda graphs are enabled.
"""
run_greedy_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True,
)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",
"model_name": "JackFram/llama-160m",

# Skip cuda graph recording for fast test.
"enforce_eager": True,
Expand Down Expand Up @@ -80,13 +86,19 @@ def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_speculative_model_quantization_config(baseline_llm_generator,
test_llm_generator,
batch_size: int):
def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int, seed: int):
"""Verify spec decode works well with draft model quantization configs.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=32,
force_output_len=True)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=32,
seed=seed,
temperature=0.0)
Loading
Loading