diff --git a/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py index c637adfb6d..e398825c73 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py @@ -1799,6 +1799,28 @@ def _validate_torchquantizer(quant_sim_model): _validate_torchquantizer(quant_sim_model) OnnxSaver._export_model_to_onnx(quant_sim_model, dummy_input, model_path, is_conditional, onnx_export_args) # pylint: disable=protected-access + def _enable_output_quantizers_for_specific_cast_ops(self, inout_tensors_dtypes: Dict[torch.nn.Module, Tuple[torch.dtype, torch.dtype]]): + """ + Enable output quantizer for Cast Ops where datatype of input tensor is int/bool + and data type of output tensor is float. + """ + # pylint: disable=protected-access + model_prefix = self.connected_graph._model_name + '.' + torch_int_dtypes = {torch.int8, torch.int16, torch.int32, torch.int64, torch.bool, torch.uint8} + torch_float_dtypes = {torch.float16, torch.float32, torch.float64} + + for module, inout_dtypes in inout_tensors_dtypes.items(): + input_tensor_dtype = inout_dtypes[0] + output_tensor_dtype = inout_dtypes[1] + # pylint: disable=protected-access + module_name = self.connected_graph._module_to_name[module].split(model_prefix)[-1] + + if input_tensor_dtype != output_tensor_dtype and input_tensor_dtype in torch_int_dtypes and output_tensor_dtype in torch_float_dtypes: + logger.info("Enabling output quantizer for module %s", module_name) + wrapped_module = getattr(self.model, module_name) + for output_quantizer in wrapped_module.output_quantizers: + setattr(output_quantizer, 'enabled', True) + def save_checkpoint(quant_sim_model: QuantizationSimModel, file_path: str): """ diff --git a/TrainingExtensions/torch/src/python/aimet_torch/utils.py b/TrainingExtensions/torch/src/python/aimet_torch/utils.py index c5889e0d1d..8f96b111d2 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/utils.py @@ -53,6 +53,7 @@ from aimet_common.defs import QuantScheme, QuantizationDataType, MAP_QUANT_SCHEME_TO_PYMO from aimet_common.utils import AimetLogger, Handle, log_with_error_and_assert_if_false import aimet_common.libpymo as libpymo +from aimet_torch import elementwise_ops logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Utils) @@ -1011,3 +1012,35 @@ def fn(_, inputs): handle.remove() return cached_data + + +def get_inout_tensors_dtypes_for_cast_modules(model: torch.nn.Module, input_tensor: Union[torch.Tensor, Tuple[torch.Tensor]]) -> Dict: + """ + Get the datatype of input and output tensor of Cast modules in a Pytorch Model. + + :param model: Pytorch Model + :param input_tensor: Input tensor to run forward pass for the model. + A tuple of tensors should be passed if model has multiple inputs + :return: map of module -> (data type of input tensor, data type of output tensor) + """ + inout_dtypes_map = {} + + def record_dtypes(module, inputs, outputs): + + # pylint: disable=protected-access + if isinstance(module, elementwise_ops.Cast): + input_dtype = None + + if isinstance(inputs, (list, tuple)): + input_dtype = inputs[0].dtype + + elif isinstance(inputs, torch.Tensor): + input_dtype = inputs.dtype + + else: + raise ValueError + + inout_dtypes_map[module] = (input_dtype, outputs.dtype) + + run_hook_for_layers_with_given_input(model, input_tensor, record_dtypes) + return inout_dtypes_map