Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] Implement multi-step scheduling #8489

Merged
merged 14 commits into from
Sep 14, 2024
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config,
self.use_async_output_proc = False
return

if self.enforce_eager:
if device_config.device_type == "cuda" and self.enforce_eager:
logger.warning(
"To see benefits of async output processing, enable CUDA "
"graph. Since, enforce-eager is enabled, async output "
Expand Down
8 changes: 6 additions & 2 deletions vllm/executor/ray_tpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,12 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
)

assert self.speculative_config is None
worker_module_name = "vllm.worker.tpu_worker"
worker_class_name = "TPUWorker"
if self.scheduler_config.is_multi_step:
worker_module_name = "vllm.worker.multi_step_tpu_worker"
worker_class_name = "MultiStepTPUWorker"
else:
worker_module_name = "vllm.worker.tpu_worker"
worker_class_name = "TPUWorker"

# GKE does not fetch environment information from metadata server
# and instead sets these from within the Ray process. Therefore we
Expand Down
16 changes: 11 additions & 5 deletions vllm/executor/tpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,17 @@ def _create_worker(
rank: int = 0,
distributed_init_method: Optional[str] = None,
):
from vllm.worker.tpu_worker import TPUWorker

worker = TPUWorker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return worker
if self.scheduler_config.is_multi_step:
from vllm.worker.multi_step_tpu_worker import MultiStepTPUWorker
worker = MultiStepTPUWorker(**self._get_worker_kwargs(
local_rank, rank, distributed_init_method))
return worker
else:
from vllm.worker.tpu_worker import TPUWorker

worker = TPUWorker(**self._get_worker_kwargs(
local_rank, rank, distributed_init_method))
return worker

def initialize_cache(
self,
Expand Down
105 changes: 105 additions & 0 deletions vllm/worker/multi_step_tpu_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import dataclasses
from typing import Dict, Optional, Tuple

import torch

from vllm.distributed import broadcast_tensor_dict
from vllm.sequence import ExecuteModelRequest
from vllm.worker.tpu_model_runner import ModelInputForTPU
from vllm.worker.tpu_worker import TPUWorker
from vllm.worker.worker_base import WorkerInput


class MultiStepTPUWorker(TPUWorker):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cached_model_input: Optional[ModelInputForTPU] = None

def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]:
assert self.is_driver_worker
assert execute_model_req.virtual_engine == 0

is_first_multi_step = execute_model_req.is_first_multi_step
is_last_step = execute_model_req.is_last_step
if is_first_multi_step:
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
worker_input = dataclasses.replace(
worker_input,
num_steps=execute_model_req.num_lookahead_slots + 1)
model_input: ModelInputForTPU = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))

if execute_model_req.async_callback:
model_input = dataclasses.replace(
model_input,
async_callback=execute_model_req.async_callback)
else:
assert self.cached_model_input is not None
model_input = self.cached_model_input
worker_input = WorkerInput()
model_input = dataclasses.replace(
model_input,
is_first_multi_step=is_first_multi_step,
is_last_step=is_last_step)

if self.do_metadata_broadcast:
if is_first_multi_step:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(
model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)
else:
broadcast_data = {
"is_first_multi_step": is_first_multi_step,
"is_last_step": is_last_step,
}
broadcast_tensor_dict(broadcast_data, src=0)

# Retuning empty dict here to keep this compatible with
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
return model_input, worker_input, {}

def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[Tuple[ModelInputForTPU, WorkerInput, Dict[str,
torch.Tensor]]]:
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
broadcast_tensor_dict({}, src=0)
return None

model_input, worker_input, _ = self._get_driver_input_and_broadcast(
execute_model_req)
if model_input.is_first_multi_step:
self.cached_model_input = model_input
return model_input, worker_input, {}
else:
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None

if len(broadcast_data) == 2:
assert self.cached_model_input is not None
self.cached_model_input = dataclasses.replace(
self.cached_model_input,
is_first_multi_step=broadcast_data["is_first_multi_step"],
is_last_step=broadcast_data["is_last_step"])
empty_worker_input = WorkerInput()
return self.cached_model_input, empty_worker_input, {}

worker_input = WorkerInput.from_broadcasted_tensor_dict(
broadcast_data)
model_input = (
self.model_runner.
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
self.cached_model_input = model_input
return model_input, worker_input, {}
Loading
Loading