Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Activation Aware Weight Quantization (AWQ) #743

Merged
merged 83 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
d3db3db
init
vayuda Aug 15, 2024
1c29347
Merge branch 'awq' of https://github.com/vayuda/ao into awq
vayuda Aug 15, 2024
ed864a2
fixed implementation
vayuda Aug 16, 2024
90be216
Merge branch 'pytorch:main' into awq
vayuda Aug 20, 2024
0e690f9
reduced vmem req
vayuda Aug 21, 2024
05ae692
Merge branch 'awq' of https://github.com/vayuda/ao into awq
vayuda Aug 21, 2024
4519792
eval on LLMs
vayuda Aug 24, 2024
fca7895
Merge branch 'pytorch:main' into awq
vayuda Aug 24, 2024
0096a83
eval on llm
vayuda Aug 24, 2024
20c2529
Merge branch 'awq'
vayuda Aug 24, 2024
7614d51
convert list to tensor
vayuda Aug 24, 2024
33a28dd
restructuring
vayuda Aug 24, 2024
7d389b5
revert unecessary hf_eval changes
vayuda Aug 24, 2024
5913b98
added wikitext eval test
vayuda Aug 25, 2024
7d045f9
added tinygemm integration
vayuda Aug 27, 2024
db302ef
made the calibration step much faster
vayuda Aug 29, 2024
2ec38f1
merge pt1
vayuda Sep 3, 2024
b400711
Merge remote-tracking branch 'upstream/main' into awq
vayuda Sep 3, 2024
4aae94b
works/created tutorial
vayuda Sep 3, 2024
8e058d7
added docs, tests, cleaned code
vayuda Sep 5, 2024
dced6e5
updated benchmark
vayuda Sep 5, 2024
c43b997
update example
vayuda Sep 5, 2024
2388091
update example
vayuda Sep 5, 2024
8619cd5
added init file
vayuda Sep 6, 2024
5378ac9
Merge remote-tracking branch 'upstream/main' into awq
vayuda Sep 6, 2024
9027082
reduce vram for calibration
vayuda Sep 8, 2024
dbac7c8
eval changes+ llama2 data
vayuda Sep 8, 2024
aa62e5f
llama2 data + eval script init changes
vayuda Sep 8, 2024
a4b006f
Merge branch 'main' into awq
vayuda Sep 9, 2024
7f21bfc
fixed qdtype bounds and example import
vayuda Sep 9, 2024
fb7fb11
Merge branch 'awq' of https://github.com/vayuda/ao into awq
vayuda Sep 9, 2024
863d503
fix tests
vayuda Sep 9, 2024
4ca9117
fix tests
vayuda Sep 10, 2024
3e70a6f
use rolling log liklihood for eval and calibrate awq with run_eval
vayuda Sep 10, 2024
c37b396
Merge branch 'pytorch:main' into awq
vayuda Sep 10, 2024
389ea77
Merge remote-tracking branch 'upstream/main' into awq
vayuda Sep 12, 2024
8ab016a
make eval use less vram
vayuda Sep 12, 2024
27c062d
updated uintx import
vayuda Sep 12, 2024
59b4174
Merge remote-tracking branch 'upstream/main' into awq
vayuda Sep 21, 2024
9d52c93
add awq to generate
vayuda Sep 22, 2024
310138e
add calibration params to cli
vayuda Sep 22, 2024
a1b2bd0
fix name
vayuda Sep 22, 2024
9379686
pass linear properly
vayuda Sep 25, 2024
8d173df
recast W*eq_scale to original dtype
vayuda Sep 25, 2024
6aab8f8
revert bad change
vayuda Sep 25, 2024
4e97611
make scales same type as model
vayuda Sep 25, 2024
77db01f
compatible with compile
vayuda Sep 25, 2024
588e81e
cast eq scale to bf16
vayuda Sep 25, 2024
bfa5797
switch calibration dataset
vayuda Sep 25, 2024
41a621b
remove extra line
vayuda Sep 25, 2024
20767d5
add import
vayuda Sep 25, 2024
ae32c7c
added save/load scales
vayuda Sep 28, 2024
e2160ae
add save/store workflow to example
vayuda Sep 28, 2024
9704c38
add arg to fn
vayuda Sep 28, 2024
04ef51d
fix cli arg
vayuda Sep 28, 2024
17e2fbc
fix cli arg
vayuda Sep 28, 2024
71f9e27
add comma
vayuda Sep 28, 2024
1716b0c
add model to fn call
vayuda Sep 28, 2024
436fb9f
fix example
vayuda Sep 28, 2024
84b407c
refactored awq impl
vayuda Oct 1, 2024
49fbbe2
Merge branch 'pytorch:main' into awq
vayuda Oct 1, 2024
1216f97
edits +update usage
vayuda Oct 2, 2024
3930660
perplexity evals added
vayuda Oct 2, 2024
3e5710c
updated readme with benchmarks
vayuda Oct 2, 2024
da6a70d
add awq-hqq to generate
vayuda Oct 2, 2024
7314d99
better citation
vayuda Oct 2, 2024
dc0c507
Merge branch 'pytorch:main' into awq
vayuda Oct 3, 2024
68d9592
nits
vayuda Oct 3, 2024
bc7526e
remove layout.py
vayuda Oct 3, 2024
1fdf068
quantization func changes
vayuda Oct 5, 2024
85ea32c
typo
vayuda Oct 5, 2024
0a12f96
fix import
vayuda Oct 5, 2024
1b5a57b
fix fn params
vayuda Oct 5, 2024
c193afb
edit
vayuda Oct 5, 2024
4e60dfd
edit
vayuda Oct 5, 2024
93dcb79
rename
vayuda Oct 5, 2024
d2ed1f2
fix indentation
vayuda Oct 5, 2024
2650702
added bf16 gaurd for uint4 tinygemm quant
vayuda Oct 5, 2024
5e25469
remove bad tests
vayuda Oct 6, 2024
3aa279f
remove arg
vayuda Oct 6, 2024
c08fdb1
one last guard..
vayuda Oct 7, 2024
b967ebd
require nightly
vayuda Oct 7, 2024
e7e329b
require nightly on everything
vayuda Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions scripts/create_weight_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import json
import torch
from transformers import AutoModel
from pathlib import Path
def create_weight_map(checkpoint_dir: Path):
"""
This function, create_weight_map, generates a mapping of a model's weights to a file (pytorch_model.bin)
and saves this mapping, along with the model's total size, to a JSON file (pytorch_model.bin.index.json).
The model is loaded from a pre-trained model specified by model_name.
This weight map is used by the HF conversion script (convert_hf_checkpoint.py).
"""
# Load the model
model_name = checkpoint_dir.parent.name +"/"+ checkpoint_dir.name
print(model_name)
model = AutoModel.from_pretrained(model_name)
# Get the state dict
state_dict = model.state_dict()
# Create the weight map
weight_map = {}
for key, tensor in state_dict.items():
# In this example, we're assuming all weights are in a single file
# You may need to adjust this if your model uses sharded weights
weight_map[key] = "pytorch_model.bin"
# Create the index dictionary
index_dict = {
"metadata": {"total_size": sum(param.numel() * param.element_size() for param in model.parameters())},
"weight_map": weight_map
}
# Save the index dictionary to a JSON file
with open(f"{checkpoint_dir}/pytorch_model.bin.index.json", "w") as f:
json.dump(index_dict, f, indent=2)
print("Created pytorch_model.bin.index.json")

