Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

k-quants #1684

Merged
merged 32 commits into from
Jun 5, 2023
Merged

k-quants #1684

merged 32 commits into from
Jun 5, 2023

Conversation

ikawrakow
Copy link
Contributor

@ikawrakow ikawrakow commented Jun 3, 2023

What

This PR adds a series of 2-6 bit quantization methods, along with quantization mixes, as proposed in #1240 and #1256. Scalar, AVX2, ARM_NEON, and CUDA implementations are provided.

Why

This is best explained with the following graph, which shows perplexity on the wikitext dataset as a function of model size:
ppl_vs_size

Note that the x-axis (model size in GiB) is logarithmic. The various circles on the graph show the perplexity of different quantization mixes added by this PR (see details below for explanation). The different colors indicate the LLaMA variant used (7B in black, 13B in red, 30B in blue, 65B in magenta). The solid squares in the corresponding color represent (model size, perplexity) for the original fp16 model. The dashed lines are added for convenience to allow for a better judgement of how closely the quantized models approach the fp16 perplexity. As we can see from this graph, generation performance as measured by perplexity is basically a fairly smooth function of quantized model size, and the quantization types added by the PR allow the user to pick the best performing quantized model, given the limits of their compute resources (in terms of being able to fully load the model into memory, but also in terms of inference speed, which tends to depend on the model size). As a specific example, the 2-bit quantization of the 30B model fits on the 16 GB RTX 4080 GPU that I have available, while the others do not, resulting in a large difference in inference performance.

Perhaps worth noting is that the 6-bit quantized perplexity is within 0.1% or better from the original fp16 model.

Another interesting observation is that the relative quantization error (as measured by perplexity) does not decrease with increasing number of weights in the base model, as one might hypothesize based on the lower quantization error observed at 13B compared to 7B (see, e.g., this table on the main page). The 13B model is indeed somehow better amenable to quantization, but relative quantization error goes back to the 7B level for the 30B and 65B models. This is illustrated with the following graph, which represents an alternative view of the data in the above graph, by showing the relative difference to the fp16 model in percent. Note that now the x-axis, being the ratio of the quantized size to the fp16 model size, is linear, while the y-axis (percent error) is logarithmic.

ppl_vs_size_relative

How (Details)

In the existing ggml quantization types we have "type-0" (Q4_0, Q5_0) and "type-1" (Q4_1, Q5_1). In "type-0", weights w are obtained from quants q using w = d * q, where d is the block scale. In "type-1", weights are given by w = d * q + m, where m is the block minimum. I use this to describe the quantizations being added by this PR.

The following new quantization types are added to ggml:

  • GGML_TYPE_Q2_K - "type-1" 2-bit quantization in super-blocks containing 16 blocks, each block having 16 weight. Block scales and mins are quantized with 4 bits. This ends up effectively using 2.5625 bits per weight (bpw)
  • GGML_TYPE_Q3_K - "type-0" 3-bit quantization in super-blocks containing 16 blocks, each block having 16 weights. Scales are quantized with 6 bits. This end up using 3.4375 bpw.
  • GGML_TYPE_Q4_K - "type-1" 4-bit quantization in super-blocks containing 8 blocks, each block having 32 weights. Scales and mins are quantized with 6 bits. This ends up using 4.5 bpw.
  • GGML_TYPE_Q5_K - "type-1" 5-bit quantization. Same super-block structure as GGML_TYPE_Q4_K resulting in 5.5 bpw
  • GGML_TYPE_Q6_K - "type-0" 6-bit quantization. Super-blocks with 16 blocks, each block having 16 weights. Scales are quantized with 8 bits. This ends up using 6.5625 bpw
  • GGML_TYPE_Q8_K - "type-0" 8-bit quantization. Only used for quantizing intermediate results. The difference to the existing Q8_0 is that the block size is 256. All 2-6 bit dot products are implemented for this quantization type.

