-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Documentation request]: Add documentation on lossless guarantees of speculative decoding in vLLM #7627
Comments
@cadedaniel is this known / expected? |
Hi @jmkuebler. Thanks for trying it out and sharing your results. The expectation in vLLM is that the distribution mostly matches the target distribution. This can be split out into three different layers:
For (1), the sampling is lossless up to hardware numerics, see "Modified Rejection Sampling" from Accelerating Large Language Model Decoding with Speculative Sampling. This means that floating-point error can accumulate differently and cause differences in the output distribution. This may be enough to cause the difference in your two completions, but I suspect it's actually in (3). For (2), we have tests that validate vLLM's algorithmic implementations of speculative decoding perform as expected. Specifically, we have two categories of tests: a. Test that verifies samples from the rejection sampler converge to the target distribution. Introduced in #2336. This verifies that vLLM's rejection sampler provides a lossless guarantee. Code vllm/tests/samplers/test_rejection_sampler.py Line 252 in 47b65a5
b. Test that verify greedy sampling with speculative decoding is equal to greedy sampling without speculative decoding. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler, provides a lossless guarantee. Code: Almost all of the tests in this directory verify this property: https://github.com/vllm-project/vllm/tree/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e, assertion implementation: vllm/tests/spec_decode/e2e/conftest.py Line 291 in b67ae00
With both (a) and (b), we have reasonable confidence that vLLM's implementation of the algorithmic components in speculative decoding are lossless when temp=0. However, I've seen cases where changing the batch size can introduce small changes to the sampled tokens. This brings us to (3). (also this is noted here vllm/tests/spec_decode/e2e/test_multistep_correctness.py Lines 28 to 34 in 47b65a5
For (3), vLLM does not currently provide stable logprobs for the same prompt. Specifically they may change slightly as the batch size changes. I am not sure if this is due to a bug in the implementation or simply the numeric instability of torch operations like torch.bmm, as explained in "Numerical accuracy". In summary, I believe the behavior you observe is expected as vLLM does not have stable logprobs over different batch sizes. It could also be a bug in the algorithmic implementations, which you can validate with the tests listed above. P.S.
|
@jmkuebler this looks like typical precision-related variance, i.e. the numeric instability that @cadedaniel referred to. When the same requests are batched differently, whether this is due to other concurrent requests being processed or the expanded batches used in spec decoding, the results aren't guaranteed to be identical even if they should be from a theoretical / math pov. These precision differences manifest as slightly different logit/logprob values at each step, which at some point can result in a different token taking the top spot as you observed. Once a different token is chosen then further divergence is inevitable. Try setting the dtype to float32 for both of these and the variance should be significantly reduced (but will need double the mem of course). If you're using bfloat16 you can switch to float16 which should be much more stable (but still less stable than float32). Using seed shouldn't make any difference to greedy/temp=0 generation, but the above applies similarly when using a seed with temp>0. |
@cadedaniel Would it be possible to formalize what you wrote here into the docs as opposed to just linking to YouTube videos? This was really helpful for my understanding especially since I didn't realize that speculative decoding came with mild guarantees of distribution matching with and without spec decoding. |
Sure, let me convert this issue into a documentation issue. I will try to get to it in the next few weeks. also cc @sroy745 . |
@cadedaniel After reading your reply I agree that this is solely a documentation issue. It's not a practical problem, I was just surprised by the observation because I did not have the numerical aspects in mind. I think your explanation hits the root cause. I read out the logprobs for my prompt (only possible for target model alone as it is not possible for the multi-step worker used for SpecDec). Here is the top 2 logprobs for the token at which the deviation happens:
I have checked, they are actually exactly the same floating point numbers! So therefore, greedy decoding would actually be ambiguous. Curiously, for this particular token, the first entry in the logprob dict is |
@cadedaniel there might actually be a slight inconsistency: without speculative decoding the rank is computed as with an vllm/vllm/model_executor/layers/sampler.py Line 801 in baaedfd
whereas with SD it seems it is done via an >= (Line 46 in baaedfd
Could that be there reason and be sth we would want to align? My take is that there is no "right" or "false" here, but consistency could be desirable. Edit: we fixed this with #7899 |
that could be a bug. but it is not used in the lossless algorithm, only in what is returned to the user. feel free to open an issue and/or pr! nice find! |
@cadedaniel I did some further investigation. from vllm import LLM, SamplingParams
llm = LLM(
model="facebook/opt-6.7b",
tensor_parallel_size=1,
use_v2_block_manager=True,
)
sampling_params = SamplingParams(temperature=0, ignore_eos=True, max_tokens=2, top_k=1)
prompts = [
'I am at KDD conference in Barcelona. After the conference tutorial today I will be presenting a paper on the topic of “The role of the brain in the development',
]
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
sampling_params = SamplingParams(temperature=0, ignore_eos=True, max_tokens=1, top_k=1)
prompts = [
'I am at KDD conference in Barcelona. After the conference tutorial today I will be presenting a paper on the topic of “The role of the brain in the development of',
]
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}") Gives:
That's all happening also without Speculative decoding. Furthermore, running it full precision, removes this behavior (at least for this example). |
Nice find. Yep this is exactly what @njhill is talking about. |
I've seen something similar with Medusa as well when using BF16: https://github.com/vllm-project/vllm/pull/4978/files#r1627603023. Might be related. |
The root cause was identified to be numerics rather than a bug and documentation was added. Hence closing this issue. |
Your current environment
The output of `python collect_env.py`
🐛 Describe the bug
At temperature 0, I would expect that a model with and without speculative decoding results in exactly the same generation. At least the theory suggests that and I did not see a warning that the implementation would not.
But in the example below (which is close to the example in the documentation), the output actually differs (marked in bolt where it starts to differ).
Target model alone:
generates
With SD
generates
The text was updated successfully, but these errors were encountered: