Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Documentation][Spec Decode] Add documentation about lossless guarantees in Speculative Decoding in vLLM #7962

Merged
merged 51 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
5650b95
Merge pull request #1 from vllm-project/main
sroy745 May 29, 2024
8f36146
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
9e75057
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
db2c679
Merge branch 'vllm-project:main' into main
sroy745 Jun 7, 2024
8d7512c
Merge branch 'vllm-project:main' into main
sroy745 Jun 10, 2024
1473f74
Merge branch 'vllm-project:main' into main
sroy745 Jun 12, 2024
4013e1a
Merge branch 'vllm-project:main' into main
sroy745 Jun 14, 2024
2dbdd78
Merge branch 'vllm-project:main' into main
sroy745 Jun 17, 2024
b3575e9
Merge branch 'vllm-project:main' into main
sroy745 Jun 20, 2024
94b0d43
Merge branch 'vllm-project:main' into main
sroy745 Jun 24, 2024
fa8fedf
Merge branch 'vllm-project:main' into main
sroy745 Jun 27, 2024
6ed96b4
Merge branch 'vllm-project:main' into main
sroy745 Jun 27, 2024
b71c533
Merge branch 'vllm-project:main' into main
sroy745 Jun 28, 2024
57babef
Merge branch 'vllm-project:main' into main
sroy745 Jun 29, 2024
4b19bac
Merge branch 'vllm-project:main' into main
sroy745 Jul 1, 2024
eb7a1c4
Merge branch 'vllm-project:main' into main
sroy745 Jul 6, 2024
7e2c87e
Merge branch 'vllm-project:main' into main
sroy745 Jul 10, 2024
6212d5f
Merge branch 'vllm-project:main' into main
sroy745 Jul 15, 2024
5491438
Merge branch 'vllm-project:main' into main
sroy745 Jul 17, 2024
68e080a
Merge branch 'vllm-project:main' into main
sroy745 Jul 31, 2024
55e4332
Merge branch 'vllm-project:main' into main
sroy745 Aug 13, 2024
532eb48
Merge branch 'vllm-project:main' into main
sroy745 Aug 22, 2024
7cea056
Merge branch 'vllm-project:main' into main
sroy745 Aug 22, 2024
185e056
Merge branch 'vllm-project:main' into main
sroy745 Aug 24, 2024
e2be95f
Merge branch 'vllm-project:main' into main
sroy745 Aug 27, 2024
2ed5473
Merge branch 'vllm-project:main' into main
sroy745 Aug 28, 2024
085dea8
Add lossless guarantees of SD in vLLM
sroy745 Aug 28, 2024
322463d
Fix formatting
sroy745 Aug 28, 2024
41be9c2
Fix formatting
sroy745 Aug 28, 2024
c4e477e
Formatting and fixing links
sroy745 Aug 28, 2024
bea3399
Fix bold
sroy745 Aug 28, 2024
e76b9fb
Fixes
sroy745 Aug 28, 2024
311b242
Fix link
sroy745 Aug 28, 2024
beb5b48
Remove new line
sroy745 Aug 28, 2024
8328f6e
Remove new line
sroy745 Aug 28, 2024
7a97508
Remove new line
sroy745 Aug 28, 2024
c1e7773
Fix heading
sroy745 Aug 28, 2024
37e2cc5
small comment fix
sroy745 Aug 28, 2024
f6606d5
small comment fix
sroy745 Aug 28, 2024
efa4714
Merge branch 'vllm-project:main' into main
sroy745 Aug 29, 2024
fb87d34
Merge branch 'vllm-project:main' into main
sroy745 Aug 29, 2024
a4ce5b8
Address comments
sroy745 Aug 30, 2024
b6e58c9
Fix format
sroy745 Aug 30, 2024
7004b00
Fix format
sroy745 Aug 30, 2024
fbec1be
Fix format
sroy745 Aug 30, 2024
25190a0
Fix format
sroy745 Aug 30, 2024
db12986
Fix format
sroy745 Aug 30, 2024
0f745e3
Fixes
sroy745 Aug 30, 2024
5419e49
Merge branch 'vllm-project:main' into main
sroy745 Aug 31, 2024
c76a4e2
Merge remote-tracking branch 'origin/main' into vllm-spec-decode-doc
sroy745 Aug 31, 2024
77d42c5
Fix comment
sroy745 Sep 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions docs/source/models/spec_decode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,46 @@ A variety of speculative models of this type are available on HF hub:
* `granite-7b-instruct-accelerator <https://huggingface.co/ibm-granite/granite-7b-instruct-accelerator>`_
* `granite-20b-code-instruct-accelerator <https://huggingface.co/ibm-granite/granite-20b-code-instruct-accelerator>`_

