From 3be2ea5c71168e2683426a552b26daff6cd7dc65 Mon Sep 17 00:00:00 2001 From: Will Lin Date: Tue, 30 Jul 2024 21:12:54 -0700 Subject: [PATCH 1/3] [Speculative decoding] [Multi-Step] decouple should_modify_greedy_probs_inplace --- vllm/lora/layers.py | 4 ++++ vllm/model_executor/layers/sampler.py | 4 ++-- vllm/spec_decode/medusa_worker.py | 3 +++ vllm/spec_decode/multi_step_worker.py | 5 +++++ vllm/spec_decode/proposer_worker_base.py | 4 ++++ vllm/spec_decode/smaller_tp_proposer_worker.py | 6 ++++++ vllm/spec_decode/spec_decode_worker.py | 3 +++ 7 files changed, 27 insertions(+), 2 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index e3316059dc6d1..e5c28eff7f166 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1066,6 +1066,10 @@ def org_vocab_size(self): @property def include_gpu_probs_tensor(self): return self.base_layer.include_gpu_probs_tensor + + @property + def should_modify_greedy_probs_inplace(self): + return self.base_layer.should_modify_greedy_probs_inplace def create_lora_weights( self, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 6632b1c434582..cc78a0ea3b869 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -51,6 +51,7 @@ def __init__(self): # containing the sampled token ids and probabilities. This is used by # speculative decoding. self.include_gpu_probs_tensor = False + self.should_modify_greedy_probs_inplace = False def _init_sampling_tensors( self, @@ -177,8 +178,7 @@ def _should_modify_greedy_probs_inplace(self) -> bool: This is used by speculative decoding, which requires that the sampling method be encoded into the probability distribution. """ - # Modify greedy probs if include_gpu_probs_tensor is set. - return self.include_gpu_probs_tensor + return self.should_modify_greedy_probs_inplace def _get_bin_counts_and_mask( diff --git a/vllm/spec_decode/medusa_worker.py b/vllm/spec_decode/medusa_worker.py index 4b82f7bf92bab..d1809e49c2a8f 100644 --- a/vllm/spec_decode/medusa_worker.py +++ b/vllm/spec_decode/medusa_worker.py @@ -35,6 +35,9 @@ def init_device(self): def set_include_gpu_probs_tensor(self): pass + def set_should_modify_greedy_probs_inplace(self): + pass + @torch.inference_mode() def sampler_output( self, diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 91689324557b5..f0e6d1c4f7821 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -46,6 +46,11 @@ def set_include_gpu_probs_tensor(self) -> None: # Need include_gpu_probs_tensor for MultiStepWorker self.model_runner.model.sampler.include_gpu_probs_tensor = True + def set_should_modify_greedy_probs_inplace(self) -> None: + self.model_runner.model.sampler.should_modify_greedy_probs_inplace = ( + True + ) + @torch.inference_mode() def sampler_output( self, diff --git a/vllm/spec_decode/proposer_worker_base.py b/vllm/spec_decode/proposer_worker_base.py index 51cefc0cbca8b..efb8ee25ba2f9 100644 --- a/vllm/spec_decode/proposer_worker_base.py +++ b/vllm/spec_decode/proposer_worker_base.py @@ -28,6 +28,10 @@ def set_include_gpu_probs_tensor(self) -> None: """Implementation optional""" pass + def set_should_modify_greedy_probs_inplace(self) -> None: + """Implementation optional""" + pass + class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC): """Proposer worker which does not use a model with kvcache""" diff --git a/vllm/spec_decode/smaller_tp_proposer_worker.py b/vllm/spec_decode/smaller_tp_proposer_worker.py index 0dbb924d25400..215ede52fb812 100644 --- a/vllm/spec_decode/smaller_tp_proposer_worker.py +++ b/vllm/spec_decode/smaller_tp_proposer_worker.py @@ -83,6 +83,12 @@ def set_include_gpu_probs_tensor(self) -> None: # Need include_gpu_probs_tensor for multi_step_worker self._worker.set_include_gpu_probs_tensor() + def set_should_modify_greedy_probs_inplace(self) -> None: + if self._is_dummy: + return + + self._worker.set_should_modify_greedy_probs_inplace() + def load_model(self) -> None: if self._is_dummy: return diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 690aad505e215..625bceb647943 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -295,7 +295,10 @@ def _configure_model_sampler_for_spec_decode(self): """ (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor ) = True + (self.scorer_worker.model_runner.model.sampler + .should_modify_greedy_probs_inplace) = True self.proposer_worker.set_include_gpu_probs_tensor() + self.proposer_worker.set_should_modify_greedy_probs_inplace() def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of cache blocks to use. From 94ed67c41a06644fdef6bc7cc16487567bd8d2d9 Mon Sep 17 00:00:00 2001 From: Will Lin Date: Tue, 30 Jul 2024 22:46:28 -0700 Subject: [PATCH 2/3] lint --- vllm/lora/layers.py | 2 +- vllm/spec_decode/multi_step_worker.py | 3 +-- vllm/spec_decode/spec_decode_worker.py | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index e5c28eff7f166..a8ea67991a375 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1066,7 +1066,7 @@ def org_vocab_size(self): @property def include_gpu_probs_tensor(self): return self.base_layer.include_gpu_probs_tensor - + @property def should_modify_greedy_probs_inplace(self): return self.base_layer.should_modify_greedy_probs_inplace diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index f0e6d1c4f7821..65bfb5dc8d5c6 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -48,8 +48,7 @@ def set_include_gpu_probs_tensor(self) -> None: def set_should_modify_greedy_probs_inplace(self) -> None: self.model_runner.model.sampler.should_modify_greedy_probs_inplace = ( - True - ) + True) @torch.inference_mode() def sampler_output( diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 625bceb647943..63a00139cc09d 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -295,8 +295,8 @@ def _configure_model_sampler_for_spec_decode(self): """ (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor ) = True - (self.scorer_worker.model_runner.model.sampler - .should_modify_greedy_probs_inplace) = True + (self.scorer_worker.model_runner.model.sampler. + should_modify_greedy_probs_inplace) = True self.proposer_worker.set_include_gpu_probs_tensor() self.proposer_worker.set_should_modify_greedy_probs_inplace() From 2fb2bc1e1edaf28921a6047e2e4cb6799dc7f770 Mon Sep 17 00:00:00 2001 From: Will Lin Date: Thu, 8 Aug 2024 19:26:48 -0700 Subject: [PATCH 3/3] inplace test --- tests/samplers/test_sampler.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index bf062e4a5c09d..f1370e411241c 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -1,7 +1,7 @@ import itertools import random from typing import Dict, List, Optional, Tuple -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest import torch @@ -703,3 +703,28 @@ def test_sampling_params(sampling_params: List[SamplingParams]): assert tokens1[0] == tokens2[1] assert tokens1[1] == tokens2[0] + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_sampler_include_gpu_probs_tensor(device: str): + set_random_seed(42) + torch.set_default_device(device) + batch_size = random.randint(1, 256) + _, fake_logits, sampler = _prepare_test(batch_size) + sampler.include_gpu_probs_tensor = True + sampler.should_modify_greedy_probs_inplace = False + + sampling_params = SamplingParams(temperature=0) + + mock_inplace = Mock() + with patch( + "vllm.model_executor.layers.sampler._modify_greedy_probs_inplace", + mock_inplace): + + sampler_output = _do_sample(batch_size, fake_logits, sampler, + sampling_params, device) + mock_inplace.assert_not_called() + + assert sampler_output.sampled_token_probs is not None + assert sampler_output.logprobs is not None + assert sampler_output.sampled_token_ids is not None