Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Documentation request]: Add documentation on lossless guarantees of speculative decoding in vLLM #7627

Closed
jmkuebler opened this issue Aug 17, 2024 · 12 comments
Labels
bug Something isn't working

Comments

@jmkuebler
Copy link
Contributor

Your current environment

The output of `python collect_env.py`
Collecting environment information...
PyTorch version: 2.4.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.30.2
Libc version: glibc-2.31

Python version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1052-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A10G
GPU 1: NVIDIA A10G
GPU 2: NVIDIA A10G
GPU 3: NVIDIA A10G
GPU 4: NVIDIA A10G
GPU 5: NVIDIA A10G
GPU 6: NVIDIA A10G
GPU 7: NVIDIA A10G

Nvidia driver version: 535.104.12
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      48 bits physical, 48 bits virtual
CPU(s):                             192
On-line CPU(s) list:                0-191
Thread(s) per core:                 2
Core(s) per socket:                 48
Socket(s):                          2
NUMA node(s):                       2
Vendor ID:                          AuthenticAMD
CPU family:                         23
Model:                              49
Model name:                         AMD EPYC 7R32
Stepping:                           0
CPU MHz:                            3276.100
BogoMIPS:                           5599.99
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          3 MiB
L1i cache:                          3 MiB
L2 cache:                           48 MiB
L3 cache:                           384 MiB
NUMA node0 CPU(s):                  0-47,96-143
NUMA node1 CPU(s):                  48-95,144-191
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save rdpid

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] pyzmq==26.1.0
[pip3] torch==2.4.0
[pip3] torchvision==0.19.0
[pip3] transformers==4.44.0
[pip3] triton==3.0.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
[conda] pyzmq                     26.1.0          py310h7d2b5bf_0    conda-forge
[conda] torch                     2.4.0                    pypi_0    pypi
[conda] torchvision               0.19.0                   pypi_0    pypi
[conda] transformers              4.44.0                   pypi_0    pypi
[conda] triton                    3.0.0                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.5.4@4db5176d9758b720b05460c50ace3c01026eb158
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      PHB     PHB     PHB     PHB     PHB     PHB     PHB     0-191   0-1             N/A
GPU1    PHB      X      PHB     PHB     PHB     PHB     PHB     PHB     0-191   0-1             N/A
GPU2    PHB     PHB      X      PHB     PHB     PHB     PHB     PHB     0-191   0-1             N/A
GPU3    PHB     PHB     PHB      X      PHB     PHB     PHB     PHB     0-191   0-1             N/A
GPU4    PHB     PHB     PHB     PHB      X      PHB     PHB     PHB     0-191   0-1             N/A
GPU5    PHB     PHB     PHB     PHB     PHB      X      PHB     PHB     0-191   0-1             N/A
GPU6    PHB     PHB     PHB     PHB     PHB     PHB      X      PHB     0-191   0-1             N/A
GPU7    PHB     PHB     PHB     PHB     PHB     PHB     PHB      X      0-191   0-1             N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

🐛 Describe the bug

At temperature 0, I would expect that a model with and without speculative decoding results in exactly the same generation. At least the theory suggests that and I did not see a warning that the implementation would not.
But in the example below (which is close to the example in the documentation), the output actually differs (marked in bolt where it starts to differ).

Target model alone:

from vllm import LLM, SamplingParams
llm = LLM(
    model="facebook/opt-6.7b",
    tensor_parallel_size=1,
    use_v2_block_manager=True,
)


sampling_params = SamplingParams(temperature=0, ignore_eos=True, max_tokens=128)

prompts = [
    "I am at KDD conference in Barcelona. After the conference tutorial today I will",
]
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Generated text: {generated_text!r}")

generates

Generated text: ' be presenting a paper on the topic of “The role of the brain in the development of the human body”. The paper is based on my PhD thesis, which I defended in December.\n\nThe paper is about the role of the brain in the development of the human body. The brain is the most complex organ in the human body. It is the organ that controls the body’s functions. The brain is also the organ that is most difficult to study.\n\nThe brain is the organ that controls the body’s functions.\n\nThe brain is also the organ that is most difficult to study.'

With SD

from vllm import LLM, SamplingParams
llm = LLM(
    model="facebook/opt-6.7b",
    tensor_parallel_size=1,
    speculative_model="facebook/opt-125m",
    num_speculative_tokens=4,
    use_v2_block_manager=True,
)