Lossless guarantees of Speculative Decoding
-------------------------------------------
In vLLM, speculative decoding aims to enhance inference efficiency while maintaining accuracy. This section addresses the lossless guarantees of
speculative decoding, breaking down the guarantees into three key areas:

1. **Theoretical Losslessness**
- Speculative decoding sampling is theoretically lossless up to the precision limits of hardware numerics. Floating-point errors might
cause slight variations in output distributions, as discussed
in `Accelerating Large Language Model Decoding with Speculative Sampling <https://arxiv.org/pdf/2302.01318>`_

2. **Algorithmic Losslessness**
- vLLM’s implementation of speculative decoding is algorithmically validated to be lossless. Key validation tests include:

- **Rejection Sampler Convergence**: Ensures that samples from vLLM’s rejection sampler align with the target
distribution. `View Test Code <https://github.com/vllm-project/vllm/blob/47b65a550866c7ffbd076ecb74106714838ce7da/tests/samplers/test_rejection_sampler.py#L252>`_

- **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling
without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler,
provides a lossless guarantee. Almost all of the tests in `this directory <https://github.com/vllm-project/vllm/tree/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e>`_
verify this property using `this assertion implementation <https://github.com/vllm-project/vllm/blob/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e/conftest.py#L291>`_

3. **vLLM Logprob Stability**
- vLLM does not currently guarantee stable token log probabilities (logprobs). This can result in different outputs for the
same request across runs. For more details, see the FAQ section
titled *Can the output of a prompt vary across runs in vLLM?* in the `FAQs <../serving/faq.rst>`_.


**Conclusion**

While vLLM strives to ensure losslessness in speculative decoding, variations in generated outputs with and without speculative decoding
can occur due to following factors:

- **Floating-Point Precision**: Differences in hardware numerical precision may lead to slight discrepancies in the output distribution.

- **Batch Size and Numerical Stability**: Changes in batch size may cause variations in logprobs and output probabilities, potentially
due to non-deterministic behavior in batched operations or numerical instability.

**Mitigation Strategies**

For mitigation strategies, please refer to the FAQ entry *Can the output of a prompt vary across runs in vLLM?* in the `FAQs <../serving/faq.rst>`_.

Resources for vLLM contributors
-------------------------------
Expand Down
19 changes: 19 additions & 0 deletions docs/source/serving/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,22 @@ A: Assuming that you're referring to using OpenAI compatible server to serve mul
Q: Which model to use for offline inference embedding?

A: If you want to use an embedding model, try: https://huggingface.co/intfloat/e5-mistral-7b-instruct. Instead models, such as Llama-3-8b, Mistral-7B-Instruct-v0.3, are generation models rather than an embedding model

----------------------------------------

Q: Can the output of a prompt vary across runs in vLLM?

A: Yes, it can. vLLM does not guarantee stable log probabilities (logprobs) for the output tokens. Variations in logprobs may occur due to
numerical instability in Torch operations or non-deterministic behavior in batched Torch operations when batching changes. For more details,
see the `Numerical Accuracy section <https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations>`_.

In vLLM, the same requests might be batched differently due to factors such as other concurrent requests,
changes in batch size, or batch expansion in speculative decoding. These batching variations, combined with numerical instability of Torch operations,
can lead to slightly different logit/logprob values at each step. Such differences can accumulate, potentially resulting in
different tokens being sampled. Once a different token is sampled, further divergence is likely.

**Mitigation Strategies**

- For improved stability and reduced variance, use `float32`. Note that this will require more memory.
- If using `bfloat16`, switching to `float16` can also help.
- Using request seeds can aid in achieving more stable generation for temperature > 0, but discrepancies due to precision differences may still occur.
Loading