Skip to content

Commit

Permalink
Add support for AQTStorage and PlainAQTStorage
Browse files Browse the repository at this point in the history
Summary:
Today `AffineQuantizedTensor` has hardcoded storage format of `int_data`, `scale`, `zero_point`. But this does not work if we want to support
packed weight. In this PR, we added support to hide the storage details for `AffineQuantizedTensor` in a family of tensor subclasses, all
should inherit from the base Storage type: `AQTStorage` (affine quantized tensor storage)

This PR just added support for a plain storage tensor (`PlainAQTStorage`) that stores `int_data`, `scale` and `zero_point` tensors directly,
in the next PR we'll also support storing packed weight (result of `torch.ops.aten._convert_weight_to_int4pack`) in a different
type of `AQTStorage`.

`AffineQuantizedTensor` will have the following:
- storage_tensor: AQTStorage (can store data of different storage formats)
- storage_layout: str (a string represents the type of storage_tensor we have, can be used in dispatch)

Test Plan:
python test/quantization/test_quant_api.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 25, 2024
1 parent 90b5e17 commit 4d232b8
Showing 1 changed file with 153 additions and 40 deletions.
193 changes: 153 additions & 40 deletions torchao/dtypes/aqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
def _aqt_is_int8(aqt):
"""Check if an AffineQuantizedTensor is int8 quantized Tensor"""
return (
aqt.int_data.dtype == torch.int8 and
aqt.storage_tensor.dtype == torch.int8 and
aqt.quant_min is None or aqt.quant_min == -128 and
aqt.quant_max is None or aqt.quant_max == 127
)

def _aqt_is_int8_reduced_range(aqt):
return (
aqt.int_data.dtype == torch.int8 and
aqt.storage_tensor.dtype == torch.int8 and
aqt.quant_min == -127 and
aqt.quant_max is None or aqt.quant_max == 127
)
Expand All @@ -34,7 +34,7 @@ def _aqt_is_uint4(aqt):
"""Check if an AffineQuantizedTensor is uint4 quantized Tensor"""
# TODO: use torch.uint4
return (
aqt.int_data.dtype == torch.int32 and
aqt.storage_tensor.dtype == torch.int32 and
aqt.quant_min is None or aqt.quant_min == 0 and
aqt.quant_max is None or aqt.quant_max == 15
)
Expand Down Expand Up @@ -69,6 +69,121 @@ def implements_aqt_aten_ops(aten_ops):
def implements_aqt_torch_function(torch_function):
return implements_torch_function(AffineQuantizedTensor, torch_function)

_STORAGE_LAYOUT_TO_AQT_STORAGE_CLS: Dict[str, Callable] = {}

def register_aqt_storage_cls(storage_layout: str):
def decorator(storage_cls):
storage_cls.storage_layout = storage_layout
_STORAGE_LAYOUT_TO_AQT_STORAGE_CLS[storage_layout] = storage_cls
return storage_cls
return decorator

def get_aqt_storage_cls(storage_layout: str) -> Callable:
if storage_layout not in _STORAGE_LAYOUT_TO_AQT_STORAGE_CLS:
raise ValueError(f"storage layout: {storage_layout} is not supported yet")
return _STORAGE_LAYOUT_TO_AQT_STORAGE_CLS.get(storage_layout)

class AQTStorage(torch.Tensor):
# this should be set for each storage class during registration
storage_layout: Optional[str] = None

def __init__(
self,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
):
pass

def get_plain() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pass

@register_aqt_storage_cls("plain")
class PlainAQTStorage(AQTStorage):
def __new__(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
):
kwargs = {}
kwargs["device"] = int_data.device
kwargs["layout"] = (
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
)
kwargs["dtype"] = int_data.dtype
kwargs["requires_grad"] = False
shape = int_data.shape
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
):
self.int_data = int_data
self.scale = scale
self.zero_point = zero_point

def __tensor_flatten__(self):
return ["int_data", "scale", "zero_point"], []

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"]
return cls(int_data, scale, zero_point)

# TODO: dedup
def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
dtype = self.dtype if dtype is None else dtype
memory_format = (
memory_format if memory_format is not None else torch.preserve_format
)
kwargs = {
"device": device,
"dtype": dtype,
"memory_format": memory_format,
}
return kwargs

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
return self.__class__(
self.int_data.to(kwargs["device"]),
self.scale.to(kwargs["device"]),
self.zero_point.to(kwargs["device"]),
)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.int_data),
fn(self.scale),
fn(self.zero_point),
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs

if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

raise NotImplementedError(
f"PlainAQTStorage dispatch: attempting to run {func}, this is not supported"
)

__torch_function__ = torch._C._disabled_torch_function_impl

def get_plain(self):
return self.int_data, self.scale, self.zero_point


