Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KV-Cache int8 quant support #10354

Open
wants to merge 19 commits into
base: main
Choose a base branch
from

Conversation

YanyunDuanIEI
Copy link

Add KV-Cache int8 quant support

Support [layer_level] and [group_level] KV-Cache int8 quant.

  • [layer_level] use common scale factors for each layer.
  • [group_level] group the head_size according to group_size, with each group_size, the scaling factor of key/value corresponding to the same value.

KV-Cache int8 quant (Click to Expand)

Get the scaling factor by calibration

Support to calibrate the KV-cache by datasets:

  • [examples/int8/calibrate.py] calibrate and save to pth.
  • [export_kv_params.py] save scaling factors to json.

Using KV-Cache int8

  • kv_cache_dtype="int8"
  • kv_quant_params_path=kv_quant_params_path
  • kv_quant_group=kv_quant_group

Signed-off-by: Yanyun Duan <duanyanyun@inspur.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link

mergify bot commented Nov 17, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @YanyunDuanIEI.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 17, 2024
csrc/attention/attention_kernels.cuh Outdated Show resolved Hide resolved
csrc/attention/attention_kernels.cuh Outdated Show resolved Hide resolved
@YanyunDuanIEI
Copy link
Author

Would it be viable to hasten the review process?

Copy link

mergify bot commented Dec 17, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @YanyunDuanIEI.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 17, 2024
@mergify mergify bot removed the needs-rebase label Dec 20, 2024
@xiabo123
Copy link

