Skip to content

Commit

Permalink
Enable output quantizers for Cast Ops where input is int or bool and …
Browse files Browse the repository at this point in the history
…output is float (#2801) (#2806)

Enable output quantizers for Cast Ops where input is int or bool and output is float (#2801)
Signed-off-by: Alankar Mahajan <quic_alanmaha@quicinc.com>
  • Loading branch information
Mahajan, Alankar authored and GitHub Enterprise committed Jan 8, 2024
1 parent 8074167 commit cbcbca0
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
22 changes: 22 additions & 0 deletions TrainingExtensions/torch/src/python/aimet_torch/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -1799,6 +1799,28 @@ def _validate_torchquantizer(quant_sim_model):
_validate_torchquantizer(quant_sim_model)
OnnxSaver._export_model_to_onnx(quant_sim_model, dummy_input, model_path, is_conditional, onnx_export_args) # pylint: disable=protected-access

def _enable_output_quantizers_for_specific_cast_ops(self, inout_tensors_dtypes: Dict[torch.nn.Module, Tuple[torch.dtype, torch.dtype]]):
"""
Enable output quantizer for Cast Ops where datatype of input tensor is int/bool
and data type of output tensor is float.
"""
# pylint: disable=protected-access
model_prefix = self.connected_graph._model_name + '.'
torch_int_dtypes = {torch.int8, torch.int16, torch.int32, torch.int64, torch.bool, torch.uint8}
torch_float_dtypes = {torch.float16, torch.float32, torch.float64}

for module, inout_dtypes in inout_tensors_dtypes.items():
input_tensor_dtype = inout_dtypes[0]
output_tensor_dtype = inout_dtypes[1]
# pylint: disable=protected-access
module_name = self.connected_graph._module_to_name[module].split(model_prefix)[-1]

if input_tensor_dtype != output_tensor_dtype and input_tensor_dtype in torch_int_dtypes and output_tensor_dtype in torch_float_dtypes:
logger.info("Enabling output quantizer for module %s", module_name)
wrapped_module = getattr(self.model, module_name)
for output_quantizer in wrapped_module.output_quantizers:
setattr(output_quantizer, 'enabled', True)


def save_checkpoint(quant_sim_model: QuantizationSimModel, file_path: str):
"""
Expand Down
33 changes: 33 additions & 0 deletions TrainingExtensions/torch/src/python/aimet_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from aimet_common.defs import QuantScheme, QuantizationDataType, MAP_QUANT_SCHEME_TO_PYMO
from aimet_common.utils import AimetLogger, Handle, log_with_error_and_assert_if_false
import aimet_common.libpymo as libpymo
from aimet_torch import elementwise_ops


logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Utils)
Expand Down Expand Up @@ -1011,3 +1012,35 @@ def fn(_, inputs):
handle.remove()

return cached_data


def get_inout_tensors_dtypes_for_cast_modules(model: torch.nn.Module, input_tensor: Union[torch.Tensor, Tuple[torch.Tensor]]) -> Dict:
"""
Get the datatype of input and output tensor of Cast modules in a Pytorch Model.
:param model: Pytorch Model
:param input_tensor: Input tensor to run forward pass for the model.
A tuple of tensors should be passed if model has multiple inputs
:return: map of module -> (data type of input tensor, data type of output tensor)
"""
inout_dtypes_map = {}

def record_dtypes(module, inputs, outputs):

# pylint: disable=protected-access
if isinstance(module, elementwise_ops.Cast):
input_dtype = None

if isinstance(inputs, (list, tuple)):
input_dtype = inputs[0].dtype

elif isinstance(inputs, torch.Tensor):
input_dtype = inputs.dtype

else:
raise ValueError

inout_dtypes_map[module] = (input_dtype, outputs.dtype)

run_hook_for_layers_with_given_input(model, input_tensor, record_dtypes)
return inout_dtypes_map

0 comments on commit cbcbca0

Please sign in to comment.