diff --git a/pytorch_lightning/plugins/environments/slurm_environment.py b/pytorch_lightning/plugins/environments/slurm_environment.py index 9f3f52fc1c381..d9be5eda54c6b 100644 --- a/pytorch_lightning/plugins/environments/slurm_environment.py +++ b/pytorch_lightning/plugins/environments/slurm_environment.py @@ -28,6 +28,11 @@ class SLURMEnvironment(ClusterEnvironment): def creates_processes_externally(self) -> bool: return True + @staticmethod + def detect() -> bool: + """Returns ``True`` if the current process was launched on a SLURM cluster.""" + return "SLURM_NTASKS" in os.environ + @property def main_address(self) -> str: # figure out the root node addr diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 47deeed2dca1d..5532385ca1d98 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -134,7 +134,6 @@ def __init__( self.precision = precision self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None self.amp_level = amp_level - self._is_slurm_managing_tasks = False self._precision_plugin: Optional[PrecisionPlugin] = None self._training_type_plugin: Optional[TrainingTypePlugin] = None @@ -167,7 +166,6 @@ def __init__( self.handle_given_plugins() self._set_distrib_type_if_training_type_plugin_passed() - self._configure_slurm_ddp() self._cluster_environment = self.select_cluster_environment() self.update_device_type_if_ipu_plugin() @@ -703,7 +701,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices ) elif self.use_ddp: - use_slurm_ddp = self.use_ddp and self._is_slurm_managing_tasks + use_slurm_ddp = self.use_ddp and self._is_slurm_managing_tasks() use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic() use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow() use_ddp_spawn = self._distrib_type == _StrategyType.DDP_SPAWN @@ -711,7 +709,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: use_tpu_spawn = self.use_tpu and self._distrib_type == _StrategyType.TPU_SPAWN use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic() use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow() - use_ddp_cpu_slurm = use_ddp_cpu_spawn and self._is_slurm_managing_tasks + use_ddp_cpu_slurm = use_ddp_cpu_spawn and self._is_slurm_managing_tasks() use_ddp_sharded = self._distrib_type == _StrategyType.DDP_SHARDED use_ddp_sharded_spawn = self._distrib_type == _StrategyType.DDP_SHARDED_SPAWN use_ddp_fully_sharded = self._distrib_type == _StrategyType.DDP_FULLY_SHARDED @@ -807,8 +805,9 @@ def select_accelerator(self) -> Accelerator: def select_cluster_environment(self) -> ClusterEnvironment: if self._cluster_environment is not None: return self._cluster_environment - if self._is_slurm_managing_tasks: + if self._is_slurm_managing_tasks(): env = SLURMEnvironment() + rank_zero_info("Multiprocessing is handled by SLURM.") elif TorchElasticEnvironment.is_using_torchelastic(): env = TorchElasticEnvironment() elif KubeflowEnvironment.is_using_kubeflow(): @@ -990,34 +989,6 @@ def update_device_type_if_training_type_plugin_passed(self) -> None: elif self.has_gpu: self._device_type = DeviceType.GPU - def _configure_slurm_ddp(self): - # extract SLURM flag vars - # whenever we have the correct number of tasks, we let slurm manage processes - # otherwise we launch the required number of processes - if self.use_ddp or self.use_ddp2: - num_requested_gpus = self.num_gpus * self.num_nodes - num_slurm_tasks = 0 - try: - num_slurm_tasks = int(os.environ["SLURM_NTASKS"]) - self._is_slurm_managing_tasks = num_slurm_tasks == num_requested_gpus - - # enable slurm cpu - if num_requested_gpus == 0: - self._is_slurm_managing_tasks = num_slurm_tasks == self.num_processes - - # in interactive mode we don't manage tasks - job_name = os.environ["SLURM_JOB_NAME"] - if job_name == "bash": - self._is_slurm_managing_tasks = False - - except Exception: - # likely not on slurm, so set the slurm managed flag to false - self._is_slurm_managing_tasks = False - - # notify user the that slurm is managing tasks - if self._is_slurm_managing_tasks: - rank_zero_info("Multi-processing is handled by Slurm.") - def _set_distrib_type_if_training_type_plugin_passed(self): # This is required as when `TrainingTypePlugin` instance is passed to either `strategy` # or `plugins` flag, `AcceleratorConnector.set_distributed_mode` is not required to be @@ -1026,3 +997,24 @@ def _set_distrib_type_if_training_type_plugin_passed(self): return if self._training_type_plugin is not None: self._distrib_type = getattr(self._training_type_plugin, "distributed_backend", None) + + def _is_slurm_managing_tasks(self) -> bool: + """Returns whether we let SLURM manage the processes or not. + + Returns ``True`` if and only if these conditions match: + + - A SLURM cluster is detected + - A distributed plugin is being used + - The process is not launching in interactive mode + - The number of tasks in SLURM matches the requested number of devices and nodes in the Trainer + """ + if ( + (not self.use_ddp and not self.use_ddp2) + or not SLURMEnvironment.detect() + or os.environ.get("SLURM_JOB_NAME") == "bash" # in interactive mode we don't manage tasks + ): + return False + + total_requested_devices = (self.num_gpus or self.num_processes) * self.num_nodes + num_slurm_tasks = int(os.environ["SLURM_NTASKS"], 0) + return num_slurm_tasks == total_requested_devices diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index d005c48757330..c01a05759b8fd 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -103,7 +103,7 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock): def test_accelerator_choice_ddp_slurm(set_device_mock, device_count_mock, setup_distributed_mock): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._accelerator_connector._is_slurm_managing_tasks + assert trainer._accelerator_connector._is_slurm_managing_tasks() assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) @@ -136,7 +136,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp2_slurm(set_device_mock, device_count_mock, setup_distributed_mock): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._accelerator_connector._is_slurm_managing_tasks + assert trainer._accelerator_connector._is_slurm_managing_tasks() assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDP2Plugin) assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) @@ -323,7 +323,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._accelerator_connector._is_slurm_managing_tasks + assert trainer._accelerator_connector._is_slurm_managing_tasks() assert isinstance(trainer.accelerator, CPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) @@ -791,7 +791,7 @@ def test_strategy_choice_ddp_spawn(cuda_available_mock, device_count_mock): def test_strategy_choice_ddp_slurm(setup_distributed_mock, strategy): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._accelerator_connector._is_slurm_managing_tasks + assert trainer._accelerator_connector._is_slurm_managing_tasks() assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) @@ -824,7 +824,7 @@ def on_fit_start(self, trainer, pl_module): def test_strategy_choice_ddp2_slurm(set_device_mock, device_count_mock, setup_distributed_mock, strategy): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._accelerator_connector._is_slurm_managing_tasks + assert trainer._accelerator_connector._is_slurm_managing_tasks() assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDP2Plugin) assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) @@ -1008,7 +1008,7 @@ def on_fit_start(self, trainer, pl_module): def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock, strategy): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._accelerator_connector._is_slurm_managing_tasks + assert trainer._accelerator_connector._is_slurm_managing_tasks() assert isinstance(trainer.accelerator, CPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 6d241222526ab..5e7c61163130d 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -500,7 +500,6 @@ def test_dp_resume(tmpdir): # fit model trainer = Trainer(**trainer_options) - trainer._is_slurm_managing_tasks = True trainer.fit(model, datamodule=dm) # track epoch before saving. Increment since we finished the current epoch, don't want to rerun