diff --git a/CHANGELOG.md b/CHANGELOG.md index 90c9f49566bae..3dfc810b8b911 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097)) - Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211)) - Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283)) +- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280)) ### Changed diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index dced28144fbe8..5117b3ab05134 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -26,9 +26,13 @@ def _has_len(dataloader: DataLoader) -> bool: + """ Checks if a given Dataloader has __len__ method implemented i.e. if + it is a finite dataloader or infinite dataloader """ try: # try getting the length - _ = len(dataloader) + if len(dataloader) == 0: + raise ValueError('Dataloader returned 0 length. Please make sure' + ' that your Dataloader atleast returns 1 batch') return True except TypeError: return False diff --git a/tests/base/__init__.py b/tests/base/__init__.py index 1e68469871d25..3c174ece038b3 100644 --- a/tests/base/__init__.py +++ b/tests/base/__init__.py @@ -25,7 +25,8 @@ LightTestOptimizerWithSchedulingMixin, LightTestMultipleOptimizersWithSchedulingMixin, LightTestOptimizersWithMixedSchedulingMixin, - LightTestReduceLROnPlateauMixin + LightTestReduceLROnPlateauMixin, + LightZeroLenDataloader ) diff --git a/tests/base/mixins.py b/tests/base/mixins.py index 1a05049f44f5f..4c249b1b71b18 100644 --- a/tests/base/mixins.py +++ b/tests/base/mixins.py @@ -255,6 +255,16 @@ def test_dataloader(self): return CustomInfDataloader(self._dataloader(train=False)) +class LightZeroLenDataloader: + """ Simple dataloader that has zero length. """ + + def train_dataloader(self): + dataloader = self._dataloader(train=True) + dataloader.dataset.data = dataloader.dataset.data[:0] + dataloader.dataset.targets = dataloader.dataset.targets[:0] + return dataloader + + class LightEmptyTestStep: """Empty test step.""" diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 194d67fe07858..6d2332cdf538b 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -16,7 +16,8 @@ LightTrainDataloader, LightInfTrainDataloader, LightInfValDataloader, - LightInfTestDataloader + LightInfTestDataloader, + LightZeroLenDataloader ) @@ -458,3 +459,26 @@ class CurrentTestModel( # verify training completed assert result == 1 + + +def test_error_on_zero_len_dataloader(tmpdir): + """ Test that error is raised if a zero-length dataloader is defined """ + tutils.reset_seed() + + class CurrentTestModel( + LightZeroLenDataloader, + LightningTestModel + ): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + # fit model + with pytest.raises(ValueError): + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=1, + test_percent_check=0.5 + ) + trainer.fit(model)