From cf99950eaeff2b767fb02b49e229f8658da1d5e4 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 28 Aug 2024 20:13:21 -0700 Subject: [PATCH 1/2] mixin --- scripts/hf_eval.py | 2 +- test/dtypes/test_affine_quantized_float.py | 141 +++++++++++++ torchao/dtypes/__init__.py | 4 + torchao/dtypes/affine_quantized_tensor.py | 218 +++++++++++++++++++-- torchao/dtypes/utils.py | 14 +- torchao/float8/__init__.py | 9 +- torchao/float8/float8_python_api.py | 9 +- torchao/float8/inference.py | 25 ++- torchao/quantization/__init__.py | 2 + torchao/quantization/quant_api.py | 69 ++++++- torchao/quantization/quant_primitives.py | 86 ++++---- torchao/utils.py | 9 + 12 files changed, 525 insertions(+), 63 deletions(-) create mode 100644 test/dtypes/test_affine_quantized_float.py diff --git a/scripts/hf_eval.py b/scripts/hf_eval.py index db7a6a9b7..5f008ee43 100644 --- a/scripts/hf_eval.py +++ b/scripts/hf_eval.py @@ -114,7 +114,7 @@ def all_linear(mod, name): parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "fp6", "None"], help='Which quantization technique to apply') + parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo", "autoquant", "None"], help='Which quantization technique to apply') parser.add_argument('-s', '--sparsity', default = "None", choices=["semi_sparse", "semi_sparse_mlp_only", "None"], help='Which sparsity technique to apply') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--save', action='store_true', help='Whether to save the model.') diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py new file mode 100644 index 000000000..8e700116b --- /dev/null +++ b/test/dtypes/test_affine_quantized_float.py @@ -0,0 +1,141 @@ +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + unwrap_tensor_subclass, +) +import pytest + +if not TORCH_VERSION_AT_LEAST_2_5: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + +from numpy import full +from torch.testing._internal.common_utils import ( + run_tests, +) +from torch._inductor.test_case import TestCase as InductorTestCase +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) +from torch._dynamo.testing import CompileCounterWithBackend + +from torchao.quantization import ( + quantize_, + float8_weight_only, + float8_dynamic_activation_float8_weight, +) +from torchao.float8.float8_utils import compute_error +import torch +import unittest +import pytest +import tempfile +import copy +import random + +from unittest.mock import patch + + +random.seed(0) +torch.manual_seed(0) + +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) +is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) + + +class ToyLinearModel(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.linear1 = torch.nn.Linear(in_features, out_features, bias=False) + self.linear2 = torch.nn.Linear(out_features, in_features, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +class TestAffineQuantizedFloat8Basic(TestCase): + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_tensor_core_layout_transpose(self): + l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + t = l.weight + shape = t.shape + apply_float8_weight_only_quant = float8_weight_only() + ql = apply_float8_weight_only_quant(l) + aqt = ql.weight + aqt_shape = aqt.shape + assert aqt_shape == shape + + # transpose shape test + for _ in range(10): + t = t.t() + aqt = aqt.t() + shape = t.shape + aqt_shape = aqt.shape + assert aqt_shape == shape + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_weights_only_save_load(self): + with torch.no_grad(): + for apply_quant in [float8_weight_only()]: + # TODO Fails when l requires grad + l = torch.nn.Linear(128, 256).eval().to(torch.bfloat16).to("cuda") + ql = apply_quant(l) + with tempfile.NamedTemporaryFile() as f: + torch.save(ql.state_dict(), f) + f.seek(0) + # `weights_only=True` is enabled for torch 2.5+ + if TORCH_VERSION_AT_LEAST_2_5: + _ = torch.load(f, weights_only=True) + else: + _ = torch.load(f, weights_only=False) + + +class TestAffineQuantizedFloat8Compile(InductorTestCase): + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not is_cuda_8_9, "Need H100") + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) + @common_utils.parametrize("mode", ["dynamic", "weight-only"]) + @common_utils.parametrize("compile", [True, False]) + # Inputs are (M,..), K, N + @common_utils.parametrize( + "sizes", + [ + ((128,), 256, 128), + ((256,), 512, 256), + ((64,), 128, 64), + ((32, 128), 64, 256), + ((64, 256), 512, 128), + ], + ) + def test_dynamic_fp8_linear( + self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple + ): + M, N, K = sizes + input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") + + mode_map = { + "dynamic": float8_dynamic_activation_float8_weight, + "weight-only": float8_weight_only, + } + + # Create a linear layer with bfloat16 dtype + model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + + quantized_model = copy.deepcopy(model) + factory = mode_map[mode]() + quantize_(model, factory) + + if compile: + quantized_model = torch.compile(quantized_model, fullgraph=True) + + output_original = model(input_tensor) + output_quantized = quantized_model(input_tensor) + + assert compute_error(output_original, output_quantized) > 20, "Error is too low" + + +common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile) + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 36c8b342e..f9d8d1e4a 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -12,6 +12,8 @@ PlainLayoutType, SemiSparseLayoutType, TensorCoreTiledLayoutType, + Float8LayoutType, + Float8AQTLayout, ) __all__ = [ @@ -27,4 +29,6 @@ "PlainLayoutType", "SemiSparseLayoutType", "TensorCoreTiledLayoutType", + "Float8LayoutType", + "Float8AQTLayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 3f9c2bbb9..2b91c7804 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1,5 +1,5 @@ import torch -from typing import Dict, Callable, Any, Tuple, Optional +from typing import Dict, Callable, Any, Tuple, Optional, Union from collections import defaultdict import functools import math @@ -29,6 +29,7 @@ LayoutType, PlainLayoutType, is_device, + get_out_shape, ) from torch.utils._python_dispatch import is_traceable_wrapper_subclass from dataclasses import dataclass @@ -36,8 +37,10 @@ find_multiple, TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_5, + _is_float8_type ) +from torchao.float8.float8_tensor import ScaledMMConfig aten = torch.ops.aten ############################### @@ -50,7 +53,7 @@ class AQTLayout(TorchAOBaseTensor): def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Get the plain (unpacked) Tensor for the layout Tensor - Returns int_data, scale and zero_point + Returns data, scale and zero_point Can be overwritten if other types of AQTLayout Tensor has different numbers of plain tensors """ pass @@ -61,17 +64,18 @@ def get_layout_type(self) -> LayoutType: @classmethod def from_plain( cls, - int_data: torch.Tensor, + data: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, layout_type: LayoutType, ): + """ Construct a Layout from data, scale, zero_point and the layout_type""" pass def __repr__(self): - int_data, scale, zero_point = self.get_plain() + data, scale, zero_point = self.get_plain() layout_type = self.get_layout_type() - return f"{self.__class__.__name__}(int_data={int_data}, scale={scale}, zero_point={zero_point}, layout_type={layout_type})" + return f"{self.__class__.__name__}(data={data}, scale={scale}, zero_point={zero_point}, layout_type={layout_type})" ############################## @@ -121,8 +125,8 @@ def __new__( layout_tensor: AQTLayout, block_size: Tuple[int, ...], shape: torch.Size, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, dtype=None, strides=None, @@ -143,8 +147,8 @@ def __init__( layout_tensor: AQTLayout, block_size: Tuple[int, ...], shape: torch.Size, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, dtype=None, strides=None, @@ -161,7 +165,7 @@ def __repr__(self): f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" ) - def dequantize(self, output_dtype=None): + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: if output_dtype is None: output_dtype = self.dtype @@ -170,8 +174,18 @@ def dequantize(self, output_dtype=None): int_data, scale = self.layout_tensor.get_plain() return dequantize_affine_fpx(int_data, scale, self.layout_type.ebits, self.layout_type.mbits, output_dtype=output_dtype) else: - int_data, scale, zero_point = self.layout_tensor.get_plain() - return dequantize_affine(int_data, self.block_size, scale, zero_point, int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) + data, scale, zero_point = self.layout_tensor.get_plain() + return dequantize_affine( + data, + self.block_size, + scale, + zero_point, + data.dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain, + output_dtype=output_dtype, + ) @staticmethod def _quantized_linear_op(input_tensor, weight_tensor, bias): @@ -284,9 +298,11 @@ def from_hp_to_floatx( cls, input_float: torch.Tensor, block_size: Tuple[int, ...], - target_dtype: torch.dtype = torch.float8_e4m3fn, + target_dtype: torch.dtype, + scale_dtype: Optional[torch.dtype] = None, layout_type: LayoutType = PlainLayoutType(), ): + if target_dtype in FP8_TYPES: return cls.from_hp_to_intx( input_float=input_float, @@ -296,11 +312,11 @@ def from_hp_to_floatx( quant_min=math.ceil(torch.finfo(target_dtype).min), quant_max=math.ceil(torch.finfo(target_dtype).max), eps=torch.finfo(torch.float32).eps, - scale_dtype=None, + scale_dtype=scale_dtype, zero_point_dtype=None, preserve_zero=True, - zero_point_domain=ZeroPointDomain.INT, - layout_type=PlainLayoutType(), + zero_point_domain=None, + layout_type=layout_type, use_hqq=False, ) else: @@ -419,6 +435,13 @@ def extra_repr(self): return f"inner_k_tiles={self.inner_k_tiles}" +@dataclass(frozen=True) +class Float8LayoutType(LayoutType): + mm_config: ScaledMMConfig + + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + return input + @register_layout_cls(PlainLayoutType) class PlainAQTLayout(AQTLayout): """ @@ -572,6 +595,110 @@ def from_plain( return cls(int_data_compressed, scale, zero_point, layout_type) +@register_layout_cls(Float8LayoutType) +class Float8AQTLayout(AQTLayout): + """ + Layout storage class for float8 layout for affine quantized tensor + """ + float8_data: torch.Tensor + scale: torch.Tensor + transposed: bool + + def __new__( + cls, + float8_data: torch.Tensor, + scale: torch.Tensor, + transposed: bool, + layout_type: LayoutType, + ): + kwargs = {} + kwargs["device"] = float8_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else float8_data.layout + ) + kwargs["dtype"] = float8_data.dtype + kwargs["requires_grad"] = False + shape = float8_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + float8_data: torch.Tensor, + scale: torch.Tensor, + transposed: bool, + layout_type: LayoutType, + ): + self.float8_data = float8_data + self.scale = scale + self.transposed = transposed + self.layout_type = layout_type + + def _apply_fn_to_data(self, fn): + """ Applys a fn to all tensor components stored on this class""" + fn(self.float8_data) + fn(self.scale) + return self + + def __tensor_flatten__(self): + return ["float8_data", "scale"], [self.transposed, self.layout_type] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + float8_data, scale = tensor_data_dict["float8_data"], tensor_data_dict["scale"] + transposed, layout_type, = tensor_attributes + return cls(float8_data, scale, transposed, layout_type) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + if func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + args[0].transposed = not args[0].transposed + return return_and_correct_aliasing(func, args, kwargs, args[0]) + + raise NotImplementedError( + f"Float8AQTLayout dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return self.float8_data, self.scale, None + + def get_layout_type(self) -> LayoutType: + return self.layout_type + + @classmethod + def from_plain( + cls, + data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + layout_type: LayoutType, + ): + """ Main entrypoint for constructing Float8Layout Tensor""" + assert _is_float8_type(data.dtype), f"Float8 Layout must be constructed from float8 dtype but got {data.dtype}" + return cls(data, scale, False, layout_type) + + def __repr__(self): + float8_data, scale, _ = self.get_plain() + layout_type = self.get_layout_type() + return (f"{self.__class__.__name__}(\n" + f"float8_data={float8_data},\n" + f"scale={scale},\n" + f"transposed={self.transposed}, " + f"layout_type={layout_type})") + + @register_layout_cls(TensorCoreTiledLayoutType) class TensorCoreTiledAQTLayout(AQTLayout): """ @@ -724,6 +851,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def get_layout_type(self) -> LayoutType: return self.layout_type + ##################################################### # torch functional and aten operator implementation # ##################################################### @@ -987,10 +1115,68 @@ def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias): return out.view(*act.shape[:-1], out_dim).to(act.dtype) +def _linear_fp_act_fp8_tensor_wise_weight_check( + input_tensor: torch.Tensor, + weight_tensor: AffineQuantizedTensor, + bias: Optional[torch.Tensor], +) -> bool: + def check_aqt_tensorwise(aqt: AffineQuantizedTensor) -> bool: + return ( + isinstance(aqt, AffineQuantizedTensor) and + isinstance(aqt.layout_tensor, Float8AQTLayout) + and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + and aqt.shape == aqt.block_size + ) + return check_aqt_tensorwise(input_tensor) and check_aqt_tensorwise(weight_tensor) + + +def _linear_fp_act_fp8_weight_impl( + input_tensor: AffineQuantizedTensor, + weight_tensor: AffineQuantizedTensor, + bias: Optional[torch.Tensor], +): + """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" + from torchao.float8.inference import cast_to_float8_e4m3_inference, preprocess_data + from torchao.float8.float8_tensor import ScaledMMConfig + from torchao.float8.float8_python_api import addmm_float8_unwrapped + + scaled_mm_config = weight_tensor.layout_type.mm_config + scaled_mm_config = scaled_mm_config if scaled_mm_config is not None else ScaledMMConfig() + + w_layout = weight_tensor.layout_tensor + w_data = weight_tensor.layout_tensor.float8_data + w_data = w_data.T if w_layout.transposed else w_data + w_scale = w_layout.scale + w_scale = w_scale if w_layout.transposed else w_scale + + out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + + inpt_data = input_tensor.layout_tensor.float8_data + # Handle case where input tensor is more than 2D + inpt_data = inpt_data.reshape(-1, input_tensor.shape[-1]) + input_scale = input_tensor.layout_tensor.scale + if input_scale.dim() >= 2: + input_scale = input_scale.reshape(-1, input_scale.shape[-1]) + + inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) + + return addmm_float8_unwrapped( + inpt_data, + input_scale, + w_data, + w_scale, + output_dtype=input_tensor.dtype, + bias=bias, + use_fast_accum=scaled_mm_config.use_fast_accum, + inverse_scale=False + ).reshape(out_shape) + + def _register_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), + (_linear_fp_act_fp8_tensor_wise_weight_check, _linear_fp_act_fp8_weight_impl), (_linear_quantized_act_fallback_check, _linear_quantized_act_fallback_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index d906251f8..036a5ca92 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -1,5 +1,5 @@ import torch -from typing import Dict, Callable, Union +from typing import Dict, Callable, Union, Tuple from collections import defaultdict import functools from dataclasses import dataclass @@ -143,3 +143,15 @@ def _get_layout_tensor_constructor(cls: Callable, layout_type_class: type(Layout def is_device(target_device_str: str, device: Union[str, torch.device]): return torch.device(device).type == target_device_str + +def get_out_shape(input_shape: Tuple[int], weight_shape: Tuple[int]) -> Tuple[int, int]: + """Returns the unflattened shape of the input tensor. + Args: + input_shape: The input tensor shape possibly more than 2 dimensions + weight_shape: The weight tensor shape. + Returns: + The unflattened shape of the input tensor. + """ + out_dim = weight_shape[0] + inpt_dims = input_shape[:-1] + return (*inpt_dims, out_dim) diff --git a/torchao/float8/__init__.py b/torchao/float8/__init__.py index 56c7b28f7..43065d2b8 100644 --- a/torchao/float8/__init__.py +++ b/torchao/float8/__init__.py @@ -25,10 +25,13 @@ ) from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp -# Needed to load Float8Tensor with weights_only = True -from torch.serialization import add_safe_globals +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig]) + +if TORCH_VERSION_AT_LEAST_2_5: + # Needed to load Float8Tensor with weights_only = True + from torch.serialization import add_safe_globals + add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig]) __all__ = [ # configuration diff --git a/torchao/float8/float8_python_api.py b/torchao/float8/float8_python_api.py index 16e270574..ade00a0a6 100644 --- a/torchao/float8/float8_python_api.py +++ b/torchao/float8/float8_python_api.py @@ -30,14 +30,19 @@ def addmm_float8_unwrapped( output_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, use_fast_accum: bool = False, + inverse_scale: bool = True ) -> torch.Tensor: """ This is the unwrapped version of addmm_float8, which does not take in Float8Tensors as inputs. This is used to standardize the logic between subclassed and non subclassed versions of the linear module. """ - a_inverse_scale = a_scale.reciprocal() - b_inverse_scale = b_scale.reciprocal() + if inverse_scale: + a_inverse_scale = a_scale.reciprocal() + b_inverse_scale = b_scale.reciprocal() + else: + a_inverse_scale = a_scale + b_inverse_scale = b_scale if output_dtype == torch.float32 and bias is not None: # Bias is not supported by _scaled_mm when output is fp32 output = torch._scaled_mm( diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index ccf83d7ce..66f83d933 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from enum import auto, Enum -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Tuple import torch import torch.nn as nn @@ -242,3 +242,26 @@ def quantize_to_float8( lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum), module_filter_fn=module_filter_fn, ) + +from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul + +def preprocess_data(a_data: torch.Tensor, b_data: torch.Tensor, scaled_mm_config: ScaledMMConfig) -> Tuple[torch.Tensor, torch.Tensor]: + """ Preprocess the inner fp8 data tensors for admmm + Args: + a_data: Input tensor A. + b_data: Input tensor B. + scaled_mm_config: Configuration for _scaled_mm. + Returns: + Preprocessed tensors A and B in the format for _scaled_mm. + """ + if scaled_mm_config.pad_inner_dim: + assert a_data.size(1) == b_data.size( + 0 + ), f"Inner dims must match for mm, got {a_data.size(1)} and {b_data.size(0)}" + a_data = pad_tensor_for_matmul(a_data, dims=1) + b_data = pad_tensor_for_matmul(b_data, dims=0) + if not is_row_major(a_data.stride()): + a_data = a_data.contiguous() + if is_row_major(b_data.stride()): + b_data = b_data.t().contiguous().t() + return a_data, b_data diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 75e762ce3..227214843 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -43,4 +43,6 @@ "fpx_weight_only", "LinearActivationQuantizedTensor", "to_linear_activation_quantized", + "float8_weight_only", + "float8_dynamic_activation_float8_weight" ] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3ff6b4fe8..54183f670 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -28,7 +28,9 @@ PlainLayoutType, AffineQuantizedTensor, SemiSparseLayoutType, - to_affine_quantized_floatx + to_affine_quantized_floatx, + Float8AQTLayout, + Float8LayoutType ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_4, @@ -57,6 +59,7 @@ from .utils import _get_per_token_block_size import logging from .autoquant import autoquant, AutoQuantizableLinearWeight +from torchao.float8.float8_tensor import ScaledMMConfig __all__ = [ "swap_conv2d_1x1_to_linear", @@ -75,6 +78,7 @@ "float8_weight_only", "uintx_weight_only", "fpx_weight_only", + "float8_dynamic_activation_float8_weight", ] from .GPTQ import ( @@ -158,7 +162,6 @@ def change_linear_weights_to_int4_woqtensors(model, groupsize=128, inner_k_tiles ### TO BE DEPRECATED END - def _replace_with_custom_fn_if_matches_filter( model, replacement_fn, @@ -491,19 +494,77 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): """ return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()) + def float8_weight_only(target_dtype: torch.dtype = torch.float8_e4m3fn): """ Applies float8 weight-only symmetric per-channel quantization to linear layers. + + Args: + target_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. + + Note: + The actual matmul will be computed in original precision of the weight tensor. + """ from torchao.dtypes import to_affine_quantized_floatx + def apply_float8wo_quant(weight): - # avoid circular dep block_size = (1, weight.shape[1]) - return to_affine_quantized_floatx(input_float=weight, block_size=block_size, target_dtype=target_dtype) + return to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=target_dtype, + layout_type=Float8LayoutType(mm_config=None), + ) return _get_linear_subclass_inserter(apply_float8wo_quant) +def float8_dynamic_activation_float8_weight( + target_dtype: torch.dtype = torch.float8_e4m3fn, + activation_dtype: torch.dtype = torch.float8_e4m3fn, + mm_config: ScaledMMConfig = ScaledMMConfig(use_fast_accum=True) +): + """ + Applies float8 dynamic symmetric per-tensor quantization to both activations and weights of linear layers. + + Args: + target_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. + activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn. + mm_config (ScaledMMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. + + """ + + from torchao.dtypes import to_affine_quantized_floatx + + #TODO we are hardcoding TensorWise scaling, will follow up PR for Tensorwise scaling + def apply_float8_dynamic_activation_quant(weight: torch.Tensor): + quantized_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=weight.shape, + target_dtype=target_dtype, + scale_dtype=torch.float32, + layout_type=Float8LayoutType(mm_config=None), + ) + + def input_quant_func(x: torch.Tensor): + activation = to_affine_quantized_floatx( + input_float=x, + block_size=x.shape, + target_dtype=activation_dtype, + scale_dtype=torch.float32, + layout_type=Float8LayoutType(mm_config=None), + ) + return activation + + quantized_weight = to_linear_activation_quantized( + quantized_weight, input_quant_func + ) + return quantized_weight + + return _get_linear_subclass_inserter(apply_float8_dynamic_activation_quant) + + def uintx_weight_only(dtype, group_size=64, pack_dim=-1): """ Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 2f1de3314..72ec988ca 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -14,7 +14,7 @@ TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, ) -from torchao.utils import _register_custom_op +from torchao.utils import _register_custom_op, _is_float8_type from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones @@ -172,8 +172,8 @@ def quantize_affine( scale: torch.Tensor, zero_point: Optional[torch.Tensor], output_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> torch.Tensor: """ @@ -217,7 +217,7 @@ def quantize_affine( output_dtype, quant_min, quant_max, - zero_point_domain.name, + zero_point_domain.name if zero_point_domain is not None else None, ) @@ -228,9 +228,9 @@ def _quantize_affine( scale: torch.Tensor, zero_point: Optional[torch.Tensor], output_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - zero_point_domain: str = ZeroPointDomain.INT.name, + quant_min: Optional[Union[int, float, bool]] = None, + quant_max: Optional[Union[int, float, bool]] = None, + zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, ) -> torch.Tensor: """op definition that has compatible signatures with custom op library """ @@ -255,9 +255,9 @@ def _quantize_affine_no_dtype_cast( block_size: List[int], scale: torch.Tensor, zero_point: Optional[torch.Tensor], - quant_min: int, - quant_max: int, - zero_point_domain: str = ZeroPointDomain.INT.name, + quant_min: Union[int, float], + quant_max: Union[int, float], + zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, ) -> torch.Tensor: # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size @@ -273,6 +273,11 @@ def _quantize_affine_no_dtype_cast( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) + if zero_point_domain is None: + quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) + quant = quant.view(original_shape) + return quant + if zero_point_domain == ZeroPointDomain.INT.name: quant = torch.clamp( torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max @@ -297,8 +302,8 @@ def dequantize_affine( scale: torch.Tensor, zero_point: Optional[torch.Tensor], input_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, *, output_dtype: torch.dtype = torch.float32, @@ -332,7 +337,7 @@ def dequantize_affine( input_dtype, quant_min, quant_max, - zero_point_domain.name, + zero_point_domain.name if zero_point_domain is not None else None, output_dtype=output_dtype, ) @@ -344,9 +349,9 @@ def _dequantize_affine( scale: torch.Tensor, zero_point: Optional[torch.Tensor], input_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - zero_point_domain: str = ZeroPointDomain.INT.name, + quant_min: Optional[Union[int, float, bool]] = None, + quant_max: Optional[Union[int, float, bool]] = None, + zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """op definition that has compatible signatures with custom op library @@ -373,11 +378,12 @@ def _dequantize_affine_no_dtype_check( block_size: List[int], scale: torch.Tensor, zero_point: Optional[torch.Tensor], - quant_min: int, - quant_max: int, - zero_point_domain: str = ZeroPointDomain.INT.name, + quant_min: Union[int, float], + quant_max: Union[int, float], + zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: + """ This function converts AQT tensors to their high precision floating point representation""" assert len(block_size) == input.dim(), f"Got input dim:{input.dim()}, block_size: {block_size}" shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) original_shape = input.shape @@ -385,7 +391,16 @@ def _dequantize_affine_no_dtype_check( shape_after_reduction = shape_for_reduction for i in reduction_dims: shape_after_reduction[i] = 1 - scale = scale.view(shape_after_reduction) + scale = scale.view(shape_after_reduction) + + # This case handles dequantization for float8 + if zero_point_domain is None: + assert zero_point is None, "zero_point should be None when zero_point_domain is None" + assert _is_float8_type(input.dtype), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}" + dequant = input.to(output_dtype) + dequant = dequant * scale + return dequant.view(original_shape).to(output_dtype) + if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) @@ -417,8 +432,8 @@ def fake_quantize_affine( scale: torch.Tensor, zero_point: Optional[torch.Tensor], quant_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> torch.Tensor: """ @@ -461,8 +476,8 @@ def fake_quantize_affine_cachemask( scale: torch.Tensor, zero_point: Optional[torch.Tensor], quant_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -504,8 +519,8 @@ def _do_fake_quantize_affine( scale: torch.Tensor, zero_point: Optional[torch.Tensor], quant_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -541,8 +556,8 @@ def choose_qparams_affine( mapping_type: MappingType, block_size: Tuple[int, ...], target_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, eps: Optional[float] = None, scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, @@ -592,7 +607,7 @@ def choose_qparams_affine( scale_dtype, zero_point_dtype, preserve_zero, - zero_point_domain.name + zero_point_domain.name if zero_point_domain is not None else None, ) @@ -643,13 +658,13 @@ def _choose_qparams_affine( mapping_type: str, block_size: List[int], target_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_min: Optional[Union[int, float, bool]] = None, + quant_max: Optional[Union[int, float, bool]] = None, eps: Optional[float] = None, scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain: str = "INT", + zero_point_domain: Optional[str] = "INT", min_val: Optional[torch.Tensor] = None, max_val: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -695,7 +710,7 @@ def _choose_qparams_affine( scale = max_val_pos / (float(quant_max - quant_min) / 2) if not preserve_zero: raise ValueError("preserve_zero == False is not supported for symmetric quantization") - if zero_point_domain != ZeroPointDomain.INT.name: + if zero_point_domain is not None and zero_point_domain != ZeroPointDomain.INT.name: raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization") scale = torch.clamp(scale, min=eps) zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) @@ -713,7 +728,8 @@ def _choose_qparams_affine( return scale.to(dtype=scale_dtype), zero_point.to(dtype=zero_point_dtype) -#HQQ + +# HQQ ############################################################################ # Shrinking operator (proximal operator for the lp norm) def _shrink_lp_op(x: torch.Tensor, beta: float, lp_norm: float) -> torch.Tensor: @@ -799,7 +815,7 @@ def _convert_to_affinequantized_format(W_q: torch.Tensor, scale: torch.Tensor, z W_q_ao = W_q.view(shape) return W_q_ao, scale_ao, zero_ao -#Main hqq quantizer function +# Main hqq quantizer function def quantize_affine_hqq( tensor: torch.Tensor, nbits: float = 4, diff --git a/torchao/utils.py b/torchao/utils.py index 329d4790f..e55dc3dc0 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -309,6 +309,15 @@ def _get_to_kwargs(self, *args, **kwargs): return kwargs +def _is_float8_type(dtype: torch.dtype) -> bool: + fp8_types = { + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + } + return dtype in fp8_types + def parse_version(version_string): # Extract just the X.Y.Z part from the version string From c771c640c6148eb938e60d882c6e09ce3f9e3988 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 28 Aug 2024 20:59:44 -0700 Subject: [PATCH 2/2] fix memory being held by autograd --- test/dtypes/test_affine_quantized.py | 71 ++++++++++++++-------- test/dtypes/test_affine_quantized_float.py | 50 ++------------- torchao/dtypes/affine_quantized_tensor.py | 45 +++++++++----- torchao/float8/inference.py | 8 ++- torchao/quantization/quant_api.py | 23 ++++--- torchao/quantization/quant_primitives.py | 38 ++++++------ 6 files changed, 122 insertions(+), 113 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index e01946888..a4f501098 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -10,12 +10,33 @@ int8_dynamic_activation_int8_semi_sparse_weight, float8_weight_only, ) +from torch.testing._internal import common_utils from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 import torch import unittest import tempfile +is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) + + +def get_quantization_functions(do_sparse: bool, do_int4: bool): + base_functions = [ + int8_weight_only(), + int8_dynamic_activation_int4_weight(), + int8_dynamic_activation_int8_weight(), + ] + if do_int4: + base_functions.append(int4_weight_only(group_size=32)) + + if do_sparse: + base_functions.append(int8_dynamic_activation_int8_semi_sparse_weight()) + + if is_cuda_8_9: + base_functions.append(float8_weight_only()) + + return base_functions + class TestAffineQuantized(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @@ -38,36 +59,36 @@ def test_tensor_core_layout_transpose(self): self.assertEqual(aqt_shape, shape) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_weights_only(self): - for apply_quant in [int4_weight_only(group_size=32), int8_weight_only(), int8_dynamic_activation_int4_weight(), - int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int8_semi_sparse_weight(), float8_weight_only()]: - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - ql = apply_quant(l) - with tempfile.NamedTemporaryFile() as f: - torch.save(ql.state_dict(), f) - f.seek(0) - # `weights_only=True` is enabled for torch 2.5+ - if TORCH_VERSION_AT_LEAST_2_5: - _ = torch.load(f, weights_only=True) - else: - _ = torch.load(f, weights_only=False) + @common_utils.parametrize("apply_quant", get_quantization_functions(True, True)) + def test_weights_only(self, apply_quant): + l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + ql = apply_quant(l) + with tempfile.NamedTemporaryFile() as f: + torch.save(ql.state_dict(), f) + f.seek(0) + # `weights_only=True` is enabled for torch 2.5+ + if TORCH_VERSION_AT_LEAST_2_5: + _ = torch.load(f, weights_only=True) + else: + _ = torch.load(f, weights_only=False) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_to_device(self): - from torchao.quantization import quantize_ - for apply_quant in [int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight()]: - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(l) - ql.to("cuda") + @common_utils.parametrize("apply_quant", get_quantization_functions(False, False)) + def test_to_device(self, apply_quant): + l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + ql = apply_quant(l) + ql.to("cuda") + + l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + ql = apply_quant(l) + ql.to(device="cuda") - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(l) - ql.to(device="cuda") + l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + ql = apply_quant(l) + ql.cuda() - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(l) - ql.cuda() +common_utils.instantiate_parametrized_tests(TestAffineQuantized) if __name__ == "__main__": run_tests() diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 8e700116b..7e2ce278d 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -13,10 +13,6 @@ ) from torch._inductor.test_case import TestCase as InductorTestCase from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import ( - TestCase, - run_tests, -) from torch._dynamo.testing import CompileCounterWithBackend from torchao.quantization import ( @@ -54,46 +50,9 @@ def forward(self, x): return x -class TestAffineQuantizedFloat8Basic(TestCase): - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_tensor_core_layout_transpose(self): - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - t = l.weight - shape = t.shape - apply_float8_weight_only_quant = float8_weight_only() - ql = apply_float8_weight_only_quant(l) - aqt = ql.weight - aqt_shape = aqt.shape - assert aqt_shape == shape - - # transpose shape test - for _ in range(10): - t = t.t() - aqt = aqt.t() - shape = t.shape - aqt_shape = aqt.shape - assert aqt_shape == shape - - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_weights_only_save_load(self): - with torch.no_grad(): - for apply_quant in [float8_weight_only()]: - # TODO Fails when l requires grad - l = torch.nn.Linear(128, 256).eval().to(torch.bfloat16).to("cuda") - ql = apply_quant(l) - with tempfile.NamedTemporaryFile() as f: - torch.save(ql.state_dict(), f) - f.seek(0) - # `weights_only=True` is enabled for torch 2.5+ - if TORCH_VERSION_AT_LEAST_2_5: - _ = torch.load(f, weights_only=True) - else: - _ = torch.load(f, weights_only=False) - - class TestAffineQuantizedFloat8Compile(InductorTestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Need H100") + @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("mode", ["dynamic", "weight-only"]) @common_utils.parametrize("compile", [True, False]) @@ -108,7 +67,7 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase): ((64, 256), 512, 128), ], ) - def test_dynamic_fp8_linear( + def test_fp8_linear_variants( self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple ): M, N, K = sizes @@ -132,7 +91,10 @@ def test_dynamic_fp8_linear( output_original = model(input_tensor) output_quantized = quantized_model(input_tensor) - assert compute_error(output_original, output_quantized) > 20, "Error is too low" + error = compute_error(output_original, output_quantized) + assert ( + compute_error(output_original, output_quantized) > 20 + ), f"Quantization error is too high got a SQNR of {error}" common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 2b91c7804..06bb8aeff 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -222,12 +222,12 @@ def from_hp_to_intx( block_size: Tuple[int, ...], target_dtype: torch.dtype, quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_max: Optional[int] = None, eps: Optional[float] = None, scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, layout_type: LayoutType = PlainLayoutType(), use_hqq: bool = False, ): @@ -245,6 +245,9 @@ def from_hp_to_intx( data = data.to(target_dtype) else: scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) + # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None + if zero_point_domain is None: + zero_point = None data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) # Note: output will be uint8 tensor for sub byte tensors for now @@ -270,7 +273,7 @@ def from_hp_to_intx_static( block_size: Tuple[int, ...], target_dtype: torch.dtype, quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_max: Optional[int] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, layout_type: LayoutType = PlainLayoutType(), ): @@ -299,8 +302,8 @@ def from_hp_to_floatx( input_float: torch.Tensor, block_size: Tuple[int, ...], target_dtype: torch.dtype, - scale_dtype: Optional[torch.dtype] = None, - layout_type: LayoutType = PlainLayoutType(), + scale_dtype: Optional[torch.dtype], + layout_type: LayoutType, ): if target_dtype in FP8_TYPES: @@ -437,10 +440,8 @@ def extra_repr(self): @dataclass(frozen=True) class Float8LayoutType(LayoutType): - mm_config: ScaledMMConfig + mm_config: Optional[ScaledMMConfig] - def pre_process(self, input: torch.Tensor) -> torch.Tensor: - return input @register_layout_cls(PlainLayoutType) class PlainAQTLayout(AQTLayout): @@ -639,9 +640,18 @@ def _apply_fn_to_data(self, fn): fn(self.scale) return self + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.float8_data.to(kwargs["device"]), + self.scale.to(kwargs["device"]), + self.transposed, + self.layout_type, + ) + def __tensor_flatten__(self): return ["float8_data", "scale"], [self.transposed, self.layout_type] - + @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride @@ -658,6 +668,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) if func is aten.t.default: """we don't need to repack the weight and just rely on external shape being changed and record the status of transpose/no-transpose @@ -687,6 +701,7 @@ def from_plain( ): """ Main entrypoint for constructing Float8Layout Tensor""" assert _is_float8_type(data.dtype), f"Float8 Layout must be constructed from float8 dtype but got {data.dtype}" + assert isinstance(layout_type, Float8LayoutType), f"Float8 Layout must be constructed from Float8LayoutType but got {layout_type}" return cls(data, scale, False, layout_type) def __repr__(self): @@ -1116,14 +1131,14 @@ def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias): return out.view(*act.shape[:-1], out_dim).to(act.dtype) def _linear_fp_act_fp8_tensor_wise_weight_check( - input_tensor: torch.Tensor, - weight_tensor: AffineQuantizedTensor, + input_tensor: Union[torch.Tensor, AffineQuantizedTensor], + weight_tensor: Union[torch.Tensor, AffineQuantizedTensor], bias: Optional[torch.Tensor], ) -> bool: - def check_aqt_tensorwise(aqt: AffineQuantizedTensor) -> bool: + def check_aqt_tensorwise(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: return ( isinstance(aqt, AffineQuantizedTensor) and - isinstance(aqt.layout_tensor, Float8AQTLayout) + isinstance(aqt.layout_type, Float8LayoutType) and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and aqt.shape == aqt.block_size ) @@ -1136,7 +1151,7 @@ def _linear_fp_act_fp8_weight_impl( bias: Optional[torch.Tensor], ): """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" - from torchao.float8.inference import cast_to_float8_e4m3_inference, preprocess_data + from torchao.float8.inference import preprocess_data from torchao.float8.float8_tensor import ScaledMMConfig from torchao.float8.float8_python_api import addmm_float8_unwrapped @@ -1155,7 +1170,7 @@ def _linear_fp_act_fp8_weight_impl( # Handle case where input tensor is more than 2D inpt_data = inpt_data.reshape(-1, input_tensor.shape[-1]) input_scale = input_tensor.layout_tensor.scale - if input_scale.dim() >= 2: + if input_scale.dim() > 2: input_scale = input_scale.reshape(-1, input_scale.shape[-1]) inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index 66f83d933..b3d5de144 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -243,10 +243,14 @@ def quantize_to_float8( module_filter_fn=module_filter_fn, ) + from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul -def preprocess_data(a_data: torch.Tensor, b_data: torch.Tensor, scaled_mm_config: ScaledMMConfig) -> Tuple[torch.Tensor, torch.Tensor]: - """ Preprocess the inner fp8 data tensors for admmm + +def preprocess_data( + a_data: torch.Tensor, b_data: torch.Tensor, scaled_mm_config: ScaledMMConfig +) -> Tuple[torch.Tensor, torch.Tensor]: + """Preprocess the inner fp8 data tensors for admmm Args: a_data: Input tensor A. b_data: Input tensor B. diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 54183f670..aa2d3b3f9 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -495,12 +495,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()) -def float8_weight_only(target_dtype: torch.dtype = torch.float8_e4m3fn): +def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn): """ Applies float8 weight-only symmetric per-channel quantization to linear layers. Args: - target_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. + weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. Note: The actual matmul will be computed in original precision of the weight tensor. @@ -513,7 +513,8 @@ def apply_float8wo_quant(weight): return to_affine_quantized_floatx( input_float=weight, block_size=block_size, - target_dtype=target_dtype, + target_dtype=weight_dtype, + scale_dtype=None, layout_type=Float8LayoutType(mm_config=None), ) @@ -521,30 +522,32 @@ def apply_float8wo_quant(weight): def float8_dynamic_activation_float8_weight( - target_dtype: torch.dtype = torch.float8_e4m3fn, activation_dtype: torch.dtype = torch.float8_e4m3fn, - mm_config: ScaledMMConfig = ScaledMMConfig(use_fast_accum=True) + weight_dtype: torch.dtype = torch.float8_e4m3fn, + mm_config: Optional[ScaledMMConfig] = None ): """ Applies float8 dynamic symmetric per-tensor quantization to both activations and weights of linear layers. Args: - target_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn. + weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. mm_config (ScaledMMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ - from torchao.dtypes import to_affine_quantized_floatx + if mm_config is None: + mm_config = ScaledMMConfig(use_fast_accum=True) + #TODO we are hardcoding TensorWise scaling, will follow up PR for Tensorwise scaling def apply_float8_dynamic_activation_quant(weight: torch.Tensor): quantized_weight = to_affine_quantized_floatx( input_float=weight, block_size=weight.shape, - target_dtype=target_dtype, + target_dtype=weight_dtype, scale_dtype=torch.float32, - layout_type=Float8LayoutType(mm_config=None), + layout_type=Float8LayoutType(mm_config=mm_config), ) def input_quant_func(x: torch.Tensor): @@ -553,7 +556,7 @@ def input_quant_func(x: torch.Tensor): block_size=x.shape, target_dtype=activation_dtype, scale_dtype=torch.float32, - layout_type=Float8LayoutType(mm_config=None), + layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight ) return activation diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 72ec988ca..34c720210 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -165,7 +165,7 @@ def _get_reduction_params(block_size, input_size): cur_dim += 1 return shape_for_reduction, reduction_dims - +@torch.no_grad() def quantize_affine( input: torch.Tensor, block_size: Tuple[int, ...], @@ -174,7 +174,7 @@ def quantize_affine( output_dtype: torch.dtype, quant_min: Optional[Union[int, float]] = None, quant_max: Optional[Union[int, float]] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, ) -> torch.Tensor: """ Args: @@ -233,6 +233,13 @@ def _quantize_affine( zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, ) -> torch.Tensor: """op definition that has compatible signatures with custom op library + + Note: + zero_point_domain is optional specifies how we quantize the floating point to quantized data: + INT: quantized_val = (float_val / scale) (integer) + zero_point (integer) + FLOAT: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale + None: quantized_val = (float_val / scale) | this is primarily used for floatx quantization + Where we do not want to round values to nearest integer and instead scale and cast. """ quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) # workaround for uintx dtypes, since we don't have native Uintx dtype connected with @@ -273,15 +280,14 @@ def _quantize_affine_no_dtype_cast( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) - if zero_point_domain is None: - quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) - quant = quant.view(original_shape) - return quant - if zero_point_domain == ZeroPointDomain.INT.name: quant = torch.clamp( torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max ) + elif zero_point_domain is None: + # This case handles quantization for float8 we expect no zero point and no zero point domain + assert zero_point is None, "zero_point should be None when zero_point_domain is None" + quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) else: assert zero_point_domain == ZeroPointDomain.FLOAT.name mid_point = (quant_max + quant_min + 1) / 2 @@ -391,15 +397,7 @@ def _dequantize_affine_no_dtype_check( shape_after_reduction = shape_for_reduction for i in reduction_dims: shape_after_reduction[i] = 1 - scale = scale.view(shape_after_reduction) - - # This case handles dequantization for float8 - if zero_point_domain is None: - assert zero_point is None, "zero_point should be None when zero_point_domain is None" - assert _is_float8_type(input.dtype), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}" - dequant = input.to(output_dtype) - dequant = dequant * scale - return dequant.view(original_shape).to(output_dtype) + scale = scale.view(shape_after_reduction) if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) @@ -412,6 +410,12 @@ def _dequantize_affine_no_dtype_check( dequant = dequant - zero_point.to(torch.int32) dequant = dequant.to(output_dtype) dequant = dequant * scale + elif zero_point_domain is None: + # This case handles dequantization for float8 we expect no zero point and no zero point domain + assert zero_point is None, "zero_point should be None when zero_point_domain is None" + assert _is_float8_type(input.dtype), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}" + dequant = input.to(output_dtype) + dequant = dequant * scale else: assert zero_point_domain == ZeroPointDomain.FLOAT.name, f"Unexpected zero point domain: {zero_point_domain}" # TODO: this seems to be a detail for tinygemm (converting from uint to int, probably need to refactor this) @@ -550,7 +554,7 @@ def _do_fake_quantize_affine( ) return (q, dq) - +@torch.no_grad() def choose_qparams_affine( input: torch.Tensor, mapping_type: MappingType,