Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve code quality in AcceleratorConnector._configure_slurm_ddp #10102

Merged
merged 23 commits into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pytorch_lightning/plugins/environments/slurm_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ class SLURMEnvironment(ClusterEnvironment):
def creates_processes_externally(self) -> bool:
return True

@staticmethod
def is_using_slurm() -> bool:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""Returns ``True`` if the current process was launched on a SLURM cluster."""
return "SLURM_NTASKS" in os.environ

@property
def main_address(self) -> str:
# figure out the root node addr
Expand Down
56 changes: 26 additions & 30 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def __init__(
self.precision = precision
self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None
self.amp_level = amp_level
self._is_slurm_managing_tasks = False

self._precision_plugin: Optional[PrecisionPlugin] = None
self._training_type_plugin: Optional[TrainingTypePlugin] = None
Expand Down Expand Up @@ -167,7 +166,6 @@ def __init__(
self.handle_given_plugins()
self._set_distrib_type_if_training_type_plugin_passed()

self._configure_slurm_ddp()
self._cluster_environment = self.select_cluster_environment()

self.update_device_type_if_ipu_plugin()
Expand Down Expand Up @@ -809,6 +807,7 @@ def select_cluster_environment(self) -> ClusterEnvironment:
return self._cluster_environment
if self._is_slurm_managing_tasks:
env = SLURMEnvironment()
rank_zero_info("Multiprocessing is handled by SLURM.")
elif TorchElasticEnvironment.is_using_torchelastic():
env = TorchElasticEnvironment()
elif KubeflowEnvironment.is_using_kubeflow():
Expand Down Expand Up @@ -990,34 +989,6 @@ def update_device_type_if_training_type_plugin_passed(self) -> None:
elif self.has_gpu:
self._device_type = DeviceType.GPU

def _configure_slurm_ddp(self):
# extract SLURM flag vars
# whenever we have the correct number of tasks, we let slurm manage processes
# otherwise we launch the required number of processes
if self.use_ddp or self.use_ddp2:
num_requested_gpus = self.num_gpus * self.num_nodes
num_slurm_tasks = 0
try:
num_slurm_tasks = int(os.environ["SLURM_NTASKS"])
self._is_slurm_managing_tasks = num_slurm_tasks == num_requested_gpus

# enable slurm cpu
if num_requested_gpus == 0:
self._is_slurm_managing_tasks = num_slurm_tasks == self.num_processes

# in interactive mode we don't manage tasks
job_name = os.environ["SLURM_JOB_NAME"]
if job_name == "bash":
self._is_slurm_managing_tasks = False

except Exception:
# likely not on slurm, so set the slurm managed flag to false
self._is_slurm_managing_tasks = False

# notify user the that slurm is managing tasks
if self._is_slurm_managing_tasks:
rank_zero_info("Multi-processing is handled by Slurm.")

def _set_distrib_type_if_training_type_plugin_passed(self):
# This is required as when `TrainingTypePlugin` instance is passed to either `strategy`
# or `plugins` flag, `AcceleratorConnector.set_distributed_mode` is not required to be
Expand All @@ -1026,3 +997,28 @@ def _set_distrib_type_if_training_type_plugin_passed(self):
return
if self._training_type_plugin is not None:
self._distrib_type = getattr(self._training_type_plugin, "distributed_backend", None)

@property
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
def _is_slurm_managing_tasks(self) -> bool:
"""Returns whether we let SLURM manage the processes or not.

Returns ``True`` if and only if these conditions match:

- A SLURM cluster is detected
- A distributed plugin is being used
- The process is not launching in interactive mode
- The number of tasks in SLURM matches the requested number of devices and nodes in the Trainer
"""
if not self.use_ddp and not self.use_ddp2:
return False

if not SLURMEnvironment.is_using_slurm():
return False

if os.environ.get("SLURM_JOB_NAME") == "bash":
# in interactive mode we don't manage tasks
return False
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

total_requested_devices = len(self.parallel_devices) * self.num_nodes
num_slurm_tasks = int(os.environ["SLURM_NTASKS"], 0)
return num_slurm_tasks == total_requested_devices
1 change: 0 additions & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,6 @@ def test_dp_resume(tmpdir):

# fit model
trainer = Trainer(**trainer_options)
trainer._is_slurm_managing_tasks = True
trainer.fit(model, datamodule=dm)

# track epoch before saving. Increment since we finished the current epoch, don't want to rerun
Expand Down