Skip to content

Commit

Permalink
feat: multi-step scheduling (#831)
Browse files Browse the repository at this point in the history
* add broadcastable model input base

* add multistep model runner

* add multistep worker

* broadcastable model input in worker base

* switch to cpu from numpy for sampled token ids

* async engine impl

* patch gpu executors

* add tests

* remove kv cache estimation

* add to benchmark

* formatting
  • Loading branch information
AlpinDale authored Nov 22, 2024
1 parent 2242cb2 commit 48a8693
Show file tree
Hide file tree
Showing 14 changed files with 1,007 additions and 48 deletions.
10 changes: 5 additions & 5 deletions aphrodite/common/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Union, cast)

import msgspec
import numpy
import torch

from aphrodite.common.pooling_params import PoolingParams
Expand Down Expand Up @@ -1106,7 +1105,10 @@ class SamplerOutput(

# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional[torch.Tensor] = None
sampled_token_ids_numpy: Optional[numpy.ndarray] = None
# CPU tensor containing the sampled token ids. Used during multi-step to
# return the sampled token ids from last rank to AsyncAphrodite to be
# 'broadcasted' to all other PP ranks for next step.
sampled_token_ids_cpu: Optional[torch.Tensor] = None

# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
Expand Down Expand Up @@ -1278,9 +1280,7 @@ def is_last_step(self) -> bool:
assert len(self.seq_group_metadata_list) > 0
first_seq_group = self.seq_group_metadata_list[0]
assert first_seq_group.state is not None
num_steps = first_seq_group.state.num_steps
current_step = first_seq_group.state.current_step
return num_steps - current_step == 1
return first_seq_group.state.remaining_steps == 1

@property
def current_step(self) -> int:
Expand Down
7 changes: 6 additions & 1 deletion aphrodite/engine/args_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,12 @@ def create_engine_config(self, ) -> EngineConfig:
"profiling phase, or result in low performance due to small "
"KV cache space. Consider setting --max-model-len to a "
"smaller value.")

if self.num_scheduler_steps > 1 and not self.use_v2_block_manager:
self.use_v2_block_manager = True
logger.warning(
"Enabled BlockSpaceManagerV2 because it is "
"required for multi-step scheduling.")

speculative_config = SpeculativeConfig.maybe_create_spec_config(
target_model_config=model_config,
Expand Down Expand Up @@ -1012,7 +1018,6 @@ def create_engine_config(self, ) -> EngineConfig:
)

if self.num_scheduler_steps > 1:
raise NotImplementedError("Multi-step is not yet supported.")
if speculative_config is not None:
raise ValueError("Speculative decoding is not supported with "
"multi-step (--num-scheduler-steps > 1)")
Expand Down
126 changes: 118 additions & 8 deletions aphrodite/engine/async_aphrodite.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import os
import time
from dataclasses import dataclass
from functools import partial
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Optional,
Set, Tuple, Type, Union)

import torch
from loguru import logger
from transformers import PreTrainedTokenizer
from typing_extensions import assert_never
Expand All @@ -15,7 +17,8 @@
from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
from aphrodite.common.pooling_params import PoolingParams
from aphrodite.common.sampling_params import SamplingParams
from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from aphrodite.engine.aphrodite_engine import (AphroditeEngine,
DecoderPromptComponents,
PromptComponents)
Expand Down Expand Up @@ -248,9 +251,25 @@ def has_new_requests(self):
return not self._new_requests.empty()


@dataclass
class SchedulerOutputState:
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
last_output: Optional[SamplerOutput] = None
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None


