diff --git a/src/brevitas/core/function_wrapper/shape.py b/src/brevitas/core/function_wrapper/shape.py index e175e4445..e8b42312a 100644 --- a/src/brevitas/core/function_wrapper/shape.py +++ b/src/brevitas/core/function_wrapper/shape.py @@ -195,8 +195,9 @@ def forward(self, x): tensor_shape = x.shape tensor_shape_list = list(tensor_shape) - tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size) - block_dim = self.group_dim + 1 if self.group_dim != -1 else -1 + tensor_shape_list[self.group_dim] = ( + tensor_shape_list[self.group_dim] + self.group_size - 1) // self.group_size + block_dim = self.group_dim + 1 if self.group_dim != -1 else len(tensor_shape_list) tensor_shape_list.insert(block_dim, self.group_size) x = x.view(tensor_shape_list) return x diff --git a/src/brevitas/core/restrict_val.py b/src/brevitas/core/restrict_val.py index 59b3fe8ec..7d6d83231 100644 --- a/src/brevitas/core/restrict_val.py +++ b/src/brevitas/core/restrict_val.py @@ -24,7 +24,10 @@ class _RestrictClampValue(brevitas.jit.ScriptModule): - def __init__(self, scaling_min_val: Optional[float], restrict_value_impl: Optional[Module]): + def __init__( + self, + scaling_min_val: Optional[float] = None, + restrict_value_impl: Optional[Module] = None): super(_RestrictClampValue, self).__init__() if scaling_min_val is not None and scaling_min_val != 0: self.clamp_min_ste = ScalarClampMinSte(scaling_min_val) @@ -90,9 +93,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x / threshold - @brevitas.jit.script_method def forward(self, x: Tensor) -> Tensor: return x @@ -116,9 +116,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x - threshold - @brevitas.jit.script_method def forward(self, x: Tensor): x = self.power_of_two(x) @@ -143,9 +140,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x / threshold - @brevitas.jit.script_method def forward(self, x: Tensor): x = self.float_to_int_impl(x) @@ -171,9 +165,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x - threshold - @brevitas.jit.script_method def forward(self, x: Tensor): x = self.float_to_int_impl(x) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index f11eb1f2a..9792ebdae 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -30,12 +30,18 @@ def __init__( tracked_parameter_list: List[torch.nn.Parameter], scaling_shape: Tuple[int, ...], restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, affine_rescaling: bool = False, affine_shift_scale: bool = False, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(StatsFromParameterScaling, self).__init__() + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.parameter_list_stats = _ParameterListStats( scaling_stats_impl, scaling_shape, @@ -44,6 +50,7 @@ def __init__( tracked_parameter_list) self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, + restrict_threshold_impl, scaling_shape, scaling_min_val, affine_rescaling, @@ -65,6 +72,7 @@ class _StatsScaling(brevitas.jit.ScriptModule): def __init__( self, restrict_scaling_impl: Module, + restrict_threshold_impl: Module, scaling_shape: Tuple[int, ...], scaling_min_val: Optional[float], affine_rescaling: bool, @@ -81,19 +89,22 @@ def __init__( else: self.affine_rescaling = Identity() self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + self.restrict_clamp_threshold = _RestrictClampValue( + restrict_value_impl=restrict_threshold_impl) self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() - self.restrict_scaling_impl = restrict_scaling_impl + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() @brevitas.jit.script_method def forward( self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: if threshold is None: threshold = torch.ones(1).type_as(stats) - threshold = self.restrict_scaling_pre(threshold) + threshold = self.restrict_threshold_pre(threshold) + threshold = self.restrict_clamp_threshold(threshold) stats = self.restrict_scaling_pre(stats) - stats = self.restrict_scaling_impl.combine_scale_threshold(stats, threshold) stats = self.affine_rescaling(stats) stats = self.restrict_clamp_scaling(stats) + stats = stats / threshold return stats @@ -107,12 +118,17 @@ def __init__( affine_rescaling: bool = False, affine_shift_scale: bool = False, restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_stats_momentum: float = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(RuntimeStatsScaling, self).__init__() + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.runtime_stats = _RuntimeStats( scaling_stats_impl, scaling_shape, @@ -122,6 +138,7 @@ def __init__( device) self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, + restrict_threshold_impl, scaling_shape, scaling_min_val, affine_rescaling, @@ -173,13 +190,14 @@ def _load_from_state_dict( class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule): def __init__( - self, - group_size: int, - group_dim: int, - input_view_impl: Module, - scaling_stats_impl: Module, - scaling_min_val: Optional[float], - restrict_scaling_impl: Module = FloatRestrictValue()) -> None: + self, + group_size: int, + group_dim: int, + input_view_impl: Module, + scaling_stats_impl: Module, + scaling_min_val: Optional[float], + restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None) -> None: super(RuntimeDynamicGroupStatsScaling, self).__init__() self.group_size = group_size self.group_dim = group_dim @@ -187,6 +205,12 @@ def __init__( self.scaling_min_val = scaling_min_val self.input_view_impl = input_view_impl self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + self.restrict_clamp_threshold = _RestrictClampValue( + restrict_value_impl=restrict_threshold_impl) + self.restrict_scaling_pre = self.restrict_clamp_scaling.restrict_value_impl.restrict_init_module( + ) + self.restrict_threshold_pre = self.restrict_clamp_threshold.restrict_value_impl.restrict_init_module( + ) @brevitas.jit.script_method def forward( @@ -196,7 +220,10 @@ def forward( if threshold is None: threshold = torch.ones(1).type_as(stats_input) stats_input_reshaped = self.input_view_impl(stats_input) - out = self.scaling_stats_impl(stats_input_reshaped) / threshold + threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold)) + out = self.scaling_stats_impl(stats_input_reshaped) + # Apply log scaling + out = self.restrict_scaling_pre(out) # Scaling min val - out = self.restrict_clamp_scaling(out) + out = self.restrict_clamp_scaling(out) / threshold return out diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 4917b859a..13ead5afc 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -62,20 +62,27 @@ def __init__( self, scaling_init: Union[float, Tensor], restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(ConstScaling, self).__init__() + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + self.restrict_clamp_threshold = _RestrictClampValue( + restrict_value_impl=restrict_threshold_impl) if isinstance(scaling_init, Tensor): scaling_init = scaling_init.to(device=device, dtype=dtype) scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) - self.restrict_init_module = restrict_scaling_impl.restrict_init_module() self.value = StatelessBuffer(scaling_init.detach()) else: scaling_init = restrict_scaling_impl.restrict_init_float(scaling_init) - self.restrict_init_module = restrict_scaling_impl.restrict_init_module() self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device)) + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() @brevitas.jit.script_method def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor: @@ -83,7 +90,7 @@ def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Te threshold = torch.ones(1).type_as(placeholder) # We first apply any restriction to scaling # For IntQuant, this is no-op, retrocompatible. - threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold)) + threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold)) restricted_value = self.restrict_clamp_scaling(self.value()) restricted_value = restricted_value / threshold return restricted_value @@ -133,11 +140,16 @@ def __init__( scaling_init: Union[float, Tensor], scaling_shape: Optional[Tuple[int, ...]] = None, restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(ParameterScaling, self).__init__() + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + if (isinstance(scaling_init, Tensor) and scaling_shape is not None and scaling_init.shape != SCALAR_SHAPE and scaling_init.shape != scaling_shape): raise RuntimeError("scaling_init.shape is non-scalar and != from scaling_shape.") @@ -149,12 +161,14 @@ def __init__( scaling_init = torch.tensor(scaling_init, dtype=dtype, device=device) scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) - self.restrict_init_module = restrict_scaling_impl.restrict_init_module() if scaling_init.shape == SCALAR_SHAPE and scaling_shape is not None: scaling_init = torch.full(scaling_shape, scaling_init, dtype=dtype, device=device) self.value = Parameter(scaling_init) self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + self.restrict_clamp_threshold = _RestrictClampValue( + restrict_value_impl=restrict_threshold_impl) + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() @brevitas.jit.script_method def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor: @@ -162,7 +176,7 @@ def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Te threshold = torch.ones(1).type_as(placeholder) # We first apply any restriction to scaling # For IntQuant, this is no-op, retrocompatible. - threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold)) + threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold)) value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) return value / threshold @@ -193,6 +207,7 @@ def __init__( tracked_parameter_list: List[torch.nn.Parameter], scaling_shape: Tuple[int, ...], restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: @@ -203,26 +218,37 @@ def __init__( scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list) - self.restrict_scaling_impl = restrict_scaling_impl + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.stats_scaling_impl = _StatsScaling( - restrict_scaling_impl, scaling_shape, scaling_min_val, False, False, dtype, device) + restrict_scaling_impl, + restrict_threshold_impl, + scaling_shape, + scaling_min_val, + False, + False, + dtype, + device) + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() + self.restrict_inplace_scaling_pre = restrict_scaling_impl.restrict_init_inplace_module() + self.init_done: bool = brevitas.jit.Attribute(False, bool) self.local_loss_mode: bool = brevitas.jit.Attribute(False, bool) - self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() - self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() + self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) @brevitas.jit.script_method def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor: if threshold is None: threshold = torch.ones(1).type_as(ignored) - # Threshold division must happen after we update self.value, but before we apply restrict_preproces - # This is because we don't want to store a parameter dependant on a runtime value (threshold) - # And because restrict needs to happen after we divide by threshold if self.init_done: - threshold = self.restrict_inplace_preprocess(threshold) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) + threshold = self.stats_scaling_impl.restrict_clamp_threshold( + self.restrict_threshold_pre(threshold)) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + value = value / threshold return value else: stats = self.parameter_list_stats() @@ -230,11 +256,12 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor stats = stats + 0. * self.value if self.local_loss_mode: return self.stats_scaling_impl(stats, threshold) - stats = self.restrict_inplace_preprocess(stats) - threshold = self.restrict_inplace_preprocess(threshold) + stats = self.restrict_inplace_scaling_pre(stats) + threshold = self.stats_scaling_impl.restrict_clamp_threshold( + self.restrict_threshold_pre(threshold)) inplace_tensor_mul(self.value.detach(), stats) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + value = value / threshold self.init_done = True return value @@ -312,12 +339,18 @@ def __init__( scaling_stats_input_view_shape_impl: Module = OverBatchOverTensorView(), scaling_shape: Tuple[int, ...] = SCALAR_SHAPE, restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_stats_momentum: Optional[float] = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(ParameterFromRuntimeStatsScaling, self).__init__() assert collect_stats_steps > 0, 'Steps should be more than 0' + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.collect_stats_steps: int = brevitas.jit.Attribute(collect_stats_steps, int) self.counter: int = brevitas.jit.Attribute(0, int) self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl @@ -326,19 +359,17 @@ def __init__( scaling_stats_momentum, Optional[float]) self.register_buffer('buffer', torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) - self.restrict_scaling_impl = restrict_scaling_impl self.restrict_scaling = _RestrictValue(restrict_scaling_impl) + self.restrict_threshold = _RestrictValue(restrict_threshold_impl) self.clamp_scaling = _ClampValue(scaling_min_val) self.local_loss_mode: bool = brevitas.jit.Attribute( False, bool) # required to support MSE eval or variants self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() - self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() + self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() @brevitas.jit.script_method def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: - # Threshold division must happen after we update self.value, but before we apply restrict_preproces - # This is because we don't want to store a parameter dependent on a runtime value (threshold) - # And because restrict needs to happen after we divide by threshold if self.counter < self.collect_stats_steps: stats_input = self.stats_input_view_shape_impl(stats_input) stats = self.stats(stats_input) @@ -360,14 +391,16 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: elif self.counter == self.collect_stats_steps: self.restrict_inplace_preprocess(self.buffer) inplace_tensor_mul(self.value.detach(), self.buffer) - threshold = self.restrict_preprocess(threshold) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) + threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold)) + value = self.clamp_scaling(self.restrict_scaling(self.value)) + value = value / threshold self.counter = self.counter + 1 - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) + return abs_binary_sign_grad(value) else: - threshold = self.restrict_preprocess(threshold) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) + threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold)) + value = self.clamp_scaling(self.restrict_scaling(self.value)) + value = value / threshold + return abs_binary_sign_grad(value) @brevitas.jit.script_method def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Tensor: @@ -378,12 +411,14 @@ def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Te return self.training_forward(stats_input, threshold) else: if self.counter <= self.collect_stats_steps: - out = self.buffer / threshold - out = self.restrict_preprocess(out) + out = self.buffer + out = self.restrict_scaling_pre(out) else: - threshold = self.restrict_preprocess(threshold) - out = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out))) + out = self.value + threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold)) + out = self.clamp_scaling(self.restrict_scaling(out)) + out = out / threshold + out = abs_binary_sign_grad(self.clamp_scaling(out)) return out def state_dict(self, destination=None, prefix='', keep_vars=False): @@ -396,7 +431,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): del output_dict[prefix + 'value'] # Save buffer into value for any non-zero number of collection steps elif self.counter <= self.collect_stats_steps: - output_dict[prefix + 'value'] = self.restrict_preprocess(self.buffer) + output_dict[prefix + 'value'] = self.restrict_scaling_pre(self.buffer) return output_dict def _load_from_state_dict( diff --git a/src/brevitas/quant/experimental/mx_quant_ocp.py b/src/brevitas/quant/experimental/mx_quant_ocp.py index 2299c1783..b2d719bc6 100644 --- a/src/brevitas/quant/experimental/mx_quant_ocp.py +++ b/src/brevitas/quant/experimental/mx_quant_ocp.py @@ -1,9 +1,16 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from typing import Optional + +from dependencies import this from dependencies import value from brevitas.core.function_wrapper.ops_ste import CeilSte +from brevitas.core.function_wrapper.ops_ste import FloorSte +from brevitas.core.restrict_val import PowerOfTwo +from brevitas.core.restrict_val import PowerOfTwoRestrictValue +from brevitas.core.restrict_val import RoundSte from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling from brevitas.inject import ExtendedInjector from brevitas.inject.enum import RestrictValueType @@ -18,10 +25,14 @@ from brevitas.quant.base import MinMaxStatsScaling from brevitas.quant.base import MSEAsymmetricScale from brevitas.quant.base import MSESymmetricScale +from brevitas.quant.base import MSESymmetricScaleSubInjector from brevitas.quant.base import ShiftedMinUintQuant +from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat from brevitas.quant.experimental.float_base import ScaledFloatActBase from brevitas.quant.experimental.float_base import ScaledFloatWeightBase +from brevitas.quant.experimental.float_quant_fnuz import FpFNUZMixin from brevitas.quant.experimental.float_quant_ocp import FpOCPAct +from brevitas.quant.experimental.float_quant_ocp import FpOCPMixin from brevitas.quant.experimental.float_quant_ocp import FpOCPWeight from brevitas.quant.solver.act import ActQuantSolver from brevitas.quant.solver.weight import WeightQuantSolver @@ -43,17 +54,28 @@ class GroupwiseActProxyMixin(ExtendedInjector): proxy_class = GroupwiseActQuantProxyFromInjector +class RestrictThresholdMixin(ExtendedInjector): + restrict_value_float_to_int_impl = FloorSte + restrict_scaling_impl = PowerOfTwoRestrictValue + + class MXWeightMixin(ExtendedInjector): + threshold_mixin = RestrictThresholdMixin group_size = 32 restrict_scaling_type = RestrictValueType.POWER_OF_TWO - restrict_value_float_to_int_impl = CeilSte + restrict_value_float_to_int_impl = FloorSte scaling_per_output_type = ScalingPerOutputType.GROUP + @value + def restrict_threshold_impl(): + return this.threshold_mixin.restrict_scaling_impl + class MXActMixin(ExtendedInjector): + threshold_mixin = RestrictThresholdMixin group_size = 32 restrict_scaling_type = RestrictValueType.POWER_OF_TWO - restrict_value_float_to_int_impl = CeilSte + restrict_value_float_to_int_impl = FloorSte scaling_impl = RuntimeDynamicGroupStatsScaling scaling_per_output_type = ScalingPerOutputType.GROUP @@ -65,6 +87,10 @@ def stats_reduce_dim(group_dim): else: return group_dim + 1 + @value + def restrict_threshold_impl(): + return this.threshold_mixin.restrict_scaling_impl + class MXFloat8e4m3Weight(MXWeightMixin, GroupwiseWeightFloatProxyMixin, @@ -135,3 +161,122 @@ class ShiftedMXUInt8WeightMSE(MSEAsymmetricScale, ShiftedMXUInt8Weight): MX Int signed weight quantizer with per-channel MSE-based scaling. """ pass + + +class Fp8e4m3WeightSymmetricGroupQuant(Fp8e4m3WeightPerChannelFloat): + """ + Block / group / vector signed symmetric e4m3 weight quantizer with float scales. + We inherit from a per-channel quantizer to re-use some underlying machinery. + """ + proxy_class = GroupwiseWeightFloatQuantProxyFromInjector + scaling_per_output_type = ScalingPerOutputType.GROUP + + +def build_options( + weight_quant, + bit_width, + scale_stats_op, + is_po2_scale, + scale_computation_type, + scale_rounding_func_type: Optional[str], + group_size: int = 32, + group_dim: Optional[int] = None, + scaling_min_val: float = 1e-8): + + options = dict() + scale_rounding_func_dict = {'ceil': CeilSte, 'floor': FloorSte, 'round': RoundSte} + + options['group_size'] = group_size + options['bit_width'] = bit_width + options['scaling_min_val'] = scaling_min_val + + if scale_stats_op == 'mse': + weight_quant = type('MSEWeightQuant', (MSESymmetricScale, weight_quant), {}) + else: + options['scale_stats_op'] = scale_stats_op + + if group_dim is not None: + options['group_dim'] = group_dim + + if scale_computation_type == 'param_from_stats': + options['scaling_impl_type'] = 'parameter_from_stats' + elif scale_computation_type == 'stats': + options['scaling_impl_type'] = 'stats' + else: + raise RuntimeError("Not supported") + + if is_po2_scale: + assert scale_rounding_func_type is not None + scale_rounding_func = scale_rounding_func_dict[scale_rounding_func_type] + options['restrict_scaling_type'] = RestrictValueType.POWER_OF_TWO + options['restrict_value_float_to_int_impl'] = scale_rounding_func + else: + # If not po2, threshold does need any restriction and will match float restriction of the scale + options['restrict_scaling_type'] = RestrictValueType.FP + options['restrict_threshold_impl'] = None + assert scale_rounding_func_type is None, "Rounding for scale not needed when float" + return options, weight_quant + + +class GroupwiseIntWeightQuantizerBuilder: + + def __new__( + self, + bit_width, + scale_stats_op, + is_po2_scale, + scale_computation_type, + scale_rounding_func_type: Optional[str], + group_size: int = 32, + group_dim: Optional[int] = None, + scaling_min_val: float = 1e-8, + ): + + weight_quant = MXInt8Weight + options, weight_quant = build_options(weight_quant, bit_width, + scale_stats_op, + is_po2_scale, + scale_computation_type, + scale_rounding_func_type, + group_size, + group_dim, + scaling_min_val) + weight_quant = weight_quant.let(**options) + return weight_quant + + +class GroupwiseFloatWeightQuantizerBuilder(GroupwiseIntWeightQuantizerBuilder): + + def __new__( + self, + exponent_bit_width, + mantissa_bit_width, + bit_width, + scale_stats_op, + is_po2_scale, + scale_computation_type, + scale_rounding_func_type: Optional[str], + group_size: int = 32, + group_dim: Optional[int] = None, + scaling_min_val: float = 1e-8, + format: Optional[str] = None): + weight_quant = Fp8e4m3WeightSymmetricGroupQuant + + if format == 'ocp': + weight_quant = type('OCPWeightQuant', (FpOCPMixin, weight_quant), {}) + if format == 'fnuz': + weight_quant = type('OCPWeightQuant', (FpFNUZMixin, weight_quant), {}) + + options, weight_quant = build_options(weight_quant, bit_width, + scale_stats_op, + is_po2_scale, + scale_computation_type, + scale_rounding_func_type, + group_size, + group_dim, + scaling_min_val) + options['exponent_bit_width'] = exponent_bit_width + options['mantissa_bit_width'] = mantissa_bit_width + + weight_quant = weight_quant.let(**options) + return weight_quant diff --git a/src/brevitas/quant/solver/common.py b/src/brevitas/quant/solver/common.py index 4d46cc704..a4930e43d 100644 --- a/src/brevitas/quant/solver/common.py +++ b/src/brevitas/quant/solver/common.py @@ -178,7 +178,8 @@ def stats_reduce_dim(scaling_stats_op, scaling_per_output, group_dim=None): elif scaling_per_output == ScalingPerOutputType.TENSOR: return None elif scaling_per_output == ScalingPerOutputType.GROUP: - return group_dim + 1 + reduce_dim = group_dim + 1 if group_dim != -1 else -1 + return reduce_dim @value def keepdim(scaling_per_output): diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index 459f0eec7..9252b8d72 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -150,11 +150,6 @@ def minifloat(self, float_datatype=True): int_scale = float_internal_scale( minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps) float_value = torch.round(self._pre_round_float_value) * int_scale - return float_value.type(self.scale.dtype) - else: - raise RuntimeError(f"FloatQuantTensor not valid.") - - @staticmethod def check_input_type(tensor): if not isinstance(tensor, FloatQuantTensor): raise RuntimeError("Tensor is not a FloatQuantTensor") diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 9460fadf1..457877459 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -20,6 +20,9 @@ from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat +from brevitas.quant.experimental.mx_quant_ocp import Fp8e4m3WeightSymmetricGroupQuant +from brevitas.quant.experimental.mx_quant_ocp import GroupwiseFloatWeightQuantizerBuilder +from brevitas.quant.experimental.mx_quant_ocp import GroupwiseIntWeightQuantizerBuilder from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3WeightMSE @@ -55,7 +58,6 @@ from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat -from brevitas_examples.common.generative.quantizers import Fp8e4m3WeightSymmetricGroupQuant from brevitas_examples.common.generative.quantizers import Int8DynamicActPerGroupFloat from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat @@ -222,7 +224,8 @@ def generate_quantizers( quantize_input_zero_point=False, device=None, weight_kwargs=None, - input_kwargs=None): + input_kwargs=None, + weight_scale_rounding_func_type=None): """ Replace float layers with quant layers in the target model """ @@ -243,8 +246,32 @@ def generate_quantizers( else: input_float_format = {} - weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][ - weight_param_method][weight_quant_granularity][weight_quant_type] + if weight_quant_granularity == 'per_group': + if weight_quant_format == 'int': + weight_quant = GroupwiseIntWeightQuantizerBuilder( + bit_width=weight_bit_width, + scale_stats_op='max' if weight_param_method != 'mse' else weight_param_method, + is_po2_scale=weight_scale_precision == 'po2_scale', + scale_computation_type='parameter_from_stats', + scale_rounding_func_type=weight_scale_rounding_func_type, + group_dim=weight_group_dim, + group_size=weight_group_size, + scaling_min_val=1e-4 if dtype == torch.float16 else 1e-8) + else: + weight_quant = GroupwiseFloatWeightQuantizerBuilder( + exponent_bit_width=weight_float_format['exponent_bit_width'], + mantissa_bit_width=weight_float_format['mantissa_bit_width'], + bit_width=weight_bit_width, + scale_stats_op='max' if weight_param_method != 'mse' else weight_param_method, + is_po2_scale=weight_scale_precision == 'po2_scale', + scale_computation_type='parameter_from_stats', + scale_rounding_func_type=weight_scale_rounding_func_type, + group_dim=weight_group_dim, + group_size=weight_group_size, + scaling_min_val=1e-4 if dtype == torch.float16 else 1e-8) + else: + weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][ + weight_param_method][weight_quant_granularity][weight_quant_type] if input_bit_width is not None and input_scale_type == 'no_scale': input_quant = sym_input_quant = linear_input_quant = INPUT_QUANT_MAP[input_quant_format][ diff --git a/src/brevitas_examples/common/generative/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py index c3c99a96f..4f7040d08 100644 --- a/src/brevitas_examples/common/generative/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -49,15 +49,6 @@ class IntWeightSymmetricGroupQuant(Int8WeightPerChannelFloat): scaling_per_output_type = ScalingPerOutputType.GROUP -class Fp8e4m3WeightSymmetricGroupQuant(Fp8e4m3WeightPerChannelFloat): - """ - Block / group / vector signed symmetric e4m3 weight quantizer with float scales. - We inherit from a per-channel quantizer to re-use some underlying machinery. - """ - proxy_class = GroupwiseWeightFloatQuantProxyFromInjector - scaling_per_output_type = ScalingPerOutputType.GROUP - - class Int8DynamicActPerTensorFloat(DynamicActProxyMixin, Int8ActPerTensorFloat): """ Symmetric quantizer with per tensor dynamic scale. diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 4a87f5a1a..5ef39fffa 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -253,7 +253,9 @@ def main(args): input_quant_granularity=args.input_quant_granularity, input_group_size=args.input_group_size, quantize_input_zero_point=args.quantize_input_zero_point, - device=device) + device=device, + weight_scale_rounding_func_type=args.weight_scale_rounding_func_type + ) layer_map = generate_quant_maps( linear_input_quant=linear_input_quant, weight_quant=weight_quant, @@ -400,6 +402,12 @@ def parse_args(args): default='per_group', choices=['per_channel', 'per_tensor', 'per_group'], help='Granularity for scales/zero-point of weights. Default: per_group.') + parser.add_argument( + '--weight-scale-rounding-func-type', + type=str, + default=None, + choices=['round', 'ceil', 'floor'], + help='Rounding function to use with Po2 scale. Default: None.') parser.add_argument( '--weight-group-dim', type=int, diff --git a/tests/brevitas/core/test_quant_mx.py b/tests/brevitas/core/test_quant_mx.py new file mode 100644 index 000000000..b2ab279d4 --- /dev/null +++ b/tests/brevitas/core/test_quant_mx.py @@ -0,0 +1,187 @@ +""" +Brief MXFP quantizer +""" +# pylint: disable=missing-function-docstring, redefined-outer-name + +import struct +from typing import Tuple + +from hypothesis import given +import pytest_cases +import torch + +from brevitas.nn.quant_activation import QuantIdentity +from brevitas.nn.quant_linear import QuantLinear +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight +from tests.brevitas.hyp_helper import float_tensor_nz_st + +torch.manual_seed(0) + + +# debug utility +def to_string(val: torch.Tensor | float, spaced: bool = True, code: str = "f") -> str | list[str]: + """ Debug util for visualizing float values """ + + def scalar_to_string(val: float, spaced: bool) -> str: + s = ''.join(bin(c).replace('0b', '').rjust(8, '0') for c in struct.pack('!' + code, val)) + spaced = spaced and len(s) == 32 + return f"{s[0]} {s[1:9]} {s[9:]}" if spaced else s + + if isinstance(val, float): + return scalar_to_string(val, spaced) + val = val.view(-1) + return [scalar_to_string(val[i].item(), spaced) for i in range(val.numel())] + + +# debug utility +def check_bits(val: torch.Tensor | float, mbits: int) -> Tuple[bool, int]: + """ return (too many precision bits, lowest mantissa bit) """ + strings = to_string(val, spaced=False) + if isinstance(strings, str): + strings = [strings] + error, lowest = False, 0 + for s in strings: + mant = s[9:] + error = error or "1" in mant[mbits:] + lowest = max(lowest, mant.find("1")) + return error, lowest + + +# Avoid returning exp 0 if we is 0 +def safe_frexp(x: torch.Tensor) -> torch.Tensor: + """torch.frexp returns unbiased exponent 0 for 0.0, which is not what we want.""" + if x.is_cuda and x.dtype not in (torch.float32, torch.float16): + x = x.float() # no gpu support for frexp on bfloat16 or any float8 + return torch.where(x == 0.0, -126, x.frexp().exponent - 1) + + +class MXFP: + """ + MXFP - Quantize OCP MXFP floating point types. + A type is defined as ebits, mbits, bias, and inf/nan handling. + """ + CONFIG = dict( + e5m2=(5, 2, 15, "ieee"), + e4m3=(4, 3, 7, "fn"), + e3m2=(3, 2, 3, "fnuz"), + e2m3=(2, 3, 1, "fnuz"), + e2m1=(2, 1, 1, "fnuz")) + + def __init__(self, name, tile_size: int | None = 32): + self.name = name.lower() + assert self.name in self.CONFIG + self.ebits, self.mbits, self.bias, self.infnan = self.CONFIG[self.name] + self.tile_size = tile_size + + @property # maximum unbiased exponent for this type + def emax(self) -> int: + return 2 ** self.ebits - 1 - self.bias - int(self.infnan == "ieee") + + @property # minimum unbiased exponent for this type + def emin(self) -> int: + return 1 - self.bias + + @property # maximum representable value; the "fn" reserves values for all non-sign bits == 1 + def maxval(self) -> float: + return 2 ** self.emax * (2.0 - (1 + int(self.infnan == "fn")) * 2 ** (-self.mbits)) + + @property # for alternative scale selection + def midmax(self) -> float: + return (2 ** (self.emax + 1) - self.maxval) / 2. + self.maxval + + @property # minimum representable positive value + def minval(self) -> float: + return 2 ** self.emin * 2 ** (-self.mbits) + + def quantize(self, tensor: torch.Tensor, axis: int = -1, select: bool = False): + """ + Fake quantize along the indicated dimension. This method assumes the tile dimension is the size of the tile, + so some reshaping and possibly padding is likely required. From there, we have 5 needed lines of code. + """ + exp = safe_frexp(tensor) # safe_frexp pretends the mantissa is < 1.0 + shared = exp.amax(axis, keepdim=True) # shared exponent per the OCP MX spec + + # This is an alternative to the OCP MX scale selection, which chooses the maximum exponent (maxexp). + # Instead, choose maxexp + 1 if absmax is closer to 2^(maxexp+1) than maxval. This reduces error on + # the highest magnitude value at the potential cost increased error or underflow of the smallest. + # Ad hoc MSE test shows that e4m3, due to reserving the most significant value for Nan, benefits the + # most from this technique. In hardware or a kernel, this is as simple as comparing bits [30:21] + # instead of [30:23] when getting max exponent, then add 1 to the max eeeeeeeemm and shift right two. + # e2m1 e3m2 e2m3 e4m3 e5m2 + # max 0.01325 0.00291 0.00080 0.00085 0.00291 + # best 0.01254 0.00280 0.00079 0.00071 0.00280 + + if select: + midmax = self.midmax * (shared - self.emax).exp2() + shared[tensor.abs().amax(axis, keepdim=True) > midmax] += 1 + + # The way this works is to appropriately shift values so that rounding can work, then shift them back. + # All values that are representable as normal given the scale are shifted up by the difference + # between the individual exponent and zero, plus the mantissa width. Subnormals get the same, + # but with decreasing mantissa bits. The maxval for saturation is adjusted on a per block basis. + scale = (self.mbits - (shared - exp - (self.emax - self.emin)).clamp_min(0) - exp).exp2() + # about that last line of code: + # The "offset" is the number of mbits lost to subnormal/underflow. This is based on the difference between + # the shared exponent and the individual exponent, adjusted to the dynamic range of normals for this type. + # It can't be negative, because we subtract it from mbits, and don't want to exceed the available mbits. + # offset = (shared - exp - (self.emax - self.emin)).clamp_min(0) + # The shift left will be mbits - offset - exp, which for negative exponents gets them into the right range. + maxval = self.maxval * (shared - self.emax).exp2() # scale maxval per tile + return ((tensor * scale).round() / scale).clamp(-maxval, maxval) + + +MAP = {"e4m3": (4, 3), "e5m2": (5, 2), "e2m3": (2, 3), "e3m2": (3, 2), "e2m1": (2, 1)} + + +@given(inp=float_tensor_nz_st(shape=(1, 32), max_val=1e10, min_val=-1e10)) +@pytest_cases.parametrize('bit_widths', list(MAP.keys())) +def test_act_mx(inp, bit_widths): + torch.set_printoptions(precision=12, sci_mode=False) + exp, mant = MAP[bit_widths] + + act_quant = QuantIdentity( + MXFloat8e4m3Act, + exponent_bit_width=exp, + mantissa_bit_width=mant, + bit_width=mant + exp + 1, + group_dim=1, + return_quant_tensor=True) + act_quant.eval() + x = inp + + quantizer = MXFP(bit_widths) + + qx = act_quant(x) + + y = quantizer.quantize(x) + assert torch.allclose(qx.value, y, atol=1e-8) + + +@given(inp=float_tensor_nz_st(shape=(1, 32), max_val=1e10, min_val=-1e10)) +@pytest_cases.parametrize('bit_widths', list(MAP.keys())) +@pytest_cases.parametrize('weight_quant_type', ['stats', 'parameter_from_stats']) +def test_weight_mx(inp, bit_widths, weight_quant_type): + torch.set_printoptions(precision=12, sci_mode=False) + exp, mant = MAP[bit_widths] + weight_quant = QuantLinear( + 32, + 1, + bias=False, + weight_quant=MXFloat8e4m3Weight, + weight_scaling_impl_type=weight_quant_type, + weight_exponent_bit_width=exp, + weight_mantissa_bit_width=mant, + weight_bit_width=mant + exp + 1) + + x = inp + weight_quant.weight.data = x + weight_quant.weight_quant.init_tensor_quant() + quantizer = MXFP(bit_widths) + + qx_weight = weight_quant.quant_weight() + qx_weight_two = weight_quant.quant_weight() + + y = quantizer.quantize(x) + assert torch.allclose(qx_weight.value, y, atol=1e-8) + assert torch.allclose(qx_weight_two.value, y, atol=1e-8) diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index fbfc76842..16f944e97 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -60,7 +60,7 @@ def reference_implementation_scale_factors_po2( return scale -@given(inp=float_tensor_random_size_st()) +@given(inp=float_tensor_random_size_st(max_val=1e10, min_val=-1e10)) def test_scale_factors_ptq_calibration_po2(inp): class TestModel(nn.Module): @@ -80,7 +80,6 @@ def forward(self, x): expected_scale = reference_implementation_scale_factors_po2(inp) scale = model.act.act_quant.scale() - assert torch.allclose(expected_scale, scale)