This is exposed via llama.cpp quantization types that define various "quantization mixes" as follows:

  • LLAMA_FTYPE_MOSTLY_Q2_K - uses GGML_TYPE_Q4_K for the attention.vw and feed_forward.w2 tensors, GGML_TYPE_Q2_K for the other tensors.
  • LLAMA_FTYPE_MOSTLY_Q3_K_S - uses GGML_TYPE_Q3_K for all tensors
  • LLAMA_FTYPE_MOSTLY_Q3_K_M - uses GGML_TYPE_Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else GGML_TYPE_Q3_K
  • LLAMA_FTYPE_MOSTLY_Q3_K_L - uses GGML_TYPE_Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else GGML_TYPE_Q3_K
  • LLAMA_FTYPE_MOSTLY_Q4_K_S - uses GGML_TYPE_Q4_K for all tensors
  • LLAMA_FTYPE_MOSTLY_Q4_K_M - uses GGML_TYPE_Q6_K for half of the attention.wv and feed_forward.w2 tensors, else GGML_TYPE_Q4_K
  • LLAMA_FTYPE_MOSTLY_Q5_K_S - uses GGML_TYPE_Q5_K for all tensors
  • LLAMA_FTYPE_MOSTLY_Q5_K_M - uses GGML_TYPE_Q6_K for half of the attention.wv and feed_forward.w2 tensors, else GGML_TYPE_Q5_K
  • LLAMA_FTYPE_MOSTLY_Q6_K- uses 6-bit quantization (GGML_TYPE_Q8_K) for all tensors

Not mentioned explicitly above is the fact that with this PR, all quantization variants use 6-bit quantization for the output.weight tensor. This lowers the perplexity of, e.g., Q4_0 by about 0.03 at 7B.

The code is quite lengthy, so it is added via separate files k_quants.h, k_qunats.c instead of being added to ggml.c. I think that it would be better to also factor out all other quantization types from ggml.c, but that is up to @ggerganov to decide.

Performance

The following table summarizes the performance results (perplexity, model size, run time for single token prediction). It is basically designed after the corresponding table on the main page).

Model Measure F16 Q2_K Q3_K_S Q3_K_M Q3_K_L Q4_K_S Q4_K_M Q5_K_S Q5_K_M Q6_K
7B perplexity 5.9066 6.7764 6.4571 6.1503 6.0869 6.0215 5.9601 5.9419 5.9208 5.9110
7B file size 13.0G 2.67G 2.75G 3.06G 3.35G 3.56G 3.80G 4.33G 4.45G 5.15G
7B ms/tok@4th, M2 Max 116 56 81 69 76 50 55 70 71 75
7B ms/tok@8th, M2 Max 111 36 46 36 46 36 40 44 46 51
7B ms/tok@4th, RTX-4080 60 15.5 18.6 17.0 17.7 15.5 16.0 16.7 16.9 18.3
7B ms/tok@4th, Ryzen7950X 214 57 58 61 67 68 71 81 82 93
13B perplexity 5.2543 5.8545 5.6033 5.4498 5.4063 5.3404 5.3002 5.2785 5.2638 5.2568
13B file size 25.0G 5.13G 5.27G 5.88G 6.45G 6.80G 7.32G 8.36G 8.60G 9.95G
13B ms/tok@4th, M2 Max 216 103 156 148 144 95 102 132 134 142
13B ms/tok@8th, M2 Max 213 67 83 77 83 68 73 81 84 95
13B ms/tok@4th, RTX-4080 - 25.3 29.2 29.3 25.5 26.2 26.2 28.6 28.9 30.0
13B ms/tok@4th, Ryzen7950X 414 109 113 118 129 130 137 156 161 180

I realize the above table is not easy to read, so adding a shortened version that shows a subset of the above data:

