From 5c9269ced437fbaa0f7a19d31aabb7a9e433f841 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 12 Apr 2024 10:39:45 -0700 Subject: [PATCH] Initial support for 8da4w QAT Summary: This commit adds support for QAT, where linear layers are fake quantized with int8 per token dynamic activations (8da) and int4 grouped per channel weights (4w). This initial implementation uses the same module swap approach as 8da4w PTQ for simplicity and code reuse. In the future, we may wish to consider migrating both flows to use tensor subclasses for better composability with other PyTorch features. Test Plan: python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group python test/quantization/test_qat.py -k test_fake_quantize_per_token python test/quantization/test_qat.py -k test_qat_8da4w_linear python test/quantization/test_qat.py -k test_qat_8da4w_quantizer Reviewers: jerryzh168, cpuhrsch, HDCharles Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar Tasks: https://github.com/pytorch-labs/ao/issues/86 --- test/quantization/test_qat.py | 179 +++++++++++++ test/quantization/test_quant_primitives.py | 4 +- torchao/quantization/_prototype/qat.py | 286 +++++++++++++++++++++ 3 files changed, 467 insertions(+), 2 deletions(-) create mode 100644 test/quantization/test_qat.py create mode 100644 torchao/quantization/_prototype/qat.py diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py new file mode 100644 index 0000000000..6c5ac1c43c --- /dev/null +++ b/test/quantization/test_qat.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# mypy: ignore-errors +# This test takes a long time to run + +import copy +import unittest + +import torch +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torchao.quantization._prototype.qat import ( + _choose_qparams_per_token_asymmetric, + fake_quantize_per_channel_group, + fake_quantize_per_token, + Int8DynActInt4WeightQATLinear, + Int8DynActInt4WeightQATQuantizer, +) +from torchao.quantization.quant_primitives import ( + get_group_qparams_symmetric, + group_quantize_tensor_symmetric, + per_token_dynamic_quant, +) +from torchao.quantization.utils import ( + TORCH_VERSION_AFTER_2_3, +) +from torchao.quantization.GPTQ import ( + Int8DynActInt4WeightLinear, + Int8DynActInt4WeightQuantizer, +) + + +# TODO: put this in a common test utils file +class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float) + self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float) + + def example_inputs(self): + return (torch.randn(1, 64).to(torch.float),) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +class TestQATQuantPrimitives(unittest.TestCase): + SEED = 123 + + def _get_qmin_qmax(self, n_bit: int): + qmin = -(2 ** (n_bit - 1)) + qmax = 2 ** (n_bit - 1) - 1 + return (qmin, qmax) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + def test_fake_quantize_per_channel_group(self): + n_bit = 4 + (qmin, qmax) = self._get_qmin_qmax(n_bit) + group_size = 128 + + torch.manual_seed(self.SEED) + x = torch.randn(100, 256).requires_grad_() + (s, zp) = get_group_qparams_symmetric(x, n_bit, group_size) + x2 = copy.deepcopy(x) + + # fake quant op + out = fake_quantize_per_channel_group( + x, s, zp, qmin, qmax, group_size, + ) + out.sum().backward() + + # compare against PTQ ops + out_ptq = torch.ops.quantized_decomposed.quantize_per_channel_group( + x2, s, zp, qmin, qmax, torch.int8, group_size, + ) + out_ptq = torch.ops.quantized_decomposed.dequantize_per_channel_group( + out_ptq, s, zp, qmin, qmax, torch.int8, group_size, torch.float32, + ) + torch.testing.assert_close(out, out_ptq, atol=0, rtol=0) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + def test_fake_quantize_per_token(self): + (qmin, qmax) = self._get_qmin_qmax(8) + + torch.manual_seed(self.SEED) + x = torch.randn(100, 256).requires_grad_() + x2 = copy.deepcopy(x) + # TODO: use torch.ops.aten.quantized_decomposed version instead + (s, zp) = _choose_qparams_per_token_asymmetric( + x, + torch.int8, # not used + ) + + # fake quant op + out = fake_quantize_per_token(x, s, zp, qmin, qmax) + out.sum().backward() + + # compare against PTQ ops + out_ptq = torch.ops.quantized_decomposed.quantize_per_token( + x2, s, zp, qmin, qmax, torch.int8, + ) + out_ptq = torch.ops.quantized_decomposed.dequantize_per_token( + out_ptq, s, zp, qmin, qmax, torch.int8, torch.float32, + ) + torch.testing.assert_close(out, out_ptq, atol=0, rtol=0) + + def _set_ptq_weight( + self, + ptq_linear: Int8DynActInt4WeightLinear, + fp32_weight: torch.Tensor, + group_size: int, + ): + """ + Set the weight to the quantized version of the given fp32 weights, + for making linear outputs comparable with QAT. + """ + n_bit = 4 + (qmin, qmax) = self._get_qmin_qmax(n_bit) + (s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size) + q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( + fp32_weight, s, zp, qmin, qmax, torch.int8, group_size, + ) + ptq_linear.weight = q_weight + ptq_linear.scales = s + ptq_linear.zeros = zp + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + def test_qat_8da4w_linear(self): + group_size = 128 + torch.manual_seed(self.SEED) + qat_linear = Int8DynActInt4WeightQATLinear(256, 688, bias=False, groupsize=group_size) + ptq_linear = Int8DynActInt4WeightLinear(256, 688, bias=False, groupsize=group_size) + + # Force the weights to be the same + self._set_ptq_weight(ptq_linear, qat_linear.weight, group_size) + + # Compare linear values + torch.manual_seed(self.SEED) + x = torch.randn(100, 256) + x2 = copy.deepcopy(x) + qat_out = qat_linear(x) + ptq_out = ptq_linear(x2) + torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + def test_qat_8da4w_quantizer(self): + group_size = 16 + torch.manual_seed(self.SEED) + m = M() + m2 = copy.deepcopy(m) + qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size) + qat_model = qat_quantizer.prepare(m) + ptq_model = ptq_quantizer.quantize(m2) + + # Force the weights to be the same + self._set_ptq_weight( + ptq_model.linear1, qat_model.linear1.weight, group_size, + ) + self._set_ptq_weight( + ptq_model.linear2, qat_model.linear2.weight, group_size, + ) + + # Compare model values + torch.manual_seed(self.SEED) + x = m.example_inputs() + x2 = copy.deepcopy(x) + qat_out = qat_model(*x) + ptq_out = ptq_model(*x2) + torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index fef3f83dcc..82533d6e47 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -9,12 +9,12 @@ import unittest import torch from torchao.quantization.quant_primitives import get_group_qparams_symmetric -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 class TestQuantPrimitives(unittest.TestCase): SEED = 123 - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.3 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") def test_get_group_qparams_symmetric(self): """ Test that `get_group_qparams_symmetric` produces the exact same scales as diff --git a/torchao/quantization/_prototype/qat.py b/torchao/quantization/_prototype/qat.py new file mode 100644 index 0000000000..5652ce9634 --- /dev/null +++ b/torchao/quantization/_prototype/qat.py @@ -0,0 +1,286 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Tuple + +import torch +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib +from torch.library import impl + +from torchao.quantization.GPTQ import _check_linear_int4_k +from torchao.quantization.quant_primitives import get_group_qparams_symmetric +from torchao.quantization.unified import TwoStepQuantizer + + +class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): + """ + Quantizer for performing QAT on a model, where linear layers have int8 + dynamic per token fake quantized activations and int4 fake quantized + grouped per channel weights. + """ + + def __init__( + self, + groupsize: int = 256, + padding_allowed: bool = False, + precision: torch.dtype = torch.float32, + qparams_precision: torch.dtype = torch.float32, + ) -> None: + super().__init__() + self.groupsize: int = groupsize + self.padding_allowed: bool = padding_allowed + self.precision: torch.dtype = precision + self.qparams_precision: torch.dtype = qparams_precision + + def prepare( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + replace_linear_8da4w_qat( + model, + self.groupsize, + self.padding_allowed, + self.precision, + self.precision, + ) + return model + + def convert( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + # TODO: replace Int8DynActInt4WeightQATLinear -> Int8DynActInt4WeightLinear + pass + + +class Int8DynActInt4WeightQATLinear(torch.nn.Linear): + """ + This module implements a linear layer with int8 dynamic per token fake + quantized activations with int4 fake quantized grouped per channel weights. + + args: + groupsize: the number of elements in each quantized group for weights + weight_precision: precision of weights + qparams_precision: precision of per group scales and zero points + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + groupsize: int = 256, + weight_precision: torch.dtype = torch.float32, + qparams_precision: torch.dtype = torch.float32, + ) -> None: + super().__init__( + in_features, + out_features, + bias, + device=None, + dtype=weight_precision, + ) + assert ( + in_features % groupsize == 0 + ), f"require in_features:{in_features} % groupsize:{groupsize} == 0" + assert not bias, "require bias=False" + self.groupsize = groupsize + self.qparams_precision = qparams_precision + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # activations: int8 dynamic asymmetric quant + (act_qmin, act_qmax) = self._get_qmin_qmax(8) + (act_scales, act_zp) = _choose_qparams_per_token_asymmetric( + x, torch.int8, # dtype not used + ) + x_fq = fake_quantize_per_token( + x, act_scales, act_zp, act_qmin, act_qmax, + ) + + # weights: int4 grouped per channel symmetric quant + (weight_qmin, weight_qmax) = self._get_qmin_qmax(4) + (weight_scales, weight_zp) = get_group_qparams_symmetric( + self.weight, 4, self.groupsize, self.qparams_precision, + ) + w_fq = fake_quantize_per_channel_group( + self.weight, + weight_scales, + weight_zp, + weight_qmin, + weight_qmax, + self.groupsize, + ) + return torch.nn.functional.linear(x_fq, w_fq) + + def _get_qmin_qmax(self, n_bit: int): + qmin = -(2 ** (n_bit - 1)) + qmax = 2 ** (n_bit - 1) - 1 + return (qmin, qmax) + + +def replace_linear_8da4w_qat( + module: torch.nn.Module, + groupsize: int, + padding_allowed: bool, + weight_precision: torch.dtype, + qparams_precision: torch.dtype, +): + """ + Replace `torch.nn.Linear` in the model with `Int8DynActInt4WeightQATLinear`. + """ + for name, child in module.named_children(): + if isinstance(child, torch.nn.Linear): + if _check_linear_int4_k(child.in_features, groupsize) or padding_allowed: + setattr( + module, + name, + Int8DynActInt4WeightQATLinear( + child.in_features, + child.out_features, + bias=False, + groupsize=groupsize, + weight_precision=weight_precision, + qparams_precision=qparams_precision, + ), + ) + else: + replace_linear_8da4w( + child, + groupsize, + padding_allowed, + weight_precision, + qparams_precision, + ) + + +# ======================== +# | QUANT PRIMITIVES | +# ======================== + +class _GenericFakeQuantize(torch.autograd.Function): + """ + Implementation of generic fake quantize with backward STE. + + With the appropriate input tensor shape, this can be used to express + grouped per channel fake quantize or per token fake quantize. + """ + + @staticmethod + def forward(ctx, input, scales, zero_points, quant_min, quant_max): + # Note: this diverges from `torch.fake_quantize_per_channel_affine`, + # which rounds first before adding the zero points. However, this + # is what `quantize_per_channel_group` and `quantize_per_token` + # do and here we try to match that behavior as closely as possible. + q = input.div(scales).add(zero_points).round() + dq = q.clamp(quant_min, quant_max).sub(zero_points).mul(scales) + # TODO: do we need this mask? + mask = torch.logical_and((q >= quant_min), (dq <= quant_max)) + ctx.save_for_backward(mask) + return dq + + @staticmethod + def backward(ctx, gy): + (mask,) = ctx.saved_tensors + return gy * mask, None, None, None, None, None + +# TODO: move this to core +quantized_decomposed_lib.define( + "fake_quantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, " + "int quant_min, int quant_max, int group_size) -> Tensor" +) + +@impl(quantized_decomposed_lib, "fake_quantize_per_channel_group", "CompositeImplicitAutograd") +def fake_quantize_per_channel_group( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + group_size: int, +) -> torch.Tensor: + assert group_size > 1 + assert input.shape[-1] % group_size == 0 + assert input.dim() == 2 + assert torch.isnan(input).sum() == 0 + grouped_input = input.reshape(-1, group_size) + scales = scales.reshape(-1, 1) + zero_points = zero_points.reshape(-1, 1) + fq = _GenericFakeQuantize.apply( + grouped_input, scales, zero_points, quant_min, quant_max, + ) + return fq.reshape_as(input) + +# TODO: move this to core +quantized_decomposed_lib.define( + "fake_quantize_per_token(Tensor input, Tensor scales, Tensor zero_points, " + "int quant_min, int quant_max) -> Tensor" +) + +@impl(quantized_decomposed_lib, "fake_quantize_per_token", "CompositeImplicitAutograd") +def fake_quantize_per_token( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, +) -> torch.Tensor: + # TODO: we won't need this import anymore once we move this to core + from torch.ao.quantization.fx._decomposed import _per_token_quant_qparam_dim_check + + _per_token_quant_qparam_dim_check(input, scales, zero_points) + return _GenericFakeQuantize.apply( + input, scales, zero_points, quant_min, quant_max, + ) + +# TODO: This is copied from torch/ao/quantization/fx/_decomposed.py. +# The version in pytorch does not have backward support yet so we add +# it here for now until https://github.com/pytorch/pytorch/pull/123452 +# is landed. +def _choose_qparams_per_token_asymmetric( + input: torch.Tensor, + dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Choose quantization parameters for per token quantization. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): original float32/float16 Tensor + dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + + Returns: + scales and zero_points, both float32 Tensors + """ + # Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18 + qmin, qmax = -128, 127 + min_val = torch.amin(input, dim=-1, keepdim=True) + max_val = torch.amax(input, dim=-1, keepdim=True) + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + eps = torch.finfo(torch.float32).eps # use xnnpack eps? + + # scale + scale = (max_val_pos - min_val_neg) / float(qmax - qmin) + scale = scale.clamp(min=eps) + + # zero point + descaled_min = min_val_neg / scale + descaled_max = max_val_pos / scale + zero_point_from_min_error = qmin + descaled_min + zero_point_from_max_error = qmax + descaled_max + zero_point = torch.where( + zero_point_from_min_error + zero_point_from_max_error > 0, + qmin - descaled_min, + qmax - descaled_max, + ) + zero_point = torch.clamp(zero_point, qmin, qmax).round() + + return scale.to(torch.float32), zero_point.to(torch.float32)