From 97111143657c340ea8e606f482e83a8d6973e4af Mon Sep 17 00:00:00 2001 From: Alexander Matveev <alexm@neuralmagic.com> Date: Thu, 11 Jul 2024 15:15:06 +0000 Subject: [PATCH] cleanups --- .../e2e/test_multistep_correctness.py | 2 +- vllm/model_executor/layers/sampler.py | 14 ++++------- vllm/model_executor/sampling_metadata.py | 4 +--- vllm/spec_decode/draft_model_runner.py | 23 +++++++++---------- vllm/worker/model_runner.py | 2 +- 5 files changed, 19 insertions(+), 26 deletions(-) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 7a5dea83821de..94cc36f22875a 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -158,7 +158,7 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator, "common_llm_kwargs", [{ # Skip cuda graph recording for fast test. - "enforce_eager": False, + "enforce_eager": True, # Required for spec decode. "use_v2_block_manager": True, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index cc9491f257575..8a19acc840f60 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -85,7 +85,7 @@ def forward( # In this case, we depend on the output tokens # TODO: Check with Cade if this is needed for spec tokens self._init_sampling_tensors(logits, sampling_metadata) - + sampling_tensors = self._sampling_tensors do_penalties = self._do_penalties do_top_p_top_k = self._do_top_p_top_k @@ -104,7 +104,7 @@ def forward( # Apply temperature scaling. # Use in-place division to avoid creating a new tensor. logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) - + if do_top_p_top_k: logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, sampling_tensors.top_ks) @@ -131,10 +131,6 @@ def forward( if self.include_gpu_probs_tensor: assert maybe_sampled_tokens_tensor is not None on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) - - # print(" -- maybe_sampled_tokens_tensor: shape = {} vals = {}".format( - # maybe_sampled_tokens_tensor.shape, - # maybe_sampled_tokens_tensor)) else: on_device_tensors = None @@ -811,9 +807,6 @@ def _get_logprobs( query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device) - # print("query_indices_gpu: shape = {}, vals = {}".format(query_indices_gpu.shape, query_indices_gpu)) - # print("next_token_ids_gpu: shape = {}, vals = {}".format(next_token_ids_gpu.shape, next_token_ids_gpu)) - # (num_selected_query_tokens, num_logprobs). Note that query_indices can # contain duplicates if beam search is enabled. selected_logprobs = logprobs[[ @@ -1057,6 +1050,9 @@ def _build_sampler_output( """ sampler_output: List[CompletionSequenceGroupOutput] = [] if not skip_cpu_samples: + assert prompt_logprobs is not None + assert sample_logprobs is not None + for (seq_group, sample_result, group_prompt_logprobs, group_sample_logprobs) in zip(sampling_metadata.seq_groups, sample_results, prompt_logprobs, diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 17fdb0948dfeb..6100ba7d8e9be 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -101,7 +101,7 @@ def __init__( self.selected_token_indices = selected_token_indices self.categorized_sample_indices = categorized_sample_indices self.num_prompts = num_prompts - + # TODO: Add docs self.skip_cpu_samples = skip_cpu_samples self.reuse_sampling_tensors = reuse_sampling_tensors @@ -134,8 +134,6 @@ def prepare( for t, seq_ids in categorized_sample_indices.items() } - # print(" selected_token_indices = {}".format(selected_token_indices)) - # print(" categorized_sample_indices = {}".format(categorized_sample_indices)) sampling_metadata = SamplingMetadata( seq_groups=seq_groups, selected_token_indices=selected_token_indices, diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 5855cf37d0945..224b471bf3262 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -2,6 +2,8 @@ import torch +from vllm import _custom_ops as ops +from vllm.attention.backends.flash_attn import FlashAttentionMetadata from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) @@ -10,14 +12,11 @@ from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, ModelRunner) -from vllm.attention.backends.flash_attn import FlashAttentionMetadata - -from vllm import _custom_ops as ops - logger = init_logger(__name__) log_advance_input = False -enable_advance_step = True +enable_advance_step = False + class TP1DraftModelRunner(ModelRunner): """Specialized model runner for speculative decoding draft model. @@ -177,16 +176,16 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs, assert len(sampling_metadata.seq_groups) == num_queries assert sampling_metadata.selected_token_indices.shape == ( num_queries, ) - # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed + # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 for i in range(num_queries): seq_group = sampling_metadata.seq_groups[i] - assert seq_group.is_prompt == False # No prompt + assert seq_group.is_prompt is False # No prompt assert seq_group.prompt_logprob_indices == [] # No prompt assert seq_group.sample_indices == [i] # Simple - assert seq_group.seq_len == None # Decode - assert seq_group.query_len == None # Decode + assert seq_group.seq_len is None # Decode + assert seq_group.query_len is None # Decode def _advance_step( self, model_input: ModelInputForGPUWithSamplingMetadata, @@ -262,7 +261,7 @@ def _advance_step( def _can_use_advance_step(self): if not enable_advance_step: return False - + # TODO: Add support for other attn backends if self.attn_backend.get_name() != "flash-attn": return False @@ -315,7 +314,7 @@ def _execute_model_with_advance_step( model_input.sampling_metadata) model_input.sampling_metadata.skip_cpu_samples = True - + # Sample the next token. outputs.append( self.model.sample( @@ -331,7 +330,7 @@ def _execute_model_with_advance_step( else: assert not model_input.is_prompt model_input = self._advance_step(model_input, outputs[-1]) - + model_input.sampling_metadata.reuse_sampling_tensors = True return outputs diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0cf247ec46a6b..d71ddb0fe279c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -383,7 +383,7 @@ def _prepare_model_input_tensors( if log_runner: print( - "Processing seq_group_id = {}, with seq_ids = {}, is_prompt = {}" + " Add seq_group_id = {}, with seq_ids = {}, is_prompt = {}" .format(seq_group_id, seq_ids, is_prompt)) for seq_id in seq_ids: