diff --git a/Docs/keras_code_examples/adaround.py b/Docs/keras_code_examples/adaround.py index 1c8eaad244..7c4a827a44 100644 --- a/Docs/keras_code_examples/adaround.py +++ b/Docs/keras_code_examples/adaround.py @@ -92,8 +92,7 @@ def apply_adaround_example(): # Returns session with adarounded weights and their corresponding encodings adarounded_model = Adaround.apply_adaround(model, params, path='./', filename_prefix='dummy', - default_param_bw=param_bw, default_quant_scheme=quant_scheme, - default_is_symmetric=False) + default_param_bw=param_bw, default_quant_scheme=quant_scheme) # Create QuantSim using adarounded_session sim = QuantizationSimModel(adarounded_model, quant_scheme, default_output_bw=output_bw, default_param_bw=param_bw) diff --git a/Examples/tensorflow/quantization/keras/adaround.ipynb b/Examples/tensorflow/quantization/keras/adaround.ipynb index a8eb85b404..cd9a822f17 100644 --- a/Examples/tensorflow/quantization/keras/adaround.ipynb +++ b/Examples/tensorflow/quantization/keras/adaround.ipynb @@ -335,8 +335,7 @@ "\n", "os.makedirs(\"./output/\", exist_ok=True)\n", "ada_model = Adaround.apply_adaround(model, params, path=\"output\", filename_prefix=\"adaround\",\n", - " default_param_bw=8, default_quant_scheme=QuantScheme.post_training_tf,\n", - " default_is_symmetric=False)" + " default_param_bw=8, default_quant_scheme=QuantScheme.post_training_tf)" ] }, { diff --git a/Examples/tensorflow/quantization/keras/quantsim_adaround_pcq.ipynb b/Examples/tensorflow/quantization/keras/quantsim_adaround_pcq.ipynb index 3eb51694f9..b8d990a982 100644 --- a/Examples/tensorflow/quantization/keras/quantsim_adaround_pcq.ipynb +++ b/Examples/tensorflow/quantization/keras/quantsim_adaround_pcq.ipynb @@ -140,12 +140,12 @@ }, { "cell_type": "markdown", - "source": [ - "For this example notebook, we are going to load a pretrained ResNet50 model from Keras. Similarly, you can load any pretrained Keras model instead." - ], "metadata": { "collapsed": false - } + }, + "source": [ + "For this example notebook, we are going to load a pretrained ResNet50 model from Keras. Similarly, you can load any pretrained Keras model instead." + ] }, { "cell_type": "code", @@ -349,7 +349,6 @@ "os.makedirs(\"./output/\", exist_ok=True)\n", "ada_model = Adaround.apply_adaround(model, params, path=\"output\", filename_prefix=\"adaround\",\n", " default_param_bw=8, default_quant_scheme=QuantScheme.post_training_tf,\n", - " default_is_symmetric=False,\n", " config_file=\"Examples/tensorflow/utils/keras/pcq_quantsim_config\") # NOTE: The same config file used in QuantSim is used here as well. Again, telling Adaround to enable PCQ.\n" ] }, diff --git a/Examples/tensorflow/utils/keras/pcq_quantsim_config b/Examples/tensorflow/utils/keras/pcq_quantsim_config index 4610120caa..7f79a4d281 100644 --- a/Examples/tensorflow/utils/keras/pcq_quantsim_config +++ b/Examples/tensorflow/utils/keras/pcq_quantsim_config @@ -1,7 +1,8 @@ {"defaults": { "ops": {}, "params": { - "is_quantized": "True" + "is_quantized": "True", + "is_symmetric": "True" }, "strict_symmetric": "True", "unsigned_symmetric": "False", diff --git a/NightlyTests/tensorflow/test_adaround_keras.py b/NightlyTests/tensorflow/test_adaround_keras.py index 40734ac186..f178caca24 100644 --- a/NightlyTests/tensorflow/test_adaround_keras.py +++ b/NightlyTests/tensorflow/test_adaround_keras.py @@ -36,13 +36,12 @@ # @@-COPYRIGHT-END-@@ # ============================================================================= """ Keras AdaRound Nightly Tests """ - -import pytest -pytestmark = pytest.mark.skip("Disable tests that requires eager execution") import json import logging import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +import pytest +pytestmark = pytest.mark.skip("Disable tests that requires eager execution") import numpy as np import tensorflow as tf from tensorflow.keras.applications.mobilenet import MobileNet @@ -120,8 +119,7 @@ def dummy_forward_pass(model: tf.keras.Model, _): quant_scheme = QuantScheme.post_training_tf_enhanced adarounded_model = Adaround.apply_adaround(model, params, path='./', filename_prefix='dummy', - default_param_bw=param_bw, default_quant_scheme=quant_scheme, - default_is_symmetric=False) + default_param_bw=param_bw, default_quant_scheme=quant_scheme) # Read exported param encodings JSON file with open('./dummy.encodings') as json_file: diff --git a/TrainingExtensions/common/src/python/aimet_common/quantsim.py b/TrainingExtensions/common/src/python/aimet_common/quantsim.py index 0501dc3575..0fe1d8d8f0 100644 --- a/TrainingExtensions/common/src/python/aimet_common/quantsim.py +++ b/TrainingExtensions/common/src/python/aimet_common/quantsim.py @@ -79,8 +79,9 @@ def gate_min_max(min_val: float, max_val: float) -> Tuple[float, float]: return gated_min, gated_max -def calculate_delta_offset(min_val: Union[float, np.ndarray], max_val: Union[float, np.ndarray], bitwidth: int) -> \ - Union[Tuple[float, float], Tuple[List, List]]: +def calculate_delta_offset(min_val: Union[float, np.ndarray], max_val: Union[float, np.ndarray], bitwidth: int, + use_symmetric_encodings: bool, use_strict_symmetric: bool) \ + -> Union[Tuple[float, float], Tuple[List, List]]: """ calculates delta and offset given min and max. :param min_val: min encoding value @@ -88,8 +89,12 @@ def calculate_delta_offset(min_val: Union[float, np.ndarray], max_val: Union[flo :param bitwidth: bitwidth used for quantization :return: delta and offset values computed """ + num_steps = 2 ** bitwidth - 1 + if use_symmetric_encodings and use_strict_symmetric: + num_steps -= 1 + min_val, max_val = gate_min_max(min_val, max_val) - delta = (max_val - min_val) / (2 ** bitwidth - 1) + delta = (max_val - min_val) / num_steps if isinstance(delta, np.ndarray): offset = np.around(min_val/delta) diff --git a/TrainingExtensions/common/test/python/test_quantsim.py b/TrainingExtensions/common/test/python/test_quantsim.py index c9baf4d95f..49816b5e3a 100644 --- a/TrainingExtensions/common/test/python/test_quantsim.py +++ b/TrainingExtensions/common/test/python/test_quantsim.py @@ -54,6 +54,7 @@ def test_offset_delta_compute(self): expected_delta = (max - min) / (2 ** bitwidth - 1) expected_offset = np.round(min / expected_delta) - delta, offset = calculate_delta_offset(min, max, bitwidth) + delta, offset = calculate_delta_offset(min, max, bitwidth, + use_strict_symmetric=False, use_symmetric_encodings=False) self.assertTrue(expected_delta == delta) self.assertTrue(expected_offset == offset) diff --git a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/adaround/adaround_weight.py b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/adaround/adaround_weight.py index 33b43f4f86..81e216e4c4 100644 --- a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/adaround/adaround_weight.py +++ b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/adaround/adaround_weight.py @@ -217,7 +217,9 @@ def _apply_adaround_helper( # pylint: disable=too-many-locals all_inp_data, all_out_data = act_sampler.sample_activation(op, hard_rounded_op, session, session_hard_rounded_weight, starting_op_names, params.num_batches) - is_symmetric = cls._get_is_symmetric_flag_for_op_param(configs, op.type, param_name="weight") + is_symmetric = cls.get_is_symmetric_flag_for_op_param(configs, op.type, + param_name="weight", + framework_to_onnx_type_dict=tf_op_type_to_onnx_type_dict) # Find next following activation function @@ -259,14 +261,13 @@ def get_config_dict_keys(config_file: str) -> Tuple[ConfigDictType, bool, bool, :return: Config dictionary, strict symmetric flag, unsigned symmetric flag, enable per channel flag. """ configs = JsonConfigImporter.import_json_config_file(config_file) - # Strict_symmetric and unsigned_symmetric flags have default value False and True respectively strict_symmetric = configs[ConfigDictKeys.DEFAULTS].get(ConfigDictKeys.STRICT_SYMMETRIC, False) - unisgned_symmetric = configs[ConfigDictKeys.DEFAULTS].get(ConfigDictKeys.UNSIGNED_SYMMETRIC, True) + unsigned_symmetric = configs[ConfigDictKeys.DEFAULTS].get(ConfigDictKeys.UNSIGNED_SYMMETRIC, False) # Read per-channel quantization field. Default = False per_channel_enabled = configs[ConfigDictKeys.DEFAULTS].get(ConfigDictKeys.PER_CHANNEL_QUANTIZATION, False) - return configs, strict_symmetric, unisgned_symmetric, per_channel_enabled + return configs, strict_symmetric, unsigned_symmetric, per_channel_enabled @staticmethod def _get_ordered_list_of_ops(graph: tf.Graph, input_op_names: List[str], output_op_names: List[str]) \ @@ -355,7 +356,8 @@ def _update_param_encodings_dict(encoding_dict: Dict, op: tf.Operation, 'is_symmetric': is_symmetric}] @staticmethod - def _get_is_symmetric_flag_for_op_param(configs: ConfigDictType, tf_op_type: str, param_name: str): + def get_is_symmetric_flag_for_op_param(configs: ConfigDictType, tf_op_type: str, param_name: str, + framework_to_onnx_type_dict: dict) -> bool: """ NOTE: Checks config file in reverse order of specificity. @@ -367,13 +369,14 @@ def _get_is_symmetric_flag_for_op_param(configs: ConfigDictType, tf_op_type: str :param configs: Dictionary containing configs. :param tf_op_type: TensorFlow operation type. :param param_name: Parameter name. - :return: Is_symmetric flag for given op's param. + :param framework_to_onnx_type_dict: Dictionary mapping framework type to ONNX type. + :return: is_symmetric flag for given op's param. """ assert param_name in MAP_TF_PARAM_NAME_TO_QUANTSIM_NAME.keys(), "param name is invalid." # third level of specificity which applies to specific op_type's parameters. try: - onnx_type = tf_op_type_to_onnx_type_dict[tf_op_type] + onnx_type = framework_to_onnx_type_dict[tf_op_type] return configs[ConfigDictKeys.OP_TYPE] \ [onnx_type] \ [ConfigDictKeys.PARAMS] \ diff --git a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/adaround_weight.py b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/adaround_weight.py index 10abd95f2c..c69d7f21b3 100644 --- a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/adaround_weight.py +++ b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/adaround_weight.py @@ -51,7 +51,7 @@ from aimet_tensorflow.keras.adaround.activation_sampler import ActivationSampler from aimet_tensorflow.keras.adaround.adaround_wrapper import AdaroundWrapper from aimet_tensorflow.keras.adaround.adaround_optimizer import AdaroundOptimizer -from aimet_tensorflow.keras.connectedgraph import ConnectedGraph +from aimet_tensorflow.keras.connectedgraph import ConnectedGraph, map_keras_types_to_onnx _logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant) @@ -69,7 +69,6 @@ class Adaround: def apply_adaround(cls, model: tf.keras.Model, params: AdaroundParameters, path: str, filename_prefix: str, default_param_bw: int = 4, default_quant_scheme: QuantScheme = QuantScheme.post_training_tf_enhanced, - default_is_symmetric: bool = False, config_file: str = None) -> tf.keras.Model: """ Returns model with optimized weight rounding of every op (Conv and Linear) and also saves the @@ -83,14 +82,12 @@ def apply_adaround(cls, model: tf.keras.Model, params: AdaroundParameters, path: :param default_param_bw: Default bitwidth (4-31) to use for quantizing layer parameters. Default 4 :param default_quant_scheme: Quantization scheme. Supported options are QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced. Default QuantScheme.post_training_tf_enhanced - :param default_is_symmetric: True if symmetric encodings is used, else asymmetric encodings. - Default False. :param config_file: Configuration file for model quantizers :return: Model with Adarounded weights """ # Get parameters from config file. To allow one central place for Adaround and Quantsim - _, strict_symmetric, unsigned_symmetric, per_channel_enabled = TfAdaround.get_config_dict_keys(config_file) + configs, strict_symmetric, unsigned_symmetric, per_channel_enabled = TfAdaround.get_config_dict_keys(config_file) # Optimization Hyper parameters opt_params = AdaroundHyperParameters(params.num_iterations, params.reg_param, params.beta_range, @@ -110,7 +107,10 @@ def apply_adaround(cls, model: tf.keras.Model, params: AdaroundParameters, path: progbar = Progbar(len(ordered_layer_indices)) for idx in ordered_layer_indices: - cls.adaround_layer(act_sampler, default_is_symmetric, strict_symmetric, unsigned_symmetric, + use_symmetric_encodings = TfAdaround.get_is_symmetric_flag_for_op_param(configs, model.layers[idx], + param_name='weight', + framework_to_onnx_type_dict=map_keras_types_to_onnx) + cls.adaround_layer(act_sampler, use_symmetric_encodings, strict_symmetric, unsigned_symmetric, default_param_bw, default_quant_scheme, model, hard_rounded_model, soft_rounded_model, idx, module_act_func_pair, opt_params, param_encodings, per_channel_enabled) progbar.add(1) diff --git a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/qc_quantize_wrapper.py b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/qc_quantize_wrapper.py index eb3273dda0..01efa93875 100644 --- a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/qc_quantize_wrapper.py +++ b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/qc_quantize_wrapper.py @@ -56,7 +56,7 @@ class QuantizerSettings: """ Class holding quantizer settings """ def __init__(self, bitwidth: int, round_mode: str, quant_scheme: Union[str, QuantScheme], is_symmetric: bool, - use_unsigned_symmetric: bool, use_strict_symmetric: bool): + use_unsigned_symmetric: bool, use_strict_symmetric: bool, enabled: bool = False): self._bitwidth = bitwidth self._round_mode = round_mode if isinstance(quant_scheme, str): @@ -72,6 +72,7 @@ def __init__(self, bitwidth: int, round_mode: str, quant_scheme: Union[str, Quan self._is_symmetric = is_symmetric self._use_unsigned_symmetric = use_unsigned_symmetric self._use_strict_symmetric = use_strict_symmetric + self._enabled = enabled @property def quant_scheme(self): @@ -123,6 +124,16 @@ def use_strict_symmetric(self, use_strict_symmetric: bool): """ Use strict symmetric setter """ self._use_strict_symmetric = use_strict_symmetric + @property + def enabled(self): + """ Enabled getter """ + return self._enabled + + @enabled.setter + def enabled(self, enabled: bool): + """ Enabled setter """ + self._enabled = enabled + class QcQuantizeWrapper(tf.keras.layers.Layer): """ Wrapper for simulating quantization noise """ diff --git a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/tensor_quantizer.py b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/tensor_quantizer.py index ed86ec296b..1927d76b1d 100644 --- a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/tensor_quantizer.py +++ b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/tensor_quantizer.py @@ -277,7 +277,9 @@ def encoding(self) -> Optional[libpymo.TfEncoding]: # pylint: disable = protected-access encodings.min = tf.keras.backend.get_value(self._encoding_min) encodings.max = tf.keras.backend.get_value(self._encoding_max) - encodings.delta, encodings.offset = calculate_delta_offset(encodings.min, encodings.max, self.bitwidth) + encodings.delta, encodings.offset = calculate_delta_offset(encodings.min, encodings.max, + self.bitwidth, self.is_symmetric, + self.use_strict_symmetric) encodings.bw = self.bitwidth return encodings return None @@ -562,7 +564,8 @@ def encoding(self) -> Optional[List[libpymo.TfEncoding]]: encodings.min = tf.keras.backend.get_value(self._encoding_min[i]) encodings.max = tf.keras.backend.get_value(self._encoding_max[i]) encodings.delta, encodings.offset = calculate_delta_offset(encodings.min, encodings.max, - self.bitwidth) + self.bitwidth, self.is_symmetric, + self.use_strict_symmetric[i]) encodings.bw = self.bitwidth all_encodings[i] = encodings else: diff --git a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quantsim.py b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quantsim.py index 28155c7013..2332300deb 100644 --- a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quantsim.py +++ b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quantsim.py @@ -373,7 +373,7 @@ def export(self, path, filename_prefix, custom_objects=None): try: convert_h5_model_to_pb_model(f'{model_path}.h5', custom_objects=custom_objects) except ValueError: - _logger.error("Could not convert h5 to frozen pb." + _logger.error("Could not convert h5 to frozen pb. " "Please call export() again with custom_objects defined.") raise encodings_dict = self.get_encodings_dict() @@ -396,7 +396,7 @@ def _compute_and_set_parameter_encodings(self, ops_with_invalid_encodings: List) channel_slice = weight_tensor.reshape(*last_two_axes_combined_shape) channel_slice = channel_slice.take(index, channel_slice.ndim - 1) elif isinstance(quantizer_wrapper.original_layer, tf.keras.layers.Conv2DTranspose): - if len(weight_tensor) == 4: + if weight_tensor.ndim == 4: channel_slice = weight_tensor.take(index, weight_tensor.ndim - 2) else: # For bias in Transpose layers diff --git a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quantsim_config/quantsim_config.py b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quantsim_config/quantsim_config.py index c41e6e9f37..bc341ce328 100644 --- a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quantsim_config/quantsim_config.py +++ b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quantsim_config/quantsim_config.py @@ -219,17 +219,16 @@ def _initialize_param_quantizers(layer: layers.Layer, param_config_dict: TreeLik if weight.dtype in QUANT_ALLOWED_DTYPES: weight_name = weight.name.split(":")[0] param_type = "bias" if "bias" in weight_name else "weight" - + # quant_settings is the global setting of the config file here. For params, if one of the settings is not + # specified, we will use the global setting (which may be specificed or defaulted). if param_type in param_config_dict: is_symmetric = param_config_dict[param_type][ConfigDictKeys.IS_SYMMETRIC].get( - SETTING, False) + SETTING, quant_settings.is_symmetric) enabled = param_config_dict[param_type][ConfigDictKeys.IS_QUANTIZED].get( - SETTING, False) + SETTING, quant_settings.enabled) else: - is_symmetric = param_config_dict[ConfigDictKeys.IS_SYMMETRIC].get( - SETTING, False) - enabled = param_config_dict[ConfigDictKeys.IS_QUANTIZED].get( - SETTING, False) + is_symmetric = quant_settings.is_symmetric + enabled = quant_settings.enabled if per_channel_quantization_enabled and isinstance(layer, keras_common_utils.per_channel_quantizeable_layers): @@ -250,7 +249,6 @@ def _initialize_param_quantizers(layer: layers.Layer, param_config_dict: TreeLik num_output_channels, enabled)) else: - param_quantizers.append( ParamPerTensorQuantizer(layer, weight_name, @@ -511,7 +509,9 @@ def _initialize_quantizers_by_layer(self, quant_scheme: Union[QuantScheme, str], SETTING, False) param_quant_settings = QuantizerSettings(default_param_bw, rounding_mode, quant_scheme, param_is_symmetric, - use_unsigned_symmetric, use_strict_symmetric) + use_unsigned_symmetric, use_strict_symmetric, + enabled=param_config_dict[ConfigDictKeys.IS_QUANTIZED].get( + SETTING, False)) # Initialize Param Quantizers self._layer_to_quantizers_dict[layer][PARAM_QUANTIZERS] = \ diff --git a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quantizer_info.py b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quantizer_info.py index 5f010fe47f..697469ba2f 100644 --- a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quantizer_info.py +++ b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quantizer_info.py @@ -444,13 +444,14 @@ def get_encoding(self) -> libpymo.TfEncoding: Get encoding if valid else raise error :return: encoding """ - def _create_encoding_object(min_val, max_val, bitwidth): + def _create_encoding_object(min_val, max_val, bitwidth, is_symmetric, use_strict_symmetric): """ Creates a libpymo encoding object """ encoding = libpymo.TfEncoding() encoding.min = min_val encoding.max = max_val encoding.bw = bitwidth - encoding.delta, encoding.offset = calculate_delta_offset(min_val, max_val, bitwidth) + encoding.delta, encoding.offset = calculate_delta_offset(min_val, max_val, bitwidth, + is_symmetric, use_strict_symmetric) return encoding if self.is_encoding_valid(): @@ -462,10 +463,12 @@ def _create_encoding_object(min_val, max_val, bitwidth): if isinstance(encoding_min, np.ndarray): encoding = [] for i, encoding_min_val in enumerate(encoding_min): - _encoding = _create_encoding_object(encoding_min_val, encoding_max[i], bitwidth) + _encoding = _create_encoding_object(encoding_min_val, encoding_max[i], bitwidth, + self.use_symmetric_encoding, self.use_strict_symmetric) encoding.append(_encoding) else: - encoding = _create_encoding_object(encoding_min, encoding_max, bitwidth) + encoding = _create_encoding_object(encoding_min, encoding_max, bitwidth, self.use_symmetric_encoding, + self.use_strict_symmetric) else: raise AssertionError('Compute encoding or Set encoding must be invoked before') diff --git a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quantsim.py b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quantsim.py index 8bc5b1ad94..5acc955812 100644 --- a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quantsim.py +++ b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quantsim.py @@ -667,7 +667,8 @@ def update_encoding_dict_entry_int(encoding_dict: Dict, quant_op_name: str): min_val, max_val = self.read_min_max(quant_op_name, variable_dict) # if per channel quantization is enabled, then min and max are numpy arrays, and this function gates the array op_bitwidth = int(self._get_op_variable_value(quant_op, QuantizeOpIndices.bit_width)) - delta, offset = calculate_delta_offset(min_val, max_val, op_bitwidth) + delta, offset = calculate_delta_offset(min_val, max_val, op_bitwidth, + use_symmetric_encodings=False, use_strict_symmetric=False) # Min and max will be numpy arrays, so to make them JSON serializable if self.per_channel_quantization_enabled and isinstance(min_val, np.ndarray): min_val = min_val.tolist() diff --git a/TrainingExtensions/tensorflow/test/python/eager/test_adaround_keras.py b/TrainingExtensions/tensorflow/test/python/eager/test_adaround_keras.py index 58372bb262..17416866ea 100644 --- a/TrainingExtensions/tensorflow/test/python/eager/test_adaround_keras.py +++ b/TrainingExtensions/tensorflow/test/python/eager/test_adaround_keras.py @@ -384,7 +384,6 @@ def test_apply_adaround_per_channel_conv2d_transpose(): Adaround.apply_adaround( model, params, path='./', filename_prefix='conv2d_transpose', default_param_bw=8, default_quant_scheme=QuantScheme.post_training_tf, - default_is_symmetric=False, config_file='config.json', ) diff --git a/TrainingExtensions/tensorflow/test/python/eager/test_tensor_quantizer_keras.py b/TrainingExtensions/tensorflow/test/python/eager/test_tensor_quantizer_keras.py index 0d50dcf071..aeee61a264 100644 --- a/TrainingExtensions/tensorflow/test/python/eager/test_tensor_quantizer_keras.py +++ b/TrainingExtensions/tensorflow/test/python/eager/test_tensor_quantizer_keras.py @@ -35,6 +35,7 @@ # ============================================================================= """ Unit tests for Keras tensor quantizer """ import tensorflow as tf +import numpy as np import aimet_common.libpymo as libpymo from aimet_tensorflow.keras.quant_sim.tensor_quantizer import ActivationTensorQuantizer, ParamPerTensorQuantizer from aimet_common.defs import QuantScheme @@ -59,7 +60,7 @@ def test_set_encodings(): assert quantizer._is_encoding_valid assert quant_encoding.min == 0.0 assert quant_encoding.max == 30.0 - assert quant_encoding.delta == 2.0 + assert np.allclose(quant_encoding.delta, 2.142857142857143, rtol=0.01) assert quant_encoding.offset == 0 diff --git a/TrainingExtensions/tensorflow/test/python/non_eager/test_adaround_weight.py b/TrainingExtensions/tensorflow/test/python/non_eager/test_adaround_weight.py index 1d4539f629..f19b8153c1 100644 --- a/TrainingExtensions/tensorflow/test/python/non_eager/test_adaround_weight.py +++ b/TrainingExtensions/tensorflow/test/python/non_eager/test_adaround_weight.py @@ -52,7 +52,7 @@ from aimet_common.utils import AimetLogger from aimet_common.quantsim_config.json_config_importer import JsonConfigImporter from aimet_tensorflow.examples.test_models import keras_model, single_residual -from aimet_tensorflow.adaround.adaround_weight import Adaround, AdaroundParameters +from aimet_tensorflow.adaround.adaround_weight import Adaround, AdaroundParameters, tf_op_type_to_onnx_type_dict logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Test) tf.compat.v1.disable_eager_execution() @@ -176,7 +176,9 @@ def test_get_is_symmetric_flag_for_op_param(self): try: configs = JsonConfigImporter.import_json_config_file(config_file='./config.json') - assert not Adaround._get_is_symmetric_flag_for_op_param(configs, conv_op.type, param_name="weight") + assert not Adaround.get_is_symmetric_flag_for_op_param(configs, conv_op.type, + param_name="weight", + framework_to_onnx_type_dict=tf_op_type_to_onnx_type_dict) finally: if os.path.isfile('./config.json'): os.remove('./config.json') @@ -202,7 +204,9 @@ def test_get_is_symmetric_flag_for_op_param(self): try: configs = JsonConfigImporter.import_json_config_file(config_file='./config.json') - assert Adaround._get_is_symmetric_flag_for_op_param(configs, conv_op.type, param_name="weight") + assert Adaround.get_is_symmetric_flag_for_op_param(configs, conv_op.type, + param_name="weight", + framework_to_onnx_type_dict=tf_op_type_to_onnx_type_dict) finally: if os.path.isfile('./config.json'): os.remove('./config.json') @@ -232,7 +236,9 @@ def test_get_is_symmetric_flag_for_op_param(self): try: configs = JsonConfigImporter.import_json_config_file(config_file='./config.json') - assert Adaround._get_is_symmetric_flag_for_op_param(configs, conv_op.type, param_name="weight") + assert Adaround.get_is_symmetric_flag_for_op_param(configs, conv_op.type, + param_name="weight", + framework_to_onnx_type_dict=tf_op_type_to_onnx_type_dict) finally: if os.path.isfile('./config.json'): os.remove('./config.json') @@ -273,9 +279,13 @@ def test_get_is_symmetric_flag_for_op_param(self): try: configs = JsonConfigImporter.import_json_config_file(config_file='./config.json') - assert Adaround._get_is_symmetric_flag_for_op_param(configs, conv_op.type, param_name="weight") + assert Adaround.get_is_symmetric_flag_for_op_param(configs, conv_op.type, + param_name="weight", + framework_to_onnx_type_dict=tf_op_type_to_onnx_type_dict) # For matmul op, is_symmetric should be False. - assert not Adaround._get_is_symmetric_flag_for_op_param(configs, matmul_op.type, param_name="weight") + assert not Adaround.get_is_symmetric_flag_for_op_param(configs, matmul_op.type, + param_name="weight", + framework_to_onnx_type_dict=tf_op_type_to_onnx_type_dict) finally: if os.path.isfile('./config.json'): os.remove('./config.json') diff --git a/TrainingExtensions/tensorflow/test/python/non_eager/test_per_channel_quantization.py b/TrainingExtensions/tensorflow/test/python/non_eager/test_per_channel_quantization.py index 4a9594d18c..f85278d983 100644 --- a/TrainingExtensions/tensorflow/test/python/non_eager/test_per_channel_quantization.py +++ b/TrainingExtensions/tensorflow/test/python/non_eager/test_per_channel_quantization.py @@ -703,8 +703,9 @@ def create_encoding(): _encoding.min = random.uniform(0, 1) _encoding.max = random.uniform(1, 3) _encoding.bw = 8 - _encoding.delta, _encoding.offset = calculate_delta_offset(_encoding.min, _encoding.max, - 8) + _encoding.delta, _encoding.offset = calculate_delta_offset(_encoding.min, _encoding.max, bitwidth=8, + use_symmetric_encodings=False, + use_strict_symmetric=False) return _encoding # Set the encodings for activation quantizers @@ -1430,7 +1431,9 @@ def create_encoding(): _encoding.min = random.uniform(0, 1) _encoding.max = random.uniform(1, 3) _encoding.bw = 8 - _encoding.delta, _encoding.offset = calculate_delta_offset(_encoding.min, _encoding.max, 8) + _encoding.delta, _encoding.offset = calculate_delta_offset(_encoding.min, _encoding.max, bitwidth=8, + use_symmetric_encodings=False, + use_strict_symmetric=False) return _encoding # Set the encodings for activation quantizers diff --git a/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py index ea2eca9899..5078ed14f9 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py @@ -1100,7 +1100,8 @@ def _create_encoding_dict(encoding: libpymo.TfEncoding, quantizer, propagate_enc ops. :return: Encoding Dictionary """ - data_type, bitwidth = quantizer.data_type, quantizer.bitwidth + data_type, bitwidth, use_symmetric_encodings, use_strict_symmetric = \ + quantizer.data_type, quantizer.bitwidth, quantizer.use_symmetric_encodings, quantizer.use_strict_symmetric if data_type == QuantizationDataType.float: enc_dict = {'bitwidth': bitwidth, 'dtype': "float"} @@ -1115,7 +1116,8 @@ def _create_encoding_dict(encoding: libpymo.TfEncoding, quantizer, propagate_enc encoding.delta, encoding.offset is_symmetric = quantizer.use_symmetric_encodings if not isinstance(quantizer, StaticGridTensorQuantizer): - scale, offset = calculate_delta_offset(encoding_min, encoding_max, bitwidth) + scale, offset = calculate_delta_offset(encoding_min, encoding_max, bitwidth, + use_symmetric_encodings, use_strict_symmetric) enc_dict = {'min': encoding_min, 'max': encoding_max, 'scale': scale, 'offset': int(offset), 'bitwidth': bw, 'is_symmetric': str(is_symmetric), 'dtype': "int"} diff --git a/TrainingExtensions/torch/test/python/test_adaround_weight.py b/TrainingExtensions/torch/test/python/test_adaround_weight.py index feee2b54d2..b90b88bd91 100644 --- a/TrainingExtensions/torch/test/python/test_adaround_weight.py +++ b/TrainingExtensions/torch/test/python/test_adaround_weight.py @@ -336,7 +336,9 @@ def test_adaround_conv_only_model_weight_binning(self): param_bit_width = 4 delta, offset = calculate_delta_offset(float(torch.min(model.conv1.weight)), float(torch.max(model.conv1.weight)), - param_bit_width) + param_bit_width, + use_symmetric_encodings=False, + use_strict_symmetric=False) print(delta, offset) input_shape = (1, 3, 32, 32) @@ -367,7 +369,9 @@ def test_unused_module_model(self): param_bit_width = 4 delta, offset = calculate_delta_offset(float(torch.min(model.conv1.weight)), float(torch.max(model.conv1.weight)), - param_bit_width) + param_bit_width, + use_symmetric_encodings=False, + use_strict_symmetric=False) print(delta, offset) input_shape = (1, 3, 32, 32) @@ -402,7 +406,9 @@ def test_out_of_sequence_module_model(self): param_bit_width = 4 delta, offset = calculate_delta_offset(float(torch.min(model.conv1.weight)), float(torch.max(model.conv1.weight)), - param_bit_width) + param_bit_width, + use_symmetric_encodings=False, + use_strict_symmetric=False) print(delta, offset) input_shape = (1, 3, 32, 32) @@ -437,7 +443,9 @@ def test_conv_transpose_2d_model(self): param_bit_width = 4 delta, offset = calculate_delta_offset(float(torch.min(model.trans_conv1.weight)), float(torch.max(model.trans_conv1.weight)), - param_bit_width) + param_bit_width, + use_symmetric_encodings=False, + use_strict_symmetric=False) logger.info("For the ConvTranspose2d layer's weights, delta = %f, offset = %f", delta, offset) input_shape = (1, 3, 24, 24)