From bda8e6847431d5cb8027af1a6cbbc9c0ab3986e4 Mon Sep 17 00:00:00 2001 From: Will Lin Date: Fri, 9 Aug 2024 14:25:48 -0700 Subject: [PATCH] format --- vllm/engine/async_llm_engine.py | 23 ++++++----- vllm/sequence.py | 11 ++++-- vllm/worker/multi_step_model_runner.py | 53 +++++++++++++++----------- vllm/worker/multi_step_worker.py | 21 +++++----- vllm/worker/worker_base.py | 6 +-- 5 files changed, 61 insertions(+), 53 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index e21b60bac0c45..957db679228c7 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,13 +1,13 @@ import asyncio import time +from dataclasses import dataclass from functools import partial from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Type, Union) -from dataclasses import dataclass +import torch from transformers import PreTrainedTokenizer from typing_extensions import assert_never -import torch import vllm.envs as envs from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, @@ -31,7 +31,6 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceGroupMetadata) - from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) @@ -411,15 +410,15 @@ def _get_cached_sampled_token_ids_for_multi_step( def _cache_output_for_multi_step( self, virtual_engine: int, output: List[Optional[SamplerOutput]]) -> None: - if (self.parallel_config.pipeline_parallel_size > 1): - if len(output) > 0 and output[0] is not None: - last_output = output[-1] - assert last_output is not None - assert last_output.sampled_token_ids_numpy is not None - assert last_output.sampled_token_ids is None - assert last_output.sampled_token_probs is None - self.cached_scheduler_outputs[ - virtual_engine].last_output = last_output + if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0 + and output[0] is not None): + last_output = output[-1] + assert last_output is not None + assert last_output.sampled_token_ids_numpy is not None + assert last_output.sampled_token_ids is None + assert last_output.sampled_token_probs is None + self.cached_scheduler_outputs[ + virtual_engine].last_output = last_output async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" diff --git a/vllm/sequence.py b/vllm/sequence.py index b9b2cfedfd267..a73ee2e9f02e4 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -8,8 +8,8 @@ from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, Union, cast) -import torch import numpy +import torch from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs from vllm.lora.request import LoRARequest @@ -1147,21 +1147,24 @@ class ExecuteModelRequest: @property def is_first_multi_step(self) -> bool: - # TODO(will) make this be able to handle batches with variable number of steps + # TODO(will) make this be able to handle batches with variable number of + # steps assert len(self.seq_group_metadata_list) > 0 first_seq_group = self.seq_group_metadata_list[0] return first_seq_group.state.current_step == 0 @property def is_last_step(self) -> bool: - # TODO(will) make this be able to handle batches with variable number of steps + # TODO(will) make this be able to handle batches with variable number of + # steps assert len(self.seq_group_metadata_list) > 0 first_seq_group = self.seq_group_metadata_list[0] return first_seq_group.state.remaining_steps == 1 @property def current_step(self) -> int: - # TODO(will) make this be able to handle batches with variable number of steps + # TODO(will) make this be able to handle batches with variable number of + # steps assert len(self.seq_group_metadata_list) > 0 return self.seq_group_metadata_list[0].state.current_step diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 76a8221ea1c37..ea7cf6505ea71 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -1,6 +1,6 @@ -import dataclasses from dataclasses import dataclass, field -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Union) +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + try: from vllm.attention.backends.flash_attn import FlashAttentionMetadata except ModuleNotFoundError: @@ -8,21 +8,22 @@ from vllm.attention.backends.rocm_flash_attn import ( ROCmFlashAttentionMetadata as FlashAttentionMetadata) -from ..model_executor.model_loader.tensorizer import TensorizerConfig +import torch + +from vllm import _custom_ops as ops +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + Logprob, SamplerOutput, SequenceGroupMetadata, + SequenceOutput) +from vllm.worker.model_runner import (GPUModelRunnerBase, + ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ( - BroadcastableModelInput, _init_frozen_model_input_from_tensor_dict, - _init_attn_metadata_from_tensor_dict, + BroadcastableModelInput, _init_attn_metadata_from_tensor_dict, + _init_frozen_model_input_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) -from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, - GPUModelRunnerBase) -from vllm.logger import init_logger -from vllm.distributed import get_pp_group -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata, SequenceOutput, - CompletionSequenceGroupOutput, Logprob) -from vllm import _custom_ops as ops -import torch +from ..model_executor.model_loader.tensorizer import TensorizerConfig if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -43,7 +44,8 @@ class ModelOutput: There are two scenarios: 1. The output tensors are ready and we can pythonize them immediately. - 2. The output tensors are not ready and we need to wait for the event to be ready. + 2. The output tensors are not ready and we need to wait for the event to be + ready. """ sampler_output: SamplerOutput sampler_output_ready_event: torch.cuda.Event @@ -217,10 +219,11 @@ class MultiStepModelRunner(MultiStepModelRunnerBase): def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any] ) -> MutableModelInputForGPUWithMultiStepMetadata: - model_input = MutableModelInputForGPUWithMultiStepMetadata.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - ) + model_input = (MutableModelInputForGPUWithMultiStepMetadata. + from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + )) return model_input def prepare_model_input( @@ -271,9 +274,11 @@ def execute_model( device="cpu", pin_memory=True) - self._base_model_runner.model.sampler.include_gpu_probs_tensor = True + self._base_model_runner.model.sampler.include_gpu_probs_tensor = ( + True) if frozen_model_input.sampling_metadata: - frozen_model_input.sampling_metadata.skip_sampler_cpu_output = True + frozen_model_input.sampling_metadata.skip_sampler_cpu_output = ( + True) # TODO(will) Will need to benchmark and look at torch profiler for # the exact location we should do this. If the CPU is very ahead, it # does not matter if we call this before executable or after, as the @@ -296,7 +301,8 @@ def execute_model( # changing batch sizes, will remove afterwards and potentially leave # comment for future optimization if frozen_model_input.sampling_metadata: - frozen_model_input.sampling_metadata.reuse_sampling_tensors = False + frozen_model_input.sampling_metadata.reuse_sampling_tensors = ( + False) else: # This is not needed for flashattn backend, but for other attn # backends such as flashinfer that performs we may need to @@ -309,7 +315,8 @@ def execute_model( # changing batch sizes, will remove afterwards and potentially leave # comment for future optimization if frozen_model_input.sampling_metadata: - frozen_model_input.sampling_metadata.reuse_sampling_tensors = False + frozen_model_input.sampling_metadata.reuse_sampling_tensors = ( + False) # Execute the model output = self._base_model_runner.execute_model(frozen_model_input, diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index 0b1342347bc5a..7c9e044fec28a 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -1,14 +1,12 @@ -from vllm.worker.worker import Worker from dataclasses import dataclass -from vllm.worker.worker import WorkerInput -from vllm.worker.model_runner_base import BroadcastableModelInput -from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.distributed import broadcast_tensor_dict, get_pp_group -from typing import Tuple, Optional, List -from dataclasses import field +from typing import List, Optional, Tuple +from vllm.distributed import broadcast_tensor_dict, get_pp_group +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.worker.model_runner_base import BroadcastableModelInput from vllm.worker.multi_step_model_runner import ( MutableModelInputForGPUWithMultiStepMetadata) +from vllm.worker.worker import Worker, WorkerInput @dataclass @@ -70,8 +68,8 @@ def _get_driver_input_and_broadcast( # otherwise we need to get the cached sampled token ids from the # execute_model_req assert execute_model_req.last_sampled_token_ids is not None - model_input.last_sampled_token_ids = execute_model_req.last_sampled_token_ids.cuda( - ) + model_input.last_sampled_token_ids = ( + execute_model_req.last_sampled_token_ids.cuda()) model_input.add_sampler_output( SamplerOutput(outputs=[], sampled_token_ids=None), model_input.last_sampled_token_ids) @@ -143,8 +141,9 @@ def prepare_input( assert isinstance( model_input, MutableModelInputForGPUWithMultiStepMetadata) - # we need to update the last sampled token ids in the model input - # for the workers so that they can run inplace advance_step + # we need to update the last sampled token ids in the model + # input for the workers so that they can run inplace + # advance_step model_input.add_sampler_output( SamplerOutput(outputs=[], sampled_token_ids=None), model_input.last_sampled_token_ids) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 6cfea94e56ab4..fac8abbc902aa 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -16,9 +16,9 @@ SamplerOutput) from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) -from vllm.worker.model_runner_base import (ModelRunnerBase, - ModelRunnerInputBase, - BroadcastableModelInput) +from vllm.worker.model_runner_base import (BroadcastableModelInput, + ModelRunnerBase, + ModelRunnerInputBase) logger = init_logger(__name__)