Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor rest of tinygemm quant primitive ops #321

Merged
merged 1 commit into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 107 additions & 2 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from torchao.quantization.quant_primitives import (
get_group_qparams_symmetric,
get_groupwise_affine_qparams,
groupwise_affine_quantize_tensor_from_qparams,
groupwise_affine_dequantize_tensor_from_qparams,
quantize_affine,
dequantize_affine,
choose_qparams_affine,
Expand Down Expand Up @@ -38,6 +40,86 @@ def check_idempotent(self, fn, *args, **kwargs):
self.assertTrue(torch.equal(output0, output1), f"Expected given function {fn} to be idempotent.")
return output1

# Legacy tinygemm ops
def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2

to_quant = w.reshape(-1, groupsize)
# assert torch.isnan(to_quant).sum() == 0

max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to(
dtype=dtype
).reshape(w.shape[0], -1)

def _groupwise_affine_quantize_tensor_from_qparams(
w,
scales,
zeros,
n_bit=4,
groupsize=128,
):
assert groupsize > 1
# needed for GPTQ single column quantize
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
groupsize = w.shape[-1]

assert w.shape[-1] % groupsize == 0
assert w.dim() == 2

to_quant = w.reshape(-1, groupsize)
# assert torch.isnan(to_quant).sum() == 0

scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
min_val = zeros - scales * (2 ** (n_bit - 1))
max_int = 2**n_bit - 1
min_int = 0
w_int4x8 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)

return w_int4x8

def _groupwise_affine_dequantize_tensor_from_qparams(
w_int4x8,
scales,
zeros,
n_bit=4,
groupsize=128,
):
assert groupsize > 1
# needed for GPTQ single column dequantize
if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1:
groupsize = w_int4x8.shape[-1]
assert w_int4x8.shape[-1] % groupsize == 0
assert w_int4x8.dim() == 2

w_int4x8_grouped = w_int4x8.reshape(-1, groupsize)
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)

w_dq = (
w_int4x8_grouped.sub(2 ** (n_bit - 1))
.mul(scales)
.add(zeros)
.reshape_as(w_int4x8)
)
return w_dq


class TestQuantPrimitives(unittest.TestCase):
SEED = 123
Expand Down Expand Up @@ -356,12 +438,12 @@ def test_not_preserve_zero_not_supported(self):
)


def test_tinygemm_get_groupwise_affine_qparams(self):
def test_get_groupwise_affine_qparams(self):
from torchao.quantization.quant_primitives import ZeroPointDomain

input = torch.randn(10, 256)
n_bit = 4
scale_ref, zero_point_ref = get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)
scale_ref, zero_point_ref = _get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)

mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
Expand Down Expand Up @@ -389,6 +471,29 @@ def test_tinygemm_get_groupwise_affine_qparams(self):
self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zero_point_ref))

def test_groupwise_affine_quantize_tensor_from_qparams(self):
input = torch.randn(10, 256)
scales = torch.randn(10, 2)
zeros = torch.randn(10, 2)
n_bit = 4
groupsize = 128

w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)

self.assertTrue(torch.equal(w_int4x8, w_int4x8_ref))

def test_groupwise_affine_dequantize_tensor_from_qparams(self):
input = torch.randint(0, 15, (10, 256), dtype=torch.int32)
scales = torch.randn(10, 2).bfloat16()
zeros = torch.randn(10, 2).bfloat16()
n_bit = 4
groupsize = 128

w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)

self.assertTrue(torch.equal(w_bf16, w_bf16_ref))

if __name__ == "__main__":
unittest.main()
76 changes: 35 additions & 41 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def dequantize_affine(

# TODO: validations
# TODO: validate scale/zero_point dimensions are compatible with block_size
assert input.dtype == input_dtype
assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}"
assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}"
quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)

Expand Down Expand Up @@ -644,22 +644,37 @@ def quant_int8_per_token_matmul(


def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
"""This is tinygemm specific, we'll keep this for now"""
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
assert n_bit <= 8, f"only n_bit smaller than 8 is supported, got: {n_bit}"

to_quant = w.reshape(-1, groupsize)
# assert torch.isnan(to_quant).sum() == 0
mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int32
block_size = (1, groupsize)
quant_min = 0
quant_max = 2**n_bit - 1
eps = 1e-6
scale_dtype = dtype
zero_point_dtype = dtype

scale, zero_point = choose_qparams_affine(
w,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=False,
zero_point_domain=ZeroPointDomain.FLOAT
)

max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to(
return scale.to(dtype=dtype).reshape(w.shape[0], -1), zero_point.to(
dtype=dtype
).reshape(w.shape[0], -1)

Expand Down Expand Up @@ -692,7 +707,6 @@ def groupwise_affine_quantize_tensor_from_qparams(
n_bit=4,
groupsize=128,
):
"""This is tinygemm specific, we'll keep this for now"""
assert groupsize > 1
# needed for GPTQ single column quantize
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
Expand All @@ -701,25 +715,12 @@ def groupwise_affine_quantize_tensor_from_qparams(
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2

to_quant = w.reshape(-1, groupsize)
# assert torch.isnan(to_quant).sum() == 0

scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
min_val = zeros - scales * (2 ** (n_bit - 1))
max_int = 2**n_bit - 1
min_int = 0
w_int4x8 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)

return w_int4x8
block_size = (1, groupsize)
output_dtype = torch.int32
quant_min = 0
quant_max = 2 ** n_bit - 1

return quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)

def groupwise_affine_dequantize_tensor_from_qparams(
w_int4x8,
Expand All @@ -728,25 +729,18 @@ def groupwise_affine_dequantize_tensor_from_qparams(
n_bit=4,
groupsize=128,
):
"""This is tinygemm specific, we'll keep this for now"""
assert groupsize > 1
# needed for GPTQ single column dequantize
if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1:
groupsize = w_int4x8.shape[-1]
assert w_int4x8.shape[-1] % groupsize == 0
assert w_int4x8.dim() == 2

w_int4x8_grouped = w_int4x8.reshape(-1, groupsize)
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)

w_dq = (
w_int4x8_grouped.sub(2 ** (n_bit - 1))
.mul(scales)
.add(zeros)
.reshape_as(w_int4x8)
)
return w_dq
block_size = (1, groupsize)
input_dtype = torch.int32
quant_min = 0
quant_max = 2**n_bit - 1
return dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype)


def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
Expand Down
Loading