Skip to content

Commit

Permalink
ref: unify slurm and TE under backendPlugin 3/n (#4581)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Nov 8, 2020
1 parent bfaf014 commit 624f5b5
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 221 deletions.
3 changes: 1 addition & 2 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,5 @@
from pytorch_lightning.accelerators.tpu_accelerator import TPUAccelerator
from pytorch_lightning.accelerators.horovod_accelerator import HorovodAccelerator
from pytorch_lightning.accelerators.ddp_hpc_accelerator import DDPHPCAccelerator
from pytorch_lightning.accelerators.ddp_cpu_torchelastic_accelerator import DDPCPUTorchElasticAccelerator
from pytorch_lightning.accelerators.ddp_cpu_slurm_accelerator import DDPCPUSLURMAccelerator
from pytorch_lightning.accelerators.ddp_cpu_hpc_accelerator import DDPCPUHPCAccelerator
from pytorch_lightning.accelerators.accelerator import Accelerator
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def select_accelerator(self):
)

elif use_ddp_cpu_slurm:
accelerator_backend = accelerators.DDPCPUSLURMAccelerator(
accelerator_backend = accelerators.DDPCPUHPCAccelerator(
self.trainer,
cluster_env,
self.trainer.plugin_connector.ddp_plugin
Expand All @@ -234,7 +234,7 @@ def select_accelerator(self):
)

elif use_ddp_cpu_torch_elastic:
accelerator_backend = accelerators.DDPCPUTorchElasticAccelerator(
accelerator_backend = accelerators.DDPCPUHPCAccelerator(
self.trainer,
cluster_env,
self.trainer.plugin_connector.ddp_plugin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,7 @@
HYDRA_AVAILABLE = True


# -------------------------------------------
# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!!
# TEMP CLASS WHILE WE DECOUPLE TE FROM DDP
# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!!
# -------------------------------------------
class DDPCPUSLURMAccelerator(Accelerator):
class DDPCPUHPCAccelerator(Accelerator):

def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
super().__init__(trainer, cluster_environment, ddp_plugin)
Expand Down
207 changes: 0 additions & 207 deletions pytorch_lightning/accelerators/ddp_cpu_torchelastic_accelerator.py

This file was deleted.

8 changes: 4 additions & 4 deletions tests/backends/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def test_accelerator_choice_ddp_cpu_te(tmpdir):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp
assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUTorchElasticAccelerator)
assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUHPCAccelerator)
assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment)
assert trainer.accelerator_backend.task_idx == 10
assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx
Expand Down Expand Up @@ -260,7 +260,7 @@ def test_accelerator_choice_ddp_cpu_slurm(tmpdir):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp
assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUSLURMAccelerator)
assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUHPCAccelerator)
assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment)
raise SystemExit()

Expand Down Expand Up @@ -295,7 +295,7 @@ def master_address(self):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp
assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUSLURMAccelerator)
assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUHPCAccelerator)
assert isinstance(trainer.accelerator_backend.cluster_environment, CustomCluster)
raise SystemExit()

Expand Down Expand Up @@ -353,7 +353,7 @@ def on_fit_start(self, trainer, pl_module):
def test_dist_backend_accelerator_mapping(tmpdir):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUSLURMAccelerator)
assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUHPCAccelerator)
raise SystemExit()

model = BoringModel()
Expand Down

0 comments on commit 624f5b5

Please sign in to comment.