Skip to content

Commit

Permalink
Test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Sep 23, 2024
1 parent 0ba6a2c commit 988af92
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
11 changes: 10 additions & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
AQInt8WeightOnlyQuantizedLinearWeight2,
AQInt8WeightOnlyQuantizedLinearWeight3,
AutoQuantizableLinearWeight,

AQFloat8WeightOnlyQuantizedLinearWeight,
)
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os
Expand All @@ -98,6 +98,7 @@
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)

def _int8wo_api(mod):
if TORCH_VERSION_AT_LEAST_2_4:
Expand Down Expand Up @@ -744,6 +745,14 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
Expand Down
8 changes: 7 additions & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,6 @@ def from_float(cls, weight):
# AQInt8WeightOnlyQuantizedLinearWeight3,
# TODO this gets picked in places where it makes perf worse, why?
AQInt8DynamicallyQuantizedLinearWeight,
AQFloat8WeightOnlyQuantizedLinearWeight,
]

DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [
Expand All @@ -510,6 +509,11 @@ def from_float(cls, weight):
AQInt4G64WeightOnlyQuantizedLinearWeight
]

OTHER_AUTOQUANT_CLASS_LIST = [
AQFloat8WeightOnlyQuantizedLinearWeight,
]


def _change_linears_to_autoquantizable(model, **kwargs):
"""
Converts all linear weight tensors to the
Expand Down Expand Up @@ -634,6 +638,8 @@ def autoquant(
if set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

if qtensor_class_list in OTHER_AUTOQUANT_CLASS_LIST:
assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9), "float8 requires CUDA arch >= 8.9"

# perform initial swap from linear weights
# to AutoQuantizableLinearWeight
Expand Down

0 comments on commit 988af92

Please sign in to comment.