Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Float8 autoquant weight only #866

Merged
merged 6 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
AQInt8WeightOnlyQuantizedLinearWeight2,
AQInt8WeightOnlyQuantizedLinearWeight3,
AutoQuantizableLinearWeight,

AQFloat8WeightOnlyQuantizedLinearWeight,
)
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os
Expand All @@ -98,6 +98,7 @@
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)

def _int8wo_api(mod):
if TORCH_VERSION_AT_LEAST_2_4:
Expand Down Expand Up @@ -744,6 +745,14 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
Expand Down
20 changes: 20 additions & 0 deletions test/kernel/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

logging.basicConfig(level=logging.INFO)

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)

class TestQuantFlow(unittest.TestCase):

Expand Down Expand Up @@ -49,6 +50,25 @@ def test_int_mm(self, device, dtype):
assert out32_2.dtype == out32_1.dtype
torch.testing.assert_allclose(out32_1, out32_2)

@parameterized.expand(
[
("cuda", torch.bfloat16),
("cuda", torch.float16),
]
)
@unittest.skipIf(not is_H100, "Needs H100")
def test_int_mm_float8(self, device, dtype):
from torchao.kernel import intmm

dtype = torch.bfloat16
m, k, n = (128, 64, 16)
x = torch.randn(m, k, dtype=dtype, device=device)
w = torch.randn(n, k, dtype=dtype, device=device).t()
x_float8 = x.to(dtype=torch.float8_e4m3fn)
w_float8 = w.to(dtype=torch.float8_e4m3fn)
out32_1 = intmm.safe_int_mm(x_float8, w_float8)
assert out32_1.dtype == torch.int32

@parameterized.expand(
[
("cuda", torch.bfloat16),
Expand Down
2 changes: 1 addition & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def from_hp_to_floatx(
input_float: torch.Tensor,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
scale_dtype: Optional[torch.dtype],
layout_type: LayoutType,
scale_dtype: Optional[torch.dtype] = None,
):

if target_dtype in FP8_TYPES:
Expand Down
7 changes: 6 additions & 1 deletion torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,12 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
input = (
input.contiguous()
) # (it seems the transpose makes cublas check the above j constraint on i)
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe adding a comment to this would be helpful, how these two branches are handled?

Copy link
Contributor Author

@jainapurva jainapurva Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The except is executed if it's a float8 dtype on H100, as there's no implementation for addmm_cuda for float8 dtypes. Added as comment

return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
except Exception:
# fallback path, would run on H100 for float8 dtypes
# Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn'
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)
else:
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
"""
Expand Down
25 changes: 24 additions & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Int8WeightOnlyQuantizedLinearWeight,
QuantizedLinearWeightBase,
)
from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType
from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType, Float8LayoutType
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torch.utils._python_dispatch import return_and_correct_aliasing
from .quant_primitives import (
Expand Down Expand Up @@ -477,6 +477,22 @@ def _quantized_linear_op(act_mat, w_qtensor, bias):
def from_float(cls, weight):
return weight

class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
"""
AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn
"""
target_dtype: torch.dtype = torch.float8_e4m3fn

@staticmethod
def _quantized_linear_op(act_mat, w_qtensor, bias):
return torch.nn.functional.linear(act_mat, w_qtensor.dequantize(), bias)

@classmethod
def from_float(cls, weight):
block_size = (1, weight.shape[1])
return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType())


# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
DEFAULT_AUTOQUANT_CLASS_LIST = [
AQFloatLinearWeight,
Expand All @@ -493,6 +509,11 @@ def from_float(cls, weight):
AQInt4G64WeightOnlyQuantizedLinearWeight
]

OTHER_AUTOQUANT_CLASS_LIST = [
AQFloat8WeightOnlyQuantizedLinearWeight,
]


def _change_linears_to_autoquantizable(model, **kwargs):
"""
Converts all linear weight tensors to the
Expand Down Expand Up @@ -617,6 +638,8 @@ def autoquant(
if set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

if qtensor_class_list in OTHER_AUTOQUANT_CLASS_LIST:
assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9), "float8 requires CUDA arch >= 8.9"

# perform initial swap from linear weights
# to AutoQuantizableLinearWeight
Expand Down
Loading