if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Create weight map for hf model')
parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/Xenova/llama2.c-stories15M"))


args = parser.parse_args()
create_weight_map(
args.checkpoint_dir
)
26 changes: 23 additions & 3 deletions scripts/hf_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def format_value(value):
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, sparsity, compile, save, batch_size, max_length):

tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id).to(dtype=precision, device=device)
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).to(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: does device=device not work

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea surprisingly it doesn't

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you open an issue with error messages


if quantization == "autoquant" and compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
Expand All @@ -64,9 +64,29 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, spars
quantize_(model, fpx_weight_only(3, 2))
elif quantization == "autoquant":
model = autoquant(model.to(device=device))
elif quantization == "awq":
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
from torchao.prototype.awq.example import get_calib_dataset
if not TORCH_VERSION_AT_LEAST_2_3:
print("AWQ quantization requires torch2.3+")
exit()
from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear
quant_dtype = torch.uint4
group_size = 64
calibration_limit = 10
calibration_seq_length = 1024
model=model.to(device)
insert_awq_observer_(model,calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size)
with torch.no_grad():
calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_limit, block_size=calibration_seq_length)
for batch in calibration_data:
model(batch.to(device))
del batch
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear)

if quantization != "autoquant" and compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
model = torch.compile(model, mode= "max-autotune", fullgraph=True)

