Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] 3bit quantization not working #1207

Closed
sidhantls opened this issue Feb 3, 2025 · 7 comments
Closed

[BUG] 3bit quantization not working #1207

sidhantls opened this issue Feb 3, 2025 · 7 comments
Labels
bug Something isn't working

Comments

@sidhantls
Copy link

Describe the bug
I'm trying to quantize LLM to 3 bits. However, the quantization code runs with an error at the end. Yet, when I set bits=4 for the same code, it works.

Software Info
Windows 10, Python 3.10

Torch: Version: 2.6.0+cu124, transformers=4.48.2, accelerate=1.3.0

To Reproduce

from datasets import load_dataset
from gptqmodel import GPTQModel, QuantizeConfig

model_id = "EleutherAI/pythia-160m"

calibration_dataset = load_dataset(
    "allenai/c4",
    data_files="en/c4-train.00001-of-01024.json.gz",
    split="train"
  ).select(range(1024))["text"]

calibration_dataset = [" ".join(item.split()[:30]) for item in calibration_dataset] # speedup

quantize_config = QuantizeConfig(
    bits=3, # works with bit=4
    group_size=128,
)
model = GPTQModel.load(model_id, quantize_config)

# increase `batch_size` to match gpu/vram specs to speed up quantization
model.quantize(calibration_dataset, batch_size=1)
model.save("saved_model")

Error:

File ~\Desktop\GPTQModel\gptqmodel\utils\model.py:440, in pack_module(name, qModules, quantizers, layers, pbar)
    433 qModules[name].to(CPU)
    434 layers[name], scale, zero, g_idx = (
    435     layers[name].to(CPU),
    436     scale.to(CPU),
    437     zero.to(CPU),
    438     g_idx.to(CPU) if g_idx is not None else None,
    439 )
--> 440 qModules[name].pack(layers[name], scale, zero, g_idx)
    441 qModules[name].to(layer_device)
    442 if pbar:

File ~\Desktop\GPTQModel\gptqmodel\nn_modules\qlinear\__init__.py:320, in PackableQuantLinear.pack(self, linear, scales, zeros, g_idx)
    318 qweight[row] |= intweight[row_offset_plus_10] << 31
    319 row += 1
--> 320 qweight[row] |= (intweight[row_offset_plus_10] >> 1) & 0x3
    321 for j in range(10):
    322     qweight[row] |= intweight[row_offset + j] << (3 * j + 2)

IndexError: index 72 is out of bounds for axis 0 with size 72
@sidhantls sidhantls added the bug Something isn't working label Feb 3, 2025
@benjamin-marie
Copy link

I have the same issue with Llama 3x and Qwen2.5 models.

@Qubitium
Copy link
Collaborator

Qubitium commented Feb 4, 2025

@sidhantls @benjamin-marie I will check this tomorrow. Ran out of time today.

Also I want to add we are aware of an potential issue with 3bit where it has lower accuracy than 2bit in our previous ci regression tests. So even after I fix this, we may still have a lingering problem of 3bit quality (model accuracy via llm-eval tests) that is actually lower than 2bit quality which requires more investigation.

@sidhantls
Copy link
Author

Thanks for pointing out the accuracy issue with 3bit @Qubitium. But is the 3-bit implementation in line with its fork, AutoGPTQ? Because we've quantized llama-3.1-8b using AutoGPTQ, and didn't notice an accuracy issue with 3bit

@Qubitium
Copy link
Collaborator

Qubitium commented Feb 4, 2025

Thanks for pointing out the accuracy issue with 3bit @Qubitium. But is the 3-bit implementation in line with its fork, AutoGPTQ? Because we've quantized llama-3.1-8b using AutoGPTQ, and didn't notice an accuracy issue with 3bit

@sidhantls Unknown for now. We normally only test/use 4, 8 bits so 2, 3 bit outlier issues haven't gotten the attention that may it should have. We will do more testing to confirm one way or another.

If memory serves me correct, we did an ARC lm-eval test recently on a 1B. llama model and 3bit lost to 2bit model which was suprising. Could be our code, or could be there is an issue with 3bits overall. We need to dig more into this.

But regardless, 2/3bit scores much much lower than 4bit. Which is also the resaon I haven't bothered to use it myself. =P

edit: ref huggingface/transformers#35460 (comment)

@Qubitium
Copy link
Collaborator

Qubitium commented Feb 5, 2025

3-bit packing fixed in https://github.com/ModelCloud/GPTQModel/pull/1218/files but there is a bigger issue where inference of 3bit is broken, causing massive quality/ppl degradation.

In short, do not use 3bit for now until we have fixed this regression.

The 3bit quant error_loss appears to be normal so we are going to backtrack to when the divergence from autogptq broke 3bit (could still be packing related but more likely inference related)

@sidhantls
Copy link
Author

@Qubitium Great, thanks for letting me know.

I've validated 3bit vs 2bit quality on Llama-3.1-8b for AutoGPTQ on MMLU and NQOpen, Given the severe degradation of performance of 3bit and 2bit, I'd recommend using a model greater than 1B parameters to validate if 3 works better than 2

@Qubitium
Copy link
Collaborator

Qubitium commented Feb 6, 2025

@sidhantls @benjamin-marie Fixed on main. It was caused by my recent refractor.😅

@Qubitium Qubitium closed this as completed Feb 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants