From f72daf69a545cbd741cf6567413d8a157cbe7e6d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 25 Jan 2022 19:45:23 +0100 Subject: [PATCH 1/7] Teardown all internal components on exception --- pytorch_lightning/trainer/trainer.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ac01227fd00ac..22fef0bf729d4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -685,14 +685,13 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: except BaseException as exception: self.state.status = TrainerStatus.INTERRUPTED if distributed_available() and self.world_size > 1: - # try syncing remaing processes, kill otherwise + # try syncing remaining processes, kill otherwise self.strategy.reconciliate_processes(traceback.format_exc()) self._on_exception() + self._call_callback_hooks("on_exception", exception) + self._teardown() # reset bookkeeping self.state.stage = None - self._call_callback_hooks("on_exception", exception) - # shutdown workers - self._data_connector.teardown() raise def fit( @@ -1174,6 +1173,7 @@ def _run( self.checkpoint_connector.resume_end() results = self._run_stage() + log.detail(f"{self.__class__.__name__}: trainer tearing down") self._teardown() @@ -1188,8 +1188,7 @@ def _run( log.detail(f"{self.__class__.__name__}: calling teardown hooks") self._call_teardown_hook() - if self.state.status != TrainerStatus.INTERRUPTED: - self.state.status = TrainerStatus.FINISHED + self.state.status = TrainerStatus.FINISHED self.state.stage = None if isinstance(self.strategy, DDPSpawnStrategy): @@ -1240,7 +1239,10 @@ def _teardown(self): self.strategy.post_dispatch(self) self.strategy.teardown() self._data_connector.teardown() - self._active_loop.teardown() + loop = self._active_loop + # loop should never be `None` here but it can because we don't know the trainer stage with `ddp_spawn` + if loop is not None: + loop.teardown() self.logger_connector.teardown() self._signal_connector.teardown() From b88f513900b7e3944cc78a339907fc874479974d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 25 Jan 2022 19:50:12 +0100 Subject: [PATCH 2/7] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa7c4f9b056bc..f370985803a7f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -69,6 +69,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247)) +- Teardown the active loop and strategy on exception ([#11620](https://github.com/PyTorchLightning/pytorch-lightning/pull/11620)) + + - Added a `MisconfigurationException` if user provided `opt_idx` in scheduler config doesn't match with actual optimizer index of its respective optimizer ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247)) From 8d49f52a5382357ad7f8b4b47347c41f1e102389 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 25 Jan 2022 19:52:07 +0100 Subject: [PATCH 3/7] Move data fetcher ownership to the loops --- .../loops/dataloader/evaluation_loop.py | 17 +++-- pytorch_lightning/loops/fit_loop.py | 45 ++++++++++-- .../trainer/connectors/data_connector.py | 71 +------------------ 3 files changed, 55 insertions(+), 78 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 1e0b30cab03c7..658d7594859f5 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict -from typing import Any, List, Sequence, Union +from functools import partial +from typing import Any, List, Optional, Sequence, Union import torch from deprecate.utils import void @@ -22,6 +23,7 @@ from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection from pytorch_lightning.trainer.states import RunningStage, TrainerFn +from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -38,6 +40,7 @@ def __init__(self, verbose: bool = True) -> None: self._logged_outputs: List[_OUT_DICT] = [] self._max_batches: List[int] = [] self._has_run: bool = False + self._data_fetcher: Optional[AbstractDataFetcher] = None @property def num_dataloaders(self) -> int: @@ -99,6 +102,8 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None: hooks.""" void(*args, **kwargs) + self._data_fetcher = DataFetcher() + # hook self._on_evaluation_model_eval() self.trainer.lightning_module.zero_grad() @@ -111,15 +116,16 @@ def advance(self, *args: Any, **kwargs: Any) -> None: dataloader_idx = self.current_dataloader_idx dataloader = self.trainer.strategy.process_dataloader(self.current_dataloader) - self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader( - dataloader, dataloader_idx=dataloader_idx + self._data_fetcher.setup( + dataloader, + batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=dataloader_idx), ) dl_max_batches = self._max_batches[dataloader_idx] kwargs = OrderedDict() if self.num_dataloaders > 1: kwargs["dataloader_idx"] = dataloader_idx - dl_outputs = self.epoch_loop.run(dataloader, dl_max_batches, kwargs) + dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs) # store batch level output per dataloader self._outputs.append(dl_outputs) @@ -169,6 +175,9 @@ def on_run_end(self) -> List[_OUT_DICT]: return logged_outputs def teardown(self) -> None: + if self._data_fetcher is not None: + self._data_fetcher.teardown() + self._data_fetcher = None self._results.cpu() self.epoch_loop.teardown() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index bed80f5b962c5..4651aeedf256b 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +import os +from functools import partial +from typing import Optional, Type +import pytorch_lightning as pl +from pytorch_lightning.accelerators import GPUAccelerator from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE @@ -21,9 +25,16 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.fetching import ( + AbstractDataFetcher, + DataFetcher, + DataLoaderIterDataFetcher, + InterBatchParallelDataFetcher, +) from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature log = logging.getLogger(__name__) @@ -55,6 +66,7 @@ def __init__( self._is_fresh_start_epoch: bool = True self._outputs: _EPOCH_OUTPUTS_TYPE = [] + self._data_fetcher: Optional[AbstractDataFetcher] = None @property def current_epoch(self) -> int: @@ -195,8 +207,12 @@ def on_run_start(self) -> None: # type: ignore[override] """Calls the ``on_train_start`` hook.""" # reset train dataloader and val dataloader self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module) + data_fetcher_cls = _select_data_fetcher(self.trainer) + self._data_fetcher = data_fetcher_cls() + self._is_fresh_start_epoch = True self._results.to(device=self.trainer.lightning_module.device) + self.trainer._call_callback_hooks("on_train_start") self.trainer._call_lightning_module_hook("on_train_start") self.trainer._call_strategy_hook("on_train_start") @@ -244,10 +260,11 @@ def advance(self) -> None: # type: ignore[override] log.detail(f"{self.__class__.__name__}: advancing loop") assert self.trainer.train_dataloader is not None dataloader = self.trainer.strategy.process_dataloader(self.trainer.train_dataloader) - data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader, 0) - + self._data_fetcher.setup( + dataloader, batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=0) + ) with self.trainer.profiler.profile("run_training_epoch"): - self._outputs = self.epoch_loop.run(data_fetcher) + self._outputs = self.epoch_loop.run(self._data_fetcher) def on_advance_end(self) -> None: # inform logger the batch loop has finished @@ -318,8 +335,26 @@ def on_run_end(self) -> None: self.trainer.strategy.on_train_end() def teardown(self) -> None: + if self._data_fetcher is not None: + self._data_fetcher.teardown() + self._data_fetcher = None self.epoch_loop.teardown() def _should_accumulate(self) -> bool: """Whether the gradients should be accumulated.""" return self.epoch_loop._should_accumulate() + + +def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]: + training_step_fx = getattr(trainer.lightning_module, "training_step") + if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): + rank_zero_warn( + "Found `dataloader_iter` argument in the `training_step`. Note that the support for " + "this signature is experimental and the behavior is subject to change." + ) + return DataLoaderIterDataFetcher + elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": + if not isinstance(trainer.accelerator, GPUAccelerator): + raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.") + return InterBatchParallelDataFetcher + return DataFetcher diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 688f6c9a4b658..f249245611223 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -14,8 +14,7 @@ import multiprocessing import os from dataclasses import dataclass -from functools import partial -from typing import Any, Collection, Iterable, List, Optional, Tuple, Union +from typing import Any, Collection, List, Optional, Tuple, Union from weakref import proxy from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler @@ -23,7 +22,6 @@ from torch.utils.data.distributed import DistributedSampler import pytorch_lightning as pl -from pytorch_lightning.accelerators import GPUAccelerator from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator @@ -42,47 +40,21 @@ ) from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.fetching import ( - AbstractDataFetcher, - DataFetcher, - DataLoaderIterDataFetcher, - InterBatchParallelDataFetcher, -) from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from pytorch_lightning.utilities.warnings import PossibleUserWarning, rank_zero_warn class DataConnector: - def __init__( - self, - trainer: "pl.Trainer", - multiple_trainloader_mode: str = "max_size_cycle", - train_data_fetcher: Optional[AbstractDataFetcher] = None, - validate_data_fetcher: Optional[AbstractDataFetcher] = None, - test_data_fetcher: Optional[AbstractDataFetcher] = None, - ): + def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"): self.trainer = trainer self.multiple_trainloader_mode = multiple_trainloader_mode - - self.train_data_fetcher = train_data_fetcher - self.validate_data_fetcher = validate_data_fetcher - self.test_data_fetcher = test_data_fetcher - self.sanity_check_data_fetcher: Optional[AbstractDataFetcher] = None - self._train_dataloader_source = _DataLoaderSource(None, "") self._val_dataloader_source = _DataLoaderSource(None, "") self._test_dataloader_source = _DataLoaderSource(None, "") self._predict_dataloader_source = _DataLoaderSource(None, "") - @property - def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]: - if self.trainer.sanity_checking: - return self.sanity_check_data_fetcher - return self.test_data_fetcher if self.trainer.testing else self.validate_data_fetcher - @property def _should_reload_train_dl(self) -> bool: """Check if train dataloader should be reloaded.""" @@ -126,33 +98,6 @@ def on_trainer_init( self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs self.trainer._is_data_prepared = False - def _select_data_fetcher(self) -> AbstractDataFetcher: - if not self.trainer.training: - return DataFetcher() - - training_step_fx = getattr(self.trainer.lightning_module, "training_step") - if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): - rank_zero_warn( - "Found `dataloader_iter` argument in the `training_step`. Note that the support for " - "this signature is experimental and the behavior is subject to change." - ) - return DataLoaderIterDataFetcher() - elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": - if not isinstance(self.trainer.accelerator, GPUAccelerator): - raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.") - return InterBatchParallelDataFetcher() - return DataFetcher() - - def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int) -> Iterable: - stage: str = self.trainer.state.stage.value - data_fetcher = getattr(self, f"{stage}_data_fetcher", None) or self._select_data_fetcher() - data_fetcher.setup( - dataloader, - batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=dataloader_idx), - ) - setattr(self, f"{stage}_data_fetcher", data_fetcher) - return data_fetcher - def prepare_data(self) -> None: # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 # or in the case where each node needs to do its own manipulation in which case just local_rank=0 @@ -543,18 +488,6 @@ def _check_eval_shuffling(dataloader, mode): ) def teardown(self) -> None: - if self.train_data_fetcher: - self.train_data_fetcher.teardown() - self.train_data_fetcher = None - if self.validate_data_fetcher: - self.validate_data_fetcher.teardown() - self.validate_data_fetcher = None - if self.test_data_fetcher: - self.test_data_fetcher.teardown() - self.test_data_fetcher = None - if self.sanity_check_data_fetcher: - self.sanity_check_data_fetcher.teardown() - self.sanity_check_data_fetcher = None _teardown_dataloader_get_iterators() From d35bcd6783e1dc35fd98f0aad09e00919b89f431 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 2 Feb 2022 22:57:30 +0100 Subject: [PATCH 4/7] Fix tests --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 1 + pytorch_lightning/utilities/fetching.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 658d7594859f5..0f9ef7946a04f 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -116,6 +116,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: dataloader_idx = self.current_dataloader_idx dataloader = self.trainer.strategy.process_dataloader(self.current_dataloader) + assert self._data_fetcher is not None self._data_fetcher.setup( dataloader, batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=dataloader_idx), diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 5b8012468dfef..a51f223484729 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -188,10 +188,10 @@ def reset(self) -> None: def teardown(self) -> None: self.reset() - if isinstance(self.dataloader, CombinedLoader): - self.dataloader.reset() - if isinstance(self.dataloader, DataLoader): - CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader) + if isinstance(self._dataloader, CombinedLoader): + self._dataloader.reset() + if isinstance(self._dataloader, DataLoader): + CombinedLoader._shutdown_workers_and_reset_iterator(self._dataloader) self.dataloader_iter = None _teardown_dataloader_get_iterators() From e251199f7530eba4168c262f14fff102b54cff1f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 2 Feb 2022 23:03:15 +0100 Subject: [PATCH 5/7] Fix test --- pytorch_lightning/trainer/connectors/data_connector.py | 8 +------- pytorch_lightning/trainer/trainer.py | 1 - tests/utilities/test_fetching.py | 4 ++-- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index f249245611223..c6abd4c51a7cb 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -27,10 +27,7 @@ from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.auto_restart import ( - _teardown_dataloader_get_iterators, - _validate_fault_tolerant_automatic, -) +from pytorch_lightning.utilities.auto_restart import _validate_fault_tolerant_automatic from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, _replace_dataloader_init_method, @@ -487,9 +484,6 @@ def _check_eval_shuffling(dataloader, mode): category=PossibleUserWarning, ) - def teardown(self) -> None: - _teardown_dataloader_get_iterators() - @dataclass class _DataLoaderSource: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3a992236e003d..97937be28f165 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1237,7 +1237,6 @@ def _teardown(self): Callback; those are handled by :meth:`_call_teardown_hook`.""" self.strategy.post_dispatch(self) self.strategy.teardown() - self._data_connector.teardown() loop = self._active_loop # loop should never be `None` here but it can because we don't know the trainer stage with `ddp_spawn` if loop is not None: diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 0d62f35ee79c7..67eccc904d669 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -228,7 +228,7 @@ def __init__(self, *args, automatic_optimization: bool = False, **kwargs): def training_step(self, dataloader_iter, batch_idx): assert self.count == batch_idx - assert isinstance(self.trainer._data_connector.train_data_fetcher, DataLoaderIterDataFetcher) + assert isinstance(self.trainer.fit_loop._data_fetcher, DataLoaderIterDataFetcher) # fetch 2 batches self.batches.append(next(dataloader_iter)) self.batches.append(next(dataloader_iter)) @@ -251,7 +251,7 @@ def training_step(self, dataloader_iter, batch_idx): def training_epoch_end(self, *_): assert self.trainer.fit_loop.epoch_loop.batch_progress.current.ready == 33 - assert self.trainer._data_connector.train_data_fetcher.fetched == 64 + assert self.trainer.fit_loop._data_fetcher.fetched == 64 assert self.count == 64 model = TestModel(automatic_optimization=automatic_optimization) From 31597945b4d5e862fa9b25ad10ed567b4a3b21e3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 2 Feb 2022 23:05:23 +0100 Subject: [PATCH 6/7] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b6d7f61bcacc..74595956d467f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -122,6 +122,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Moved ownership of the lightning optimizers from the `Trainer` to the `Strategy` ([#11444](https://github.com/PyTorchLightning/pytorch-lightning/pull/11444)) +- Moved ownership of the data fetchers from the DataConnector to the Loops ([#11621](https://github.com/PyTorchLightning/pytorch-lightning/pull/11621)) + + - Moved `batch_to_device` method from `Accelerator` to `TrainingTypePlugin` ([#10649](https://github.com/PyTorchLightning/pytorch-lightning/pull/10649)) From b93ffc3fc50fa10bbdec2758fe5cd871b7bc82d9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Feb 2022 08:14:59 +0100 Subject: [PATCH 7/7] Fix test --- tests/utilities/test_fetching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 67eccc904d669..21397fc5b7dbc 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -180,7 +180,7 @@ def __init__(self, check_inter_batch): self._check_inter_batch = check_inter_batch def on_train_epoch_end(self, trainer, lightning_module): - fetcher = trainer._data_connector.train_data_fetcher + fetcher = trainer.fit_loop._data_fetcher assert isinstance(fetcher, InterBatchParallelDataFetcher if self._check_inter_batch else DataFetcher) assert fetcher.prefetch_batches == 1