Skip to content

Commit

Permalink
rename to StatefulModelInput
Browse files Browse the repository at this point in the history
  • Loading branch information
SolitaryThinker committed Aug 19, 2024
1 parent 70b4e12 commit e9a0211
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 27 deletions.
10 changes: 5 additions & 5 deletions tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ModelInputForGPUWithPoolingMetadata)
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.worker.multi_step_model_runner import (
MutableModelInputForGPUWithMultiStepMetadata)
StatefulModelInput)


class MockAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_multi_step_model_runner_input():
sampling_metadata=sampling_metadata,
attn_metadata=attn_metadata)

model_input = MutableModelInputForGPUWithMultiStepMetadata(
model_input = StatefulModelInput(
frozen_model_input=frozen_model_input,
is_last_step=True,
is_first_multi_step=False,
Expand All @@ -190,20 +190,20 @@ def test_multi_step_model_runner_input():
)

assert isinstance(model_input,
MutableModelInputForGPUWithMultiStepMetadata)
StatefulModelInput)

# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (MutableModelInputForGPUWithMultiStepMetadata.
received_model_input = (StatefulModelInput.
from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))

receieved_frozen_input = received_model_input.frozen_model_input

# Check that received copy has correct values.
assert isinstance(received_model_input,
MutableModelInputForGPUWithMultiStepMetadata)
StatefulModelInput)
assert receieved_frozen_input.input_tokens is not None
assert (receieved_frozen_input.input_tokens ==
frozen_model_input.input_tokens).all()
Expand Down
28 changes: 14 additions & 14 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class ModelOutput:

def pythonize(
self,
input_metadata: "MutableModelInputForGPUWithMultiStepMetadata",
input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream,
pinned_sampled_token_buffer: torch.Tensor) -> None:
"""Pythonize the output. Blocking."""
Expand All @@ -65,7 +65,7 @@ def pythonize(

def maybe_pythonize(
self,
input_metadata: "MutableModelInputForGPUWithMultiStepMetadata",
input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream,
pinned_sampled_token_buffer: torch.Tensor) -> None:
"""Pythonize the output if ready, else return None. Non-blocking."""
Expand All @@ -76,7 +76,7 @@ def maybe_pythonize(

def _pythonize_sampler_output(
self,
input_metadata: "MutableModelInputForGPUWithMultiStepMetadata",
input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream,
pinned_sampled_token_buffer: torch.Tensor, blocking: bool) -> bool:
"""
Expand All @@ -97,7 +97,7 @@ def _pythonize_sampler_output(


@dataclass(frozen=False)
class MutableModelInputForGPUWithMultiStepMetadata(BroadcastableModelInput):
class StatefulModelInput(BroadcastableModelInput):
# actual frozen model input dataclass passed to _base_model_runner
frozen_model_input: Optional[ModelInputForGPUWithSamplingMetadata] = None

Expand Down Expand Up @@ -137,7 +137,7 @@ def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "MutableModelInputForGPUWithMultiStepMetadata":
) -> "StatefulModelInput":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
Expand Down Expand Up @@ -182,7 +182,7 @@ def add_sampler_output(self,
# metadata
# mypy: disable-error-code=type-var
class MultiStepModelRunner(
GPUModelRunnerBase[MutableModelInputForGPUWithMultiStepMetadata]):
GPUModelRunnerBase[StatefulModelInput]):
# mypy: enable-error-code=type-var

def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):
Expand All @@ -199,8 +199,8 @@ def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):

def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]
) -> MutableModelInputForGPUWithMultiStepMetadata:
model_input = (MutableModelInputForGPUWithMultiStepMetadata.
) -> StatefulModelInput:
model_input = (StatefulModelInput.
from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
Expand All @@ -212,11 +212,11 @@ def prepare_model_input(
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> MutableModelInputForGPUWithMultiStepMetadata:
) -> StatefulModelInput:
frozen_model_input = self._base_model_runner.prepare_model_input(
seq_group_metadata_list, virtual_engine, finished_requests_ids)

model_input = MutableModelInputForGPUWithMultiStepMetadata(
model_input = StatefulModelInput(
frozen_model_input=frozen_model_input,
num_seqs=len(frozen_model_input.seq_lens),
num_queries=len(frozen_model_input.query_lens),
Expand All @@ -226,7 +226,7 @@ def prepare_model_input(
@torch.inference_mode()
def execute_model(
self,
model_input: MutableModelInputForGPUWithMultiStepMetadata,
model_input: StatefulModelInput,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
Expand Down Expand Up @@ -354,9 +354,9 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs,
assert seq_group.query_len is None # Decode

def _advance_step(
self, model_input: MutableModelInputForGPUWithMultiStepMetadata,
self, model_input: StatefulModelInput,
out: SamplerOutput
) -> MutableModelInputForGPUWithMultiStepMetadata:
) -> StatefulModelInput:
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
assert frozen_model_input.attn_metadata is not None
Expand Down Expand Up @@ -420,7 +420,7 @@ def vocab_size(self) -> int:


def _pythonize_sampler_output(
model_input: MutableModelInputForGPUWithMultiStepMetadata,
model_input: StatefulModelInput,
output: SamplerOutput, pinned_sampled_token_buffer: torch.Tensor,
sampled_token_ids: torch.Tensor) -> SamplerOutput:
""" This function is only called when the output tensors are ready.
Expand Down
16 changes: 8 additions & 8 deletions vllm/worker/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.worker.model_runner_base import BroadcastableModelInput
from vllm.worker.multi_step_model_runner import (
MultiStepModelRunner, MutableModelInputForGPUWithMultiStepMetadata)
MultiStepModelRunner, StatefulModelInput)
from vllm.worker.worker import Worker, WorkerInput


@dataclass
class MultiStepState:
worker_input: WorkerInput
model_input: MutableModelInputForGPUWithMultiStepMetadata
model_input: StatefulModelInput


class MultiStepWorker(Worker):
Expand Down Expand Up @@ -54,7 +54,7 @@ def _get_driver_input_and_broadcast(
# on first step we prepare the worker input and model input normally
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: MutableModelInputForGPUWithMultiStepMetadata = (
model_input: StatefulModelInput = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
Expand Down Expand Up @@ -90,7 +90,7 @@ def _get_driver_input_and_broadcast(
def _prepare_last_sampled_token_ids_for_tp_workers(
self,
execute_model_req: ExecuteModelRequest,
model_input: MutableModelInputForGPUWithMultiStepMetadata,
model_input: StatefulModelInput,
) -> None:
"""
Prepare the last sampled token ids for TP workers. If it's the last
Expand Down Expand Up @@ -130,7 +130,7 @@ def _prepare_last_sampled_token_ids_for_tp_workers(
def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[Tuple[MutableModelInputForGPUWithMultiStepMetadata,
) -> Optional[Tuple[StatefulModelInput,
WorkerInput]]:
"""
Depending on the current state of the request and multi step worker,
Expand All @@ -152,7 +152,7 @@ def prepare_input(
model_input, worker_input = self._get_driver_input_and_broadcast(
execute_model_req)
assert isinstance(model_input,
MutableModelInputForGPUWithMultiStepMetadata)
StatefulModelInput)
if execute_model_req.is_first_multi_step:
# cache the worker input and model input for the next steps
self.multi_step_states[virtual_engine] = MultiStepState(
Expand All @@ -166,7 +166,7 @@ def prepare_input(
return None
model_input, worker_input = broadcast_data
assert isinstance(model_input,
MutableModelInputForGPUWithMultiStepMetadata)
StatefulModelInput)
virtual_engine = worker_input.virtual_engine
if model_input.is_first_multi_step:
pass
Expand All @@ -180,7 +180,7 @@ def prepare_input(
# for the next step (sampled_token_ids from the previous step)

assert isinstance(
model_input, MutableModelInputForGPUWithMultiStepMetadata)
model_input, StatefulModelInput)
# we need to update the last sampled token ids in the model
# input for the workers so that they can run inplace
# advance_step
Expand Down

0 comments on commit e9a0211

Please sign in to comment.