From 9e8f61e824076878f67bc55ee5da025c8fb8d3f3 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Tue, 27 Aug 2024 13:26:26 +0000 Subject: [PATCH 1/6] async_output_proc: Add virtual engine support --- vllm/engine/llm_engine.py | 115 +++++++++++++++++++++++------------- vllm/sequence.py | 15 ++++- vllm/worker/model_runner.py | 12 ++-- vllm/worker/worker_base.py | 5 +- 4 files changed, 96 insertions(+), 51 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7356c1abbfa88..46db1f4aa3a23 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,7 +1,7 @@ import time from collections import deque from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List, Mapping, Optional) from typing import Sequence as GenericSequence @@ -40,7 +40,8 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, SamplerOutput, Sequence, SequenceGroup, - SequenceGroupMetadata, SequenceStatus) + SequenceGroupMetadata, SequenceStatus, + AsyncCallbackData) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -88,6 +89,19 @@ class SchedulerOutputState: last_output: Optional[SamplerOutput] = None +@dataclass +class SchedulerContext: + output_queue: Deque[Tuple[List[SamplerOutput], + List[Tuple[ScheduledSequenceGroup, + SequenceGroupMetadata]], + SchedulerOutputs]] = field( + default_factory=lambda: deque()) + + request_outputs: List[Union[RequestOutput, + EmbeddingRequestOutput]] = field( + default_factory=lambda: []) + + class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -406,12 +420,17 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: for _ in range(self.parallel_config.pipeline_parallel_size) ] - # Async output processing pointers - self.output_queue: Deque[Tuple[List[SamplerOutput], - List[SequenceGroupMetadata], - SchedulerOutputs]] = deque() - self.request_outputs: List[Union[RequestOutput, - EmbeddingRequestOutput]] = [] + self.scheduler_contexts = [ + SchedulerContext() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + self.async_callback_data = [ + AsyncCallbackData(self._process_model_outputs, { + "virtual_engine": v_id, + "is_async": True, + }) for v_id in range(self.parallel_config.pipeline_parallel_size) + ] def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -1214,32 +1233,28 @@ def _process_sequence_group_outputs( return - def _process_model_outputs(self, - is_async: bool, - clear_outputs: bool = True) -> None: + def _process_model_outputs(self, virtual_engine: int, + is_async: bool) -> None: """Apply the model output to the sequences in the scheduled seq groups. + virtual_engine: The engine id to operate on is_async: Indicates whether this postprocessor runs in parallel with the GPU forward pass and is processing tokens from the previous step. If this is true, then no tokens need to be appended since it is already done externally (before the next schedule() call) - clear_outputs: Sometimes existing outputs need to be combined - with outputs of this call. This happens for postprocessor - draining at the final stage (like when sequences are finished) Returns RequestOutputs that can be returned to the client. """ now = time.time() - if clear_outputs: - self.request_outputs.clear() + ctx: SchedulerContext = self.scheduler_contexts[virtual_engine] - if len(self.output_queue) == 0: + if len(ctx.output_queue) == 0: return None (outputs, seq_group_metadata_list, - scheduler_outputs) = self.output_queue.popleft() + scheduler_outputs) = ctx.output_queue.popleft() # Sanity check assert len(seq_group_metadata_list) == len( @@ -1314,11 +1329,11 @@ def _process_model_outputs(self, if (seq_group.is_finished() if self.step_return_finished_only else True): request_output = RequestOutputFactory.create(seq_group) - self.request_outputs.append(request_output) + ctx.request_outputs.append(request_output) for seq_group in scheduler_outputs.ignored_seq_groups: request_output = RequestOutputFactory.create(seq_group) - self.request_outputs.append(request_output) + ctx.request_outputs.append(request_output) if is_async: # Log stats. @@ -1414,29 +1429,41 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: "Pipeline parallelism is only supported through AsyncLLMEngine " "as performance will be severely degraded otherwise.") + virtual_engine = 0 + # These are cached outputs from previous iterations. None if on first # iteration - cached_outputs = self.cached_scheduler_outputs[0] + cached_outputs = self.cached_scheduler_outputs[virtual_engine] seq_group_metadata_list = cached_outputs.seq_group_metadata_list scheduler_outputs = cached_outputs.scheduler_outputs allow_async_output_proc = cached_outputs.allow_async_output_proc + ctx = self.scheduler_contexts[virtual_engine] + # Skip the scheduler if there are any remaining steps in the seq groups. # This ensures that the scheduler is only called again when the current # batch has completed. if not self._has_remaining_steps(seq_group_metadata_list): + + # Clear outputs on scheduler iteration start + ctx.request_outputs.clear() + + # Schedule iteration (seq_group_metadata_list, scheduler_outputs, - allow_async_output_proc) = self.scheduler[0].schedule() + allow_async_output_proc + ) = self.scheduler[virtual_engine].schedule() - if not allow_async_output_proc and len(self.output_queue) > 0: - self._process_model_outputs(is_async=True) + # Maybe switch from async mode to sync mode + if not allow_async_output_proc and len(ctx.output_queue) > 0: + self._process_model_outputs(virtual_engine=virtual_engine, + is_async=True) if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): # cache the scheduler outputs for the next iteration if we have # lookahead slots self._cache_scheduler_outputs_for_multi_step( - 0, seq_group_metadata_list, scheduler_outputs, + virtual_engine, seq_group_metadata_list, scheduler_outputs, allow_async_output_proc) assert seq_group_metadata_list is not None @@ -1447,14 +1474,14 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ - 0].get_and_reset_finished_requests_ids() + virtual_engine].get_and_reset_finished_requests_ids() # Check if we have a cached last_output from the previous iteration. # For supporting PP this is probably the best way to pass the # sampled_token_ids, as a separate broadcast over all the PP stages # will cause one virtual engine's microbatch to block the pipeline. last_sampled_token_ids = \ - self._get_last_sampled_token_ids(0) + self._get_last_sampled_token_ids(virtual_engine) execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, @@ -1469,20 +1496,24 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: last_sampled_token_ids=last_sampled_token_ids) if allow_async_output_proc: - execute_model_req.output_proc_callback_fn = \ - self._process_model_outputs + execute_model_req.async_callback = self.async_callback_data[ + virtual_engine] output = self.model_executor.execute_model( execute_model_req=execute_model_req) - # we need to do this here so that last step's sampled_token_ids can + # We need to do this here so that last step's sampled_token_ids can # be passed to the next iteration for PP. if self.scheduler_config.is_multi_step: - self._update_cached_scheduler_output(0, output) + self._update_cached_scheduler_output(virtual_engine, output) else: - if len(self.output_queue) > 0: + # Nothing scheduled => If there is pending async postprocessor, + # then finish it here. + if len(ctx.output_queue) > 0: assert not self.scheduler_config.is_multi_step - self._process_model_outputs(is_async=True) + self._process_model_outputs(virtual_engine=virtual_engine, + is_async=True) + # No outputs in this case output = [] # Finish the current step for all the sequence groups. @@ -1497,7 +1528,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # Add results to the output_queue # (for async or non-async postprocessing) - self.output_queue.append( + ctx.output_queue.append( (output, seq_group_metadata_list, scheduler_outputs)) if output and allow_async_output_proc: @@ -1508,8 +1539,10 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: output[0], seq_group_metadata_list, scheduler_outputs.scheduled_seq_groups) + # Check if need to run the usual non-async path if not allow_async_output_proc: - self._process_model_outputs(is_async=False) + self._process_model_outputs(virtual_engine=virtual_engine, + is_async=False) # Log stats. self.do_log_stats(scheduler_outputs, output) @@ -1517,14 +1550,16 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # Tracing self.do_tracing(scheduler_outputs) else: + # Multi-step case self.request_outputs = [] if not self.has_unfinished_requests(): - # Drain async postprocessor - if len(self.output_queue) > 0: + # Drain async postprocessor (if exists) + if len(ctx.output_queue) > 0: assert not self.scheduler_config.is_multi_step - self._process_model_outputs(is_async=True, clear_outputs=False) - assert len(self.output_queue) == 0 + self._process_model_outputs(virtual_engine=virtual_engine, + is_async=True) + assert len(ctx.output_queue) == 0 # Stop the execute model loop in parallel workers until there are # more requests to process. This avoids waiting indefinitely in @@ -1533,7 +1568,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # queued control plane messages, such as add/remove lora adapters. self.model_executor.stop_remote_worker_execution_loop() - return self.request_outputs + return ctx.request_outputs def _has_remaining_steps( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] diff --git a/vllm/sequence.py b/vllm/sequence.py index 964072dd7c8f1..6cc5a8b33e27c 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -811,6 +811,9 @@ def remove(self, seq_id: int) -> None: self.is_single_seq = len(self.seqs) == 1 def is_finished(self) -> bool: + if self.is_single_seq: + return self.seqs[0].is_finished() + return all(seq.is_finished() for seq in self.seqs) def is_prefill(self) -> bool: @@ -1259,6 +1262,12 @@ def expand_with_bonus_tokens( [self.hidden_states, self.second_last_token_hidden_states])[index] +@dataclass +class AsyncCallbackData: + func: Callable + kw_args: Dict[str, Any] + + class ExecuteModelRequest( msgspec.Struct, array_like=True, # type: ignore[call-arg] @@ -1290,8 +1299,8 @@ class ExecuteModelRequest( finished_requests_ids: List[str] = msgspec.field(default_factory=list) # The last sampled token ids for multi step decoding. last_sampled_token_ids: Optional[torch.Tensor] = None - # Async postprocessor - output_proc_callback_fn: Optional[Callable] = None + # Async callback + async_callback: Optional[AsyncCallbackData] = None @property def is_first_multi_step(self) -> bool: @@ -1338,4 +1347,4 @@ def clone( finished_requests_ids=self.finished_requests_ids, last_sampled_token_ids=self.last_sampled_token_ids.clone() if self.last_sampled_token_ids is not None else None, - output_proc_callback_fn=self.output_proc_callback_fn) + async_callback=self.async_callback) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a81b892992237..fbfc911b7cb0c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -42,7 +42,7 @@ LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) + SequenceGroupMetadata, AsyncCallbackData) from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_hip, is_pin_memory_available) from vllm.worker.model_runner_base import ( @@ -90,7 +90,7 @@ class ModelInputForGPU(ModelRunnerInputBase): request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None finished_requests_ids: Optional[List[str]] = None virtual_engine: int = 0 - output_proc_callback_fn: Optional[Callable] = None + async_callback: Optional[AsyncCallbackData] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -1456,9 +1456,11 @@ def execute_model( if not self.is_driver_worker: return [] - if model_input.output_proc_callback_fn is not None: - model_input.output_proc_callback_fn(is_async=True) - + if model_input.async_callback is not None: + func = model_input.async_callback.func + kw_args = model_input.async_callback.kw_args + func(**kw_args) + # Sample the next token. output: SamplerOutput = self.model.sample( logits=logits, diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index e35d5c962a489..012043673b094 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -263,11 +263,10 @@ def _get_driver_input_and_broadcast( broadcast_data.update(kwargs) broadcast_tensor_dict(broadcast_data, src=0) - if execute_model_req.output_proc_callback_fn: + if execute_model_req.async_callback: model_input = dataclasses.replace( # type: ignore model_input, - output_proc_callback_fn=execute_model_req. - output_proc_callback_fn) + async_callback=execute_model_req.async_callback) return model_input, worker_input, kwargs From 5022a225556a0e9aacdd4203b4ae8f49332a96cb Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Tue, 27 Aug 2024 18:16:08 +0000 Subject: [PATCH 2/6] fixes --- vllm/engine/async_llm_engine.py | 37 ++++++++++++++++++++++++--------- vllm/engine/llm_engine.py | 14 ++++++------- vllm/worker/model_runner.py | 10 ++++----- 3 files changed, 38 insertions(+), 23 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3445b7084bbcd..544b2adde5186 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -279,10 +279,16 @@ async def step_async( scheduler_outputs = cached_outputs.scheduler_outputs allow_async_output_proc = cached_outputs.allow_async_output_proc + ctx = self.scheduler_contexts[virtual_engine] + # skip the scheduler if there are any remaining steps in the seq groups. # This ensures that the scheduler is only called again when the current # batch has completed. if not self._has_remaining_steps(seq_group_metadata_list): + + # Clear outputs on scheduler iteration start + ctx.request_outputs.clear() + (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc ) = self.scheduler[virtual_engine].schedule() @@ -290,8 +296,9 @@ async def step_async( # If current scheduler iteration has no async postprocessor, # then we need first to drain the pending async postprocessor # before moving forward - if not allow_async_output_proc and len(self.output_queue) > 0: - self._process_model_outputs(is_async=True) + if not allow_async_output_proc and len(ctx.output_queue) > 0: + self._process_model_outputs(virtual_engine=virtual_engine, + is_async=True) if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): @@ -332,8 +339,8 @@ async def step_async( last_sampled_token_ids=last_sampled_token_ids) if allow_async_output_proc: - execute_model_req.output_proc_callback_fn = \ - self._process_model_outputs + execute_model_req.async_callback = self.async_callback_data[ + virtual_engine] # Execute the model. output = await self.model_executor.execute_model_async( @@ -343,9 +350,10 @@ async def step_async( if self.scheduler_config.is_multi_step: self._update_cached_scheduler_output(virtual_engine, output) else: - if len(self.output_queue) > 0: + if len(ctx.output_queue) > 0: assert not self.scheduler_config.is_multi_step - self._process_model_outputs(is_async=True) + self._process_model_outputs(virtual_engine=virtual_engine, + is_async=True) output = [] # Finish the current step for all the sequence groups. @@ -360,7 +368,7 @@ async def step_async( virtual_engine] = SchedulerOutputState() # Cache results in engine - self.output_queue.append( + ctx.output_queue.append( (output, seq_group_metadata_list, scheduler_outputs)) if output and allow_async_output_proc: @@ -372,7 +380,8 @@ async def step_async( scheduler_outputs.scheduled_seq_groups) if not allow_async_output_proc: - self._process_model_outputs(is_async=False) + self._process_model_outputs(virtual_engine=virtual_engine, + is_async=False) # Log stats. self.do_log_stats(scheduler_outputs, output) @@ -381,9 +390,17 @@ async def step_async( self.do_tracing(scheduler_outputs) else: - self.request_outputs = [] + ctx.request_outputs = [] + + if not self.has_unfinished_requests(): + # Drain async postprocessor (if exists) + if len(ctx.output_queue) > 0: + assert not self.scheduler_config.is_multi_step + self._process_model_outputs(virtual_engine=virtual_engine, + is_async=True) + assert len(ctx.output_queue) == 0 - return self.request_outputs + return ctx.request_outputs async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 46db1f4aa3a23..4d8c037c21920 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -38,10 +38,10 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - SamplerOutput, Sequence, SequenceGroup, - SequenceGroupMetadata, SequenceStatus, - AsyncCallbackData) +from vllm.sequence import (AsyncCallbackData, EmbeddingSequenceGroupOutput, + ExecuteModelRequest, SamplerOutput, Sequence, + SequenceGroup, SequenceGroupMetadata, + SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -91,9 +91,7 @@ class SchedulerOutputState: @dataclass class SchedulerContext: - output_queue: Deque[Tuple[List[SamplerOutput], - List[Tuple[ScheduledSequenceGroup, - SequenceGroupMetadata]], + output_queue: Deque[Tuple[List[SamplerOutput], List[SequenceGroupMetadata], SchedulerOutputs]] = field( default_factory=lambda: deque()) @@ -1551,7 +1549,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: self.do_tracing(scheduler_outputs) else: # Multi-step case - self.request_outputs = [] + ctx.request_outputs = [] if not self.has_unfinished_requests(): # Drain async postprocessor (if exists) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fbfc911b7cb0c..1cd8746d050b1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -6,8 +6,8 @@ import warnings import weakref from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, - Tuple, Type, TypeVar, Union) +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, + TypeVar, Union) import numpy as np import torch @@ -41,8 +41,8 @@ from vllm.prompt_adapter.worker_manager import ( LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata, AsyncCallbackData) +from vllm.sequence import (AsyncCallbackData, IntermediateTensors, + SamplerOutput, SequenceGroupMetadata) from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_hip, is_pin_memory_available) from vllm.worker.model_runner_base import ( @@ -1460,7 +1460,7 @@ def execute_model( func = model_input.async_callback.func kw_args = model_input.async_callback.kw_args func(**kw_args) - + # Sample the next token. output: SamplerOutput = self.model.sample( logits=logits, From 0ebc1351005d12d95b28603ba8273812acb5d88b Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Tue, 27 Aug 2024 18:19:25 +0000 Subject: [PATCH 3/6] Cody's comments --- vllm/engine/llm_engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4d8c037c21920..ef01fb02721a2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1427,6 +1427,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: "Pipeline parallelism is only supported through AsyncLLMEngine " "as performance will be severely degraded otherwise.") + # For llm_engine, there is no pipeline parallel support, so the engine + # used is always 0 virtual_engine = 0 # These are cached outputs from previous iterations. None if on first From dc8395cb495a12da1b4a7de035807ce4319325eb Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Wed, 28 Aug 2024 02:24:13 +0000 Subject: [PATCH 4/6] Nick's comments --- vllm/core/scheduler.py | 10 +++++----- vllm/engine/async_llm_engine.py | 2 +- vllm/engine/llm_engine.py | 26 ++++++++++++++------------ vllm/sequence.py | 8 +------- vllm/worker/model_runner.py | 14 ++++++-------- 5 files changed, 27 insertions(+), 33 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 280d7b7e61e2c..0ce36142600b3 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -300,7 +300,7 @@ def __init__( cache_config: CacheConfig, lora_config: Optional[LoRAConfig], pipeline_parallel_size: int = 1, - output_proc_callback_fn: Optional[Callable] = None, + output_proc_callback: Optional[Callable] = None, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config @@ -374,8 +374,8 @@ def __init__( # iterations. I.e. since the output processing is lagged one step, # we cannot reuse the cached objects immediately when the schedule() # is called again, but only when schedule() is called the second time. - self.output_proc_callback_fn = output_proc_callback_fn - self.use_async_output_proc = self.output_proc_callback_fn is not None + self.output_proc_callback = output_proc_callback + self.use_async_output_proc = self.output_proc_callback is not None self.num_cache_iters = 2 if self.use_async_output_proc else 1 self.cache_id = 0 @@ -571,8 +571,8 @@ def _schedule_running( seq_group): tmp = self.running self.running = orig_running - assert self.output_proc_callback_fn is not None - self.output_proc_callback_fn(is_async=True) + assert self.output_proc_callback is not None + self.output_proc_callback() self.running = tmp while not self._can_append_slots(seq_group): diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 544b2adde5186..76583602c5115 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -339,7 +339,7 @@ async def step_async( last_sampled_token_ids=last_sampled_token_ids) if allow_async_output_proc: - execute_model_req.async_callback = self.async_callback_data[ + execute_model_req.async_callback = self.async_callback[ virtual_engine] # Execute the model. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ef01fb02721a2..6c6201adeebd7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,3 +1,4 @@ +import functools import time from collections import deque from contextlib import contextmanager @@ -38,10 +39,9 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import (AsyncCallbackData, EmbeddingSequenceGroupOutput, - ExecuteModelRequest, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupMetadata, - SequenceStatus) +from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, + SamplerOutput, Sequence, SequenceGroup, + SequenceGroupMetadata, SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -362,9 +362,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: Scheduler( scheduler_config, cache_config, lora_config, parallel_config.pipeline_parallel_size, - self._process_model_outputs + functools.partial(self._process_model_outputs, + virtual_engine=v_id, + is_async=True) if model_config.use_async_output_proc else None) - for _ in range(parallel_config.pipeline_parallel_size) + for v_id in range(parallel_config.pipeline_parallel_size) ] # Metric Logging. @@ -423,11 +425,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: for _ in range(self.parallel_config.pipeline_parallel_size) ] - self.async_callback_data = [ - AsyncCallbackData(self._process_model_outputs, { - "virtual_engine": v_id, - "is_async": True, - }) for v_id in range(self.parallel_config.pipeline_parallel_size) + self.async_callback = [ + functools.partial(self._process_model_outputs, + virtual_engine=v_id, + is_async=True) + for v_id in range(self.parallel_config.pipeline_parallel_size) ] def _initialize_kv_caches(self) -> None: @@ -1496,7 +1498,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: last_sampled_token_ids=last_sampled_token_ids) if allow_async_output_proc: - execute_model_req.async_callback = self.async_callback_data[ + execute_model_req.async_callback = self.async_callback[ virtual_engine] output = self.model_executor.execute_model( diff --git a/vllm/sequence.py b/vllm/sequence.py index 6cc5a8b33e27c..533a8a3a0c82d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1262,12 +1262,6 @@ def expand_with_bonus_tokens( [self.hidden_states, self.second_last_token_hidden_states])[index] -@dataclass -class AsyncCallbackData: - func: Callable - kw_args: Dict[str, Any] - - class ExecuteModelRequest( msgspec.Struct, array_like=True, # type: ignore[call-arg] @@ -1300,7 +1294,7 @@ class ExecuteModelRequest( # The last sampled token ids for multi step decoding. last_sampled_token_ids: Optional[torch.Tensor] = None # Async callback - async_callback: Optional[AsyncCallbackData] = None + async_callback: Optional[Callable] = None @property def is_first_multi_step(self) -> bool: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1cd8746d050b1..cf785aed6c44e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -6,8 +6,8 @@ import warnings import weakref from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, - TypeVar, Union) +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, + Tuple, Type, TypeVar, Union) import numpy as np import torch @@ -41,8 +41,8 @@ from vllm.prompt_adapter.worker_manager import ( LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams -from vllm.sequence import (AsyncCallbackData, IntermediateTensors, - SamplerOutput, SequenceGroupMetadata) +from vllm.sequence import (IntermediateTensors, SamplerOutput, + SequenceGroupMetadata) from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_hip, is_pin_memory_available) from vllm.worker.model_runner_base import ( @@ -90,7 +90,7 @@ class ModelInputForGPU(ModelRunnerInputBase): request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None finished_requests_ids: Optional[List[str]] = None virtual_engine: int = 0 - async_callback: Optional[AsyncCallbackData] = None + async_callback: Optional[Callable] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -1457,9 +1457,7 @@ def execute_model( return [] if model_input.async_callback is not None: - func = model_input.async_callback.func - kw_args = model_input.async_callback.kw_args - func(**kw_args) + model_input.async_callback() # Sample the next token. output: SamplerOutput = self.model.sample( From bef70d45986eecf396b89ed79d41f502afdcb1bf Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Wed, 28 Aug 2024 02:26:37 +0000 Subject: [PATCH 5/6] format --- vllm/core/scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 0ce36142600b3..64557eddba450 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1086,6 +1086,7 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: ) def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: + assert seq_group.sampling_params is not None no_beam_search = (seq_group.sampling_params.best_of == 1 and not seq_group.sampling_params.use_beam_search) From 32e8d6f01a9897f28a2e6e7148e42cf1822f3197 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 28 Aug 2024 00:01:49 -0700 Subject: [PATCH 6/6] Fix --- vllm/core/scheduler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 3eaad0577b51d..fbc53afa38f67 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1088,7 +1088,6 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: ) def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: - assert seq_group.sampling_params is not None no_beam_search = seq_group.sampling_params is None or ( seq_group.sampling_params.best_of == 1 and not seq_group.sampling_params.use_beam_search)