Skip to content

Commit

Permalink
Float8 autoquant weight only (pytorch#866)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored and weifengpy committed Sep 26, 2024
1 parent 0b8dd85 commit 87faf04
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 4 deletions.
11 changes: 10 additions & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
AQInt8WeightOnlyQuantizedLinearWeight2,
AQInt8WeightOnlyQuantizedLinearWeight3,
AutoQuantizableLinearWeight,

AQFloat8WeightOnlyQuantizedLinearWeight,
)
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os
Expand All @@ -98,6 +98,7 @@
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)

def _int8wo_api(mod):
if TORCH_VERSION_AT_LEAST_2_4:
Expand Down Expand Up @@ -744,6 +745,14 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
Expand Down
20 changes: 20 additions & 0 deletions test/kernel/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

logging.basicConfig(level=logging.INFO)

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)

class TestQuantFlow(unittest.TestCase):

Expand Down Expand Up @@ -49,6 +50,25 @@ def test_int_mm(self, device, dtype):
assert out32_2.dtype == out32_1.dtype
torch.testing.assert_allclose(out32_1, out32_2)

@parameterized.expand(
[
("cuda", torch.bfloat16),
("cuda", torch.float16),
]
)
@unittest.skipIf(not is_H100, "Needs H100")
def test_int_mm_float8(self, device, dtype):
from torchao.kernel import intmm

dtype = torch.bfloat16
m, k, n = (128, 64, 16)
x = torch.randn(m, k, dtype=dtype, device=device)
w = torch.randn(n, k, dtype=dtype, device=device).t()
x_float8 = x.to(dtype=torch.float8_e4m3fn)
w_float8 = w.to(dtype=torch.float8_e4m3fn)
out32_1 = intmm.safe_int_mm(x_float8, w_float8)
assert out32_1.dtype == torch.int32

@parameterized.expand(
[
("cuda", torch.bfloat16),
Expand Down
2 changes: 1 addition & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def from_hp_to_floatx(
input_float: torch.Tensor,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
scale_dtype: Optional[torch.dtype],
layout_type: LayoutType,
scale_dtype: Optional[torch.dtype] = None,
):

if target_dtype in FP8_TYPES:
Expand Down
7 changes: 6 additions & 1 deletion torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,12 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
input = (
input.contiguous()
) # (it seems the transpose makes cublas check the above j constraint on i)
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
try:
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
except Exception:
# fallback path, would run on H100 for float8 dtypes
# Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn'
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)
else:
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
"""
Expand Down
25 changes: 24 additions & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Int8WeightOnlyQuantizedLinearWeight,
QuantizedLinearWeightBase,
)
from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType
from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType, Float8LayoutType
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torch.utils._python_dispatch import return_and_correct_aliasing
from .quant_primitives import (
Expand Down Expand Up @@ -477,6 +477,22 @@ def _quantized_linear_op(act_mat, w_qtensor, bias):
def from_float(cls, weight):
return weight

class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
"""
AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn
"""
target_dtype: torch.dtype = torch.float8_e4m3fn

@staticmethod
def _quantized_linear_op(act_mat, w_qtensor, bias):
return torch.nn.functional.linear(act_mat, w_qtensor.dequantize(), bias)

@classmethod
def from_float(cls, weight):
block_size = (1, weight.shape[1])
return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType())


# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
DEFAULT_AUTOQUANT_CLASS_LIST = [
AQFloatLinearWeight,
Expand All @@ -493,6 +509,11 @@ def from_float(cls, weight):
AQInt4G64WeightOnlyQuantizedLinearWeight
]

OTHER_AUTOQUANT_CLASS_LIST = [
AQFloat8WeightOnlyQuantizedLinearWeight,
]


def _change_linears_to_autoquantizable(model, **kwargs):
"""
Converts all linear weight tensors to the
Expand Down Expand Up @@ -617,6 +638,8 @@ def autoquant(
if set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

if qtensor_class_list in OTHER_AUTOQUANT_CLASS_LIST:
assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9), "float8 requires CUDA arch >= 8.9"

# perform initial swap from linear weights
# to AutoQuantizableLinearWeight
Expand Down

0 comments on commit 87faf04

Please sign in to comment.