Skip to content

Commit

Permalink
updated uint4 and perchannel_symmetricweight based on new API pytorch…
Browse files Browse the repository at this point in the history
  • Loading branch information
melvinebenezer committed Jul 7, 2024
1 parent c0497b8 commit d54399c
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 210 deletions.
4 changes: 3 additions & 1 deletion test/dtypes/test_uint4.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
from torchao.dtypes.uint4 import (
UInt4Tensor,
PerChannelSymmetricWeightUInt4Tensor,
)
from torchao.dtypes import (
PerChannelSymmetricWeightUInt4Tensor
)
import unittest
from unittest import TestCase, main
Expand Down
2 changes: 2 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .nf4tensor import NF4Tensor, to_nf4
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
from .uint4 import UInt4Tensor
from .perchannel_symmetricweight import PerChannelSymmetricWeightUInt4Tensor
from .affine_quantized_tensor import AffineQuantizedTensor, to_affine_quantized

__all__ = [
Expand All @@ -9,4 +10,5 @@
"UInt4Tensor"
"AffineQuantizedTensor",
"to_affine_quantized",
"PerChannelSymmetricWeightUInt4Tensor",
]
144 changes: 144 additions & 0 deletions torchao/dtypes/perchannel_symmetricweight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import torch
from torchao.dtypes.uint4 import pack_uint4, unpack_uint4
from torchao.dtypes import UInt4Tensor
from typing import Dict, Any
from torchao.dtypes.utils import _implements
from torchao.dtypes.utils import _ATEN_OP_OR_TORCH_FN_TABLE

SYMMETRIC_WEIGHT_OPS_TABLE: Dict[Any, Any] = {}

from torchao.dtypes.utils import _implements

def implements(aten_ops_or_torch_fns):
return _implements(PerChannelSymmetricWeightUInt4Tensor, aten_ops_or_torch_fns)

def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype):
# assumes symmetric quantization
# assumes axis == 0
# assumes dense memory format
# TODO(future): relax ^ as needed

# default setup for affine quantization of activations
eps = torch.finfo(torch.float32).eps

# get min and max
min_val, max_val = torch.aminmax(x, dim=1)

# calculate scale and zero point based on min and max
# reference: https://fburl.com/code/srbiybme
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
device = min_val_neg.device

# reference: https://fburl.com/code/4wll53rk
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
# ensure scale is the same dtype as the original tensor
scale = torch.clamp(scale, min=eps).to(x.dtype)
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)

# quantize based on qmin/qmax/scale/zp
# reference: torch/ao/quantization/fx/_decomposed.py?lines=63
x_div = x.transpose(0, 1) / scale
x_round = torch.round(x_div)
x_zp = x_round + zero_point
x_zp = x_zp.transpose(0, 1)
quant = torch.clamp(x_zp, quant_min, quant_max)

if target_dtype == torch.uint4:
# TODO: simplify (maybe implement to)
quant = PerChannelSymmetricWeightUInt4Tensor.from_unpacked(
quant.to(torch.uint8), scale
)
else:
quant = quant.to(target_dtype)

return quant, scale, zero_point

class PerChannelSymmetricWeightUInt4Tensor(UInt4Tensor):
@staticmethod
def __new__(cls, elem, scales, **kwargs):
return super().__new__(cls, elem, **kwargs)

def __init__(self, elem, scales, **kwargs):
super().__init__(elem, **kwargs)

self.scales = scales

def __tensor_flatten__(self):
return ["elem", "scales"], None

@staticmethod
def __tensor_unflatten__(flattened, meta, outer_size, outer_stride):
assert meta is None
elem = flattened["elem"]
scales = flattened["scales"]
return PerChannelSymmetricWeightUInt4Tensor(elem, scales)

@classmethod

# inconsistently.

def from_unpacked(cls, unpacked, scales):
return cls(pack_uint4(unpacked), scales)

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs

if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](*args, **kwargs)

with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs)

raise NotImplementedError(
f"PerChannelSymmetricWeightUInt4Tensor dispatch: attempting to run {func}, this is not supported"
)


@classmethod
def from_float(cls, w_fp32):
w_int4, scales, _zp = _dynamically_quantize_per_channel_int4(
w_fp32, 0, 15, torch.uint4
)
w_int4 = w_int4.to(device=w_fp32.device)
return w_int4

@implements([torch.ops.aten.addmm.default])
def _(func, args, kwargs):
bias, x, weight = args
x_view = x.view(-1, x.shape[-1])
y = torch.mm(x_view, weight.to(torch.uint8).to(x.dtype)) * weight.scales
y = y.reshape(*x.shape[:-1], -1)
if bias is not None:
y += bias
return y

@implements([torch.ops.aten.t.default])
def _(func, args, kwargs):
# TODO: add proper support for transpose
(tensor,) = args
unpacked = unpack_uint4(tensor.elem)
transposed = torch.ops.aten.t.default(unpacked)
return PerChannelSymmetricWeightUInt4Tensor.from_unpacked(
transposed, tensor.scales
)

@implements([torch.ops.aten.detach.default])
def _(func, args, kwargs):
(tensor,) = args
return

if __name__ == "__main__":
# test
x = torch.randn(2, 3, 4)
w = torch.randn(5, 4)
b = torch.randn(5)
y = PerChannelSymmetricWeightUInt4Tensor.from_float(w)
# print(y)
Loading

0 comments on commit d54399c

Please sign in to comment.