Skip to content

Commit

Permalink
ref: unify slurm and TE under backendPlugin 2/n (#4580)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored and rohitgr7 committed Nov 21, 2020
1 parent f2b37ed commit 2a3a356
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 221 deletions.
3 changes: 1 addition & 2 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
from pytorch_lightning.accelerators.tpu_accelerator import TPUAccelerator
from pytorch_lightning.accelerators.horovod_accelerator import HorovodAccelerator
from pytorch_lightning.accelerators.ddp_slurm_accelerator import DDPSLURMAccelerator
from pytorch_lightning.accelerators.ddp_torchelastic_accelerator import DDPTorchElasticAccelerator
from pytorch_lightning.accelerators.ddp_hpc_accelerator import DDPHPCAccelerator
from pytorch_lightning.accelerators.ddp_cpu_torchelastic_accelerator import DDPCPUTorchElasticAccelerator
from pytorch_lightning.accelerators.ddp_cpu_slurm_accelerator import DDPCPUSLURMAccelerator
from pytorch_lightning.accelerators.accelerator import Accelerator
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def select_accelerator(self):
)

elif use_slurm_ddp:
accelerator_backend = accelerators.DDPSLURMAccelerator(
accelerator_backend = accelerators.DDPHPCAccelerator(
self.trainer,
cluster_env,
self.trainer.plugin_connector.ddp_plugin
Expand All @@ -241,7 +241,7 @@ def select_accelerator(self):
)

elif use_torchelastic_ddp:
accelerator_backend = accelerators.DDPTorchElasticAccelerator(
accelerator_backend = accelerators.DDPHPCAccelerator(
self.trainer,
cluster_env,
self.trainer.plugin_connector.ddp_plugin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available


try:
from hydra.utils import to_absolute_path, get_original_cwd
from hydra.core.hydra_config import HydraConfig
Expand All @@ -35,12 +36,7 @@
HYDRA_AVAILABLE = True


# -------------------------------------------
# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!!
# TEMP CLASS WHILE WE DECOUPLE SLURM FROM DDP
# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!!
# -------------------------------------------
class DDPSLURMAccelerator(Accelerator):
class DDPHPCAccelerator(Accelerator):

def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
super().__init__(trainer, cluster_environment, ddp_plugin)
Expand Down
209 changes: 0 additions & 209 deletions pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/backends/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_accelerator_choice_ddp_slurm(tmpdir):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp
assert isinstance(trainer.accelerator_backend, accelerators.DDPSLURMAccelerator)
assert isinstance(trainer.accelerator_backend, accelerators.DDPHPCAccelerator)
assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment)
assert trainer.accelerator_backend.task_idx == 10
assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_accelerator_choice_ddp_te(tmpdir):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp
assert isinstance(trainer.accelerator_backend, accelerators.DDPTorchElasticAccelerator)
assert isinstance(trainer.accelerator_backend, accelerators.DDPHPCAccelerator)
assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment)
assert trainer.accelerator_backend.task_idx == 10
assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx
Expand Down

0 comments on commit 2a3a356

Please sign in to comment.