From 269fda8c8713db3f442d69f00712dfc35ba6b0d3 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 16 Mar 2021 01:53:18 +0530 Subject: [PATCH 1/6] Add verify config for predict --- .../trainer/configuration_validator.py | 10 +++++- tests/trainer/test_trainer.py | 32 +++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 9cb22f39b7228..a5c3a8d04a1dd 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -30,7 +30,9 @@ def verify_loop_configurations(self, model: LightningModule): model: The model to check the configuration. """ - if not self.trainer.testing: + if self.trainer.predicting: + self.__verify_predict_loop_configuration(model) + elif not self.trainer.testing: self.__verify_train_loop_configuration(model) self.__verify_eval_loop_configuration(model, 'validation') else: @@ -98,3 +100,9 @@ def __verify_eval_loop_configuration(self, model, eval_loop_name): rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop') if has_step and not has_loader: rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop') + + def __verify_predict_loop_configuration(self, model): + + has_predict_dataloader = is_overridden('predict_dataloader', model) + if not has_predict_dataloader: + raise MisconfigurationException('Dataloader not found for `Trainer.predict`') diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 59f3c2b54c13c..c379b6c3d96c6 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1850,3 +1850,35 @@ def compare_optimizers(): trainer.max_epochs = 2 # simulate multiple fit calls trainer.fit(model) compare_optimizers() + + +def test_trainer_predict_verify_config(tmpdir): + + class TestModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] + + model = TestModel() + datamodule = TestLightningDataModule(dataloaders) + + trainer = Trainer(default_root_dir=tmpdir) + + if datamodule: + results = trainer.predict(model, datamodule=datamodule) + else: + results = trainer.predict(model, dataloaders=dataloaders) + + assert len(results) == 2 + assert results[0][0].shape == torch.Size([1, 2]) + + model.predict_dataloader = None + + with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"): + trainer.predict(model) From 3af8202751d768ac8cf5725aa96865ed7102f536 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 16 Mar 2021 10:32:14 +0530 Subject: [PATCH 2/6] Fix tests & Changelog --- CHANGELOG.md | 7 +++++++ tests/trainer/test_trainer.py | 5 +++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b2aa324beaf3..fda1ded47207c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [1.2.x] - YYYY-MM-DD + +### Added + +- Added `trainer.predict` config validation ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541)) + + ## [1.2.3] - 2021-03-09 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c379b6c3d96c6..724a4075b03e3 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1852,7 +1852,8 @@ def compare_optimizers(): compare_optimizers() -def test_trainer_predict_verify_config(tmpdir): +@pytest.mark.parametrize("datamodule", [False, True]) +def test_trainer_predict_verify_config(tmpdir, datamodule): class TestModel(LightningModule): @@ -1866,11 +1867,11 @@ def forward(self, x): dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] model = TestModel() - datamodule = TestLightningDataModule(dataloaders) trainer = Trainer(default_root_dir=tmpdir) if datamodule: + datamodule = TestLightningDataModule(dataloaders) results = trainer.predict(model, datamodule=datamodule) else: results = trainer.predict(model, dataloaders=dataloaders) From 9a679bd901728885bf3762c8681924577de7f0c1 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 16 Mar 2021 11:02:57 +0530 Subject: [PATCH 3/6] Update Changelog --- CHANGELOG.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fda1ded47207c..fc173a99231a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,9 +7,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [1.2.x] - YYYY-MM-DD -### Added +### Fixed -- Added `trainer.predict` config validation ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541)) +- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541)) ## [1.2.3] - 2021-03-09 @@ -59,7 +59,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089)) - Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107)) - ## [1.2.0] - 2021-02-18 ### Added From d499974e0297ba82a086b3d2cbc64354fa953585 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 09:23:57 +0100 Subject: [PATCH 4/6] Update test_trainer.py --- tests/trainer/test_trainer.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 724a4075b03e3..8870a51d4d798 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1852,8 +1852,11 @@ def compare_optimizers(): compare_optimizers() -@pytest.mark.parametrize("datamodule", [False, True]) -def test_trainer_predict_verify_config(tmpdir, datamodule): +@pytest.mark.parametrize("data", [ + dict(datamodule=TestLightningDataModule(dataloaders)), + dict(dataloaders=[torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2)), +])]) +def test_trainer_predict_verify_config(tmpdir, data): class TestModel(LightningModule): @@ -1864,17 +1867,10 @@ def __init__(self): def forward(self, x): return self.layer(x) - dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] - model = TestModel() - trainer = Trainer(default_root_dir=tmpdir) - if datamodule: - datamodule = TestLightningDataModule(dataloaders) - results = trainer.predict(model, datamodule=datamodule) - else: - results = trainer.predict(model, dataloaders=dataloaders) + results = trainer.predict(model, **data) assert len(results) == 2 assert results[0][0].shape == torch.Size([1, 2]) From 3047e16dc0cd21b647f8f2bf04514a534f47ec0e Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 16 Mar 2021 14:21:55 +0530 Subject: [PATCH 5/6] fix code formatting issues --- tests/trainer/test_trainer.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 8870a51d4d798..40c1608c2ae8c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1852,10 +1852,17 @@ def compare_optimizers(): compare_optimizers() -@pytest.mark.parametrize("data", [ - dict(datamodule=TestLightningDataModule(dataloaders)), - dict(dataloaders=[torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2)), -])]) +@pytest.mark.parametrize( + "data", [ + dict(datamodule=TestLightningDataModule(dataloaders)), + dict( + dataloaders=[ + torch.utils.data.DataLoader(RandomDataset(32, 2)), + torch.utils.data.DataLoader(RandomDataset(32, 2)), + ] + ) + ] +) def test_trainer_predict_verify_config(tmpdir, data): class TestModel(LightningModule): From 779dadd347bc3df00b50dc152f3101dd0e4e6a33 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 16 Mar 2021 14:40:33 +0530 Subject: [PATCH 6/6] fix tests --- tests/trainer/test_trainer.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 40c1608c2ae8c..6966edc3cbf70 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1852,18 +1852,8 @@ def compare_optimizers(): compare_optimizers() -@pytest.mark.parametrize( - "data", [ - dict(datamodule=TestLightningDataModule(dataloaders)), - dict( - dataloaders=[ - torch.utils.data.DataLoader(RandomDataset(32, 2)), - torch.utils.data.DataLoader(RandomDataset(32, 2)), - ] - ) - ] -) -def test_trainer_predict_verify_config(tmpdir, data): +@pytest.mark.parametrize("use_datamodule", [False, True]) +def test_trainer_predict_verify_config(tmpdir, use_datamodule): class TestModel(LightningModule): @@ -1874,10 +1864,16 @@ def __init__(self): def forward(self, x): return self.layer(x) + dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] + model = TestModel() trainer = Trainer(default_root_dir=tmpdir) - results = trainer.predict(model, **data) + if use_datamodule: + datamodule = TestLightningDataModule(dataloaders) + results = trainer.predict(model, datamodule=datamodule) + else: + results = trainer.predict(model, dataloaders=dataloaders) assert len(results) == 2 assert results[0][0].shape == torch.Size([1, 2])