Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not prefetch when possible #12101

Merged
merged 14 commits into from
Feb 28, 2022
9 changes: 8 additions & 1 deletion pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ def dataloaders(self) -> Sequence[DataLoader]:
raise RuntimeError("Dataloaders should be available.")
return dataloaders

@property
def prefetch_batches(self) -> int:
batches = self.trainer.num_test_batches if self.trainer.testing else self.trainer.num_val_batches
is_unsized = batches[self.current_dataloader_idx] == float("inf")
inter_batch_parallelism = os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1"
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return 1 if is_unsized or inter_batch_parallelism else 0

def connect(self, epoch_loop: EvaluationEpochLoop) -> None: # type: ignore[override]
"""Connect the evaluation epoch loop with this loop."""
self.epoch_loop = epoch_loop
Expand Down Expand Up @@ -121,7 +128,7 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None:
void(*args, **kwargs)

data_fetcher_cls = _select_data_fetcher_type(self.trainer)
self._data_fetcher = data_fetcher_cls()
self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches)

# hook
self._on_evaluation_model_eval()
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def on_run_start( # type: ignore[override]
self._reload_dataloader_state_dict(data_fetcher)
# creates the iterator inside the fetcher but returns `self`
self._data_fetcher = cast(AbstractDataFetcher, iter(data_fetcher))
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
data_fetcher.fetched += self.batch_progress.current.ready

def advance( # type: ignore[override]
self,
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def reset(self) -> None:

def on_run_start(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override]
self._reload_dataloader_state_dict(data_fetcher)
iter(data_fetcher) # creates the iterator inside the fetcher
_ = iter(data_fetcher) # creates the iterator inside the fetcher
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
data_fetcher.fetched += self.batch_progress.current.ready

def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override]
"""Runs a single training batch.
Expand Down
9 changes: 8 additions & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ def restarting(self, restarting: bool) -> None:
restarting &= finished_before_on_train_end
Loop.restarting.fset(self, restarting) # call the parent setter

@property
def prefetch_batches(self) -> int:
is_unsized = self.trainer.num_training_batches == float("inf")
inter_batch_parallelism = os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1"
return 1 if is_unsized or inter_batch_parallelism else 0

@property
def _skip_backward(self) -> bool:
"""Determines whether the loop will skip backward during automatic optimization."""
Expand Down Expand Up @@ -213,8 +219,9 @@ def on_run_start(self) -> None: # type: ignore[override]
"""Calls the ``on_train_start`` hook."""
# reset train dataloader and val dataloader
self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)

data_fetcher_cls = _select_data_fetcher(self.trainer)
self._data_fetcher = data_fetcher_cls()
self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches)

self._is_fresh_start_epoch = True
self._results.to(device=self.trainer.lightning_module.device)
Expand Down
33 changes: 13 additions & 20 deletions pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,13 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool:

def has_len(dataloader: Union[DataLoader, Iterable]) -> bool:
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
infinite dataloader.

Raises:
ValueError:
If the length of Dataloader is 0, as it requires at least one batch
"""

infinite dataloader."""
try:
# try getting the length
if len(dataloader) == 0:
raise ValueError("`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch")
rank_zero_warn(
f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention."
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
)
has_len = True
except TypeError:
has_len = False
Expand All @@ -122,30 +118,27 @@ def has_len_all_ranks(
model: Union["pl.LightningModule", "pl.LightningDataModule"],
) -> bool:
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
infinite dataloader.

Raises:
ValueError:
If the length of Dataloader is 0, as it requires at least one batch
"""
infinite dataloader."""
try:
total_length = training_type.reduce(torch.tensor(len(dataloader)).to(model.device), reduce_op="sum")
local_length = len(dataloader)
total_length = training_type.reduce(torch.tensor(local_length).to(model.device), reduce_op="sum")

if total_length == 0:
raise MisconfigurationException(
"Total length of `Dataloader` across ranks is zero. Please make sure that it returns at least 1 batch."
rank_zero_warn(
f"Total length of `{dataloader.__class__.__name__}` across ranks is zero."
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
" Please make sure this was your intention"
carmocca marked this conversation as resolved.
Show resolved Hide resolved
)
if total_length > 0 and local_length == 0:
if model.allow_zero_length_dataloader_with_multiple_devices:
rank_zero_warn(
"Total length of `Dataloader` across ranks is zero, but local rank has zero length."
" Please be cautious of uneven batch length."
f"Total length of `{dataloader.__class__.__name__}` across ranks is zero, but local rank has zero"
" length. Please be cautious of uneven batch length."
)
has_len = False
else:
raise MisconfigurationException(
"`Dataloader` within local rank has zero length. Please make sure that it returns at least 1 batch."
f"`{dataloader.__class__.__name__}` within local rank has zero length."
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
" Please make sure that it returns at least 1 batch."
)
else:
has_len = True
Expand Down
22 changes: 17 additions & 5 deletions pytorch_lightning/utilities/fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator
from copy import deepcopy
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Sized, Tuple

import torch
from torch.utils.data.dataloader import DataLoader
Expand All @@ -30,6 +30,7 @@
MergedIteratorState,
patch_dataloader_iterator,
)
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training

Expand Down Expand Up @@ -79,6 +80,8 @@ def __init__(self, prefetch_batches: int = 0) -> None:
def setup(self, dataloader: Iterable, **kwargs: Any) -> None:
self._add_capture_metadata_collate(dataloader)
self._dataloader = dataloader
_patch_dataloader_get_iterators()
self._attach_data_fetcher()

@property
def dataloader(self) -> Iterable:
Expand Down Expand Up @@ -172,8 +175,6 @@ def _attach_data_fetcher_fn(loader: DataLoader) -> None:

def __iter__(self) -> "AbstractDataFetcher":
self.reset()
self._attach_data_fetcher()
_patch_dataloader_get_iterators()
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.dataloader_iter = iter(self.dataloader)
self._apply_patch()
self.prefetching()
Expand Down Expand Up @@ -205,7 +206,7 @@ class DataFetcher(AbstractDataFetcher):

Args:
prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track
whether a batch is the last one (available with :attr:`self.done`).
whether a batch is the last one (available with :attr:`self.done`) under any training setup.
store_on_device: Whether to store the pre-fetched batches on device.
"""

Expand All @@ -214,11 +215,13 @@ def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> N
self.store_on_device = store_on_device
self.batch_to_device: Callable[[Any], Any] = _no_op_batch_to_device
self.batches: List[Any] = []
self._has_len = False

def setup( # type: ignore[override]
self, dataloader: Iterable, batch_to_device: Optional[Callable[[Any], Any]] = None
) -> None:
super().setup(dataloader)
self._has_len = has_len(dataloader)
if batch_to_device is not None:
self.batch_to_device = batch_to_device

Expand All @@ -233,6 +236,9 @@ def prefetching(self) -> None:
try:
self._fetch_next_batch(iterator)
except StopIteration:
# this would only happen when prefetch_batches > the number of batches available and makes
# `fetching_function` jump directly to the empty iterator case without trying to fetch again
self.done = True
break

def fetching_function(self) -> Any:
Expand Down Expand Up @@ -266,6 +272,11 @@ def _fetch_next_batch(self, iterator: Iterator) -> None:
start_output = self.on_fetch_start()
batch = next(iterator)
self.fetched += 1
if not self.prefetch_batches and self._has_len:
# when we don't prefetch but the dataloader is sized, we use the length for `done`
dataloader = self.dataloader
assert isinstance(dataloader, Sized) # `_has_len` is True
self.done = self.fetched >= len(dataloader)
self.on_fetch_end(batch, start_output)

def move_to_device(self, batch: Any) -> Any:
Expand Down Expand Up @@ -360,7 +371,8 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None:
...
"""

def __init__(self) -> None:
def __init__(self, prefetch_batches: int = 0) -> None:
# prefetch batches is not used for this class
super().__init__()
self.store_on_device = False

Expand Down
6 changes: 0 additions & 6 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,16 +648,12 @@ def train_dataloader(self):
"ready": n_epochs,
"started": n_epochs,
"processed": n_epochs,
# TODO: the following "-1" offset will be fixed by
# https://github.com/PyTorchLightning/pytorch-lightning/pull/8578
"completed": n_epochs - 1,
},
"current": {
"ready": n_epochs,
"started": n_epochs,
"processed": n_epochs,
# TODO: the following "-1" offset will be fixed by
# https://github.com/PyTorchLightning/pytorch-lightning/pull/8578
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"completed": n_epochs - 1,
},
},
Expand Down Expand Up @@ -956,8 +952,6 @@ def val_dataloader(self):
# totals are increased by 1 (the failed batch which never completed)
expected = state_dict.copy()

# TODO: `is_last_batch` is not correct on reload, the next line should not be necessary
carmocca marked this conversation as resolved.
Show resolved Hide resolved
expected["epoch_loop.batch_progress"]["is_last_batch"] = val_check_interval == 1.0
assert state_dict_after_restart["epoch_loop.batch_progress"] == expected["epoch_loop.batch_progress"]

val_dl_progress = "epoch_loop.val_loop.dataloader_progress"
Expand Down
16 changes: 6 additions & 10 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,20 +516,16 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
assert len(trainer.test_dataloaders) == 1


def test_error_on_zero_len_dataloader(tmpdir):
"""Test that error is raised if a zero-length dataloader is defined."""

class CustomBoringModel(BoringModel):
def train_dataloader(self):
return DataLoader(RandomDataset(32, 0))

model = CustomBoringModel()
def test_warning_on_zero_len_dataloader(tmpdir):
"""Test that a warning is raised if a zero-length dataloader is defined."""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=1,
)
with pytest.raises(ValueError, match="returned 0 length. .* at least 1 batch"):
trainer.fit(model)
dataloader = DataLoader(RandomDataset(32, 0))
with pytest.warns(UserWarning, match="returned 0 length"):
trainer.fit(model, dataloader)


@RunIf(skip_windows=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/utilities/test_auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1452,7 +1452,7 @@ def load_state_dict(self, state_dict):


class RandomFaultTolerantSampler(RandomSampler):
def __init__(self, *args, seed: int = 0, generator=None, **kwargs):
def __init__(self, *args, seed: int = 0, **kwargs):
generator = torch.Generator().manual_seed(seed)
super().__init__(*args, generator=generator, **kwargs)
self.counter = 0
Expand Down Expand Up @@ -1558,7 +1558,7 @@ def configure_optimizers(self):
seed_everything(42)
model = TestModel(should_fail=True)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval)
with suppress(CustomException):
with pytest.raises(CustomException):
trainer.fit(model)
trainer.train_dataloader = None
failed_batches = model.batches
Expand Down
6 changes: 3 additions & 3 deletions tests/utilities/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __iter__(self):
def test_has_len():
assert has_len(DataLoader(RandomDataset(1, 1)))

with pytest.raises(ValueError, match="`Dataloader` returned 0 length."):
with pytest.warns(UserWarning, match="`DataLoader` returned 0 length."):
assert has_len(DataLoader(RandomDataset(0, 0)))

assert not has_len(DataLoader(RandomIterableDataset(1, 1)))
Expand All @@ -112,8 +112,8 @@ def test_has_len_all_rank():
trainer = Trainer(fast_dev_run=True)
model = BoringModel()

with pytest.raises(MisconfigurationException, match="Total length of `Dataloader` across ranks is zero."):
assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.strategy, model)
with pytest.warns(UserWarning, match="Total length of `DataLoader` across ranks is zero."):
assert has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.strategy, model)

assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.strategy, model)

Expand Down
Loading