diff --git a/scripts/create_weight_map.py b/scripts/create_weight_map.py new file mode 100644 index 000000000..334465a9a --- /dev/null +++ b/scripts/create_weight_map.py @@ -0,0 +1,43 @@ +import json +import torch +from transformers import AutoModel +from pathlib import Path +def create_weight_map(checkpoint_dir: Path): + """ + This function, create_weight_map, generates a mapping of a model's weights to a file (pytorch_model.bin) + and saves this mapping, along with the model's total size, to a JSON file (pytorch_model.bin.index.json). + The model is loaded from a pre-trained model specified by model_name. + This weight map is used by the HF conversion script (convert_hf_checkpoint.py). + """ + # Load the model + model_name = checkpoint_dir.parent.name +"/"+ checkpoint_dir.name + print(model_name) + model = AutoModel.from_pretrained(model_name) + # Get the state dict + state_dict = model.state_dict() + # Create the weight map + weight_map = {} + for key, tensor in state_dict.items(): + # In this example, we're assuming all weights are in a single file + # You may need to adjust this if your model uses sharded weights + weight_map[key] = "pytorch_model.bin" + # Create the index dictionary + index_dict = { + "metadata": {"total_size": sum(param.numel() * param.element_size() for param in model.parameters())}, + "weight_map": weight_map + } + # Save the index dictionary to a JSON file + with open(f"{checkpoint_dir}/pytorch_model.bin.index.json", "w") as f: + json.dump(index_dict, f, indent=2) + print("Created pytorch_model.bin.index.json") + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Create weight map for hf model') + parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/Xenova/llama2.c-stories15M")) + + + args = parser.parse_args() + create_weight_map( + args.checkpoint_dir + ) \ No newline at end of file diff --git a/scripts/hf_eval.py b/scripts/hf_eval.py index 5f008ee43..1d4935f1b 100644 --- a/scripts/hf_eval.py +++ b/scripts/hf_eval.py @@ -48,7 +48,7 @@ def format_value(value): def run_evaluation(repo_id, tasks, limit, device, precision, quantization, sparsity, compile, save, batch_size, max_length): tokenizer = AutoTokenizer.from_pretrained(repo_id) - model = AutoModelForCausalLM.from_pretrained(repo_id).to(dtype=precision, device=device) + model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).to(device) if quantization == "autoquant" and compile: model = torch.compile(model, mode="max-autotune", fullgraph=True) @@ -64,9 +64,29 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, spars quantize_(model, fpx_weight_only(3, 2)) elif quantization == "autoquant": model = autoquant(model.to(device=device)) + elif quantization == "awq": + from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + from torchao.prototype.awq.example import get_calib_dataset + if not TORCH_VERSION_AT_LEAST_2_3: + print("AWQ quantization requires torch2.3+") + exit() + from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear + quant_dtype = torch.uint4 + group_size = 64 + calibration_limit = 10 + calibration_seq_length = 1024 + model=model.to(device) + insert_awq_observer_(model,calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size) + with torch.no_grad(): + calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_limit, block_size=calibration_seq_length) + for batch in calibration_data: + model(batch.to(device)) + del batch + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) + quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear) if quantization != "autoquant" and compile: - model = torch.compile(model, mode="max-autotune", fullgraph=True) + model = torch.compile(model, mode= "max-autotune", fullgraph=True) if sparsity == "semi_sparse": def all_linear(mod, name): @@ -114,7 +134,7 @@ def all_linear(mod, name): parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo", "autoquant", "None"], help='Which quantization technique to apply') + parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "awq", "None"], help='Which quantization technique to apply') parser.add_argument('-s', '--sparsity', default = "None", choices=["semi_sparse", "semi_sparse_mlp_only", "None"], help='Which sparsity technique to apply') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--save', action='store_true', help='Whether to save the model.') diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index 2d689a0c0..bb754135d 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -4,7 +4,7 @@ import torch -from torchao.dtypes.uintx.uintx import to_uintx +from torchao.dtypes.uintx import to_uintx from torchao.quantization.quant_api import quantize_, uintx_weight_only from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py new file mode 100644 index 000000000..eccf8db8f --- /dev/null +++ b/test/prototype/test_awq.py @@ -0,0 +1,129 @@ +from copy import deepcopy +import os +import pytest +import torch +from torchao.quantization import quantize_ + +from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 +if TORCH_VERSION_AT_LEAST_2_3: + from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear + +class ToyLinearModel(torch.nn.Module): + def __init__(self, m=512, n=256, k=128): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + self.linear3 = torch.nn.Linear(k, 1, bias=False) + + def example_inputs(self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"): + return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)] + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + return x + +devices = ["cpu", "cuda"] +# torch.uintx dtypes are introduced in 2.3 +if TORCH_VERSION_AT_LEAST_2_3: + qdtypes = (torch.uint4, torch.uint7) +else: + qdtypes = () + +@pytest.fixture(autouse=True) +def run_before_and_after_tests(): + yield + torch._dynamo.reset() # reset cache between tests + +@pytest.mark.parametrize("device", devices) +@pytest.mark.parametrize("qdtype", qdtypes) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5,reason="requires nightly pytorch") +def test_awq_loading(device, qdtype): + if qdtype == torch.uint4 and device == "cpu": + pytest.skip("uint4 not supported on cpu") + + dataset_size = 100 + l1,l2,l3 = 512,256,128 + original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs + quant_dtype = qdtype + group_size = 128 + n_calibration_examples = 10 + n_validation_examples = 10 + sequence_length = 5 + + m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device) + dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device) + calibration_data = dataset[:n_calibration_examples] + + # calibrate + insert_awq_observer_(m, n_validation_examples, sequence_length, quant_dtype=quant_dtype, group_size=group_size) + + for example in calibration_data: + m(example.to(device)) + + + # quantize + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) + quantize_(m, awq_uintx(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear) + + model_save_path = "awq_model.pth" + torch.save(m, model_save_path) + loaded_model = torch.load(model_save_path) + os.remove(model_save_path) + + if torch.cuda.is_available(): + m = torch.compile(m, fullgraph=True) + loaded_model = torch.compile(loaded_model, fullgraph=True) + + awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) + awq_save_load_out = torch.cat([loaded_model(i.squeeze(0)) for i in dataset]) + + assert awq_out is not None + assert awq_save_load_out is not None + assert torch.allclose(awq_out, awq_save_load_out, atol = 1e-2) + +@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5,reason="requires nightly pytorch") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_save_weights_only(): + dataset_size = 100 + l1,l2,l3 = 512,256,128 + original_dtype = torch.bfloat16 + quant_dtype = torch.uint4 + device = "cuda" + group_size = 128 + n_calibration_examples = 10 + n_validation_examples = 10 + sequence_length = 5 + + m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device) + m2 = deepcopy(m) + dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device) + calibration_data = dataset[:n_calibration_examples] + + # calibrate + insert_awq_observer_(m, n_validation_examples, sequence_length, quant_dtype=quant_dtype, group_size=group_size) + + for example in calibration_data: + m(example.to(device)) + + + # quantize + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) + quantize_(m, awq_uintx(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear) + + model_save_path = "awq_model.pth" + torch.save(m.state_dict(), model_save_path) + m2.load_state_dict(torch.load(model_save_path), assign=True) # load weights only.torch.load(model_save_path) + os.remove(model_save_path) + + m = torch.compile(m, fullgraph=True) + m2 = torch.compile(m2, fullgraph=True) + + awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) + awq_save_load_out = torch.cat([m2(i.squeeze(0)) for i in dataset]) + + assert awq_out is not None + assert awq_save_load_out is not None + assert torch.allclose(awq_out, awq_save_load_out, atol = 1e-2) \ No newline at end of file diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index d495c2065..cc80d40db 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -233,4 +233,4 @@ def run_evaluation( args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, - ) + ) \ No newline at end of file diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 19e42e7cd..7e4708ba5 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -161,6 +161,8 @@ def main( temperature: float = 0.8, checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), quantization: Optional[str] = None, + calibration_limit: int = 10, + calibration_seq_length: int = 256, kv_cache_quantization: bool = False, cache_size: Optional[int] = None, linear_causal_mask: bool=False, @@ -229,6 +231,33 @@ def main( quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType())) if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) + if quantization.startswith("awq"): + from torchao._models._eval import TransformerEvalWrapper + from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + from torchao.prototype.awq.example import get_calib_dataset + if not TORCH_VERSION_AT_LEAST_2_3: + print("Awq requires torch2.3+") + exit() + from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear + quant_dtype = quantization.split("-")[1] + group_size = int(quantization.split("-")[2]) + quant_dtype = getattr(torch, quant_dtype, torch.uint8) + model=model.to(device) + # get calibration data + insert_awq_observer_(model, calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size) + TransformerEvalWrapper( + model=model.to(device), + tokenizer=tokenizer, + max_seq_length=calibration_seq_length, + input_prep_func=prepare_inputs_for_model, + device=device, + ).run_eval( + tasks=['wikitext'], + limit=calibration_limit, + ) + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) + use_hqq = "hqq" in quantization + quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, use_hqq=use_hqq), is_observed_linear) if "uintx" in quantization: # uintx-nbits-groupsize, e.g. "uintx-2-64" if "hqq" in quantization: @@ -420,6 +449,8 @@ def callback(x): +'autoquant-int4, autoquant-float8, uintx--, uintx---hqq, sparse-marlin' ) ) + parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples") + parser.add_argument("--calibration_seq_length", type=int, default=256, help="Sequence length for calibration") parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') @@ -435,5 +466,5 @@ def callback(x): args = parser.parse_args() main( args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, - args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result + args.temperature, args.checkpoint_path, args.quantization, args.calibration_limit, args.calibration_seq_length, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result ) diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index e69de29bb..ad9166079 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -0,0 +1 @@ +from .uintx import UintxTensor, UintxLayoutType, UintxAQTLayout, to_uintx, _DTYPE_TO_BIT_WIDTH \ No newline at end of file diff --git a/torchao/dtypes/uintx/uintx.py b/torchao/dtypes/uintx/uintx.py index cfe75f4dc..a0cd687f5 100644 --- a/torchao/dtypes/uintx/uintx.py +++ b/torchao/dtypes/uintx/uintx.py @@ -30,7 +30,7 @@ _BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()} else: - print("uintx feature need torch 2.3+, please upgrade pytorch") + print("uintx feature requires torch 2.3+, please upgrade pytorch") class UintxTensor(TorchAOBaseTensor): diff --git a/torchao/prototype/awq/README.md b/torchao/prototype/awq/README.md new file mode 100644 index 000000000..e7b7f782f --- /dev/null +++ b/torchao/prototype/awq/README.md @@ -0,0 +1,29 @@ +# AWQ Quantization +Adapted from https://github.com/mit-han-lab/llm-awq + +## Benchmarks +Evaluation perplexity numbers were calculated using the script in awq/example.py Group size of 64 was used for all quantization methods. For Llama-2-7b-chat-hf, performance benchmarks were calculated using the torchao/_models/llama/generate.py script and run on a 1xA100 80GB SXM4 instance. The awq-uint4 quantization method does not use an efficient fused kernel which is why performance is not great. awq-hqq uses tinygemm int4->bf16 kernel + hqq to provide better performance. + +| Model | Quantization | Tokens/sec | Throughput (GB/sec) | Peak Mem (GB) | Model Size (GB) | +|--------------------|--------------|------------|---------------------|---------------|-----------------| +| Llama-2-7b-chat-hf | bfloat16 | 107.38 | 1418.93 | 13.88 | 13.21 | +| | awq-hqq-int4 | 196.6 | 761.2 | 5.05 | 3.87 | +| | awq-uint4 | 43.59 | 194.93 | 7.31 | 4.47 | +| | int4wo-hqq | 209.19 | 804.32 | 4.89 | 3.84 | +| | int4wo-64 | 201.14 | 751.42 | 4.87 | 3.74 | + + + +The following tests were performed using LM eval and groupsize = 128 +| Model | Quantization | Perplexity | Truthful QA MC2 | WinoGrande | ARC challenge | +| Llama-3-8B-Instruct| bfloat16 | 10.936 | 0.540 | 0.783 | 0.567 | +| | awq-hqq-int4 | 11.383 | 0.522 | 0.772 | 0.543 | +| | awq-uint4 | 11.409 | 0.519 | 0.756 | 0.577 | +| | int4wo-hqq | 11.905 | 0.528 | 0.757 | 0.563 | +| | int4wo-128 | 12.380 | 0.502 | 0.753 | 0.548 | + + + + + + diff --git a/torchao/prototype/awq/__init__.py b/torchao/prototype/awq/__init__.py new file mode 100644 index 000000000..ca9381d57 --- /dev/null +++ b/torchao/prototype/awq/__init__.py @@ -0,0 +1,2 @@ +from .api import insert_awq_observer_, awq_uintx +from .core import AWQObservedLinear \ No newline at end of file diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py new file mode 100644 index 000000000..e3a8827e2 --- /dev/null +++ b/torchao/prototype/awq/api.py @@ -0,0 +1,139 @@ +import torch +import torch.nn.functional as F + +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, + _DTYPE_TO_QVALUE_BOUNDS, +) +from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata +from torchao.quantization.observer import PerGroup +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter +from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType +from torchao.dtypes import( + to_affine_quantized_intx, + TensorCoreTiledLayoutType, +) +from .core import( + AWQObserver, + AWQObservedLinear, +) + + +assert len(_DTYPE_TO_BIT_WIDTH) > 0, "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+" + +def insert_awq_observer_(model: torch.nn.Module, n_validation_examples: int, validation_sequence_len: int, quant_dtype: torch.dtype = torch.uint4, scale_search_space_size: int = 20, group_size: int = 128): + """ + Inserts AWQObserver into Linear layers of a given model. + + Args: + model: The model to be modified (in place). Ensure model is on the desired device for calibration + n_validation_examples: Number of examples used to validate scale options + validation_sequence_len: Number of tokens in each validation example + quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8 + scale search space size: how many different scale options to try. Original AWQ implementation uses 20. A larger size can lead to better results but takes longer to calibrate + group_size: Quantization granularity. Use -1 for channel wise quantization + """ + _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) + assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" + # AQT config + mapping_type = MappingType.ASYMMETRIC + quantization_granularity = PerGroup(group_size) + quant_min = 0 + quant_max = 255 if quant_dtype == torch.uint8 else 2 ** _DTYPE_TO_BIT_WIDTH[quant_dtype] - 1 + eps = torch.finfo(torch.float32).eps + preserve_zero = True + zero_point_dtype = torch.int64 + zero_point_domain = ZeroPointDomain.INT + + + def replace_with_observer(layer): + # creates observer and replaces linear layers with AWQObservedLinear layers + observer = AWQObserver( + layer.weight, + layer.bias, + quantization_granularity, + mapping_type, + quant_dtype, + n_validation_examples, + validation_sequence_len, + scale_search_space_size, + preserve_zero = preserve_zero, + zero_point_domain = zero_point_domain, + zero_point_dtype = zero_point_dtype, + quant_min=quant_min, + quant_max = quant_max, + eps = eps) + return AWQObservedLinear.from_float(layer, observer) + _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) + +def _observed_linear_subclass_inserter(constructor): + """ + Replaces unquantized AWQObservedLinear instances with quantized linear instances. + + Args: + constructor: the function which applies quantization to the AWQObservedLinear layer + """ + def insert_subclass(observed_linear): + # creates the new linear layer using constructor + linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, observed_linear.bias!=None, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype) + linear.weight = torch.nn.Parameter(constructor(observed_linear), requires_grad=False) + linear.bias = observed_linear.bias + return linear + + return insert_subclass + + +def awq_uintx(quant_dtype: torch.dtype = torch.uint4, + group_size: int = 64, + use_hqq: bool = False,): + """ + Quantizes linear layers when passed into quantize_() + + Args: + quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8 + group_size: Quantization granularity. Use -1 for channel wise quantization + weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used + """ + assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" + + def weight_quant_func(observed_linear): + equalization_scale = observed_linear.act_obs.calculate_qparams() + # AQT config + if quant_dtype == torch.uint4: + target_dtype = torch.int32 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT + layout_type = TensorCoreTiledLayoutType(inner_k_tiles=8) + else: + target_dtype = torch.uint8 + eps = torch.finfo(torch.float32).eps + preserve_zero = True + zero_point_dtype = torch.int64 + zero_point_domain = ZeroPointDomain.INT + layout_type = UintxLayoutType(quant_dtype) + + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0] + quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1] + qw = to_affine_quantized_intx( + observed_linear.weight * equalization_scale, + mapping_type, + block_size, + target_dtype, quant_min, + quant_max, eps, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + layout_type=layout_type, + use_hqq=use_hqq + ) + + return to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale) + + return _observed_linear_subclass_inserter(weight_quant_func) + + diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py new file mode 100644 index 000000000..77810a2e4 --- /dev/null +++ b/torchao/prototype/awq/core.py @@ -0,0 +1,156 @@ +from dataclasses import dataclass +from typing import Tuple, Optional + +import torch +import torch.nn.functional as F + +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType +from torchao.dtypes import to_affine_quantized_intx +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, +) +from torchao.quantization.observer import ( + AffineQuantizedObserverBase, GranularityType +) + + +class AWQObserver(AffineQuantizedObserverBase): + def __init__(self, + weight: torch.Tensor, + bias: torch.Tensor, + quantization_granularity: GranularityType, + mapping_type: MappingType, + target_dtype: torch.dtype, + n_validation_examples: int, + validation_sequence_len: int, + scale_search_space_size: int = 20, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: Optional[bool] = True, + zero_point_domain = ZeroPointDomain.INT, + ): + """ + A custom observer for Activation aware Weight Quantization (AWQ) + + Args: + weight: The weight tensor to be observed. + bias: The bias tensor to be observed. + quantization_granularity: Granularity type which specifies how many weights share the same scale/zero point + input_dtype: The data type of the input tensor. + mapping_type: Always set to asymmetric + target_dtype: The target data type of the quantized tensor + n_validation_examples: Number of examples used to calibrate observer + validation_sequence_len: Number of tokens in each example + scale_search_space_size: The number of scales to search for. + quant_min: The minimum quantized value + quant_max: The maximum quantized value + eps: The minimum scale. + scale_dtype: The data type of the scale tensor. + zero_point_dtype: The data type of the zero point tensor. + preserve_zero: A flag to indicate whether we need zero to be exactly + representable or not. + zero_point_domain: The domain of the zero point. + """ + super().__init__( + mapping_type, + target_dtype, + quantization_granularity, + quant_min = quant_min, + quant_max = quant_max, + eps = eps, + scale_dtype = scale_dtype, + zero_point_dtype = zero_point_dtype, + preserve_zero = preserve_zero, + zero_point_domain = zero_point_domain, + ) + self.quantization_granularity = quantization_granularity + self.weight = weight + self.bias = bias + self.n_validation_examples = n_validation_examples + self.validation_sequence_len = validation_sequence_len + self.calibration_token_count = 0 + self.inputs = [] + self.outputs = [] + self.scale_options = scale_search_space_size + self.device = self.weight.device + self.average = torch.zeros((1,weight.shape[1]), device= self.device) + if self.bias is not None: + self.bias.to(self.device) + @torch.no_grad() + def forward(self, input: torch.Tensor, output: torch.Tensor): + # import pdb + # pdb.set_trace() + # print(input.shape, input.abs().sum(1).shape, self.average.shape) + if len(self.inputs) < self.n_validation_examples: + self.inputs.append(input.to("cpu")) + self.outputs.append(output.to("cpu")) + self.calibration_token_count += input.shape[-2] + self.average += input.abs().sum(-2) + + + + def calculate_qparams(self): + # import pdb + # pdb.set_trace() + assert self.outputs != None, "calibrate observer first by running model on exemplar data" + self.average /= (self.calibration_token_count) + for i in range(self.n_validation_examples): + self.inputs[i] = self.inputs[i].to(self.device) + self.outputs[i] = self.outputs[i].to(self.device) + + best_loss = float('inf') + best_scales = None + for i in range(self.scale_options): + ratio = i * 1 / self.scale_options + scales = self.average.pow(ratio).to(self.weight.dtype) + scales = scales / (scales.max() * scales.min()).sqrt() + layout = UintxLayoutType(self.target_dtype) + # regardless of weight dtype, we have to store as packed uint8 tensors + tensor_dtype = torch.uint8 + w = to_affine_quantized_intx( + self.weight*scales, + self.mapping_type, + (1, self.quantization_granularity.group_size), + tensor_dtype, + quant_min = self.quant_min, + quant_max = self.quant_max, + eps = self.eps, + scale_dtype = self.scale_dtype, + zero_point_dtype = self.zero_point_dtype, + preserve_zero = self.preserve_zero, + zero_point_domain = self.zero_point_domain, + layout_type = layout + ) + loss = 0 + for i in range(self.n_validation_examples): + q_out = F.linear(self.inputs[i]/scales, w, self.bias) + loss += (self.outputs[i] - q_out).pow(2).mean().item() + if loss < best_loss: + best_scales = scales + best_loss = loss + for i in range(self.n_validation_examples): + self.inputs[i].to("cpu") + self.outputs[i].to("cpu") + return best_scales.detach() + +class AWQObservedLinear(torch.nn.Linear): + def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None): + super().__init__(in_features, out_features, bias, device, dtype) + self.act_obs = act_obs + + def forward(self, input: torch.Tensor): + output = F.linear(input, self.weight, self.bias) + self.act_obs(input, output) + return output + + @classmethod + def from_float(cls, float_linear: torch.nn.Linear, act_obs: AWQObserver): + observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype) + observed_linear.weight = float_linear.weight + observed_linear.bias = float_linear.bias + return observed_linear \ No newline at end of file diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py new file mode 100644 index 000000000..8b2eb0675 --- /dev/null +++ b/torchao/prototype/awq/example.py @@ -0,0 +1,230 @@ +import torch +import argparse +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset +from tqdm import tqdm +import time +from torchao.prototype.awq import insert_awq_observer_, AWQObservedLinear, awq_uintx +from torchao.quantization import quantize_, int4_weight_only, uintx_weight_only + + +# adapted from: https://github.com/mit-han-lab/llm-awq/blob/main/awq/entry.py#L255 +def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512): + dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") + samples = [] + n_tokens = n_samples * block_size + n_run = n_tokens + for data in dataset: + line = data["text"] + line = line.strip() + line_encoded = tokenizer.encode(line) + if len(line_encoded) > 512: + continue + sample = torch.tensor([line_encoded]) + if sample.numel() == 0: + continue + samples.append(sample) + n_run -= len(line_encoded) + if n_run <= n_samples: + break + + cat_samples = torch.cat(samples, dim=1) + return [cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_samples)] + +# from https://github.com/mobiusml/hqq/blob/master/examples/llama2_benchmark/eval_model.py +def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True): + model.eval() + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "right" + tokenizer.add_eos_token = False + + dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + encodings = tokenizer('\n\n'.join(dataset['text']), return_tensors='pt') + + encodings['input_ids'] = encodings['input_ids'].to('cuda') + + lls, t = [], [] + for i in tqdm(range(0, encodings['input_ids'].size(1), stride), disable=not verbose): + begin_loc = max(i + stride - sequence_length, 0) + end_loc = min(i + stride, encodings['input_ids'].size(1)) + trg_len = end_loc - i + input_ids = encodings['input_ids'][:,begin_loc:end_loc] + target_ids = input_ids.clone() + target_ids[:,:-trg_len] = -100 #ignore context + + t1 = time.time() + with torch.no_grad(): + log_likelihood = model(input_ids, labels=target_ids).loss * trg_len + torch.cuda.synchronize() + t2 = time.time() + t.append((t2-t1)) + lls.append(log_likelihood) + + del input_ids, target_ids + + ppl = float(torch.exp(torch.stack(lls).sum() / end_loc)) + pred_time = sum(t)/len(t) + if(verbose): + print('perplexity', ppl) + print('time', str(pred_time) + ' sec') + + return {'perplexity':ppl, 'prediction_time':pred_time} + +# adapted from Hicham Badri (@mobicham) +def benchmark(model, tokenizer, max_length, tasks=None): + import numpy as np + import copy + import lm_eval + model.eval(); + model.config.use_cache = False + try: + lm_eval.tasks.initialize_tasks() + except: + pass + model_eval = lm_eval.models.huggingface.HFLM(pretrained=model, tokenizer=tokenizer) + eval_batch_size = 1 #8 + if tasks is None: + tasks = ["PPL","truthfulqa_mc2", "winogrande", "arc_challenge", "hellaswag", "gsm8k", "mmlu"] + results = {} + if "PPL" in tasks: + results["perplexity"] = wiki2_eval(model, tokenizer, 512, verbose=True) + ############################################ + if "truthfulqa_mc2" in tasks: + for task in [("truthfulqa_mc2", 0)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] + print(tag, results[tag]) + if "winogrande" in tasks: + for task in [("winogrande", 5)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] + print(tag, results[tag]) + if "arc_challenge" in tasks: + for task in [("arc_challenge", 25)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] + print(tag, results[tag]) + + # ############################################ + if "hellaswag" in tasks: + for task in [("hellaswag", 10)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] + print(tag, results[tag]) + if "gsm8k" in tasks: + for task in [("gsm8k", 5)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] + print(tag, results[tag]) + # ############################################ + + results_1 = copy.deepcopy(results) + if "mmlu" in tasks: + #MMLU + results_mmlu = {} + for task in [("mmlu", 5)]: + tag, fewshot = task + results_mmlu[tag] = lm_eval.evaluator.simple_evaluate(model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size)['results'] + print(tag, results_mmlu[tag]) + + mmlu_list = "hendrycksTest-abstract_algebra,hendrycksTest-anatomy,hendrycksTest-astronomy,hendrycksTest-business_ethics,hendrycksTest-clinical_knowledge,hendrycksTest-college_biology,hendrycksTest-college_chemistry,hendrycksTest-college_computer_science,hendrycksTest-college_mathematics,hendrycksTest-college_medicine,hendrycksTest-college_physics,hendrycksTest-computer_security,hendrycksTest-conceptual_physics,hendrycksTest-econometrics,hendrycksTest-electrical_engineering,hendrycksTest-elementary_mathematics,hendrycksTest-formal_logic,hendrycksTest-global_facts,hendrycksTest-high_school_biology,hendrycksTest-high_school_chemistry,hendrycksTest-high_school_computer_science,hendrycksTest-high_school_european_history,hendrycksTest-high_school_geography,hendrycksTest-high_school_government_and_politics,hendrycksTest-high_school_macroeconomics,hendrycksTest-high_school_mathematics,hendrycksTest-high_school_microeconomics,hendrycksTest-high_school_physics,hendrycksTest-high_school_psychology,hendrycksTest-high_school_statistics,hendrycksTest-high_school_us_history,hendrycksTest-high_school_world_history,hendrycksTest-human_aging,hendrycksTest-human_sexuality,hendrycksTest-international_law,hendrycksTest-jurisprudence,hendrycksTest-logical_fallacies,hendrycksTest-machine_learning,hendrycksTest-management,hendrycksTest-marketing,hendrycksTest-medical_genetics,hendrycksTest-miscellaneous,hendrycksTest-moral_disputes,hendrycksTest-moral_scenarios,hendrycksTest-nutrition,hendrycksTest-philosophy,hendrycksTest-prehistory,hendrycksTest-professional_accounting,hendrycksTest-professional_law,hendrycksTest-professional_medicine,hendrycksTest-professional_psychology,hendrycksTest-public_relations,hendrycksTest-security_studies,hendrycksTest-sociology,hendrycksTest-us_foreign_policy,hendrycksTest-virology,hendrycksTest-world_religions" + mmlu_list = [l.replace('hendrycksTest-','') for l in mmlu_list.split(',')] + results_mmlu = results_mmlu['mmlu'] + + k = [] + for r in results_mmlu: + if np.any([(l in r) for l in mmlu_list]): + k.append(results_mmlu[r]['acc,none']) + + assert len(k)==57 + print('MMLU avg acc', np.mean(k)) + + results['mmlu'] = np.mean(k) + return results + + +def wikitext2_ppl( + repo_id: str, + quant: str, + tasks: list[str], + calibration_size: int, + validation_size:int, + device: str, + precision:torch.dtype, + sequence_length: int, + compile: bool, + model_save_path: str): + print(f"Loading model on {device}...") + torch.manual_seed(34) + t0 = time.time() + # load any model with torch.nn.linear layers + tokenizer = AutoTokenizer.from_pretrained(repo_id) + model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).eval().to(device) + print(f"Time to load model: {time.time() - t0:.02f} seconds") + if quant.startswith("awq"): + quant_dtype = quant.split("-")[1] + group_size = int(quant.split("-")[2]) + quant_dtype = getattr(torch, quant_dtype, torch.bfloat16) + print(f"running {quant_dtype} calibration") + t0 = time.time() + # insert observers to find average magnitude and calculate scales + insert_awq_observer_(model,validation_size, sequence_length, quant_dtype=quant_dtype, group_size=group_size) + calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_size, block_size=sequence_length) + for batch in calibration_data: + model(batch.to(device)) + batch.to("cpu") + print(f"time for calibration: {time.time() - t0:.02f} seconds") + + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) + use_hqq = "hqq" in quant + print(f"running {quant_dtype} quantization") + t0 = time.time() + quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, use_hqq=use_hqq), is_observed_linear) + print(f"time for quantization: {time.time() - t0:.02f} seconds") + if model_save_path is not None: + print(f"Saving model to {model_save_path}") + torch.save(model, model_save_path) + elif quant.startswith("int4wo"): + group_size = int(quant.split("-")[1]) + use_hqq = "hqq" in quant + print(f"running {quant} quantization with group size {group_size}") + quantize_(model, int4_weight_only(group_size=group_size, use_hqq= use_hqq)) + if compile: + model = torch.compile(model) + + results = benchmark(model, tokenizer, sequence_length, tasks=tasks) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate a model with the specified parameters.") + + + # Optional arguments with default values + parser.add_argument("repo", type=str, help="Repository ID of the model.") + parser.add_argument("quant", type=str, help="Quantization method. Options are either awq-uint- for x =[1..8], int4wo-, or int4wo--hqq.") + parser.add_argument("--tasks", type=list[str], help="Task to benchmark model on. Either PPL or QA", default=["PPL"]) + parser.add_argument("--calibration_samples", type=int, default=10, help="Number of samples to use for calibration. Default is 10.") + parser.add_argument("--validation_size", type=int, default=1, help="Validation size. Default is 1.") + parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") + parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") + parser.add_argument("--seq_len", type=int, default=512, help="Length of examples to calibrate and evaluate model on. Default is 512") + parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") + parser.add_argument("--model_save_path", type=str, default=None, help="Path to store the scale values.") + + args = parser.parse_args() + + # Convert precision argument to torch dtype + precision_dtype = getattr(torch, args.precision, torch.bfloat16) + ppl = wikitext2_ppl( + args.repo, + args.quant, + args.tasks, + args.calibration_samples, + args.validation_size, + args.device, + args.precision, + args.seq_len, + args.compile, + args.model_save_path + ) + + print(f"{args.quant} Results: {ppl}") \ No newline at end of file diff --git a/torchao/quantization/linear_activation_scale.py b/torchao/quantization/linear_activation_scale.py index f1230ef75..eee918147 100644 --- a/torchao/quantization/linear_activation_scale.py +++ b/torchao/quantization/linear_activation_scale.py @@ -145,8 +145,7 @@ def _(func, types, args, kwargs): kwargs, args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), ) - - + @implements(aten.t.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index 4653e6577..bef4abe71 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -55,6 +55,25 @@ class PerAxis(GranularityType): axis: int @dataclass(frozen=True) + +class PerGroup(GranularityType): + """ + Represents per-channel group granularity in quantization. + + This granularity type calcualtes different quantization parameters + for each group of elements. + + For example if the input tensor is shape [8, 16], and the group size is 4, then + the input tensor is reshaped to [64, 4] + quantization parameters are calculated for each group of 4 elements, + giving a total of 64 quantization parameters. + + Attributes: + group_size (int): The size of each quantization group + + """ + group_size: int + class PerRow(GranularityType): """ Represents row-wise granularity in quantization.