sampling_params = SamplingParams(temperature=0, ignore_eos=True, max_tokens=128)

prompts = [
    "I am at KDD conference in Barcelona. After the conference tutorial today I will",
]
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Generated text: {generated_text!r}")

generates

Generated text: ' be presenting a paper on the topic of “The role of the brain in the development of language”. The paper is based on my PhD thesis, which I defended in December.\n\nThe paper is about the role of the brain in the development of language. I will present the results of my research on the role of the brain in the development of language. I will present the results of my research on the role of the brain in the development of language.\n\nThe paper is about the role of the brain in the development of language. I will present the results of my research on the role of the brain in the'

@jmkuebler jmkuebler added the bug Something isn't working label Aug 17, 2024
@jmkuebler
Copy link
Contributor Author

@cadedaniel is this known / expected?

@cadedaniel
Copy link
Collaborator

Hi @jmkuebler. Thanks for trying it out and sharing your results.

The expectation in vLLM is that the distribution mostly matches the target distribution. This can be split out into three different layers:

  1. Theoretical losslessness -- what are the theoretical guarantees of speculative decoding
  2. vLLM algorithmic losslessness -- what are the guarantees of speculative decoding in vLLM given vLLM's fwd pass implementation
  3. vLLM logprob stability -- what are the guarantees provided by vLLM's fwd pass implementation

For (1), the sampling is lossless up to hardware numerics, see "Modified Rejection Sampling" from Accelerating Large Language Model Decoding with Speculative Sampling. This means that floating-point error can accumulate differently and cause differences in the output distribution. This may be enough to cause the difference in your two completions, but I suspect it's actually in (3).

For (2), we have tests that validate vLLM's algorithmic implementations of speculative decoding perform as expected. Specifically, we have two categories of tests:

a. Test that verifies samples from the rejection sampler converge to the target distribution. Introduced in #2336. This verifies that vLLM's rejection sampler provides a lossless guarantee. Code

