From 2acc0b2e4cc327aa4f698b5a6d25571bcbd285e1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 24 Feb 2022 21:19:10 +0100 Subject: [PATCH 01/13] Do not prefetch when possible --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 9 ++++++++- pytorch_lightning/loops/fit_loop.py | 9 ++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index cb0e79ae89448..5c744043ff0a9 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -81,6 +81,13 @@ def dataloaders(self) -> Sequence[DataLoader]: raise RuntimeError("Dataloaders should be available.") return dataloaders + @property + def prefetch_batches(self) -> int: + batches = self.trainer.num_test_batches if self.trainer.testing else self.trainer.num_val_batches + iterable_dataset = batches[self.current_dataloader_idx] == float("inf") + inter_batch_parallelism = os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1" + return 1 if iterable_dataset or inter_batch_parallelism else 0 + def connect(self, epoch_loop: EvaluationEpochLoop) -> None: # type: ignore[override] """Connect the evaluation epoch loop with this loop.""" self.epoch_loop = epoch_loop @@ -121,7 +128,7 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None: void(*args, **kwargs) data_fetcher_cls = _select_data_fetcher_type(self.trainer) - self._data_fetcher = data_fetcher_cls() + self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches) # hook self._on_evaluation_model_eval() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index a942f3bf75a99..cec6c5e8f7cf1 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -149,6 +149,12 @@ def restarting(self, restarting: bool) -> None: restarting &= finished_before_on_train_end Loop.restarting.fset(self, restarting) # call the parent setter + @property + def prefetch_batches(self) -> int: + iterable_dataset = self.trainer.num_training_batches == float("inf") + inter_batch_parallelism = os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1" + return 1 if iterable_dataset or inter_batch_parallelism else 0 + @property def _skip_backward(self) -> bool: """Determines whether the loop will skip backward during automatic optimization.""" @@ -213,8 +219,9 @@ def on_run_start(self) -> None: # type: ignore[override] """Calls the ``on_train_start`` hook.""" # reset train dataloader and val dataloader self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module) + data_fetcher_cls = _select_data_fetcher(self.trainer) - self._data_fetcher = data_fetcher_cls() + self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches) self._is_fresh_start_epoch = True self._results.to(device=self.trainer.lightning_module.device) From 4501536f79036deddafc1edc06cf6d5a82d714bb Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 24 Feb 2022 21:50:23 +0100 Subject: [PATCH 02/13] Minor change --- pytorch_lightning/utilities/fetching.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 0b8b1276618b6..309c1ce05b8eb 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -205,7 +205,7 @@ class DataFetcher(AbstractDataFetcher): Args: prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track - whether a batch is the last one (available with :attr:`self.done`). + whether a batch is the last one (available with :attr:`self.done`) under any training setup. store_on_device: Whether to store the pre-fetched batches on device. """ @@ -360,7 +360,8 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None: ... """ - def __init__(self) -> None: + def __init__(self, prefetch_batches: int = 0) -> None: + # prefetch batches is not used for this class super().__init__() self.store_on_device = False From dc20fc6d918b92f6abb0acd36433a8a840abd9b3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 24 Feb 2022 22:49:08 +0100 Subject: [PATCH 03/13] Remove TODOs --- tests/loops/test_loops.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index d578cecdab01e..20295293b9f6b 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -648,16 +648,12 @@ def train_dataloader(self): "ready": n_epochs, "started": n_epochs, "processed": n_epochs, - # TODO: the following "-1" offset will be fixed by - # https://github.com/PyTorchLightning/pytorch-lightning/pull/8578 "completed": n_epochs - 1, }, "current": { "ready": n_epochs, "started": n_epochs, "processed": n_epochs, - # TODO: the following "-1" offset will be fixed by - # https://github.com/PyTorchLightning/pytorch-lightning/pull/8578 "completed": n_epochs - 1, }, }, From d361131b2b63b0a3d79e1473e2da830131139a1c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Feb 2022 00:54:11 +0100 Subject: [PATCH 04/13] Fix for fault-tolerance restart --- .../loops/epoch/evaluation_epoch_loop.py | 2 + .../loops/epoch/training_epoch_loop.py | 2 + pytorch_lightning/utilities/fetching.py | 14 ++- tests/loops/test_loops.py | 2 - tests/utilities/test_fetching.py | 96 ++++++++++--------- 5 files changed, 69 insertions(+), 47 deletions(-) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index d4cf026a6047a..1eae3af918001 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -87,6 +87,8 @@ def on_run_start( # type: ignore[override] self._reload_dataloader_state_dict(data_fetcher) self._dataloader_iter = iter(data_fetcher) + # add the previous `fetched` value to properly track `is_last_batch` with no prefetching + data_fetcher.fetched += self.batch_progress.current.ready def advance( # type: ignore[override] self, diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index c8eefedd3c327..86cae76506a74 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -144,6 +144,8 @@ def reset(self) -> None: def on_run_start(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override] self._reload_dataloader_state_dict(data_fetcher) self._dataloader_iter = iter(data_fetcher) + # add the previous `fetched` value to properly track `is_last_batch` with no prefetching + data_fetcher.fetched += self.batch_progress.current.ready def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override] """Runs a single training batch. diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 309c1ce05b8eb..861e95920822e 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -30,6 +30,7 @@ MergedIteratorState, patch_dataloader_iterator, ) +from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -79,6 +80,8 @@ def __init__(self, prefetch_batches: int = 0) -> None: def setup(self, dataloader: Iterable, **kwargs: Any) -> None: self._add_capture_metadata_collate(dataloader) self._dataloader = dataloader + _patch_dataloader_get_iterators() + self._attach_data_fetcher() @property def dataloader(self) -> Iterable: @@ -172,8 +175,6 @@ def _attach_data_fetcher_fn(loader: DataLoader) -> None: def __iter__(self) -> "AbstractDataFetcher": self.reset() - self._attach_data_fetcher() - _patch_dataloader_get_iterators() self.dataloader_iter = iter(self.dataloader) self._apply_patch() self.prefetching() @@ -214,11 +215,14 @@ def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> N self.store_on_device = store_on_device self.batch_to_device: Callable[[Any], Any] = _no_op_batch_to_device self.batches: List[Any] = [] + # used to know whether we can access the dataloader length + self._has_len: bool = False def setup( # type: ignore[override] self, dataloader: Iterable, batch_to_device: Optional[Callable[[Any], Any]] = None ) -> None: super().setup(dataloader) + self._has_len = has_len(dataloader) if batch_to_device is not None: self.batch_to_device = batch_to_device @@ -233,6 +237,9 @@ def prefetching(self) -> None: try: self._fetch_next_batch(iterator) except StopIteration: + # this would only happen when prefetch_batches > the number of batches available and makes + # `fetching_function` jump directly to the empty iteration return path + self.done = True break def fetching_function(self) -> Tuple[Any, bool]: @@ -266,6 +273,9 @@ def _fetch_next_batch(self, iterator: Iterator) -> None: start_output = self.on_fetch_start() batch = next(iterator) self.fetched += 1 + if not self.prefetch_batches and self._has_len: + # support `done` for non-iterable datasets + self.done = self.fetched >= len(self.dataloader) self.on_fetch_end(batch, start_output) def move_to_device(self, batch: Any) -> Any: diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 20295293b9f6b..cfc347293484c 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -952,8 +952,6 @@ def val_dataloader(self): # totals are increased by 1 (the failed batch which never completed) expected = state_dict.copy() - # TODO: `is_last_batch` is not correct on reload, the next line should not be necessary - expected["epoch_loop.batch_progress"]["is_last_batch"] = val_check_interval == 1.0 assert state_dict_after_restart["epoch_loop.batch_progress"] == expected["epoch_loop.batch_progress"] val_dl_progress = "epoch_loop.val_loop.dataloader_progress" diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index b35768ffa7b90..91cd3eac160b0 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -18,7 +18,6 @@ import pytest import torch -from torch import tensor from torch.utils.data import DataLoader, Dataset, IterableDataset from pytorch_lightning import Callback, LightningDataModule, Trainer @@ -30,55 +29,66 @@ from tests.helpers.runif import RunIf -@pytest.mark.parametrize("use_combined_loader", [False, True]) -def test_prefetch_iterator(use_combined_loader): - """Test the DataFetcher with PyTorch IterableDataset.""" +class IterDataset(IterableDataset): + def __iter__(self): + yield 1 + yield 2 + yield 3 - class IterDataset(IterableDataset): - def __iter__(self): - yield 1 - yield 2 - yield 3 - - for prefetch_batches in range(5): - iterator = DataFetcher(prefetch_batches=prefetch_batches) - assert iterator.prefetch_batches == prefetch_batches - - if use_combined_loader: - loader = CombinedLoader([DataLoader(IterDataset()), DataLoader(IterDataset())]) - else: - loader = DataLoader(IterDataset()) - iterator.setup(loader) - - def generate(): - generated = [(iterator.fetched, *data) for i, data in enumerate(iterator, prefetch_batches + 1)] - assert iterator.fetched == 3 - assert iterator.done - return generated - - is_last_batch = [False, False, prefetch_batches > 0] - fetched = list(range(prefetch_batches + 1, 4)) - fetched += [3] * (3 - len(fetched)) - if use_combined_loader: - batches = [[tensor(1), tensor(1)], [tensor(2), tensor(2)], [tensor(3), tensor(3)]] - else: - batches = [1, 2, 3] - expected = list(zip(fetched, batches, is_last_batch)) - assert len(expected) == 3 - - assert generate() == expected - # validate reset works properly. - assert generate() == expected - assert iterator.fetched == 3 +class SizedDataset(Dataset): + def __len__(self): + return 3 + + def __getitem__(self, idx): + return idx + 1 + + +@pytest.mark.parametrize("use_combined_loader", [False, True]) +@pytest.mark.parametrize("dataset_cls", [IterDataset, SizedDataset]) +@pytest.mark.parametrize("prefetch_batches", list(range(5))) +def test_prefetch_iterator(use_combined_loader, dataset_cls, prefetch_batches): + fetcher = DataFetcher(prefetch_batches=prefetch_batches) + assert fetcher.prefetch_batches == prefetch_batches + + if use_combined_loader: + loader = CombinedLoader([DataLoader(dataset_cls()), DataLoader(dataset_cls())]) + else: + loader = DataLoader(dataset_cls()) + fetcher.setup(loader) + + def generate(): + generated = [(fetcher.fetched, *data) for data in fetcher] + assert fetcher.fetched == 3 + assert fetcher.done + return generated + + # we can only know the last batch with sized iterables or when we prefetch + is_last_batch = [False, False, prefetch_batches > 0 or dataset_cls is SizedDataset] + fetched = list(range(prefetch_batches + 1, 4)) + fetched += [3] * (3 - len(fetched)) + batches = [[1, 1], [2, 2], [3, 3]] if use_combined_loader else [1, 2, 3] + expected = list(zip(fetched, batches, is_last_batch)) + assert len(expected) == 3 + + assert generate() == expected + # validate reset works properly. + assert generate() == expected + assert fetcher.fetched == 3 + + +def test_empty_prefetch_iterator(): class EmptyIterDataset(IterableDataset): def __iter__(self): return iter([]) loader = DataLoader(EmptyIterDataset()) - iterator = DataFetcher() - iterator.setup(loader) - assert not list(iterator) + fetcher = DataFetcher() + fetcher.setup(loader) + + assert not fetcher.done + assert not list(fetcher) + assert fetcher.done def test_misconfiguration_error(): From b9655d8c8d2a08c91b7ee04982084d762a876a8a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Feb 2022 00:56:26 +0100 Subject: [PATCH 05/13] Fix comments --- pytorch_lightning/utilities/fetching.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 861e95920822e..fbbcde07df456 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -238,7 +238,7 @@ def prefetching(self) -> None: self._fetch_next_batch(iterator) except StopIteration: # this would only happen when prefetch_batches > the number of batches available and makes - # `fetching_function` jump directly to the empty iteration return path + # `fetching_function` jump directly to the empty iterator case without trying to fetch again self.done = True break @@ -274,7 +274,7 @@ def _fetch_next_batch(self, iterator: Iterator) -> None: batch = next(iterator) self.fetched += 1 if not self.prefetch_batches and self._has_len: - # support `done` for non-iterable datasets + # this is for `done` with sized datasets self.done = self.fetched >= len(self.dataloader) self.on_fetch_end(batch, start_output) From 25bb22d676a2fdd158835f082aeceefd4c9c146e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Feb 2022 01:09:53 +0100 Subject: [PATCH 06/13] mypy --- pytorch_lightning/utilities/fetching.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index fbbcde07df456..6d53962505271 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -15,7 +15,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Iterator from copy import deepcopy -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Sized, Tuple import torch from torch.utils.data.dataloader import DataLoader @@ -215,8 +215,7 @@ def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> N self.store_on_device = store_on_device self.batch_to_device: Callable[[Any], Any] = _no_op_batch_to_device self.batches: List[Any] = [] - # used to know whether we can access the dataloader length - self._has_len: bool = False + self._has_len = False def setup( # type: ignore[override] self, dataloader: Iterable, batch_to_device: Optional[Callable[[Any], Any]] = None @@ -274,8 +273,10 @@ def _fetch_next_batch(self, iterator: Iterator) -> None: batch = next(iterator) self.fetched += 1 if not self.prefetch_batches and self._has_len: - # this is for `done` with sized datasets - self.done = self.fetched >= len(self.dataloader) + # when we don't prefetch but the dataloader is sized, we use the length for `done` + dataloader = self.dataloader + assert isinstance(dataloader, Sized) # `_has_len` is True + self.done = self.fetched >= len(dataloader) self.on_fetch_end(batch, start_output) def move_to_device(self, batch: Any) -> Any: From 6e765886aa79489d0eef63fd2ea3ae5792f6d8bd Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Feb 2022 01:21:56 +0100 Subject: [PATCH 07/13] Fix manual test --- tests/utilities/test_auto_restart.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index f9edf364120bf..d35f97c87160c 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1452,7 +1452,7 @@ def load_state_dict(self, state_dict): class RandomFaultTolerantSampler(RandomSampler): - def __init__(self, *args, seed: int = 0, generator=None, **kwargs): + def __init__(self, *args, seed: int = 0, **kwargs): generator = torch.Generator().manual_seed(seed) super().__init__(*args, generator=generator, **kwargs) self.counter = 0 @@ -1467,7 +1467,7 @@ def load_state_dict(self, state_dict): self.restarting = True def __len__(self): - return len(self.data_source) - self.counter + return max(len(self.data_source) - self.counter, 1) def __iter__(self) -> Iterator[int]: n = len(self.data_source) @@ -1558,7 +1558,7 @@ def configure_optimizers(self): seed_everything(42) model = TestModel(should_fail=True) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval) - with suppress(CustomException): + with pytest.raises(CustomException): trainer.fit(model) trainer.train_dataloader = None failed_batches = model.batches From 9c51e85e0c946915c1fea6ba77d9df37a8790564 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Feb 2022 01:23:23 +0100 Subject: [PATCH 08/13] Fix fault tolerant manual test --- tests/utilities/test_auto_restart.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index d35f97c87160c..28bdfda702f71 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1467,6 +1467,7 @@ def load_state_dict(self, state_dict): self.restarting = True def __len__(self): + # the `utilities.data.has_len` requires at least 1 sample return max(len(self.data_source) - self.counter, 1) def __iter__(self) -> Iterator[int]: From 444c70724931fe2b780725840a5fb00bc6722f16 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Feb 2022 03:09:48 +0100 Subject: [PATCH 09/13] Fix test and extend coverage --- tests/utilities/test_fetching.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 91cd3eac160b0..f71eb65a7f8af 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -196,7 +196,7 @@ def __init__(self, check_inter_batch): def on_train_epoch_end(self, trainer, lightning_module): fetcher = trainer.fit_loop._data_fetcher assert isinstance(fetcher, InterBatchParallelDataFetcher if self._check_inter_batch else DataFetcher) - assert fetcher.prefetch_batches == 1 + assert fetcher.prefetch_batches == int(self._check_inter_batch) trainer_kwargs = dict( default_root_dir=tmpdir, @@ -277,14 +277,19 @@ def training_epoch_end(self, *_): @RunIf(min_torch="1.8.0") def test_fetching_dataloader_iter_running_stages(fn, tmpdir): class TestModel(BoringModel): - def validation_step(self, dataloader_iter, batch_idx): - assert isinstance(self.trainer.validate_loop._data_fetcher, DataLoaderIterDataFetcher) + def fetch(self, data_fetcher, dataloader_iter, batch_idx): + assert isinstance(data_fetcher, DataLoaderIterDataFetcher) + assert data_fetcher.fetched == batch_idx batch = next(dataloader_iter) + assert data_fetcher.fetched == batch_idx + 1 + return batch + + def validation_step(self, dataloader_iter, batch_idx): + batch = self.fetch(self.trainer.validate_loop._data_fetcher, dataloader_iter, batch_idx) return super().validation_step(batch, batch_idx) def test_step(self, dataloader_iter, batch_idx): - assert isinstance(self.trainer.test_loop._data_fetcher, DataLoaderIterDataFetcher) - batch = next(dataloader_iter) + batch = self.fetch(self.trainer.test_loop._data_fetcher, dataloader_iter, batch_idx) return super().test_step(batch, batch_idx) model = TestModel() From 628447461fbee7f89c32eff2dd38f809ad7b9812 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Feb 2022 18:50:13 +0100 Subject: [PATCH 10/13] Support 0-length map datasets --- pytorch_lightning/utilities/data.py | 33 +++++++++++----------------- tests/trainer/test_dataloaders.py | 16 +++++--------- tests/utilities/test_auto_restart.py | 3 +-- tests/utilities/test_data.py | 6 ++--- tests/utilities/test_fetching.py | 20 ++++++++++++----- 5 files changed, 37 insertions(+), 41 deletions(-) diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 2b68c51db57d5..661732bb5afa5 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -89,17 +89,13 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool: def has_len(dataloader: Union[DataLoader, Iterable]) -> bool: """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or - infinite dataloader. - - Raises: - ValueError: - If the length of Dataloader is 0, as it requires at least one batch - """ - + infinite dataloader.""" try: # try getting the length if len(dataloader) == 0: - raise ValueError("`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch") + rank_zero_warn( + f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention." + ) has_len = True except TypeError: has_len = False @@ -122,30 +118,27 @@ def has_len_all_ranks( model: Union["pl.LightningModule", "pl.LightningDataModule"], ) -> bool: """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or - infinite dataloader. - - Raises: - ValueError: - If the length of Dataloader is 0, as it requires at least one batch - """ + infinite dataloader.""" try: - total_length = training_type.reduce(torch.tensor(len(dataloader)).to(model.device), reduce_op="sum") local_length = len(dataloader) + total_length = training_type.reduce(torch.tensor(local_length).to(model.device), reduce_op="sum") if total_length == 0: - raise MisconfigurationException( - "Total length of `Dataloader` across ranks is zero. Please make sure that it returns at least 1 batch." + rank_zero_warn( + f"Total length of `{dataloader.__class__.__name__}` across ranks is zero." + " Please make sure this was your intention" ) if total_length > 0 and local_length == 0: if model.allow_zero_length_dataloader_with_multiple_devices: rank_zero_warn( - "Total length of `Dataloader` across ranks is zero, but local rank has zero length." - " Please be cautious of uneven batch length." + f"Total length of `{dataloader.__class__.__name__}` across ranks is zero, but local rank has zero" + " length. Please be cautious of uneven batch length." ) has_len = False else: raise MisconfigurationException( - "`Dataloader` within local rank has zero length. Please make sure that it returns at least 1 batch." + f"`{dataloader.__class__.__name__}` within local rank has zero length." + " Please make sure that it returns at least 1 batch." ) else: has_len = True diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index d8e4989bb34d6..08d54e05bfffe 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -516,20 +516,16 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path): assert len(trainer.test_dataloaders) == 1 -def test_error_on_zero_len_dataloader(tmpdir): - """Test that error is raised if a zero-length dataloader is defined.""" - - class CustomBoringModel(BoringModel): - def train_dataloader(self): - return DataLoader(RandomDataset(32, 0)) - - model = CustomBoringModel() +def test_warning_on_zero_len_dataloader(tmpdir): + """Test that a warning is raised if a zero-length dataloader is defined.""" + model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, fast_dev_run=1, ) - with pytest.raises(ValueError, match="returned 0 length. .* at least 1 batch"): - trainer.fit(model) + dataloader = DataLoader(RandomDataset(32, 0)) + with pytest.warns(UserWarning, match="returned 0 length"): + trainer.fit(model, dataloader) @RunIf(skip_windows=True) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 28bdfda702f71..d8aa21f3dd8ec 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1467,8 +1467,7 @@ def load_state_dict(self, state_dict): self.restarting = True def __len__(self): - # the `utilities.data.has_len` requires at least 1 sample - return max(len(self.data_source) - self.counter, 1) + return len(self.data_source) - self.counter def __iter__(self) -> Iterator[int]: n = len(self.data_source) diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index ae454080a3b77..feeccf694d9a6 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -93,7 +93,7 @@ def __iter__(self): def test_has_len(): assert has_len(DataLoader(RandomDataset(1, 1))) - with pytest.raises(ValueError, match="`Dataloader` returned 0 length."): + with pytest.warns(UserWarning, match="`DataLoader` returned 0 length."): assert has_len(DataLoader(RandomDataset(0, 0))) assert not has_len(DataLoader(RandomIterableDataset(1, 1))) @@ -112,8 +112,8 @@ def test_has_len_all_rank(): trainer = Trainer(fast_dev_run=True) model = BoringModel() - with pytest.raises(MisconfigurationException, match="Total length of `Dataloader` across ranks is zero."): - assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.strategy, model) + with pytest.warns(UserWarning, match="Total length of `DataLoader` across ranks is zero."): + assert has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.strategy, model) assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.strategy, model) diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index f71eb65a7f8af..d39d3aefb8cab 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -77,13 +77,21 @@ def generate(): assert fetcher.fetched == 3 -def test_empty_prefetch_iterator(): - class EmptyIterDataset(IterableDataset): - def __iter__(self): - return iter([]) +class EmptyIterDataset(IterableDataset): + def __iter__(self): + return iter([]) - loader = DataLoader(EmptyIterDataset()) - fetcher = DataFetcher() + +class EmptySizedDataset(Dataset): + def __len__(self): + return 0 + + +@pytest.mark.parametrize("dataset_cls", [EmptyIterDataset, EmptySizedDataset]) +@pytest.mark.parametrize("prefetch_batches", list(range(2))) +def test_empty_prefetch_iterator(dataset_cls, prefetch_batches): + loader = DataLoader(dataset_cls()) + fetcher = DataFetcher(prefetch_batches=prefetch_batches) fetcher.setup(loader) assert not fetcher.done From 1fc58ed08016ec16375377f57d4ac2808345c643 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 28 Feb 2022 11:58:42 +0100 Subject: [PATCH 11/13] Rename variable --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 4 ++-- pytorch_lightning/loops/fit_loop.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 5c744043ff0a9..8f35b39d60fdd 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -84,9 +84,9 @@ def dataloaders(self) -> Sequence[DataLoader]: @property def prefetch_batches(self) -> int: batches = self.trainer.num_test_batches if self.trainer.testing else self.trainer.num_val_batches - iterable_dataset = batches[self.current_dataloader_idx] == float("inf") + is_unsized = batches[self.current_dataloader_idx] == float("inf") inter_batch_parallelism = os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1" - return 1 if iterable_dataset or inter_batch_parallelism else 0 + return 1 if is_unsized or inter_batch_parallelism else 0 def connect(self, epoch_loop: EvaluationEpochLoop) -> None: # type: ignore[override] """Connect the evaluation epoch loop with this loop.""" diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index cec6c5e8f7cf1..0c9c68b24f4a0 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -151,9 +151,9 @@ def restarting(self, restarting: bool) -> None: @property def prefetch_batches(self) -> int: - iterable_dataset = self.trainer.num_training_batches == float("inf") + is_unsized = self.trainer.num_training_batches == float("inf") inter_batch_parallelism = os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1" - return 1 if iterable_dataset or inter_batch_parallelism else 0 + return 1 if is_unsized or inter_batch_parallelism else 0 @property def _skip_backward(self) -> bool: From a0334e978c9dffa7d72a7a1f8f81aa89d8bb7a1e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 28 Feb 2022 17:42:58 +0100 Subject: [PATCH 12/13] Mention prefetching in the docs --- docs/source/guides/data.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/guides/data.rst b/docs/source/guides/data.rst index dbfba8598a36f..1eda7ac6291ae 100644 --- a/docs/source/guides/data.rst +++ b/docs/source/guides/data.rst @@ -393,6 +393,9 @@ option when using sequential data. to ``limit_{mode}_batches``, if it is set to 1.0 it will run for the whole dataset, otherwise it will throw an exception. Here ``mode`` can be train/val/test/predict. +When iterable datasets are used, Lightning will pre-fetch 1 batch (in addition to the current batch) so it can detect +when the training will stop and run validation if necessary. + .. testcode:: # IterableDataset From 54403f61a9f4a716f49bf2bf4052e1a225350b54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 28 Feb 2022 18:16:40 +0100 Subject: [PATCH 13/13] Update pytorch_lightning/utilities/data.py Co-authored-by: Rohit Gupta --- pytorch_lightning/utilities/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 5eb2ec05f407d..3f59a8f017cc7 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -126,7 +126,7 @@ def has_len_all_ranks( if total_length == 0: rank_zero_warn( f"Total length of `{dataloader.__class__.__name__}` across ranks is zero." - " Please make sure this was your intention" + " Please make sure this was your intention." ) if total_length > 0 and local_length == 0: if model.allow_zero_length_dataloader_with_multiple_devices: