diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index a3b76327e0a53..79cc0cd8fea8f 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -21,7 +21,7 @@ def append_new_token(seq_group, token_id: int): def schedule_and_update_computed_tokens(scheduler): - metas, out = scheduler.schedule() + metas, out, _, _ = scheduler.schedule() for s, meta in zip(out.scheduled_seq_groups, metas): s.seq_group.update_num_computed_tokens(meta.token_chunk_size) return metas, out @@ -180,7 +180,7 @@ def test_maximal_decoding(): """Verify decoding requests are prioritized.""" block_size = 4 max_seqs = 2 - max_model_len = 2 + max_model_len = 8 max_num_batched_tokens = 2 scheduler_config = SchedulerConfig(max_num_batched_tokens, max_seqs, diff --git a/tests/core/utils.py b/tests/core/utils.py index 12b66d50749db..3c3cc4f6f4d50 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -199,7 +199,7 @@ def append_new_token(out, token_id: int): def schedule_and_update_computed_tokens(scheduler): - metas, out = scheduler.schedule() + metas, out, _, _ = scheduler.schedule() for s, meta in zip(out.scheduled_seq_groups, metas): s.seq_group.update_num_computed_tokens(meta.token_chunk_size) return metas, out diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 4d54e43d5788c..2f0adb02bc353 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -49,6 +49,8 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, str(TP_SIZE), "--distributed-executor-backend", DIST_BACKEND, + # disable output proc callback to test PP + "--disable-output-proc-callback", ] # compare without pipeline parallelism diff --git a/tests/distributed/test_pp_cudagraph.py b/tests/distributed/test_pp_cudagraph.py index 4912858d8279e..b01ae8aa36d1e 100644 --- a/tests/distributed/test_pp_cudagraph.py +++ b/tests/distributed/test_pp_cudagraph.py @@ -22,6 +22,8 @@ def test_pp_cudagraph(PP_SIZE, MODEL_NAME, ATTN_BACKEND): str(PP_SIZE), "--distributed-executor-backend", "mp", + # disable output proc callback to test PP + "--disable-output-proc-callback", ] os.environ["VLLM_ATTENTION_BACKEND"] = ATTN_BACKEND diff --git a/tests/engine/test_stop_strings.py b/tests/engine/test_stop_strings.py index 1584b85aeb064..c5dc49252ce93 100644 --- a/tests/engine/test_stop_strings.py +++ b/tests/engine/test_stop_strings.py @@ -98,6 +98,10 @@ def _test_stopping(llm_engine: LLMEngine, output: Optional[CompletionOutput] = None output_text = "" stop_reason = None + + # Run first (because of async callback) + llm_engine.step() + while llm_engine.has_unfinished_requests(): (request_output, ) = llm_engine.step() (output, ) = request_output.outputs diff --git a/tests/multi_step/test_correctness.py b/tests/multi_step/test_correctness.py index bc14311c66424..085203dd4de51 100644 --- a/tests/multi_step/test_correctness.py +++ b/tests/multi_step/test_correctness.py @@ -60,7 +60,7 @@ async def test_multi_step(example_prompts, model: str, tp_size: int, server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"] ms_server_args = DEFAULT_SERVER_ARGS + \ - ["--num-scheduler-steps", f"{num_scheduler_steps}"] + ["--num-scheduler-steps", f"{num_scheduler_steps}"]#, "--disable-output-proc-callback"] if eager_mode: ms_server_args.append("--enforce-eager") @@ -82,4 +82,11 @@ def get_text_generations(completions): ref_generations = get_text_generations(ref_completions) test_generations = get_text_generations(test_completions) + + print("ref_generations:") + for gen in ref_generations: + print("ref_gen: {}".format(gen)) + print("test_generations:") + for gen in test_generations: + print("test_gen: {}".format(gen)) assert ref_generations == test_generations diff --git a/vllm/config.py b/vllm/config.py index a5a9984a0114a..3da41dbf5b35a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -137,6 +137,7 @@ def __init__( skip_tokenizer_init: bool = False, served_model_name: Optional[Union[str, List[str]]] = None, limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_output_proc_callback: Optional[bool] = True, ) -> None: self.model = model self.tokenizer = tokenizer @@ -167,6 +168,7 @@ def __init__( code_revision, rope_scaling, rope_theta) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + self.use_output_proc_callback = use_output_proc_callback # Choose a default enforce_eager value if the user did not specify # a value (enforce_eager is None) @@ -320,6 +322,30 @@ def _verify_cuda_graph(self) -> None: self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_model_len) + def verify_output_proc_callback(self, speculative_config, + device_config) -> None: + if device_config.device_type != "cuda": + logger.warning( + "Output proc callback can only be enabled with CUDA") + self.use_output_proc_callback = False + return + if self.enforce_eager: + logger.warning( + "To see benefits of output processor callback, enable CUDA " + "graph. Since, enforce-eager is enabled, output processor " + "callback cannot be used") + self.use_output_proc_callback = not self.enforce_eager + return + # Async postprocessor is not necessary with embedding mode + # since there is no token generation + if self.embedding_mode: + self.use_output_proc_callback = False + + if speculative_config: + self.use_output_proc_callback = False + + # TO DO: assert mp backend + def verify_with_parallel_config( self, parallel_config: "ParallelConfig", @@ -352,6 +378,11 @@ def verify_with_parallel_config( "fallback to the eager mode.") self.enforce_eager = True + if (pipeline_parallel_size > 1) and (self.use_output_proc_callback): + raise NotImplementedError( + "Output processor callback is not supported with " + "pipeline parallelism currently. Disable the callback.") + def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled.""" @@ -1754,6 +1785,8 @@ class EngineConfig: def __post_init__(self): """Verify configs are valid & consistent with each other. """ + self.model_config.verify_output_proc_callback(self.speculative_config, + self.device_config) self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 3b716e32032c1..d979b6d12d12a 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -4,7 +4,8 @@ import time from collections import deque from dataclasses import dataclass, field -from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set, + Tuple, Union) from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager @@ -293,13 +294,12 @@ def scheduled_seq_group_builder(): class Scheduler: - def __init__( - self, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - lora_config: Optional[LoRAConfig], - pipeline_parallel_size: int = 1, - ) -> None: + def __init__(self, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + lora_config: Optional[LoRAConfig], + pipeline_parallel_size: int = 1, + output_proc_callback_fn: Optional[Callable] = None) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config # Note for LoRA scheduling: the current policy is extremely @@ -364,10 +364,35 @@ def __init__( self.num_cumulative_preemption: int = 0 # Used to cache python objects - self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache( - scheduler_running_outputs_builder) - self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache( - scheduled_seq_group_builder) + self._seq_group_metadata_cache: List[PyObjectCache] = [] + self._scheduler_running_outputs_cache: List[PyObjectCache] = [] + self._scheduled_seq_group_cache: List[PyObjectCache] = [] + + # For async output processing, we need to swap cache buffers between + # iterations. I.e. since the output processing is lagged one step, + # we cannot reuse the cached objects immediately when the schedule() + # is called again, but only when schedule() is called the second time. + self.output_proc_callback_fn = output_proc_callback_fn + self.use_output_proc_callback = self.output_proc_callback_fn is not None + self.num_cache_iters = 2 if self.use_output_proc_callback else 1 + + self.cache_id = 0 + for i in range(self.num_cache_iters): + self._seq_group_metadata_cache.append( + PyObjectCache(seq_group_metadata_builder)) + self._scheduler_running_outputs_cache.append( + PyObjectCache(scheduler_running_outputs_builder)) + self._scheduled_seq_group_cache.append( + PyObjectCache(scheduled_seq_group_builder)) + + # Avoid deque alloc + self.tmp_queue: Deque[SequenceGroup] = deque() + + self._async_stopped: List[SequenceGroup] = [] + + @property + def next_cache_id(self): + return (self.cache_id + 1) % self.num_cache_iters @property def lora_enabled(self) -> bool: @@ -483,7 +508,7 @@ def _schedule_running( SchedulerRunningOutputs. """ ret: SchedulerRunningOutputs = \ - self._scheduler_running_outputs_cache.get_object() + self._scheduler_running_outputs_cache[self.cache_id].get_object() ret.blocks_to_swap_out.clear() ret.blocks_to_copy.clear() ret.decode_seq_groups.clear() @@ -510,8 +535,12 @@ def _schedule_running( # NOTE(woosuk): Preemption happens only when there is no available slot # to keep all the sequence groups in the RUNNING state. - running_queue = self.running + # Store original running requests for the case of async + preemption + if self.use_output_proc_callback: + orig_running = self.running.copy() + running_queue = self.running + assert len(self._async_stopped) == 0 while running_queue: seq_group = running_queue[0] num_running_tokens = self._get_num_new_tokens( @@ -521,6 +550,28 @@ def _schedule_running( break running_queue.popleft() + + # With async postprocessor, an extra decode run is done + # to process the final tokens. The check below avoids this extra + # decode run when the model max len is reached, in order to avoid + # a memory overflow. + if self.use_output_proc_callback and seq_group.seqs[0].get_len( + ) > self.scheduler_config.max_model_len: + self._async_stopped.append(seq_group) + continue + + # With async postprocessor, when preemption kicks in, we need + # first to drain the async postprocessor, so that all async + # block_table freeing is applied before the preemption freeing + # is applied. + if self.use_output_proc_callback and not self._can_append_slots( + seq_group): + tmp = self.running + self.running = orig_running + assert self.output_proc_callback_fn is not None + self.output_proc_callback_fn(is_async=True) + self.running = tmp + while not self._can_append_slots(seq_group): budget.subtract_num_batched_tokens(seq_group.request_id, num_running_tokens) @@ -556,7 +607,7 @@ def _schedule_running( is_prefill = seq_group.is_prefill() scheduled_seq_group: ScheduledSequenceGroup = \ - self._scheduled_seq_group_cache.get_object() + self._scheduled_seq_group_cache[self.cache_id].get_object() scheduled_seq_group.seq_group = seq_group if is_prefill: scheduled_seq_group.token_chunk_size = num_running_tokens @@ -579,8 +630,8 @@ def _schedule_running( if curr_loras is not None and seq_group.lora_int_id > 0: curr_loras.add(seq_group.lora_int_id) - self._scheduler_running_outputs_cache.reset() - self._scheduled_seq_group_cache.reset() + self._scheduler_running_outputs_cache[self.next_cache_id].reset() + self._scheduled_seq_group_cache[self.next_cache_id].reset() return ret @@ -1031,17 +1082,33 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), ) - def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: + def _allow_output_proc_callback(self, seq_group: SequenceGroup) -> bool: + no_beam_search = (seq_group.sampling_params.n == 1 + and not seq_group.sampling_params.use_beam_search) + + return no_beam_search + + def schedule( + self + ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, List[Tuple[ + ScheduledSequenceGroup, SequenceGroupMetadata]], bool]: # Schedule sequence groups. # This function call changes the internal states of the scheduler # such as self.running, self.swapped, and self.waiting. scheduler_start_time = time.perf_counter() + scheduler_outputs = self._schedule() now = time.time() if not self.cache_config.enable_prefix_caching: common_computed_block_nums = [] + # TODO: Combine multi-step and async postprocessor + allow_output_proc_callback: bool = self.use_output_proc_callback + + # Create list of scheduled request ids + scheduled_ids: List[Tuple[ScheduledSequenceGroup, + SequenceGroupMetadata]] = [] # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] for i, scheduled_seq_group in enumerate( @@ -1050,6 +1117,11 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: token_chunk_size = scheduled_seq_group.token_chunk_size seq_group.maybe_set_first_scheduled_time(now) + seq_group_metadata = self._seq_group_metadata_cache[ + self.cache_id].get_object() + seq_group_metadata.seq_data.clear() + seq_group_metadata.block_tables.clear() + # seq_id -> SequenceData seq_data: Dict[int, SequenceData] = {} # seq_id -> physical block numbers @@ -1139,6 +1211,11 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: ) seq_group_metadata_list.append(seq_group_metadata) + if allow_output_proc_callback: + allow_output_proc_callback = self._allow_output_proc_callback( + seq_group) + + scheduled_ids.append((scheduled_seq_group, seq_group_metadata)) # Now that the batch has been created, we can assume all blocks in the # batch will have been computed before the next scheduling invocation. # This is because the engine assumes that a failure in model execution @@ -1147,6 +1224,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: self.block_manager.mark_blocks_as_computed( scheduled_seq_group.seq_group) + self._seq_group_metadata_cache[self.next_cache_id].reset() + scheduler_time = time.perf_counter() - scheduler_start_time # Add this to scheduler time to all the sequences that are currently # running. This will help estimate if the scheduler is a significant @@ -1158,7 +1237,12 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: else: seq_group.metrics.scheduler_time = scheduler_time - return seq_group_metadata_list, scheduler_outputs + # Move to next cache (if exists) + self.cache_id = self.next_cache_id + + # Return results + return (seq_group_metadata_list, scheduler_outputs, scheduled_ids, + allow_output_proc_callback) def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: self.block_manager.fork(parent_seq, child_seq) @@ -1168,7 +1252,7 @@ def free_seq(self, seq: Sequence) -> None: self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: - remaining: Deque[SequenceGroup] = deque() + self.tmp_queue.clear() for seq_group in self.running: if seq_group.is_finished(): # Free cross-attention block table, if it exists @@ -1178,8 +1262,30 @@ def free_finished_seq_groups(self) -> None: # next step. self._finished_requests_ids.append(seq_group.request_id) else: - remaining.append(seq_group) - self.running = remaining + self.tmp_queue.append(seq_group) + # Free finished seqs + for seq in seq_group.get_seqs(): + if seq.is_finished(): + self.free_seq(seq) + + # Swap + q = self.running + self.running = self.tmp_queue + q.clear() + self.tmp_queue = q + + # Handle async stopped sequence groups + # (ones that reached max model len) + if self.use_output_proc_callback and len(self._async_stopped) > 0: + for seq_group in self._async_stopped: + self._free_seq_group_cross_attn_blocks(seq_group) + self._finished_requests_ids.append(seq_group.request_id) + + for seq in seq_group.get_seqs(): + if seq.is_finished(): + self.free_seq(seq) + + self._async_stopped.clear() def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: self.block_manager.allocate(seq_group) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7f45c3d06375a..8d24436bf72d9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -147,6 +147,7 @@ class EngineArgs: otlp_traces_endpoint: Optional[str] = None collect_detailed_traces: Optional[str] = None + disable_output_proc_callback: Optional[bool] = False def __post_init__(self): if self.tokenizer is None: @@ -735,6 +736,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "modules. This involves use of possibly costly and or blocking " "operations and hence might have a performance impact.") + parser.add_argument( + '--disable-output-proc-callback', + action='store_true', + default=False, + help="Disable async output processing. This may result in " + "lower performance.") return parser @classmethod @@ -794,6 +801,7 @@ def create_engine_config(self, ) -> EngineConfig: skip_tokenizer_init=self.skip_tokenizer_init, served_model_name=self.served_model_name, limit_mm_per_prompt=self.limit_mm_per_prompt, + use_output_proc_callback=not self.disable_output_proc_callback, ) cache_config = CacheConfig( block_size=self.block_size, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 6385d3ca2297e..5f334d802d6d8 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -12,7 +12,7 @@ import vllm.envs as envs from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from vllm.core.scheduler import SchedulerOutputs +from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine, @@ -258,6 +258,9 @@ class SchedulerOutputState: last_output: Optional[SamplerOutput] = None seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None scheduler_outputs: Optional[SchedulerOutputs] = None + scheduled_ids: Optional[List[Tuple[ScheduledSequenceGroup, + SequenceGroupMetadata]]] = None + allow_output_proc_callback: bool = False class _AsyncLLMEngine(LLMEngine): @@ -288,22 +291,39 @@ async def step_async( cached_outputs = self.cached_scheduler_outputs[virtual_engine] seq_group_metadata_list = cached_outputs.seq_group_metadata_list scheduler_outputs = cached_outputs.scheduler_outputs + scheduled_ids = cached_outputs.scheduled_ids + allow_output_proc_callback = cached_outputs.allow_output_proc_callback + # skip the scheduler if there are any remaining steps in the seq groups. # This ensures that the scheduler is only called again when the current # batch has completed. if not self._has_remaining_steps(seq_group_metadata_list): - seq_group_metadata_list, scheduler_outputs = self.scheduler[ - virtual_engine].schedule() + (seq_group_metadata_list, scheduler_outputs, scheduled_ids, + allow_output_proc_callback + ) = self.scheduler[virtual_engine].schedule() + + self.request_outputs.clear() if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): # cache the scheduler outputs for the next iteration if we have # lookahead slots self._cache_scheduler_outputs_for_multi_step( - virtual_engine, seq_group_metadata_list, scheduler_outputs) + virtual_engine, seq_group_metadata_list, scheduler_outputs, + scheduled_ids, allow_output_proc_callback) + + if self.scheduler_config.is_multi_step and allow_output_proc_callback: + assert len(self.output_queue) == 0 + self.output_queue.append( + (None, scheduled_ids, scheduler_outputs)) assert seq_group_metadata_list is not None assert scheduler_outputs is not None + assert scheduled_ids is not None + + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_output_proc_callback) if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ @@ -328,6 +348,11 @@ async def step_async( # We use ExecuteModelRequest to pass the last sampled_token_ids # to each of the non-last PP stages for in-place prepare_input. last_sampled_token_ids=last_sampled_token_ids) + + if allow_output_proc_callback: + execute_model_req.callback_fn = self._process_model_outputs + execute_model_req.use_async_and_multi_step = use_async_and_multi_step + # Execute the model. output = await self.model_executor.execute_model_async( execute_model_req) @@ -336,6 +361,8 @@ async def step_async( if self.scheduler_config.is_multi_step: self._update_cached_scheduler_output(virtual_engine, output) else: + if not use_async_and_multi_step and len(self.output_queue) > 0: + self._process_model_outputs(is_async=True) output = [] # Finish the current step for all the sequence groups. @@ -348,19 +375,35 @@ async def step_async( if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[ virtual_engine] = SchedulerOutputState() - request_outputs = self._process_model_outputs( - output, scheduler_outputs.scheduled_seq_groups, - scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) - else: - request_outputs = [] - # Log stats. - self.do_log_stats(scheduler_outputs, output) + if use_async_and_multi_step: + self.output_queue.clear() - # Tracing - self.do_tracing(scheduler_outputs) + if not use_async_and_multi_step: + self.output_queue.append( + (output, scheduled_ids, scheduler_outputs)) + + if (len(output) > 0) and allow_output_proc_callback: + assert len( + output + ) == 1, "Multi step decoding does not work with output processor callback" # noqa: E501 + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) + + if not allow_output_proc_callback: + self._process_model_outputs(is_async=False) + + # Log stats. + self.do_log_stats(scheduler_outputs, output) + + # Tracing + self.do_tracing(scheduler_outputs) + + else: + return [] - return request_outputs + return self.request_outputs def _has_remaining_steps( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] @@ -385,12 +428,20 @@ def _has_remaining_steps( def _cache_scheduler_outputs_for_multi_step( self, virtual_engine: int, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - scheduler_outputs: SchedulerOutputs) -> None: + scheduler_outputs: SchedulerOutputs, + scheduled_ids: Optional[List[Tuple[ScheduledSequenceGroup, + SequenceGroupMetadata]]], + allow_output_proc_callback: bool) -> None: + v = virtual_engine self.cached_scheduler_outputs[ virtual_engine].seq_group_metadata_list = seq_group_metadata_list - self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \ + self.cached_scheduler_outputs[v].scheduler_outputs = \ scheduler_outputs - self.cached_scheduler_outputs[virtual_engine].last_output = None + self.cached_scheduler_outputs[v].scheduled_ids = \ + scheduled_ids + self.cached_scheduler_outputs[v].allow_output_proc_callback = \ + allow_output_proc_callback + self.cached_scheduler_outputs[v].last_output = None def _get_last_sampled_token_ids( self, virtual_engine: int) -> Optional[torch.Tensor]: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 36cb6ce795f3e..559986c33b97b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,6 +1,7 @@ import time +from collections import deque from contextlib import contextmanager -from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, +from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List, Mapping, Optional) from typing import Sequence as GenericSequence from typing import Set, Tuple, Type, TypeVar, Union @@ -36,9 +37,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - PoolerOutput, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupMetadata, - SequenceStatus) + SamplerOutput, Sequence, SequenceGroup, + SequenceGroupMetadata, SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -178,6 +178,7 @@ def __init__( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, + step_return_finished_only=False, ) -> None: logger.info( "Initializing an LLM engine (v%s) with config: " @@ -192,7 +193,7 @@ def __init__( "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " - "enable_prefix_caching=%s)", + "enable_prefix_caching=%s use_output_proc_callback=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -222,6 +223,7 @@ def __init__( model_config.served_model_name, scheduler_config.use_v2_block_manager, cache_config.enable_prefix_caching, + model_config.use_output_proc_callback, ) # TODO(woosuk): Print more configs in debug mode. from vllm.plugins import load_general_plugins @@ -240,6 +242,7 @@ def __init__( self.observability_config = observability_config or ObservabilityConfig( ) self.log_stats = log_stats + self.step_return_finished_only = step_return_finished_only if not self.model_config.skip_tokenizer_init: self.tokenizer = self._init_tokenizer() @@ -327,8 +330,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: # NOTE: the cache_config here have been updated with the numbers of # GPU and CPU blocks, which are profiled in the distributed executor. self.scheduler = [ - Scheduler(scheduler_config, cache_config, lora_config, - parallel_config.pipeline_parallel_size) + Scheduler( + scheduler_config, cache_config, lora_config, + parallel_config.pipeline_parallel_size, + self._process_model_outputs + if model_config.use_output_proc_callback else None) for _ in range(parallel_config.pipeline_parallel_size) ] @@ -378,6 +384,14 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: ), )) + # Output processing callback pointers + self.output_queue: Deque[Tuple[List[SamplerOutput], + List[Tuple[ScheduledSequenceGroup, + SequenceGroupMetadata]], + SchedulerOutputs]] = deque() + self.request_outputs: List[Union[RequestOutput, + EmbeddingRequestOutput]] = [] + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -1169,32 +1183,57 @@ def _process_sequence_group_outputs( return - def _process_model_outputs( - self, - output: GenericSequence[Union[SamplerOutput, PoolerOutput]], - scheduled_seq_groups: List[ScheduledSequenceGroup], - ignored_seq_groups: List[SequenceGroup], - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + def _process_model_outputs(self, + is_async: bool, + sampler_output: Optional[SamplerOutput] = None, + is_last_output: bool = False) -> None: """Apply the model output to the sequences in the scheduled seq groups. Returns RequestOutputs that can be returned to the client. """ - now = time.time() - # Organize outputs by [sequence group][step] instead of - # [step][sequence group]. - output_by_sequence_group = create_output_by_sequence_group( - output, num_seq_groups=len(scheduled_seq_groups)) + if sampler_output is None: + self.request_outputs.clear() + + if len(self.output_queue) == 0: + return None + + if sampler_output is not None: + (outputs, scheduled_ids, scheduler_outputs) = self.output_queue[0] + assert outputs is None + outputs = [sampler_output] + else: + (outputs, scheduled_ids, + scheduler_outputs) = self.output_queue.popleft() + + # Organize outputs by [step][sequence group] instead of + # [sequence group][step]. + if len(outputs) > 1: + outputs_by_sequence_group = create_output_by_sequence_group( + outputs, num_seq_groups=len(scheduled_ids)) + else: + outputs_by_sequence_group = outputs - # Update the scheduled sequence groups with the model outputs. - for scheduled_seq_group, outputs, seq_group_meta in zip( - scheduled_seq_groups, output_by_sequence_group, - seq_group_metadata_list): + output = [None] + finished_before: List[int] = [] + for i, (scheduled_seq_group, + seq_group_meta) in enumerate(scheduled_ids): seq_group = scheduled_seq_group.seq_group - seq_group.update_num_computed_tokens( - scheduled_seq_group.token_chunk_size) + + if seq_group.is_finished(): + finished_before.append(i) + continue + + if len(outputs) > 1: + output = outputs_by_sequence_group[i] + else: + output[0] = outputs[0][i] + + if not is_async: + seq_group.update_num_computed_tokens( + scheduled_seq_group.token_chunk_size) + if output is not None and len(output) > 0: for o in output: if (isinstance(o, SamplerOutput) @@ -1211,30 +1250,82 @@ def _process_model_outputs( else: seq_group.metrics.model_execute_time = ( o.model_execute_time) + if self.model_config.embedding_mode: - self._process_sequence_group_outputs(seq_group, outputs) + self._process_sequence_group_outputs(seq_group, output) continue - self.output_processor.process_prompt_logprob(seq_group, outputs) + self.output_processor.process_prompt_logprob(seq_group, output) if seq_group_meta.do_sample: - self.output_processor.process_outputs(seq_group, outputs) + self.output_processor.process_outputs(seq_group, output, + is_async) + + if sampler_output is not None and not is_last_output: + return # Free the finished sequence groups. for scheduler in self.scheduler: scheduler.free_finished_seq_groups() # Create the outputs. - request_outputs: List[Union[RequestOutput, - EmbeddingRequestOutput]] = [] - for scheduled_seq_group in scheduled_seq_groups: + for i, (scheduled_seq_group, _) in enumerate(scheduled_ids): + if sampler_output is None: + if i in finished_before: + continue # Avoids double processing + seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) + if (seq_group.is_finished() + if self.step_return_finished_only else True): + request_output = RequestOutputFactory.create(seq_group) + self.request_outputs.append(request_output) + + for seq_group in scheduler_outputs.ignored_seq_groups: request_output = RequestOutputFactory.create(seq_group) - request_outputs.append(request_output) - for seq_group in ignored_seq_groups: - request_output = RequestOutputFactory.create(seq_group) - request_outputs.append(request_output) - return request_outputs + self.request_outputs.append(request_output) + + if is_async: + # Log stats. + self.do_log_stats(scheduler_outputs, outputs) + + # Tracing + self.do_tracing(scheduler_outputs) + + return None + + def _advance_to_next_step( + self, output: List[SamplerOutput], + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: + """Given model output from a single run, append the tokens to the + sequences. This is normally done inside output processor, but it is + required if the worker is to perform async forward pass to next step. + """ + for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \ + zip(seq_group_metadata_list, output, scheduled_seq_groups): + seq_group = scheduled_seq_group.seq_group + + if not seq_group.is_finished(): + seq_group.update_num_computed_tokens( + seq_group_metadata.token_chunk_size) + + if not seq_group.is_finished() and seq_group_metadata.do_sample: + assert len(sequence_group_outputs.samples) == 1, ( + "output_proc_callback expects a single sample" + " (i.e sampling_params.n == 1 and no " + "sampling_params.best_of > 1)") + seq_group_metadata.is_prompt = False + seq_output = sequence_group_outputs.samples[0] + + # NOTE: Beam search is not supported, so we can assume that + # parent_seq_id == seq_id. + seq_data = seq_group_metadata.seq_data[ + seq_output.parent_seq_id] + + token_id = seq_output.output_token + token_logprob = seq_output.logprobs[token_id] + + seq_data.append_token_id(token_id, token_logprob.logprob) def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. @@ -1291,8 +1382,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: raise NotImplementedError( "Pipeline parallelism is only supported through AsyncLLMEngine " "as performance will be severely degraded otherwise.") - seq_group_metadata_list, scheduler_outputs = self.scheduler[ - 0].schedule() + (seq_group_metadata_list, scheduler_outputs, scheduled_ids, + allow_output_proc_callback) = self.scheduler[0].schedule() if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ @@ -1305,20 +1396,36 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, finished_requests_ids=finished_requests_ids) + + if allow_output_proc_callback: + execute_model_req.callback_fn = self._process_model_outputs + output = self.model_executor.execute_model( execute_model_req=execute_model_req) else: + if len(self.output_queue) > 0: + self._process_model_outputs(is_async=True) output = [] - request_outputs = self._process_model_outputs( - output, scheduler_outputs.scheduled_seq_groups, - scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) + # Add results to the output_queue + # (for async or non-async postprocessing) + self.output_queue.append((output, scheduled_ids, scheduler_outputs)) - # Log stats. - self.do_log_stats(scheduler_outputs, output) + if (len(output) > 0) and allow_output_proc_callback: + assert len(output) == 1, ("Multi step decoding does not work " + "with output processor callback") - # Tracing - self.do_tracing(scheduler_outputs) + self._advance_to_next_step(output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) + + if not allow_output_proc_callback: + self._process_model_outputs(is_async=False) + + # Log stats. + self.do_log_stats(scheduler_outputs, output) + + # Tracing + self.do_tracing(scheduler_outputs) if not self.has_unfinished_requests(): # Stop the execute model loop in parallel workers until there are @@ -1328,7 +1435,10 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # queued control plane messages, such as add/remove lora adapters. self.model_executor.stop_remote_worker_execution_loop() - return request_outputs + # Do not process the extra +1 runs (due to async postprocessor) + self.output_queue.clear() + + return self.request_outputs def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: if logger_name in self.stat_loggers: diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 92aecebe6ec38..4e64bfc2085f6 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -25,13 +25,10 @@ class SequenceGroupOutputProcessor(ABC): @staticmethod def create_output_processor( - scheduler_config: SchedulerConfig, - detokenizer: Detokenizer, - scheduler: List[Scheduler], - seq_counter: Counter, - get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], - stop_checker: "StopChecker", - ): + scheduler_config: SchedulerConfig, detokenizer: Detokenizer, + scheduler: List[Scheduler], seq_counter: Counter, + get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], + stop_checker: "StopChecker"): """Create an output processor. This returns a single-step output processor if num_lookahead_slots is @@ -41,13 +38,9 @@ def create_output_processor( # Importing here to avoid cycle. from vllm.engine.output_processor.single_step import ( SingleStepOutputProcessor) - return SingleStepOutputProcessor( - scheduler_config, - detokenizer, - scheduler, - seq_counter, - stop_checker, - ) + return SingleStepOutputProcessor(scheduler_config, detokenizer, + scheduler, seq_counter, + stop_checker) else: # Importing here to avoid cycle. from vllm.engine.output_processor.multi_step import ( @@ -62,7 +55,8 @@ def create_output_processor( @abstractmethod def process_outputs(self, sequence_group: SequenceGroup, - outputs: List[SequenceGroupOutput]) -> None: + outputs: List[SequenceGroupOutput], + is_async: bool) -> None: """Process new token ids for the sequence group. Handles logic such as detokenization, stop checking, and freeing/forking sequences in the scheduler. diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 25d15df9f915d..35d8507518fa4 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -31,14 +31,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): number of new output tokens per sequence differs in a single batch. """ - def __init__( - self, - detokenizer: Detokenizer, - scheduler: List[Scheduler], - seq_counter: Counter, - get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], - stop_checker: StopChecker, - ): + def __init__(self, detokenizer: Detokenizer, scheduler: List[Scheduler], + seq_counter: Counter, + get_tokenizer_for_seq: Callable[[Sequence], + PreTrainedTokenizer], + stop_checker: StopChecker): self.detokenizer = detokenizer self.scheduler = scheduler self.seq_counter = seq_counter @@ -58,8 +55,10 @@ def _log_prompt_logprob_unsupported_warning_once(): "Prompt logprob is not supported by multi step workers. " "(e.g., speculative decode uses multi step workers).") - def process_outputs(self, sequence_group: SequenceGroup, - outputs: List[SequenceGroupOutput]) -> None: + def process_outputs(self, + sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput], + is_async: bool = False) -> None: """Append new tokens in the outputs to sequences in the sequence group. This only supports sequence groups of size 1. It supports greater than @@ -69,6 +68,9 @@ def process_outputs(self, sequence_group: SequenceGroup, including freeing finished sequences. It also handles cases where there are tokens emitted after the EOS token. """ + # TODO: Add support for async if necessary + assert not is_async + seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) assert seqs, "expected running sequences" @@ -139,7 +141,3 @@ def _process_seq_outputs(self, seq: Sequence, ) if seq.is_finished(): break - - if seq.is_finished(): - for scheduler in self.scheduler: - scheduler.free_seq(seq) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 4a46c93f84256..b9c1b8cf553c8 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -29,14 +29,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): that is currently difficult to schedule multiple steps ahead of time. """ - def __init__( - self, - scheduler_config: SchedulerConfig, - detokenizer: Detokenizer, - scheduler: List[Scheduler], - seq_counter: Counter, - stop_checker: StopChecker, - ): + def __init__(self, scheduler_config: SchedulerConfig, + detokenizer: Detokenizer, scheduler: List[Scheduler], + seq_counter: Counter, stop_checker: StopChecker): self.scheduler_config = scheduler_config self.detokenizer = detokenizer self.scheduler = scheduler @@ -44,7 +39,8 @@ def __init__( self.stop_checker = stop_checker def process_outputs(self, sequence_group: SequenceGroup, - outputs: List[SequenceGroupOutput]) -> None: + outputs: List[SequenceGroupOutput], + is_async: bool) -> None: """Append all new tokens to sequences in the sequence group. Fork any surviving beam candidates; free any unsurviving ones. @@ -53,7 +49,8 @@ def process_outputs(self, sequence_group: SequenceGroup, """ assert (len(outputs) == 1 ), f"{type(self)} does not support multiple outputs per step" - return self._process_sequence_group_outputs(sequence_group, outputs[0]) + return self._process_sequence_group_outputs(sequence_group, outputs[0], + is_async) def process_prompt_logprob(self, seq_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: @@ -80,29 +77,33 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, seq_group.prompt_logprobs.extend(prompt_logprobs) def _process_sequence_group_outputs(self, seq_group: SequenceGroup, - outputs: SequenceGroupOutput) -> None: + outputs: SequenceGroupOutput, + is_async: bool) -> None: sampling_params = seq_group.sampling_params if sampling_params.n == 1 and not sampling_params.use_beam_search: - # only have one output sample - sample = outputs.samples[0] - # only have one sequence - seq = seq_group.seqs[0] - seq.append_token_id(sample.output_token, sample.logprobs) - if sampling_params.detokenize and self.detokenizer: - new_char_count = self.detokenizer.decode_sequence_inplace( - seq, sampling_params) + if len(outputs.samples) > 0: + sample = outputs.samples[0] + # only have one sequence + seq = seq_group.seqs[0] + seq.append_token_id(sample.output_token, sample.logprobs, + not is_async) + if sampling_params.detokenize and self.detokenizer: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + else: + new_char_count = 0 + + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count, + sampling_params, + lora_req=seq_group.lora_request, + ) + return else: - new_char_count = 0 - self.stop_checker.maybe_stop_sequence( - seq, - new_char_count, - sampling_params, - lora_req=seq_group.lora_request, - ) - if seq.is_finished(): - for scheduler in self.scheduler: - scheduler.free_seq(seq) - return + # When a new request is added and execute_model + # is still not finished + return # Process samples samples = outputs.samples @@ -130,23 +131,23 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # the parent sequence from the sequence group since it will # not be used in the future iterations. parent.status = SequenceStatus.FINISHED_ABORTED - seq_group.remove(parent.seq_id) - for scheduler in self.scheduler: - scheduler.free_seq(parent) + # seq_group.remove(parent.seq_id) + # for scheduler in self.scheduler: + # scheduler.free_seq(parent) continue # Fork the parent sequence if there are multiple child samples. for child_sample in child_samples[:-1]: new_child_seq_id: int = next(self.seq_counter) child = parent.fork(new_child_seq_id) child.append_token_id(child_sample.output_token, - child_sample.logprobs) + child_sample.logprobs, not is_async) child_seqs.append((child, parent)) # Continue the parent sequence for the last child sample. # We reuse the parent sequence here to reduce redundant memory # copies, especially when using non-beam search sampling methods. last_child_sample = child_samples[-1] parent.append_token_id(last_child_sample.output_token, - last_child_sample.logprobs) + last_child_sample.logprobs, not is_async) child_seqs.append((parent, parent)) for seq, _ in child_seqs: @@ -177,10 +178,10 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # manager. Keep them in the sequence group as candidate output. # NOTE: we need to fork the new sequences before freeing the # old sequences. - for seq, parent in child_seqs: - if seq is parent and seq.is_finished(): - for scheduler in self.scheduler: - scheduler.free_seq(seq) + # for seq, parent in child_seqs: + # if seq is parent and seq.is_finished(): + # for scheduler in self.scheduler: + # scheduler.free_seq(seq) return # Beam search case diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ecd6dc64d343b..262399f2d9e3f 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -128,6 +128,7 @@ def __init__( max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, + disable_output_proc_callback: Optional[bool] = False, **kwargs, ) -> None: ''' @@ -169,6 +170,7 @@ def __init__( max_context_len_to_capture=max_context_len_to_capture, max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, + disable_output_proc_callback=disable_output_proc_callback, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args( @@ -603,7 +605,6 @@ def _validate_and_add_requests( inputs = [inputs] num_requests = len(inputs) - if isinstance(params, list) and len(params) != num_requests: raise ValueError("The lengths of prompts and params " "must be the same.") @@ -678,6 +679,10 @@ def _run_engine( postfix=(f"est. speed input: {0:.2f} toks/s, " f"output: {0:.2f} toks/s"), ) + + # In the loop below, only finished outputs are used + self.llm_engine.step_return_finished_only = True + # Run the engine. outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] total_in_toks = 0 @@ -700,6 +705,10 @@ def _run_engine( f"est. speed input: {in_spd:.2f} toks/s, " f"output: {out_spd:.2f} toks/s") pbar.update(1) + + # Restore original behavior + self.llm_engine.step_return_finished_only = False + if use_tqdm: pbar.close() # Sort the outputs by request ID. diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 4df54a09e5e8c..1a35a7c3b8f75 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -64,8 +64,9 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks=num_cpu_blocks) def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", @@ -188,7 +189,7 @@ async def stop_remote_worker_execution_loop_async(self) -> None: @abstractmethod async def _driver_execute_model_async( self, - execute_model_req: Optional[ExecuteModelRequest] = None + execute_model_req: Optional[ExecuteModelRequest] = None, ) -> List[SamplerOutput]: """Execute the model asynchronously in the driver worker. diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 422bef107f352..55fe4f89b4ea4 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -20,19 +20,14 @@ class ExecutorBase(ABC): uses_ray: bool # whether the executor uses Ray for orchestration. - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - speculative_config: Optional[SpeculativeConfig], - prompt_adapter_config: Optional[PromptAdapterConfig], - observability_config: Optional[ObservabilityConfig], - ) -> None: + def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + speculative_config: Optional[SpeculativeConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], + observability_config: Optional[ObservabilityConfig]) -> None: self.model_config = model_config self.cache_config = cache_config self.lora_config = lora_config diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index af426e31591f2..ee069302a9e36 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -168,5 +168,5 @@ async def execute_model_async( execute_model_req: ExecuteModelRequest, ) -> List[Union[SamplerOutput, PoolerOutput]]: output = await make_async(self.driver_worker.execute_model - )(execute_model_req=execute_model_req, ) + )(execute_model_req=execute_model_req) return output diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 08a35a074b37b..b51659ac546bd 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -218,7 +218,7 @@ def __init__(self, *args, **kwargs): async def _driver_execute_model_async( self, - execute_model_req: Optional[ExecuteModelRequest] = None + execute_model_req: Optional[ExecuteModelRequest] = None, ) -> List[SamplerOutput]: if not self.tp_driver_workers: return await self.driver_exec_model(execute_model_req) diff --git a/vllm/sequence.py b/vllm/sequence.py index 206da192193dc..d675bef3e9e43 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,8 +5,8 @@ from array import array from collections import defaultdict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, - Tuple, Union, cast) +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping, + Optional, Set, Tuple, Union, cast) import msgspec import torch @@ -474,14 +474,15 @@ def reset_state_for_recompute(self): """Reset the sequence states for recomputation.""" self.data.reset_state_for_recompute() - def append_token_id( - self, - token_id: int, - logprobs: Dict[int, Logprob], - ) -> None: + def append_token_id(self, + token_id: int, + logprobs: Dict[int, Logprob], + update_seq_data: bool = True) -> None: assert token_id in logprobs self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id].logprob) + # Only do this when output proc callback is not used + if update_seq_data: + self.data.append_token_id(token_id, logprobs[token_id].logprob) def get_len(self) -> int: return self.data.get_len() @@ -1242,6 +1243,9 @@ class ExecuteModelRequest( finished_requests_ids: List[str] = msgspec.field(default_factory=list) # The last sampled token ids for multi step decoding. last_sampled_token_ids: Optional[torch.Tensor] = None + # Async postprocessor + callback_fn: Optional[Callable] = None + use_async_and_multi_step: bool = False @property def is_first_multi_step(self) -> bool: @@ -1287,4 +1291,6 @@ def clone( num_steps=self.num_steps, finished_requests_ids=self.finished_requests_ids, last_sampled_token_ids=self.last_sampled_token_ids.clone() - if self.last_sampled_token_ids is not None else None) + if self.last_sampled_token_ids is not None else None, + callback_fn=self.callback_fn, + use_async_and_multi_step=self.use_async_and_multi_step) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9f27c734efd1e..62de414026ad4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -5,8 +5,8 @@ import warnings import weakref from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, - TypeVar, Union) +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, + Tuple, Type, TypeVar, Union) import numpy as np import torch @@ -100,7 +100,9 @@ class ModelInputForGPU(ModelRunnerInputBase): request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None finished_requests_ids: Optional[List[str]] = None virtual_engine: int = 0 - + callback_fn: Optional[Callable] = None + use_async_and_multi_step: bool = False + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, @@ -1419,7 +1421,7 @@ def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None + finished_requests_ids: Optional[List[str]] = None, ) -> ModelInputForGPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. @@ -1459,6 +1461,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + callback_fn = None ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError("num_steps > 1 is not supported in ModelRunner") @@ -1572,6 +1575,14 @@ def execute_model( if not self.is_driver_worker: return [] + if callback_fn is not None: + # Async + multi-step + callback_fn() + else: + # Only async + if (model_input.callback_fn is not None): + model_input.callback_fn(is_async=True) + # Sample the next token. output: SamplerOutput = self.model.sample( logits=logits, diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 90c39407d7266..64c209cbbdcd3 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -196,6 +196,7 @@ def execute_model( kv_caches: Optional[List[torch.Tensor]], intermediate_tensors: Optional[IntermediateTensors], num_steps: int = 1, + callback_fn = None ) -> Optional[List[SamplerOutput]]: """ Execute the model on the given input. diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 521205eca05af..7f604e508363e 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -215,6 +215,19 @@ def prepare_model_input( ) return model_input + def async_process_outputs(self): + process_model_outputs_fn = self._cur_model_input.frozen_model_input.callback_fn + assert process_model_outputs_fn is not None + + for model_output in self._cur_model_input.cached_outputs: + if not model_output.pythonized: + model_output.maybe_pythonize(self._cur_model_input, + self._copy_stream, + self.pinned_sampled_token_ids) + if model_output.pythonized: + process_model_outputs_fn( + is_async=False, sampler_output=model_output.sampler_output) + @torch.inference_mode() def execute_model( self, @@ -272,10 +285,15 @@ def execute_model( model_input, model_input.cached_outputs[-1].sampler_output) # Execute the model - output = self._base_model_runner.execute_model(frozen_model_input, - kv_caches, - intermediate_tensors, - num_steps=1) + self._cur_model_input = model_input + + output = self._base_model_runner.execute_model( + frozen_model_input, + kv_caches, + intermediate_tensors, + num_steps=1, + callback_fn=self.async_process_outputs + if frozen_model_input.use_async_and_multi_step else None) # record the event for the current step so that the next step can sync model_input.record_step_event(current_stream) @@ -301,9 +319,11 @@ def execute_model( output[0].logprobs = None # Pythonize the output if CPU is ahead and the previous step is # ready. - for model_output in model_input.cached_outputs: - model_output.maybe_pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) + if not frozen_model_input.use_async_and_multi_step: + for model_output in model_input.cached_outputs: + model_output.maybe_pythonize(model_input, + self._copy_stream, + self.pinned_sampled_token_ids) model_input.current_step += 1 @@ -317,9 +337,23 @@ def execute_model( # Pythonize the output and block if needed since it is the last step if model_input.is_last_step: outputs = [] - for output in model_input.cached_outputs: - output.pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) + for output_id in range(len(model_input.cached_outputs)): + is_last_output = output_id == len( + model_input.cached_outputs) - 1 + + output = model_input.cached_outputs[output_id] + if not output.pythonized: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + + if model_input.frozen_model_input.use_async_and_multi_step: + process_model_outputs_fn = model_input.frozen_model_input.callback_fn + assert process_model_outputs_fn is not None + process_model_outputs_fn( + is_async=False, + sampler_output=output.sampler_output, + is_last_output=is_last_output) + outputs.append(output.sampler_output) return outputs diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index 6a6caba9371eb..4a90d60dda6c1 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -1,3 +1,4 @@ +import dataclasses from dataclasses import dataclass from typing import List, Optional, Tuple @@ -59,6 +60,13 @@ def _get_driver_input_and_broadcast( execute_model_req.seq_group_metadata_list, execute_model_req.virtual_engine, execute_model_req.finished_requests_ids)) + + if execute_model_req.callback_fn: + model_input.frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input, + callback_fn=execute_model_req.callback_fn, + use_async_and_multi_step=execute_model_req. + use_async_and_multi_step) else: # on subsequent steps we reuse the worker input and model input multi_step_state = self.multi_step_states[virtual_engine] diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 9fddc863548eb..4720cd9bb3b30 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -256,6 +256,11 @@ def _get_driver_input_and_broadcast( broadcast_data.update(model_input.as_broadcastable_tensor_dict()) broadcast_tensor_dict(broadcast_data, src=0) + if execute_model_req.callback_fn: + model_input = dataclasses.replace( # type: ignore + model_input, + callback_fn=execute_model_req.callback_fn) + return model_input, worker_input def prepare_input( @@ -281,7 +286,7 @@ def prepare_input( def execute_model( self, - execute_model_req: Optional[ExecuteModelRequest] = None + execute_model_req: Optional[ExecuteModelRequest] = None, ) -> Optional[List[SamplerOutput]]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" @@ -315,6 +320,7 @@ def execute_model( model_input, self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None, intermediate_tensors, num_steps) + model_execute_time = time.perf_counter() - start_time if not get_pp_group().is_last_rank: # output is IntermediateTensors