def test_rejection_sampling_approximates_target_distribution(

b. Test that verify greedy sampling with speculative decoding is equal to greedy sampling without speculative decoding. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler, provides a lossless guarantee. Code: Almost all of the tests in this directory verify this property: https://github.com/vllm-project/vllm/tree/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e, assertion implementation:
def run_equality_correctness_test(baseline_llm_generator,

With both (a) and (b), we have reasonable confidence that vLLM's implementation of the algorithmic components in speculative decoding are lossless when temp=0. However, I've seen cases where changing the batch size can introduce small changes to the sampled tokens. This brings us to (3). (also this is noted here

@cadedaniel has seen cases where the output probabilities of a draft/target
model change slightly with certain batch sizes or prompts, even with Torch
determinism flags set. It is unclear if this is a bug in vLLM, due to non-
determinism in on-device batched operations, a bug in vLLM's spec decode
implementation, or the "hardware numerics" limitations. Either way, rejection
sampling ensures the output distribution matches the target model, but it breaks
greedy-equality tests for those batch sizes/prompts.
)

For (3), vLLM does not currently provide stable logprobs for the same prompt. Specifically they may change slightly as the batch size changes. I am not sure if this is due to a bug in the implementation or simply the numeric instability of torch operations like torch.bmm, as explained in "Numerical accuracy".

In summary, I believe the behavior you observe is expected as vLLM does not have stable logprobs over different batch sizes. It could also be a bug in the algorithmic implementations, which you can validate with the tests listed above.

P.S.

  • One testing gap in vLLM's speculative decoding implementation is verifying losslessness when sampling parameters are used, such as freq_penalty/topk/topp. We can add more tests here. But it doesn't seem related to the divergence you observe.
  • If you really need stable generation, I believe you can use request-seeds here [Bugfix] Make spec. decode respect per-request seed. #6034. But they have some impact on latency (not optimized).

@njhill
Copy link
Member

njhill commented Aug 19, 2024

@jmkuebler this looks like typical precision-related variance, i.e. the numeric instability that @cadedaniel referred to. When the same requests are batched differently, whether this is due to other concurrent requests being processed or the expanded batches used in spec decoding, the results aren't guaranteed to be identical even if they should be from a theoretical / math pov.

These precision differences manifest as slightly different logit/logprob values at each step, which at some point can result in a different token taking the top spot as you observed. Once a different token is chosen then further divergence is inevitable.

Try setting the dtype to float32 for both of these and the variance should be significantly reduced (but will need double the mem of course). If you're using bfloat16 you can switch to float16 which should be much more stable (but still less stable than float32).

Using seed shouldn't make any difference to greedy/temp=0 generation, but the above applies similarly when using a seed with temp>0.

@w013nad
Copy link

w013nad commented Aug 20, 2024

@cadedaniel Would it be possible to formalize what you wrote here into the docs as opposed to just linking to YouTube videos? This was really helpful for my understanding especially since I didn't realize that speculative decoding came with mild guarantees of distribution matching with and without spec decoding.

@cadedaniel
Copy link
Collaborator

Sure, let me convert this issue into a documentation issue. I will try to get to it in the next few weeks. also cc @sroy745 .

@cadedaniel cadedaniel changed the title [Bug]: Speculative sampling does not excatly maintain the distribution [Documentation request]: Add documentation on lossless guarantees of speculative decoding in vLLM Aug 20, 2024
@jmkuebler
Copy link
Contributor Author

jmkuebler commented Aug 21, 2024

@cadedaniel
Thank you so much for the fast and detailed answer!

After reading your reply I agree that this is solely a documentation issue. It's not a practical problem, I was just surprised by the observation because I did not have the numerical aspects in mind.

I think your explanation hits the root cause. I read out the logprobs for my prompt (only possible for target model alone as it is not possible for the multi-step worker used for SpecDec). Here is the top 2 logprobs for the token at which the deviation happens:

 {5: Logprob(logprob=-3.157238483428955, rank=2, decoded_token=' the'),
  2777: Logprob(logprob=-3.157238483428955, rank=1, decoded_token=' language')},

I have checked, they are actually exactly the same floating point numbers! So therefore, greedy decoding would actually be ambiguous.

Curiously, for this particular token, the first entry in the logprob dict is rank=2 while usually rank=1 goes first. But that might not mean anything.

@jmkuebler
Copy link
Contributor Author

jmkuebler commented Aug 21, 2024

@cadedaniel there might actually be a slight inconsistency:

without speculative decoding the rank is computed as with an > (

result = (x > vals[:, None])
)
whereas with SD it seems it is done via an >= (
sampled_token_ids_ranks = (logprob_tensor >=
)

Could that be there reason and be sth we would want to align? My take is that there is no "right" or "false" here, but consistency could be desirable.


Edit: we fixed this with #7899

@cadedaniel
Copy link
Collaborator

cadedaniel commented Aug 22, 2024

that could be a bug. but it is not used in the lossless algorithm, only in what is returned to the user. feel free to open an issue and/or pr!

nice find!

@jmkuebler
Copy link
Contributor Author

jmkuebler commented Aug 26, 2024

@cadedaniel I did some further investigation.
I realize that this in fact has nothing to do with Speculative decoding, but it is numerical aspects.
I can track down the deviation to a single token. And apparently it makes a difference whether you push the last token through the model alone or with a larger batch (either in case of specdec above or by adding it to the prefil)

from vllm import LLM, SamplingParams
llm = LLM(
    model="facebook/opt-6.7b",
    tensor_parallel_size=1,
    use_v2_block_manager=True,
)

sampling_params = SamplingParams(temperature=0, ignore_eos=True, max_tokens=2, top_k=1)

prompts = [
        'I am at KDD conference in Barcelona. After the conference tutorial today I will be presenting a paper on the topic of “The role of the brain in the development',
]
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Generated text: {generated_text!r}")

sampling_params = SamplingParams(temperature=0, ignore_eos=True, max_tokens=1, top_k=1)

prompts = [
        'I am at KDD conference in Barcelona. After the conference tutorial today I will be presenting a paper on the topic of “The role of the brain in the development of',
]

outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Generated text: {generated_text!r}")

Gives:

Generated text: ' of language'

Generated text: ' the'

That's all happening also without Speculative decoding.

Furthermore, running it full precision, removes this behavior (at least for this example).

@cadedaniel
Copy link
Collaborator

Nice find. Yep this is exactly what @njhill is talking about.

@abhigoyal1997
Copy link
Contributor

I've seen something similar with Medusa as well when using BF16: https://github.com/vllm-project/vllm/pull/4978/files#r1627603023. Might be related.

@jmkuebler
Copy link
Contributor Author

The root cause was identified to be numerics rather than a bug and documentation was added. Hence closing this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants