Skip to content

Commit

Permalink
Fix: Train loop config validation was run during trainer.predict (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 authored and Borda committed Mar 16, 2021
1 parent 7d4db76 commit f35b9b8
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541))


- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541))


## [1.2.3] - 2021-03-09

### Fixed
Expand Down
10 changes: 9 additions & 1 deletion pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def verify_loop_configurations(self, model: LightningModule):
model: The model to check the configuration.
"""
if not self.trainer.testing:
if self.trainer.predicting:
self.__verify_predict_loop_configuration(model)
elif not self.trainer.testing:
self.__verify_train_loop_configuration(model)
self.__verify_eval_loop_configuration(model, 'validation')
else:
Expand Down Expand Up @@ -98,3 +100,9 @@ def __verify_eval_loop_configuration(self, model, eval_loop_name):
rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop')
if has_step and not has_loader:
rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop')

def __verify_predict_loop_configuration(self, model):

has_predict_dataloader = is_overridden('predict_dataloader', model)
if not has_predict_dataloader:
raise MisconfigurationException('Dataloader not found for `Trainer.predict`')
32 changes: 32 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1850,3 +1850,35 @@ def compare_optimizers():
trainer.max_epochs = 2 # simulate multiple fit calls
trainer.fit(model)
compare_optimizers()


@pytest.mark.parametrize("use_datamodule", [False, True])
def test_trainer_predict_verify_config(tmpdir, use_datamodule):

class TestModel(LightningModule):

def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)

def forward(self, x):
return self.layer(x)

dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))]

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir)

if use_datamodule:
datamodule = TestLightningDataModule(dataloaders)
results = trainer.predict(model, datamodule=datamodule)
else:
results = trainer.predict(model, dataloaders=dataloaders)

assert len(results) == 2
assert results[0][0].shape == torch.Size([1, 2])

model.predict_dataloader = None

with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"):
trainer.predict(model)

0 comments on commit f35b9b8

Please sign in to comment.