if sparsity == "semi_sparse":
def all_linear(mod, name):
Expand Down Expand Up @@ -114,7 +134,7 @@ def all_linear(mod, name):
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo", "autoquant", "None"], help='Which quantization technique to apply')
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "awq", "None"], help='Which quantization technique to apply')
parser.add_argument('-s', '--sparsity', default = "None", choices=["semi_sparse", "semi_sparse_mlp_only", "None"], help='Which sparsity technique to apply')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--save', action='store_true', help='Whether to save the model.')
Expand Down
2 changes: 1 addition & 1 deletion test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from torchao.dtypes.uintx.uintx import to_uintx
from torchao.dtypes.uintx import to_uintx
from torchao.quantization.quant_api import quantize_, uintx_weight_only
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
Expand Down
127 changes: 127 additions & 0 deletions test/prototype/test_awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from copy import deepcopy
import os
import pytest
import torch
from torchao.quantization import quantize_

from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
if TORCH_VERSION_AT_LEAST_2_3:
from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear

class ToyLinearModel(torch.nn.Module):
def __init__(self, m=512, n=256, k=128):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)
self.linear3 = torch.nn.Linear(k, 1, bias=False)

def example_inputs(self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"):
return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)]

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
return x

devices = ["cpu", "cuda"]
# torch.uintx dtypes are introduced in 2.3
if TORCH_VERSION_AT_LEAST_2_3:
qdtypes = (torch.uint4, torch.uint7)
else:
qdtypes = ()

@pytest.fixture(autouse=True)
def run_before_and_after_tests():
yield
torch._dynamo.reset() # reset cache between tests

idtypes = (torch.half, torch.bfloat16)
@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("qdtype", qdtypes)
@pytest.mark.parametrize("idtype", idtypes)
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3,reason="torch.uint(2-7) requires torch2.3+")
def test_awq_loading(device, qdtype, idtype):
dataset_size = 100
l1,l2,l3 = 512,256,128
original_dtype = idtype
quant_dtype = qdtype
group_size = 128
n_calibration_examples = 10
n_validation_examples = 10
sequence_length = 5

m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device)
dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device)
calibration_data = dataset[:n_calibration_examples]

# calibrate
insert_awq_observer_(m, n_validation_examples, sequence_length, quant_dtype=quant_dtype, group_size=group_size)

for example in calibration_data:
m(example.to(device))


# quantize
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
quantize_(m, awq_uintx(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear)
vayuda marked this conversation as resolved.
Show resolved Hide resolved

model_save_path = "awq_model.pth"
torch.save(m, model_save_path)
loaded_model = torch.load(model_save_path)
os.remove(model_save_path)

if device == "cuda":
m = torch.compile(m, fullgraph=True)
loaded_model = torch.compile(loaded_model, fullgraph=True)

awq_out = torch.cat([m(i.squeeze(0)) for i in dataset])
awq_save_load_out = torch.cat([loaded_model(i.squeeze(0)) for i in dataset])

assert awq_out is not None
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol = 1e-2)

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3,reason="torch.uint(2-7) requires torch2.3+")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_save_weights_only():
dataset_size = 100
l1,l2,l3 = 512,256,128
original_dtype = torch.half
quant_dtype = torch.uint4
device = "cuda"
group_size = 128
n_calibration_examples = 10
n_validation_examples = 10
sequence_length = 5

m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device)
m2 = deepcopy(m)
dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device)
calibration_data = dataset[:n_calibration_examples]

# calibrate
insert_awq_observer_(m, n_validation_examples, sequence_length, quant_dtype=quant_dtype, group_size=group_size)

for example in calibration_data:
m(example.to(device))


