Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mark SLURM detection methods in AcceleratorConnector as protected #10101

Merged
merged 10 commits into from
Oct 25, 2021
Merged
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `GPUStatsMonitor` and `XLAStatsMonitor` in favor of `DeviceStatsMonitor` callback ([#9924](https://github.com/PyTorchLightning/pytorch-lightning/pull/9924))


- Deprecated access to the `AcceleratorConnector.is_slurm_managing_tasks` attribute and marked it as protected ([#10101](https://github.com/PyTorchLightning/pytorch-lightning/pull/10101))


- Deprecated access to the `AcceleratorConnector.configure_slurm_ddp` method and marked it as protected ([#10101](https://github.com/PyTorchLightning/pytorch-lightning/pull/10101))


### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
Expand Down
42 changes: 31 additions & 11 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(
self.precision = precision
self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None
self.amp_level = amp_level
self.is_slurm_managing_tasks = False
self._is_slurm_managing_tasks = False

self._precision_plugin: Optional[PrecisionPlugin] = None
self._training_type_plugin: Optional[TrainingTypePlugin] = None
Expand Down Expand Up @@ -164,7 +164,7 @@ def __init__(
self._set_training_type_plugin()
else:
self.set_distributed_mode()
self.configure_slurm_ddp()
self._configure_slurm_ddp()

self.handle_given_plugins()
self.update_device_type_if_ipu_plugin()
Expand Down Expand Up @@ -685,15 +685,15 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices
)
elif self.use_ddp:
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
use_slurm_ddp = self.use_ddp and self._is_slurm_managing_tasks
use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic()
use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow()
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
use_ddp_cpu_spawn = use_ddp_spawn and self.use_cpu
use_tpu_spawn = self.use_tpu and self._distrib_type == DistributedType.TPU_SPAWN
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic()
use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow()
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self._is_slurm_managing_tasks
use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED
use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN
use_ddp_fully_sharded = self._distrib_type == DistributedType.DDP_FULLY_SHARDED
Expand Down Expand Up @@ -789,7 +789,7 @@ def select_accelerator(self) -> Accelerator:
def select_cluster_environment(self) -> ClusterEnvironment:
if self._cluster_environment is not None:
return self._cluster_environment
if self.is_slurm_managing_tasks:
if self._is_slurm_managing_tasks:
env = SLURMEnvironment()
elif TorchElasticEnvironment.is_using_torchelastic():
env = TorchElasticEnvironment()
Expand Down Expand Up @@ -972,7 +972,27 @@ def update_device_type_if_training_type_plugin_passed(self) -> None:
elif self.has_gpu:
self._device_type = DeviceType.GPU

def configure_slurm_ddp(self):
@property
def is_slurm_managing_tasks(self) -> bool:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
rank_zero_deprecation(
"`AcceleratorConnector.is_slurm_managing_tasks` was deprecated in v1.5 and will be removed in v1.6."
)
return self._is_slurm_managing_tasks

@is_slurm_managing_tasks.setter
def is_slurm_managing_tasks(self, value: bool) -> bool:
rank_zero_deprecation(
"`AcceleratorConnector.is_slurm_managing_tasks` was deprecated in v1.5 and will be removed in v1.6."
)
self._is_slurm_managing_tasks = value

def configure_slurm_ddp(self) -> None:
rank_zero_deprecation(
"`AcceleratorConnector.configure_slurm_ddp()` was deprecated in v1.5 and will be removed in v1.6."
)
self._configure_slurm_ddp()

awaelchli marked this conversation as resolved.
Show resolved Hide resolved
def _configure_slurm_ddp(self):
# extract SLURM flag vars
# whenever we have the correct number of tasks, we let slurm manage processes
# otherwise we launch the required number of processes
Expand All @@ -981,21 +1001,21 @@ def configure_slurm_ddp(self):
num_slurm_tasks = 0
try:
num_slurm_tasks = int(os.environ["SLURM_NTASKS"])
self.is_slurm_managing_tasks = num_slurm_tasks == num_requested_gpus
self._is_slurm_managing_tasks = num_slurm_tasks == num_requested_gpus

# enable slurm cpu
if num_requested_gpus == 0:
self.is_slurm_managing_tasks = num_slurm_tasks == self.num_processes
self._is_slurm_managing_tasks = num_slurm_tasks == self.num_processes

# in interactive mode we don't manage tasks
job_name = os.environ["SLURM_JOB_NAME"]
if job_name == "bash":
self.is_slurm_managing_tasks = False
self._is_slurm_managing_tasks = False

except Exception:
# likely not on slurm, so set the slurm managed flag to false
self.is_slurm_managing_tasks = False
self._is_slurm_managing_tasks = False

# notify user the that slurm is managing tasks
if self.is_slurm_managing_tasks:
if self._is_slurm_managing_tasks:
rank_zero_info("Multi-processing is handled by Slurm.")
12 changes: 6 additions & 6 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock):
def test_accelerator_choice_ddp_slurm(setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.accelerator_connector.is_slurm_managing_tasks
assert trainer.accelerator_connector._is_slurm_managing_tasks
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -132,7 +132,7 @@ def on_fit_start(self, trainer, pl_module):
def test_accelerator_choice_ddp2_slurm(device_count_mock, setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.accelerator_connector.is_slurm_managing_tasks
assert trainer.accelerator_connector._is_slurm_managing_tasks
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDP2Plugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -307,7 +307,7 @@ def on_fit_start(self, trainer, pl_module):
def test_accelerator_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.accelerator_connector.is_slurm_managing_tasks
assert trainer.accelerator_connector._is_slurm_managing_tasks
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -756,7 +756,7 @@ def test_strategy_choice_ddp_spawn(cuda_available_mock, device_count_mock):
def test_strategy_choice_ddp_slurm(setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.accelerator_connector.is_slurm_managing_tasks
assert trainer.accelerator_connector._is_slurm_managing_tasks
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -788,7 +788,7 @@ def on_fit_start(self, trainer, pl_module):
def test_strategy_choice_ddp2_slurm(device_count_mock, setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.accelerator_connector.is_slurm_managing_tasks
assert trainer.accelerator_connector._is_slurm_managing_tasks
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDP2Plugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down Expand Up @@ -963,7 +963,7 @@ def on_fit_start(self, trainer, pl_module):
def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.accelerator_connector.is_slurm_managing_tasks
assert trainer.accelerator_connector._is_slurm_managing_tasks
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
Expand Down
15 changes: 15 additions & 0 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,18 @@ def test_v1_6_0_deprecated_accelerator_pass_through_functions():

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.on_train_batch_start(batch=None, batch_idx=0)


def test_v1_6_0_configure_slurm_ddp():
trainer = Trainer()
with pytest.deprecated_call(match=r"`AcceleratorConnector.configure_slurm_ddp\(\)` was deprecated in v1.5"):
trainer.accelerator_connector.configure_slurm_ddp()


def test_v1_6_0_is_slurm_managing_tasks():
trainer = Trainer()
with pytest.deprecated_call(match=r"`AcceleratorConnector.is_slurm_managing_tasks` was deprecated in v1.5"):
_ = trainer.accelerator_connector.is_slurm_managing_tasks

with pytest.deprecated_call(match=r"`AcceleratorConnector.is_slurm_managing_tasks` was deprecated in v1.5"):
trainer.accelerator_connector.is_slurm_managing_tasks = False
2 changes: 1 addition & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def test_dp_resume(tmpdir):

# fit model
trainer = Trainer(**trainer_options)
trainer.is_slurm_managing_tasks = True
trainer._is_slurm_managing_tasks = True
trainer.fit(model, datamodule=dm)

# track epoch before saving. Increment since we finished the current epoch, don't want to rerun
Expand Down