Skip to content

Commit

Permalink
Initial support for 8da4w QAT
Browse files Browse the repository at this point in the history
Summary: This commit adds support for QAT, where linear layers
are fake quantized with int8 per token dynamic activations (8da)
and int4 grouped per channel weights (4w). This initial
implementation uses the same module swap approach as 8da4w PTQ
for simplicity and code reuse. In the future, we may wish to
consider migrating both flows to use tensor subclasses for
better composability with other PyTorch features.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group
python test/quantization/test_qat.py -k test_fake_quantize_per_token
python test/quantization/test_qat.py -k test_qat_8da4w_linear
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer

Reviewers: jerryzh168, cpuhrsch, HDCharles

Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar

Tasks: #86
  • Loading branch information
andrewor14 committed Apr 16, 2024
1 parent d76ecc2 commit d5cd97f
Show file tree
Hide file tree
Showing 4 changed files with 471 additions and 2 deletions.
183 changes: 183 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# mypy: ignore-errors
# This test takes a long time to run

import copy
import unittest

import torch
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torchao.quantization.prototype.qat import (
_choose_qparams_per_token_asymmetric,
fake_quantize_per_channel_group,
fake_quantize_per_token,
Int8DynActInt4WeightQATLinear,
Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.quant_primitives import (
get_group_qparams_symmetric,
group_quantize_tensor_symmetric,
per_token_dynamic_quant,
)
from torchao.quantization.utils import (
TORCH_VERSION_AFTER_2_3,
)
from torchao.quantization.GPTQ import (
Int8DynActInt4WeightLinear,
Int8DynActInt4WeightQuantizer,
)


# TODO: put this in a common test utils file
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, 64).to(torch.float),)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


class TestQAT(unittest.TestCase):
SEED = 123

def _get_qmin_qmax(self, n_bit: int):
qmin = -(2 ** (n_bit - 1))
qmax = 2 ** (n_bit - 1) - 1
return (qmin, qmax)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_fake_quantize_per_channel_group(self):
n_bit = 4
(qmin, qmax) = self._get_qmin_qmax(n_bit)
group_size = 128

torch.manual_seed(self.SEED)
x = torch.randn(100, 256).requires_grad_()
(s, zp) = get_group_qparams_symmetric(x, n_bit, group_size)
x2 = copy.deepcopy(x)

# fake quant op
out = fake_quantize_per_channel_group(
x, s, zp, qmin, qmax, group_size,
)
out.sum().backward()

# compare against PTQ ops
out_ptq = torch.ops.quantized_decomposed.quantize_per_channel_group(
x2, s, zp, qmin, qmax, torch.int8, group_size,
)
out_ptq = torch.ops.quantized_decomposed.dequantize_per_channel_group(
out_ptq, s, zp, qmin, qmax, torch.int8, group_size, torch.float32,
)
torch.testing.assert_close(out, out_ptq, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_fake_quantize_per_token(self):
(qmin, qmax) = self._get_qmin_qmax(8)

torch.manual_seed(self.SEED)
x = torch.randn(100, 256).requires_grad_()
x2 = copy.deepcopy(x)
# TODO: use torch.ops.aten.quantized_decomposed version instead
(s, zp) = _choose_qparams_per_token_asymmetric(
x,
torch.int8, # not used
)

# fake quant op
out = fake_quantize_per_token(x, s, zp, qmin, qmax)
out.sum().backward()

# compare against PTQ ops
out_ptq = torch.ops.quantized_decomposed.quantize_per_token(
x2, s, zp, qmin, qmax, torch.int8,
)
out_ptq = torch.ops.quantized_decomposed.dequantize_per_token(
out_ptq, s, zp, qmin, qmax, torch.int8, torch.float32,
)
torch.testing.assert_close(out, out_ptq, atol=0, rtol=0)

def _set_ptq_weight(
self,
ptq_linear: Int8DynActInt4WeightLinear,
fp32_weight: torch.Tensor,
group_size: int,
):
"""
Set the weight to the quantized version of the given fp32 weights,
for making linear outputs comparable with QAT.
"""
n_bit = 4
(qmin, qmax) = self._get_qmin_qmax(n_bit)
(s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size)
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
fp32_weight, s, zp, qmin, qmax, torch.int8, group_size,
)
ptq_linear.weight = q_weight
ptq_linear.scales = s
ptq_linear.zeros = zp

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_qat_8da4w_linear(self):
group_size = 128
torch.manual_seed(self.SEED)
qat_linear = Int8DynActInt4WeightQATLinear(
256, 688, bias=False, groupsize=group_size,
)
ptq_linear = Int8DynActInt4WeightLinear(
256, 688, bias=False, groupsize=group_size,
)

# Force the weights to be the same
self._set_ptq_weight(ptq_linear, qat_linear.weight, group_size)

# Compare linear values
torch.manual_seed(self.SEED)
x = torch.randn(100, 256)
x2 = copy.deepcopy(x)
qat_out = qat_linear(x)
ptq_out = ptq_linear(x2)
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_qat_8da4w_quantizer(self):
group_size = 16
torch.manual_seed(self.SEED)
m = M()
m2 = copy.deepcopy(m)
qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size)
qat_model = qat_quantizer.prepare(m)
ptq_model = ptq_quantizer.quantize(m2)

# Force the weights to be the same
self._set_ptq_weight(
ptq_model.linear1, qat_model.linear1.weight, group_size,
)
self._set_ptq_weight(
ptq_model.linear2, qat_model.linear2.weight, group_size,
)

# Compare model values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
qat_out = qat_model(*x)
ptq_out = ptq_model(*x2)
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import unittest
import torch
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3

class TestQuantPrimitives(unittest.TestCase):
SEED = 123

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_get_group_qparams_symmetric(self):
"""
Test that `get_group_qparams_symmetric` produces the exact same scales as
Expand Down
Empty file.
Loading

0 comments on commit d5cd97f

Please sign in to comment.