-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support int8 KVCache Quant in Vllm #1507
Conversation
Hi, Thanks for the wonderful PR, and I implement the FP8 (e5m2 / e3m4) KVCache quantization based on this PR at this branch, which does not require any calibration (as for the case of INT8).
|
Awesome work! |
FP8 e5m2 has the same dynamic range with fp16, so I think it should be safe to cast, while FP8 e4m3 has more precision but may cost more cycles than e5m2 when casting from fp16. Meanwhile, the casting implementation seems to be done in a similar way with cuda native This idea originates from two parts: 1) Recently I tried fp16 & fp32 & bf16 mixed precision training on V100 (fp32 <--> bf16 is roughly similar to fp16 <--> fp8 at high level). More analysis and details are posted here. 2) I also bumped into this idea somewhere (a blog, but no implementation provided. I couldn't find the exact reference now). |
Thanks for sharing the experiments and the benchmarks. How to understand the plot (e.g., x-axis and y-axis)? |
@ZiyueHuang I dump the kv cache of each transformer layer and get the mean data among all transformer layers. The prompt length is 14,generation tokens' length is 65. So the total sequence length is 14+65=79. The hidden dims is 6144, head_num=48,kv_dims=6144/48=128. So the dim of each layer's key/value cache is [b, seq_len, 128]. The x-axis is the kv dim ([0, 63] is the key cache and [64, 127] is the value cache) and the y-axis is the sequence length. Note that the kv cache data distribution is related to model. |
attn_dtype tgt_value = __ldg(&value[src_value_idx]); | ||
value_cache[tgt_value_idx] = quant(tgt_value, v_scale, v_zp); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are lot of redundant codes in reshape_and_cache_quantized_kernel
compared reshape_and_cache__kernel
. Is it better to merge to one function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have merged the functions and eliminated the redundant code.
csrc/cache_kernels.cu
Outdated
v_scale, | ||
v_zp); | ||
}); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it better to merge reshape_and_cache_quantized
to reshape_and_cache
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have merged reshape_and_cache_quantized to reshape_and_cache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Several comments left.
Thank you for taking the time to review our code. We will take your advice and respond as soon as possible. |
Thanks for your great work! Met compile errors, left some comments here, but seems need more modify, I can't run kv_int8 model with this PR(turn to #1112 finally). Exception in callback functools.partial(<function _raise_exception_on_finish at 0x7fe75458d120>, request_tracker=<vllm.engine.async_llm_engine.RequestTracker object at 0x7fe6c149b820>)
handle: <Handle functools.partial(<function _raise_exception_on_finish at 0x7fe75458d120>, request_tracker=<vllm.engine.async_llm_engine.RequestTracker object at 0x7fe6c149b820>)>
Traceback (most recent call last):
File "/workdir/vllm_smoothquant/vllm/vllm/engine/async_llm_engine.py", line 28, in _raise_exception_on_finish
task.result()
File "/workdir/vllm_smoothquant/vllm/vllm/engine/async_llm_engine.py", line 351, in run_engine_loop
has_requests_in_progress = await self.engine_step()
File "/workdir/vllm_smoothquant/vllm/vllm/engine/async_llm_engine.py", line 330, in engine_step
request_outputs = await self.engine.step_async()
File "/workdir/vllm_smoothquant/vllm/vllm/engine/async_llm_engine.py", line 191, in step_async
output = await self._run_workers_async(
File "/workdir/vllm_smoothquant/vllm/vllm/engine/async_llm_engine.py", line 216, in _run_workers_async
output = executor(*args, **kwargs)
File "/root/anaconda3/envs/vllm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/workdir/vllm_smoothquant/vllm/vllm/worker/worker.py", line 369, in execute_model
output = self.model(
File "/root/anaconda3/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/workdir/vllm_smoothquant/vllm/vllm/model_executor/models/llama.py", line 310, in forward
hidden_states = self.model(input_ids, positions, kv_caches,
File "/root/anaconda3/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/workdir/vllm_smoothquant/vllm/vllm/model_executor/models/llama.py", line 268, in forward
hidden_states = layer(
File "/root/anaconda3/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/workdir/vllm_smoothquant/vllm/vllm/model_executor/models/llama.py", line 211, in forward
hidden_states = self.self_attn(
File "/root/anaconda3/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/workdir/vllm_smoothquant/vllm/vllm/model_executor/models/llama.py", line 159, in forward
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
File "/root/anaconda3/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/workdir/vllm_smoothquant/vllm/vllm/model_executor/layers/attention.py", line 405, in forward
return super().forward(
File "/workdir/vllm_smoothquant/vllm/vllm/model_executor/layers/attention.py", line 296, in forward
cache_ops.reshape_and_cache_quantized(
RuntimeError: expected scalar type Int but found Long My application scenarios is long context(~1000 tokens) but short outputs(~15 tokens), so KV Cache quantize may not help with me, I didn't see any improvement in prefilling or decoding phase. Any ideas about this? Hoping to use W8A16 asap, thanks. EDIT: FYI, prefilling and decoding latency of KV_INT8 model is closely to FP16 model with small batch. In my test(A40 * 1), KV_INT8 got ~40% throughout improvement, W8A8 got ~45% throughout improvement, W8A8+KV_INT8 ~100% throughout improvement......Amazing, can't believe this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall now the pr is lgtm. @zhuohan123 could you please take time to review?
Hi! @zhuohan123 @WoosukKwon We think this PR is ready. Could you please take some time to review it? Much appreciated! |
@AniZpZ Could you please solve the conflicts? We are plannig to review it again and merge it if all is OK. |
Sure, I will sovle the conficts. |
@zhuohan123 @zhaoyang-star Hi!The conflicts have been solved now. |
LGTM |
@zhuohan123 this was included in the release tracker for 4.0.0, but ended up not being merged in time. Should it be added to the new release tracker? |
Hi, will you rebase your code on vllm 0.5.0? |
4 similar comments
Hi, will you rebase your code on vllm 0.5.0? |
Hi, will you rebase your code on vllm 0.5.0? |
Hi, will you rebase your code on vllm 0.5.0? |
Hi, will you rebase your code on vllm 0.5.0? |
Quantization for kv cache can lift the throughput with minimal loss in model performance. We impelement int8 kv cache quantization which can achieve a 15% throughput improvement. This pr is part of #1112. We spilted the huge PR into 2 independent parts for easier review.
The usage of int8 KV Cache quant is simple:
The loss in model performance is minor . The following data is our experiemnt result on mmlu dataset
You can find more details like how to gernerate KV cache scales in original PR #1112
You can use the method with w8a8 inference #1508 for best throughput