Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move data fetcher ownership to the loops #11621

Merged
merged 11 commits into from
Feb 9, 2022
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Moved ownership of the lightning optimizers from the `Trainer` to the `Strategy` ([#11444](https://github.com/PyTorchLightning/pytorch-lightning/pull/11444))


- Moved ownership of the data fetchers from the DataConnector to the Loops ([#11621](https://github.com/PyTorchLightning/pytorch-lightning/pull/11621))


- Moved `batch_to_device` method from `Accelerator` to `TrainingTypePlugin` ([#10649](https://github.com/PyTorchLightning/pytorch-lightning/pull/10649))


Expand Down
18 changes: 14 additions & 4 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import Any, List, Sequence, Union
from functools import partial
from typing import Any, List, Optional, Sequence, Union

import torch
from deprecate.utils import void
Expand All @@ -22,6 +23,7 @@
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher
from pytorch_lightning.utilities.types import EPOCH_OUTPUT


Expand All @@ -38,6 +40,7 @@ def __init__(self, verbose: bool = True) -> None:
self._logged_outputs: List[_OUT_DICT] = []
self._max_batches: List[int] = []
self._has_run: bool = False
self._data_fetcher: Optional[AbstractDataFetcher] = None

@property
def num_dataloaders(self) -> int:
Expand Down Expand Up @@ -99,6 +102,8 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None:
hooks."""
void(*args, **kwargs)

self._data_fetcher = DataFetcher()

# hook
self._on_evaluation_model_eval()
self.trainer.lightning_module.zero_grad()
Expand All @@ -111,15 +116,17 @@ def advance(self, *args: Any, **kwargs: Any) -> None:

dataloader_idx = self.current_dataloader_idx
dataloader = self.trainer.strategy.process_dataloader(self.current_dataloader)
self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader(
dataloader, dataloader_idx=dataloader_idx
assert self._data_fetcher is not None
self._data_fetcher.setup(
dataloader,
batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=dataloader_idx),
)
dl_max_batches = self._max_batches[dataloader_idx]

kwargs = OrderedDict()
if self.num_dataloaders > 1:
kwargs["dataloader_idx"] = dataloader_idx
dl_outputs = self.epoch_loop.run(dataloader, dl_max_batches, kwargs)
dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)

# store batch level output per dataloader
self._outputs.append(dl_outputs)
Expand Down Expand Up @@ -169,6 +176,9 @@ def on_run_end(self) -> List[_OUT_DICT]:
return logged_outputs

def teardown(self) -> None:
if self._data_fetcher is not None:
self._data_fetcher.teardown()
self._data_fetcher = None
self._results.cpu()
self.epoch_loop.teardown()

Expand Down
42 changes: 38 additions & 4 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
# limitations under the License.
import logging
import math
from typing import Optional
import os
from functools import partial
from typing import Optional, Type

import pytorch_lightning as pl
from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.epoch import TrainingEpochLoop
from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE
Expand All @@ -25,7 +29,14 @@
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import (
AbstractDataFetcher,
DataFetcher,
DataLoaderIterDataFetcher,
InterBatchParallelDataFetcher,
)
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import rank_zero_warn

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -58,6 +69,7 @@ def __init__(

self._is_fresh_start_epoch: bool = True
self._outputs: _EPOCH_OUTPUTS_TYPE = []
self._data_fetcher: Optional[AbstractDataFetcher] = None

@property
def global_step(self) -> int:
Expand Down Expand Up @@ -184,6 +196,8 @@ def on_run_start(self) -> None: # type: ignore[override]
"""Calls the ``on_train_start`` hook."""
# reset train dataloader and val dataloader
self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)
data_fetcher_cls = _select_data_fetcher(self.trainer)
self._data_fetcher = data_fetcher_cls()

ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
if not ft_enabled and self.restarting and self.trainer.num_training_batches not in (0, float("inf")):
Expand All @@ -204,6 +218,7 @@ def on_run_start(self) -> None: # type: ignore[override]

self._is_fresh_start_epoch = True
self._results.to(device=self.trainer.lightning_module.device)

self.trainer._call_callback_hooks("on_train_start")
self.trainer._call_lightning_module_hook("on_train_start")
self.trainer._call_strategy_hook("on_train_start")
Expand Down Expand Up @@ -251,10 +266,11 @@ def advance(self) -> None: # type: ignore[override]
log.detail(f"{self.__class__.__name__}: advancing loop")
assert self.trainer.train_dataloader is not None
dataloader = self.trainer.strategy.process_dataloader(self.trainer.train_dataloader)
data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader, 0)

self._data_fetcher.setup(
dataloader, batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=0)
)
with self.trainer.profiler.profile("run_training_epoch"):
self._outputs = self.epoch_loop.run(data_fetcher)
self._outputs = self.epoch_loop.run(self._data_fetcher)

def on_advance_end(self) -> None:
# inform logger the batch loop has finished
Expand Down Expand Up @@ -325,8 +341,26 @@ def on_run_end(self) -> None:
self.trainer.strategy.on_train_end()

def teardown(self) -> None:
if self._data_fetcher is not None:
self._data_fetcher.teardown()
self._data_fetcher = None
self.epoch_loop.teardown()

def _should_accumulate(self) -> bool:
"""Whether the gradients should be accumulated."""
return self.epoch_loop._should_accumulate()


def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
training_step_fx = getattr(trainer.lightning_module, "training_step")
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
rank_zero_warn(
"Found `dataloader_iter` argument in the `training_step`. Note that the support for "
"this signature is experimental and the behavior is subject to change."
)
return DataLoaderIterDataFetcher
elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
if not isinstance(trainer.accelerator, GPUAccelerator):
raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")
return InterBatchParallelDataFetcher
return DataFetcher
carmocca marked this conversation as resolved.
Show resolved Hide resolved
79 changes: 3 additions & 76 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,20 @@
import multiprocessing
import os
from dataclasses import dataclass
from functools import partial
from typing import Any, Collection, Iterable, List, Optional, Tuple, Union
from typing import Any, Collection, List, Optional, Tuple, Union
from weakref import proxy

from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
from torch.utils.data.dataset import IterableDataset
from torch.utils.data.distributed import DistributedSampler

import pytorch_lightning as pl
from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import (
_teardown_dataloader_get_iterators,
_validate_fault_tolerant_automatic,
)
from pytorch_lightning.utilities.auto_restart import _validate_fault_tolerant_automatic
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_replace_dataloader_init_method,
Expand All @@ -42,47 +37,21 @@
)
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import (
AbstractDataFetcher,
DataFetcher,
DataLoaderIterDataFetcher,
InterBatchParallelDataFetcher,
)
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from pytorch_lightning.utilities.warnings import PossibleUserWarning, rank_zero_warn


class DataConnector:
def __init__(
self,
trainer: "pl.Trainer",
multiple_trainloader_mode: str = "max_size_cycle",
train_data_fetcher: Optional[AbstractDataFetcher] = None,
validate_data_fetcher: Optional[AbstractDataFetcher] = None,
test_data_fetcher: Optional[AbstractDataFetcher] = None,
):
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
self.trainer = trainer
self.multiple_trainloader_mode = multiple_trainloader_mode

self.train_data_fetcher = train_data_fetcher
self.validate_data_fetcher = validate_data_fetcher
self.test_data_fetcher = test_data_fetcher
self.sanity_check_data_fetcher: Optional[AbstractDataFetcher] = None

self._train_dataloader_source = _DataLoaderSource(None, "")
self._val_dataloader_source = _DataLoaderSource(None, "")
self._test_dataloader_source = _DataLoaderSource(None, "")
self._predict_dataloader_source = _DataLoaderSource(None, "")

@property
def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]:
if self.trainer.sanity_checking:
return self.sanity_check_data_fetcher
return self.test_data_fetcher if self.trainer.testing else self.validate_data_fetcher

@property
def _should_reload_train_dl(self) -> bool:
"""Check if train dataloader should be reloaded."""
Expand Down Expand Up @@ -126,33 +95,6 @@ def on_trainer_init(
self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
self.trainer._is_data_prepared = False

def _select_data_fetcher(self) -> AbstractDataFetcher:
if not self.trainer.training:
return DataFetcher()

training_step_fx = getattr(self.trainer.lightning_module, "training_step")
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
rank_zero_warn(
"Found `dataloader_iter` argument in the `training_step`. Note that the support for "
"this signature is experimental and the behavior is subject to change."
)
return DataLoaderIterDataFetcher()
elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
if not isinstance(self.trainer.accelerator, GPUAccelerator):
raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")
return InterBatchParallelDataFetcher()
return DataFetcher()

def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int) -> Iterable:
stage: str = self.trainer.state.stage.value
data_fetcher = getattr(self, f"{stage}_data_fetcher", None) or self._select_data_fetcher()
data_fetcher.setup(
dataloader,
batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=dataloader_idx),
)
setattr(self, f"{stage}_data_fetcher", data_fetcher)
return data_fetcher

def prepare_data(self) -> None:
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
Expand Down Expand Up @@ -542,21 +484,6 @@ def _check_eval_shuffling(dataloader, mode):
category=PossibleUserWarning,
)

def teardown(self) -> None:
if self.train_data_fetcher:
self.train_data_fetcher.teardown()
self.train_data_fetcher = None
if self.validate_data_fetcher:
self.validate_data_fetcher.teardown()
self.validate_data_fetcher = None
if self.test_data_fetcher:
self.test_data_fetcher.teardown()
self.test_data_fetcher = None
if self.sanity_check_data_fetcher:
self.sanity_check_data_fetcher.teardown()
self.sanity_check_data_fetcher = None
_teardown_dataloader_get_iterators()


@dataclass
class _DataLoaderSource:
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,6 @@ def _teardown(self):
Callback; those are handled by :meth:`_call_teardown_hook`."""
self.strategy.post_dispatch(self)
self.strategy.teardown()
self._data_connector.teardown()
loop = self._active_loop
# loop should never be `None` here but it can because we don't know the trainer stage with `ddp_spawn`
if loop is not None:
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/utilities/fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ def reset(self) -> None:

def teardown(self) -> None:
self.reset()
if isinstance(self.dataloader, CombinedLoader):
self.dataloader.reset()
if isinstance(self.dataloader, DataLoader):
CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader)
if isinstance(self._dataloader, CombinedLoader):
self._dataloader.reset()
if isinstance(self._dataloader, DataLoader):
CombinedLoader._shutdown_workers_and_reset_iterator(self._dataloader)
self.dataloader_iter = None
_teardown_dataloader_get_iterators()

Expand Down
6 changes: 3 additions & 3 deletions tests/utilities/test_fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def __init__(self, check_inter_batch):
self._check_inter_batch = check_inter_batch

def on_train_epoch_end(self, trainer, lightning_module):
fetcher = trainer._data_connector.train_data_fetcher
fetcher = trainer.fit_loop._data_fetcher
assert isinstance(fetcher, InterBatchParallelDataFetcher if self._check_inter_batch else DataFetcher)
assert fetcher.prefetch_batches == 1

Expand Down Expand Up @@ -228,7 +228,7 @@ def __init__(self, *args, automatic_optimization: bool = False, **kwargs):

def training_step(self, dataloader_iter, batch_idx):
assert self.count == batch_idx
assert isinstance(self.trainer._data_connector.train_data_fetcher, DataLoaderIterDataFetcher)
assert isinstance(self.trainer.fit_loop._data_fetcher, DataLoaderIterDataFetcher)
# fetch 2 batches
self.batches.append(next(dataloader_iter))
self.batches.append(next(dataloader_iter))
Expand All @@ -251,7 +251,7 @@ def training_step(self, dataloader_iter, batch_idx):

def training_epoch_end(self, *_):
assert self.trainer.fit_loop.epoch_loop.batch_progress.current.ready == 33
assert self.trainer._data_connector.train_data_fetcher.fetched == 64
assert self.trainer.fit_loop._data_fetcher.fetched == 64
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
assert self.count == 64

model = TestModel(automatic_optimization=automatic_optimization)
Expand Down