# quantize
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
quantize_(m, awq_uintx(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear)

model_save_path = "awq_model.pth"
torch.save(m.state_dict(), model_save_path)
m2.load_state_dict(torch.load(model_save_path), assign=True) # load weights only.torch.load(model_save_path)
os.remove(model_save_path)

m = torch.compile(m, fullgraph=True)
m2 = torch.compile(m2, fullgraph=True)

awq_out = torch.cat([m(i.squeeze(0)) for i in dataset])
awq_save_load_out = torch.cat([m2(i.squeeze(0)) for i in dataset])

assert awq_out is not None
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol = 1e-2)
2 changes: 1 addition & 1 deletion torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,4 @@ def run_evaluation(
args.calibration_limit,
args.calibration_seq_length,
args.pad_calibration_inputs,
)
)
53 changes: 52 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def main(
temperature: float = 0.8,
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
quantization: Optional[str] = None,
calibration_limit: int = 10,
calibration_seq_length: int = 256,
kv_cache_quantization: bool = False,
cache_size: Optional[int] = None,
linear_causal_mask: bool=False,
Expand Down Expand Up @@ -229,6 +231,53 @@ def main(
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
if quantization.startswith("awq"):
from torchao._models._eval import TransformerEvalWrapper
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
from torchao.prototype.awq.example import get_calib_dataset
if not TORCH_VERSION_AT_LEAST_2_3:
print("Awq requires torch2.3+")
exit()
from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear
quant_dtype = quantization.split("-")[1]
group_size = int(quantization.split("-")[2])
quant_dtype = getattr(torch, quant_dtype, torch.uint8)
model=model.to(device)
# get calibration data
insert_awq_observer_(model, calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size)
TransformerEvalWrapper(
model=model.to(device),
tokenizer=tokenizer,
max_seq_length=calibration_seq_length,
input_prep_func=prepare_inputs_for_model,
device=device,
).run_eval(
tasks=['wikitext'],
limit=calibration_limit,
)
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
if "hqq" in quant and quant_dtype == torch.uint4:
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
_DTYPE_TO_QVALUE_BOUNDS,
)
from torchao.dtypes import to_affine_quantized_intx, TensorCoreTiledLayoutType
def hqqint4(weight):
vayuda marked this conversation as resolved.
Show resolved Hide resolved
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT

return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=True)
vayuda marked this conversation as resolved.
Show resolved Hide resolved
quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, weight_quant_fn=hqqint4), is_observed_linear)
else:
quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear)
if "uintx" in quantization:
# uintx-nbits-groupsize, e.g. "uintx-2-64"
if "hqq" in quantization:
Expand Down Expand Up @@ -420,6 +469,8 @@ def callback(x):
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
)
)
parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples")
parser.add_argument("--calibration_seq_length", type=int, default=256, help="Sequence length for calibration")
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size')
parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)')
Expand All @@ -435,5 +486,5 @@ def callback(x):
args = parser.parse_args()
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
args.temperature, args.checkpoint_path, args.quantization, args.calibration_limit, args.calibration_seq_length, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
)
1 change: 1 addition & 0 deletions torchao/dtypes/uintx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .uintx import UintxTensor, UintxLayoutType, UintxAQTLayout, to_uintx, _DTYPE_TO_BIT_WIDTH
2 changes: 1 addition & 1 deletion torchao/dtypes/uintx/uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

_BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()}
else:
print("uintx feature need torch 2.3+, please upgrade pytorch")
print("uintx feature requires torch 2.3+, please upgrade pytorch")


class UintxTensor(TorchAOBaseTensor):
Expand Down
29 changes: 29 additions & 0 deletions torchao/prototype/awq/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# AWQ Quantization
Adapted from https://github.com/mit-han-lab/llm-awq

## Benchmarks
Evaluation perplexity numbers were calculated using the script in awq/example.py Group size of 64 was used for all quantization methods. For Llama-2-7b-chat-hf, performance benchmarks were calculated using the torchao/_models/llama/generate.py script and run on a 1xA100 80GB SXM4 instance. The awq-uint4 quantization method does not use an efficient fused kernel which is why performance is not great. awq-hqq uses tinygemm int4->bf16 kernel + hqq to provide better performance.

| Model | Quantization | Tokens/sec | Throughput (GB/sec) | Peak Mem (GB) | Model Size (GB) |
|--------------------|--------------|------------|---------------------|---------------|-----------------|
| Llama-2-7b-chat-hf | bfloat16 | 107.38 | 1418.93 | 13.88 | 13.21 |
| | awq-hqq-int4 | 196.6 | 761.2 | 5.05 | 3.87 |
| | awq-uint4 | 43.59 | 194.93 | 7.31 | 4.47 |
| | int4wo-hqq | 209.19 | 804.32 | 4.89 | 3.84 |
| | int4wo-64 | 201.14 | 751.42 | 4.87 | 3.74 |



The following tests were performed using LM eval and groupsize = 128
| Model | Quantization | Perplexity | Truthful QA MC2 | WinoGrande | ARC challenge |
| Llama-3-8B-Instruct| bfloat16 | 10.936 | 0.540 | 0.783 | 0.567 |
| | awq-hqq-int4 | 11.383 | 0.522 | 0.772 | 0.543 |
| | awq-uint4 | 11.409 | 0.519 | 0.756 | 0.577 |
| | int4wo-hqq | 11.905 | 0.528 | 0.757 | 0.563 |
| | int4wo-128 | 12.380 | 0.502 | 0.753 | 0.548 |






2 changes: 2 additions & 0 deletions torchao/prototype/awq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .api import insert_awq_observer_, awq_uintx
from .core import AWQObservedLinear
Loading
Loading