From efeca8ad94bbb6d14ba27ee54988966885ec2ae8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Feb 2021 18:06:35 +0100 Subject: [PATCH 1/9] rename get_model -> lightning_module --- pytorch_lightning/core/optimizer.py | 6 +- pytorch_lightning/trainer/callback_hook.py | 62 +++++++++---------- .../connectors/checkpoint_connector.py | 8 +-- .../logger_connector/epoch_result_store.py | 6 +- .../logger_connector/logger_connector.py | 8 +-- pytorch_lightning/trainer/evaluation_loop.py | 14 ++--- pytorch_lightning/trainer/model_hooks.py | 4 +- pytorch_lightning/trainer/predict_loop.py | 10 +-- pytorch_lightning/trainer/properties.py | 4 +- pytorch_lightning/trainer/trainer.py | 32 +++++----- pytorch_lightning/trainer/training_loop.py | 26 ++++---- pytorch_lightning/trainer/training_tricks.py | 4 +- pytorch_lightning/tuner/batch_size_scaling.py | 2 +- tests/callbacks/test_swa.py | 2 +- tests/models/test_restore.py | 2 +- tests/trainer/properties/test_get_model.py | 4 +- tests/trainer/test_trainer.py | 6 +- 17 files changed, 100 insertions(+), 100 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index d18abde814aab..06c0323466cee 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -101,11 +101,11 @@ def _to_lightning_optimizer(cls, optimizer, trainer, opt_idx): return optimizer def _toggle_model(self): - model_ref = self._trainer.get_model() + model_ref = self._trainer.lightning_module model_ref.toggle_optimizer(self, self._optimizer_idx) def _untoggle_model(self): - model_ref = self._trainer.get_model() + model_ref = self._trainer.lightning_module model_ref.untoggle_optimizer(self) @contextmanager @@ -129,7 +129,7 @@ def toggle_model(self, sync_grad: bool = True): def __optimizer_step(self, closure: Optional[Callable] = None, profiler_name: str = None, **kwargs): trainer = self._trainer optimizer = self._optimizer - model = trainer.get_model() + model = trainer.lightning_module with trainer.profiler.profile(profiler_name): trainer.accelerator_backend.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index a11394734f97b..f6bec5e518e3e 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -39,7 +39,7 @@ def setup(self, model, stage: str): def teardown(self, stage: str): """Called at the end of fit and test""" for callback in self.callbacks: - callback.teardown(self, self.get_model(), stage) + callback.teardown(self, self.lightning_module, stage) def on_init_start(self): """Called when the trainer initialization begins, model has not yet been set.""" @@ -54,72 +54,72 @@ def on_init_end(self): def on_fit_start(self): """Called when the trainer initialization begins, model has not yet been set.""" for callback in self.callbacks: - callback.on_fit_start(self, self.get_model()) + callback.on_fit_start(self, self.lightning_module) def on_fit_end(self): """Called when the trainer initialization begins, model has not yet been set.""" for callback in self.callbacks: - callback.on_fit_end(self, self.get_model()) + callback.on_fit_end(self, self.lightning_module) def on_sanity_check_start(self): """Called when the validation sanity check starts.""" for callback in self.callbacks: - callback.on_sanity_check_start(self, self.get_model()) + callback.on_sanity_check_start(self, self.lightning_module) def on_sanity_check_end(self): """Called when the validation sanity check ends.""" for callback in self.callbacks: - callback.on_sanity_check_end(self, self.get_model()) + callback.on_sanity_check_end(self, self.lightning_module) def on_train_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: - callback.on_train_epoch_start(self, self.get_model()) + callback.on_train_epoch_start(self, self.lightning_module) def on_train_epoch_end(self, outputs): """Called when the epoch ends.""" for callback in self.callbacks: - callback.on_train_epoch_end(self, self.get_model(), outputs) + callback.on_train_epoch_end(self, self.lightning_module, outputs) def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: - callback.on_validation_epoch_start(self, self.get_model()) + callback.on_validation_epoch_start(self, self.lightning_module) def on_validation_epoch_end(self): """Called when the epoch ends.""" for callback in self.callbacks: - callback.on_validation_epoch_end(self, self.get_model()) + callback.on_validation_epoch_end(self, self.lightning_module) def on_test_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: - callback.on_test_epoch_start(self, self.get_model()) + callback.on_test_epoch_start(self, self.lightning_module) def on_test_epoch_end(self): """Called when the epoch ends.""" for callback in self.callbacks: - callback.on_test_epoch_end(self, self.get_model()) + callback.on_test_epoch_end(self, self.lightning_module) def on_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: - callback.on_epoch_start(self, self.get_model()) + callback.on_epoch_start(self, self.lightning_module) def on_epoch_end(self): """Called when the epoch ends.""" for callback in self.callbacks: - callback.on_epoch_end(self, self.get_model()) + callback.on_epoch_end(self, self.lightning_module) def on_train_start(self): """Called when the train begins.""" for callback in self.callbacks: - callback.on_train_start(self, self.get_model()) + callback.on_train_start(self, self.lightning_module) def on_train_end(self): """Called when the train ends.""" for callback in self.callbacks: - callback.on_train_end(self, self.get_model()) + callback.on_train_end(self, self.lightning_module) def on_pretrain_routine_start(self, model): """Called when the train begins.""" @@ -134,74 +134,74 @@ def on_pretrain_routine_end(self, model): def on_batch_start(self): """Called when the training batch begins.""" for callback in self.callbacks: - callback.on_batch_start(self, self.get_model()) + callback.on_batch_start(self, self.lightning_module) def on_batch_end(self): """Called when the training batch ends.""" for callback in self.callbacks: - callback.on_batch_end(self, self.get_model()) + callback.on_batch_end(self, self.lightning_module) def on_train_batch_start(self, batch, batch_idx, dataloader_idx): """Called when the training batch begins.""" for callback in self.callbacks: - callback.on_train_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx) + callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): """Called when the training batch ends.""" for callback in self.callbacks: - callback.on_train_batch_end(self, self.get_model(), outputs, batch, batch_idx, dataloader_idx) + callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): """Called when the validation batch begins.""" for callback in self.callbacks: - callback.on_validation_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx) + callback.on_validation_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): """Called when the validation batch ends.""" for callback in self.callbacks: - callback.on_validation_batch_end(self, self.get_model(), outputs, batch, batch_idx, dataloader_idx) + callback.on_validation_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) def on_test_batch_start(self, batch, batch_idx, dataloader_idx): """Called when the test batch begins.""" for callback in self.callbacks: - callback.on_test_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx) + callback.on_test_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx): """Called when the test batch ends.""" for callback in self.callbacks: - callback.on_test_batch_end(self, self.get_model(), outputs, batch, batch_idx, dataloader_idx) + callback.on_test_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) def on_validation_start(self): """Called when the validation loop begins.""" for callback in self.callbacks: - callback.on_validation_start(self, self.get_model()) + callback.on_validation_start(self, self.lightning_module) def on_validation_end(self): """Called when the validation loop ends.""" for callback in self.callbacks: - callback.on_validation_end(self, self.get_model()) + callback.on_validation_end(self, self.lightning_module) def on_test_start(self): """Called when the test begins.""" for callback in self.callbacks: - callback.on_test_start(self, self.get_model()) + callback.on_test_start(self, self.lightning_module) def on_test_end(self): """Called when the test ends.""" for callback in self.callbacks: - callback.on_test_end(self, self.get_model()) + callback.on_test_end(self, self.lightning_module) def on_keyboard_interrupt(self): """Called when the training is interrupted by KeyboardInterrupt.""" for callback in self.callbacks: - callback.on_keyboard_interrupt(self, self.get_model()) + callback.on_keyboard_interrupt(self, self.lightning_module) def on_save_checkpoint(self): """Called when saving a model checkpoint.""" callback_states = {} for callback in self.callbacks: callback_class = type(callback) - state = callback.on_save_checkpoint(self, self.get_model()) + state = callback.on_save_checkpoint(self, self.lightning_module) if state: callback_states[callback_class] = state return callback_states @@ -224,11 +224,11 @@ def on_after_backward(self): Called after loss.backward() and before optimizers do anything. """ for callback in self.callbacks: - callback.on_after_backward(self, self.get_model()) + callback.on_after_backward(self, self.lightning_module) def on_before_zero_grad(self, optimizer): """ Called after optimizer.step() and before optimizer.zero_grad(). """ for callback in self.callbacks: - callback.on_before_zero_grad(self, self.get_model(), optimizer) + callback.on_before_zero_grad(self, self.lightning_module, optimizer) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 2fca7b410f3e1..4f5238a570ede 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -94,7 +94,7 @@ def restore(self, checkpoint_path: str, on_gpu: bool) -> bool: checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) # acquire the model - model = self.trainer.get_model() + model = self.trainer.lightning_module # restore model and datamodule state self.restore_model_state(model, checkpoint) @@ -214,7 +214,7 @@ def hpc_save(self, folderpath: str, logger): filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt') # give model a chance to do something on hpc_save - model = self.trainer.get_model() + model = self.trainer.lightning_module checkpoint = self.dump_checkpoint() model.on_hpc_save(checkpoint) @@ -307,7 +307,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: checkpoint['amp_scaling_state'] = amp.state_dict() # add the hyper_parameters and state_dict from the model - model = self.trainer.get_model() + model = self.trainer.lightning_module # dump the module_arguments and state_dict from the model checkpoint['state_dict'] = model.state_dict() @@ -339,7 +339,7 @@ def hpc_load(self, checkpoint_path: str, on_gpu: bool): checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) # acquire the model - model = self.trainer.get_model() + model = self.trainer.lightning_module # restore model and datamodule state self.restore_model_state(model, checkpoint) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 84f88fb9840f2..c435204107775 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -235,7 +235,7 @@ def info(self): """ This function provides necessary parameters to properly configure HookResultStore obj """ - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module return { "batch_idx": self.trainer.batch_idx, "fx_name": model_ref._current_hook_fx_name or model_ref._current_fx_name, @@ -252,7 +252,7 @@ def reset_model(self): """ This function is used to reset model state at the end of the capture """ - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module model_ref._results = Result() model_ref._current_hook_fx_name = None model_ref._current_fx_name = '' @@ -263,7 +263,7 @@ def cache_result(self) -> None: and store the result object """ with self.trainer.profiler.profile("cache_result"): - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module # extract hook results hook_result = model_ref._results diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 595a5e84bf630..cd1bb5e8b787e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -82,7 +82,7 @@ def cached_results(self) -> Union[EpochResultStore, None]: def get_metrics(self, key: str) -> Dict: metrics_holder = getattr(self, f"_{key}", None) - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module metrics_holder.convert( self.trainer._device_type == DeviceType.TPU, model_ref.device if model_ref is not None else model_ref, @@ -103,7 +103,7 @@ def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoc def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataloaders): # Todo: required argument `testing` is not used - model = self.trainer.get_model() + model = self.trainer.lightning_module # set dataloader_idx only if multiple ones model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None # track batch_size @@ -263,7 +263,7 @@ def track_metrics_deprecated(self, deprecated_eval_results): def evaluation_epoch_end(self, testing): # Todo: required argument `testing` is not used # reset dataloader idx - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module model_ref._current_dataloader_idx = None # setting `has_batch_loop_finished` to True @@ -408,7 +408,7 @@ def log_train_epoch_end_metrics( # epoch_output[optimizer_idx][training_step_idx][tbptt_index] # remember that not using truncated backprop is equivalent with truncated back prop of len(1) - model = self.trainer.get_model() + model = self.trainer.lightning_module epoch_callback_metrics = {} diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index fe3fc62ff1189..fc8b4721293bf 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -47,7 +47,7 @@ def on_trainer_init(self): def get_evaluation_dataloaders(self, max_batches): # select dataloaders - model = self.trainer.get_model() + model = self.trainer.lightning_module # select dataloaders if self.trainer.testing: @@ -80,14 +80,14 @@ def on_evaluation_start(self, *args, **kwargs): self.trainer.call_hook('on_validation_start', *args, **kwargs) def on_evaluation_model_eval(self, *_, **__): - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module if self.trainer.testing: model_ref.on_test_model_eval() else: model_ref.on_validation_model_eval() def on_evaluation_model_train(self, *_, **__): - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module if self.trainer.testing: model_ref.on_test_model_train() else: @@ -100,7 +100,7 @@ def on_evaluation_end(self, *args, **kwargs): self.trainer.call_hook('on_validation_end', *args, **kwargs) def reload_evaluation_dataloaders(self): - model = self.trainer.get_model() + model = self.trainer.lightning_module if self.trainer.testing: self.trainer.reset_test_dataloader(model) else: @@ -151,7 +151,7 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx): # configure args args = self._build_args(batch, batch_idx, dataloader_idx) - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module model_ref._results = Result() if self.testing: @@ -199,7 +199,7 @@ def log_epoch_metrics_on_evaluation_end(self): return eval_loop_results def __run_eval_epoch_end(self, num_dataloaders): - model = self.trainer.get_model() + model = self.trainer.lightning_module # with a single dataloader don't pass an array outputs = self.outputs @@ -270,7 +270,7 @@ def __auto_reduce_result_objs(self, outputs): return eval_results def on_predict_epoch_end(self): - self.trainer._progress_bar_callback.on_test_end(self.trainer, self.trainer.get_model()) + self.trainer._progress_bar_callback.on_test_end(self.trainer, self.trainer.lightning_module) results = self._predictions diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index 420911bb2b064..97fe6b1482ce0 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -22,12 +22,12 @@ class TrainerModelHooksMixin(ABC): def is_function_implemented(self, f_name, model=None): if model is None: - model = self.get_model() + model = self.lightning_module f_op = getattr(model, f_name, None) return callable(f_op) def has_arg(self, f_name, arg_name): - model = self.get_model() + model = self.lightning_module f_op = getattr(model, f_name, None) return arg_name in inspect.signature(f_op).parameters diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py index 43016b8943c81..4fecbbaf05348 100644 --- a/pytorch_lightning/trainer/predict_loop.py +++ b/pytorch_lightning/trainer/predict_loop.py @@ -28,7 +28,7 @@ def on_trainer_init(self): def get_predict_dataloaders(self, max_batches): # select dataloaders - model = self.trainer.get_model() + model = self.trainer.lightning_module self.trainer.reset_predict_dataloader(model) dataloaders = self.trainer.predict_dataloaders if max_batches is None: @@ -40,7 +40,7 @@ def should_skip_predict(self, dataloaders, max_batches): return dataloaders is None or not sum(max_batches) def on_predict_model_eval(self, *_, **__): - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module model_ref.on_predict_model_eval() def setup(self, model, max_batches, dataloaders): @@ -55,7 +55,7 @@ def setup(self, model, max_batches, dataloaders): self.num_dataloaders = self._get_num_dataloaders(dataloaders) self._predictions = [[] for _ in range(self.num_dataloaders)] - self.trainer._progress_bar_callback.on_predict_start(self.trainer, self.trainer.get_model()) + self.trainer._progress_bar_callback.on_predict_start(self.trainer, self.trainer.lightning_module) def _get_num_dataloaders(self, dataloaders): # case where user does: @@ -71,7 +71,7 @@ def predict(self, batch, batch_idx, dataloader_idx): if self.num_dataloaders: args.append(dataloader_idx) - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module model_ref._current_fx_name = "predict" predictions = self.trainer.accelerator_backend.predict(args) @@ -82,7 +82,7 @@ def predict(self, batch, batch_idx, dataloader_idx): return def on_predict_epoch_end(self): - self.trainer._progress_bar_callback.on_predict_end(self.trainer, self.trainer.get_model()) + self.trainer._progress_bar_callback.on_predict_end(self.trainer, self.trainer.lightning_module) results = self._predictions diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 1f0cc52870f7e..6581a9a3ba81a 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -245,7 +245,7 @@ def progress_bar_callback(self) -> Optional[ProgressBarBase]: @property def progress_bar_dict(self) -> dict: """ Read-only for progress bar metrics. """ - ref_model = self.get_model() + ref_model = self.lightning_module ref_model = cast(LightningModule, ref_model) standard_metrics = ref_model.get_progress_bar_dict() @@ -270,7 +270,7 @@ def disable_validation(self) -> bool: @property def enable_validation(self) -> bool: """ Check if we should run validation during training. """ - model_ref = self.get_model() + model_ref = self.lightning_module val_loop_enabled = is_overridden('validation_step', model_ref) and self.limit_val_batches > 0 return val_loop_enabled diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2453a08ba9067..e01d369d27220 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -583,7 +583,7 @@ def _pre_training_routine(self): # Pre-train # -------------------------- # on pretrain routine start - ref_model = self.get_model() + ref_model = self.lightning_module self.on_pretrain_routine_start(ref_model) if self.is_function_implemented("on_pretrain_routine_start"): @@ -611,15 +611,15 @@ def run_train(self): if not self.is_global_zero and self.progress_bar_callback is not None: self.progress_bar_callback.disable() - self.run_sanity_check(self.get_model()) + self.run_sanity_check(self.lightning_module) # set stage for logging - self._set_running_stage(RunningStage.TRAINING, self.get_model()) + self._set_running_stage(RunningStage.TRAINING, self.lightning_module) self.checkpoint_connector.has_trained = False # enable train mode - model = self.get_model() + model = self.lightning_module model.train() torch.set_grad_enabled(True) @@ -678,7 +678,7 @@ def run_train(self): def run_evaluation(self, max_batches=None, on_epoch=False): # used to know if we are logging for val, test + reset cached results - self._set_running_stage(RunningStage.TESTING if self.testing else RunningStage.EVALUATING, self.get_model()) + self._set_running_stage(RunningStage.TESTING if self.testing else RunningStage.EVALUATING, self.lightning_module) self.logger_connector.reset() # bookkeeping @@ -694,7 +694,7 @@ def run_evaluation(self, max_batches=None, on_epoch=False): # enable eval mode + no grads self.evaluation_loop.on_evaluation_model_eval() # ref model - model = self.get_model() + model = self.lightning_module model.zero_grad() torch.set_grad_enabled(False) @@ -811,7 +811,7 @@ def run_predict(self): return [] # ref model - model = self.get_model() + model = self.lightning_module # enable eval mode + no grads self.predict_loop.on_predict_model_eval() @@ -905,7 +905,7 @@ def test( # -------------------- self.verbose_test = verbose - self._set_running_stage(RunningStage.TESTING, model or self.get_model()) + self._set_running_stage(RunningStage.TESTING, model or self.lightning_module) # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if test_dataloaders and datamodule: @@ -914,7 +914,7 @@ def test( ) # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test') + self.data_connector.attach_datamodule(model or self.lightning_module, datamodule, 'test') if model is not None: results = self.__test_given_model(model, test_dataloaders) @@ -922,11 +922,11 @@ def test( results = self.__test_using_best_weights(ckpt_path, test_dataloaders) self.teardown('test') - self._set_running_stage(None, model or self.get_model()) + self._set_running_stage(None, model or self.lightning_module) return results def __test_using_best_weights(self, ckpt_path, test_dataloaders): - model = self.get_model() + model = self.lightning_module # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path: @@ -962,7 +962,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): # teardown if self.is_function_implemented('teardown'): - model_ref = self.get_model() + model_ref = self.lightning_module model_ref.teardown('test') return results @@ -1012,7 +1012,7 @@ def predict( # -------------------- # If you supply a datamodule you can't supply dataloaders - model = model or self.get_model() + model = model or self.lightning_module self._set_running_stage(RunningStage.PREDICTING, model) @@ -1073,7 +1073,7 @@ def _reset_result_and_set_hook_fx_name(self, hook_name): # on_before_zero_grad is called within training_step if "batch_start" in hook_name or "on_before_zero_grad" in hook_name: return True - model_ref = self.get_model() + model_ref = self.lightning_module if model_ref is not None: # used to track current hook name called model_ref._results = Result() @@ -1081,7 +1081,7 @@ def _reset_result_and_set_hook_fx_name(self, hook_name): return False def _cache_logged_metrics(self): - model_ref = self.get_model() + model_ref = self.lightning_module if model_ref is not None: # capture logging for this hook self.logger_connector.cache_logged_metrics() @@ -1100,7 +1100,7 @@ def call_hook(self, hook_name, *args, **kwargs): # next call hook in lightningModule output = None - model_ref = self.get_model() + model_ref = self.lightning_module if is_overridden(hook_name, model_ref): hook_fx = getattr(model_ref, hook_name) output = hook_fx(*args, **kwargs) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 0908e96bd1c17..5827518d16b00 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -158,7 +158,7 @@ def check_checkpoint_callback(self, should_update, is_last=False): if is_last and any(cb.save_last for cb in callbacks): rank_zero_info("Saving latest checkpoint...") - model = self.trainer.get_model() + model = self.trainer.lightning_module for cb in callbacks: cb.on_validation_end(self.trainer, model) @@ -167,7 +167,7 @@ def check_early_stopping_callback(self, should_update): # TODO bake this logic into the EarlyStopping callback if should_update and self.trainer.checkpoint_connector.has_trained: callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] - model = self.trainer.get_model() + model = self.trainer.lightning_module for cb in callbacks: cb.on_validation_end(self.trainer, model) @@ -177,7 +177,7 @@ def on_train_epoch_start(self, epoch): # update training progress in trainer self.trainer.current_epoch = epoch - model = self.trainer.get_model() + model = self.trainer.lightning_module # reset train dataloader if epoch != 0 and self.trainer.reload_dataloaders_every_epoch: @@ -189,7 +189,7 @@ def on_train_epoch_start(self, epoch): self.trainer.train_dataloader.sampler.set_epoch(epoch) # changing gradient according accumulation_scheduler - self.trainer.accumulation_scheduler.on_epoch_start(self.trainer, self.trainer.get_model()) + self.trainer.accumulation_scheduler.on_epoch_start(self.trainer, self.trainer.lightning_module) # stores accumulated grad fractions per batch self.accumulated_loss = TensorRunningAccum(window_length=self.trainer.accumulate_grad_batches) @@ -229,8 +229,8 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end hook_overridden = ( - is_overridden("training_epoch_end", model=self.trainer.get_model()) - or is_overridden("on_train_epoch_end", model=self.trainer.get_model()) + is_overridden("training_epoch_end", model=self.trainer.lightning_module) + or is_overridden("on_train_epoch_end", model=self.trainer.lightning_module) ) # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end @@ -281,7 +281,7 @@ def _check_training_step_output(self, training_step_output): def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # give the PL module a result for logging - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module with self.trainer.profiler.profile("model_forward"): args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) @@ -379,7 +379,7 @@ def _process_training_step_output(self, training_step_output, split_batch): return training_step_output_for_epoch_end, training_step_output def _process_training_step_output_1_0(self, training_step_output, split_batch): - result = self.trainer.get_model()._results + result = self.trainer.lightning_module._results loss = None hiddens = None @@ -437,7 +437,7 @@ def _process_result(self, training_step_output, split_batch): return training_step_output_for_epoch_end def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) using_native_amp = self.trainer.amp_backend == AMPType.NATIVE @@ -482,7 +482,7 @@ def _track_gradient_norm(self): grad_norm_dict = {} if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0: if float(self.trainer.track_grad_norm) > 0: - model = self.trainer.get_model() + model = self.trainer.lightning_module grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm) return grad_norm_dict @@ -495,7 +495,7 @@ def process_hiddens(self, opt_closure_result): def tbptt_split_batch(self, batch): splits = [batch] if self.trainer.truncated_bptt_steps is not None: - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module with self.trainer.profiler.profile("tbptt_split_batch"): splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps) return splits @@ -789,7 +789,7 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, if len(self.trainer.optimizers) > 1: # revert back to previous state - self.trainer.get_model().untoggle_optimizer(opt_idx) + self.trainer.lightning_module.untoggle_optimizer(opt_idx) return result @@ -926,7 +926,7 @@ def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if self.automatic_optimization and len(self.trainer.optimizers) > 1: - model = self.trainer.get_model() + model = self.trainer.lightning_module model.toggle_optimizer(optimizer, opt_idx) # use to track metrics internally diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 7665f96426df1..d8faa57de73db 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -37,13 +37,13 @@ def get_model(self) -> LightningModule: """Warning: this is just empty shell for code implemented in other class.""" def print_nan_gradients(self) -> None: - model = self.get_model() + model = self.lightning_module for param in model.parameters(): if (param.grad is not None) and torch.isnan(param.grad.float()).any(): log.info(param, param.grad) def detect_nan_tensors(self, loss: Tensor) -> None: - model = self.get_model() + model = self.lightning_module # check if loss is nan if not torch.isfinite(loss).all(): diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 3a52b6dd2e8fa..c29cffc42607b 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -268,7 +268,7 @@ def _adjust_batch_size( The new batch size for the next trial and a bool that signals whether the new value is different than the previous batch size. """ - model = trainer.get_model() + model = trainer.lightning_module batch_size = lightning_getattr(model, batch_arg_name) new_size = value if value is not None else int(batch_size * factor) if desc: diff --git a/tests/callbacks/test_swa.py b/tests/callbacks/test_swa.py index 72a4c4fc1ab80..bcaa27f454dea 100644 --- a/tests/callbacks/test_swa.py +++ b/tests/callbacks/test_swa.py @@ -111,7 +111,7 @@ def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_proc trainer.fit(model) # check the model is the expected - assert trainer.get_model() == model + assert trainer.lightning_module == model @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0") diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 28e3e65a87586..d28ab6177f21c 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -456,7 +456,7 @@ def on_train_start(self): dp_model.module.module.running_stage = RunningStage.EVALUATING dataloader = self.train_dataloader() - tpipes.run_prediction(self.trainer.get_model(), dataloader) + tpipes.run_prediction(self.trainer.lightning_module, dataloader) self.on_train_start_called = True # new model diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 73e900072b7e0..6f1f31e186302 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -24,10 +24,10 @@ class TrainerGetModel(BoringModel): def on_fit_start(self): - assert self == self.trainer.get_model() + assert self == self.trainer.lightning_module def on_fit_end(self): - assert self == self.trainer.get_model() + assert self == self.trainer.lightning_module def test_get_model(tmpdir): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 71caaaad4d7f9..b7a28597ab7d4 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -474,7 +474,7 @@ def mock_save_function(filepath, *args): trainer.current_epoch = i trainer.global_step = i trainer.logger_connector.callback_metrics = {"checkpoint_on": torch.tensor(loss)} - checkpoint_callback.on_validation_end(trainer, trainer.get_model()) + checkpoint_callback.on_validation_end(trainer, trainer.lightning_module) file_lists = set(os.listdir(tmpdir)) @@ -1420,11 +1420,11 @@ def setup(self, model, stage): trainer.fit(model) assert trainer.stage == "fit" - assert trainer.get_model().stage == "fit" + assert trainer.lightning_module.stage == "fit" trainer.test(ckpt_path=None) assert trainer.stage == "test" - assert trainer.get_model().stage == "test" + assert trainer.lightning_module.stage == "test" @pytest.mark.parametrize( From 9ca9da62e96ac682de071d3612ac7b772f6e7939 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Feb 2021 18:11:12 +0100 Subject: [PATCH 2/9] update references to get_model --- pytorch_lightning/trainer/callback_hook.py | 5 +++-- pytorch_lightning/trainer/connectors/model_connector.py | 3 --- pytorch_lightning/trainer/model_hooks.py | 6 ++---- pytorch_lightning/trainer/properties.py | 1 - pytorch_lightning/trainer/training_tricks.py | 7 ++----- tests/trainer/properties/test_get_model.py | 8 ++++---- 6 files changed, 11 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index f6bec5e518e3e..f640d0f1eef88 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -14,8 +14,9 @@ from abc import ABC from copy import deepcopy -from typing import Callable, List +from typing import List +from pytorch_lightning import LightningModule from pytorch_lightning.callbacks import Callback @@ -24,7 +25,7 @@ class TrainerCallbackHookMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class callbacks: List[Callback] = [] - get_model: Callable + lightning_module: LightningModule def on_before_accelerator_backend_setup(self, model): """Called in the beginning of fit and test""" diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index 060601049f9b7..2e95fe7209ded 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -38,9 +38,6 @@ def copy_trainer_model_properties(self, model): m.testing = self.trainer.testing m.precision = self.trainer.precision - def get_model(self): - return self._get_reference_model(self.trainer.model) - def _get_reference_model(self, model): if self.trainer.accelerator_backend and self.trainer.accelerator_backend.lightning_module: return self.trainer.accelerator_backend.lightning_module diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index 97fe6b1482ce0..c6daf30277f91 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -20,6 +20,8 @@ class TrainerModelHooksMixin(ABC): + lightning_module: LightningModule + def is_function_implemented(self, f_name, model=None): if model is None: model = self.lightning_module @@ -30,7 +32,3 @@ def has_arg(self, f_name, arg_name): model = self.lightning_module f_op = getattr(model, f_name, None) return arg_name in inspect.signature(f_op).parameters - - @abstractmethod - def get_model(self) -> LightningModule: - """Warning: this is just empty shell for code implemented in other class.""" diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 6581a9a3ba81a..008556c171fb9 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -353,7 +353,6 @@ def model(self, model: torch.nn.Module) -> None: self.accelerator.model = model def get_model(self) -> LightningModule: - # TODO: rename this to lightning_module (see training type plugin) # backward compatible return self.lightning_module diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index d8faa57de73db..9a18930e74ee4 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -29,13 +29,10 @@ class TrainerTrainingTricksMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class default_root_dir: str - progress_bar_callback:... + lightning_module: LightningModule + progress_bar_callback: ... on_gpu: bool - @abstractmethod - def get_model(self) -> LightningModule: - """Warning: this is just empty shell for code implemented in other class.""" - def print_nan_gradients(self) -> None: model = self.lightning_module for param in model.parameters(): diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 6f1f31e186302..7a5dd3f685ed5 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -32,7 +32,7 @@ def on_fit_end(self): def test_get_model(tmpdir): """ - Tests that :meth:`trainer.get_model` extracts the model correctly + Tests that `trainer.lightning_module` extracts the model correctly """ model = TrainerGetModel() @@ -50,7 +50,7 @@ def test_get_model(tmpdir): @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") def test_get_model_ddp_cpu(tmpdir): """ - Tests that :meth:`trainer.get_model` extracts the model correctly when using ddp on cpu + Tests that `trainer.lightning_module` extracts the model correctly when using ddp on cpu """ model = TrainerGetModel() @@ -70,7 +70,7 @@ def test_get_model_ddp_cpu(tmpdir): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_get_model_gpu(tmpdir): """ - Tests that :meth:`trainer.get_model` extracts the model correctly when using GPU + Tests that `trainer.lightning_module` extracts the model correctly when using GPU """ model = TrainerGetModel() @@ -91,7 +91,7 @@ def test_get_model_gpu(tmpdir): @DDPLauncher.run("--accelerator [accelerator]", max_epochs=["1"], accelerator=["ddp", "ddp_spawn"]) def test_get_model_ddp_gpu(tmpdir, args=None): """ - Tests that :meth:`trainer.get_model` extracts the model correctly when using GPU + ddp accelerators + Tests that `trainer.lightning_module` extracts the model correctly when using GPU + ddp accelerators """ model = TrainerGetModel() From ddb7f67a7ac579c241c00c45b6bad3f8c9eaf0d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Feb 2021 18:31:02 +0100 Subject: [PATCH 3/9] pep8 --- pytorch_lightning/trainer/model_hooks.py | 2 +- pytorch_lightning/trainer/trainer.py | 4 +++- pytorch_lightning/trainer/training_tricks.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index c6daf30277f91..7e3d6cc78320c 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from abc import ABC, abstractmethod +from abc import ABC from pytorch_lightning.core.lightning import LightningModule diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e01d369d27220..b87ac7dd946cd 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -678,7 +678,9 @@ def run_train(self): def run_evaluation(self, max_batches=None, on_epoch=False): # used to know if we are logging for val, test + reset cached results - self._set_running_stage(RunningStage.TESTING if self.testing else RunningStage.EVALUATING, self.lightning_module) + self._set_running_stage( + RunningStage.TESTING if self.testing else RunningStage.EVALUATING, self.lightning_module + ) self.logger_connector.reset() # bookkeeping diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 9a18930e74ee4..6b388f7137ce1 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod +from abc import ABC import torch from torch import Tensor @@ -30,7 +30,7 @@ class TrainerTrainingTricksMixin(ABC): # the proper values/initialisation should be done in child class default_root_dir: str lightning_module: LightningModule - progress_bar_callback: ... + progress_bar_callback:... on_gpu: bool def print_nan_gradients(self) -> None: From 74c55c49cfc14132114fdbd23a37959e49ba6eeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Feb 2021 22:58:25 +0100 Subject: [PATCH 4/9] add proper deprecation --- pytorch_lightning/trainer/deprecated_api.py | 10 ++++++++++ pytorch_lightning/trainer/properties.py | 4 ---- tests/deprecated_api/test_remove_1-4.py | 7 +++++++ 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 71b557bf75a2c..8e2e90dda8a7c 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn @@ -23,6 +24,7 @@ class DeprecatedDistDeviceAttributes: _running_stage: RunningStage num_gpus: int accelerator_connector: AcceleratorConnector + lightning_module = LightningModule @property def on_cpu(self) -> bool: @@ -130,3 +132,11 @@ def use_single_gpu(self, val: bool) -> None: ) if val: self.accelerator_connector._device_type = DeviceType.GPU + + def get_model(self) -> LightningModule: + rank_zero_warn( + "The use of `Trainer.get_model()` is deprecated in favor of `Trainer.lightning_module`" + " and will be removed in v1.4.", + DeprecationWarning, + ) + return self.lightning_module diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index caa46a3dbb2ba..47aad2710394d 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -352,10 +352,6 @@ def model(self, model: torch.nn.Module) -> None: """ self.accelerator.model = model - def get_model(self) -> LightningModule: - # backward compatible - return self.lightning_module - @property def lightning_optimizers(self) -> List[LightningOptimizer]: if self._lightning_optimizers is None: diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index f65b13a661f39..c13e5b9dfadfd 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -30,6 +30,13 @@ from tests.helpers import BoringModel +def test_v1_4_0_deprecated_trainer_methods(): + with pytest.deprecated_call(match='will be removed in v1.4'): + trainer = Trainer() + _ = trainer.get_model() + assert trainer.get_model() == trainer.lightning_module + + def test_v1_4_0_deprecated_imports(): _soft_unimport_module('pytorch_lightning.utilities.argparse_utils') with pytest.deprecated_call(match='will be removed in v1.4'): From e9c4e2c677242db2bd05a2c6ca18c08ba6f972f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Feb 2021 23:01:59 +0100 Subject: [PATCH 5/9] remove outdated _get_reference_model --- pytorch_lightning/trainer/connectors/model_connector.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index 2e95fe7209ded..4a0c565d78be0 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -25,7 +25,7 @@ def __init__(self, trainer): self.trainer = trainer def copy_trainer_model_properties(self, model): - ref_model = self._get_reference_model(model) + ref_model = self.trainer.lightning_module or model automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization self.trainer.train_loop.automatic_optimization = automatic_optimization @@ -37,8 +37,3 @@ def copy_trainer_model_properties(self, model): m.use_amp = self.trainer.amp_backend is not None m.testing = self.trainer.testing m.precision = self.trainer.precision - - def _get_reference_model(self, model): - if self.trainer.accelerator_backend and self.trainer.accelerator_backend.lightning_module: - return self.trainer.accelerator_backend.lightning_module - return model From 4ba2eaaa39d71d2d749c6500c77549975873830d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Feb 2021 23:09:06 +0100 Subject: [PATCH 6/9] fix cyclic import --- pytorch_lightning/trainer/callback_hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index f640d0f1eef88..f292f5a78bc65 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -16,8 +16,8 @@ from copy import deepcopy from typing import List -from pytorch_lightning import LightningModule from pytorch_lightning.callbacks import Callback +from pytorch_lightning.core.lightning import LightningModule class TrainerCallbackHookMixin(ABC): From a33a6d28ee1267f6d400cc4174d77c1f0cd217f9 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 18 Feb 2021 14:47:04 +0100 Subject: [PATCH 7/9] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d4eeddb5b6ca..b3bb906ea3758 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -230,6 +230,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated using `'val_loss'` to set the `ModelCheckpoint` monitor ([#6012](https://github.com/PyTorchLightning/pytorch-lightning/pull/6012)) +- Deprecated `.get_model()` with explicit `.lightning_module` property ([#6035](https://github.com/PyTorchLightning/pytorch-lightning/pull/6035)) + + ### Removed - Removed deprecated checkpoint argument `filepath` ([#5321](https://github.com/PyTorchLightning/pytorch-lightning/pull/5321)) From 5947b39f54fa746e62222524e572153369883248 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 18 Feb 2021 14:53:09 +0100 Subject: [PATCH 8/9] ... --- pytorch_lightning/trainer/deprecated_api.py | 3 +++ pytorch_lightning/trainer/trainer.py | 3 ++- pytorch_lightning/utilities/data.py | 4 +--- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 8e2e90dda8a7c..ddd54961c558c 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -133,6 +133,9 @@ def use_single_gpu(self, val: bool) -> None: if val: self.accelerator_connector._device_type = DeviceType.GPU + +class DeprecatedModelAttributes: + def get_model(self) -> LightningModule: rank_zero_warn( "The use of `Trainer.get_model()` is deprecated in favor of `Trainer.lightning_module`" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e0bc6d51dbb2b..7ad61020ab099 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -45,7 +45,7 @@ from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin -from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes +from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes, DeprecatedModelAttributes from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin @@ -80,6 +80,7 @@ class Trainer( TrainerTrainingTricksMixin, TrainerDataLoadingMixin, DeprecatedDistDeviceAttributes, + DeprecatedModelAttributes, ): @overwrite_by_env_vars diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 6b887b8526f90..a73299e2af77b 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -30,9 +30,7 @@ def has_len(dataloader: DataLoader) -> bool: try: # try getting the length if len(dataloader) == 0: - raise ValueError( - '`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch' - ) + raise ValueError('`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch') has_len = True except TypeError: has_len = False From a69d3b04a81c6c4bb2db164f183282b2a5ade025 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 18 Feb 2021 14:54:59 +0100 Subject: [PATCH 9/9] ... --- pytorch_lightning/trainer/deprecated_api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index ddd54961c558c..e1eecf26ed70e 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -24,7 +24,6 @@ class DeprecatedDistDeviceAttributes: _running_stage: RunningStage num_gpus: int accelerator_connector: AcceleratorConnector - lightning_module = LightningModule @property def on_cpu(self) -> bool: @@ -136,10 +135,11 @@ def use_single_gpu(self, val: bool) -> None: class DeprecatedModelAttributes: + lightning_module = LightningModule + def get_model(self) -> LightningModule: rank_zero_warn( "The use of `Trainer.get_model()` is deprecated in favor of `Trainer.lightning_module`" - " and will be removed in v1.4.", - DeprecationWarning, + " and will be removed in v1.4.", DeprecationWarning ) return self.lightning_module