Skip to content

Commit

Permalink
Fix CI after quantize op change in PyTorch core
Browse files Browse the repository at this point in the history
Summary: pytorch/pytorch#125781 recently
changed the numerics of the quantize op subtly. This commit fixes
the numerics mismatch caused by this PR by making our quantize
ops consistent with the ones in core.

Test Plan:
python test/quantization/test_quant_primitives.py -k test_quantize_dequantize_group_sym
python test/quantization/test_quant_api.py TestQuantFlow.test_quantized_tensor_subclass_8da4w

Reviewers: jerryzh168, cpuhrsch

Subscribers: jerryzh168, cpuhrsch, supriyar
  • Loading branch information
andrewor14 committed May 15, 2024
1 parent 10da375 commit 45d0868
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_quantize_activation_per_token_abs_max_zero_input(self):
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)


@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_quantize_dequantize_group_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/prototype/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def forward(ctx, input, scales, zero_points, quant_min, quant_max):
# which rounds first before adding the zero points. However, this
# is what `quantize_per_channel_group` and `quantize_per_token`
# do and here we try to match that behavior as closely as possible.
q = input.div(scales).add(zero_points).round()
q = input.mul(1.0 / scales).add(zero_points).round()
dq = q.clamp(quant_min, quant_max).sub(zero_points).mul(scales)
# TODO: do we need this mask?
mask = torch.logical_and((q >= quant_min), (q <= quant_max))
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def quantize_affine(

if zero_point_domain == ZeroPointDomain.INT:
quant = torch.clamp(
torch.round(input / scale) + zero_point, quant_min, quant_max
torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
).to(output_dtype)
else:
assert zero_point_domain == ZeroPointDomain.FLOAT
Expand Down

0 comments on commit 45d0868

Please sign in to comment.