From 4d232b8d84600aab624925d73d9bffe1b67158fe Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 24 May 2024 22:32:08 -0700 Subject: [PATCH] Add support for `AQTStorage` and `PlainAQTStorage` Summary: Today `AffineQuantizedTensor` has hardcoded storage format of `int_data`, `scale`, `zero_point`. But this does not work if we want to support packed weight. In this PR, we added support to hide the storage details for `AffineQuantizedTensor` in a family of tensor subclasses, all should inherit from the base Storage type: `AQTStorage` (affine quantized tensor storage) This PR just added support for a plain storage tensor (`PlainAQTStorage`) that stores `int_data`, `scale` and `zero_point` tensors directly, in the next PR we'll also support storing packed weight (result of `torch.ops.aten._convert_weight_to_int4pack`) in a different type of `AQTStorage`. `AffineQuantizedTensor` will have the following: - storage_tensor: AQTStorage (can store data of different storage formats) - storage_layout: str (a string represents the type of storage_tensor we have, can be used in dispatch) Test Plan: python test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags: --- torchao/dtypes/aqt.py | 193 +++++++++++++++++++++++++++++++++--------- 1 file changed, 153 insertions(+), 40 deletions(-) diff --git a/torchao/dtypes/aqt.py b/torchao/dtypes/aqt.py index 7619545f52..fbb4487347 100644 --- a/torchao/dtypes/aqt.py +++ b/torchao/dtypes/aqt.py @@ -18,14 +18,14 @@ def _aqt_is_int8(aqt): """Check if an AffineQuantizedTensor is int8 quantized Tensor""" return ( - aqt.int_data.dtype == torch.int8 and + aqt.storage_tensor.dtype == torch.int8 and aqt.quant_min is None or aqt.quant_min == -128 and aqt.quant_max is None or aqt.quant_max == 127 ) def _aqt_is_int8_reduced_range(aqt): return ( - aqt.int_data.dtype == torch.int8 and + aqt.storage_tensor.dtype == torch.int8 and aqt.quant_min == -127 and aqt.quant_max is None or aqt.quant_max == 127 ) @@ -34,7 +34,7 @@ def _aqt_is_uint4(aqt): """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" # TODO: use torch.uint4 return ( - aqt.int_data.dtype == torch.int32 and + aqt.storage_tensor.dtype == torch.int32 and aqt.quant_min is None or aqt.quant_min == 0 and aqt.quant_max is None or aqt.quant_max == 15 ) @@ -69,6 +69,121 @@ def implements_aqt_aten_ops(aten_ops): def implements_aqt_torch_function(torch_function): return implements_torch_function(AffineQuantizedTensor, torch_function) +_STORAGE_LAYOUT_TO_AQT_STORAGE_CLS: Dict[str, Callable] = {} + +def register_aqt_storage_cls(storage_layout: str): + def decorator(storage_cls): + storage_cls.storage_layout = storage_layout + _STORAGE_LAYOUT_TO_AQT_STORAGE_CLS[storage_layout] = storage_cls + return storage_cls + return decorator + +def get_aqt_storage_cls(storage_layout: str) -> Callable: + if storage_layout not in _STORAGE_LAYOUT_TO_AQT_STORAGE_CLS: + raise ValueError(f"storage layout: {storage_layout} is not supported yet") + return _STORAGE_LAYOUT_TO_AQT_STORAGE_CLS.get(storage_layout) + +class AQTStorage(torch.Tensor): + # this should be set for each storage class during registration + storage_layout: Optional[str] = None + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + ): + pass + + def get_plain() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + pass + +@register_aqt_storage_cls("plain") +class PlainAQTStorage(AQTStorage): + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + ): + self.int_data = int_data + self.scale = scale + self.zero_point = zero_point + + def __tensor_flatten__(self): + return ["int_data", "scale", "zero_point"], [] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] + return cls(int_data, scale, zero_point) + + # TODO: dedup + def _get_to_kwargs(self, *args, **kwargs): + device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + memory_format = ( + memory_format if memory_format is not None else torch.preserve_format + ) + kwargs = { + "device": device, + "dtype": dtype, + "memory_format": memory_format, + } + return kwargs + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.int_data.to(kwargs["device"]), + self.scale.to(kwargs["device"]), + self.zero_point.to(kwargs["device"]), + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.int_data), + fn(self.scale), + fn(self.zero_point), + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"PlainAQTStorage dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self): + return self.int_data, self.scale, self.zero_point + class AffineQuantizedTensor(torch.Tensor): """ @@ -103,9 +218,7 @@ class AffineQuantizedTensor(torch.Tensor): @staticmethod def __new__( cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, + storage_tensor: AQTStorage, block_size: Tuple[int, ...], shape: torch.Size, quant_min: Optional[int] = None, @@ -115,9 +228,9 @@ def __new__( strides=None, ): kwargs = {} - kwargs["device"] = int_data.device + kwargs["device"] = storage_tensor.device kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + kwargs.get("layout") if kwargs.get("layout", False) else storage_tensor.layout ) if dtype is None: dtype = scale.dtype @@ -129,9 +242,7 @@ def __new__( def __init__( self, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, + storage_tensor: AQTStorage, block_size: Tuple[int, ...], shape: torch.Size, quant_min: Optional[int] = None, @@ -140,9 +251,7 @@ def __init__( dtype=None, strides=None, ): - self.int_data = int_data - self.scale = scale - self.zero_point = zero_point + self.storage_tensor = storage_tensor self.block_size = block_size self.quant_min = quant_min self.quant_max = quant_max @@ -157,21 +266,20 @@ def __repr__(self): def dequantize(self, output_dtype=None): if output_dtype is None: output_dtype = self.dtype - return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) + int_data, scale, zero_point = self.storage_tensor.get_plain() + return dequantize_affine(int_data, self.block_size, scale, zero_point, int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] + return ["storage_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] + storage_tensor = tensor_data_dict["storage_tensor"] block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes return cls( - int_data, - scale, - zero_point, + storage_tensor, block_size, shape if outer_size is None else outer_size, quant_min, @@ -195,13 +303,15 @@ def from_float( zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + storage_layout: str = "plain", ): scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) + + storage_cls = get_aqt_storage_cls(storage_layout) + storage_tensor = storage_cls(int_data, scale, zero_point) return cls( - int_data, - scale, - zero_point, + storage_tensor, block_size, input_float.shape, quant_min, @@ -210,6 +320,10 @@ def from_float( dtype=input_float.dtype ) + @property + def storage_layout(self) -> str: + return self.storage_tensor.storage_layout + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = {} if kwargs is None else kwargs @@ -238,9 +352,7 @@ def _get_to_kwargs(self, *args, **kwargs): def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) return self.__class__( - self.int_data.to(kwargs["device"]), - self.scale.to(kwargs["device"]), - self.zero_point.to(kwargs["device"]), + self.storage_tensor.to(kwargs["device"]), self.block_size, self.shape, self.quant_min, @@ -251,9 +363,7 @@ def to(self, *args, **kwargs): def _apply_fn_to_data(self, fn): return self.__class__( - fn(self.int_data), - fn(self.scale), - fn(self.zero_point), + fn(self.storage_tensor), self.block_size, self.shape, self.quant_min, @@ -308,7 +418,9 @@ def functional_linear(*args, **kwargs): if ( is_cuda and input_is_int8 and - input_tensor_dtype_is_expected + input_tensor_dtype_is_expected and + input_tensor.storage_layout == "plain" and + weight_qtensor.storage_layout == "plain" ): # # 1. do the matrix form of dot(X_i, W_j) @@ -321,10 +433,10 @@ def functional_linear(*args, **kwargs): # value of a float 16, (which results in a value of inf even if multiplying # by the other scale would bring it within the expected range) - x_vals_int8 = input_tensor.int_data - x_scales = input_tensor.scale - w_vals_int8_t = weight_qtensor.int_data.contiguous().t() - w_scales = weight_qtensor.scale + x_vals_int8 = input_tensor.storage_tensor.int_data + x_scales = input_tensor.storage_tensor.scale + w_vals_int8_t = weight_qtensor.storage_tensor.int_data.contiguous().t() + w_scales = weight_qtensor.storage_tensor.scale tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1)) @@ -344,22 +456,22 @@ def functional_linear(*args, **kwargs): # weight only quantization # TODO: enable cpu and mps path as well # TODO: make sure weight dimension matches the expectation of the int4mm kernel - # TODO: move this to TinygemmAffineQuantizedTensor if ( is_cuda and weight_is_uint4 and weight_qtensor.dtype == torch.bfloat16 and len(weight_qtensor.shape) == 2 and weight_qtensor.block_size[0] == 1 and - weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT + weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and + weight_qtensor.storage_layout == "plain" ): # groupwise int4 quantization # TODO: currently doing packing on the fly, we'll need to figure out # the API to do packing before hand # TODO: expose the arg innerKTiles = 8 - packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles) - scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point) + packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.storage_tensor.int_data.to(torch.int32), innerKTiles) + scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.storage_tensor.scale, weight_qtensor.storage_tensor.zero_point) groupsize = weight_qtensor.block_size[-1] return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros) elif ( @@ -368,11 +480,12 @@ def functional_linear(*args, **kwargs): len(weight_qtensor.shape) == 2 and len(weight_qtensor.block_size) == 2 and weight_qtensor.block_size[0] == 1 and - weight_qtensor.block_size[1] == weight_qtensor.shape[1] + weight_qtensor.block_size[1] == weight_qtensor.shape[1] and + weight_qtensor.storage_layout == "plain" ): # TODO: enable mps path as well # per channel int8 weight only quantizated mm - return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale) + return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.storage_tensor.int_data, weight_qtensor.storage_tensor.scale) else: weight_tensor = weight_qtensor.dequantize() return torch.nn.functional.linear(input_tensor, weight_tensor, bias)