diff --git a/TrainingExtensions/torch/src/python/aimet_torch/bias_correction.py b/TrainingExtensions/torch/src/python/aimet_torch/bias_correction.py index d345648656..e8f58490b8 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/bias_correction.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/bias_correction.py @@ -76,12 +76,11 @@ def forward_pass(model: torch.nn.Module, batch: torch.Tensor): :param batch: batch :return: Nothing """ - model.eval() # first check if the model is on GPU or not if utils.is_model_on_gpu(model): batch = batch.cuda() try: - with torch.no_grad(): + with utils.in_eval_mode(model), torch.no_grad(): _ = model(batch) except StopForwardException: pass diff --git a/TrainingExtensions/torch/src/python/aimet_torch/channel_pruning/channel_pruner.py b/TrainingExtensions/torch/src/python/aimet_torch/channel_pruning/channel_pruner.py index 12dd294ece..be8504d18f 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/channel_pruning/channel_pruner.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/channel_pruning/channel_pruner.py @@ -51,6 +51,7 @@ from aimet_torch.layer_database import LayerDatabase, Layer from aimet_torch.data_subsampler import DataSubSampler from aimet_torch.channel_pruning.weight_reconstruction import WeightReconstructor +from aimet_torch import utils from aimet_torch.winnow.winnow import winnow_model @@ -150,7 +151,7 @@ def sorting_hook(module, _inp, _out): handles.append(pair.layer.module.register_forward_hook(sorting_hook)) # run one forward pass with hooks - with torch.no_grad(): + with utils.in_eval_mode(model), torch.no_grad(): _ = model(input_data) # remove hooks diff --git a/TrainingExtensions/torch/src/python/aimet_torch/data_subsampler.py b/TrainingExtensions/torch/src/python/aimet_torch/data_subsampler.py index 30ddd389d4..5e57013c7d 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/data_subsampler.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/data_subsampler.py @@ -167,8 +167,6 @@ def _forward_pass(model: torch.nn.Module, batch: Union[torch.Tensor, List, Tuple :param model: model :param batch: batch """ - # keep the model in eval mode - model.eval() # get the model's device placement information device = utils.get_device(model) @@ -179,7 +177,8 @@ def _forward_pass(model: torch.nn.Module, batch: Union[torch.Tensor, List, Tuple batch = [batch] try: - with torch.no_grad(): + # keep the model in eval mode + with utils.in_eval_mode(model), torch.no_grad(): _ = model(*batch) except StopForwardException: pass diff --git a/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py b/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py index a90b64d178..dffeb5fcd7 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py @@ -656,7 +656,8 @@ def _create_onnx_model_with_markers(cls, dummy_input, pt_model, working_dir, onn cls._add_markers(model, module_name_map, module_marker_map, is_conditional) temp_file = os.path.join(working_dir, 'temp_onnx_model_with_markers.onnx') if is_conditional: - dummy_output = model(*dummy_input) + with aimet_torch.utils.in_eval_mode(model), torch.no_grad(): + dummy_output = model(*dummy_input) scripted_model = torch.jit.script(model) torch.onnx.export(scripted_model, dummy_input, temp_file, example_outputs=dummy_output, enable_onnx_checker=False, **onnx_export_args.kwargs) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py index 3a1ae169bd..e6345d0b49 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py @@ -258,8 +258,7 @@ def compute_encodings(self, forward_pass_callback, forward_pass_callback_args): layer.set_mode(QcQuantizeOpMode.ANALYSIS) # Run forward iterations so we can collect statistics to compute the appropriate encodings - self.model.eval() - with torch.no_grad(): + with utils.in_eval_mode(self.model), torch.no_grad(): _ = forward_pass_callback(self.model, forward_pass_callback_args) # Get the computed per-layer encodings and log them @@ -371,7 +370,7 @@ def export_torch_script_model_and_encodings(path: str, filename_prefix: str, :param dummy_input: Dummy input to the model. Used to parse model graph. :return: None """ - with torch.no_grad(): + with utils.in_eval_mode(original_model), torch.no_grad(): trace = torch.jit.trace(original_model, dummy_input) ts_path = os.path.join(path, filename_prefix + '.torchscript.pth') trace.save(ts_path) @@ -1144,8 +1143,7 @@ def _export_conditional(self, path: str, filename_prefix: str, dummy_input: Unio if self._is_conditional: self._add_inputs_hook(hooks) - self.model.eval() - with torch.no_grad(): + with utils.in_eval_mode(self.model), torch.no_grad(): _ = forward_pass_callback(self.model, forward_pass_callback_args) # Any hooks that were hit during forward pass callback would have removed themselves. Remove the remaining @@ -1219,8 +1217,9 @@ def run_modules_for_traced_custom_marker(self, module_list: List[torch.nn.Module module = getattr(module, '_module_to_wrap') # Only perform init and trace if the given module is a leaf module, and we have not recorded it before if module in module_to_name_map and module_to_name_map[module] not in self._module_marker_map: - marker_layer = torch.jit.trace(CustomMarker(module, module_to_name_map[module]), - dummy_input) + with utils.in_eval_mode(module), torch.no_grad(): + marker_layer = torch.jit.trace(CustomMarker(module, module_to_name_map[module]), + dummy_input) self._module_marker_map[module_to_name_map[module]] = marker_layer diff --git a/TrainingExtensions/torch/src/python/aimet_torch/utils.py b/TrainingExtensions/torch/src/python/aimet_torch/utils.py index 457cc284f6..034babfc5c 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/utils.py @@ -114,9 +114,6 @@ def _hook_to_collect_inp_out_data(_, inp, out): handle = self._module.register_forward_hook(_hook_to_collect_inp_out_data) - # keep the model in eval mode - self._model.eval() - # get the model's device placement information device = get_device(self._model) @@ -127,7 +124,7 @@ def _hook_to_collect_inp_out_data(_, inp, out): model_input = [model_input] try: - with torch.no_grad(): + with in_eval_mode(self._model), torch.no_grad(): _ = self._model(*model_input) except StopForwardException: @@ -234,7 +231,7 @@ def run_hook_for_layers(model: torch.nn.Module, input_shapes: Union[Tuple, List[ device = get_device(model) dummy_tensors = create_rand_tensors_given_shapes(input_shapes) dummy_tensors = [tensor.to(device) for tensor in dummy_tensors] - with torch.no_grad(): + with in_eval_mode(model), torch.no_grad(): _ = model(*dummy_tensors) # -------------------------- @@ -271,7 +268,7 @@ def run_hook_for_layers_with_given_input(model: torch.nn.Module, input_tensor: U # ------------------------------------------------ # Run forward pass to execute the hook functions # ------------------------------------------------ - with torch.no_grad(): + with in_eval_mode(model), torch.no_grad(): if isinstance(input_tensor, (list, tuple)): _ = model(*input_tensor) else: diff --git a/TrainingExtensions/torch/test/python/test_utils.py b/TrainingExtensions/torch/test/python/test_utils.py index fdd26d0088..3dcd368831 100644 --- a/TrainingExtensions/torch/test/python/test_utils.py +++ b/TrainingExtensions/torch/test/python/test_utils.py @@ -244,6 +244,7 @@ def test_change_tensor_device(self): def _collect_inp_out_data(self, device): model = TinyModel().to(device=device) + model.eval() model_input = torch.randn(1, 3, 32, 32).to(device=device) module_data = utils.ModuleData(model, model.conv1) @@ -287,6 +288,7 @@ def test_collect_inp_out_data_gpu(self): def _collect_inp_out_data_multi_input(self, device): model = MultiInput().to(device=device) + model.eval() inp_shape_1 = (1, 3, 32, 32) inp_shape_2 = (1, 3, 20, 20) model_input = utils.create_rand_tensors_given_shapes([inp_shape_1, inp_shape_2])