From fa78b5b73fe993333366c4dc7c03f4e326c1f374 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 30 Aug 2024 18:56:22 -0700 Subject: [PATCH] Expose hqq through `int4_weight_only` API Summary: att, this is a follow up for https://github.com/pytorch/ao/pull/605 to make hqq available in quantize_ API `quantize_(model, int4_weight_only(group_size, use_hqq=True)` Test Plan: python generate.py --compile --quantization int4wo-hqq-64 --precision bfloat16 Average tokens/sec: 195.24 Average Bandwidth: 729.40 GB/s Peak Memory Usage: 5.09 GB Model Size: 3.74 GB python eval.py --compile --quantization int4wo-hqq-64 --precision bfloat16 wikitext: {'word_perplexity,none': 12.823631773497512, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.611400903914048, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.6883154699192412, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'} Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/llama/eval.py | 5 +++-- torchao/_models/llama/generate.py | 7 ++++--- torchao/quantization/README.md | 7 ++++++- torchao/quantization/quant_api.py | 30 +++++++++++++++++++++++++++--- 4 files changed, 40 insertions(+), 9 deletions(-) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index f673a966de..05e76f282f 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -71,7 +71,8 @@ def run_evaluation( if "int4wo" in quantization and not "gptq" in quantization: groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize_(model.to(device), int4_weight_only(group_size=groupsize)) + use_hqq = "hqq" in quantization + quantize_(model.to(device), int4_weight_only(group_size=groupsize, use_hqq=use_hqq)) if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) if "int4wo" in quantization and "gptq" in quantization: @@ -120,7 +121,7 @@ def run_evaluation( parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-, int4wo--gptq") + parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-, int4wo--gptq, int4wo-hqq-") parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 94a18488b2..e2914e9944 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -128,7 +128,7 @@ def generate( # execute token generation input_pos = torch.tensor([T], device=device, dtype=torch.int) generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs) - + seq = torch.cat((seq[:T+1], *generated_tokens)) return seq @@ -218,7 +218,8 @@ def main( if "int4wo" in quantization: groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize_(model, int4_weight_only(group_size=groupsize)) + use_hqq = "hqq" in quantization + quantize_(model, int4_weight_only(group_size=groupsize, use_hqq=use_hqq)) if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) if "autoquant" == quantization: @@ -387,7 +388,7 @@ def callback(x): parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') - parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant') + parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, int4wo-hqq-, autoquant') parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 0746f08730..fc9dc4fd7a 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -216,7 +216,12 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is ```python # for torch 2.4+ from torchao.quantization import quantize_, int4_weight_only -quantize_(model, int4_weight_only()) +group_size = 32 + +# you can enable [hqq](https://github.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through +# use_hqq flag for `int4_weight_only` quantization +use_hqq = False +quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index aa2d3b3f93..236c56a106 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -416,7 +416,7 @@ def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner size is more fine grained, choices are [256, 128, 64, 32] `layout_type`: layout type for quantized tensor, default is `TensorCoreTiledLayoutType(inner_k_tiles=8)` """ - def apply_int4_weight_only_quant(weight, use_hqq=False): + def apply_int4_weight_only_quant(weight): if weight.shape[-1] % group_size != 0: return weight @@ -498,7 +498,7 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn): """ Applies float8 weight-only symmetric per-channel quantization to linear layers. - + Args: weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. @@ -617,9 +617,33 @@ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor: return weight layout_type = FpxTensorCoreLayoutType(ebits, mbits) - print("layout type:", layout_type) return to_affine_quantized_fpx(weight, layout_type) return _get_linear_subclass_inserter(apply_quant_llm) +def hqq_weight_only(dtype, group_size): + + def apply_hqq(weight: torch.Tensor) -> torch.Tensor: + if dtype != torch.uint4: + layout_type = UintxLayoutType(dtype=dtype, pack_dim=-1) + return to_affine_quantized_intx( + input_float=weight, + mapping_type=MappingType.ASYMMETRIC, + block_size=(1, group_size), + target_dtype=dtype, + zero_point_domain=ZeroPointDomain.FLOAT, + layout_type=layout_type, + preserve_zero=False, + use_hqq=True, + ) + else: + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain=ZeroPointDomain.FLOAT, layout_type=layout_type, use_hqq=True) + + if TORCH_VERSION_AT_LEAST_2_5: torch.serialization.add_safe_globals([_int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant])