Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for Per Tensor Scaling for Float8 Dynamic Autoquant #1175

Merged
merged 1 commit into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
AutoQuantizableLinearWeight,
AQFloat8WeightOnlyQuantizedLinearWeight,
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
)
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os
Expand Down Expand Up @@ -770,11 +771,23 @@ def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
def test_aq_float8_dynamic_quant_subclass(self, device, dtype):
def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest("Fails for {dtype}")
with self.assertRaisesRegex(AssertionError, "PerRow quantization only works for bfloat16 precision"):
self._test_lin_weight_subclass_impl(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
)
else:
self._test_lin_weight_subclass_impl(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
Expand Down
40 changes: 39 additions & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Callable
import torch
import torchao
from torchao.quantization.quant_primitives import (
Expand Down Expand Up @@ -500,7 +501,7 @@ class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(AQMixin, LinearActiv
"""
AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight using per row scaling
"""
activation_granularity: str = PerRow()
activation_granularity = PerRow()
@classmethod
def from_float(cls, weight):

Expand Down Expand Up @@ -537,6 +538,42 @@ def get_per_token_block_size(x):
weight = super(AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func)
return weight

class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor):
"""
AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight using per tensor scaling
"""
activation_granularity = PerTensor()
@classmethod
def from_float(cls, weight):

# avoid circular dep
from torchao.dtypes import to_affine_quantized_floatx
from torchao.quantization.quant_api import _input_activation_quant_func_fp8
# weight settings
def get_weight_block_size(x):
assert x.ndim == 2, "Only works for 2D tensors"
return x.shape
target_dtype = torch.float8_e4m3fn

input_target_dtype = torch.float8_e4m3fn
_layout = Float8Layout(mm_config=Float8MMConfig(use_fast_accum=True))
input_quant_func = lambda x: _input_activation_quant_func_fp8(
x=x,
activation_granularity=cls.activation_granularity,
activation_dtype=input_target_dtype,
)
block_size = get_weight_block_size(weight)
weight = to_affine_quantized_floatx(
input_float=weight,
block_size=block_size,
target_dtype=target_dtype,
_layout=_layout,
scale_dtype=torch.float32,
)
from torchao.float8.inference import _is_rowwise_scaled
weight = super(AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func)
return weight


# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
DEFAULT_AUTOQUANT_CLASS_LIST = [
Expand All @@ -557,6 +594,7 @@ def get_per_token_block_size(x):
OTHER_AUTOQUANT_CLASS_LIST = [
AQFloat8WeightOnlyQuantizedLinearWeight,
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
]


Expand Down
Loading