Skip to content

Commit

Permalink
cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-redhat committed Jul 11, 2024
1 parent bccd4b6 commit 9711114
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 26 deletions.
2 changes: 1 addition & 1 deletion tests/spec_decode/e2e/test_multistep_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator,
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": False,
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
Expand Down
14 changes: 5 additions & 9 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def forward(
# In this case, we depend on the output tokens
# TODO: Check with Cade if this is needed for spec tokens
self._init_sampling_tensors(logits, sampling_metadata)

sampling_tensors = self._sampling_tensors
do_penalties = self._do_penalties
do_top_p_top_k = self._do_top_p_top_k
Expand All @@ -104,7 +104,7 @@ def forward(
# Apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))

if do_top_p_top_k:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)
Expand All @@ -131,10 +131,6 @@ def forward(
if self.include_gpu_probs_tensor:
assert maybe_sampled_tokens_tensor is not None
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)

# print(" -- maybe_sampled_tokens_tensor: shape = {} vals = {}".format(
# maybe_sampled_tokens_tensor.shape,
# maybe_sampled_tokens_tensor))
else:
on_device_tensors = None

Expand Down Expand Up @@ -811,9 +807,6 @@ def _get_logprobs(
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)

# print("query_indices_gpu: shape = {}, vals = {}".format(query_indices_gpu.shape, query_indices_gpu))
# print("next_token_ids_gpu: shape = {}, vals = {}".format(next_token_ids_gpu.shape, next_token_ids_gpu))

# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# contain duplicates if beam search is enabled.
selected_logprobs = logprobs[[
Expand Down Expand Up @@ -1057,6 +1050,9 @@ def _build_sampler_output(
"""
sampler_output: List[CompletionSequenceGroupOutput] = []
if not skip_cpu_samples:
assert prompt_logprobs is not None
assert sample_logprobs is not None

for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
sample_results, prompt_logprobs,
Expand Down
4 changes: 1 addition & 3 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices
self.num_prompts = num_prompts

# TODO: Add docs
self.skip_cpu_samples = skip_cpu_samples
self.reuse_sampling_tensors = reuse_sampling_tensors
Expand Down Expand Up @@ -134,8 +134,6 @@ def prepare(
for t, seq_ids in categorized_sample_indices.items()
}

# print(" selected_token_indices = {}".format(selected_token_indices))
# print(" categorized_sample_indices = {}".format(categorized_sample_indices))
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
selected_token_indices=selected_token_indices,
Expand Down
23 changes: 11 additions & 12 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from vllm import _custom_ops as ops
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
Expand All @@ -10,14 +12,11 @@
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)

from vllm.attention.backends.flash_attn import FlashAttentionMetadata

from vllm import _custom_ops as ops

logger = init_logger(__name__)

log_advance_input = False
enable_advance_step = True
enable_advance_step = False


class TP1DraftModelRunner(ModelRunner):
"""Specialized model runner for speculative decoding draft model.
Expand Down Expand Up @@ -177,16 +176,16 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs,
assert len(sampling_metadata.seq_groups) == num_queries
assert sampling_metadata.selected_token_indices.shape == (
num_queries, )
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501

for i in range(num_queries):
seq_group = sampling_metadata.seq_groups[i]

assert seq_group.is_prompt == False # No prompt
assert seq_group.is_prompt is False # No prompt
assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple
assert seq_group.seq_len == None # Decode
assert seq_group.query_len == None # Decode
assert seq_group.seq_len is None # Decode
assert seq_group.query_len is None # Decode

def _advance_step(
self, model_input: ModelInputForGPUWithSamplingMetadata,
Expand Down Expand Up @@ -262,7 +261,7 @@ def _advance_step(
def _can_use_advance_step(self):
if not enable_advance_step:
return False

# TODO: Add support for other attn backends
if self.attn_backend.get_name() != "flash-attn":
return False
Expand Down Expand Up @@ -315,7 +314,7 @@ def _execute_model_with_advance_step(
model_input.sampling_metadata)

model_input.sampling_metadata.skip_cpu_samples = True

# Sample the next token.
outputs.append(
self.model.sample(
Expand All @@ -331,7 +330,7 @@ def _execute_model_with_advance_step(
else:
assert not model_input.is_prompt
model_input = self._advance_step(model_input, outputs[-1])

model_input.sampling_metadata.reuse_sampling_tensors = True

return outputs
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def _prepare_model_input_tensors(

if log_runner:
print(
"Processing seq_group_id = {}, with seq_ids = {}, is_prompt = {}"
" Add seq_group_id = {}, with seq_ids = {}, is_prompt = {}"
.format(seq_group_id, seq_ids, is_prompt))

for seq_id in seq_ids:
Expand Down

0 comments on commit 9711114

Please sign in to comment.