From a1986d131ed1ce3d410221a55b1c8b673c22dc6b Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Tue, 27 Aug 2024 17:43:49 +0000 Subject: [PATCH] format --- examples/offline_inference.py | 5 +- .../multi_step/test_correctness_async_llm.py | 10 ++-- vllm/engine/async_llm_engine.py | 53 +++++++++++++------ vllm/engine/llm_engine.py | 28 +++++----- vllm/worker/multi_step_model_runner.py | 18 ++++--- 5 files changed, 71 insertions(+), 43 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 7dee660296782..a39fd1f151e19 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,7 +11,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m", num_scheduler_steps=8, use_v2_block_manager=True, disable_async_output_proc=False) +llm = LLM(model="facebook/opt-125m", + num_scheduler_steps=8, + use_v2_block_manager=True, + disable_async_output_proc=False) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index ad99d70d7417c..ac04be3d9a689 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -47,10 +47,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str, @pytest.mark.parametrize("eager_mode", [False, True]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.parametrize("is_async", [False, True]) @pytest.mark.asyncio async def test_multi_step(example_prompts, model: str, tp_size: int, pp_size: int, eager_mode: int, - num_scheduler_steps: int, num_prompts: int): + num_scheduler_steps: int, num_prompts: int, + is_async: bool): prompts = example_prompts if len(prompts) < num_prompts: @@ -62,9 +64,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int, ms_server_args = DEFAULT_SERVER_ARGS + \ ["--num-scheduler-steps", f"{num_scheduler_steps}"] - # Disable output proc callback as its not supported - # with multi-step right now - ms_server_args += ["--disable-async-output-proc"] + if not is_async: + ms_server_args += ["--disable-async-output-proc"] + if eager_mode: ms_server_args.append("--enforce-eager") diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 080ae56e3f887..9315d74237a1f 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -279,6 +279,10 @@ async def step_async( scheduler_outputs = cached_outputs.scheduler_outputs allow_async_output_proc = cached_outputs.allow_async_output_proc + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + ctx = self.scheduler_contexts[virtual_engine] # skip the scheduler if there are any remaining steps in the seq groups. @@ -294,11 +298,22 @@ async def step_async( allow_async_output_proc ) = self.scheduler[virtual_engine].schedule() + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) + # For async + multi-step, init the queue + if use_async_and_multi_step: + assert len(ctx.output_queue) == 0 + assert seq_group_metadata_list is not None + ctx.output_queue.append( + (None, seq_group_metadata_list, scheduler_outputs)) + if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): # cache the scheduler outputs for the next iteration if we have @@ -310,9 +325,6 @@ async def step_async( assert seq_group_metadata_list is not None assert scheduler_outputs is not None - assert not (self.scheduler_config.is_multi_step and \ - allow_async_output_proc) - if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() @@ -340,6 +352,8 @@ async def step_async( if allow_async_output_proc: execute_model_req.async_callback = self.async_callback_data[ virtual_engine] + execute_model_req.use_async_and_multi_step = \ + use_async_and_multi_step # Execute the model. output = await self.model_executor.execute_model_async( @@ -349,7 +363,7 @@ async def step_async( if self.scheduler_config.is_multi_step: self._update_cached_scheduler_output(virtual_engine, output) else: - if len(ctx.output_queue) > 0: + if not use_async_and_multi_step and len(ctx.output_queue) > 0: assert not self.scheduler_config.is_multi_step self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) @@ -361,22 +375,25 @@ async def step_async( seq_group.finish_step() if not self._has_remaining_steps(seq_group_metadata_list): - # clear the cache if we have finished all the steps + # Clear the cache if we have finished all the steps if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[ virtual_engine] = SchedulerOutputState() - # Cache results in engine - ctx.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs)) + if use_async_and_multi_step: + # For async + multi-step, clear the queue + ctx.output_queue.clear() + else: + ctx.output_queue.append( + (output, seq_group_metadata_list, scheduler_outputs)) - if output and allow_async_output_proc: - assert len( - output - ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 - self._advance_to_next_step( - output[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) + if output and allow_async_output_proc: + assert len( + output + ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) if not allow_async_output_proc: self._process_model_outputs(virtual_engine=virtual_engine, @@ -389,7 +406,11 @@ async def step_async( self.do_tracing(scheduler_outputs) else: - ctx.request_outputs = [] + # Multi-step case + if use_async_and_multi_step: + return [] + else: + ctx.request_outputs = [] if not self.has_unfinished_requests(): # Drain async postprocessor (if exists) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a6ff042e5d2ec..88de9c906f8d2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1268,6 +1268,8 @@ def _process_model_outputs(self, (outputs, seq_group_metadata_list, scheduler_outputs) = ctx.output_queue.popleft() + assert outputs is not None + # Sanity check assert len(seq_group_metadata_list) == len( scheduler_outputs.scheduled_seq_groups) @@ -1325,20 +1327,19 @@ def _process_model_outputs(self, self.output_processor.process_outputs(seq_group, output, is_async) - # Free finished sequence groups. - if is_multi_step: - if is_last_output: - for scheduler in self.scheduler: - scheduler.free_finished_seq_groups() - else: - for scheduler in self.scheduler: - scheduler.free_finished_seq_groups() + # For async + multi-step, free finished seqs and create outputs + # only on the final step. + if is_multi_step and not is_last_output: + return + + for scheduler in self.scheduler: + scheduler.free_finished_seq_groups() # Create the outputs. for i, _ in enumerate(seq_group_metadata_list): scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - if i in finished_before: + if not is_multi_step and i in finished_before: continue # Avoids double processing seq_group = scheduled_seq_group.seq_group @@ -1354,10 +1355,7 @@ def _process_model_outputs(self, # For async + multi-step, do stats only on the last output. # Otherwise, do stats if the execution is async - if is_multi_step: - do_stats = is_last_output - else: - do_stats = is_async + do_stats = is_multi_step or is_async if do_stats: # Log stats. @@ -1493,6 +1491,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # For async + multi-step, init the queue if use_async_and_multi_step: assert len(ctx.output_queue) == 0 + assert seq_group_metadata_list is not None ctx.output_queue.append( (None, seq_group_metadata_list, scheduler_outputs)) @@ -1533,7 +1532,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: if allow_async_output_proc: execute_model_req.async_callback = self.async_callback_data[ virtual_engine] - execute_model_req.use_async_and_multi_step = use_async_and_multi_step + execute_model_req.use_async_and_multi_step = \ + use_async_and_multi_step output = self.model_executor.execute_model( execute_model_req=execute_model_req) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index cd5e514dfb705..91bea68a9528e 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -14,9 +14,9 @@ from vllm import _custom_ops as ops from vllm.distributed import get_pp_group from vllm.logger import init_logger -from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, - Logprob, SamplerOutput, SequenceGroupMetadata, - SequenceOutput, AsyncCallbackData) +from vllm.sequence import (AsyncCallbackData, CompletionSequenceGroupOutput, + IntermediateTensors, Logprob, SamplerOutput, + SequenceGroupMetadata, SequenceOutput) from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ( @@ -233,6 +233,8 @@ def _async_process_outputs(self, model_input: StatefulModelInput, def _final_process_outputs(self, model_input: StatefulModelInput, output_proc_callback: AsyncCallbackData): + assert model_input.frozen_model_input is not None + if output_proc_callback is not None: output_proc_fn = output_proc_callback.func output_proc_kw_args = output_proc_callback.kw_args @@ -325,13 +327,13 @@ def execute_model( frozen_model_input = dataclasses.replace( # type: ignore model_input.frozen_model_input, async_callback=async_callback) + assert frozen_model_input is not None # Execute the model - output = self._base_model_runner.execute_model( - frozen_model_input, - kv_caches, - intermediate_tensors, - num_steps=1) + output = self._base_model_runner.execute_model(frozen_model_input, + kv_caches, + intermediate_tensors, + num_steps=1) # record the event for the current step so that the next step can sync model_input.record_step_event(current_stream)