From 40d5e5ff540785dca03ba0e9e6d5c86ae59315d2 Mon Sep 17 00:00:00 2001 From: Will Lin Date: Thu, 8 Aug 2024 21:49:57 -0700 Subject: [PATCH] clean up --- vllm/worker/multi_step_model_runner.py | 1 - vllm/worker/multi_step_worker.py | 17 ++++------------- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 1cd3a5470673f..97696a4bda34c 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -313,7 +313,6 @@ def execute_model( # event for the pythonization so that we only pythonize if the # tensors are ready. May be able to be combined with the step event - # torch.cuda.synchronize() output_ready_event = torch.cuda.Event() output_ready_event.record(current_stream) if self.parallel_config.pipeline_parallel_size > 1: diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index 1d0e9c8ef9da8..0b1342347bc5a 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -135,17 +135,12 @@ def prepare_input( pass # cache the worker input and model input for the next steps # TODO(will) see below - - # self.multi_step_states[virtual_engine] = MultiStepState( - # worker_input=worker_input, model_input=model_input) else: - # TODO(will) possible to also use the cached worker input and model input - # this can be done if we want to optimize the broadcast to only send - # the last sampled token ids for non-first multi steps + # TODO(will) possible to also use the cached worker input and + # model input this can be done if we want to optimize the + # broadcast to only send the last sampled token ids for + # non-first multi steps - # multi_step_state = self.multi_step_states[virtual_engine] - # cached_model_input = multi_step_state.model_input - # cached_worker_input = multi_step_state.worker_input assert isinstance( model_input, MutableModelInputForGPUWithMultiStepMetadata) # we need to update the last sampled token ids in the model input @@ -153,10 +148,6 @@ def prepare_input( model_input.add_sampler_output( SamplerOutput(outputs=[], sampled_token_ids=None), model_input.last_sampled_token_ids) - # self.multi_step_states[virtual_engine] = MultiStepState( - # worker_input=worker_input, model_input=model_input) - # model_input = cached_model_input - # worker_input = cached_worker_input assert model_input is not None assert worker_input is not None