From 49ecfe9dd757e043873fd50ebcf1414afd518fb5 Mon Sep 17 00:00:00 2001 From: jiqing-feng <107918818+jiqing-feng@users.noreply.github.com> Date: Tue, 24 Sep 2024 03:32:27 +0800 Subject: [PATCH] Fix typical acceptance sampler with correct recovered token ids (#8562) --- .../test_typical_acceptance_sampler.py | 17 ++++++----- .../layers/typical_acceptance_sampler.py | 28 ++++++------------- 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index 1eba98cefd04a..4ddad66dce1fb 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -365,7 +365,7 @@ def test_accept_tokens_partially(seed: int, device: str): # Next only keep the first 2 draft tokens same as the zero temperature # tokens. For the remaining 3 choose some other tokens. In the # response we will expect the first 2 tokens to be the same as the - # draft tokens and the rest as -1 + # draft tokens and the recovered token and rest as -1 draft_token_ids_to_replace = get_draft_token_ids( batch_size, k, vocab_size, zero_temperature_token_ids) draft_token_ids = torch.cat( @@ -378,6 +378,8 @@ def test_accept_tokens_partially(seed: int, device: str): assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2]) + assert torch.all( + output_token_ids[:, 2] == target_with_bonus_probs.argmax(-1)[:, 2]) assert torch.all(output_token_ids[:, -3:] == -1) @@ -443,14 +445,14 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str): @pytest.mark.parametrize("seed", list(range(10))) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_replacement_token_ids(seed: int, device: str): +def test_get_recovered_token_ids(seed: int, device: str): """ Test the TypicalAcceptanceSampler's method for generating replacement token IDs. - This test verifies that the `_replacement_token_ids` method of the + This test verifies that the `_get_recovered_token_ids` method of the TypicalAcceptanceSampler correctly identifies the token IDs to be used - as replacements based on the target probability distribution. + as recovered token IDs based on the target probability distribution. Specifically, it ensures that the method correctly identifies the tokens with the highest probability for each sequence in the batch. """ @@ -462,10 +464,7 @@ def test_replacement_token_ids(seed: int, device: str): typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) typical_acceptance_sampler.init_gpu_tensors(device=device) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - expected_replacement_tokens = -torch.ones( - (batch_size, k), dtype=torch.long) - expected_replacement_tokens[:, 0] = torch.argmax(target_probs[:, 0, :], - dim=1) + expected_replacement_tokens = torch.argmax(target_probs, dim=-1) actual_replacement_tokens = ( - typical_acceptance_sampler._replacement_token_ids(target_probs)) + typical_acceptance_sampler._get_recovered_token_ids(target_probs)) assert torch.all(expected_replacement_tokens == actual_replacement_tokens) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 8c03e46927752..584cf971d9c05 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -80,7 +80,7 @@ def forward( target_probs = target_with_bonus_probs[:, :-1] accepted = self._evaluate_accepted_tokens(target_probs, draft_token_ids) - recovered_token_ids = self._replacement_token_ids(target_probs) + recovered_token_ids = self._get_recovered_token_ids(target_probs) output_token_ids = self._create_output(accepted, recovered_token_ids, draft_token_ids, bonus_token_ids) @@ -148,16 +148,10 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): accepted_mask = candidates_prob > threshold return accepted_mask - def _replacement_token_ids(self, target_probs): + def _get_recovered_token_ids(self, target_probs): """ - Generate one replacement token ID for each sequence based on target - probabilities. The replacement token is used as the fallback option - if typical acceptance sampling does not accept any draft tokens for - that particular sequence. - - This method computes the token IDs to be replaced by selecting the - token with the highest probability for each sequence in the first - position. The rest of the output is filled with -1. + The recovered token ids will fill the first unmatched token + by the target token. Parameters ---------- @@ -168,13 +162,9 @@ def _replacement_token_ids(self, target_probs): Returns ------- torch.Tensor - A tensor of shape (batch_size, k) with the replacement - token IDs. Only the first column is set, and the rest of the - columns are filled with -1. + A tensor of shape (batch_size, k) with the recovered token + ids which are selected from target probs. """ - max_indices = torch.argmax(target_probs[:, 0, :], dim=1) - output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), - dtype=self.token_id_dtype, - device=target_probs.device) - output[:, 0] = max_indices - return output + max_indices = torch.argmax(target_probs, dim=-1) + + return max_indices