diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index ad99d70d7417c..ac04be3d9a689 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -47,10 +47,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str, @pytest.mark.parametrize("eager_mode", [False, True]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.parametrize("is_async", [False, True]) @pytest.mark.asyncio async def test_multi_step(example_prompts, model: str, tp_size: int, pp_size: int, eager_mode: int, - num_scheduler_steps: int, num_prompts: int): + num_scheduler_steps: int, num_prompts: int, + is_async: bool): prompts = example_prompts if len(prompts) < num_prompts: @@ -62,9 +64,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int, ms_server_args = DEFAULT_SERVER_ARGS + \ ["--num-scheduler-steps", f"{num_scheduler_steps}"] - # Disable output proc callback as its not supported - # with multi-step right now - ms_server_args += ["--disable-async-output-proc"] + if not is_async: + ms_server_args += ["--disable-async-output-proc"] + if eager_mode: ms_server_args.append("--enforce-eager") diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 51fde6e4eb7a3..4c2f715820317 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1107,10 +1107,7 @@ def schedule( if not self.cache_config.enable_prefix_caching: common_computed_block_nums = [] - # TODO: Combine multi-step and async postprocessor - allow_async_output_proc: bool = ( - self.use_async_output_proc - and not self.scheduler_config.is_multi_step) + allow_async_output_proc: bool = self.use_async_output_proc # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 37696bf1d9dc9..3058214c50a5f 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -279,6 +279,10 @@ async def step_async( scheduler_outputs = cached_outputs.scheduler_outputs allow_async_output_proc = cached_outputs.allow_async_output_proc + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + ctx = self.scheduler_contexts[virtual_engine] # skip the scheduler if there are any remaining steps in the seq groups. @@ -289,17 +293,27 @@ async def step_async( # Clear outputs on scheduler iteration start ctx.request_outputs.clear() + # Schedule iteration (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc ) = self.scheduler[virtual_engine].schedule() - # If current scheduler iteration has no async postprocessor, - # then we need first to drain the pending async postprocessor - # before moving forward + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + + # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) + # For async + multi-step, init the queue + if use_async_and_multi_step: + assert len(ctx.output_queue) == 0 + assert seq_group_metadata_list is not None + ctx.output_queue.append( + (None, seq_group_metadata_list, scheduler_outputs)) + if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): # cache the scheduler outputs for the next iteration if we have @@ -311,9 +325,6 @@ async def step_async( assert seq_group_metadata_list is not None assert scheduler_outputs is not None - assert not (self.scheduler_config.is_multi_step and \ - allow_async_output_proc) - if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() @@ -339,8 +350,13 @@ async def step_async( last_sampled_token_ids=last_sampled_token_ids) if allow_async_output_proc: - execute_model_req.async_callback = self.async_callback[ - virtual_engine] + async_callback = self.async_callback_multi_step[ + virtual_engine] if use_async_and_multi_step \ + else self.async_callback[virtual_engine] + + execute_model_req.async_callback = async_callback + execute_model_req.use_async_and_multi_step = \ + use_async_and_multi_step # Execute the model. output = await self.model_executor.execute_model_async( @@ -350,7 +366,7 @@ async def step_async( if self.scheduler_config.is_multi_step: self._update_cached_scheduler_output(virtual_engine, output) else: - if len(ctx.output_queue) > 0: + if not use_async_and_multi_step and len(ctx.output_queue) > 0: assert not self.scheduler_config.is_multi_step self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) @@ -362,22 +378,25 @@ async def step_async( seq_group.finish_step() if not self._has_remaining_steps(seq_group_metadata_list): - # clear the cache if we have finished all the steps + # Clear the cache if we have finished all the steps if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[ virtual_engine] = SchedulerOutputState() - # Cache results in engine - ctx.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs)) + if use_async_and_multi_step: + # For async + multi-step, clear the queue + ctx.output_queue.clear() + else: + ctx.output_queue.append( + (output, seq_group_metadata_list, scheduler_outputs)) - if output and allow_async_output_proc: - assert len( - output - ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 - self._advance_to_next_step( - output[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) + if output and allow_async_output_proc: + assert len( + output + ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) if not allow_async_output_proc: self._process_model_outputs(virtual_engine=virtual_engine, @@ -390,7 +409,11 @@ async def step_async( self.do_tracing(scheduler_outputs) else: - ctx.request_outputs = [] + # Multi-step case + if use_async_and_multi_step: + return [] + else: + ctx.request_outputs = [] if not self.has_unfinished_requests(): # Drain async postprocessor (if exists) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a6de8817946cc..92c02072593e6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -91,7 +91,8 @@ class SchedulerOutputState: @dataclass class SchedulerContext: - output_queue: Deque[Tuple[List[SamplerOutput], List[SequenceGroupMetadata], + output_queue: Deque[Tuple[Optional[List[SamplerOutput]], + List[SequenceGroupMetadata], SchedulerOutputs]] = field( default_factory=lambda: deque()) @@ -432,6 +433,13 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: for v_id in range(self.parallel_config.pipeline_parallel_size) ] + self.async_callback_multi_step = [ + functools.partial(self._process_model_outputs, + virtual_engine=v_id, + is_async=False) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -1240,28 +1248,49 @@ def _process_sequence_group_outputs( return - def _process_model_outputs(self, virtual_engine: int, - is_async: bool) -> None: + def _process_model_outputs(self, + virtual_engine: int, + is_async: bool, + sampler_output: Optional[SamplerOutput] = None, + is_last_output: bool = False) -> None: """Apply the model output to the sequences in the scheduled seq groups. virtual_engine: The engine id to operate on + is_async: Indicates whether this postprocessor runs in parallel with the GPU forward pass and is processing tokens from the previous step. If this is true, then no tokens need to be appended since it is already done externally (before the next schedule() call) + sampler_output: Used with multi-step execution to provide + sampler_output of each step + is_last_output: Used with multi-step execution to indicate + the last step (of each multi-step group) + Returns RequestOutputs that can be returned to the client. """ now = time.time() + is_multi_step = sampler_output is not None + ctx: SchedulerContext = self.scheduler_contexts[virtual_engine] if len(ctx.output_queue) == 0: return None - (outputs, seq_group_metadata_list, - scheduler_outputs) = ctx.output_queue.popleft() + if is_multi_step: + # Async + multi-step case + (outputs, seq_group_metadata_list, + scheduler_outputs) = ctx.output_queue[0] + assert outputs is None + outputs = [sampler_output] + else: + # Async standard case + (outputs, seq_group_metadata_list, + scheduler_outputs) = ctx.output_queue.popleft() + + assert outputs is not None # Sanity check assert len(seq_group_metadata_list) == len( @@ -1320,7 +1349,11 @@ def _process_model_outputs(self, virtual_engine: int, self.output_processor.process_outputs(seq_group, output, is_async) - # Free the finished sequence groups. + # For async + multi-step, free finished seqs and create outputs + # only on the final step. + if is_multi_step and not is_last_output: + return + for scheduler in self.scheduler: scheduler.free_finished_seq_groups() @@ -1328,7 +1361,7 @@ def _process_model_outputs(self, virtual_engine: int, for i, _ in enumerate(seq_group_metadata_list): scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - if i in finished_before: + if not is_multi_step and i in finished_before: continue # Avoids double processing seq_group = scheduled_seq_group.seq_group @@ -1342,7 +1375,11 @@ def _process_model_outputs(self, virtual_engine: int, request_output = RequestOutputFactory.create(seq_group) ctx.request_outputs.append(request_output) - if is_async: + # For async + multi-step, do stats only on the last output. + # Otherwise, do stats if the execution is async + do_stats = is_multi_step or is_async + + if do_stats: # Log stats. self.do_log_stats(scheduler_outputs, outputs, finished_before) @@ -1437,7 +1474,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: "as performance will be severely degraded otherwise.") # For llm_engine, there is no pipeline parallel support, so the engine - # used is always 0 + # used is always 0. virtual_engine = 0 # These are cached outputs from previous iterations. None if on first @@ -1447,6 +1484,10 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: scheduler_outputs = cached_outputs.scheduler_outputs allow_async_output_proc = cached_outputs.allow_async_output_proc + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + ctx = self.scheduler_contexts[virtual_engine] # Skip the scheduler if there are any remaining steps in the seq groups. @@ -1462,11 +1503,22 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: allow_async_output_proc ) = self.scheduler[virtual_engine].schedule() + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) + # For async + multi-step, init the queue + if use_async_and_multi_step: + assert len(ctx.output_queue) == 0 + assert seq_group_metadata_list is not None + ctx.output_queue.append( + (None, seq_group_metadata_list, scheduler_outputs)) + if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): # cache the scheduler outputs for the next iteration if we have @@ -1478,9 +1530,6 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: assert seq_group_metadata_list is not None assert scheduler_outputs is not None - assert not (self.scheduler_config.is_multi_step and \ - allow_async_output_proc) - if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() @@ -1505,8 +1554,13 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: last_sampled_token_ids=last_sampled_token_ids) if allow_async_output_proc: - execute_model_req.async_callback = self.async_callback[ - virtual_engine] + async_callback = self.async_callback_multi_step[ + virtual_engine] if use_async_and_multi_step \ + else self.async_callback[virtual_engine] + + execute_model_req.async_callback = async_callback + execute_model_req.use_async_and_multi_step = \ + use_async_and_multi_step output = self.model_executor.execute_model( execute_model_req=execute_model_req) @@ -1518,7 +1572,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: else: # Nothing scheduled => If there is pending async postprocessor, # then finish it here. - if len(ctx.output_queue) > 0: + if not use_async_and_multi_step and len(ctx.output_queue) > 0: assert not self.scheduler_config.is_multi_step self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) @@ -1535,18 +1589,23 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[0] = SchedulerOutputState() - # Add results to the output_queue - # (for async or non-async postprocessing) - ctx.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs)) + if use_async_and_multi_step: + # For async + multi-step, clear the queue + ctx.output_queue.clear() + else: + # Add results to the output_queue + # (for async or non-async postprocessing) + ctx.output_queue.append( + (output, seq_group_metadata_list, scheduler_outputs)) - if output and allow_async_output_proc: - assert len(output) == 1, ("Multi step decoding does not work " - "with async output processing.") + if output and allow_async_output_proc: + assert len(output) == 1, ( + "Multi step decoding does not work " + "with async output processing.") - self._advance_to_next_step( - output[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) # Check if need to run the usual non-async path if not allow_async_output_proc: @@ -1560,7 +1619,10 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: self.do_tracing(scheduler_outputs) else: # Multi-step case - ctx.request_outputs = [] + if use_async_and_multi_step: + return [] + else: + ctx.request_outputs = [] if not self.has_unfinished_requests(): # Drain async postprocessor (if exists) diff --git a/vllm/sequence.py b/vllm/sequence.py index 3125acc6fd535..e7cde87f605a7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1295,6 +1295,7 @@ class ExecuteModelRequest( last_sampled_token_ids: Optional[torch.Tensor] = None # Async callback async_callback: Optional[Callable] = None + use_async_and_multi_step: bool = False @property def is_first_multi_step(self) -> bool: @@ -1341,4 +1342,5 @@ def clone( finished_requests_ids=self.finished_requests_ids, last_sampled_token_ids=self.last_sampled_token_ids.clone() if self.last_sampled_token_ids is not None else None, - async_callback=self.async_callback) + async_callback=self.async_callback, + use_async_and_multi_step=self.use_async_and_multi_step) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index de1a2e3235a8c..43853063cfb40 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -92,6 +92,7 @@ class ModelInputForGPU(ModelRunnerInputBase): finished_requests_ids: Optional[List[str]] = None virtual_engine: int = 0 async_callback: Optional[Callable] = None + use_async_and_multi_step: bool = False def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 521205eca05af..0abca9d9f4558 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -1,5 +1,7 @@ +import dataclasses +import functools from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union try: from vllm.attention.backends.flash_attn import FlashAttentionMetadata @@ -215,6 +217,46 @@ def prepare_model_input( ) return model_input + def _async_process_outputs(self, model_input: StatefulModelInput, + output_proc_callback: Callable): + # Proceed with pythonization and output_proc in order. + # Stop on the first one that fails to pythonize + cont = True + for model_output in model_input.cached_outputs: + if not model_output.pythonized: + model_output.maybe_pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + if model_output.pythonized: + output_proc_callback( + sampler_output=model_output.sampler_output) + else: + cont = False + + if not cont: + break + + def _final_process_outputs(self, model_input: StatefulModelInput, + output_proc_callback: Optional[Callable]): + assert model_input.frozen_model_input is not None + + outputs = [] + for output_id in range(len(model_input.cached_outputs)): + is_last_output = output_id == len(model_input.cached_outputs) - 1 + + output = model_input.cached_outputs[output_id] + if not output.pythonized: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + + if model_input.frozen_model_input.use_async_and_multi_step: + assert output_proc_callback is not None + output_proc_callback(sampler_output=output.sampler_output, + is_last_output=is_last_output) + + outputs.append(output.sampler_output) + + return outputs + @torch.inference_mode() def execute_model( self, @@ -271,6 +313,20 @@ def execute_model( model_input = self._advance_step( model_input, model_input.cached_outputs[-1].sampler_output) + output_proc_callback = None + if frozen_model_input.use_async_and_multi_step: + output_proc_callback = frozen_model_input.async_callback + assert output_proc_callback is not None + async_callback = functools.partial( + self._async_process_outputs, + model_input=model_input, + output_proc_callback=output_proc_callback) + + frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input, + async_callback=async_callback) + assert frozen_model_input is not None + # Execute the model output = self._base_model_runner.execute_model(frozen_model_input, kv_caches, @@ -301,9 +357,11 @@ def execute_model( output[0].logprobs = None # Pythonize the output if CPU is ahead and the previous step is # ready. - for model_output in model_input.cached_outputs: - model_output.maybe_pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) + if not frozen_model_input.use_async_and_multi_step: + for model_output in model_input.cached_outputs: + model_output.maybe_pythonize(model_input, + self._copy_stream, + self.pinned_sampled_token_ids) model_input.current_step += 1 @@ -316,11 +374,8 @@ def execute_model( # Pythonize the output and block if needed since it is the last step if model_input.is_last_step: - outputs = [] - for output in model_input.cached_outputs: - output.pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) - outputs.append(output.sampler_output) + outputs = self._final_process_outputs(model_input, + output_proc_callback) return outputs # should be [SamplerOutput] diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index 2ed77dd698f5c..e0e421942f409 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -1,3 +1,4 @@ +import dataclasses from dataclasses import dataclass from typing import Dict, List, Optional, Tuple @@ -61,6 +62,13 @@ def _get_driver_input_and_broadcast( execute_model_req.seq_group_metadata_list, execute_model_req.virtual_engine, execute_model_req.finished_requests_ids)) + + if execute_model_req.async_callback: + model_input.frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input, + async_callback=execute_model_req.async_callback, + use_async_and_multi_step=execute_model_req. + use_async_and_multi_step) else: # on subsequent steps we reuse the worker input and model input multi_step_state = self.multi_step_states[virtual_engine]