From 627c81f27362824636caf3d4618f76f1a6342a0c Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Thu, 12 Sep 2024 13:35:19 +0000 Subject: [PATCH] Fix issue https://github.com/vllm-project/vllm/issues/8219 --- tests/basic_correctness/test_preemption.py | 1 + vllm/engine/llm_engine.py | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 7e77037da07d3..50d399bef1878 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -64,6 +64,7 @@ def test_chunked_prefill_recompute( enable_chunked_prefill=enable_chunked_prefill, max_num_seqs=max_num_seqs, worker_use_ray=worker_use_ray, + disable_log_stats=False, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1448e23c85beb..0880843c67a0d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1435,7 +1435,8 @@ def _process_model_outputs(self, # LLMEngine/AsyncLLMEngine directly if is_async: # Log stats. - self.do_log_stats(scheduler_outputs, outputs, finished_before) + self.do_log_stats(scheduler_outputs, outputs, finished_before, + skip) # Tracing self.do_tracing(scheduler_outputs) @@ -1743,18 +1744,20 @@ def remove_logger(self, logger_name: str) -> None: def do_log_stats(self, scheduler_outputs: Optional[SchedulerOutputs] = None, model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[List[int]] = None) -> None: + finished_before: Optional[List[int]] = None, + skip: Optional[List[int]] = None) -> None: """Forced log when no requests active.""" if self.log_stats: stats = self._get_stats(scheduler_outputs, model_output, - finished_before) + finished_before, skip) for logger in self.stat_loggers.values(): logger.log(stats) def _get_stats(self, scheduler_outputs: Optional[SchedulerOutputs], model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[List[int]] = None) -> Stats: + finished_before: Optional[List[int]] = None, + skip: Optional[List[int]] = None) -> Stats: """Get Stats to be Logged to Prometheus. Args: @@ -1836,6 +1839,11 @@ def _get_stats(self, actual_num_batched_tokens -= 1 continue + # Currently, skip == preempted sequences, so we need to skip + # their log stats + if skip and idx in skip: + continue + group_was_prefill = idx < scheduler_outputs.num_prefill_groups seq_group = scheduled_seq_group.seq_group