Skip to content

Commit

Permalink
Add support for iterable datasets when val_check_interval=1.0 (Lightn…
Browse files Browse the repository at this point in the history
…ing-AI#1283)

* Add support for iterable datasets when val_check_interval=1.0

* Update CHANGELOG.md
  • Loading branch information
ethanwharris authored and akarnachev committed Apr 3, 2020
1 parent 3a4ee01 commit 9ebe93c
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for non-primitive types in `hparams` for `TensorboardLogger` ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130))
- Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))

### Changed

Expand Down
21 changes: 12 additions & 9 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,19 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
'If you want to disable validation set `val_percent_check` to 0.0 instead.')
else:
if not _has_len(self.train_dataloader):
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
'DataLoader does not implement `__len__`) for `train_dataloader`, '
'`Trainer(val_check_interval)` must be an int. An int k specifies checking '
'validation every k training batches.')

self._percent_range_check('val_check_interval')
if self.val_check_interval == 1.0:
self.val_check_batch = float('inf')
else:
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
'DataLoader does not implement `__len__`) for `train_dataloader`, '
'`Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies '
'checking validation every k training batches.')
else:
self._percent_range_check('val_check_interval')

self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)

def _reset_eval_dataloader(self, model: LightningModule,
mode: str) -> Tuple[int, List[DataLoader]]:
Expand Down
23 changes: 19 additions & 4 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,8 @@ def run_training_epoch(self):
train_dataloader = train_dataloader.per_device_loader(device)

# run epoch
for batch_idx, batch in self.profiler.profile_iterable(
enumerate(train_dataloader), "get_train_batch"
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
):
# stop epoch if we limited the number of training batches
if batch_idx >= self.num_training_batches:
Expand Down Expand Up @@ -429,8 +429,10 @@ def run_training_epoch(self):
# ---------------
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
should_check_val = not self.disable_validation and can_check_epoch
should_check_val = should_check_val and (is_val_check_batch or early_stop_epoch)
can_check_val = not self.disable_validation and can_check_epoch
should_check_val = is_val_check_batch or early_stop_epoch
should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))
should_check_val = can_check_val and should_check_val

# fast_dev_run always forces val checking after train batch
if self.fast_dev_run or should_check_val:
Expand Down Expand Up @@ -740,3 +742,16 @@ def call_checkpoint_callback(self):
if self.checkpoint_callback is not None:
self.checkpoint_callback.on_validation_end(self, self.get_model())
self.on_validation_end()


def _with_is_last(iterable):
"""Pass through values from the given iterable with an added boolean indicating if this is the last item.
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_"""
it = iter(iterable)
last = next(it)
for val in it:
# yield last and has next
yield last, False
last = val
# yield last, no longer has next
yield last, True
10 changes: 9 additions & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,6 @@ class CurrentTestModel(
)
trainer.fit(model)

# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
Expand All @@ -383,6 +382,15 @@ class CurrentTestModel(
# verify training completed
assert result == 1

trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1
)
result = trainer.fit(model)

# verify training completed
assert result == 1


def test_inf_val_dataloader(tmpdir):
"""Test inf val data loader (e.g. IterableDataset)"""
Expand Down

0 comments on commit 9ebe93c

Please sign in to comment.