Skip to content

Commit

Permalink
Refactor rest of tinygemm quant primitive ops (pytorch#321)
Browse files Browse the repository at this point in the history
Summary:
This PR replaces the remaining tinygemm specific quant primitive ops with the general quant primitive ops
that we want to use for everything, we could delete these ops in a separate PR if needed

Test Plan:
python test/quantization/test_quant_primitives.py -k test_get_groupwise_affine_qparams
python test/quantization/test_quant_primitives.py -k test_groupwise_affine_quantize_tensor_from_qparams
python test/quantization/test_quant_primitives.py -k test_groupwise_affine_dequantize_tensor_from_qparams

accuracy:

perf:
no diff for generated code with `TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py`
  • Loading branch information
jerryzh168 authored Jun 5, 2024
1 parent 2b91917 commit a4cc35e
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 43 deletions.
109 changes: 107 additions & 2 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from torchao.quantization.quant_primitives import (
get_group_qparams_symmetric,
get_groupwise_affine_qparams,
groupwise_affine_quantize_tensor_from_qparams,
groupwise_affine_dequantize_tensor_from_qparams,
quantize_affine,
dequantize_affine,
choose_qparams_affine,
Expand Down Expand Up @@ -38,6 +40,86 @@ def check_idempotent(self, fn, *args, **kwargs):
self.assertTrue(torch.equal(output0, output1), f"Expected given function {fn} to be idempotent.")
return output1

# Legacy tinygemm ops
def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2

to_quant = w.reshape(-1, groupsize)
# assert torch.isnan(to_quant).sum() == 0

max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to(
dtype=dtype
).reshape(w.shape[0], -1)

def _groupwise_affine_quantize_tensor_from_qparams(
w,
scales,
zeros,
n_bit=4,
groupsize=128,
):
assert groupsize > 1
# needed for GPTQ single column quantize
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
groupsize = w.shape[-1]

assert w.shape[-1] % groupsize == 0
assert w.dim() == 2

to_quant = w.reshape(-1, groupsize)
# assert torch.isnan(to_quant).sum() == 0

scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
min_val = zeros - scales * (2 ** (n_bit - 1))
max_int = 2**n_bit - 1
min_int = 0
w_int4x8 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)

return w_int4x8

def _groupwise_affine_dequantize_tensor_from_qparams(
w_int4x8,
scales,
zeros,
n_bit=4,
groupsize=128,
):
assert groupsize > 1
# needed for GPTQ single column dequantize
if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1:
groupsize = w_int4x8.shape[-1]
assert w_int4x8.shape[-1] % groupsize == 0
assert w_int4x8.dim() == 2

w_int4x8_grouped = w_int4x8.reshape(-1, groupsize)
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)

w_dq = (
w_int4x8_grouped.sub(2 ** (n_bit - 1))
.mul(scales)
.add(zeros)
.reshape_as(w_int4x8)
)
return w_dq


class TestQuantPrimitives(unittest.TestCase):
SEED = 123
Expand Down Expand Up @@ -356,12 +438,12 @@ def test_not_preserve_zero_not_supported(self):
)


def test_tinygemm_get_groupwise_affine_qparams(self):
def test_get_groupwise_affine_qparams(self):
from torchao.quantization.quant_primitives import ZeroPointDomain

input = torch.randn(10, 256)
n_bit = 4
scale_ref, zero_point_ref = get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)
scale_ref, zero_point_ref = _get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)

mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
Expand Down Expand Up @@ -389,6 +471,29 @@ def test_tinygemm_get_groupwise_affine_qparams(self):
self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zero_point_ref))

def test_groupwise_affine_quantize_tensor_from_qparams(self):
input = torch.randn(10, 256)
scales = torch.randn(10, 2)
zeros = torch.randn(10, 2)
n_bit = 4
groupsize = 128

w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)

self.assertTrue(torch.equal(w_int4x8, w_int4x8_ref))

def test_groupwise_affine_dequantize_tensor_from_qparams(self):
input = torch.randint(0, 15, (10, 256), dtype=torch.int32)
scales = torch.randn(10, 2).bfloat16()
zeros = torch.randn(10, 2).bfloat16()
n_bit = 4
groupsize = 128

w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)

self.assertTrue(torch.equal(w_bf16, w_bf16_ref))

if __name__ == "__main__":
unittest.main()
76 changes: 35 additions & 41 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def dequantize_affine(

# TODO: validations
# TODO: validate scale/zero_point dimensions are compatible with block_size
assert input.dtype == input_dtype
assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}"
assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}"
quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)

Expand Down Expand Up @@ -644,22 +644,37 @@ def quant_int8_per_token_matmul(


def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
"""This is tinygemm specific, we'll keep this for now"""
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
assert n_bit <= 8, f"only n_bit smaller than 8 is supported, got: {n_bit}"

to_quant = w.reshape(-1, groupsize)
# assert torch.isnan(to_quant).sum() == 0
mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int32
block_size = (1, groupsize)
quant_min = 0
quant_max = 2**n_bit - 1
eps = 1e-6
scale_dtype = dtype
zero_point_dtype = dtype

scale, zero_point = choose_qparams_affine(
w,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=False,
zero_point_domain=ZeroPointDomain.FLOAT
)

max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to(
return scale.to(dtype=dtype).reshape(w.shape[0], -1), zero_point.to(
dtype=dtype
).reshape(w.shape[0], -1)

Expand Down Expand Up @@ -692,7 +707,6 @@ def groupwise_affine_quantize_tensor_from_qparams(
n_bit=4,
groupsize=128,
):
"""This is tinygemm specific, we'll keep this for now"""
assert groupsize > 1
# needed for GPTQ single column quantize
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
Expand All @@ -701,25 +715,12 @@ def groupwise_affine_quantize_tensor_from_qparams(
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2

to_quant = w.reshape(-1, groupsize)
# assert torch.isnan(to_quant).sum() == 0

scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
min_val = zeros - scales * (2 ** (n_bit - 1))
max_int = 2**n_bit - 1
min_int = 0
w_int4x8 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)

return w_int4x8
block_size = (1, groupsize)
output_dtype = torch.int32
quant_min = 0
quant_max = 2 ** n_bit - 1

return quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)

def groupwise_affine_dequantize_tensor_from_qparams(
w_int4x8,
Expand All @@ -728,25 +729,18 @@ def groupwise_affine_dequantize_tensor_from_qparams(
n_bit=4,
groupsize=128,
):
"""This is tinygemm specific, we'll keep this for now"""
assert groupsize > 1
# needed for GPTQ single column dequantize
if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1:
groupsize = w_int4x8.shape[-1]
assert w_int4x8.shape[-1] % groupsize == 0
assert w_int4x8.dim() == 2

w_int4x8_grouped = w_int4x8.reshape(-1, groupsize)
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)

w_dq = (
w_int4x8_grouped.sub(2 ** (n_bit - 1))
.mul(scales)
.add(zeros)
.reshape_as(w_int4x8)
)
return w_dq
block_size = (1, groupsize)
input_dtype = torch.int32
quant_min = 0
quant_max = 2**n_bit - 1
return dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype)


def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
Expand Down

0 comments on commit a4cc35e

Please sign in to comment.