Skip to content

Commit

Permalink
[Core] Logprobs support in Multi-step (vllm-project#7652)
Browse files Browse the repository at this point in the history
  • Loading branch information
afeldman-nm authored Aug 30, 2024
1 parent 4abed65 commit 428dd14
Show file tree
Hide file tree
Showing 103 changed files with 874 additions and 378 deletions.
43 changes: 26 additions & 17 deletions tests/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union

from vllm.sequence import SampleLogprobs
from vllm.sequence import Logprob, SampleLogprobs

TokensText = Tuple[List[int], str]

Expand Down Expand Up @@ -38,34 +38,39 @@ def check_outputs_equal(
float]],
SampleLogprobs]]]

# Allow for tokens to be represented as str's rather than IDs
TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
List[Dict[str,
Logprob]]]]]


def check_logprobs_close(
*,
outputs_0_lst: Sequence[TokensTextLogprobs],
outputs_1_lst: Sequence[TokensTextLogprobs],
outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
name_0: str,
name_1: str,
num_outputs_0_skip_tokens: int = 0,
warn_on_mismatch: bool = True,
):
"""
Compare the logprobs of two sequences generated by different models,
always_check_logprobs: bool = False,
) -> None:
"""Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
Arguments:
* outputs_0_lst: First sequence to compare
* outputs_0_lst: Second sequence to compare
* name_0: sequence #0 name
* name_1: sequence #1 name
* num_outputs_0_skip_tokens: If > 0, specifies the number of initial
Args:
outputs_0_lst: First sequence to compare
outputs_0_lst: Second sequence to compare
name_0: sequence #0 name
name_1: sequence #1 name
num_outputs_0_skip_tokens: If > 0, specifies the number of initial
sequence #0 tokens & logprobs to discard
before comparison, i.e. all
of sequence #1 will be compared to
sequence #0 beginning at index
num_outputs_0_skip_tokens
* warn_on_mismatch: Issue a warning if there is token-wise or text-wise
warn_on_mismatch: Issue a warning if there is token-wise or text-wise
mismatch between the two sequences
always_check_logprobs: If true, check logprobs even when tokens match
"""
assert len(outputs_0_lst) == len(outputs_1_lst)

Expand Down Expand Up @@ -94,8 +99,12 @@ def check_logprobs_close(
for idx, (output_id_0,
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):

# If generated tokens don't match, then
if output_id_0 != output_id_1:
is_tok_mismatch = output_id_0 != output_id_1

# If generated tokens don't match
# or it is desired to always check logprobs,
# then
if is_tok_mismatch or always_check_logprobs:
logprobs_elem_0 = logprobs_0[idx]
logprobs_elem_1 = logprobs_1[idx]

Expand All @@ -111,7 +120,7 @@ def check_logprobs_close(
assert output_id_0 in logprobs_elem_1, fail_msg
assert output_id_1 in logprobs_elem_0, fail_msg

if warn_on_mismatch:
if warn_on_mismatch and is_tok_mismatch:
with warnings.catch_warnings():
# This ensures that repeated warnings are shown
# in the output, not just the first occurrence
Expand Down
99 changes: 69 additions & 30 deletions tests/multi_step/test_correctness_async_llm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Test the AsyncLLMEngine with multi-step-decoding

from typing import List
from typing import List, Optional

import pytest

from ..utils import RemoteOpenAIServer
from ..models.utils import check_logprobs_close
from ..utils import (completions_with_server_args, get_client_text_generations,
get_client_text_logprob_generations)

MODELS = [
"JackFram/llama-160m",
Expand All @@ -23,22 +25,6 @@
]


async def completions_with_server_args(prompts: List[str], model_name: str,
server_cli_args: List[str]):

outputs = None
with RemoteOpenAIServer(model_name, server_cli_args) as server:
async with server.get_async_client() as client:
outputs = await client.completions.create(model=model_name,
prompt=prompts,
temperature=0,
stream=False,
max_tokens=5)
assert outputs is not None

return outputs


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize(("tp_size, pp_size"), [
(1, 1),
Expand All @@ -47,12 +33,43 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
@pytest.mark.parametrize("eager_mode", [False, True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs", [None, 5])
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio
async def test_multi_step(example_prompts, model: str, tp_size: int,
pp_size: int, eager_mode: int,
num_scheduler_steps: int, num_prompts: int,
is_async: bool):
async def test_multi_step(
example_prompts,
model: str,
tp_size: int,
pp_size: int,
eager_mode: int,
num_scheduler_steps: int,
num_prompts: int,
is_async: bool,
num_logprobs: Optional[int],
) -> None:
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
client/server environment.
Set up an engine with single-step scheduling as a ground-truth reference.
Send a completions API request to both engines with the same prompts.
Validate:
* Generated tokens match
* Generated logprobs are all very close
Args:
example_prompts: test fixture providing example prompts
model: model under test (same for single- and multi-step engines)
tp_size: degree of tensor-parallelism
pp_size: degree of pipeline-parallelism
eager_mode
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
GPU -> CPU output transfer
num_prompts: number of example prompts under test
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> no logprobs
"""

