Skip to content

Commit

Permalink
Move AffineQuantizedTensor to torchao/dtypes (#272)
Browse files Browse the repository at this point in the history
Summary:
att

Test Plan:
regression tests in test/quantization/test_quant_api.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored May 24, 2024
1 parent bc46bdc commit f8f74c7
Show file tree
Hide file tree
Showing 4 changed files with 465 additions and 461 deletions.
27 changes: 14 additions & 13 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@
get_symmetric_quantization_config,
)

from torchao.quantization.subclass import (
to_aqt,
to_laqt,
from torchao.dtypes import (
to_aq,
AffineQuantizedTensor,
LinearActQuantizedTensor,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)

from torchao.quantization.subclass import (
to_laq,
LinearActQuantizedTensor,
)
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
apply_dynamic_quant,
Expand Down Expand Up @@ -429,17 +430,17 @@ def get_per_token_block_size(x):
# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_target_dtype = torch.int8
input_quant_func = lambda x: to_aqt(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()

def apply_weight_quant(weight):
return to_aqt(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)

def apply_act_quant(weight):
return to_laqt(weight, input_quant_func)
return to_laq(weight, input_quant_func)

# note: order is important
m = quantize(m, apply_weight_quant)
Expand Down Expand Up @@ -484,7 +485,7 @@ def test_quantized_tensor_subclass_int4(self):
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))

def apply_weight_quant(weight):
return to_aqt(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)

m = quantize(m, apply_weight_quant)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
Expand Down Expand Up @@ -515,7 +516,7 @@ def test_quantized_tensor_subclass_int8(self):

def apply_weight_quant(weight):
block_size = (1, weight.shape[1])
return to_aqt(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)

m = quantize(m, apply_weight_quant)

Expand Down Expand Up @@ -555,7 +556,7 @@ def get_per_token_block_size(x):
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: to_aqt(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)

# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
Expand All @@ -565,10 +566,10 @@ def get_per_token_block_size(x):

def apply_weight_quant(weight):
block_size = get_weight_block_size(weight)
return to_aqt(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)

def apply_act_quant(weight):
return to_laqt(weight, input_quant_func)
return to_laq(weight, input_quant_func)

m = quantize(m, apply_weight_quant)
m = quantize(m, apply_act_quant)
Expand Down
3 changes: 3 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from .nf4tensor import NF4Tensor, to_nf4
from .uint4 import UInt4Tensor
from .aqt import AffineQuantizedTensor, to_aq

__all__ = [
"NF4Tensor",
"to_nf4",
"UInt4Tensor"
"AffineQuantizedTensor",
"to_aq",
]
Loading

0 comments on commit f8f74c7

Please sign in to comment.