diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py new file mode 100644 index 0000000000..dc898bc63f --- /dev/null +++ b/test/quantization/test_qat.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# mypy: ignore-errors +# This test takes a long time to run + +import copy +import unittest + +import torch +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torchao.quantization.prototype.qat import ( + _choose_qparams_per_token_asymmetric, + fake_quantize_per_channel_group, + fake_quantize_per_token, + Int8DynActInt4WeightQATLinear, + Int8DynActInt4WeightQATQuantizer, +) +from torchao.quantization.quant_primitives import ( + get_group_qparams_symmetric, + group_quantize_tensor_symmetric, + per_token_dynamic_quant, +) +from torchao.quantization.utils import ( + TORCH_VERSION_AFTER_2_3, +) +from torchao.quantization.GPTQ import ( + Int8DynActInt4WeightLinear, + Int8DynActInt4WeightQuantizer, +) + + +# TODO: put this in a common test utils file +class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float) + self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float) + + def example_inputs(self): + return (torch.randn(1, 64).to(torch.float),) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +class TestQAT(unittest.TestCase): + SEED = 123 + + def _get_qmin_qmax(self, n_bit: int): + qmin = -(2 ** (n_bit - 1)) + qmax = 2 ** (n_bit - 1) - 1 + return (qmin, qmax) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + def test_fake_quantize_per_channel_group(self): + n_bit = 4 + (qmin, qmax) = self._get_qmin_qmax(n_bit) + group_size = 128 + + torch.manual_seed(self.SEED) + x = torch.randn(100, 256).requires_grad_() + (s, zp) = get_group_qparams_symmetric(x, n_bit, group_size) + x2 = copy.deepcopy(x) + + # fake quant op + out = fake_quantize_per_channel_group( + x, s, zp, qmin, qmax, group_size, + ) + out.sum().backward() + + # compare against PTQ ops + out_ptq = torch.ops.quantized_decomposed.quantize_per_channel_group( + x2, s, zp, qmin, qmax, torch.int8, group_size, + ) + out_ptq = torch.ops.quantized_decomposed.dequantize_per_channel_group( + out_ptq, s, zp, qmin, qmax, torch.int8, group_size, torch.float32, + ) + torch.testing.assert_close(out, out_ptq, atol=0, rtol=0) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + def test_fake_quantize_per_token(self): + (qmin, qmax) = self._get_qmin_qmax(8) + + torch.manual_seed(self.SEED) + x = torch.randn(100, 256).requires_grad_() + x2 = copy.deepcopy(x) + # TODO: use torch.ops.aten.quantized_decomposed version instead + (s, zp) = _choose_qparams_per_token_asymmetric( + x, + torch.int8, # not used + ) + + # fake quant op + out = fake_quantize_per_token(x, s, zp, qmin, qmax) + out.sum().backward() + + # compare against PTQ ops + out_ptq = torch.ops.quantized_decomposed.quantize_per_token( + x2, s, zp, qmin, qmax, torch.int8, + ) + out_ptq = torch.ops.quantized_decomposed.dequantize_per_token( + out_ptq, s, zp, qmin, qmax, torch.int8, torch.float32, + ) + torch.testing.assert_close(out, out_ptq, atol=0, rtol=0) + + def _set_ptq_weight( + self, + ptq_linear: Int8DynActInt4WeightLinear, + fp32_weight: torch.Tensor, + group_size: int, + ): + """ + Set the weight to the quantized version of the given fp32 weights, + for making linear outputs comparable with QAT. + """ + n_bit = 4 + (qmin, qmax) = self._get_qmin_qmax(n_bit) + (s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size) + q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( + fp32_weight, s, zp, qmin, qmax, torch.int8, group_size, + ) + ptq_linear.weight = q_weight + ptq_linear.scales = s + ptq_linear.zeros = zp + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + def test_qat_8da4w_linear(self): + group_size = 128 + torch.manual_seed(self.SEED) + qat_linear = Int8DynActInt4WeightQATLinear( + 256, 688, bias=False, groupsize=group_size, + ) + ptq_linear = Int8DynActInt4WeightLinear( + 256, 688, bias=False, groupsize=group_size, + ) + + # Force the weights to be the same + self._set_ptq_weight(ptq_linear, qat_linear.weight, group_size) + + # Compare linear values + torch.manual_seed(self.SEED) + x = torch.randn(100, 256) + x2 = copy.deepcopy(x) + qat_out = qat_linear(x) + ptq_out = ptq_linear(x2) + torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + def test_qat_8da4w_quantizer(self): + group_size = 16 + torch.manual_seed(self.SEED) + m = M() + m2 = copy.deepcopy(m) + qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size) + qat_model = qat_quantizer.prepare(m) + ptq_model = ptq_quantizer.quantize(m2) + + # Force the weights to be the same + self._set_ptq_weight( + ptq_model.linear1, qat_model.linear1.weight, group_size, + ) + self._set_ptq_weight( + ptq_model.linear2, qat_model.linear2.weight, group_size, + ) + + # Compare model values + torch.manual_seed(self.SEED) + x = m.example_inputs() + x2 = copy.deepcopy(x) + qat_out = qat_model(*x) + ptq_out = ptq_model(*x2) + torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index fef3f83dcc..82533d6e47 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -9,12 +9,12 @@ import unittest import torch from torchao.quantization.quant_primitives import get_group_qparams_symmetric -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 class TestQuantPrimitives(unittest.TestCase): SEED = 123 - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.3 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") def test_get_group_qparams_symmetric(self): """ Test that `get_group_qparams_symmetric` produces the exact same scales as diff --git a/torchao/quantization/prototype/__init__.py b/torchao/quantization/prototype/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/quantization/prototype/qat.py b/torchao/quantization/prototype/qat.py new file mode 100644 index 0000000000..5652ce9634 --- /dev/null +++ b/torchao/quantization/prototype/qat.py @@ -0,0 +1,286 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Tuple + +import torch +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib +from torch.library import impl + +from torchao.quantization.GPTQ import _check_linear_int4_k +from torchao.quantization.quant_primitives import get_group_qparams_symmetric +from torchao.quantization.unified import TwoStepQuantizer + + +class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): + """ + Quantizer for performing QAT on a model, where linear layers have int8 + dynamic per token fake quantized activations and int4 fake quantized + grouped per channel weights. + """ + + def __init__( + self, + groupsize: int = 256, + padding_allowed: bool = False, + precision: torch.dtype = torch.float32, + qparams_precision: torch.dtype = torch.float32, + ) -> None: + super().__init__() + self.groupsize: int = groupsize + self.padding_allowed: bool = padding_allowed + self.precision: torch.dtype = precision + self.qparams_precision: torch.dtype = qparams_precision + + def prepare( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + replace_linear_8da4w_qat( + model, + self.groupsize, + self.padding_allowed, + self.precision, + self.precision, + ) + return model + + def convert( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + # TODO: replace Int8DynActInt4WeightQATLinear -> Int8DynActInt4WeightLinear + pass + + +class Int8DynActInt4WeightQATLinear(torch.nn.Linear): + """ + This module implements a linear layer with int8 dynamic per token fake + quantized activations with int4 fake quantized grouped per channel weights. + + args: + groupsize: the number of elements in each quantized group for weights + weight_precision: precision of weights + qparams_precision: precision of per group scales and zero points + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + groupsize: int = 256, + weight_precision: torch.dtype = torch.float32, + qparams_precision: torch.dtype = torch.float32, + ) -> None: + super().__init__( + in_features, + out_features, + bias, + device=None, + dtype=weight_precision, + ) + assert ( + in_features % groupsize == 0 + ), f"require in_features:{in_features} % groupsize:{groupsize} == 0" + assert not bias, "require bias=False" + self.groupsize = groupsize + self.qparams_precision = qparams_precision + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # activations: int8 dynamic asymmetric quant + (act_qmin, act_qmax) = self._get_qmin_qmax(8) + (act_scales, act_zp) = _choose_qparams_per_token_asymmetric( + x, torch.int8, # dtype not used + ) + x_fq = fake_quantize_per_token( + x, act_scales, act_zp, act_qmin, act_qmax, + ) + + # weights: int4 grouped per channel symmetric quant + (weight_qmin, weight_qmax) = self._get_qmin_qmax(4) + (weight_scales, weight_zp) = get_group_qparams_symmetric( + self.weight, 4, self.groupsize, self.qparams_precision, + ) + w_fq = fake_quantize_per_channel_group( + self.weight, + weight_scales, + weight_zp, + weight_qmin, + weight_qmax, + self.groupsize, + ) + return torch.nn.functional.linear(x_fq, w_fq) + + def _get_qmin_qmax(self, n_bit: int): + qmin = -(2 ** (n_bit - 1)) + qmax = 2 ** (n_bit - 1) - 1 + return (qmin, qmax) + + +def replace_linear_8da4w_qat( + module: torch.nn.Module, + groupsize: int, + padding_allowed: bool, + weight_precision: torch.dtype, + qparams_precision: torch.dtype, +): + """ + Replace `torch.nn.Linear` in the model with `Int8DynActInt4WeightQATLinear`. + """ + for name, child in module.named_children(): + if isinstance(child, torch.nn.Linear): + if _check_linear_int4_k(child.in_features, groupsize) or padding_allowed: + setattr( + module, + name, + Int8DynActInt4WeightQATLinear( + child.in_features, + child.out_features, + bias=False, + groupsize=groupsize, + weight_precision=weight_precision, + qparams_precision=qparams_precision, + ), + ) + else: + replace_linear_8da4w( + child, + groupsize, + padding_allowed, + weight_precision, + qparams_precision, + ) + + +# ======================== +# | QUANT PRIMITIVES | +# ======================== + +class _GenericFakeQuantize(torch.autograd.Function): + """ + Implementation of generic fake quantize with backward STE. + + With the appropriate input tensor shape, this can be used to express + grouped per channel fake quantize or per token fake quantize. + """ + + @staticmethod + def forward(ctx, input, scales, zero_points, quant_min, quant_max): + # Note: this diverges from `torch.fake_quantize_per_channel_affine`, + # which rounds first before adding the zero points. However, this + # is what `quantize_per_channel_group` and `quantize_per_token` + # do and here we try to match that behavior as closely as possible. + q = input.div(scales).add(zero_points).round() + dq = q.clamp(quant_min, quant_max).sub(zero_points).mul(scales) + # TODO: do we need this mask? + mask = torch.logical_and((q >= quant_min), (dq <= quant_max)) + ctx.save_for_backward(mask) + return dq + + @staticmethod + def backward(ctx, gy): + (mask,) = ctx.saved_tensors + return gy * mask, None, None, None, None, None + +# TODO: move this to core +quantized_decomposed_lib.define( + "fake_quantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, " + "int quant_min, int quant_max, int group_size) -> Tensor" +) + +@impl(quantized_decomposed_lib, "fake_quantize_per_channel_group", "CompositeImplicitAutograd") +def fake_quantize_per_channel_group( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + group_size: int, +) -> torch.Tensor: + assert group_size > 1 + assert input.shape[-1] % group_size == 0 + assert input.dim() == 2 + assert torch.isnan(input).sum() == 0 + grouped_input = input.reshape(-1, group_size) + scales = scales.reshape(-1, 1) + zero_points = zero_points.reshape(-1, 1) + fq = _GenericFakeQuantize.apply( + grouped_input, scales, zero_points, quant_min, quant_max, + ) + return fq.reshape_as(input) + +# TODO: move this to core +quantized_decomposed_lib.define( + "fake_quantize_per_token(Tensor input, Tensor scales, Tensor zero_points, " + "int quant_min, int quant_max) -> Tensor" +) + +@impl(quantized_decomposed_lib, "fake_quantize_per_token", "CompositeImplicitAutograd") +def fake_quantize_per_token( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, +) -> torch.Tensor: + # TODO: we won't need this import anymore once we move this to core + from torch.ao.quantization.fx._decomposed import _per_token_quant_qparam_dim_check + + _per_token_quant_qparam_dim_check(input, scales, zero_points) + return _GenericFakeQuantize.apply( + input, scales, zero_points, quant_min, quant_max, + ) + +# TODO: This is copied from torch/ao/quantization/fx/_decomposed.py. +# The version in pytorch does not have backward support yet so we add +# it here for now until https://github.com/pytorch/pytorch/pull/123452 +# is landed. +def _choose_qparams_per_token_asymmetric( + input: torch.Tensor, + dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Choose quantization parameters for per token quantization. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): original float32/float16 Tensor + dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + + Returns: + scales and zero_points, both float32 Tensors + """ + # Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18 + qmin, qmax = -128, 127 + min_val = torch.amin(input, dim=-1, keepdim=True) + max_val = torch.amax(input, dim=-1, keepdim=True) + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + eps = torch.finfo(torch.float32).eps # use xnnpack eps? + + # scale + scale = (max_val_pos - min_val_neg) / float(qmax - qmin) + scale = scale.clamp(min=eps) + + # zero point + descaled_min = min_val_neg / scale + descaled_max = max_val_pos / scale + zero_point_from_min_error = qmin + descaled_min + zero_point_from_max_error = qmax + descaled_max + zero_point = torch.where( + zero_point_from_min_error + zero_point_from_max_error > 0, + qmin - descaled_min, + qmax - descaled_max, + ) + zero_point = torch.clamp(zero_point, qmin, qmax).round() + + return scale.to(torch.float32), zero_point.to(torch.float32)