Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor QAT to use common fake_quantize_affine primitive #527

Merged
merged 1 commit into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,19 @@ def test_qat_generic_fake_quantize(self):

ao_input = copy.deepcopy(py_input)
ao_input.grad.data.zero_()
ao_s = copy.deepcopy(py_s).reshape(-1, 1)
ao_zp = copy.deepcopy(py_zp).reshape(-1, 1)
ao_out = _GenericFakeQuantize.apply(ao_input, ao_s, ao_zp, qmin, qmax)
block_size = (1, ao_input.shape[-1])
ao_s = copy.deepcopy(py_s)
ao_zp = copy.deepcopy(py_zp)
ao_out = _GenericFakeQuantize.apply(ao_input, ao_s, ao_zp, qmin, qmax, block_size)
ao_out.sum().backward()

torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0)
torch.testing.assert_close(py_input.grad, ao_input.grad, atol=0, rtol=0)

# Test that gradients are close enough
num_grads = py_input.grad.numel()
num_equal_grads = torch.eq(py_input.grad, ao_input.grad).flatten().sum().item()
num_equal_grad_threshold = 0.8
self.assertGreaterEqual(num_equal_grads / num_grads, num_equal_grad_threshold)

def _assert_close_4w(self, val, ref):
# Note: for int4 weight-only quantization, we do not expect exact match
Expand Down
108 changes: 43 additions & 65 deletions torchao/quantization/prototype/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Optional, Tuple
from typing import Any, List, Optional, Tuple

import torch
import torch.nn.functional as F
Expand All @@ -25,7 +25,10 @@
ZeroPointDomain,
)
from torchao.quantization.unified import TwoStepQuantizer
from torchao.quantization.utils import get_group_qparams_symmetric
from torchao.quantization.utils import (
_get_per_token_block_size,
Copy link
Contributor

@jerryzh168 jerryzh168 Jul 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's helpful we could have a general util like:

def get_block_size(granularity, **kw_params) -> Callable:
    if granularity == Granularity.PER_BLOCK:
        ...
    elif type == Granularity.PER_TOKEN:
        ...
     ...


block_size = get_block_size(Granularity.PER_TOKEN)(x)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, let's do that separately

get_group_qparams_symmetric,
)


# =================
Expand Down Expand Up @@ -346,8 +349,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
scales, zero_points = get_groupwise_affine_qparams(
self.weight, n_bit, self.groupsize, self.scales_precision,
)
w_fq = _Int4WeightOnlyFakeQuantize.apply(
self.weight, scales, zero_points, qmin, qmax, self.groupsize,
w_fq = fake_quantize_per_channel_group(
self.weight,
scales,
zero_points,
qmin,
qmax,
self.groupsize,
ZeroPointDomain.FLOAT,
)
return F.linear(x, w_fq)

Expand All @@ -370,39 +379,6 @@ def disable_4w_fake_quant(mod: torch.nn.Module):
# | QUANT PRIMITIVES |
# ========================

class _Int4WeightOnlyFakeQuantize(torch.autograd.Function):
"""
Implementation of int4 grouped per channel weight-only fake quantize
intended to match the numerics of the efficient int4 tinygemm kernel.
"""

@staticmethod
def forward(ctx, input, scales, zero_points, quant_min, quant_max, groupsize):
assert groupsize > 1
assert input.shape[-1] % groupsize == 0
assert input.dim() == 2
n_bit = 4
block_size = (1, groupsize)
quant_min = 0
quant_max = 2 ** n_bit - 1
(fq, mask) = fake_quantize_affine_cachemask(
input,
block_size,
scales,
zero_points,
torch.int32,
quant_min,
quant_max,
zero_point_domain = ZeroPointDomain.FLOAT,
)
ctx.save_for_backward(mask)
return fq

@staticmethod
def backward(ctx, gy):
(mask,) = ctx.saved_tensors
return gy * mask, None, None, None, None, None

class _GenericFakeQuantize(torch.autograd.Function):
"""
Implementation of generic fake quantize with backward STE.
Expand All @@ -412,71 +388,73 @@ class _GenericFakeQuantize(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, input, scales, zero_points, quant_min, quant_max):
def forward(
ctx: torch.autograd.function.FunctionCtx,
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
block_size: List[int],
zero_point_domain: ZeroPointDomain=ZeroPointDomain.INT,
) -> torch.Tensor:
# Note: for bf16 inputs, casting them to fp32 has the unexpected
# side effect of reducing memory footprint significantly, presumably
# because bf16 * fp32 kernels are not as memory efficient
assert input.dtype == torch.float32
assert scales.dtype == torch.float32
assert zero_points.dtype == torch.int32
q = input.mul(1.0 / scales).round().add(zero_points)
dq = q.clamp(quant_min, quant_max).sub(zero_points).mul(scales)
mask = torch.logical_and((q >= quant_min), (q <= quant_max))

(fq, mask) = fake_quantize_affine_cachemask(
input,
block_size,
scales,
zero_points,
torch.int32,
quant_min,
quant_max,
zero_point_domain,
)

ctx.save_for_backward(mask)
return dq
return fq

@staticmethod
def backward(ctx, gy):
(mask,) = ctx.saved_tensors
return gy * mask, None, None, None, None, None

# TODO: move this to core
quantized_decomposed_lib.define(
"fake_quantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, "
"int quant_min, int quant_max, int group_size) -> Tensor"
)
return gy * mask, None, None, None, None, None, None

@impl(quantized_decomposed_lib, "fake_quantize_per_channel_group", "CompositeImplicitAutograd")
def fake_quantize_per_channel_group(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
group_size: int,
zero_point_domain: ZeroPointDomain=ZeroPointDomain.INT,
) -> torch.Tensor:
assert group_size > 1
assert input.shape[-1] % group_size == 0
assert input.dim() == 2
grouped_input = input.reshape(-1, group_size).to(torch.float32)
scales = scales.reshape(-1, 1)
zero_points = zero_points.reshape(-1, 1)
fq = _GenericFakeQuantize.apply(
grouped_input, scales, zero_points, quant_min, quant_max,
block_size = (1, group_size)
return _GenericFakeQuantize.apply(
input, scales, zero_points, quant_min, quant_max, block_size, zero_point_domain,
)
return fq.reshape_as(input).to(input.dtype)

# TODO: move this to core
quantized_decomposed_lib.define(
"fake_quantize_per_token(Tensor input, Tensor scales, Tensor zero_points, "
"int quant_min, int quant_max) -> Tensor"
)

@impl(quantized_decomposed_lib, "fake_quantize_per_token", "CompositeImplicitAutograd")
def fake_quantize_per_token(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
) -> torch.Tensor:
# TODO: we won't need this import anymore once we move this to core
from torch.ao.quantization.fx._decomposed import _per_token_quant_qparam_dim_check

_per_token_quant_qparam_dim_check(input, scales, zero_points)
block_size = _get_per_token_block_size(input)
fq_input = input.to(torch.float32)
fq = _GenericFakeQuantize.apply(
fq_input, scales, zero_points, quant_min, quant_max,
fq_input, scales, zero_points, quant_min, quant_max, block_size,
)
return fq.reshape_as(input).to(input.dtype)

Expand Down
21 changes: 3 additions & 18 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
TORCH_VERSION_AFTER_2_4,
unwrap_tensor_subclass,
)

from .subclass import (
QuantizedLinearWeightBase,
LinearActQuantizedTensor,
Expand All @@ -42,6 +41,7 @@
Int4WeightOnlyGPTQQuantizer,
Int4WeightOnlyQuantizer,
)
from .utils import _get_per_token_block_size
import logging
from .autoquant import autoquant, AutoQuantizableLinearWeight

Expand Down Expand Up @@ -343,19 +343,10 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight):
quant_min = -8
quant_max = 7

# TODO: make a general helper function?
# input settings
def get_per_token_block_size(x):
block_size = []
for i in range(len(x.shape)-1):
block_size.append(1)
block_size.append(x.shape[-1])
return block_size

# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_target_dtype = torch.int8
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, _get_per_token_block_size(x), input_target_dtype)

weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
weight = to_linear_act_quantized(weight, input_quant_func)
Expand Down Expand Up @@ -441,18 +432,12 @@ def get_weight_block_size(x):
zero_point_dtype = torch.int64

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, _get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)

block_size = get_weight_block_size(weight)
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
Expand Down
9 changes: 8 additions & 1 deletion torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import torch
from torch.utils._python_dispatch import TorchDispatchMode
Expand Down Expand Up @@ -475,3 +475,10 @@ def recommended_inductor_config_setter():
torch._inductor.config.fx_graph_cache = True
torch._inductor.config.triton.unique_kernel_names = True
torch.set_float32_matmul_precision("high")

def _get_per_token_block_size(x: torch.Tensor) -> List[int]:
block_size = []
for i in range(len(x.shape)-1):
block_size.append(1)
block_size.append(x.shape[-1])
return block_size
Loading