Skip to content

Commit

Permalink
Refactor int4 weight only quantization with call to quantize
Browse files Browse the repository at this point in the history
Summary:
This is similar to #294 but applied for int4 weight only quantization

Test Plan:

unit perf test:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int4_wo_quant_perf
elapsed time: 0.2166275215148926, ref elapsed time: 0.2191881561279297
elapsed time: 0.2376406478881836, ref elapsed time: 0.22721023559570314
elapsed time: 0.21919679641723633, ref elapsed time: 0.2154969596862793

integration perf test:

reference: elapsed_time:  2.5900126953125  milliseconds
after refactor: elapsed_time:  2.56680078125  milliseconds
diff: no diff

TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

Before:
After:
generated code diff:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Jun 1, 2024
1 parent 55a4676 commit a19ac04
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 47 deletions.
4 changes: 2 additions & 2 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _ref_change_linear_weights_to_int4_woqtensors(model, **kwargs):

_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=False, **kwargs),
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=True, **kwargs),
filter_fn,
)

Expand Down Expand Up @@ -633,7 +633,7 @@ def test_quantized_tensor_subclass_int8_wo_quant_perf(self):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int4 weight only quant implementation")
# @unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int4 weight only quant implementation")
def test_quantized_tensor_subclass_int4_wo_quant_perf(self):
kwargs = {"groupsize": 32}
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
Expand Down
170 changes: 134 additions & 36 deletions torchao/dtypes/aqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ def implements_aqt_torch_function(torch_function):
def register_aqt_layout_cls(extended_layout: str):
def decorator(layout_cls):
layout_cls.extended_layout = extended_layout
_EXTENDED_LAYOUT_TO_AQT_LAYOUT_CLS[extended_layout] = layout_cls
_EXTENDED_LAYOUT_TO_AQT_LAYOUT_CLS[extended_layout] = layout_cls.from_plain
return layout_cls
return decorator

def get_aqt_layout_cls(extended_layout: str) -> Callable:
def get_aqt_layout_cls_ctr(extended_layout: str) -> Callable:
if extended_layout not in _EXTENDED_LAYOUT_TO_AQT_LAYOUT_CLS:
raise ValueError(f"extended_layout: {extended_layout} is not supported yet")
return _EXTENDED_LAYOUT_TO_AQT_LAYOUT_CLS.get(extended_layout)
Expand All @@ -90,17 +90,18 @@ class AQTLayout(torch.Tensor):
# this should be set for each layout class during registration
extended_layout: Optional[str] = None

def __init__(
self,
def get_plain() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pass

@classmethod
def from_plain(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
):
pass

def get_plain() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pass

def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
Expand Down Expand Up @@ -205,6 +206,14 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
def get_plain(self):
return self.int_data, self.scale, self.zero_point

@classmethod
def from_plain(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
):
return cls(int_data, scale, zero_point)

@register_aqt_layout_cls("tensor_core_tiled")
class TensorCoreTiledAQTLayout(AQTLayout):
Expand All @@ -222,40 +231,45 @@ class TensorCoreTiledAQTLayout(AQTLayout):

def __new__(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
packed_weight: torch.Tensor,
scale_and_zero: torch.Tensor,
):
kwargs = {}
kwargs["device"] = int_data.device
kwargs["device"] = packed_weight.device
kwargs["layout"] = (
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
kwargs.get("layout") if kwargs.get("layout", False) else packed_weight.layout
)
kwargs["dtype"] = int_data.dtype
kwargs["dtype"] = packed_weight.dtype
kwargs["requires_grad"] = False
shape = int_data.shape
shape = packed_weight.shape
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
packed_weight: torch.Tensor,
scale_and_zero: torch.Tensor,
):
# TODO: expose the arg
innerKTiles = 8
self.packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), innerKTiles)
self.scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
self.packed_weight = packed_weight
self.scale_and_zero = scale_and_zero

def __tensor_flatten__(self):
return ["packed_weight", "scale_and_zero"]
return ["packed_weight", "scale_and_zero"], []

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"]
# TODO: fix the unflatten logic
return cls(packed_weight, scale_and_zero)

@classmethod
def from_plain(cls, int_data, scale, zero_point):
# TODO: expose the arg
innerKTiles = 8
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), innerKTiles)
scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
return cls(packed_weight, scale_and_zero)