Model Measure F16 Q2_K Q3_K_M Q4_K_S Q5_K_S Q6_K
7B perplexity 5.9066 6.7764 6.1503 6.0215 5.9419 5.9110
7B file size 13.0G 2.67G 3.06G 3.56G 4.33G 5.15G
7B ms/tok @ 4th, M2 Max 116 56 69 50 70 75
7B ms/tok @ 8th, M2 Max 111 36 36 36 44 51
7B ms/tok @ 4th, RTX-4080 60 15.5 17.0 15.5 16.7 18.3
7B ms/tok @ 4th, Ryzen 214 57 61 68 81 93
13B perplexity 5.2543 5.8545 5.4498 5.3404 5.2785 5.2568
13B file size 25.0G 5.13G 5.88G 6.80G 8.36G 9.95G
13B ms/tok @ 4th, M2 Max 216 103 148 95 132 142
13B ms/tok @ 8th, M2 Max 213 67 77 68 81 95
13B ms/tok @ 4th, RTX-4080 - 25.3 29.3 26.2 28.6 30.0
13B ms/tok @ 4th, Ryzen 414 109 118 130 156 180

I think it is better to have quantization separate from
ggml. For now just adding the k-quants there, but it would be
better to also factor out the existing ggml quantizations.
CUDA is not ideal - ~50% slower than Q4_0 for
single token prediction, about the same in batch
mode (perplexity). CPU single token is ~55 ms
(on Ryzen 7950X).
It is now ~22.5 ms/token on my GPU, so ~30% slower than Q4_0.
Single token is now 20.5 ms/token (~20% slower than Q4_0).
Perplexity is on par with Q4_0.
Performance is the same or perhaps very slightly better than Q4_0 on the CPU.
On the GPU, single token prediction is ~10% better than Q4_0,
batch mode (perplexity is about the same).
Performance is ~40% lower compared to Q4_K on the CPU.
This is to be expected, considering that we are memory bound
on the CPU and the 6-bit model is ~44% larger than the 4-bit.
On the GPU, single token prediction is ~6% lower than Q4_0,
batch mode (perplexity) is even closer (but still slower).
Performance is ~20% lower compared to Q4_K on the CPU.
This is to be expected, considering that we are memory bound
on the CPU and the 5-bit model is ~22% larger than the 4-bit.
On the GPU, single token prediction is about the same as Q4_0
for both, single token and batch prediction.
It is 22% slower than Q4_K, despite the smaller model size.
On x86_64, where we are memory bound, the Q3_K model is
quite a bit faster than Q4_K.
Token prediction is pretty good - about 15.5 ms on a RTX 4080.
Perplexity is about the same as Q4_K.
About the same performance as Q4_K.
Single token prediction is now ~36 ms on M2 Max.
The code is much simpler too.
Stranegly enough, for the few prompts I tried with the 7B model
the responses looked perfectly reasonable. Only realized something
is not quite right when I tried the larger models and started getting
nonse back.

In any case, Q2_K single token evaluation time on an RTX 4080 in a Ryzen7950X
box iusing CUDA and model fully loaded on the GPU are
  ~15.5 ms for 7B, ~25.4 ms for 13B, and ~55.8 ms for 30B.
The max number of layers that fit in VRAM for The 65B is 32.
With that, we get ~330 ms per token, which is not that much faster
than just running on the CPU (~470 ms per token).
Q3_K is now running at ~18.5 ms / token on CUDA,
so the gap to Q4_0 is only 10%.
It seems memory acccess pattern is more important for
performance than the amount of computation the kernel
does.
For perplexity, where we are less memory bound, time per
pass drops by ~5%. Barely measurable difference for single
token prediction.
@ikawrakow ikawrakow requested a review from ggerganov June 3, 2023 15:24
github-actions[bot]

This comment was marked as off-topic.

We cannot possibly be expecting rmse < 0.002 for 2- and 3-bit
quantization variants.
github-actions[bot]

This comment was marked as off-topic.

@qwerr0
Copy link

qwerr0 commented Jan 16, 2024

amzing work.

@francoisfleuret
Copy link

This is best explained with the following graph, which shows perplexity on the wikitext dataset as a function of model size:

This graph is gorgeous. Any hope to have the raw numbers to replot it?

