Skip to content

Commit

Permalink
add Cody's mock test
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-redhat committed Jul 17, 2024
1 parent 959c02c commit bb9c4d8
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 28 deletions.
30 changes: 8 additions & 22 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ def get_output_from_llm_generator(
llm_generator,
prompts,
sampling_params,
ensure_gpu_advance_used: bool = False,
) -> Tuple[List[str], List[List[int]]]:
tokens: List[str] = []
token_ids: List[List[int]] = []
Expand All @@ -218,11 +217,6 @@ def get_output_from_llm_generator(
token_ids = [output.outputs[0].token_ids for output in outputs]
tokens = [output.outputs[0].text for output in outputs]

if ensure_gpu_advance_used:
num_gpu_runs = (llm.llm_engine.model_executor.driver_worker.
proposer_worker.model_runner._num_gpu_runs)
assert num_gpu_runs > 0

del llm

return tokens, token_ids
Expand All @@ -242,14 +236,12 @@ def get_logprobs_from_llm_generator(
return logprobs


def run_greedy_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
print_tokens: bool = False,
ensure_gpu_advance_used: bool = False):
def run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
print_tokens: bool = False):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
Expand Down Expand Up @@ -280,17 +272,11 @@ def run_greedy_equality_correctness_test(
)

spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
test_llm_generator,
prompts,
sampling_params,
ensure_gpu_advance_used=ensure_gpu_advance_used)
test_llm_generator, prompts, sampling_params)

(baseline_batch_tokens,
baseline_batch_token_ids) = get_output_from_llm_generator(
baseline_llm_generator,
prompts,
sampling_params,
ensure_gpu_advance_used=False)
baseline_llm_generator, prompts, sampling_params)

assert len(baseline_batch_token_ids) == len(prompts)
assert len(spec_batch_token_ids) == len(prompts)
Expand Down
3 changes: 1 addition & 2 deletions tests/spec_decode/e2e/test_multistep_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True,
ensure_gpu_advance_used=True)
force_output_len=True)


@pytest.mark.parametrize(
Expand Down
48 changes: 48 additions & 0 deletions tests/spec_decode/test_multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,3 +642,51 @@ def test_draft_proposals_mixed_k():
assert proposals.proposal_lens.tolist() == [
k for _ in range(expected_num_proposal_seqs - 1)
] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]


@torch.inference_mode()
def test_use_draft_model_runner_advance_step():
"""Verify that draft model runner triggers advance step
when applicable.
"""
seed = 100
model_name = 'JackFram/llama-68m'

k = 5
batch_size = 32
block_size = 32
num_gpu_blocks = 2048 // block_size
worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)

# Mock "_gpu_advance_step" to raise an exception when called.
exception_secret = "artificial stop"
worker.model_runner._gpu_advance_step = MagicMock()
worker.model_runner._gpu_advance_step.side_effect = ValueError(
exception_secret)

seq_group_metadata_list, _, _ = create_batch(batch_size, k)

# Fallback (should not call) when num_steps=1.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
num_steps=1)
worker.execute_model(execute_model_req=execute_model_req)

# Expect exception if _gpu_advance_step is called.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
num_steps=k)

with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
call_args_list = worker.model_runner._gpu_advance_step.call_args_list
assert len(call_args_list) == 1
4 changes: 0 additions & 4 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ def __init__(
return_hidden_states=return_hidden_states,
)

# Used mainly for tests (has no perf penalty)
self._num_gpu_runs = 0

def _update_flash_attn_metadata(self, attn_metadata, num_seqs,
num_queries):
assert isinstance(attn_metadata, FlashAttentionMetadata)
Expand Down Expand Up @@ -125,7 +122,6 @@ def _gpu_advance_step(
self, model_input: ModelInputForGPUWithSamplingMetadata,
last_output: SamplerOutput
) -> ModelInputForGPUWithSamplingMetadata:
self._num_gpu_runs += 1
# Currently, we expect "decode mode" only
assert not model_input.is_prompt

Expand Down

0 comments on commit bb9c4d8

Please sign in to comment.