Skip to content

Commit

Permalink
[Core] Refactoring sampler and support prompt logprob for chunked pre…
Browse files Browse the repository at this point in the history
  • Loading branch information
rkooo567 authored Apr 26, 2024
1 parent a88081b commit 603ad84
Show file tree
Hide file tree
Showing 18 changed files with 862 additions and 633 deletions.
44 changes: 41 additions & 3 deletions tests/samplers/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,44 @@

@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
def test_get_prompt_logprobs(
hf_runner,
vllm_runner,
model,
dtype,
chunked_prefill_token_size: int,
num_top_logprobs: int,
example_prompts,
):
max_num_seqs = 256
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
max_num_batched_tokens = chunked_prefill_token_size

max_tokens = 5
num_top_logprobs = 6
hf_model = hf_runner(model, dtype=dtype)
hf_logprobs = hf_model.generate_greedy_logprobs(
example_prompts,
max_tokens=max_tokens,
)
del hf_model

vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs)
vllm_model = vllm_runner(
model,
dtype=dtype,
max_logprobs=num_top_logprobs,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,
)
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=5,
prompt_logprobs=num_top_logprobs,
temperature=0.0)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)
Expand All @@ -52,9 +70,18 @@ def test_get_prompt_logprobs(
"The output text from the top logprob for each token position "
"should be the same as the output text in the result.")

# The first prompt logprob is always None
assert result.prompt_logprobs[0] is None
for prompt_logprobs in result.prompt_logprobs[1:]:
# If the prompt token is not included in the top X
# logprob, it can return 1 more data
assert (len(prompt_logprobs) == num_top_logprobs
or len(prompt_logprobs) == num_top_logprobs + 1)

# Test whether prompt logprobs are consistent with HF
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
# Check prompt logprobs
# The first prompt logprob is always None, so we compare it from 1:.
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
for token_id, logprob in vllm_prompt_logprob_dict.items():
Expand All @@ -74,6 +101,17 @@ def test_get_prompt_logprobs(
"The token should be decoded by the time it is returned "
" to the user.")

# Test if prompt logprobs are correctly set.
for vllm_result in vllm_results:
token_ids = vllm_result.prompt_token_ids
prompt_logprobs = vllm_result.prompt_logprobs

# The first token doesn't have logprob.
assert prompt_logprobs[0] is None

for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]):
assert token_id in logprob_dict


def test_max_logprobs():
runner = VllmRunner("facebook/opt-125m", max_logprobs=1)
Expand Down
47 changes: 31 additions & 16 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from transformers import GenerationConfig, GenerationMixin

from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import Counter
Expand Down Expand Up @@ -54,6 +55,7 @@ def _do_sample(
sampler: MockLogitsSampler,
model_runner: ModelRunner,
sampling_params: SamplingParams,
device: str,
):
seq_group_metadata_list = []
prompt_lens = []
Expand All @@ -68,9 +70,12 @@ def _do_sample(
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
device=device,
pin_memory=model_runner.pin_memory)
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)


Expand All @@ -85,7 +90,7 @@ def test_sampler_all_greedy(seed: int, device: str):

sampling_params = SamplingParams(temperature=0)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params)
sampling_params, device)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
Expand All @@ -111,7 +116,7 @@ def test_sampler_all_random(seed: int, device: str):
n=random.randint(1, 10),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params)
sampling_params, device)

for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
Expand All @@ -137,7 +142,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params)
sampling_params, device)

for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
Expand All @@ -160,10 +165,10 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
seed=random.randint(0, 10000),
)
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params)
model_runner, sampling_params, device)

second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params)
model_runner, sampling_params, device)

assert first_sampler_output == second_sampler_output

Expand All @@ -183,7 +188,8 @@ def test_sampler_all_beam(seed: int, device: str):
best_of=2,
use_beam_search=True,
)
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params)
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params,
device)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
Expand Down Expand Up @@ -443,10 +449,12 @@ def run_test_case(*,
"batch size")

_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
sampling_metadata = model_runner._prepare_sample(
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens=prompt_lens if prompt_lens else None,
subquery_lens=prompt_lens if prompt_lens else None)
subquery_lens=prompt_lens if prompt_lens else None,
device=device,
pin_memory=model_runner.pin_memory)
# the logits tensor is modified in-place by the sampler
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)

Expand Down Expand Up @@ -530,8 +538,12 @@ def test_sampler_mixed(seed: int, device: str):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

def test_sampling(model_runner: ModelRunner):
sampling_metadata = model_runner._prepare_sample(
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
device=device,
pin_memory=model_runner.pin_memory)
sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)

