Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: unify slurm and TE under backendPlugin 2/n #4580

Merged
merged 1 commit into from
Nov 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
from pytorch_lightning.accelerators.tpu_accelerator import TPUAccelerator
from pytorch_lightning.accelerators.horovod_accelerator import HorovodAccelerator
from pytorch_lightning.accelerators.ddp_slurm_accelerator import DDPSLURMAccelerator
from pytorch_lightning.accelerators.ddp_torchelastic_accelerator import DDPTorchElasticAccelerator
from pytorch_lightning.accelerators.ddp_hpc_accelerator import DDPHPCAccelerator
from pytorch_lightning.accelerators.ddp_cpu_torchelastic_accelerator import DDPCPUTorchElasticAccelerator
from pytorch_lightning.accelerators.ddp_cpu_slurm_accelerator import DDPCPUSLURMAccelerator
from pytorch_lightning.accelerators.accelerator import Accelerator
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def select_accelerator(self):
)

elif use_slurm_ddp:
accelerator_backend = accelerators.DDPSLURMAccelerator(
accelerator_backend = accelerators.DDPHPCAccelerator(
self.trainer,
cluster_env,
self.trainer.plugin_connector.ddp_plugin
Expand All @@ -241,7 +241,7 @@ def select_accelerator(self):
)

elif use_torchelastic_ddp:
accelerator_backend = accelerators.DDPTorchElasticAccelerator(
accelerator_backend = accelerators.DDPHPCAccelerator(
self.trainer,
cluster_env,
self.trainer.plugin_connector.ddp_plugin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available


try:
from hydra.utils import to_absolute_path, get_original_cwd
from hydra.core.hydra_config import HydraConfig
Expand All @@ -35,12 +36,7 @@
HYDRA_AVAILABLE = True


# -------------------------------------------
# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!!
# TEMP CLASS WHILE WE DECOUPLE SLURM FROM DDP
# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!!
# -------------------------------------------
class DDPSLURMAccelerator(Accelerator):
class DDPHPCAccelerator(Accelerator):

def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
super().__init__(trainer, cluster_environment, ddp_plugin)
Expand Down
209 changes: 0 additions & 209 deletions pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/backends/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_accelerator_choice_ddp_slurm(tmpdir):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp
assert isinstance(trainer.accelerator_backend, accelerators.DDPSLURMAccelerator)
assert isinstance(trainer.accelerator_backend, accelerators.DDPHPCAccelerator)
assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment)
assert trainer.accelerator_backend.task_idx == 10
assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_accelerator_choice_ddp_te(tmpdir):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp
assert isinstance(trainer.accelerator_backend, accelerators.DDPTorchElasticAccelerator)
assert isinstance(trainer.accelerator_backend, accelerators.DDPHPCAccelerator)
assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment)
assert trainer.accelerator_backend.task_idx == 10
assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx
Expand Down