-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trellis quantization #113
Draft
ikawrakow
wants to merge
53
commits into
main
Choose a base branch
from
ik/try_trellis
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Trellis quantization #113
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Using 12 bits per 8 weights I get a better rmse than iq2_xxs. I still need to see how quantizing the group-of-8 scales will affect accuracy. By AVX2 SIMDifying the search for the best code, LLaMA-3.1-8B gets quantized in 130 seconds on the Ryzen-7950X CPU - sluggish but still acceptable.
rmse increases by just 3%, so this is beating iq2_xss in terms of rmse at the same 2.0625 bpw.
I now see that I was comparing apples to oranges: iq2_xxs was using a weight of sigma^2/4 + x^2, while the Trellis approach wasn't (weight = 1). Once I use the same weight, iq2_kt is actually slightly worse than iq2_xxs in terms of rmse, so does not look promising at this point. Also, once each group of 8 Trellis values no longer has a constant sum(q^2) that we can precompute, quantization becomes significantly slower (476 seconds for LLaMA-3.1-8B).
so we can run perplexity calcs. As already indicated by rmse, the 2-bit trellis approach is quite a bit worse than iq2_xxs.
With blocks of 32 and 16 bits per groups of 8 the brute force seach becomes prohibitive in terms of CPU time (30+ minutes for 8B LLaMA after SIMDifying with AVX2). The trick is to group the points in clusters, find the nearest cluster, and only search within the cluster.
Using blocks of 32 and 16 bits per group of 8 weights it beats iq2_xxs in terms of PPL by a significant margin. It is 0.0625 bpw larger, but even if we go to 15 bits per group od 8 (so 0.0625 bpw less than iq2_xxs), PPL is still lower.
Re-quantize after determining block scales (at the epxense of much longer quantization time).
Implemented as DMMV. Very slow - just 81 t/s for LLaMA-3.1-8B. Then again, Q2_K_S with forced to use DMMV only gets 112 t/s vs 145 t/s via MMVQ. My memory is that when the DMMV kernels were properly maintained/used, DMMV was about on par with MMVQ for k-quants on my GPU.
We arrive at 112 t/s.
We arrive at 139 t/s (no FA), and 149 t/s (FA). My RTX-4080 is ~20% slower than the RTX-6000 quoted in the QTIP repository, so with FA (which I'm sure they also used) we are at around ~180 t/s on their GPU, so almost matching their performance.
We arrive at 146 t/s (no FA), and 158 t/s (FA). This is measured for LLaMA-3.1-8B with output.weight left as f16.
3.125 bpw. So far does not look good on the PPL vs bpw plot.
PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.8322, which is starting to be competitive/slightly better than other quants.
PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7892
PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7689 after shrinking by 0.015 bpw by using iq4_k instead of q5_k for attn_v.
Nearly 60% improvement of quantization speed by having the points nelonging to a cluster copied to contiguous memory during initialization, and then accessed sequantially while searching for the closest point. LLaMA-3.1-8B now gets quantized in ~150 seconds on the Ryzen-5975WX.
Same trick as last commit applied to iq2_kt. Here we get an even larger speedup: quantization time on the Ryzen-5975WX for LLaMA-3.1-8B drops to 195 seconds from 375 seconds!
We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.2406 PPL(LLaMA-2-7B, 4096) = 6.4179
We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.1642 PPL(LLaMA-2-7B, 4096) = 6.3920
We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.1642 PPL(LLaMA-2-7B, 4096) = 6.3920
We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.0297 PPL(LLaMA-2-7B, 4096) = 6.3913 Ah, quantization is faster too. About 20% faster.
We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 8.9627 PPL(LLaMA-2-7B, 4096) = 6.3825 Quantization is faster too: ~200 seconds for LLaMA-3.1-8B on Ryzen-5975WX.
15 bits per group of 4, plus 8 bit scales ifor blocks of 32. This gives a slightly better PPL than iq4_kss.
at the expense of much longer quantization time.
It was working for 4.125 bpw. But after changing to 4.0 bpw there is something wrong and I don't see the bug.
Go to groups of 8 for iq3_kt. 2 x 8 = 16 bits for the magnitude plus 1 bpw for the sign. It goves a visible improvement in the PPL vs bpw plot, but that comes at the expense of much longer quantization time (7.5 minutes for LLaMA-3.1-8B on the Ryzen-5975WX). I also notices that the 3INST generator is not actually generating a Gaussian distribution. But going to a better generator means readjusting all the hyper-parameters, so leaving it for later.
ikawrakow
force-pushed
the
ik/try_trellis
branch
from
November 21, 2024 09:54
ecac9d6
to
3a9926b
Compare
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The latest quantization hype is
QTIP
- paper, repository. They use a Trellis approach and report impressive results, so I decided to look into this more closely.This PR implements what they call "3INST" in their paper. Basically, if we have a seed
seed
, we generateN
quantized valuesq_i
viawhere
a, b, mask1
andmask2
are suitable constants. This generates values that are (nearly) normally distributed. One uses this to describe a group ofN
quants with a singleL
-bit seed (index). Apart from borrowing the "3INST" algorithm from the QTIP paper, the implementation here has noting else in common with QTIP - there are no Hadamard transforms, and no (tail-biting) Viterbi algorithm is utilized during quantization. Instead, in the usual i- and k-quants style, quants are organized in blocks and super-blocks with suitable block scales, and the search for the best seed during quantization is done via a clustering algorithm.The PR adds 3 new quantization types:
IQ2_KT
:L=16
bits for groups ofN=8
quants. Block size is 32 with a 4-bit block scale, plus a single float scale per tensor row (the 32 bits added by this scale can be safely neglected for typical tensor row sizes), so we end up using 2.125 btwIQ3_KT
:L=12
bits for groups ofN=4
quants. Block size is also 32 with a 4-bit block scale, so 3.125 bpwIQ4_KT
:L=15
bits for groups ofN=4
quants. Blocks of 32 with 8-bit block scales, so 4.0 bpw.Quantization accuracy
This figure shows quantization error
PPL(Q)/PPL(bf16)-1
for LLaMA-3.1-8B-Instruct (context length of 8192 tokens). The blue symbols are k-quants, the black symbols are i-quants, cyan symbols are iqk-quants (not available in mainlinellama.cpp
), and the orange symbols are the Trellis quants added by this PR. We do see a small but noticeable improvement compared to i- and iqk-quants, with about 0.2 fewer bpw required to achieve the same quantization error.How does this compare to the QTIP paper? Unfortunately they report results without fine tuning only for LLaMA-v2. The table shows a comparison between the 2-bit quantizations for LLaMA-v2-7B (the QTIP results are taken from Table 3 in their paper, context length is 4096 tokens)
Although there are small differences between the PPL computed by
llama.cpp
and by the tools used by the QTIP authors, the quantization error as defined above is basically independent of the specifics of the PPL calculation, so we see that the 2 bpw quantization implemented here slightly outperforms QTIP without fine tuning (at the expense of using 0.125 bpw more bits). Given this, and the above graph, my conclusion is that Trellis based quantization is a small improvement compared to i-,k-,iqk-quants, but nowhere near the hype observed around the Internet.Performance
The QTIP authors give TG speed for their 2 bpw variant on an RTX-6000 Ada GPU (see here) and a 7B LLaMA model. My GPU is RTX-4080 (so same generation as theirs, but lower specs). I did a quick attempt to get QTIP going in my environment to have apples-to-apples performance comparison, but it was not successful, so I will use the ratio between their
f16
performance on the RTX-6000 (55.9 t/s) to myfp16
performance on the RTX-4080 (46.2 t/s) to translate QTIP performance on the RTX-6000 (188 t/s) to estimated performance on the RTX-4080:In comparison, I get 194 t/s for
IQ2_KT
(with flash attention enabled, which I assume they also use). These results are with the output tensor left asf16
(which is what is done in QTIP).IQ2_XSS
achieves 208 t/s (output asf16
) or 216 t/s (output asQ5_K
), so QTIP performance is far behind the performance of a model of similar size using a more efficient quantization.Caveats
AVX2
support. The search for the optimum seed is extremely expensive (the QTIP authors say "prohibitive" forL >= 12
without their tail-biting search space reduction), so I had to SIMDify to not have to wait forever for a quantization to finish. This PR being mostly a POC for now, I did not want to spend the time implementing for other instruction sets (or even porting to run on a GPU).AVX2
, quantization is slow - depending on quantization type it takes between 2.5 and 4.5 minutes to quantize LLaMA-3.1-8B on a 32-core Ryzen-5975WX CPU.DMMV
mechanism inllama.cpp
. The algorithm outputs float values, so one needs to convert toint8_t
to use the usual quantized dot products. The cost of this conversion is likely to (more than) offset any advantage one might gain by using SIMDint8_t
dot products.