@YanyunDuanIEI Hello, is the kv cache int8 quantization in this PR online or offline? The code requires calibration sets such as' c4 '. Can it be like this PR( #1507 )Write down the operation of model transformation?

@YanyunDuanIEI
Copy link
Author

YanyunDuanIEI commented Jan 20, 2025

@YanyunDuanIEI Hello, is the kv cache int8 quantization in this PR online or offline? The code requires calibration sets such as' c4 '. Can it be like this PR( #1507 )Write down the operation of model transformation?

It is offline, and the demo is located in the examples/int8/ directory. There is an execution demo named examples/int8/run_calibrate.sh.

@xiabo123
Copy link

@YanyunDuanIEI Hello, is the kv cache int8 quantization in this PR online or offline? The code requires calibration sets such as' c4 '. Can it be like this PR( #1507 )Write down the operation of model transformation?

It is offline, and the demo is located in the examples/int8/ directory. There is an execution demo named examples/int8/run_calibrate.sh.

Thank you for your answer.

@xiabo123
Copy link

@YanyunDuanIEI Hello, May I ask if there is a download path for the calibration set files "ceval_val_cmcc.jsonl" and "mapping. json" for "ceval_val_cmcc" and "ceval"?

@mergify mergify bot added documentation Improvements or additions to documentation ci/build frontend labels Jan 21, 2025
Copy link
Collaborator

Hi, if you are still interested in getting this in, please fix the merge conflict, thank you!

@YanyunDuanIEI
Copy link
Author

@YanyunDuanIEI Hello, May I ask if there is a download path for the calibration set files "ceval_val_cmcc.jsonl" and "mapping. json" for "ceval_val_cmcc" and "ceval"?

Most of the datasets are in LLaMA-Factory, located in the LLaMA-Factory/evaluation/.

@xiabo123
Copy link

@YanyunDuanIEI This doesn't seem to support models from the qwen2 series. Is it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This examples directory has a lot of lines, especially due to the scales in the work_dir. If you want to keep this example, please try to:

  • Rename dir to int8_kv_cache
  • Write a README describing how to use
  • Cleanup/consolidate these scripts if possible
  • Possibly remove the work_dir? I think it is reasonable to keep one set of scales as demonstration, but I don't see a reason to keep so many

I think once this support lands, we can easily update llmcompressor with examples to produce calibrated int8 kv cache scales - similar to like we have for FP8 now https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_kv_cache

float k_scale = 0;
float v_scale = 0;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8Group128) {
int64_t tgt_kvs_idx = floor((kv_head_idx*HEAD_SIZE)/quant_group);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no need to keep quant_group as an argument since we have the KVCacheDataType as template parameter. We know in the kInt8Group128 that the quant_group will be 128, so I think we can remove this parameter completely.

Comment on lines +18 to +21
// printf("\n dequant scale= %f, zero_point= %f \n", scale, zero_point);
// if(abs(res+1.268555)<=0.01)
// printf("\nI am here int8_to_float, x = %d, a= %d, res=%f, scale=%f, zero_point=%f \n",
// x, a, res, scale, zero_point);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leftover cruft

Comment on lines +28 to +31
// printf("\n quant scale= %f \n", scale);
// if(abs(x+1.268555)<=0.00001)
// printf("\nI am here float_to_int8, x = %f, fx= %d, res=%d, scale=%f, zero_point=%f, (x-zero_point) / scale)=%f \n",
// x, fx, res, scale, zero_point, (x-zero_point) / scale);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Comment on lines +35 to +39
template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion_int8(const Tin& x,
const float scale, const float zero_point) {
return x;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not seem right, what is the purpose of this definition?

Comment on lines +113 to +114
k_scales: torch.Tensor,
v_scales: torch.Tensor,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we tend to just use scale rather than scales even in the case of using tensors, see these kernels as example

def apply_int8_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
input_zero_point: Optional[torch.Tensor] = None,
azp_adj: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
):

Comment on lines 871 to 875
k_scale=k_scale,
v_scale=v_scale,
quant_group,
k_scales,
v_scales,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use named assignment of args here

Comment on lines 882 to 887
k_scale=k_scale,
v_scale=v_scale,
quant_group,
k_scales,
v_scales,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use named assignment of args here

Comment on lines +72 to +86
k_scales_lists = v_scales_lists = [1.0]
# k_scales_lists = [0.16]
# v_scales_lists = [0.005]
self._k_scales = torch.Tensor(k_scales_lists).type(torch.float32).to("cuda")
self._v_scales = torch.Tensor(v_scales_lists).type(torch.float32).to("cuda")
self._quant_group = cache_config.kv_quant_group
if cache_config.cache_dtype.startswith("int8"):
if cache_config.kv_quant_params_path is not None:
k_scales_lists = cache_config.kv_quant_params[0].pop(0)
v_scales_lists = cache_config.kv_quant_params[1].pop(0)
self._k_scales = torch.Tensor(k_scales_lists).type(torch.float32).to("cuda")
self._v_scales = torch.Tensor(v_scales_lists).type(torch.float32).to("cuda")
if self._quant_group !=0:
self._k_scales = self._k_scales.reshape((-1, num_kv_heads, head_size//self._quant_group))
self._v_scales = self._v_scales.reshape((-1, num_kv_heads, head_size//self._quant_group))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
k_scales_lists = v_scales_lists = [1.0]
# k_scales_lists = [0.16]
# v_scales_lists = [0.005]
self._k_scales = torch.Tensor(k_scales_lists).type(torch.float32).to("cuda")
self._v_scales = torch.Tensor(v_scales_lists).type(torch.float32).to("cuda")
self._quant_group = cache_config.kv_quant_group
if cache_config.cache_dtype.startswith("int8"):
if cache_config.kv_quant_params_path is not None:
k_scales_lists = cache_config.kv_quant_params[0].pop(0)
v_scales_lists = cache_config.kv_quant_params[1].pop(0)
self._k_scales = torch.Tensor(k_scales_lists).type(torch.float32).to("cuda")
self._v_scales = torch.Tensor(v_scales_lists).type(torch.float32).to("cuda")
if self._quant_group !=0:
self._k_scales = self._k_scales.reshape((-1, num_kv_heads, head_size//self._quant_group))
self._v_scales = self._v_scales.reshape((-1, num_kv_heads, head_size//self._quant_group))
default_scale = [1.0]
self._k_scales = torch.Tensor(default_scale).type(torch.float32).to("cuda")
self._v_scales = torch.Tensor(default_scale).type(torch.float32).to("cuda")
self._quant_group = cache_config.kv_quant_group
if cache_config.cache_dtype.startswith("int8"):
if cache_config.kv_quant_params_path is not None:
k_scales_lists = cache_config.kv_quant_params[0].pop(0)
v_scales_lists = cache_config.kv_quant_params[1].pop(0)
self._k_scales = torch.Tensor(default_scale).type(torch.float32).to("cuda")
self._v_scales = torch.Tensor(default_scale).type(torch.float32).to("cuda")
if self._quant_group !=0:
self._k_scales = self._k_scales.reshape((-1, num_kv_heads, head_size//self._quant_group))
self._v_scales = self._v_scales.reshape((-1, num_kv_heads, head_size//self._quant_group))

# v_scales_lists = [0.005]
self._k_scales = torch.Tensor(k_scales_lists).type(torch.float32).to("cuda")
self._v_scales = torch.Tensor(v_scales_lists).type(torch.float32).to("cuda")
self._quant_group = cache_config.kv_quant_group
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can deduce kv_quant_group from the cache_config.cache_dtype, as mentioned in the kernels

Copy link

mergify bot commented Jan 28, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @YanyunDuanIEI.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 28, 2025
@mergify mergify bot added the v1 label Feb 5, 2025
@xiabo123
Copy link

xiabo123 commented Feb 8, 2025

@YanyunDuanIEI May I ask, in rocm, when performing accuracy verification, the variable 'quantization_param_path' needs to specify a file path, is it the same as the variable 'kv_quant_params_path'? Or, can we specify the generated JSON files' kv_cache_scales_layer_level.json 'and'kv_cache_scales_quant_group128.json 'separately?
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation frontend needs-rebase v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants