diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index b320a9b223840..1efac79b63b83 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -114,13 +114,13 @@ def on_validation_model_eval(self) -> None: """ Sets the model to eval during the val loop """ - self.eval() + self.trainer.model.eval() def on_validation_model_train(self) -> None: """ Sets the model to train during the val loop """ - self.train() + self.trainer.model.train() def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """ @@ -172,19 +172,19 @@ def on_test_model_train(self) -> None: """ Sets the model to train during the test loop """ - self.train() + self.trainer.model.train() def on_test_model_eval(self) -> None: """ Sets the model to eval during the test loop """ - self.eval() + self.trainer.model.eval() def on_predict_model_eval(self) -> None: """ Sets the model to eval during the predict loop """ - self.eval() + self.trainer.model.eval() def on_epoch_start(self) -> None: """ diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py index b33f41cb2ea48..cdbc232f3eed9 100644 --- a/pytorch_lightning/trainer/predict_loop.py +++ b/pytorch_lightning/trainer/predict_loop.py @@ -44,7 +44,6 @@ def on_predict_model_eval(self, *_, **__): model_ref.on_predict_model_eval() def setup(self, model, max_batches, dataloaders): - self.trainer.call_hook("on_predict_start") # copy properties for forward overrides self.trainer.model_connector.copy_trainer_model_properties(model) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 46886ab39c85c..e6bf36df92a01 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -582,11 +582,11 @@ def run_train(self) -> None: self.checkpoint_connector.has_trained = False # enable train mode - model = self.lightning_module - model.train() + self.model.train() torch.set_grad_enabled(True) # reload data when needed + model = self.lightning_module self.train_loop.reset_train_val_dataloaders(model) # hook @@ -772,8 +772,6 @@ def run_evaluate(self): return eval_loop_results def run_predict(self): - self.predict_loop.on_predict_start() - # prepare dataloaders dataloaders, max_batches = self.predict_loop.get_predict_dataloaders() @@ -789,6 +787,9 @@ def run_predict(self): model.zero_grad() torch.set_grad_enabled(False) + # call hook + self.predict_loop.on_predict_start() + # set up the eval loop self.predict_loop.setup(model, max_batches, dataloaders) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ed255d690ed83..cbba0b7a454a5 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1438,7 +1438,9 @@ def setup(self, model, stage): ) @patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics") def test_log_every_n_steps(log_metrics_mock, tmpdir, train_batches, max_steps, log_interval): + class TestModel(BoringModel): + def training_step(self, *args, **kwargs): self.log("foo", -1) return super().training_step(*args, **kwargs) @@ -1888,3 +1890,33 @@ def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir): trainer.validate() with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"): trainer.test() + + +class TrainerStagesModel(BoringModel): + + def on_train_start(self) -> None: + assert self.trainer.model.training + assert self.training + + def on_validation_start(self) -> None: + assert not self.trainer.model.training + assert not self.training + + def on_test_start(self) -> None: + assert not self.trainer.model.training + assert not self.training + + def on_predict_start(self) -> None: + assert not self.trainer.model.training + assert not self.training + + +@pytest.mark.parametrize(['accelerator', 'num_processes'], + [(None, 1), pytest.param('ddp', 2, marks=RunIf(skip_windows=True))]) +def test_model_in_correct_mode_during_stages(tmpdir, accelerator, num_processes): + model = TrainerStagesModel() + trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, num_processes=num_processes, fast_dev_run=True) + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + trainer.predict(model, model.val_dataloader())