From f35b9b8b18912338807b6850a5557182dc613aa7 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Tue, 16 Mar 2021 21:49:49 +0530 Subject: [PATCH] Fix: Train loop config validation was run during `trainer.predict` (#6541) --- CHANGELOG.md | 3 ++ .../trainer/configuration_validator.py | 10 +++++- tests/trainer/test_trainer.py | 32 +++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b4d4abbc09d04..8210222b07c28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -139,6 +139,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541)) +- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541)) + + ## [1.2.3] - 2021-03-09 ### Fixed diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 9cb22f39b7228..a5c3a8d04a1dd 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -30,7 +30,9 @@ def verify_loop_configurations(self, model: LightningModule): model: The model to check the configuration. """ - if not self.trainer.testing: + if self.trainer.predicting: + self.__verify_predict_loop_configuration(model) + elif not self.trainer.testing: self.__verify_train_loop_configuration(model) self.__verify_eval_loop_configuration(model, 'validation') else: @@ -98,3 +100,9 @@ def __verify_eval_loop_configuration(self, model, eval_loop_name): rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop') if has_step and not has_loader: rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop') + + def __verify_predict_loop_configuration(self, model): + + has_predict_dataloader = is_overridden('predict_dataloader', model) + if not has_predict_dataloader: + raise MisconfigurationException('Dataloader not found for `Trainer.predict`') diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 59f3c2b54c13c..6966edc3cbf70 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1850,3 +1850,35 @@ def compare_optimizers(): trainer.max_epochs = 2 # simulate multiple fit calls trainer.fit(model) compare_optimizers() + + +@pytest.mark.parametrize("use_datamodule", [False, True]) +def test_trainer_predict_verify_config(tmpdir, use_datamodule): + + class TestModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir) + + if use_datamodule: + datamodule = TestLightningDataModule(dataloaders) + results = trainer.predict(model, datamodule=datamodule) + else: + results = trainer.predict(model, dataloaders=dataloaders) + + assert len(results) == 2 + assert results[0][0].shape == torch.Size([1, 2]) + + model.predict_dataloader = None + + with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"): + trainer.predict(model)