From 97111143657c340ea8e606f482e83a8d6973e4af Mon Sep 17 00:00:00 2001
From: Alexander Matveev <alexm@neuralmagic.com>
Date: Thu, 11 Jul 2024 15:15:06 +0000
Subject: [PATCH] cleanups

---
 .../e2e/test_multistep_correctness.py         |  2 +-
 vllm/model_executor/layers/sampler.py         | 14 ++++-------
 vllm/model_executor/sampling_metadata.py      |  4 +---
 vllm/spec_decode/draft_model_runner.py        | 23 +++++++++----------
 vllm/worker/model_runner.py                   |  2 +-
 5 files changed, 19 insertions(+), 26 deletions(-)

diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py
index 7a5dea83821de..94cc36f22875a 100644
--- a/tests/spec_decode/e2e/test_multistep_correctness.py
+++ b/tests/spec_decode/e2e/test_multistep_correctness.py
@@ -158,7 +158,7 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator,
     "common_llm_kwargs",
     [{
         # Skip cuda graph recording for fast test.
-        "enforce_eager": False,
+        "enforce_eager": True,
 
         # Required for spec decode.
         "use_v2_block_manager": True,
diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py
index cc9491f257575..8a19acc840f60 100644
--- a/vllm/model_executor/layers/sampler.py
+++ b/vllm/model_executor/layers/sampler.py
@@ -85,7 +85,7 @@ def forward(
                 # In this case, we depend on the output tokens
                 # TODO: Check with Cade if this is needed for spec tokens
                 self._init_sampling_tensors(logits, sampling_metadata)
-            
+
         sampling_tensors = self._sampling_tensors
         do_penalties = self._do_penalties
         do_top_p_top_k = self._do_top_p_top_k
@@ -104,7 +104,7 @@ def forward(
         # Apply temperature scaling.
         # Use in-place division to avoid creating a new tensor.
         logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
-        
+
         if do_top_p_top_k:
             logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
                                         sampling_tensors.top_ks)
@@ -131,10 +131,6 @@ def forward(
         if self.include_gpu_probs_tensor:
             assert maybe_sampled_tokens_tensor is not None
             on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
-
-            # print("  -- maybe_sampled_tokens_tensor: shape = {} vals = {}".format(
-            #     maybe_sampled_tokens_tensor.shape,
-            #     maybe_sampled_tokens_tensor))
         else:
             on_device_tensors = None
 
@@ -811,9 +807,6 @@ def _get_logprobs(
     query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
     next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)
 
-    # print("query_indices_gpu: shape = {}, vals = {}".format(query_indices_gpu.shape, query_indices_gpu))
-    # print("next_token_ids_gpu: shape = {}, vals = {}".format(next_token_ids_gpu.shape, next_token_ids_gpu))
-
     # (num_selected_query_tokens, num_logprobs). Note that query_indices can
     # contain duplicates if beam search is enabled.
     selected_logprobs = logprobs[[
@@ -1057,6 +1050,9 @@ def _build_sampler_output(
     """
     sampler_output: List[CompletionSequenceGroupOutput] = []
     if not skip_cpu_samples:
+        assert prompt_logprobs is not None
+        assert sample_logprobs is not None
+
         for (seq_group, sample_result, group_prompt_logprobs,
              group_sample_logprobs) in zip(sampling_metadata.seq_groups,
                                            sample_results, prompt_logprobs,
diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py
index 17fdb0948dfeb..6100ba7d8e9be 100644
--- a/vllm/model_executor/sampling_metadata.py
+++ b/vllm/model_executor/sampling_metadata.py
@@ -101,7 +101,7 @@ def __init__(
         self.selected_token_indices = selected_token_indices
         self.categorized_sample_indices = categorized_sample_indices
         self.num_prompts = num_prompts
-        
+
         # TODO: Add docs
         self.skip_cpu_samples = skip_cpu_samples
         self.reuse_sampling_tensors = reuse_sampling_tensors
@@ -134,8 +134,6 @@ def prepare(
             for t, seq_ids in categorized_sample_indices.items()
         }
 
-        # print("  selected_token_indices = {}".format(selected_token_indices))
-        # print("  categorized_sample_indices = {}".format(categorized_sample_indices))
         sampling_metadata = SamplingMetadata(
             seq_groups=seq_groups,
             selected_token_indices=selected_token_indices,
diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py
index 5855cf37d0945..224b471bf3262 100644
--- a/vllm/spec_decode/draft_model_runner.py
+++ b/vllm/spec_decode/draft_model_runner.py
@@ -2,6 +2,8 @@
 
 import torch
 
+from vllm import _custom_ops as ops
+from vllm.attention.backends.flash_attn import FlashAttentionMetadata
 from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
                          ModelConfig, ParallelConfig, SchedulerConfig,
                          VisionLanguageConfig)
@@ -10,14 +12,11 @@
 from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
                                       ModelRunner)
 
-from vllm.attention.backends.flash_attn import FlashAttentionMetadata
-
-from vllm import _custom_ops as ops
-
 logger = init_logger(__name__)
 
 log_advance_input = False
-enable_advance_step = True
+enable_advance_step = False
+
 
 class TP1DraftModelRunner(ModelRunner):
     """Specialized model runner for speculative decoding draft model.
@@ -177,16 +176,16 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs,
         assert len(sampling_metadata.seq_groups) == num_queries
         assert sampling_metadata.selected_token_indices.shape == (
             num_queries, )
-        # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed
+        # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
 
         for i in range(num_queries):
             seq_group = sampling_metadata.seq_groups[i]
 
-            assert seq_group.is_prompt == False  # No prompt
+            assert seq_group.is_prompt is False  # No prompt
             assert seq_group.prompt_logprob_indices == []  # No prompt
             assert seq_group.sample_indices == [i]  # Simple
-            assert seq_group.seq_len == None  # Decode
-            assert seq_group.query_len == None  # Decode
+            assert seq_group.seq_len is None  # Decode
+            assert seq_group.query_len is None  # Decode
 
     def _advance_step(
             self, model_input: ModelInputForGPUWithSamplingMetadata,
@@ -262,7 +261,7 @@ def _advance_step(
     def _can_use_advance_step(self):
         if not enable_advance_step:
             return False
-        
+
         # TODO: Add support for other attn backends
         if self.attn_backend.get_name() != "flash-attn":
             return False
@@ -315,7 +314,7 @@ def _execute_model_with_advance_step(
                                                model_input.sampling_metadata)
 
             model_input.sampling_metadata.skip_cpu_samples = True
-            
+
             # Sample the next token.
             outputs.append(
                 self.model.sample(
@@ -331,7 +330,7 @@ def _execute_model_with_advance_step(
                 else:
                     assert not model_input.is_prompt
                     model_input = self._advance_step(model_input, outputs[-1])
-                    
+
                     model_input.sampling_metadata.reuse_sampling_tensors = True
 
         return outputs
diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py
index 0cf247ec46a6b..d71ddb0fe279c 100644
--- a/vllm/worker/model_runner.py
+++ b/vllm/worker/model_runner.py
@@ -383,7 +383,7 @@ def _prepare_model_input_tensors(
 
             if log_runner:
                 print(
-                    "Processing seq_group_id = {}, with seq_ids = {}, is_prompt = {}"
+                    " Add seq_group_id = {}, with seq_ids = {}, is_prompt = {}"
                     .format(seq_group_id, seq_ids, is_prompt))
 
             for seq_id in seq_ids: