Skip to content

Commit

Permalink
Remove TrainingTypePlugin.post_dispatch in favor of teardown (#10939
Browse files Browse the repository at this point in the history
)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
  • Loading branch information
3 people committed Dec 6, 2021
1 parent 629ca09 commit 6bfc0bb
Show file tree
Hide file tree
Showing 11 changed files with 38 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated the access to the attribute `IndexBatchSamplerWrapper.batch_indices` in favor of `IndexBatchSamplerWrapper.seen_batch_indices` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870))


- Deprecated `TrainingTypePlugin.post_dispatch` in favor of `TrainingTypePlugin.teardown` ([#10939](https://github.com/PyTorchLightning/pytorch-lightning/pull/10939))


### Removed

- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,6 @@ def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -
def dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something when ``TrainingTypePlugin.dispatch()`` gets called."""

def post_dispatch(self) -> None:
"""Hook to do something after the training/evaluation/prediction finishes."""

@contextlib.contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""
Expand Down
5 changes: 1 addition & 4 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,6 @@ def pre_dispatch(self, trainer: "pl.Trainer") -> None:
if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()

def post_dispatch(self, trainer: "pl.Trainer") -> None:
self.cluster_environment.teardown()
super().post_dispatch(trainer)

def barrier(self, *args, **kwargs) -> None:
if not distributed_available():
return
Expand Down Expand Up @@ -501,6 +497,7 @@ def reconciliate_processes(self, trace: str) -> None:
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")

def teardown(self) -> None:
super().teardown()
if isinstance(self.model, DistributedDataParallel):
self.model = self.lightning_module

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def register_plugins(cls, plugin_registry: Dict) -> None:
)

def teardown(self) -> None:
super().teardown()
if isinstance(self.model, DistributedDataParallel):
self.model = self.lightning_module

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def test_step_end(self, output):
return output

def teardown(self) -> None:
super().teardown()
if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tup
return [(name, p) for name, p in model.named_parameters() if p in opt_params]

def teardown(self) -> None:
super().teardown()
if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
return self._step(RunningStage.PREDICTING, *args, **kwargs)

def teardown(self) -> None:
super().teardown()
# undo dataloader patching
pl.trainer.data_loading._update_dataloader = self._update_dataloader_original

Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,7 @@ def block_backward_sync(self):
yield None
else:
yield None

def teardown(self) -> None:
self.cluster_environment.teardown()
super().teardown()
15 changes: 13 additions & 2 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT

TBroadcast = TypeVar("TBroadcast")
Expand All @@ -49,6 +51,11 @@ def __init__(
self.optimizers: List[Optimizer] = []
self.lr_schedulers: List[_LRScheduler] = []
self.optimizer_frequencies: List[int] = []
if is_overridden("post_dispatch", self, parent=TrainingTypePlugin):
rank_zero_deprecation(
f"`{self.__class__.__name__}.post_dispatch()` has been deprecated in v1.6 and will be removed in v1.7."
f" Move your implementation to `{self.__class__.__name__}.teardown()` instead."
)

@property
def checkpoint_io(self) -> CheckpointIO:
Expand Down Expand Up @@ -486,5 +493,9 @@ def dispatch(self, trainer: "pl.Trainer") -> None:
self.precision_plugin.dispatch(trainer)

def post_dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something after the training/evaluation/prediction starts."""
self.precision_plugin.post_dispatch()
r"""
.. deprecated::
v1.6 This method has been deprecated in v1.6 and will be removed in v1.7. Use :meth:`teardown` instead.
Hook to do something after the training/evaluation/prediction finishes.
"""
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,9 +1224,9 @@ def _log_hyperparams(self) -> None:
self.logger.save()

def _post_dispatch(self):
self.training_type_plugin.post_dispatch(self)
# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns
# which need to happen before.
self.training_type_plugin.post_dispatch(self)
self.accelerator.teardown()
self._data_connector.teardown()
self._active_loop.teardown()
Expand Down
12 changes: 12 additions & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# limitations under the License.
"""Test deprecated functionality which will be removed in v1.7.0."""
import os
from re import escape
from unittest import mock
from unittest.mock import Mock

import pytest
import torch

from pytorch_lightning import Callback, LightningDataModule, Trainer
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
Expand All @@ -25,6 +27,7 @@
from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor
from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.plugins import SingleDevicePlugin
from pytorch_lightning.plugins.environments import (
KubeflowEnvironment,
LightningEnvironment,
Expand Down Expand Up @@ -539,3 +542,12 @@ def test_v1_7_0_index_batch_sampler_wrapper_batch_indices():

with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"):
sampler.batch_indices = []


def test_v1_7_0_post_dispatch_hook():
class CustomPlugin(SingleDevicePlugin):
def post_dispatch(self, trainer):
pass

with pytest.deprecated_call(match=escape("`CustomPlugin.post_dispatch()` has been deprecated in v1.6")):
CustomPlugin(torch.device("cpu"))

0 comments on commit 6bfc0bb

Please sign in to comment.