Expand Down Expand Up @@ -627,9 +639,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
device=device,
pin_memory=model_runner.pin_memory)

sample_probs = None

Expand Down
10 changes: 7 additions & 3 deletions tests/test_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner
Expand Down Expand Up @@ -82,9 +83,12 @@ def pick_ith(token_ids, logits):
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
device=model_runner.device,
pin_memory=model_runner.pin_memory)
logits_processor_output = logits_processor(
embedding=None,
hidden_states=input_tensor,
Expand Down
19 changes: 13 additions & 6 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch

from vllm.config import ModelConfig, SchedulerConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size

Expand Down Expand Up @@ -97,9 +98,12 @@ def test_prepare_prompt(batch_size):
assert len(input_positions) == sum(prompt_lens)
torch.testing.assert_close(input_tokens, input_positions)

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
device=model_runner.device,
pin_memory=model_runner.pin_memory)
assert len(input_tokens) == sum(prompt_lens)
assert len(input_positions) == sum(prompt_lens)
actual = sampling_metadata.selected_token_indices
Expand Down Expand Up @@ -195,9 +199,12 @@ def test_prepare_decode_cuda_graph(batch_size):
for prompt_len in prompt_lens:
expected_selected_token_indices.append(selected_token_start_idx)
selected_token_start_idx += 1
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
device=model_runner.device,
pin_memory=model_runner.pin_memory)
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
device=actual.device,
Expand Down
15 changes: 15 additions & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,20 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
self.block_manager.get_common_computed_block_ids(
seq_group.get_seqs(status=SequenceStatus.RUNNING)))

do_sample = True
if seq_group.is_prefill():
seqs = seq_group.get_seqs()
# Prefill has only 1 sequence.
assert len(seqs) == 1
# In the next iteration, all prompt tokens are not computed.
# It means the prefill is chunked, and we don't need sampling.
# NOTE: We use get_len instead of get_prompt_len because when
# a sequence is preempted, prefill includes previous generated
# output tokens.
if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
seqs[0].data.get_len()):
do_sample = False

# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
is_prompt = seq_group.is_prefill()
Expand All @@ -924,6 +938,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
seq_data=seq_data,
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
do_sample=do_sample,
token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums,
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ async def step_async(self) -> List[RequestOutput]:

request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups)
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)

# Log stats.
if self.log_stats:
Expand Down
25 changes: 13 additions & 12 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
SequenceGroup, SequenceStage)
SequenceGroup, SequenceGroupMetadata)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
Expand Down Expand Up @@ -476,9 +476,12 @@ def has_unfinished_requests(self) -> bool:
return self.scheduler.has_unfinished_seqs()

def _process_model_outputs(
self, output: List[SamplerOutput],
scheduled_seq_groups: List[SequenceGroup],
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
self,
output: List[SamplerOutput],
scheduled_seq_groups: List[SequenceGroup],
ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> List[RequestOutput]:
"""Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client.
Expand All @@ -492,17 +495,15 @@ def _process_model_outputs(
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))

# Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs in zip(scheduled_seq_groups,
output_by_sequence_group):
for scheduled_seq_group, outputs, seq_group_meta in zip(
scheduled_seq_groups, output_by_sequence_group,
seq_group_metadata_list):
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)

# If all sequences in the sequence group are in DECODE, then we can
# process the output tokens. Otherwise, they are (chunked) prefill
# samples and should not be processed.
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
if all(stage == SequenceStage.DECODE for stage in stages):
self.output_processor.process_prompt_logprob(seq_group, outputs)
if seq_group_meta.do_sample:
self.output_processor.process_outputs(seq_group, outputs)

# Free the finished sequence groups.
Expand Down Expand Up @@ -585,7 +586,7 @@ def step(self) -> List[RequestOutput]:

request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups)
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)

# Log stats.
if self.log_stats:
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/output_processor/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,9 @@ def process_outputs(self, sequence_group: SequenceGroup,
scheduler.
"""
pass

@abstractmethod
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Update prompt logprobs received from outputs to seq_group."""
pass
9 changes: 9 additions & 0 deletions vllm/engine/output_processor/multi_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ def __init__(
self.get_tokenizer_for_seq = get_tokenizer_for_seq
self.stop_checker = stop_checker

def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
# TODO(sang): Prompt logprob currently not implemented in multi step
# workers.
logger.warning(
"Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers).")
pass

def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Append new tokens in the outputs to sequences in the sequence group.
Expand Down
Loading

0 comments on commit 603ad84

Please sign in to comment.