@mofosyne mofosyne added Tensor Encoding Scheme https://github.com/ggerganov/llama.cpp/wiki/Tensor-Encoding-Schemes Review Complexity : High Generally require indepth knowledge of LLMs or GPUs labels May 25, 2024
@Seedmanc
Copy link

So, as a bottom line, is an L quant of a lower bit depth better or worse than an S quant of a higher one? Like Q4 L vs Q5 S.
Can the file size be always used as a measure of quant's quality when given a list of various bit depths and quant letters? Some simple guidelines for choosing which model version to download would be nice, this diversity of versions is too much to handle.

@kaizizzzzzz
Copy link

Hi, I'm curious about the latency I ran on 4090 GPU. I ran inference of Llama 7b, for bf16, the latency is 71.51 ms / token, which is close to 60ms / token in the table above. However when I ran Q2_K quantization, the latency is 35.25 ms / token, which is twice the result in the table above.

@Green-Sky
Copy link
Collaborator

@kaizizzzzzz those numbers are now over a year old and there continue to be performance improvements.

@kaizizzzzzz
Copy link

@Green-Sky Yes, improvements will be fine. But it is actually much worse. The latency of 1 year ago for Q2_k quantization is 16ms/tokens, but now is 35ms/tokens. So I'm curious about this.

@kaizizzzzzz
Copy link

Activation is quentized in 8 bits?

@fedric95
Copy link

fedric95 commented Sep 8, 2024

So The "2.5625 bits per weight" is not the final figure for Q2_K? I think the figure is around 3.3-3.4?

The figure is 2.5625 bpw for the tensors quantized with 2 bits and 4.5 bpw for the tensors quantized with 4 bits. The specific quantization mix exposed as LLAMA_FTYPE_MOSTLY_Q2_K results in

((11008 + 2*4096)*4.5 + (2*11008 + 2*4096)*2.5625)/(3*11008 + 4*4096) = 3.315 bpw

for the LLaMA 7B model (ignoring 1-d tensors, which always remain as f32 in llama.cpp).

I know that is an old comment and I am sorry abou that, but it would be great to know how you ended up with this formula, I am trying to replicate it for also other quantization schemas and other models

@HAOYON-666
Copy link

(myenv) (base) root@master-22:/data/zhy/tools/llama.cpp/build_cuda/bin# ./main -m /data/zhy/models/Llama-3-chinese-8b-instruct-v3-q4_k_m/llama-3-chinese-8b-instruct-v3-q4_k_m.gguf
-n -1
-ngl 256
-t 12
--color
-r "User:"
--in-prefix " "
-i
-p
User:'想聊聊什么话题呢?聊聊吗?
bash: ./main: 没有那个文件或目录

@HAOYON-666
Copy link

please give me some suggestion

@SamuelHafner
Copy link

Thank you for your work. But do you know the original papers of quantizitation types

@Green-Sky
Copy link
Collaborator

@SamuelHafner there is no original paper for k-quants. They where cooked up by @ikawrakow

@SamuelHafner
Copy link

@Green-Sky So in general there is no Paper for different k-quants. Like ollama use also kquants. So there i no paper?

@ikawrakow
Copy link
Contributor Author

There are no papers on k- or i-quants because I don't like writing papers. Combined with me enjoying the luxury of not needing another paper on my CV, and me not looking for a job or for investment, I see no reason to go and advertise on arXiv.

On the other hand, not having published a paper (or papers) allows other quantization researchers to ignore k- and i-quants, despite HF being littered with GGUFs containing k- and i-quantized models. Which makes their new shiny quantization methods look better than they actually are. Which is good for keeping the hype wave going.

So, in short, a win-win 😃

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority Very important issue Less than 4 bits Efforts related to viable quantized models using <4 bits research 🔬 Review Complexity : High Generally require indepth knowledge of LLMs or GPUs Tensor Encoding Scheme https://github.com/ggerganov/llama.cpp/wiki/Tensor-Encoding-Schemes 🦙. llama
Projects
Development

Successfully merging this pull request may close these issues.