Skip to content

Commit

Permalink
[Core] Async_output_proc: Add virtual engine support (towards pipelin…
Browse files Browse the repository at this point in the history
…e parallel) (vllm-project#7911)
  • Loading branch information
alexm-neuralmagic authored Aug 28, 2024
1 parent 851743c commit 0a88c1d
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 67 deletions.
11 changes: 5 additions & 6 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def __init__(
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
pipeline_parallel_size: int = 1,
output_proc_callback_fn: Optional[Callable] = None,
output_proc_callback: Optional[Callable] = None,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
Expand Down Expand Up @@ -376,8 +376,8 @@ def __init__(
# iterations. I.e. since the output processing is lagged one step,
# we cannot reuse the cached objects immediately when the schedule()
# is called again, but only when schedule() is called the second time.
self.output_proc_callback_fn = output_proc_callback_fn
self.use_async_output_proc = self.output_proc_callback_fn is not None
self.output_proc_callback = output_proc_callback
self.use_async_output_proc = self.output_proc_callback is not None
self.num_cache_iters = 2 if self.use_async_output_proc else 1

self.cache_id = 0
Expand Down Expand Up @@ -573,8 +573,8 @@ def _schedule_running(
seq_group):
tmp = self.running
self.running = orig_running
assert self.output_proc_callback_fn is not None
self.output_proc_callback_fn(is_async=True)
assert self.output_proc_callback is not None
self.output_proc_callback()
self.running = tmp

while not self._can_append_slots(seq_group):
Expand Down Expand Up @@ -1091,7 +1091,6 @@ def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
no_beam_search = seq_group.sampling_params is None or (
seq_group.sampling_params.best_of == 1
and not seq_group.sampling_params.use_beam_search)

return no_beam_search

def schedule(
Expand Down
37 changes: 27 additions & 10 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,19 +279,26 @@ async def step_async(
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc

ctx = self.scheduler_contexts[virtual_engine]

# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):

# Clear outputs on scheduler iteration start
ctx.request_outputs.clear()

(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()

# If current scheduler iteration has no async postprocessor,
# then we need first to drain the pending async postprocessor
# before moving forward
if not allow_async_output_proc and len(self.output_queue) > 0:
self._process_model_outputs(is_async=True)
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)

if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
Expand Down Expand Up @@ -332,8 +339,8 @@ async def step_async(
last_sampled_token_ids=last_sampled_token_ids)

if allow_async_output_proc:
execute_model_req.output_proc_callback_fn = \
self._process_model_outputs
execute_model_req.async_callback = self.async_callback[
virtual_engine]

# Execute the model.
output = await self.model_executor.execute_model_async(
Expand All @@ -343,9 +350,10 @@ async def step_async(
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
if len(self.output_queue) > 0:
if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(is_async=True)
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
output = []

# Finish the current step for all the sequence groups.
Expand All @@ -360,7 +368,7 @@ async def step_async(
virtual_engine] = SchedulerOutputState()

# Cache results in engine
self.output_queue.append(
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))

if output and allow_async_output_proc:
Expand All @@ -372,7 +380,8 @@ async def step_async(
scheduler_outputs.scheduled_seq_groups)

if not allow_async_output_proc:
self._process_model_outputs(is_async=False)
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=False)

# Log stats.
self.do_log_stats(scheduler_outputs, output)
Expand All @@ -381,9 +390,17 @@ async def step_async(
self.do_tracing(scheduler_outputs)

else:
self.request_outputs = []
ctx.request_outputs = []

if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
assert len(ctx.output_queue) == 0

return self.request_outputs
return ctx.request_outputs

async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
Expand Down
121 changes: 79 additions & 42 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import functools
import time
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
Mapping, Optional)
from typing import Sequence as GenericSequence
Expand Down Expand Up @@ -88,6 +89,17 @@ class SchedulerOutputState:
last_output: Optional[SamplerOutput] = None


@dataclass
class SchedulerContext:
output_queue: Deque[Tuple[List[SamplerOutput], List[SequenceGroupMetadata],
SchedulerOutputs]] = field(
default_factory=lambda: deque())

request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = field(
default_factory=lambda: [])


class LLMEngine:
"""An LLM engine that receives requests and generates texts.
Expand Down Expand Up @@ -350,9 +362,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
Scheduler(
scheduler_config, cache_config, lora_config,
parallel_config.pipeline_parallel_size,
self._process_model_outputs
functools.partial(self._process_model_outputs,
virtual_engine=v_id,
is_async=True)
if model_config.use_async_output_proc else None)
for _ in range(parallel_config.pipeline_parallel_size)
for v_id in range(parallel_config.pipeline_parallel_size)
]

# Metric Logging.
Expand Down Expand Up @@ -406,12 +420,17 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
for _ in range(self.parallel_config.pipeline_parallel_size)
]

# Async output processing pointers
self.output_queue: Deque[Tuple[List[SamplerOutput],
List[SequenceGroupMetadata],
SchedulerOutputs]] = deque()
self.request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
self.scheduler_contexts = [
SchedulerContext()
for _ in range(self.parallel_config.pipeline_parallel_size)
]

self.async_callback = [
functools.partial(self._process_model_outputs,
virtual_engine=v_id,
is_async=True)
for v_id in range(self.parallel_config.pipeline_parallel_size)
]

def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
Expand Down Expand Up @@ -1221,32 +1240,28 @@ def _process_sequence_group_outputs(

return

def _process_model_outputs(self,
is_async: bool,
clear_outputs: bool = True) -> None:
def _process_model_outputs(self, virtual_engine: int,
is_async: bool) -> None:
"""Apply the model output to the sequences in the scheduled seq groups.
virtual_engine: The engine id to operate on
is_async: Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
clear_outputs: Sometimes existing outputs need to be combined
with outputs of this call. This happens for postprocessor
draining at the final stage (like when sequences are finished)
Returns RequestOutputs that can be returned to the client.
"""
now = time.time()

if clear_outputs:
self.request_outputs.clear()
ctx: SchedulerContext = self.scheduler_contexts[virtual_engine]

if len(self.output_queue) == 0:
if len(ctx.output_queue) == 0:
return None

(outputs, seq_group_metadata_list,
scheduler_outputs) = self.output_queue.popleft()
scheduler_outputs) = ctx.output_queue.popleft()

# Sanity check
assert len(seq_group_metadata_list) == len(
Expand Down Expand Up @@ -1321,11 +1336,11 @@ def _process_model_outputs(self,
if (seq_group.is_finished()
if self.step_return_finished_only else True):
request_output = RequestOutputFactory.create(seq_group)
self.request_outputs.append(request_output)
ctx.request_outputs.append(request_output)

for seq_group in scheduler_outputs.ignored_seq_groups:
request_output = RequestOutputFactory.create(seq_group)
self.request_outputs.append(request_output)
ctx.request_outputs.append(request_output)

if is_async:
# Log stats.
Expand Down Expand Up @@ -1421,29 +1436,43 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"Pipeline parallelism is only supported through AsyncLLMEngine "
"as performance will be severely degraded otherwise.")

# For llm_engine, there is no pipeline parallel support, so the engine
# used is always 0
virtual_engine = 0

# These are cached outputs from previous iterations. None if on first
# iteration
cached_outputs = self.cached_scheduler_outputs[0]
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc

ctx = self.scheduler_contexts[virtual_engine]

# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):

# Clear outputs on scheduler iteration start
ctx.request_outputs.clear()

# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc) = self.scheduler[0].schedule()
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()

if not allow_async_output_proc and len(self.output_queue) > 0:
self._process_model_outputs(is_async=True)
# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)

if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(
0, seq_group_metadata_list, scheduler_outputs,
virtual_engine, seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)

assert seq_group_metadata_list is not None
Expand All @@ -1454,14 +1483,14 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:

if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
0].get_and_reset_finished_requests_ids()
virtual_engine].get_and_reset_finished_requests_ids()

# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
# sampled_token_ids, as a separate broadcast over all the PP stages
# will cause one virtual engine's microbatch to block the pipeline.
last_sampled_token_ids = \
self._get_last_sampled_token_ids(0)
self._get_last_sampled_token_ids(virtual_engine)

execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
Expand All @@ -1476,20 +1505,24 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
last_sampled_token_ids=last_sampled_token_ids)

if allow_async_output_proc:
execute_model_req.output_proc_callback_fn = \
self._process_model_outputs
execute_model_req.async_callback = self.async_callback[
virtual_engine]

output = self.model_executor.execute_model(
execute_model_req=execute_model_req)

# we need to do this here so that last step's sampled_token_ids can
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(0, output)
self._update_cached_scheduler_output(virtual_engine, output)
else:
if len(self.output_queue) > 0:
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(is_async=True)
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
# No outputs in this case
output = []

# Finish the current step for all the sequence groups.
Expand All @@ -1504,7 +1537,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:

# Add results to the output_queue
# (for async or non-async postprocessing)
self.output_queue.append(
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))

if output and allow_async_output_proc:
Expand All @@ -1515,23 +1548,27 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)

# Check if need to run the usual non-async path
if not allow_async_output_proc:
self._process_model_outputs(is_async=False)
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=False)

# Log stats.
self.do_log_stats(scheduler_outputs, output)

# Tracing
self.do_tracing(scheduler_outputs)
else:
self.request_outputs = []
# Multi-step case
ctx.request_outputs = []

if not self.has_unfinished_requests():
# Drain async postprocessor
if len(self.output_queue) > 0:
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(is_async=True, clear_outputs=False)
assert len(self.output_queue) == 0
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
assert len(ctx.output_queue) == 0

# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
Expand All @@ -1540,7 +1577,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
# queued control plane messages, such as add/remove lora adapters.
self.model_executor.stop_remote_worker_execution_loop()

return self.request_outputs
return ctx.request_outputs

def _has_remaining_steps(
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
Expand Down
Loading

0 comments on commit 0a88c1d

Please sign in to comment.