-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Activation Aware Weight Quantization (AWQ) #743
Merged
Merged
Changes from 69 commits
Commits
Show all changes
83 commits
Select commit
Hold shift + click to select a range
d3db3db
init
vayuda 1c29347
Merge branch 'awq' of https://github.com/vayuda/ao into awq
vayuda ed864a2
fixed implementation
vayuda 90be216
Merge branch 'pytorch:main' into awq
vayuda 0e690f9
reduced vmem req
vayuda 05ae692
Merge branch 'awq' of https://github.com/vayuda/ao into awq
vayuda 4519792
eval on LLMs
vayuda fca7895
Merge branch 'pytorch:main' into awq
vayuda 0096a83
eval on llm
vayuda 20c2529
Merge branch 'awq'
vayuda 7614d51
convert list to tensor
vayuda 33a28dd
restructuring
vayuda 7d389b5
revert unecessary hf_eval changes
vayuda 5913b98
added wikitext eval test
vayuda 7d045f9
added tinygemm integration
vayuda db302ef
made the calibration step much faster
vayuda 2ec38f1
merge pt1
vayuda b400711
Merge remote-tracking branch 'upstream/main' into awq
vayuda 4aae94b
works/created tutorial
vayuda 8e058d7
added docs, tests, cleaned code
vayuda dced6e5
updated benchmark
vayuda c43b997
update example
vayuda 2388091
update example
vayuda 8619cd5
added init file
vayuda 5378ac9
Merge remote-tracking branch 'upstream/main' into awq
vayuda 9027082
reduce vram for calibration
vayuda dbac7c8
eval changes+ llama2 data
vayuda aa62e5f
llama2 data + eval script init changes
vayuda a4b006f
Merge branch 'main' into awq
vayuda 7f21bfc
fixed qdtype bounds and example import
vayuda fb7fb11
Merge branch 'awq' of https://github.com/vayuda/ao into awq
vayuda 863d503
fix tests
vayuda 4ca9117
fix tests
vayuda 3e70a6f
use rolling log liklihood for eval and calibrate awq with run_eval
vayuda c37b396
Merge branch 'pytorch:main' into awq
vayuda 389ea77
Merge remote-tracking branch 'upstream/main' into awq
vayuda 8ab016a
make eval use less vram
vayuda 27c062d
updated uintx import
vayuda 59b4174
Merge remote-tracking branch 'upstream/main' into awq
vayuda 9d52c93
add awq to generate
vayuda 310138e
add calibration params to cli
vayuda a1b2bd0
fix name
vayuda 9379686
pass linear properly
vayuda 8d173df
recast W*eq_scale to original dtype
vayuda 6aab8f8
revert bad change
vayuda 4e97611
make scales same type as model
vayuda 77db01f
compatible with compile
vayuda 588e81e
cast eq scale to bf16
vayuda bfa5797
switch calibration dataset
vayuda 41a621b
remove extra line
vayuda 20767d5
add import
vayuda ae32c7c
added save/load scales
vayuda e2160ae
add save/store workflow to example
vayuda 9704c38
add arg to fn
vayuda 04ef51d
fix cli arg
vayuda 17e2fbc
fix cli arg
vayuda 71f9e27
add comma
vayuda 1716b0c
add model to fn call
vayuda 436fb9f
fix example
vayuda 84b407c
refactored awq impl
vayuda 49fbbe2
Merge branch 'pytorch:main' into awq
vayuda 1216f97
edits +update usage
vayuda 3930660
perplexity evals added
vayuda 3e5710c
updated readme with benchmarks
vayuda da6a70d
add awq-hqq to generate
vayuda 7314d99
better citation
vayuda dc0c507
Merge branch 'pytorch:main' into awq
vayuda 68d9592
nits
vayuda bc7526e
remove layout.py
vayuda 1fdf068
quantization func changes
vayuda 85ea32c
typo
vayuda 0a12f96
fix import
vayuda 1b5a57b
fix fn params
vayuda c193afb
edit
vayuda 4e60dfd
edit
vayuda 93dcb79
rename
vayuda d2ed1f2
fix indentation
vayuda 2650702
added bf16 gaurd for uint4 tinygemm quant
vayuda 5e25469
remove bad tests
vayuda 3aa279f
remove arg
vayuda c08fdb1
one last guard..
vayuda b967ebd
require nightly
vayuda e7e329b
require nightly on everything
vayuda File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import json | ||
import torch | ||
from transformers import AutoModel | ||
from pathlib import Path | ||
def create_weight_map(checkpoint_dir: Path): | ||
""" | ||
This function, create_weight_map, generates a mapping of a model's weights to a file (pytorch_model.bin) | ||
and saves this mapping, along with the model's total size, to a JSON file (pytorch_model.bin.index.json). | ||
The model is loaded from a pre-trained model specified by model_name. | ||
This weight map is used by the HF conversion script (convert_hf_checkpoint.py). | ||
""" | ||
# Load the model | ||
model_name = checkpoint_dir.parent.name +"/"+ checkpoint_dir.name | ||
print(model_name) | ||
model = AutoModel.from_pretrained(model_name) | ||
# Get the state dict | ||
state_dict = model.state_dict() | ||
# Create the weight map | ||
weight_map = {} | ||
for key, tensor in state_dict.items(): | ||
# In this example, we're assuming all weights are in a single file | ||
# You may need to adjust this if your model uses sharded weights | ||
weight_map[key] = "pytorch_model.bin" | ||
# Create the index dictionary | ||
index_dict = { | ||
"metadata": {"total_size": sum(param.numel() * param.element_size() for param in model.parameters())}, | ||
"weight_map": weight_map | ||
} | ||
# Save the index dictionary to a JSON file | ||
with open(f"{checkpoint_dir}/pytorch_model.bin.index.json", "w") as f: | ||
json.dump(index_dict, f, indent=2) | ||
print("Created pytorch_model.bin.index.json") | ||
|
||
if __name__ == '__main__': | ||
import argparse | ||
parser = argparse.ArgumentParser(description='Create weight map for hf model') | ||
parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/Xenova/llama2.c-stories15M")) | ||
|
||
|
||
args = parser.parse_args() | ||
create_weight_map( | ||
args.checkpoint_dir | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from copy import deepcopy | ||
import os | ||
import pytest | ||
import torch | ||
from torchao.quantization import quantize_ | ||
|
||
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 | ||
if TORCH_VERSION_AT_LEAST_2_3: | ||
from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear | ||
|
||
class ToyLinearModel(torch.nn.Module): | ||
def __init__(self, m=512, n=256, k=128): | ||
super().__init__() | ||
self.linear1 = torch.nn.Linear(m, n, bias=False) | ||
self.linear2 = torch.nn.Linear(n, k, bias=False) | ||
self.linear3 = torch.nn.Linear(k, 1, bias=False) | ||
|
||
def example_inputs(self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"): | ||
return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)] | ||
|
||
def forward(self, x): | ||
x = self.linear1(x) | ||
x = self.linear2(x) | ||
x = self.linear3(x) | ||
return x | ||
|
||
devices = ["cpu", "cuda"] | ||
# torch.uintx dtypes are introduced in 2.3 | ||
if TORCH_VERSION_AT_LEAST_2_3: | ||
qdtypes = (torch.uint4, torch.uint7) | ||
else: | ||
qdtypes = () | ||
|
||
@pytest.fixture(autouse=True) | ||
def run_before_and_after_tests(): | ||
yield | ||
torch._dynamo.reset() # reset cache between tests | ||
|
||
idtypes = (torch.half, torch.bfloat16) | ||
@pytest.mark.parametrize("device", devices) | ||
@pytest.mark.parametrize("qdtype", qdtypes) | ||
@pytest.mark.parametrize("idtype", idtypes) | ||
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3,reason="torch.uint(2-7) requires torch2.3+") | ||
def test_awq_loading(device, qdtype, idtype): | ||
dataset_size = 100 | ||
l1,l2,l3 = 512,256,128 | ||
original_dtype = idtype | ||
quant_dtype = qdtype | ||
group_size = 128 | ||
n_calibration_examples = 10 | ||
n_validation_examples = 10 | ||
sequence_length = 5 | ||
|
||
m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device) | ||
dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device) | ||
calibration_data = dataset[:n_calibration_examples] | ||
|
||
# calibrate | ||
insert_awq_observer_(m, n_validation_examples, sequence_length, quant_dtype=quant_dtype, group_size=group_size) | ||
|
||
for example in calibration_data: | ||
m(example.to(device)) | ||
|
||
|
||
# quantize | ||
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) | ||
quantize_(m, awq_uintx(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear) | ||
vayuda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
model_save_path = "awq_model.pth" | ||
torch.save(m, model_save_path) | ||
loaded_model = torch.load(model_save_path) | ||
os.remove(model_save_path) | ||
|
||
if device == "cuda": | ||
m = torch.compile(m, fullgraph=True) | ||
loaded_model = torch.compile(loaded_model, fullgraph=True) | ||
|
||
awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) | ||
awq_save_load_out = torch.cat([loaded_model(i.squeeze(0)) for i in dataset]) | ||
|
||
assert awq_out is not None | ||
assert awq_save_load_out is not None | ||
assert torch.allclose(awq_out, awq_save_load_out, atol = 1e-2) | ||
|
||
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3,reason="torch.uint(2-7) requires torch2.3+") | ||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
def test_save_weights_only(): | ||
dataset_size = 100 | ||
l1,l2,l3 = 512,256,128 | ||
original_dtype = torch.half | ||
quant_dtype = torch.uint4 | ||
device = "cuda" | ||
group_size = 128 | ||
n_calibration_examples = 10 | ||
n_validation_examples = 10 | ||
sequence_length = 5 | ||
|
||
m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device) | ||
m2 = deepcopy(m) | ||
dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device) | ||
calibration_data = dataset[:n_calibration_examples] | ||
|
||
# calibrate | ||
insert_awq_observer_(m, n_validation_examples, sequence_length, quant_dtype=quant_dtype, group_size=group_size) | ||
|
||
for example in calibration_data: | ||
m(example.to(device)) | ||
|
||
|
||
# quantize | ||
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) | ||
quantize_(m, awq_uintx(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear) | ||
|
||
model_save_path = "awq_model.pth" | ||
torch.save(m.state_dict(), model_save_path) | ||
m2.load_state_dict(torch.load(model_save_path), assign=True) # load weights only.torch.load(model_save_path) | ||
os.remove(model_save_path) | ||
|
||
m = torch.compile(m, fullgraph=True) | ||
m2 = torch.compile(m2, fullgraph=True) | ||
|
||
awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) | ||
awq_save_load_out = torch.cat([m2(i.squeeze(0)) for i in dataset]) | ||
|
||
assert awq_out is not None | ||
assert awq_save_load_out is not None | ||
assert torch.allclose(awq_out, awq_save_load_out, atol = 1e-2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -233,4 +233,4 @@ def run_evaluation( | |
args.calibration_limit, | ||
args.calibration_seq_length, | ||
args.pad_calibration_inputs, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .uintx import UintxTensor, UintxLayoutType, UintxAQTLayout, to_uintx, _DTYPE_TO_BIT_WIDTH |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# AWQ Quantization | ||
Adapted from https://github.com/mit-han-lab/llm-awq | ||
|
||
## Benchmarks | ||
Evaluation perplexity numbers were calculated using the script in awq/example.py Group size of 64 was used for all quantization methods. For Llama-2-7b-chat-hf, performance benchmarks were calculated using the torchao/_models/llama/generate.py script and run on a 1xA100 80GB SXM4 instance. The awq-uint4 quantization method does not use an efficient fused kernel which is why performance is not great. awq-hqq uses tinygemm int4->bf16 kernel + hqq to provide better performance. | ||
|
||
| Model | Quantization | Tokens/sec | Throughput (GB/sec) | Peak Mem (GB) | Model Size (GB) | | ||
|--------------------|--------------|------------|---------------------|---------------|-----------------| | ||
| Llama-2-7b-chat-hf | bfloat16 | 107.38 | 1418.93 | 13.88 | 13.21 | | ||
| | awq-hqq-int4 | 196.6 | 761.2 | 5.05 | 3.87 | | ||
| | awq-uint4 | 43.59 | 194.93 | 7.31 | 4.47 | | ||
| | int4wo-hqq | 209.19 | 804.32 | 4.89 | 3.84 | | ||
| | int4wo-64 | 201.14 | 751.42 | 4.87 | 3.74 | | ||
|
||
|
||
|
||
The following tests were performed using LM eval and groupsize = 128 | ||
| Model | Quantization | Perplexity | Truthful QA MC2 | WinoGrande | ARC challenge | | ||
| Llama-3-8B-Instruct| bfloat16 | 10.936 | 0.540 | 0.783 | 0.567 | | ||
| | awq-hqq-int4 | 11.383 | 0.522 | 0.772 | 0.543 | | ||
| | awq-uint4 | 11.409 | 0.519 | 0.756 | 0.577 | | ||
| | int4wo-hqq | 11.905 | 0.528 | 0.757 | 0.563 | | ||
| | int4wo-128 | 12.380 | 0.502 | 0.753 | 0.548 | | ||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .api import insert_awq_observer_, awq_uintx | ||
from .core import AWQObservedLinear |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: does
device=device
not workThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea surprisingly it doesn't
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you open an issue with error messages