class AffineQuantizedTensor(torch.Tensor):
"""
Expand Down Expand Up @@ -103,9 +218,7 @@ class AffineQuantizedTensor(torch.Tensor):
@staticmethod
def __new__(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
storage_tensor: AQTStorage,
block_size: Tuple[int, ...],
shape: torch.Size,
quant_min: Optional[int] = None,
Expand All @@ -115,9 +228,9 @@ def __new__(
strides=None,
):
kwargs = {}
kwargs["device"] = int_data.device
kwargs["device"] = storage_tensor.device
kwargs["layout"] = (
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
kwargs.get("layout") if kwargs.get("layout", False) else storage_tensor.layout
)
if dtype is None:
dtype = scale.dtype
Expand All @@ -129,9 +242,7 @@ def __new__(

def __init__(
self,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
storage_tensor: AQTStorage,
block_size: Tuple[int, ...],
shape: torch.Size,
quant_min: Optional[int] = None,
Expand All @@ -140,9 +251,7 @@ def __init__(
dtype=None,
strides=None,
):
self.int_data = int_data
self.scale = scale
self.zero_point = zero_point
self.storage_tensor = storage_tensor
self.block_size = block_size
self.quant_min = quant_min
self.quant_max = quant_max
Expand All @@ -157,21 +266,20 @@ def __repr__(self):
def dequantize(self, output_dtype=None):
if output_dtype is None:
output_dtype = self.dtype
return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype)
int_data, scale, zero_point = self.storage_tensor.get_plain()
return dequantize_affine(int_data, self.block_size, scale, zero_point, int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype)

def __tensor_flatten__(self):
return ["int_data", "scale", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]
return ["storage_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"]
storage_tensor = tensor_data_dict["storage_tensor"]
block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes
return cls(
int_data,
scale,
zero_point,
storage_tensor,
block_size,
shape if outer_size is None else outer_size,
quant_min,
Expand All @@ -195,13 +303,15 @@ def from_float(
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
storage_layout: str = "plain",
):
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)

storage_cls = get_aqt_storage_cls(storage_layout)
storage_tensor = storage_cls(int_data, scale, zero_point)
return cls(
int_data,
scale,
zero_point,
storage_tensor,
block_size,
input_float.shape,
quant_min,
Expand All @@ -210,6 +320,10 @@ def from_float(
dtype=input_float.dtype
)

@property
def storage_layout(self) -> str:
return self.storage_tensor.storage_layout

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
Expand Down Expand Up @@ -238,9 +352,7 @@ def _get_to_kwargs(self, *args, **kwargs):
def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
return self.__class__(
self.int_data.to(kwargs["device"]),
self.scale.to(kwargs["device"]),
self.zero_point.to(kwargs["device"]),
self.storage_tensor.to(kwargs["device"]),
self.block_size,
self.shape,
self.quant_min,
Expand All @@ -251,9 +363,7 @@ def to(self, *args, **kwargs):

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.int_data),
fn(self.scale),
fn(self.zero_point),
fn(self.storage_tensor),
self.block_size,
self.shape,
self.quant_min,
Expand Down Expand Up @@ -308,7 +418,9 @@ def functional_linear(*args, **kwargs):
if (
is_cuda and
input_is_int8 and
input_tensor_dtype_is_expected
input_tensor_dtype_is_expected and
input_tensor.storage_layout == "plain" and
weight_qtensor.storage_layout == "plain"
):
#
# 1. do the matrix form of dot(X_i, W_j)
Expand All @@ -321,10 +433,10 @@ def functional_linear(*args, **kwargs):
# value of a float 16, (which results in a value of inf even if multiplying
# by the other scale would bring it within the expected range)

x_vals_int8 = input_tensor.int_data
x_scales = input_tensor.scale
w_vals_int8_t = weight_qtensor.int_data.contiguous().t()
w_scales = weight_qtensor.scale
x_vals_int8 = input_tensor.storage_tensor.int_data
x_scales = input_tensor.storage_tensor.scale
w_vals_int8_t = weight_qtensor.storage_tensor.int_data.contiguous().t()
w_scales = weight_qtensor.storage_tensor.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1))

Expand All @@ -344,22 +456,22 @@ def functional_linear(*args, **kwargs):
# weight only quantization
# TODO: enable cpu and mps path as well
# TODO: make sure weight dimension matches the expectation of the int4mm kernel
# TODO: move this to TinygemmAffineQuantizedTensor
if (
is_cuda and
weight_is_uint4 and
weight_qtensor.dtype == torch.bfloat16 and
len(weight_qtensor.shape) == 2 and
weight_qtensor.block_size[0] == 1 and
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
weight_qtensor.storage_layout == "plain"
):
# groupwise int4 quantization
# TODO: currently doing packing on the fly, we'll need to figure out
# the API to do packing before hand
# TODO: expose the arg
innerKTiles = 8
packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles)
scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.storage_tensor.int_data.to(torch.int32), innerKTiles)
scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.storage_tensor.scale, weight_qtensor.storage_tensor.zero_point)
groupsize = weight_qtensor.block_size[-1]
return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros)
elif (
Expand All @@ -368,11 +480,12 @@ def functional_linear(*args, **kwargs):
len(weight_qtensor.shape) == 2 and
len(weight_qtensor.block_size) == 2 and
weight_qtensor.block_size[0] == 1 and
weight_qtensor.block_size[1] == weight_qtensor.shape[1]
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
weight_qtensor.storage_layout == "plain"
):
# TODO: enable mps path as well
# per channel int8 weight only quantizated mm
return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale)
return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.storage_tensor.int_data, weight_qtensor.storage_tensor.scale)
else:
weight_tensor = weight_qtensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
Expand Down

0 comments on commit 4d232b8

Please sign in to comment.