Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve code quality in AcceleratorConnector._configure_slurm_ddp #10102

Merged
merged 23 commits into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pytorch_lightning/plugins/environments/slurm_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ class SLURMEnvironment(ClusterEnvironment):
def creates_processes_externally(self) -> bool:
return True

@staticmethod
def detect() -> bool:
"""Returns ``True`` if the current process was launched on a SLURM cluster."""
return "SLURM_NTASKS" in os.environ

@property
def main_address(self) -> str:
# figure out the root node addr
Expand Down
58 changes: 25 additions & 33 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def __init__(
self.precision = precision
self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None
self.amp_level = amp_level
self._is_slurm_managing_tasks = False

self._precision_plugin: Optional[PrecisionPlugin] = None
self._training_type_plugin: Optional[TrainingTypePlugin] = None
Expand Down Expand Up @@ -167,7 +166,6 @@ def __init__(
self.handle_given_plugins()
self._set_distrib_type_if_training_type_plugin_passed()

self._configure_slurm_ddp()
self._cluster_environment = self.select_cluster_environment()

self.update_device_type_if_ipu_plugin()
Expand Down Expand Up @@ -703,15 +701,15 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices
)
elif self.use_ddp:
use_slurm_ddp = self.use_ddp and self._is_slurm_managing_tasks
use_slurm_ddp = self.use_ddp and self._is_slurm_managing_tasks()
use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic()
use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow()
use_ddp_spawn = self._distrib_type == _StrategyType.DDP_SPAWN
use_ddp_cpu_spawn = use_ddp_spawn and self.use_cpu
use_tpu_spawn = self.use_tpu and self._distrib_type == _StrategyType.TPU_SPAWN
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic()
use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow()
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self._is_slurm_managing_tasks
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self._is_slurm_managing_tasks()
use_ddp_sharded = self._distrib_type == _StrategyType.DDP_SHARDED
use_ddp_sharded_spawn = self._distrib_type == _StrategyType.DDP_SHARDED_SPAWN
use_ddp_fully_sharded = self._distrib_type == _StrategyType.DDP_FULLY_SHARDED
Expand Down Expand Up @@ -807,8 +805,9 @@ def select_accelerator(self) -> Accelerator:
def select_cluster_environment(self) -> ClusterEnvironment:
if self._cluster_environment is not None:
return self._cluster_environment
if self._is_slurm_managing_tasks:
if self._is_slurm_managing_tasks():
env = SLURMEnvironment()
rank_zero_info("Multiprocessing is handled by SLURM.")
elif TorchElasticEnvironment.is_using_torchelastic():
env = TorchElasticEnvironment()
elif KubeflowEnvironment.is_using_kubeflow():
Expand Down Expand Up @@ -990,34 +989,6 @@ def update_device_type_if_training_type_plugin_passed(self) -> None:
elif self.has_gpu:
self._device_type = DeviceType.GPU

def _configure_slurm_ddp(self):
# extract SLURM flag vars
# whenever we have the correct number of tasks, we let slurm manage processes
# otherwise we launch the required number of processes
if self.use_ddp or self.use_ddp2:
num_requested_gpus = self.num_gpus * self.num_nodes
num_slurm_tasks = 0
try:
num_slurm_tasks = int(os.environ["SLURM_NTASKS"])
self._is_slurm_managing_tasks = num_slurm_tasks == num_requested_gpus

# enable slurm cpu
if num_requested_gpus == 0:
self._is_slurm_managing_tasks = num_slurm_tasks == self.num_processes

# in interactive mode we don't manage tasks
job_name = os.environ["SLURM_JOB_NAME"]
if job_name == "bash":
self._is_slurm_managing_tasks = False

except Exception:
# likely not on slurm, so set the slurm managed flag to false
self._is_slurm_managing_tasks = False

# notify user the that slurm is managing tasks
if self._is_slurm_managing_tasks:
rank_zero_info("Multi-processing is handled by Slurm.")

def _set_distrib_type_if_training_type_plugin_passed(self):
# This is required as when `TrainingTypePlugin` instance is passed to either `strategy`
# or `plugins` flag, `AcceleratorConnector.set_distributed_mode` is not required to be
Expand All @@ -1026,3 +997,24 @@ def _set_distrib_type_if_training_type_plugin_passed(self):
return
if self._training_type_plugin is not None:
self._distrib_type = getattr(self._training_type_plugin, "distributed_backend", None)

def _is_slurm_managing_tasks(self) -> bool:
"""Returns whether we let SLURM manage the processes or not.

Returns ``True`` if and only if these conditions match:

- A SLURM cluster is detected
- A distributed plugin is being used
- The process is not launching in interactive mode
- The number of tasks in SLURM matches the requested number of devices and nodes in the Trainer
"""
if (
(not self.use_ddp and not self.use_ddp2)
or not SLURMEnvironment.detect()
or os.environ.get("SLURM_JOB_NAME") == "bash" # in interactive mode we don't manage tasks
):
return False

total_requested_devices = (self.num_gpus or self.num_processes) * self.num_nodes
num_slurm_tasks = int(os.environ["SLURM_NTASKS"], 0)
return num_slurm_tasks == total_requested_devices
12 changes: 6 additions & 6 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock):
def test_accelerator_choice_ddp_slurm(set_device_mock, device_count_mock, setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer._accelerator_connector._is_slurm_managing_tasks
assert trainer._accelerator_connector._is_slurm_managing_tasks()
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -136,7 +136,7 @@ def on_fit_start(self, trainer, pl_module):
def test_accelerator_choice_ddp2_slurm(set_device_mock, device_count_mock, setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer._accelerator_connector._is_slurm_managing_tasks
assert trainer._accelerator_connector._is_slurm_managing_tasks()
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDP2Plugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -323,7 +323,7 @@ def on_fit_start(self, trainer, pl_module):
def test_accelerator_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer._accelerator_connector._is_slurm_managing_tasks
assert trainer._accelerator_connector._is_slurm_managing_tasks()
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -791,7 +791,7 @@ def test_strategy_choice_ddp_spawn(cuda_available_mock, device_count_mock):
def test_strategy_choice_ddp_slurm(setup_distributed_mock, strategy):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer._accelerator_connector._is_slurm_managing_tasks
assert trainer._accelerator_connector._is_slurm_managing_tasks()
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -824,7 +824,7 @@ def on_fit_start(self, trainer, pl_module):
def test_strategy_choice_ddp2_slurm(set_device_mock, device_count_mock, setup_distributed_mock, strategy):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer._accelerator_connector._is_slurm_managing_tasks
assert trainer._accelerator_connector._is_slurm_managing_tasks()
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDP2Plugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -1008,7 +1008,7 @@ def on_fit_start(self, trainer, pl_module):
def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock, strategy):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer._accelerator_connector._is_slurm_managing_tasks
assert trainer._accelerator_connector._is_slurm_managing_tasks()
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down
1 change: 0 additions & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,6 @@ def test_dp_resume(tmpdir):

# fit model
trainer = Trainer(**trainer_options)
trainer._is_slurm_managing_tasks = True
trainer.fit(model, datamodule=dm)

# track epoch before saving. Increment since we finished the current epoch, don't want to rerun
Expand Down