Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HQQ support #605

Merged
merged 21 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from typing import Dict, Callable, Any, Tuple, Optional
from collections import defaultdict
import functools
import math
from torchao.quantization.quant_primitives import (
choose_qparams_affine,
quantize_affine,
dequantize_affine,
ZeroPointDomain,
MappingType,
int_scaled_matmul,
quantize_affine_hqq,
)
from torchao.quantization.utils import (
pack_tinygemm_scales_and_zeros,
Expand All @@ -25,6 +27,7 @@
PlainLayoutType,
is_device,
)

from dataclasses import dataclass
from torchao.utils import TORCH_VERSION_AFTER_2_5

Expand Down Expand Up @@ -75,7 +78,6 @@ def _get_to_kwargs(self, *args, **kwargs):
##############################
# Tensor Subclass Definition #
##############################

class AffineQuantizedTensor(torch.Tensor):
"""
Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation:
Expand Down Expand Up @@ -190,14 +192,23 @@ def from_float(
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
layout_type: LayoutType = PlainLayoutType(),
use_hqq: bool = False,
mobicham marked this conversation as resolved.
Show resolved Hide resolved
):
original_shape = input_float.shape
input_float = layout_type.pre_process(input_float)

scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
int_data = layout_type.post_process(int_data)
if(use_hqq):
assert zero_point_domain == ZeroPointDomain.FLOAT and mapping_type == MappingType.ASYMMETRIC and quant_min==0, "Invalid input parameters for HQQ quantization."
nbits = int(math.log2(quant_max + 1))
axis = 1 if (block_size[0]==1) else 0
group_size = max(block_size)
int_data, scale, zero_point, _ = quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=input_float.dtype, device=input_float.device, verbose=False, raw_output=False)

else:
input_float = layout_type.pre_process(input_float)
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
int_data = layout_type.post_process(int_data)

layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type)
return cls(
Expand Down
184 changes: 184 additions & 0 deletions torchao/prototype/hqq/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import torch
import math
from torch import Tensor, float16, float32
from typing import Union


# Shrinking operator (proximal operator for the lp norm)
def shrink_lp_op(x: Tensor, beta: float, lp_norm: float) -> Tensor:
if lp_norm == 1:
return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta)
else:
return torch.sign(x) * torch.nn.functional.relu(
torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x), lp_norm - 1)
)


# Proximal solver || W - dequantize(quantize(W))||_p^p
@torch.inference_mode()
def optimize_weights_proximal_legacy(
tensor: Tensor,
scale: Tensor,
zero: Tensor,
min_max: list,
axis: int = 0,
dtype: Union[torch.dtype, None] = None,
device: Union[str, None] = None,
verbose: bool = False,
opt_params: dict = {
"lp_norm": 0.7,
"beta": 1e1,
"kappa": 1.01,
"iters": 20,
"early_stop": True,
},
) -> tuple:
lp_norm, beta, kappa, iters, early_stop = (
opt_params["lp_norm"],
opt_params["beta"],
opt_params["kappa"],
opt_params["iters"],
opt_params["early_stop"],
)

device = tensor.device if (device is None) else torch.device(device)

if dtype is None:
dtype = float16 if (device.type == "cuda") else float32

W_f = tensor.to(dtype=dtype, device=device)
scale = scale.to(dtype=dtype, device=device)
zero = zero.to(dtype=dtype, device=device)

best_error = 1e4
for i in range(iters):
W_q = torch.round(W_f * scale + zero).clamp(min_max[0], min_max[1])
W_r = (W_q - zero) / scale
W_e = shrink_lp_op(W_f - W_r, beta, lp_norm)
zero = torch.mean(W_q - (W_f - W_e) * scale, axis=axis, keepdim=True)
beta *= kappa

current_error = float(torch.abs(W_f - W_r).mean())
if verbose:
print("Iter " + str(i + 1), " | Error: " + str(current_error))
if early_stop:
if current_error < best_error:
best_error = current_error
else:
break

scale = scale.to(tensor.device)
zero = zero.to(tensor.device)
del W_f, W_q, W_r, W_e
torch.cuda.empty_cache()

W_q = torch.round(tensor * scale + zero).clamp(min_max[0], min_max[1])
return W_q, scale, zero


# Default: fast with early stopping
optimize_weights_proximal = optimize_weights_proximal_legacy


# Mainly used to check if the group-size is divisible by numel()
def is_divisible(val1: int, val2: int) -> bool:
return int(val2 * math.ceil(val1 / val2)) == val1


# Converts hqq format W_dequant = (W_q - zero)*scale into affinequantized format: (W_q - mid_point)*scale_ao + zero_ao
def convert_to_affinequantized_format(W_q, scale, zero, nbits, shape):
mobicham marked this conversation as resolved.
Show resolved Hide resolved
quant_min = 0
quant_max = 2**nbits - 1
mid_point = (quant_max + quant_min + 1) / 2
zero_ao = ((mid_point - zero.float()) * scale.float()).to(zero.dtype)
scale_ao = scale
W_q_ao = W_q.view(shape)
return W_q_ao, scale_ao, zero_ao


# Main HQQ Quantizer - simplified, no bitpacking.
class HQQQuantizer:
optimize_weights = optimize_weights_proximal

@classmethod
def quantize(
cls,
tensor: Tensor,
nbits: float = 4,
group_size: int = 64,
optimize: bool = True,
axis: int = 1,
compute_dtype: torch.dtype = float16,
device: str = "cuda",
verbose: bool = False, # to check the optimizer error
raw_output: bool = False, # If True, it will return the quant params in hqq lib format
) -> tuple:
assert axis in [0, 1], "axis should be either 0 or 1"
if group_size is not None:
assert is_divisible(tensor.numel(), group_size), (
"group_size should be divisble by the total tensor dimensions. shape: "
+ str(tensor.shape)
+ ", group_size: "
+ str(group_size)
)

W = tensor.to(device=device, dtype=torch.float32)
shape = W.shape

# Reshape for grouping
if group_size is not None:
W = (
W.reshape([-1, group_size])
if (axis == 1)
else W.reshape([group_size, -1])
)

# Get min/max values
_min = W.min(axis=axis, keepdim=True)[0]
_max = W.max(axis=axis, keepdim=True)[0]

max_v = round(2**nbits - 1)
min_v = 0
min_max = [min_v, max_v]

# Clamp to avoid fp16 issues
scale = (max_v / (_max - _min)).clamp(max=2e4)
zero = -_min * scale

# Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14
if nbits in [4]:
zero = torch.round(zero)

# Fine-tune weights
if optimize:
W_q, scale, zero = HQQQuantizer.optimize_weights(
tensor=W,
scale=scale,
zero=zero,
min_max=min_max,
axis=axis,
device=device,
verbose=verbose,
)
else:
W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])

# Store meta-data (we invert the scale for dequantization)
scale = 1.0 / scale

# Convert to affienquantized format
if raw_output is False:
W_q, scale, zero = convert_to_affinequantized_format(
W_q, scale, zero, nbits, shape
)

# Make sure all the weights are in the right compute_dtype/device
W_q = W_q.to(dtype=torch.uint8, device=device)
scale = scale.to(dtype=compute_dtype, device=device)
zero = zero.to(dtype=compute_dtype, device=device)

# cleanup
del W, _min, _max
torch.cuda.empty_cache()

return W_q, scale, zero, shape
62 changes: 62 additions & 0 deletions torchao/prototype/hqq/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
from torchao.prototype.hqq.core import HQQQuantizer
from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
ZeroPointDomain,
PlainAQTLayout,
PlainLayoutType,
TensorCoreTiledAQTLayout,
TensorCoreTiledLayoutType,
MappingType,
)

#Parameters
device, compute_dtype = "cuda:0", torch.bfloat16
group_size, axis = 64, 1

linear_layer = torch.nn.Linear(4096, 11800, bias=False, device=device)
x = torch.randn((1, linear_layer.in_features), dtype=torch.float, device=device)/20.
y_ref = linear_layer(x)
W = linear_layer.weight.data.clone().to(device=device, dtype=compute_dtype)
del linear_layer.weight
################################################################################################

for nbits in list(range(2, 9))[::-1]:
print('------------------------------------------------------------------------------')
q_tensor_default = AffineQuantizedTensor.from_float(
input_float=W,
mapping_type=MappingType.ASYMMETRIC,
block_size=[1, group_size],
target_dtype=torch.uint8,
quant_min=0,
quant_max=2**nbits - 1,
preserve_zero=False,#Important
zero_point_domain= ZeroPointDomain.FLOAT,
layout_type=PlainLayoutType(),
)

linear_layer.weight = q_tensor_default
print("nbits", nbits, "| Default dequantization error", (W - q_tensor_default.dequantize()).abs().mean().item())
print("nbits", nbits, '| Default Dot product error', (y_ref - linear_layer(x.to(compute_dtype))).abs().mean().item())
# 4-bit Default dequantization error 0.001953125
# 4-bit Default Dot product error 0.0057801781222224236


q_tensor_hqq = AffineQuantizedTensor.from_float(
mobicham marked this conversation as resolved.
Show resolved Hide resolved
input_float=W,
mapping_type=MappingType.ASYMMETRIC,
block_size=[1, group_size],
target_dtype=torch.uint8,
quant_min=0,
quant_max=2**nbits - 1,
preserve_zero=False,#Important
zero_point_domain= ZeroPointDomain.FLOAT,
layout_type=PlainLayoutType(),
use_hqq=True,
)

linear_layer.weight = q_tensor_hqq
print("nbits", nbits, "| HQQ dequantization error", (W - q_tensor_hqq.dequantize()).abs().mean().item())
print("nbits", nbits, '| HQQ Dot product error', (y_ref - linear_layer(x.to(compute_dtype))).abs().mean().item())
# 4-bit HQQ dequantization error 0.0004863739013671875
# 4-bit HQQ Dot product error 0.0014263123739510775
Loading
Loading