Skip to content

Commit

Permalink
[TPU] Implement multi-step scheduling (#8489)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Sep 14, 2024
1 parent 47790f3 commit 50e9ec4
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 76 deletions.
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config,
self.use_async_output_proc = False
return

if self.enforce_eager:
if device_config.device_type == "cuda" and self.enforce_eager:
logger.warning(
"To see benefits of async output processing, enable CUDA "
"graph. Since, enforce-eager is enabled, async output "
Expand Down
8 changes: 6 additions & 2 deletions vllm/executor/ray_tpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,12 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
)

assert self.speculative_config is None
worker_module_name = "vllm.worker.tpu_worker"
worker_class_name = "TPUWorker"
if self.scheduler_config.is_multi_step:
worker_module_name = "vllm.worker.multi_step_tpu_worker"
worker_class_name = "MultiStepTPUWorker"
else:
worker_module_name = "vllm.worker.tpu_worker"
worker_class_name = "TPUWorker"

# GKE does not fetch environment information from metadata server
# and instead sets these from within the Ray process. Therefore we
Expand Down
16 changes: 11 additions & 5 deletions vllm/executor/tpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,17 @@ def _create_worker(
rank: int = 0,
distributed_init_method: Optional[str] = None,
):
from vllm.worker.tpu_worker import TPUWorker

worker = TPUWorker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return worker
if self.scheduler_config.is_multi_step:
from vllm.worker.multi_step_tpu_worker import MultiStepTPUWorker
worker = MultiStepTPUWorker(**self._get_worker_kwargs(
local_rank, rank, distributed_init_method))
return worker
else:
from vllm.worker.tpu_worker import TPUWorker

worker = TPUWorker(**self._get_worker_kwargs(
local_rank, rank, distributed_init_method))
return worker

def initialize_cache(
self,
Expand Down
105 changes: 105 additions & 0 deletions vllm/worker/multi_step_tpu_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import dataclasses
from typing import Dict, Optional, Tuple

import torch

from vllm.distributed import broadcast_tensor_dict
from vllm.sequence import ExecuteModelRequest
from vllm.worker.tpu_model_runner import ModelInputForTPU
from vllm.worker.tpu_worker import TPUWorker
from vllm.worker.worker_base import WorkerInput


class MultiStepTPUWorker(TPUWorker):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cached_model_input: Optional[ModelInputForTPU] = None

def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]:
assert self.is_driver_worker
assert execute_model_req.virtual_engine == 0

is_first_multi_step = execute_model_req.is_first_multi_step
is_last_step = execute_model_req.is_last_step
if is_first_multi_step:
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
worker_input = dataclasses.replace(
worker_input,
num_steps=execute_model_req.num_lookahead_slots + 1)
model_input: ModelInputForTPU = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))

if execute_model_req.async_callback:
model_input = dataclasses.replace(
model_input,
async_callback=execute_model_req.async_callback)
else:
assert self.cached_model_input is not None
model_input = self.cached_model_input
worker_input = WorkerInput()
model_input = dataclasses.replace(
model_input,
is_first_multi_step=is_first_multi_step,
is_last_step=is_last_step)

if self.do_metadata_broadcast:
if is_first_multi_step:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(
model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)
else:
broadcast_data = {
"is_first_multi_step": is_first_multi_step,
"is_last_step": is_last_step,
}
broadcast_tensor_dict(broadcast_data, src=0)

# Retuning empty dict here to keep this compatible with
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
return model_input, worker_input, {}

def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[Tuple[ModelInputForTPU, WorkerInput, Dict[str,
torch.Tensor]]]:
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
broadcast_tensor_dict({}, src=0)
return None

model_input, worker_input, _ = self._get_driver_input_and_broadcast(
execute_model_req)
if model_input.is_first_multi_step:
self.cached_model_input = model_input
return model_input, worker_input, {}
else:
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None

if len(broadcast_data) == 2:
assert self.cached_model_input is not None
self.cached_model_input = dataclasses.replace(
self.cached_model_input,
is_first_multi_step=broadcast_data["is_first_multi_step"],
is_last_step=broadcast_data["is_last_step"])
empty_worker_input = WorkerInput()
return self.cached_model_input, empty_worker_input, {}

worker_input = WorkerInput.from_broadcasted_tensor_dict(
broadcast_data)
model_input = (
self.model_runner.
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
self.cached_model_input = model_input
return model_input, worker_input, {}
Loading

0 comments on commit 50e9ec4

Please sign in to comment.