Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trellis quantization #113

Draft
wants to merge 53 commits into
base: main
Choose a base branch
from
Draft

Trellis quantization #113

wants to merge 53 commits into from

Conversation

ikawrakow
Copy link
Owner

@ikawrakow ikawrakow commented Nov 15, 2024

The latest quantization hype is QTIP - paper, repository. They use a Trellis approach and report impressive results, so I decided to look into this more closely.

This PR implements what they call "3INST" in their paper. Basically, if we have a seed seed, we generate N quantized values q_i via

uint32_t u32;
float16_t * h = reinterpret_cast<float16_t*>(&u32)
for i in 0...N-1
    seed = a * seed + b
    u32 = (mask1 & seed) ^ mask2
    q_i = h[0] + h[1]
end

where a, b, mask1 and mask2 are suitable constants. This generates values that are (nearly) normally distributed. One uses this to describe a group of N quants with a single L-bit seed (index). Apart from borrowing the "3INST" algorithm from the QTIP paper, the implementation here has noting else in common with QTIP - there are no Hadamard transforms, and no (tail-biting) Viterbi algorithm is utilized during quantization. Instead, in the usual i- and k-quants style, quants are organized in blocks and super-blocks with suitable block scales, and the search for the best seed during quantization is done via a clustering algorithm.

The PR adds 3 new quantization types:

  • IQ2_KT: L=16 bits for groups of N=8 quants. Block size is 32 with a 4-bit block scale, plus a single float scale per tensor row (the 32 bits added by this scale can be safely neglected for typical tensor row sizes), so we end up using 2.125 btw
  • IQ3_KT: L=12 bits for groups of N=4 quants. Block size is also 32 with a 4-bit block scale, so 3.125 bpw
  • IQ4_KT: L=15 bits for groups of N=4 quants. Blocks of 32 with 8-bit block scales, so 4.0 bpw.

Quantization accuracy

This figure shows quantization error PPL(Q)/PPL(bf16)-1 for LLaMA-3.1-8B-Instruct (context length of 8192 tokens). The blue symbols are k-quants, the black symbols are i-quants, cyan symbols are iqk-quants (not available in mainline llama.cpp), and the orange symbols are the Trellis quants added by this PR. We do see a small but noticeable improvement compared to i- and iqk-quants, with about 0.2 fewer bpw required to achieve the same quantization error.

il31a

How does this compare to the QTIP paper? Unfortunately they report results without fine tuning only for LLaMA-v2. The table shows a comparison between the 2-bit quantizations for LLaMA-v2-7B (the QTIP results are taken from Table 3 in their paper, context length is 4096 tokens)

Quantization PPL(f16) PPL (Q) Quantization error
QTIP 2 bpw 5.12 6.82 33.2%
IQ2_KT 4.94 6.36 28.7%

Although there are small differences between the PPL computed by llama.cpp and by the tools used by the QTIP authors, the quantization error as defined above is basically independent of the specifics of the PPL calculation, so we see that the 2 bpw quantization implemented here slightly outperforms QTIP without fine tuning (at the expense of using 0.125 bpw more bits). Given this, and the above graph, my conclusion is that Trellis based quantization is a small improvement compared to i-,k-,iqk-quants, but nowhere near the hype observed around the Internet.

Performance

The QTIP authors give TG speed for their 2 bpw variant on an RTX-6000 Ada GPU (see here) and a 7B LLaMA model. My GPU is RTX-4080 (so same generation as theirs, but lower specs). I did a quick attempt to get QTIP going in my environment to have apples-to-apples performance comparison, but it was not successful, so I will use the ratio between their f16 performance on the RTX-6000 (55.9 t/s) to my fp16 performance on the RTX-4080 (46.2 t/s) to translate QTIP performance on the RTX-6000 (188 t/s) to estimated performance on the RTX-4080:

QTIP (2 bpw, RTX-4080) = fp16(RTX-4080)/fp16(RTX-6000) * QTIP (2 bpw, RTX-6000) = 46.2/55.9*188 = 155.4 t/s

In comparison, I get 194 t/s for IQ2_KT (with flash attention enabled, which I assume they also use). These results are with the output tensor left as f16 (which is what is done in QTIP). IQ2_XSS achieves 208 t/s (output as f16) or 216 t/s (output as Q5_K), so QTIP performance is far behind the performance of a model of similar size using a more efficient quantization.

Caveats

  • Quantization is only implemented for a CPU with AVX2 support. The search for the optimum seed is extremely expensive (the QTIP authors say "prohibitive" for L >= 12 without their tail-biting search space reduction), so I had to SIMDify to not have to wait forever for a quantization to finish. This PR being mostly a POC for now, I did not want to spend the time implementing for other instruction sets (or even porting to run on a GPU).
  • Even with AVX2, quantization is slow - depending on quantization type it takes between 2.5 and 4.5 minutes to quantize LLaMA-3.1-8B on a 32-core Ryzen-5975WX CPU.
  • Inference is only implemented on CUDA. Due to the "3INST" algorithm, I expect low performance on the CPU and on the Apple GPU, so did not bother to implement for those.
  • There are no quantized matrix-vector kernels, so implementation is via the DMMV mechanism in llama.cpp. The algorithm outputs float values, so one needs to convert to int8_t to use the usual quantized dot products. The cost of this conversion is likely to (more than) offset any advantage one might gain by using SIMD int8_t dot products.