def to(self, *args, **kwargs):
Expand All @@ -273,6 +287,14 @@ def _apply_fn_to_data(self, fn):
self.scale_and_zero = fn(self.scale_and_zero)
return self

def _change_shape(self, shape):
# int_data, scale, zero = self.get_plain()
# int_data = int_data.view(shape)
# changed = self.from_plain(int_data, scale, zero)
# return changed
# TODO: changing shape is no-op for int4 packed weight right now
return self

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs
Expand All @@ -282,16 +304,47 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

if func is aten.view.default:
assert len(args) == 2
new = args[0]._change_shape(args[1])
return return_and_correct_aliasing(func, args, kwargs, new)

raise NotImplementedError(
f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported"
f"TensorCoreTiledAQTLayout dispatch: attempting to run {func}, this is not supported"
)

__torch_function__ = torch._C._disabled_torch_function_impl

def get_plain(self):
raise NotImplementedError(
f"Unpacking for tensor core tiled storage is not yet implemented"
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
unpack_tinygemm_scales_and_zeros,
quantize_affine,
)
cur_shape = self.shape
assert len(cur_shape) == 4
# TODO: expose the arg
innerKTiles = 8
original_shape = (cur_shape[0] * 8, cur_shape[1] * (innerKTiles * 16))
eye_shape = original_shape[1]
block_size = (1, 32)
device = self.device
original_dtype = torch.bfloat16
groupsize = 32
target_dtype = torch.int32
quant_min = 0
quant_max = 15
zero_point_domain = ZeroPointDomain.FLOAT
assert len(block_size) == 2 and block_size[0] == 1
groupsize = block_size[-1]
dequantized = torch.ops.aten._weight_int4pack_mm(torch.eye(eye_shape, device=device, dtype=original_dtype), self.packed_weight, groupsize, self.scale_and_zero)
dequantized = dequantized.t().contiguous()
scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero)
# TODO: move this to `unpack_tinygemm_scales_and_zeros`?
scale = scale.reshape(scale.shape[:-1]).contiguous()
zero = zero.reshape(zero.shape[:-1]).contiguous()
int_data = quantize_affine(dequantized, block_size, scale, zero, target_dtype, quant_min, quant_max, zero_point_domain)
return int_data, scale, zero

class AffineQuantizedTensor(torch.Tensor):
"""
Expand Down Expand Up @@ -413,15 +466,26 @@ def from_float(
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
extended_layout: str = "plain",
):
original_shape = input_float.shape
if extended_layout == "tensor_core_tiled":
from torchao.quantization.utils import find_multiple
orig_out_features, orig_in_features = input_float.shape
in_features = find_multiple(orig_in_features, 1024)
out_features = find_multiple(orig_out_features, 8)
input_float = torch.nn.functional.pad(
input_float,
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
)

scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)

layout_cls = get_aqt_layout_cls(extended_layout)
layout_tensor = layout_cls(int_data, scale, zero_point)
layout_cls_ctr = get_aqt_layout_cls_ctr(extended_layout)
layout_tensor = layout_cls_ctr(int_data, scale, zero_point)
return cls(
layout_tensor,
block_size,
input_float.shape,
original_shape,
quant_min,
quant_max,
zero_point_domain,
Expand Down Expand Up @@ -507,7 +571,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported"
)

def _quantized_linear_op(input_tensor, weight_qtensor, bias):
def _quantized_linear_op(input_tensor, weight_qtensor, bias, _from_flinear=True):
# TODO: the old tensor subclass can use the single implementation for both F.linear dispatch
# and aten.addmm/aten.mm dispatch because `_change_shape` is not implmeneted correctly (got ignored
# for the int_data), this makes the dimension for weight_qtensor indeterministic, we need to fix
# the issue and make sure we have a clear accepted dimension for `_quantized_linear_op`
# after that we can remove _from_linear flag

is_cuda = weight_qtensor.is_cuda
is_cpu = weight_qtensor.device == torch.device("cpu")
if isinstance(weight_qtensor, AffineQuantizedTensor):
Expand Down Expand Up @@ -564,15 +634,42 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
weight_is_uint4 and
weight_qtensor.dtype == torch.bfloat16 and
len(weight_qtensor.shape) == 2 and
weight_qtensor.block_size[0] == 1 and
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
weight_qtensor.layout == "tensor_core_tiled"
):
# groupwise int4 quantization
groupsize = weight_qtensor.block_size[-1]
if not _from_flinear:
weight_qtensor = weight_qtensor.t()
assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"

# TODO: check groupsize quantization
# avoid circular dep, TODO: move this to a common util.py
from torchao.quantization.utils import find_multiple
act_mat = input_tensor
# weight is packed from padded (out_features, in_features) weight tensor
# (same dimension requirement as F.linear weight)
packed_weight = weight_qtensor.layout_tensor.packed_weight
scale_and_zero = weight_qtensor.layout_tensor.scale_and_zero
return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scale_and_zero)

orig_act_size = act_mat.size()
orig_dtype = act_mat.dtype

# reshape and pad activation
act_mat = act_mat.reshape(-1, act_mat.shape[-1]).to(torch.bfloat16)
pad_size = find_multiple(act_mat.shape[-1], 1024)
act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1]))

# groupwise int4 quantization
groupsize = weight_qtensor.block_size[1]
y = torch.ops.aten._weight_int4pack_mm(act_mat.contiguous(), packed_weight, groupsize, scale_and_zero)

# remove out_feature padding
orig_out_features = weight_qtensor.shape[-2]
y = y[:, :orig_out_features]
y = y.reshape(*orig_act_size[:-1], orig_out_features)

if bias is not None:
y += bias
return y.to(orig_dtype)
elif (
weight_is_int8 and
len(weight_qtensor.shape) == 2 and
Expand Down Expand Up @@ -602,6 +699,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
# is_cpu and is_mps only, some issue with is_contiguous() currently
# return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_qtensor.layout_tensor.scale)

breakpoint()
raise NotImplementedError("No specialized dispatch found for quantized linear op")


Expand Down Expand Up @@ -639,7 +737,7 @@ def aten_mm(func, *args, **kwargs):
args[0],
)
try:
return _quantized_linear_op(input_tensor, weight_tensor, bias)
return _quantized_linear_op(input_tensor, weight_tensor, bias, _from_flinear=False)
except:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
Expand All @@ -653,7 +751,7 @@ def aten_mm(func, *args, **kwargs):
None
)
try:
return _quantized_linear_op(input_tensor, weight_tensor, bias)
return _quantized_linear_op(input_tensor, weight_tensor, bias, _from_flinear=False)
except:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
Expand Down
16 changes: 10 additions & 6 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,15 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs):
"""
filter_fn = kwargs.pop("filter_fn", _is_linear)

_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs),
filter_fn,
)
if TORCH_VERSION_AFTER_2_4:
quantize(model, get_apply_int4wo_quant(**kwargs), filter_fn)
unwrap_tensor_subclass(model, filter_fn)
else:
_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=False, **kwargs),
filter_fn,
)

def swap_conv2d_1x1_to_linear(model, filter_fn=None):
"""
Expand Down Expand Up @@ -351,8 +355,8 @@ def get_apply_int4wo_quant(groupsize=32):
def apply_int4wo_quant(weight):
# avoid circular dep
from torchao.dtypes.aqt import to_aq
from torchao.quantization.utils import find_multiple

groupsize = 32
mapping_type = MappingType.ASYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int32
Expand Down
8 changes: 5 additions & 3 deletions tutorials/quantize_vit/run_vit_b_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,22 @@

## Quantization code - start
# int8 act, int8 weight dynamic quantization
torchao.apply_dynamic_quant(model)
# torchao.apply_dynamic_quant(model)

# int8 weight only quantization
# torchao.quantization.change_linear_weights_to_int8_woqtensors(model)
## Quantization code - end

# int4 weight only quantization
torchao.quantization.change_linear_weights_to_int4_woqtensors(model, groupsize=32)
## Quantization code - end

## compilation configs
torch._dynamo.config.automatic_dynamic_shapes = False
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
## compilation configs end

model = torch.compile(model, mode='max-autotune', fullgraph=True)
model = torch.compile(model, mode='max-autotune')

# Must run with no_grad when optimizing for inference
with torch.no_grad():
Expand Down

0 comments on commit a19ac04

Please sign in to comment.