From 6bfc0bbc565045d73ee6e00d604201b822619a2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 6 Dec 2021 23:27:30 +0100 Subject: [PATCH] Remove `TrainingTypePlugin.post_dispatch` in favor of `teardown` (#10939) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- CHANGELOG.md | 3 +++ .../plugins/precision/precision_plugin.py | 3 --- pytorch_lightning/plugins/training_type/ddp.py | 5 +---- .../plugins/training_type/ddp_spawn.py | 1 + pytorch_lightning/plugins/training_type/dp.py | 1 + .../plugins/training_type/horovod.py | 1 + pytorch_lightning/plugins/training_type/ipu.py | 1 + .../plugins/training_type/parallel.py | 4 ++++ .../plugins/training_type/training_type_plugin.py | 15 +++++++++++++-- pytorch_lightning/trainer/trainer.py | 2 +- tests/deprecated_api/test_remove_1-7.py | 12 ++++++++++++ 11 files changed, 38 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d1930f7658a3..fa9bd9c0ce71b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -116,6 +116,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the access to the attribute `IndexBatchSamplerWrapper.batch_indices` in favor of `IndexBatchSamplerWrapper.seen_batch_indices` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870)) +- Deprecated `TrainingTypePlugin.post_dispatch` in favor of `TrainingTypePlugin.teardown` ([#10939](https://github.com/PyTorchLightning/pytorch-lightning/pull/10939)) + + ### Removed - Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507)) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 928f1667aed82..109be55b8dd63 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -207,9 +207,6 @@ def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) - def dispatch(self, trainer: "pl.Trainer") -> None: """Hook to do something when ``TrainingTypePlugin.dispatch()`` gets called.""" - def post_dispatch(self) -> None: - """Hook to do something after the training/evaluation/prediction finishes.""" - @contextlib.contextmanager def forward_context(self) -> Generator[None, None, None]: """A contextmanager for managing model forward/training_step/evaluation_step/predict_step.""" diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index fc42949eb593e..62d198536b877 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -356,10 +356,6 @@ def pre_dispatch(self, trainer: "pl.Trainer") -> None: if trainer_fn == TrainerFn.FITTING: self.configure_ddp() - def post_dispatch(self, trainer: "pl.Trainer") -> None: - self.cluster_environment.teardown() - super().post_dispatch(trainer) - def barrier(self, *args, **kwargs) -> None: if not distributed_available(): return @@ -501,6 +497,7 @@ def reconciliate_processes(self, trace: str) -> None: raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}") def teardown(self) -> None: + super().teardown() if isinstance(self.model, DistributedDataParallel): self.model = self.lightning_module diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 328f04b05c908..661ba20f2e560 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -395,6 +395,7 @@ def register_plugins(cls, plugin_registry: Dict) -> None: ) def teardown(self) -> None: + super().teardown() if isinstance(self.model, DistributedDataParallel): self.model = self.lightning_module diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index b273988e5ab77..3016ee7462366 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -150,6 +150,7 @@ def test_step_end(self, output): return output def teardown(self) -> None: + super().teardown() if self.on_gpu: # GPU teardown self.lightning_module.cpu() diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 3efb008542169..84a41d5a5f30c 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -217,6 +217,7 @@ def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tup return [(name, p) for name, p in model.named_parameters() if p in opt_params] def teardown(self) -> None: + super().teardown() if self.on_gpu: # GPU teardown self.lightning_module.cpu() diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index c8d6835ec3417..2763ad645facb 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -276,6 +276,7 @@ def predict_step(self, *args, **kwargs) -> STEP_OUTPUT: return self._step(RunningStage.PREDICTING, *args, **kwargs) def teardown(self) -> None: + super().teardown() # undo dataloader patching pl.trainer.data_loading._update_dataloader = self._update_dataloader_original diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 3a05455b87990..2dc2a95f03d38 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -132,3 +132,7 @@ def block_backward_sync(self): yield None else: yield None + + def teardown(self) -> None: + self.cluster_environment.teardown() + super().teardown() diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index c83a8484821a3..fc5de9863665a 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -28,8 +28,10 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import ReduceOp +from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT TBroadcast = TypeVar("TBroadcast") @@ -49,6 +51,11 @@ def __init__( self.optimizers: List[Optimizer] = [] self.lr_schedulers: List[_LRScheduler] = [] self.optimizer_frequencies: List[int] = [] + if is_overridden("post_dispatch", self, parent=TrainingTypePlugin): + rank_zero_deprecation( + f"`{self.__class__.__name__}.post_dispatch()` has been deprecated in v1.6 and will be removed in v1.7." + f" Move your implementation to `{self.__class__.__name__}.teardown()` instead." + ) @property def checkpoint_io(self) -> CheckpointIO: @@ -486,5 +493,9 @@ def dispatch(self, trainer: "pl.Trainer") -> None: self.precision_plugin.dispatch(trainer) def post_dispatch(self, trainer: "pl.Trainer") -> None: - """Hook to do something after the training/evaluation/prediction starts.""" - self.precision_plugin.post_dispatch() + r""" + .. deprecated:: + v1.6 This method has been deprecated in v1.6 and will be removed in v1.7. Use :meth:`teardown` instead. + + Hook to do something after the training/evaluation/prediction finishes. + """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5624d376bf816..88f929531ea95 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1224,9 +1224,9 @@ def _log_hyperparams(self) -> None: self.logger.save() def _post_dispatch(self): - self.training_type_plugin.post_dispatch(self) # these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns # which need to happen before. + self.training_type_plugin.post_dispatch(self) self.accelerator.teardown() self._data_connector.teardown() self._active_loop.teardown() diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 6f7e1199ab438..986491a306ce7 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -13,10 +13,12 @@ # limitations under the License. """Test deprecated functionality which will be removed in v1.7.0.""" import os +from re import escape from unittest import mock from unittest.mock import Mock import pytest +import torch from pytorch_lightning import Callback, LightningDataModule, Trainer from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor @@ -25,6 +27,7 @@ from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper +from pytorch_lightning.plugins import SingleDevicePlugin from pytorch_lightning.plugins.environments import ( KubeflowEnvironment, LightningEnvironment, @@ -539,3 +542,12 @@ def test_v1_7_0_index_batch_sampler_wrapper_batch_indices(): with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"): sampler.batch_indices = [] + + +def test_v1_7_0_post_dispatch_hook(): + class CustomPlugin(SingleDevicePlugin): + def post_dispatch(self, trainer): + pass + + with pytest.deprecated_call(match=escape("`CustomPlugin.post_dispatch()` has been deprecated in v1.6")): + CustomPlugin(torch.device("cpu"))