Skip to content

Commit

Permalink
fix boolean check on iterable dataset when len not defined (#6828)
Browse files Browse the repository at this point in the history
* fix iterable dataset len check

* update predict and validate

* add validate to test

* add changelog

* add predict
  • Loading branch information
awaelchli authored Apr 5, 2021
1 parent 22a266d commit 264aa68
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506))


- Fixed an issue with `IterableDataset` when `__len__` is not defined ([#6828](https://github.com/PyTorchLightning/pytorch-lightning/pull/6828))


## [1.2.6] - 2021-03-30

### Changed
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ def validate(
self.validating = True

# If you supply a datamodule you can't supply val_dataloaders
if val_dataloaders and datamodule:
if val_dataloaders is not None and datamodule:
raise MisconfigurationException(
'You cannot pass both `trainer.validate(val_dataloaders=..., datamodule=...)`'
)
Expand Down Expand Up @@ -928,7 +928,7 @@ def test(
self.testing = True

# If you supply a datamodule you can't supply test_dataloaders
if test_dataloaders and datamodule:
if test_dataloaders is not None and datamodule:
raise MisconfigurationException('You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`')

model_provided = model is not None
Expand Down Expand Up @@ -1024,7 +1024,7 @@ def predict(
self.state = TrainerState.PREDICTING
self.predicting = True

if dataloaders and datamodule:
if dataloaders is not None and datamodule:
raise MisconfigurationException(
'You cannot pass dataloaders to trainer.predict if you supply a datamodule.'
)
Expand Down
26 changes: 20 additions & 6 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,28 +636,42 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):

def test_warning_with_iterable_dataset_and_len(tmpdir):
""" Tests that a warning message is shown when an IterableDataset defines `__len__`. """
model = EvalModelTemplate()
model = BoringModel()
original_dataset = model.train_dataloader().dataset

class IterableWithLen(IterableDataset):
class IterableWithoutLen(IterableDataset):

def __iter__(self):
return iter(original_dataset)

class IterableWithLen(IterableWithoutLen):

def __len__(self):
return len(original_dataset)

# with __len__ defined
dataloader = DataLoader(IterableWithLen(), batch_size=16)
assert has_len(dataloader)
assert has_iterable_dataset(dataloader)
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=3,
)
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.validate(model, val_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.test(model, test_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.predict(model, dataloaders=[dataloader])

# without __len__ defined
dataloader = DataLoader(IterableWithoutLen(), batch_size=16)
assert not has_len(dataloader)
assert has_iterable_dataset(dataloader)
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
trainer.validate(model, val_dataloaders=dataloader)
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
trainer.test(model, test_dataloaders=dataloader)
trainer.predict(model, dataloaders=dataloader)


@RunIf(min_gpus=2)
Expand Down

0 comments on commit 264aa68

Please sign in to comment.