Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] [2/N] refactor worker_base input preparation for multi-step #7387

Merged
merged 1 commit into from
Aug 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine
num_steps = execute_model_req.num_steps
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync.
Expand All @@ -286,6 +287,7 @@ def prepare_worker_input(
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
num_steps=num_steps,
)

@torch.inference_mode()
Expand Down
92 changes: 61 additions & 31 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class WorkerInput:
blocks_to_swap_out: Optional[torch.Tensor] = None
blocks_to_copy: Optional[torch.Tensor] = None
virtual_engine: int = 0
num_steps: int = 1

@classmethod
def from_broadcasted_tensor_dict(
Expand All @@ -145,6 +146,7 @@ def from_broadcasted_tensor_dict(
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"),
)

def as_broadcastable_tensor_dict(
Expand All @@ -158,6 +160,7 @@ def as_broadcastable_tensor_dict(
"blocks_to_swap_out": self.blocks_to_swap_out,
"blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine,
"num_steps": self.num_steps,
}

return tensor_dict
Expand Down Expand Up @@ -216,13 +219,50 @@ def execute_worker(self, worker_input: WorkerInput) -> None:
"""
raise NotImplementedError

def execute_model(
def _get_worker_input_from_broadcast(
self) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]:
""" Get the worker input from the broadcasted tensor dict. """
assert self.do_metadata_broadcast
assert not self.is_driver_worker
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None

worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
model_input = (
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
broadcast_data))

return model_input, worker_input

def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[ModelRunnerInputBase, WorkerInput]:
""" Get the driver input and broadcast it to other workers. """
assert self.is_driver_worker

worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))

if self.do_metadata_broadcast:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)

return model_input, worker_input

def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
start_time = time.perf_counter()
) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]:
"""
Prepare the inputs to ModelRunner and workers.
"""
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
Expand All @@ -233,34 +273,24 @@ def execute_model(
# notify all other workers to stop their execution loop.
broadcast_tensor_dict({}, src=0)
return None

worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
num_steps = execute_model_req.num_steps

if self.do_metadata_broadcast:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(
model_input.as_broadcastable_tensor_dict())
broadcast_data["num_steps"] = num_steps
broadcast_tensor_dict(broadcast_data, src=0)
return self._get_driver_input_and_broadcast(execute_model_req)
else:
assert self.do_metadata_broadcast
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None
return self._get_worker_input_from_broadcast()

def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
start_time = time.perf_counter()

inputs = self.prepare_input(execute_model_req)
if inputs is None:
return None

num_steps = broadcast_data.pop("num_steps")
worker_input = WorkerInput.from_broadcasted_tensor_dict(
broadcast_data)
model_input = (
self.model_runner.
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
model_input, worker_input = inputs
num_steps = worker_input.num_steps

self.execute_worker(worker_input)

Expand Down
Loading