Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove TrainingTypePlugin.post_dispatch in favor of teardown #10939

Merged
merged 14 commits into from
Dec 6, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated the access to the attribute `IndexBatchSamplerWrapper.batch_indices` in favor of `IndexBatchSamplerWrapper.seen_batch_indices` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870))


- Deprecated `TrainingTypePlugin.post_dispatch` in favor of `TrainingTypePlugin.teardown` ([#10939](https://github.com/PyTorchLightning/pytorch-lightning/pull/10939))


### Removed

- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,6 @@ def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -
def dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something when ``TrainingTypePlugin.dispatch()`` gets called."""

def post_dispatch(self) -> None:
"""Hook to do something after the training/evaluation/prediction finishes."""

@contextlib.contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""
Expand Down
5 changes: 1 addition & 4 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,6 @@ def pre_dispatch(self, trainer: "pl.Trainer") -> None:
if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()

def post_dispatch(self, trainer: "pl.Trainer") -> None:
self.cluster_environment.teardown()
super().post_dispatch(trainer)

def barrier(self, *args, **kwargs) -> None:
if not distributed_available():
return
Expand Down Expand Up @@ -501,6 +497,7 @@ def reconciliate_processes(self, trace: str) -> None:
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")

def teardown(self) -> None:
super().teardown()
if isinstance(self.model, DistributedDataParallel):
self.model = self.lightning_module

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def register_plugins(cls, plugin_registry: Dict) -> None:
)

def teardown(self) -> None:
super().teardown()
if isinstance(self.model, DistributedDataParallel):
self.model = self.lightning_module

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def test_step_end(self, output):
return output

def teardown(self) -> None:
super().teardown()
if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tup
return [(name, p) for name, p in model.named_parameters() if p in opt_params]

def teardown(self) -> None:
super().teardown()
if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
return self._step(RunningStage.PREDICTING, *args, **kwargs)

def teardown(self) -> None:
super().teardown()
# undo dataloader patching
pl.trainer.data_loading._update_dataloader = self._update_dataloader_original

Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,7 @@ def block_backward_sync(self):
yield None
else:
yield None

def teardown(self) -> None:
self.cluster_environment.teardown()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
super().teardown()
15 changes: 13 additions & 2 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT

TBroadcast = TypeVar("TBroadcast")
Expand All @@ -49,6 +51,11 @@ def __init__(
self.optimizers: List[Optimizer] = []
self.lr_schedulers: List[_LRScheduler] = []
self.optimizer_frequencies: List[int] = []
if is_overridden("post_dispatch", self, parent=TrainingTypePlugin):
rank_zero_deprecation(
f"`{self.__class__.__name__}.post_dispatch()` has been deprecated in v1.6 and will be removed in v1.7."
f" Move your implementation to `{self.__class__.__name__}.teardown()` instead."
)

@property
def checkpoint_io(self) -> CheckpointIO:
Expand Down Expand Up @@ -486,5 +493,9 @@ def dispatch(self, trainer: "pl.Trainer") -> None:
self.precision_plugin.dispatch(trainer)

def post_dispatch(self, trainer: "pl.Trainer") -> None:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""Hook to do something after the training/evaluation/prediction starts."""
self.precision_plugin.post_dispatch()
r"""
.. deprecated::
v1.6 This method has been deprecated in v1.6 and will be removed in v1.7. Use :meth:`teardown` instead.

Hook to do something after the training/evaluation/prediction finishes.
"""
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,9 +1224,9 @@ def _log_hyperparams(self) -> None:
self.logger.save()

def _post_dispatch(self):
self.training_type_plugin.post_dispatch(self)
# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns
# which need to happen before.
self.training_type_plugin.post_dispatch(self)
self.accelerator.teardown()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
self._data_connector.teardown()
self._active_loop.teardown()
Expand Down
12 changes: 12 additions & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# limitations under the License.
"""Test deprecated functionality which will be removed in v1.7.0."""
import os
from re import escape
from unittest import mock
from unittest.mock import Mock

import pytest
import torch

from pytorch_lightning import Callback, LightningDataModule, Trainer
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
Expand All @@ -25,6 +27,7 @@
from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor
from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.plugins import SingleDevicePlugin
from pytorch_lightning.plugins.environments import (
KubeflowEnvironment,
LightningEnvironment,
Expand Down Expand Up @@ -539,3 +542,12 @@ def test_v1_7_0_index_batch_sampler_wrapper_batch_indices():

with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"):
sampler.batch_indices = []


def test_v1_7_0_post_dispatch_hook():
class CustomPlugin(SingleDevicePlugin):
def post_dispatch(self, trainer):
pass

with pytest.deprecated_call(match=escape("`CustomPlugin.post_dispatch()` has been deprecated in v1.6")):
CustomPlugin(torch.device("cpu"))
Loading