@ikawrakow ikawrakow marked this pull request as draft November 15, 2024 19:05
Using 12 bits per 8 weights I get a better rmse than
iq2_xxs. I still need to see how quantizing the group-of-8
scales will affect accuracy. By AVX2 SIMDifying the search
for the best code, LLaMA-3.1-8B gets quantized in 130 seconds
on the Ryzen-7950X CPU - sluggish but still acceptable.
rmse increases by just 3%, so this is beating iq2_xss in terms
of rmse at the same 2.0625 bpw.
I now see that I was comparing apples to oranges:
iq2_xxs was using a weight of sigma^2/4 + x^2, while
the Trellis approach wasn't (weight = 1). Once I use the same weight,
iq2_kt is actually slightly worse than iq2_xxs in terms
of rmse, so does not look promising at this point.
Also, once each group of 8 Trellis values no longer has a
constant sum(q^2) that we can precompute, quantization
becomes significantly slower (476 seconds for LLaMA-3.1-8B).
so we can run perplexity calcs.
As already indicated by rmse, the 2-bit trellis approach is
quite a bit worse than iq2_xxs.
With blocks of 32 and 16 bits per groups of 8 the brute force
seach becomes prohibitive in terms of CPU time (30+ minutes
for 8B LLaMA after SIMDifying with AVX2). The trick is to
group the points in clusters, find the nearest cluster,
and only search within the cluster.
Using blocks of 32 and 16 bits per group of 8 weights
it beats iq2_xxs in terms of PPL by a significant margin.
It is 0.0625 bpw larger, but even if we go to 15 bits per
group od 8 (so 0.0625 bpw less than iq2_xxs), PPL is still
lower.
Re-quantize after determining block scales
(at the epxense of much longer quantization time).
Implemented as DMMV.
Very slow - just 81 t/s for LLaMA-3.1-8B.
Then again, Q2_K_S with forced to use DMMV only
gets 112 t/s vs 145 t/s via MMVQ. My memory is that
when the DMMV kernels were properly maintained/used,
DMMV was about on par with MMVQ for k-quants on my GPU.
We arrive at 112 t/s.
We arrive at 139 t/s (no FA), and 149 t/s (FA).

My RTX-4080 is ~20% slower than the RTX-6000 quoted in the
QTIP repository, so with FA (which I'm sure they also used)
we are at around ~180 t/s on their GPU, so almost matching
their performance.
We arrive at 146 t/s (no FA), and 158 t/s (FA).
This is measured for LLaMA-3.1-8B with output.weight
left as f16.
3.125 bpw. So far does not look good on the PPL vs bpw plot.
PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.8322, which is
starting to be competitive/slightly better than other quants.
PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7892
PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7689 after shrinking
by 0.015 bpw by using iq4_k instead of q5_k for attn_v.
Nearly 60% improvement of quantization speed by having the
points nelonging to a cluster copied to contiguous memory
during initialization, and then accessed sequantially while
searching for the closest point. LLaMA-3.1-8B now gets
quantized in ~150 seconds on the Ryzen-5975WX.
Same trick as last commit applied to iq2_kt. Here we get
an even larger speedup: quantization time on the Ryzen-5975WX
for LLaMA-3.1-8B drops to 195 seconds from 375 seconds!
We arrive at
PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.2406
PPL(LLaMA-2-7B,            4096) = 6.4179
We arrive at
PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.1642
PPL(LLaMA-2-7B,            4096) = 6.3920
We arrive at
PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.1642
PPL(LLaMA-2-7B,            4096) = 6.3920
We arrive at
PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.0297
PPL(LLaMA-2-7B,            4096) = 6.3913

Ah, quantization is faster too. About 20% faster.
We arrive at
PPL(LLaMA-3.1-8B-Instruct, 8192) = 8.9627
PPL(LLaMA-2-7B,            4096) = 6.3825

Quantization is faster too: ~200 seconds for LLaMA-3.1-8B
on Ryzen-5975WX.
15 bits per group of 4, plus 8 bit scales ifor blocks of 32.
This gives a slightly better PPL than iq4_kss.
at the expense of much longer quantization time.
It was working for 4.125 bpw. But after changing to 4.0 bpw
there is something wrong and I don't see the bug.
Go to groups of 8 for iq3_kt. 2 x 8 = 16 bits for the magnitude
plus 1 bpw for the sign. It goves a visible improvement in the
PPL vs bpw plot, but that comes at the expense of much longer
quantization time (7.5 minutes for LLaMA-3.1-8B on the Ryzen-5975WX).

I also notices that the 3INST generator is not actually generating a
Gaussian distribution. But going to a better generator means
readjusting all the hyper-parameters, so leaving it for later.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants