Skip to content

Commit

Permalink
Remove reference to DistributedDataParallel from parallel plugin tear…
Browse files Browse the repository at this point in the history
…down (#8943)
  • Loading branch information
four4fish committed Aug 27, 2021
1 parent 53885af commit b497fb8
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 12 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `Plugin` in `base_plugin.py`, access `TrainingTypePlugin` and `PrecisionPlugin` directly instead ([#9066](https://github.com/PyTorchLightning/pytorch-lightning/pull/9066))


- Removed `teardown` from `ParallelPlugin` ([#8943](https://github.com/PyTorchLightning/pytorch-lightning/pull/8943))


### Fixed

- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (
Expand Down
10 changes: 10 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,3 +501,13 @@ def reconciliate_processes(self, trace: str):
os.kill(pid, signal.SIGKILL)
shutil.rmtree(sync_dir)
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")

def teardown(self) -> None:
if isinstance(self.model, DistributedDataParallel):
self.model = self.lightning_module

if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()
10 changes: 10 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,13 @@ def register_plugins(cls, plugin_registry: Dict) -> None:
description="DDPSpawn Plugin with `find_unused_parameters` as False",
find_unused_parameters=False,
)

def teardown(self) -> None:
if isinstance(self.model, DistributedDataParallel):
self.model = self.lightning_module

if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()
7 changes: 7 additions & 0 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,10 @@ def test_step_end(self, output):
if not is_overridden("test_step_end", self.lightning_module):
return self.reduce(output)
return output

def teardown(self) -> None:
if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()
7 changes: 7 additions & 0 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,10 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.Distributed
def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tuple[str, nn.Parameter]]:
opt_params = {p for group in optimizer.param_groups for p in group.get("params", [])}
return [(name, p) for name, p in model.named_parameters() if p in opt_params]

def teardown(self) -> None:
if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()
12 changes: 0 additions & 12 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,3 @@ def block_backward_sync(self):
yield None
else:
yield None

def teardown(self) -> None:
# Un-reference the wrapper if any was used.
# todo (tchaton): Add support for all plugins.
if isinstance(self.model, DistributedDataParallel):
self.model = self.lightning_module

if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()

0 comments on commit b497fb8

Please sign in to comment.