diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f6f3f3693b1b1..beb5f70b1da739 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -237,6 +237,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated using `'val_loss'` to set the `ModelCheckpoint` monitor ([#6012](https://github.com/PyTorchLightning/pytorch-lightning/pull/6012)) +- Deprecated `.get_model()` with explicit `.lightning_module` property ([#6035](https://github.com/PyTorchLightning/pytorch-lightning/pull/6035)) + + ### Removed - Removed deprecated checkpoint argument `filepath` ([#5321](https://github.com/PyTorchLightning/pytorch-lightning/pull/5321)) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index d18abde814aab4..06c0323466cee1 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -101,11 +101,11 @@ def _to_lightning_optimizer(cls, optimizer, trainer, opt_idx): return optimizer def _toggle_model(self): - model_ref = self._trainer.get_model() + model_ref = self._trainer.lightning_module model_ref.toggle_optimizer(self, self._optimizer_idx) def _untoggle_model(self): - model_ref = self._trainer.get_model() + model_ref = self._trainer.lightning_module model_ref.untoggle_optimizer(self) @contextmanager @@ -129,7 +129,7 @@ def toggle_model(self, sync_grad: bool = True): def __optimizer_step(self, closure: Optional[Callable] = None, profiler_name: str = None, **kwargs): trainer = self._trainer optimizer = self._optimizer - model = trainer.get_model() + model = trainer.lightning_module with trainer.profiler.profile(profiler_name): trainer.accelerator_backend.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index a11394734f97b8..f292f5a78bc65b 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -14,9 +14,10 @@ from abc import ABC from copy import deepcopy -from typing import Callable, List +from typing import List from pytorch_lightning.callbacks import Callback +from pytorch_lightning.core.lightning import LightningModule class TrainerCallbackHookMixin(ABC): @@ -24,7 +25,7 @@ class TrainerCallbackHookMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class callbacks: List[Callback] = [] - get_model: Callable + lightning_module: LightningModule def on_before_accelerator_backend_setup(self, model): """Called in the beginning of fit and test""" @@ -39,7 +40,7 @@ def setup(self, model, stage: str): def teardown(self, stage: str): """Called at the end of fit and test""" for callback in self.callbacks: - callback.teardown(self, self.get_model(), stage) + callback.teardown(self, self.lightning_module, stage) def on_init_start(self): """Called when the trainer initialization begins, model has not yet been set.""" @@ -54,72 +55,72 @@ def on_init_end(self): def on_fit_start(self): """Called when the trainer initialization begins, model has not yet been set.""" for callback in self.callbacks: - callback.on_fit_start(self, self.get_model()) + callback.on_fit_start(self, self.lightning_module) def on_fit_end(self): """Called when the trainer initialization begins, model has not yet been set.""" for callback in self.callbacks: - callback.on_fit_end(self, self.get_model()) + callback.on_fit_end(self, self.lightning_module) def on_sanity_check_start(self): """Called when the validation sanity check starts.""" for callback in self.callbacks: - callback.on_sanity_check_start(self, self.get_model()) + callback.on_sanity_check_start(self, self.lightning_module) def on_sanity_check_end(self): """Called when the validation sanity check ends.""" for callback in self.callbacks: - callback.on_sanity_check_end(self, self.get_model()) + callback.on_sanity_check_end(self, self.lightning_module) def on_train_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: - callback.on_train_epoch_start(self, self.get_model()) + callback.on_train_epoch_start(self, self.lightning_module) def on_train_epoch_end(self, outputs): """Called when the epoch ends.""" for callback in self.callbacks: - callback.on_train_epoch_end(self, self.get_model(), outputs) + callback.on_train_epoch_end(self, self.lightning_module, outputs) def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: - callback.on_validation_epoch_start(self, self.get_model()) + callback.on_validation_epoch_start(self, self.lightning_module) def on_validation_epoch_end(self): """Called when the epoch ends.""" for callback in self.callbacks: - callback.on_validation_epoch_end(self, self.get_model()) + callback.on_validation_epoch_end(self, self.lightning_module) def on_test_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: - callback.on_test_epoch_start(self, self.get_model()) + callback.on_test_epoch_start(self, self.lightning_module) def on_test_epoch_end(self): """Called when the epoch ends.""" for callback in self.callbacks: - callback.on_test_epoch_end(self, self.get_model()) + callback.on_test_epoch_end(self, self.lightning_module) def on_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: - callback.on_epoch_start(self, self.get_model()) + callback.on_epoch_start(self, self.lightning_module) def on_epoch_end(self): """Called when the epoch ends.""" for callback in self.callbacks: - callback.on_epoch_end(self, self.get_model()) + callback.on_epoch_end(self, self.lightning_module) def on_train_start(self): """Called when the train begins.""" for callback in self.callbacks: - callback.on_train_start(self, self.get_model()) + callback.on_train_start(self, self.lightning_module) def on_train_end(self): """Called when the train ends.""" for callback in self.callbacks: - callback.on_train_end(self, self.get_model()) + callback.on_train_end(self, self.lightning_module) def on_pretrain_routine_start(self, model): """Called when the train begins.""" @@ -134,74 +135,74 @@ def on_pretrain_routine_end(self, model): def on_batch_start(self): """Called when the training batch begins.""" for callback in self.callbacks: - callback.on_batch_start(self, self.get_model()) + callback.on_batch_start(self, self.lightning_module) def on_batch_end(self): """Called when the training batch ends.""" for callback in self.callbacks: - callback.on_batch_end(self, self.get_model()) + callback.on_batch_end(self, self.lightning_module) def on_train_batch_start(self, batch, batch_idx, dataloader_idx): """Called when the training batch begins.""" for callback in self.callbacks: - callback.on_train_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx) + callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): """Called when the training batch ends.""" for callback in self.callbacks: - callback.on_train_batch_end(self, self.get_model(), outputs, batch, batch_idx, dataloader_idx) + callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): """Called when the validation batch begins.""" for callback in self.callbacks: - callback.on_validation_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx) + callback.on_validation_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): """Called when the validation batch ends.""" for callback in self.callbacks: - callback.on_validation_batch_end(self, self.get_model(), outputs, batch, batch_idx, dataloader_idx) + callback.on_validation_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) def on_test_batch_start(self, batch, batch_idx, dataloader_idx): """Called when the test batch begins.""" for callback in self.callbacks: - callback.on_test_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx) + callback.on_test_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx): """Called when the test batch ends.""" for callback in self.callbacks: - callback.on_test_batch_end(self, self.get_model(), outputs, batch, batch_idx, dataloader_idx) + callback.on_test_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) def on_validation_start(self): """Called when the validation loop begins.""" for callback in self.callbacks: - callback.on_validation_start(self, self.get_model()) + callback.on_validation_start(self, self.lightning_module) def on_validation_end(self): """Called when the validation loop ends.""" for callback in self.callbacks: - callback.on_validation_end(self, self.get_model()) + callback.on_validation_end(self, self.lightning_module) def on_test_start(self): """Called when the test begins.""" for callback in self.callbacks: - callback.on_test_start(self, self.get_model()) + callback.on_test_start(self, self.lightning_module) def on_test_end(self): """Called when the test ends.""" for callback in self.callbacks: - callback.on_test_end(self, self.get_model()) + callback.on_test_end(self, self.lightning_module) def on_keyboard_interrupt(self): """Called when the training is interrupted by KeyboardInterrupt.""" for callback in self.callbacks: - callback.on_keyboard_interrupt(self, self.get_model()) + callback.on_keyboard_interrupt(self, self.lightning_module) def on_save_checkpoint(self): """Called when saving a model checkpoint.""" callback_states = {} for callback in self.callbacks: callback_class = type(callback) - state = callback.on_save_checkpoint(self, self.get_model()) + state = callback.on_save_checkpoint(self, self.lightning_module) if state: callback_states[callback_class] = state return callback_states @@ -224,11 +225,11 @@ def on_after_backward(self): Called after loss.backward() and before optimizers do anything. """ for callback in self.callbacks: - callback.on_after_backward(self, self.get_model()) + callback.on_after_backward(self, self.lightning_module) def on_before_zero_grad(self, optimizer): """ Called after optimizer.step() and before optimizer.zero_grad(). """ for callback in self.callbacks: - callback.on_before_zero_grad(self, self.get_model(), optimizer) + callback.on_before_zero_grad(self, self.lightning_module, optimizer) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 2fca7b410f3e1c..4f5238a570ede4 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -94,7 +94,7 @@ def restore(self, checkpoint_path: str, on_gpu: bool) -> bool: checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) # acquire the model - model = self.trainer.get_model() + model = self.trainer.lightning_module # restore model and datamodule state self.restore_model_state(model, checkpoint) @@ -214,7 +214,7 @@ def hpc_save(self, folderpath: str, logger): filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt') # give model a chance to do something on hpc_save - model = self.trainer.get_model() + model = self.trainer.lightning_module checkpoint = self.dump_checkpoint() model.on_hpc_save(checkpoint) @@ -307,7 +307,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: checkpoint['amp_scaling_state'] = amp.state_dict() # add the hyper_parameters and state_dict from the model - model = self.trainer.get_model() + model = self.trainer.lightning_module # dump the module_arguments and state_dict from the model checkpoint['state_dict'] = model.state_dict() @@ -339,7 +339,7 @@ def hpc_load(self, checkpoint_path: str, on_gpu: bool): checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) # acquire the model - model = self.trainer.get_model() + model = self.trainer.lightning_module # restore model and datamodule state self.restore_model_state(model, checkpoint) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 84f88fb9840f25..c435204107775b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -235,7 +235,7 @@ def info(self): """ This function provides necessary parameters to properly configure HookResultStore obj """ - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module return { "batch_idx": self.trainer.batch_idx, "fx_name": model_ref._current_hook_fx_name or model_ref._current_fx_name, @@ -252,7 +252,7 @@ def reset_model(self): """ This function is used to reset model state at the end of the capture """ - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module model_ref._results = Result() model_ref._current_hook_fx_name = None model_ref._current_fx_name = '' @@ -263,7 +263,7 @@ def cache_result(self) -> None: and store the result object """ with self.trainer.profiler.profile("cache_result"): - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module # extract hook results hook_result = model_ref._results diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 89c72883fc497b..8ebec3238e2765 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -82,7 +82,7 @@ def cached_results(self) -> Union[EpochResultStore, None]: def get_metrics(self, key: str) -> Dict: metrics_holder = getattr(self, f"_{key}", None) - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module metrics_holder.convert( self.trainer._device_type == DeviceType.TPU, model_ref.device if model_ref is not None else model_ref, @@ -103,7 +103,7 @@ def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoc def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataloaders): # Todo: required argument `testing` is not used - model = self.trainer.get_model() + model = self.trainer.lightning_module # set dataloader_idx only if multiple ones model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None # track batch_size @@ -263,7 +263,7 @@ def track_metrics_deprecated(self, deprecated_eval_results): def evaluation_epoch_end(self, testing): # Todo: required argument `testing` is not used # reset dataloader idx - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module model_ref._current_dataloader_idx = None # setting `has_batch_loop_finished` to True @@ -408,7 +408,7 @@ def log_train_epoch_end_metrics( # epoch_output[optimizer_idx][training_step_idx][tbptt_index] # remember that not using truncated backprop is equivalent with truncated back prop of len(1) - model = self.trainer.get_model() + model = self.trainer.lightning_module epoch_callback_metrics = {} diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index 060601049f9b7e..4a0c565d78be0b 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -25,7 +25,7 @@ def __init__(self, trainer): self.trainer = trainer def copy_trainer_model_properties(self, model): - ref_model = self._get_reference_model(model) + ref_model = self.trainer.lightning_module or model automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization self.trainer.train_loop.automatic_optimization = automatic_optimization @@ -37,11 +37,3 @@ def copy_trainer_model_properties(self, model): m.use_amp = self.trainer.amp_backend is not None m.testing = self.trainer.testing m.precision = self.trainer.precision - - def get_model(self): - return self._get_reference_model(self.trainer.model) - - def _get_reference_model(self, model): - if self.trainer.accelerator_backend and self.trainer.accelerator_backend.lightning_module: - return self.trainer.accelerator_backend.lightning_module - return model diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 71b557bf75a2c0..e1eecf26ed70e5 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn @@ -130,3 +131,15 @@ def use_single_gpu(self, val: bool) -> None: ) if val: self.accelerator_connector._device_type = DeviceType.GPU + + +class DeprecatedModelAttributes: + + lightning_module = LightningModule + + def get_model(self) -> LightningModule: + rank_zero_warn( + "The use of `Trainer.get_model()` is deprecated in favor of `Trainer.lightning_module`" + " and will be removed in v1.4.", DeprecationWarning + ) + return self.lightning_module diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 053c4ea5ae3603..284baff3e2a630 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -47,7 +47,7 @@ def on_trainer_init(self): def get_evaluation_dataloaders(self, max_batches): # select dataloaders - model = self.trainer.get_model() + model = self.trainer.lightning_module # select dataloaders if self.trainer.testing: @@ -80,14 +80,14 @@ def on_evaluation_start(self, *args, **kwargs): self.trainer.call_hook('on_validation_start', *args, **kwargs) def on_evaluation_model_eval(self, *_, **__): - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module if self.trainer.testing: model_ref.on_test_model_eval() else: model_ref.on_validation_model_eval() def on_evaluation_model_train(self, *_, **__): - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module if self.trainer.testing: model_ref.on_test_model_train() else: @@ -100,7 +100,7 @@ def on_evaluation_end(self, *args, **kwargs): self.trainer.call_hook('on_validation_end', *args, **kwargs) def reload_evaluation_dataloaders(self): - model = self.trainer.get_model() + model = self.trainer.lightning_module if self.trainer.testing: self.trainer.reset_test_dataloader(model) else: @@ -151,7 +151,7 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx): # configure args args = self._build_args(batch, batch_idx, dataloader_idx) - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module model_ref._results = Result() if self.testing: @@ -199,7 +199,7 @@ def log_epoch_metrics_on_evaluation_end(self): return eval_loop_results def __run_eval_epoch_end(self, num_dataloaders): - model = self.trainer.get_model() + model = self.trainer.lightning_module # with a single dataloader don't pass an array outputs = self.outputs @@ -270,7 +270,7 @@ def __auto_reduce_result_objs(self, outputs): return eval_results def on_predict_epoch_end(self): - self.trainer._progress_bar_callback.on_test_end(self.trainer, self.trainer.get_model()) + self.trainer._progress_bar_callback.on_test_end(self.trainer, self.trainer.lightning_module) results = self._predictions diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index 420911bb2b0642..7e3d6cc78320cd 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -13,24 +13,22 @@ # limitations under the License. import inspect -from abc import ABC, abstractmethod +from abc import ABC from pytorch_lightning.core.lightning import LightningModule class TrainerModelHooksMixin(ABC): + lightning_module: LightningModule + def is_function_implemented(self, f_name, model=None): if model is None: - model = self.get_model() + model = self.lightning_module f_op = getattr(model, f_name, None) return callable(f_op) def has_arg(self, f_name, arg_name): - model = self.get_model() + model = self.lightning_module f_op = getattr(model, f_name, None) return arg_name in inspect.signature(f_op).parameters - - @abstractmethod - def get_model(self) -> LightningModule: - """Warning: this is just empty shell for code implemented in other class.""" diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py index 43016b8943c81c..4fecbbaf053489 100644 --- a/pytorch_lightning/trainer/predict_loop.py +++ b/pytorch_lightning/trainer/predict_loop.py @@ -28,7 +28,7 @@ def on_trainer_init(self): def get_predict_dataloaders(self, max_batches): # select dataloaders - model = self.trainer.get_model() + model = self.trainer.lightning_module self.trainer.reset_predict_dataloader(model) dataloaders = self.trainer.predict_dataloaders if max_batches is None: @@ -40,7 +40,7 @@ def should_skip_predict(self, dataloaders, max_batches): return dataloaders is None or not sum(max_batches) def on_predict_model_eval(self, *_, **__): - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module model_ref.on_predict_model_eval() def setup(self, model, max_batches, dataloaders): @@ -55,7 +55,7 @@ def setup(self, model, max_batches, dataloaders): self.num_dataloaders = self._get_num_dataloaders(dataloaders) self._predictions = [[] for _ in range(self.num_dataloaders)] - self.trainer._progress_bar_callback.on_predict_start(self.trainer, self.trainer.get_model()) + self.trainer._progress_bar_callback.on_predict_start(self.trainer, self.trainer.lightning_module) def _get_num_dataloaders(self, dataloaders): # case where user does: @@ -71,7 +71,7 @@ def predict(self, batch, batch_idx, dataloader_idx): if self.num_dataloaders: args.append(dataloader_idx) - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module model_ref._current_fx_name = "predict" predictions = self.trainer.accelerator_backend.predict(args) @@ -82,7 +82,7 @@ def predict(self, batch, batch_idx, dataloader_idx): return def on_predict_epoch_end(self): - self.trainer._progress_bar_callback.on_predict_end(self.trainer, self.trainer.get_model()) + self.trainer._progress_bar_callback.on_predict_end(self.trainer, self.trainer.lightning_module) results = self._predictions diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index ec735e9dccf715..47aad2710394dc 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -245,7 +245,7 @@ def progress_bar_callback(self) -> Optional[ProgressBarBase]: @property def progress_bar_dict(self) -> dict: """ Read-only for progress bar metrics. """ - ref_model = self.get_model() + ref_model = self.lightning_module ref_model = cast(LightningModule, ref_model) standard_metrics = ref_model.get_progress_bar_dict() @@ -270,7 +270,7 @@ def disable_validation(self) -> bool: @property def enable_validation(self) -> bool: """ Check if we should run validation during training. """ - model_ref = self.get_model() + model_ref = self.lightning_module val_loop_enabled = is_overridden('validation_step', model_ref) and self.limit_val_batches > 0 return val_loop_enabled @@ -352,11 +352,6 @@ def model(self, model: torch.nn.Module) -> None: """ self.accelerator.model = model - def get_model(self) -> LightningModule: - # TODO: rename this to lightning_module (see training type plugin) - # backward compatible - return self.lightning_module - @property def lightning_optimizers(self) -> List[LightningOptimizer]: if self._lightning_optimizers is None: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b9a0cc92a151b6..e1c2bcbbbce71b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -45,7 +45,7 @@ from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin -from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes +from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes, DeprecatedModelAttributes from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin @@ -80,6 +80,7 @@ class Trainer( TrainerTrainingTricksMixin, TrainerDataLoadingMixin, DeprecatedDistDeviceAttributes, + DeprecatedModelAttributes, ): @overwrite_by_env_vars @@ -582,7 +583,7 @@ def _pre_training_routine(self): # Pre-train # -------------------------- # on pretrain routine start - ref_model = self.get_model() + ref_model = self.lightning_module self.on_pretrain_routine_start(ref_model) if self.is_function_implemented("on_pretrain_routine_start"): @@ -610,15 +611,15 @@ def run_train(self): if not self.is_global_zero and self.progress_bar_callback is not None: self.progress_bar_callback.disable() - self.run_sanity_check(self.get_model()) + self.run_sanity_check(self.lightning_module) # set stage for logging - self._set_running_stage(RunningStage.TRAINING, self.get_model()) + self._set_running_stage(RunningStage.TRAINING, self.lightning_module) self.checkpoint_connector.has_trained = False # enable train mode - model = self.get_model() + model = self.lightning_module model.train() torch.set_grad_enabled(True) @@ -677,7 +678,9 @@ def run_train(self): def run_evaluation(self, max_batches=None, on_epoch=False): # used to know if we are logging for val, test + reset cached results - self._set_running_stage(RunningStage.TESTING if self.testing else RunningStage.EVALUATING, self.get_model()) + self._set_running_stage( + RunningStage.TESTING if self.testing else RunningStage.EVALUATING, self.lightning_module + ) self.logger_connector.reset() # bookkeeping @@ -693,7 +696,7 @@ def run_evaluation(self, max_batches=None, on_epoch=False): # enable eval mode + no grads self.evaluation_loop.on_evaluation_model_eval() # ref model - model = self.get_model() + model = self.lightning_module model.zero_grad() torch.set_grad_enabled(False) @@ -810,7 +813,7 @@ def run_predict(self): return [] # ref model - model = self.get_model() + model = self.lightning_module # enable eval mode + no grads self.predict_loop.on_predict_model_eval() @@ -904,7 +907,7 @@ def test( # -------------------- self.verbose_test = verbose - self._set_running_stage(RunningStage.TESTING, model or self.get_model()) + self._set_running_stage(RunningStage.TESTING, model or self.lightning_module) # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if test_dataloaders and datamodule: @@ -913,7 +916,7 @@ def test( ) # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test') + self.data_connector.attach_datamodule(model or self.lightning_module, datamodule, 'test') if model is not None: results = self.__test_given_model(model, test_dataloaders) @@ -921,11 +924,11 @@ def test( results = self.__test_using_best_weights(ckpt_path, test_dataloaders) self.teardown('test') - self._set_running_stage(None, model or self.get_model()) + self._set_running_stage(None, model or self.lightning_module) return results def __test_using_best_weights(self, ckpt_path, test_dataloaders): - model = self.get_model() + model = self.lightning_module # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path: @@ -961,7 +964,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): # teardown if self.is_function_implemented('teardown'): - model_ref = self.get_model() + model_ref = self.lightning_module model_ref.teardown('test') return results @@ -1011,7 +1014,7 @@ def predict( # -------------------- # If you supply a datamodule you can't supply dataloaders - model = model or self.get_model() + model = model or self.lightning_module self._set_running_stage(RunningStage.PREDICTING, model) @@ -1072,7 +1075,7 @@ def _reset_result_and_set_hook_fx_name(self, hook_name): # on_before_zero_grad is called within training_step if "batch_start" in hook_name or "on_before_zero_grad" in hook_name: return True - model_ref = self.get_model() + model_ref = self.lightning_module if model_ref is not None: # used to track current hook name called model_ref._results = Result() @@ -1080,7 +1083,7 @@ def _reset_result_and_set_hook_fx_name(self, hook_name): return False def _cache_logged_metrics(self): - model_ref = self.get_model() + model_ref = self.lightning_module if model_ref is not None: # capture logging for this hook self.logger_connector.cache_logged_metrics() @@ -1099,7 +1102,7 @@ def call_hook(self, hook_name, *args, **kwargs): # next call hook in lightningModule output = None - model_ref = self.get_model() + model_ref = self.lightning_module if is_overridden(hook_name, model_ref): hook_fx = getattr(model_ref, hook_name) output = hook_fx(*args, **kwargs) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 32b175fbaae974..9a2a9da1636bfd 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -158,7 +158,7 @@ def check_checkpoint_callback(self, should_update, is_last=False): if is_last and any(cb.save_last for cb in callbacks): rank_zero_info("Saving latest checkpoint...") - model = self.trainer.get_model() + model = self.trainer.lightning_module for cb in callbacks: cb.on_validation_end(self.trainer, model) @@ -167,7 +167,7 @@ def check_early_stopping_callback(self, should_update): # TODO bake this logic into the EarlyStopping callback if should_update and self.trainer.checkpoint_connector.has_trained: callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] - model = self.trainer.get_model() + model = self.trainer.lightning_module for cb in callbacks: cb.on_validation_end(self.trainer, model) @@ -177,7 +177,7 @@ def on_train_epoch_start(self, epoch): # update training progress in trainer self.trainer.current_epoch = epoch - model = self.trainer.get_model() + model = self.trainer.lightning_module # reset train dataloader if epoch != 0 and self.trainer.reload_dataloaders_every_epoch: @@ -189,7 +189,7 @@ def on_train_epoch_start(self, epoch): self.trainer.train_dataloader.sampler.set_epoch(epoch) # changing gradient according accumulation_scheduler - self.trainer.accumulation_scheduler.on_epoch_start(self.trainer, self.trainer.get_model()) + self.trainer.accumulation_scheduler.on_epoch_start(self.trainer, self.trainer.lightning_module) # stores accumulated grad fractions per batch self.accumulated_loss = TensorRunningAccum(window_length=self.trainer.accumulate_grad_batches) @@ -229,8 +229,8 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end hook_overridden = ( - is_overridden("training_epoch_end", model=self.trainer.get_model()) - or is_overridden("on_train_epoch_end", model=self.trainer.get_model()) + is_overridden("training_epoch_end", model=self.trainer.lightning_module) + or is_overridden("on_train_epoch_end", model=self.trainer.lightning_module) ) # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end @@ -281,7 +281,7 @@ def _check_training_step_output(self, training_step_output): def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # give the PL module a result for logging - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module with self.trainer.profiler.profile("model_forward"): args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) @@ -372,7 +372,7 @@ def _process_training_step_output(self, training_step_output, split_batch): return training_step_output_for_epoch_end, training_step_output def _process_training_step_output_1_0(self, training_step_output, split_batch): - result = self.trainer.get_model()._results + result = self.trainer.lightning_module._results loss = None hiddens = None @@ -407,7 +407,7 @@ def _process_training_step_output_1_0(self, training_step_output, split_batch): return training_step_output_for_epoch_end, training_step_output def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) using_native_amp = self.trainer.amp_backend == AMPType.NATIVE @@ -452,7 +452,7 @@ def _track_gradient_norm(self): grad_norm_dict = {} if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0: if float(self.trainer.track_grad_norm) > 0: - model = self.trainer.get_model() + model = self.trainer.lightning_module grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm) return grad_norm_dict @@ -465,7 +465,7 @@ def process_hiddens(self, opt_closure_result): def tbptt_split_batch(self, batch): splits = [batch] if self.trainer.truncated_bptt_steps is not None: - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module with self.trainer.profiler.profile("tbptt_split_batch"): splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps) return splits @@ -758,7 +758,7 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, if len(self.trainer.optimizers) > 1: # revert back to previous state - self.trainer.get_model().untoggle_optimizer(opt_idx) + self.trainer.lightning_module.untoggle_optimizer(opt_idx) return result @@ -895,7 +895,7 @@ def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if self.automatic_optimization and len(self.trainer.optimizers) > 1: - model = self.trainer.get_model() + model = self.trainer.lightning_module model.toggle_optimizer(optimizer, opt_idx) # use to track metrics internally diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 7665f96426df15..6b388f7137ce12 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod +from abc import ABC import torch from torch import Tensor @@ -29,21 +29,18 @@ class TrainerTrainingTricksMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class default_root_dir: str + lightning_module: LightningModule progress_bar_callback:... on_gpu: bool - @abstractmethod - def get_model(self) -> LightningModule: - """Warning: this is just empty shell for code implemented in other class.""" - def print_nan_gradients(self) -> None: - model = self.get_model() + model = self.lightning_module for param in model.parameters(): if (param.grad is not None) and torch.isnan(param.grad.float()).any(): log.info(param, param.grad) def detect_nan_tensors(self, loss: Tensor) -> None: - model = self.get_model() + model = self.lightning_module # check if loss is nan if not torch.isfinite(loss).all(): diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 3a52b6dd2e8faa..c29cffc42607bd 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -268,7 +268,7 @@ def _adjust_batch_size( The new batch size for the next trial and a bool that signals whether the new value is different than the previous batch size. """ - model = trainer.get_model() + model = trainer.lightning_module batch_size = lightning_getattr(model, batch_arg_name) new_size = value if value is not None else int(batch_size * factor) if desc: diff --git a/tests/callbacks/test_swa.py b/tests/callbacks/test_swa.py index 337773994b2907..ea8e368e395425 100644 --- a/tests/callbacks/test_swa.py +++ b/tests/callbacks/test_swa.py @@ -111,7 +111,7 @@ def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_proc trainer.fit(model) # check the model is the expected - assert trainer.get_model() == model + assert trainer.lightning_module == model @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0") diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index f65b13a661f398..c13e5b9dfadfd1 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -30,6 +30,13 @@ from tests.helpers import BoringModel +def test_v1_4_0_deprecated_trainer_methods(): + with pytest.deprecated_call(match='will be removed in v1.4'): + trainer = Trainer() + _ = trainer.get_model() + assert trainer.get_model() == trainer.lightning_module + + def test_v1_4_0_deprecated_imports(): _soft_unimport_module('pytorch_lightning.utilities.argparse_utils') with pytest.deprecated_call(match='will be removed in v1.4'): diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 28e3e65a875869..d28ab6177f21c7 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -456,7 +456,7 @@ def on_train_start(self): dp_model.module.module.running_stage = RunningStage.EVALUATING dataloader = self.train_dataloader() - tpipes.run_prediction(self.trainer.get_model(), dataloader) + tpipes.run_prediction(self.trainer.lightning_module, dataloader) self.on_train_start_called = True # new model diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 73e900072b7e0c..7a5dd3f685ed50 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -24,15 +24,15 @@ class TrainerGetModel(BoringModel): def on_fit_start(self): - assert self == self.trainer.get_model() + assert self == self.trainer.lightning_module def on_fit_end(self): - assert self == self.trainer.get_model() + assert self == self.trainer.lightning_module def test_get_model(tmpdir): """ - Tests that :meth:`trainer.get_model` extracts the model correctly + Tests that `trainer.lightning_module` extracts the model correctly """ model = TrainerGetModel() @@ -50,7 +50,7 @@ def test_get_model(tmpdir): @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") def test_get_model_ddp_cpu(tmpdir): """ - Tests that :meth:`trainer.get_model` extracts the model correctly when using ddp on cpu + Tests that `trainer.lightning_module` extracts the model correctly when using ddp on cpu """ model = TrainerGetModel() @@ -70,7 +70,7 @@ def test_get_model_ddp_cpu(tmpdir): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_get_model_gpu(tmpdir): """ - Tests that :meth:`trainer.get_model` extracts the model correctly when using GPU + Tests that `trainer.lightning_module` extracts the model correctly when using GPU """ model = TrainerGetModel() @@ -91,7 +91,7 @@ def test_get_model_gpu(tmpdir): @DDPLauncher.run("--accelerator [accelerator]", max_epochs=["1"], accelerator=["ddp", "ddp_spawn"]) def test_get_model_ddp_gpu(tmpdir, args=None): """ - Tests that :meth:`trainer.get_model` extracts the model correctly when using GPU + ddp accelerators + Tests that `trainer.lightning_module` extracts the model correctly when using GPU + ddp accelerators """ model = TrainerGetModel() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ce2eeb43e01149..167930425dab1f 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -474,7 +474,7 @@ def mock_save_function(filepath, *args): trainer.current_epoch = i trainer.global_step = i trainer.logger_connector.callback_metrics = {"checkpoint_on": torch.tensor(loss)} - checkpoint_callback.on_validation_end(trainer, trainer.get_model()) + checkpoint_callback.on_validation_end(trainer, trainer.lightning_module) file_lists = set(os.listdir(tmpdir)) @@ -1421,11 +1421,11 @@ def setup(self, model, stage): trainer.fit(model) assert trainer.stage == "fit" - assert trainer.get_model().stage == "fit" + assert trainer.lightning_module.stage == "fit" trainer.test(ckpt_path=None) assert trainer.stage == "test" - assert trainer.get_model().stage == "test" + assert trainer.lightning_module.stage == "test" @pytest.mark.parametrize(