From 45d0868a83d927cb987d78c13a402906da6d43f0 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Wed, 15 May 2024 11:41:28 -0700 Subject: [PATCH] Fix CI after quantize op change in PyTorch core Summary: https://github.com/pytorch/pytorch/pull/125781 recently changed the numerics of the quantize op subtly. This commit fixes the numerics mismatch caused by this PR by making our quantize ops consistent with the ones in core. Test Plan: python test/quantization/test_quant_primitives.py -k test_quantize_dequantize_group_sym python test/quantization/test_quant_api.py TestQuantFlow.test_quantized_tensor_subclass_8da4w Reviewers: jerryzh168, cpuhrsch Subscribers: jerryzh168, cpuhrsch, supriyar --- test/quantization/test_quant_primitives.py | 2 +- torchao/quantization/prototype/qat.py | 2 +- torchao/quantization/quant_primitives.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index a64439a25e..0fb48d761b 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -156,7 +156,7 @@ def test_quantize_activation_per_token_abs_max_zero_input(self): quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") def test_quantize_dequantize_group_sym(self): input = torch.randn(10, 10) mapping_type = MappingType.SYMMETRIC diff --git a/torchao/quantization/prototype/qat.py b/torchao/quantization/prototype/qat.py index d15e841d74..314543bb8e 100644 --- a/torchao/quantization/prototype/qat.py +++ b/torchao/quantization/prototype/qat.py @@ -209,7 +209,7 @@ def forward(ctx, input, scales, zero_points, quant_min, quant_max): # which rounds first before adding the zero points. However, this # is what `quantize_per_channel_group` and `quantize_per_token` # do and here we try to match that behavior as closely as possible. - q = input.div(scales).add(zero_points).round() + q = input.mul(1.0 / scales).add(zero_points).round() dq = q.clamp(quant_min, quant_max).sub(zero_points).mul(scales) # TODO: do we need this mask? mask = torch.logical_and((q >= quant_min), (q <= quant_max)) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 4f39a6055d..30c6854480 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -201,7 +201,7 @@ def quantize_affine( if zero_point_domain == ZeroPointDomain.INT: quant = torch.clamp( - torch.round(input / scale) + zero_point, quant_min, quant_max + torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max ).to(output_dtype) else: assert zero_point_domain == ZeroPointDomain.FLOAT