From 283ee26ce81512b5a6cbb6327c1b0026b3f0da6a Mon Sep 17 00:00:00 2001 From: Sai Chaitanya Gajula Date: Wed, 3 Jul 2024 21:29:53 +0530 Subject: [PATCH] Modify quantizer type based on encoding provided when strict=False (#3141) Signed-off-by: Sai Chaitanya Gajula --- .../src/python/aimet_torch/qc_quantize_op.py | 28 ++++++-- .../torch/src/python/aimet_torch/utils.py | 40 ++++++++++- .../torch/test/python/test_quantsim_config.py | 69 +++++++++++++++++++ 3 files changed, 130 insertions(+), 7 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/qc_quantize_op.py b/TrainingExtensions/torch/src/python/aimet_torch/qc_quantize_op.py index b4bff32ac6..2a2177d045 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/qc_quantize_op.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/qc_quantize_op.py @@ -508,7 +508,7 @@ def import_param_encodings(self, ... } """ - # pylint: disable=too-many-branches + # pylint: disable=too-many-branches, too-many-statements for param_name, quantizer in self.param_quantizers.items(): if quantizer._is_encoding_frozen: # pylint: disable=protected-access continue @@ -523,11 +523,27 @@ def import_param_encodings(self, if quantizer.enabled: # pylint: disable=protected-access if isinstance(quantizer, StaticGridPerChannelQuantizer) and len(quantizer._cppOp) != len(encoding): - raise ValueError(f"Invalid PerChannel encodings for {param_name}, the quantizer is a " - f"PerChannelQuantizer. To avoid this, disable per_channel_quantization") - if isinstance(quantizer, StaticGridPerTensorQuantizer) and len(encoding) != 1: - raise ValueError(f"Invalid PerTensor encodings for {param_name}, the quantizer is a " - f"PerTensorQuantizer. To avoid this, enable per_channel_quantization") + assert len(encoding) == 1, (f'Number of Per Channel encodings provided ({len(encoding)}) is ' + f'not same as number of channels ({len(quantizer._cppOp)})') + if strict: + raise ValueError(f"Invalid PerChannel encodings for {param_name}, the quantizer is a " + f"PerChannelQuantizer. To avoid this, disable per_channel_quantization") + # Modifying PerChannel quantizer to PerTensor + _logger.warning('Replacing PerChannel Quantizer with PerTensor based on encoding provided') + quantizer = utils.get_per_tensor_quantizer_from_per_channel(quantizer) + self.param_quantizers[param_name] = quantizer + elif isinstance(quantizer, StaticGridPerTensorQuantizer) and len(encoding) != 1: + if strict: + raise ValueError(f"Invalid PerTensor encodings for {param_name}, the quantizer is a " + f"PerTensorQuantizer. To avoid this, enable per_channel_quantization") + # Modifying PerTensor quantizer to PerChannel + _logger.warning('Replacing PerTensor Quantizer with PerChannel based on encoding provided..') + quantizer = utils.get_per_channel_quantizer_from_per_tensor(quantizer, self.get_original_module()) + assert len(quantizer._cppOp) == len(encoding), (f'Number of per channel encodings ({len(encoding)})' + f' should much with number of output ' + f'channels ({len(quantizer._cppOp)})') + self.param_quantizers[param_name] = quantizer + if encoding[0]['dtype'] == 'int': # Validate and set symmetric flags before computing partial encodings validate_is_symmetric_flag(quantizer, encoding[0], strict) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/utils.py b/TrainingExtensions/torch/src/python/aimet_torch/utils.py index bde2a0b0b9..572d83f9c4 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/utils.py @@ -59,7 +59,7 @@ from aimet_common.utils import profile as _profile import aimet_common.libpymo as libpymo from aimet_torch import elementwise_ops -from aimet_torch.tensor_quantizer import TensorQuantizer +from aimet_torch.tensor_quantizer import TensorQuantizer, StaticGridPerChannelQuantizer, StaticGridPerTensorQuantizer logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Utils) @@ -1215,6 +1215,44 @@ def _validate_is_symmetric_flag(quantizer: TensorQuantizer, encoding_dict: Dict, raise AttributeError("Provided encoding doesn't have 'is_symmetric' flag") +def get_per_channel_quantizer_from_per_tensor(quantizer: TensorQuantizer, original_module: torch.nn.Module): + """ Get PerChannel Quantizer with same settings as given PerTensor Quantizer """ + channel_axis = 0 + if isinstance(original_module, (torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d)): + if len(original_module.weight.shape) > 1: + channel_axis = 1 + + num_channels = original_module.weight.shape[channel_axis] + use_strict_symmetric = quantizer.use_strict_symmetric + use_unsigned_symmetric = quantizer.use_unsigned_symmetric + quantizer = StaticGridPerChannelQuantizer(quantizer.bitwidth, quantizer.round_mode, + quantizer.quant_scheme, + quantizer.use_symmetric_encodings, + num_channels=num_channels, + enabled_by_default=quantizer.enabled, + ch_axis=channel_axis, + data_type=quantizer.data_type) + quantizer.use_strict_symmetric = use_strict_symmetric + quantizer.use_unsigned_symmetric = use_unsigned_symmetric + return quantizer + + +def get_per_tensor_quantizer_from_per_channel(quantizer: TensorQuantizer): + """ Get PerTensor Quantizer with same settings as given PerChannel Quantizer """ + use_strict_symmetric = quantizer.use_strict_symmetric + use_unsigned_symmetric = quantizer.use_unsigned_symmetric + quantizer = StaticGridPerTensorQuantizer(quantizer.bitwidth, quantizer.round_mode, + quantizer.quant_scheme, + quantizer.use_symmetric_encodings, + enabled_by_default=quantizer.enabled, + data_type=quantizer.data_type) + quantizer.use_strict_symmetric = use_strict_symmetric + quantizer.use_unsigned_symmetric = use_unsigned_symmetric + return quantizer + + def validate_is_symmetric_flag(quantizer: TensorQuantizer, encoding_dict: Dict, strict: bool = True): """ Validate 'is_symmetric' flag from encoding_dict with quantizer.use_symmetric_encodings and set the later accordingly diff --git a/TrainingExtensions/torch/test/python/test_quantsim_config.py b/TrainingExtensions/torch/test/python/test_quantsim_config.py index a296d3b8ae..54f775f3bd 100644 --- a/TrainingExtensions/torch/test/python/test_quantsim_config.py +++ b/TrainingExtensions/torch/test/python/test_quantsim_config.py @@ -2224,3 +2224,72 @@ def test_load_and_freeze_with_partial_encodings(self, sample_enc): assert sim.model.conv1.param_quantizers['weight'].use_symmetric_encodings else: assert not sim.model.conv1.param_quantizers['weight'].use_symmetric_encodings + + def test_load_encodings_to_allow_modifying_quantizer_type(self): + """ Test load encodings API to allow modifying quantizer type based on encoding """ + model = test_models.TinyModelWithNoMathInvariantOps() + dummy_input = torch.randn([1, 3, 24, 24]) + + sample_act_enc = {"min": -4, "max": 4, "bitwidth": 8, "dtype": "int", "is_symmetric": "False"} + sample_param_enc = {"min": -4, "max": 4, "bitwidth": 8, "dtype": "int", "is_symmetric": "True"} + + encodings = {"activation_encodings": {"conv1": {"input": {"0": sample_act_enc}}, + "mul1": {"output": {"0": sample_act_enc}}}, + "param_encodings": {}} + + pcq_config = { + "defaults":{ + "ops":{ + "is_output_quantized": "True" + }, + "params":{ + "is_quantized": "True", + "is_symmetric": "True" + }, + "strict_symmetric": "False", + "per_channel_quantization": "True" + }, + "params": {}, + "op_type": {}, + "model_input":{ + "is_input_quantized": "True" + }, + "supergroups": [], + "model_output": {} + } + + with tempfile.TemporaryDirectory() as tmp_dir: + pcq_config_file = os.path.join(tmp_dir, 'pcq_quantsim_config.json') + with open(pcq_config_file, 'w') as f: + json.dump(pcq_config, f) + + for config_file in [None, pcq_config_file]: + if config_file is None: + # PTQ to PCQ case, initial quantizer is PTQ, but the encodings are of PCQ + encodings['param_encodings']['conv1.weight'] = [sample_param_enc for i in range(16)] + else: + # PCQ to PTQ case, initial quantizer is PCQ, but the encodings are of PTQ + encodings['param_encodings']['conv1.weight'] = [sample_param_enc] + + with tempfile.TemporaryDirectory() as tmp_dir: + with open(os.path.join(tmp_dir, 'replace_quantizer_with_enc.json'), 'w') as f: + json.dump(encodings, f) + + sim = QuantizationSimModel(model, dummy_input, quant_scheme=QuantScheme.post_training_tf, config_file=config_file) + + # Checking Quantizer type before loading encodings to Quantsim + if config_file is None: + assert isinstance(sim.model.conv1.param_quantizers['weight'], StaticGridPerTensorQuantizer) + else: + assert isinstance(sim.model.conv1.param_quantizers['weight'], StaticGridPerChannelQuantizer) + + sim.load_and_freeze_encodings(os.path.join(tmp_dir, 'replace_quantizer_with_enc.json'), + ignore_when_quantizer_disabled=True) + + sim.compute_encodings(lambda m, _: m(dummy_input), None) + + # Checking whether the quantizer is modifed to required type after laoding encodings + if config_file is None: + assert isinstance(sim.model.conv1.param_quantizers['weight'], StaticGridPerChannelQuantizer) + else: + assert isinstance(sim.model.conv1.param_quantizers['weight'], StaticGridPerTensorQuantizer)