From f205c09854853172a446c92aa81eb7199da324ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20M=2E=20K=C3=BCbler?= <44084297+jmkuebler@users.noreply.github.com> Date: Thu, 29 Aug 2024 07:18:13 +0200 Subject: [PATCH] [Bugfix] Unify rank computation across regular decoding and speculative decoding (#7899) --- tests/spec_decode/test_utils.py | 21 ++++++++++++++++++++- vllm/spec_decode/util.py | 4 ++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py index 06780d4b8cd01..195fce64822bd 100644 --- a/tests/spec_decode/test_utils.py +++ b/tests/spec_decode/test_utils.py @@ -4,10 +4,12 @@ import torch from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.layers.sampler import _get_ranks from vllm.model_executor.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids -from vllm.spec_decode.util import split_batch_by_proposal_len +from vllm.spec_decode.util import (get_sampled_token_logprobs, + split_batch_by_proposal_len) def test_get_all_seq_ids(): @@ -126,3 +128,20 @@ def mock_spec_decode_sampler(acceptance_sampler_method): return sampler else: raise ValueError(f"Invalid sampler name {acceptance_sampler_method}") + + +def test_get_sampled_token_logprobs(): + """Verify get_sampled_token_logprobs returns consistent rankings + with regular get_ranks when probabilities match exactly. + """ + logprob_tensor = torch.tensor( + [[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size) + sampled_token_tensor = torch.tensor([[1, + 0]]) # shape (num_steps, batch_size) + ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor, + sampled_token_tensor) + + ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)), + sampled_token_tensor.reshape(-1)) + + assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular) diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index d18ee47e23a5c..5d5f8767e5b6d 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -43,8 +43,8 @@ def get_sampled_token_logprobs( sampled_token_ids, ] expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand( -1, -1, vocab_size) - sampled_token_ids_ranks = (logprob_tensor >= - expanded_selected_logprobs).sum(-1) + sampled_token_ids_ranks = (logprob_tensor > + expanded_selected_logprobs).sum(-1).add_(1) return sampled_token_ids_ranks, selected_logprobs