prompts = example_prompts
if len(prompts) < num_prompts:
Expand All @@ -77,14 +94,36 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
str(pp_size),
]

# Spin up client/server & issue completion API requests.
# Default `max_wait_seconds` is 240 but was empirically
# was raised 3x to 720 *just for this test* due to
# observed timeouts in GHA CI
ref_completions = await completions_with_server_args(
prompts, model, server_args + distributed_args)
prompts,
model,
server_args + distributed_args,
num_logprobs,
max_wait_seconds=3 * 240)
test_completions = await completions_with_server_args(
prompts, model, ms_server_args + distributed_args)

def get_text_generations(completions):
return [x.text for x in completions.choices]

ref_generations = get_text_generations(ref_completions)
test_generations = get_text_generations(test_completions)
prompts,
model,
ms_server_args + distributed_args,
num_logprobs,
max_wait_seconds=3 * 240)

# Assert multi-step scheduling produces identical tokens
# to single-step scheduling.
ref_generations = get_client_text_generations(ref_completions)
test_generations = get_client_text_generations(test_completions)
assert ref_generations == test_generations

# Assert multi-step scheduling produces nearly-identical logprobs
# to single-step scheduling.
ref_text_logprobs = get_client_text_logprob_generations(ref_completions)
test_text_logprobs = get_client_text_logprob_generations(test_completions)
check_logprobs_close(
outputs_0_lst=ref_text_logprobs,
outputs_1_lst=test_text_logprobs,
name_0="hf",
name_1="vllm",
)
95 changes: 74 additions & 21 deletions tests/multi_step/test_correctness_llm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Test the LLMEngine with multi-step-decoding

from typing import Optional

import pytest

from ..models.utils import check_outputs_equal
from ..models.utils import check_logprobs_close, check_outputs_equal

MODELS = [
"JackFram/llama-160m",
Expand All @@ -18,32 +20,83 @@
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
def test_multi_step_llm(hf_runner, vllm_runner, example_prompts, model: str,
dtype: str, tp_size: int, max_tokens: int,
enforce_eager: int, num_scheduler_steps: int,
num_prompts: int) -> None:
@pytest.mark.parametrize("num_logprobs", [None, 5])
def test_multi_step_llm(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
tp_size: int,
max_tokens: int,
enforce_eager: int,
num_scheduler_steps: int,
num_prompts: int,
num_logprobs: Optional[int],
) -> None:
"""Test vLLM engine with multi-step scheduling via sync LLM Engine.
Set up a HuggingFace (HF) transformers model as a ground-truth reference.
Prompt them with the same example prompts.
Validate:
* Generated tokens match
* Generated logprobs are all very close
Args:
hf_runner: HF transformers model runner fixture
vllm_runner: vLLM model runner fixture
example_prompts: test fixture providing example prompts
model: model under test (same for single- and multi-step engines)
dtype: tensor datatype for engine to utilize
tp_size: degree of tensor-parallelism
max_tokens: the maximum number of tokens to generate
enforce_eager
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
GPU -> CPU output transfer
num_prompts: number of example prompts under test
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> no logprobs
"""

prompts = example_prompts
if len(prompts) < num_prompts:
prompts = prompts * ((num_prompts // len(prompts)) + 1)
prompts = prompts[:num_prompts]
assert len(prompts) == num_prompts

with vllm_runner(model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
use_v2_block_manager=True,
num_scheduler_steps=num_scheduler_steps) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
with vllm_runner(
model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
use_v2_block_manager=True,
num_scheduler_steps=num_scheduler_steps,
) as vllm_model:
vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens)
if num_logprobs is None else
vllm_model.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs))

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
hf_outputs = (hf_model.generate_greedy(prompts, max_tokens)
if num_logprobs is None else
hf_model.generate_greedy_logprobs_limit(
prompts, max_tokens, num_logprobs))

if num_logprobs is None:
check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
else:
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
3 changes: 2 additions & 1 deletion tests/spec_decode/test_multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import pytest
import torch

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
SamplerOutput, get_all_seq_ids)
get_all_seq_ids)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
Expand Down
3 changes: 2 additions & 1 deletion tests/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import pytest
import torch

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput
from vllm.sequence import ExecuteModelRequest, SequenceOutput
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics)
Expand Down
4 changes: 2 additions & 2 deletions tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import torch

from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceData, SequenceGroupMetadata,
SequenceOutput)
SequenceData, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
Expand Down
5 changes: 3 additions & 2 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import pytest

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, SamplerOutput,
SequenceData, SequenceOutput)
CompletionSequenceGroupOutput, SequenceData,
SequenceOutput)

from .core.utils import create_dummy_prompt

Expand Down
Loading

0 comments on commit 428dd14

Please sign in to comment.