From b497fb80e53238ad345c6914be17c8b1e1a6577b Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Thu, 26 Aug 2021 17:51:05 -0700 Subject: [PATCH] Remove reference to DistributedDataParallel from parallel plugin teardown (#8943) --- CHANGELOG.md | 3 +++ pytorch_lightning/plugins/training_type/ddp.py | 10 ++++++++++ pytorch_lightning/plugins/training_type/ddp_spawn.py | 10 ++++++++++ pytorch_lightning/plugins/training_type/dp.py | 7 +++++++ pytorch_lightning/plugins/training_type/horovod.py | 7 +++++++ pytorch_lightning/plugins/training_type/parallel.py | 12 ------------ 6 files changed, 37 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 36d90ae213fdb..bba7ed346980c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -217,6 +217,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `Plugin` in `base_plugin.py`, access `TrainingTypePlugin` and `PrecisionPlugin` directly instead ([#9066](https://github.com/PyTorchLightning/pytorch-lightning/pull/9066)) +- Removed `teardown` from `ParallelPlugin` ([#8943](https://github.com/PyTorchLightning/pytorch-lightning/pull/8943)) + + ### Fixed - Fixed save/load/resume from checkpoint for DeepSpeed Plugin ( diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index aeb43fcdebfe4..6d96a443e391a 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -501,3 +501,13 @@ def reconciliate_processes(self, trace: str): os.kill(pid, signal.SIGKILL) shutil.rmtree(sync_dir) raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}") + + def teardown(self) -> None: + if isinstance(self.model, DistributedDataParallel): + self.model = self.lightning_module + + if self.on_gpu: + # GPU teardown + self.lightning_module.cpu() + # clean up memory + torch.cuda.empty_cache() diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 08c049997bdfd..c31a908902a27 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -364,3 +364,13 @@ def register_plugins(cls, plugin_registry: Dict) -> None: description="DDPSpawn Plugin with `find_unused_parameters` as False", find_unused_parameters=False, ) + + def teardown(self) -> None: + if isinstance(self.model, DistributedDataParallel): + self.model = self.lightning_module + + if self.on_gpu: + # GPU teardown + self.lightning_module.cpu() + # clean up memory + torch.cuda.empty_cache() diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 551324416cce9..5b0887c848322 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -119,3 +119,10 @@ def test_step_end(self, output): if not is_overridden("test_step_end", self.lightning_module): return self.reduce(output) return output + + def teardown(self) -> None: + if self.on_gpu: + # GPU teardown + self.lightning_module.cpu() + # clean up memory + torch.cuda.empty_cache() diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index e5eb8bf9723ea..19694e1bcda11 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -206,3 +206,10 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.Distributed def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tuple[str, nn.Parameter]]: opt_params = {p for group in optimizer.param_groups for p in group.get("params", [])} return [(name, p) for name, p in model.named_parameters() if p in opt_params] + + def teardown(self) -> None: + if self.on_gpu: + # GPU teardown + self.lightning_module.cpu() + # clean up memory + torch.cuda.empty_cache() diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 71aae1bb71a91..31d2deb5f65e6 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -133,15 +133,3 @@ def block_backward_sync(self): yield None else: yield None - - def teardown(self) -> None: - # Un-reference the wrapper if any was used. - # todo (tchaton): Add support for all plugins. - if isinstance(self.model, DistributedDataParallel): - self.model = self.lightning_module - - if self.on_gpu: - # GPU teardown - self.lightning_module.cpu() - # clean up memory - torch.cuda.empty_cache()