Skip to content

Commit

Permalink
Remove reference to DistributedDataParallel from parallel plugin tear…
Browse files Browse the repository at this point in the history
…down()
  • Loading branch information
four4fish committed Aug 20, 2021
1 parent 6992db5 commit a6c21b6
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 7 deletions.
10 changes: 10 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,3 +463,13 @@ def reconciliate_processes(self, trace: str):
os.kill(pid, signal.SIGKILL)
shutil.rmtree(sync_dir)
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")

def teardown(self) -> None:
if isinstance(self.model, DistributedDataParallel):
self.model = self.lightning_module

if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()
10 changes: 10 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,13 @@ def register_plugins(cls, plugin_registry: Dict) -> None:
description="DDPSpawn Plugin with `find_unused_parameters` as False",
find_unused_parameters=False,
)

def teardown(self) -> None:
if isinstance(self.model, DistributedDataParallel):
self.model = self.lightning_module

if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()
5 changes: 0 additions & 5 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,6 @@ def block_backward_sync(self):
yield None

def teardown(self) -> None:
# Un-reference the wrapper if any was used.
# todo (tchaton): Add support for all plugins.
if isinstance(self.model, DistributedDataParallel):
self.model = self.lightning_module

if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
Expand Down
20 changes: 19 additions & 1 deletion tests/plugins/test_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins import DDPPlugin
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -69,3 +69,21 @@ def test_ddp_barrier_non_consecutive_device_ids(barrier_mock, tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, gpus=gpus, accelerator="ddp")
trainer.fit(model)
barrier_mock.assert_any_call(device_ids=[gpus[trainer.local_rank]])


class BoringModelDDP(BoringModel):
def on_train_start(self) -> None:
"""Check model wrapped by DistributedDataParallel at configurate_ddp"""
assert isinstance(self.trainer.model, DistributedDataParallel)

def on_fit_end(self) -> None:
"""Check if teardown function at the end of training."""
assert self.trainer.model.device == torch.device("cpu")
assert isinstance(self.trainer.model, LightningModule)

@RunIf(min_gpus=1)
def test_ddp_plugin_teardown():
"""Tests with ddp plugin."""
trainer = Trainer(num_processes=2, accelerator="ddp", fast_dev_run=True)
model = BoringModelDDP()
trainer.fit(model)
23 changes: 22 additions & 1 deletion tests/plugins/test_ddp_spawn_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
# limitations under the License.
import torch

from pytorch_lightning import Trainer
import torch
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins import DDPSpawnPlugin
from tests.helpers.boring_model import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -77,3 +80,21 @@ def test_ddp_spawn_extra_parameters(tmpdir):
trainer.fit(model, datamodule=dm)
assert trainer.callback_metrics[val_name] == torch.tensor(val)
assert model.test_val == "test_val"


class BoringModelDDP(BoringModel):
def on_train_start(self) -> None:
"""Check model wrapped by DistributedDataParallel at configurate_ddp."""
assert isinstance(self.trainer.model, DistributedDataParallel)

def on_fit_end(self) -> None:
"""Check teardown function at the end of training."""
assert self.trainer.model.device == torch.device("cpu")
assert isinstance(self.trainer.model, LightningModule)

@RunIf(min_gpus=1)
def test_ddp_plugin_teardown():
"""Tests with ddp spawn plugin."""
trainer = Trainer(num_processes=2, accelerator="ddp_spawn", fast_dev_run=True)
model = BoringModelDDP()
trainer.fit(model)

0 comments on commit a6c21b6

Please sign in to comment.