class _AsyncAphrodite(AphroditeEngine):
"""Extension of AphroditeEngine to add async methods."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
pipeline_parallel_size = \
self.parallel_config.pipeline_parallel_size
self.cached_scheduler_outputs = [
SchedulerOutputState() for _ in range(pipeline_parallel_size)
]

async def step_async(
self, virtual_engine: int
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Expand All @@ -263,13 +282,35 @@ async def step_async(
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()
# these are cached outputs from previous iterations. None if on first
# iteration
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()
if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs)
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None

if not scheduler_outputs.is_empty():
# Execute the model.
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
# sampled_token_ids, as a separate broadcast over all the PP stages
# will cause one virtual engine's microbatch to block the pipeline.
last_sampled_token_ids = \
self._get_last_sampled_token_ids(virtual_engine)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
Expand All @@ -279,20 +320,89 @@ async def step_async(
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids,
)
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
# Execute the model.
output = await self.model_executor.execute_model_async(
execute_model_req)
# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
output = []

request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
for seq_group in seq_group_metadata_list:
seq_group.finish_step()
if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
else:
request_outputs = []

# Log stats.
self.do_log_stats(scheduler_outputs, output)

return request_outputs

def _has_remaining_steps(
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
) -> bool:
if (not self.scheduler_config.is_multi_step
or not seq_group_metadata_list):
return False
# TODO: this is a sanity check for now to make sure that all the
# seqs are on the same steps. Eventually we will want to do some sort of
# dynamic scheduling when doing multi-step decoding.
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
if any([
seq_group.state.remaining_steps != ref_remaining_steps
for seq_group in seq_group_metadata_list[1:]
]):
raise AssertionError(("All running sequence groups should "
"have the same remaining steps."))
return ref_remaining_steps > 0

def _cache_scheduler_outputs_for_multi_step(
self, virtual_engine: int,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
scheduler_outputs: SchedulerOutputs) -> None:
self.cached_scheduler_outputs[
virtual_engine].seq_group_metadata_list = seq_group_metadata_list
self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
scheduler_outputs
self.cached_scheduler_outputs[virtual_engine].last_output = None
def _get_last_sampled_token_ids(
self, virtual_engine: int) -> Optional[torch.Tensor]:
cached_last_output = self.cached_scheduler_outputs[
virtual_engine].last_output
if (self.scheduler_config.is_multi_step
and self.parallel_config.pipeline_parallel_size > 1
and cached_last_output is not None
and cached_last_output.sampled_token_ids_cpu is not None):
return cached_last_output.sampled_token_ids_cpu
return None

def _update_cached_scheduler_output(
self, virtual_engine: int,
output: List[Optional[SamplerOutput]]) -> None:
if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
and output[0] is not None):
last_output = output[-1]
assert last_output is not None
assert last_output.sampled_token_ids_cpu is not None
assert last_output.sampled_token_ids is None
assert last_output.sampled_token_probs is None
self.cached_scheduler_outputs[
virtual_engine].last_output = last_output

async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
Expand Down
12 changes: 8 additions & 4 deletions aphrodite/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,18 @@ def _get_create_worker_kwargs(
distributed_init_method: Optional[str] = None) -> Dict:
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
distributed_init_method)
if self.speculative_config is None:
if self.scheduler_config.is_multi_step:
worker_kwargs.update(
worker_module_name="aphrodite.task_handler.worker",
worker_class_name="Worker")
else:
worker_module_name="aphrodite.task_handler.multi_step_worker",
worker_class_name="MultiStepWorker")
elif self.speculative_config:
worker_kwargs.update(
worker_module_name="aphrodite.spec_decode.spec_decode_worker",
worker_class_name="create_spec_worker")
else:
worker_kwargs.update(
worker_module_name="aphrodite.task_handler.worker",
worker_class_name="Worker")
return worker_kwargs

def _create_worker(self,
Expand Down
3 changes: 3 additions & 0 deletions aphrodite/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def _get_worker_wrapper_args(self) -> Dict[str, Any]:
if self.speculative_config is not None:
worker_module_name = "aphrodite.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
elif self.scheduler_config.is_multi_step:
worker_module_name = "aphrodite.task_handler.multi_step_worker"
worker_class_name = "MultiStepWorker"
else:
worker_module_name = "aphrodite.task_handler.worker"
worker_class_name = "Worker"
Expand Down
43 changes: 31 additions & 12 deletions aphrodite/task_handler/model_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from aphrodite.attention.backends.abstract import AttentionBackend
from aphrodite.modeling import SamplingMetadata

T = TypeVar('T', bound="ModelRunnerInputBase")
T = TypeVar('T', bound="BroadcastableModelInput")


def _add_attn_metadata_broadcastable_dict(
Expand Down Expand Up @@ -81,18 +81,24 @@ def _add_sampling_metadata_broadcastable_dict(
sampling_metadata.selected_token_indices)


@dataclasses.dataclass(frozen=True)
class ModelRunnerInputBase(ABC):
"""Local inputs to each worker's model runner. May contain
device-specific data. Different worker backends may have different methods
of converting from the global ExecuteModelRequest produced by the LLM
engine to the worker-local ModelRunnerInputBase objects.
Model runners that support multi-GPU execution should define a
ModelRunnerInputBase subclass, add their required fields, and specify how to
serialize/deserialize a ModelInput for broadcast between workers.
def _init_frozen_model_input_from_tensor_dict(
frozen_model_input_cls: Type["ModelRunnerInputBase"],
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Helper method to initialize a frozen ModelInput based on broadcastable
"""
valid_tensor_kwargs = {}
for field in dataclasses.fields(frozen_model_input_cls):
val = tensor_dict.pop(field.name, None)
if val is not None:
valid_tensor_kwargs[field.name] = val
frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs)
tensor_dict["frozen_model_input"] = frozen_model_input
return tensor_dict

class BroadcastableModelInput(ABC):

@abstractmethod
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
"""
Extract broadcastable fields. Override for fields that require some
Expand All @@ -109,11 +115,24 @@ def from_broadcasted_tensor_dict(
) -> T:
"""
Pop fields from the given tensor_dict and populate a new instance of
ModelRunnerInputBase.
BroadcastableModelInput.
"""
raise NotImplementedError


@dataclasses.dataclass(frozen=True)
class ModelRunnerInputBase(BroadcastableModelInput):
"""Local inputs to each worker's model runner. May contain
device-specific data. Different worker backends may have different methods
of converting from the global ExecuteModelRequest produced by the LLM
engine to the worker-local ModelRunnerInputBase objects.
Model runners that support multi-GPU execution should define a
ModelRunnerInputBase subclass, add their required fields, and specify how to
serialize/deserialize a ModelInput for broadcast between workers.
"""
pass


class ModelRunnerInputBuilderBase(ABC, Generic[T]):
"""A builder to create ModelRunnerInputBase objects.
"""
Expand Down
Loading

0 comments on commit 48a8693

Please sign in to comment.