From 27e68207514d895376d879cabe670408f99bd1d9 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 8 Nov 2020 13:54:58 -0500 Subject: [PATCH 1/2] ref: unify slurm and TE under backendPlugin --- pytorch_lightning/accelerators/ddp_slurm_accelerator.py | 9 ++------- .../accelerators/ddp_torchelastic_accelerator.py | 7 +++---- .../cluster_environments/cluster_environment.py | 4 ++++ .../cluster_environments/slurm_environment.py | 3 +++ .../cluster_environments/torchelastic_environment.py | 3 +++ 5 files changed, 15 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp_slurm_accelerator.py b/pytorch_lightning/accelerators/ddp_slurm_accelerator.py index 1ea4461c3c3cc..e0a563eb91100 100644 --- a/pytorch_lightning/accelerators/ddp_slurm_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_slurm_accelerator.py @@ -25,7 +25,6 @@ from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available -from pytorch_lightning.utilities.seed import seed_everything try: from hydra.utils import to_absolute_path, get_original_cwd @@ -52,7 +51,7 @@ def __init__(self, trainer, cluster_environment=None, ddp_plugin=None): def setup(self, model): self.trainer.model = model - self.task_idx = int(os.environ['SLURM_LOCALID']) + self.task_idx = self.cluster_environment.local_rank def train(self): model = self.trainer.model @@ -88,7 +87,7 @@ def test_step(self, args): output = self.training_step(args) return output - def barrier(self, name: str = None): + def barrier(self, name: Optional[str] = None): if torch_distrib.is_initialized(): torch_distrib.barrier() @@ -115,10 +114,6 @@ def ddp_train(self, process_idx, model): Dict with evaluation results """ - seed = os.environ.get("PL_GLOBAL_SEED") - if seed is not None: - seed_everything(int(seed)) - # determine which process we are and world size self.set_world_ranks(process_idx) diff --git a/pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py b/pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py index e54ad905de80e..1293585fc4567 100644 --- a/pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py @@ -24,8 +24,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.distributed import rank_zero_only -from pytorch_lightning.utilities.distributed import sync_ddp_if_available +from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available try: @@ -53,7 +52,7 @@ def __init__(self, trainer, cluster_environment=None, ddp_plugin=None): def setup(self, model): self.trainer.model = model - self.task_idx = int(os.environ['LOCAL_RANK']) + self.task_idx = self.cluster_environment.local_rank def train(self): model = self.trainer.model @@ -120,7 +119,7 @@ def ddp_train(self, process_idx, model): self.set_world_ranks(process_idx) # toggle prog bar - if self.trainer.global_rank == 0 and self.trainer.progress_bar_callback is not None: + if self.trainer.global_rank != 0 and self.trainer.progress_bar_callback is not None: self.trainer.progress_bar_callback.disable() # set warning rank diff --git a/pytorch_lightning/cluster_environments/cluster_environment.py b/pytorch_lightning/cluster_environments/cluster_environment.py index ff3436e66204c..08fbbf4095ca3 100644 --- a/pytorch_lightning/cluster_environments/cluster_environment.py +++ b/pytorch_lightning/cluster_environments/cluster_environment.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + class ClusterEnvironment: def __init__(self): @@ -25,3 +26,6 @@ def master_port(self): def world_size(self): return self._world_size + + def local_rank(self): + pass diff --git a/pytorch_lightning/cluster_environments/slurm_environment.py b/pytorch_lightning/cluster_environments/slurm_environment.py index 44cdc2207899c..6df1cf680c57f 100644 --- a/pytorch_lightning/cluster_environments/slurm_environment.py +++ b/pytorch_lightning/cluster_environments/slurm_environment.py @@ -67,6 +67,9 @@ def master_port(self): def world_size(self): return self._world_size + def local_rank(self): + return int(os.environ['SLURM_LOCALID']) + def _resolve_root_node_address(self, root_node): if '[' in root_node: name, numbers = root_node.split('[', maxsplit=1) diff --git a/pytorch_lightning/cluster_environments/torchelastic_environment.py b/pytorch_lightning/cluster_environments/torchelastic_environment.py index d50a10a782dbb..a4d769518d252 100644 --- a/pytorch_lightning/cluster_environments/torchelastic_environment.py +++ b/pytorch_lightning/cluster_environments/torchelastic_environment.py @@ -46,3 +46,6 @@ def master_port(self): def world_size(self): return os.environ.get('WORLD_SIZE') + + def local_rank(self): + return int(os.environ['LOCAL_RANK']) From 579764d6c2018f8cc7006f8e5338276e22018d83 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 8 Nov 2020 14:07:09 -0500 Subject: [PATCH 2/2] ref: unify slurm and TE under backendPlugin --- .../accelerators/ddp2_accelerator.py | 13 +---------- .../accelerators/ddp_cpu_slurm_accelerator.py | 4 ++-- .../ddp_cpu_torchelastic_accelerator.py | 4 ++-- .../accelerators/ddp_slurm_accelerator.py | 4 ++-- .../ddp_torchelastic_accelerator.py | 4 ++-- tests/backends/test_accelerator_connector.py | 22 ++++++++++++++----- 6 files changed, 26 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 452c15ba3bc38..db4ccbda01bf0 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -46,19 +46,8 @@ def __init__(self, trainer, cluster_environment=None, ddp_plugin=None): self.nickname = 'ddp2' def setup(self, model): - self._resolve_task_idx() self.trainer.model = model - - def _resolve_task_idx(self): - if self.trainer.is_slurm_managing_tasks: - self.task_idx = int(os.environ['SLURM_LOCALID']) - else: - # torchelastic or general non_slurm ddp2 - try: - self.task_idx = int(os.environ['LOCAL_RANK']) - except Exception as exp: - m = 'ddp2 only works in SLURM or via torchelastic with the WORLD_SIZE, LOCAL_RANK, GROUP_RANK flags' - raise MisconfigurationException(m) from exp + self.task_idx = self.cluster_environment.local_rank() def train(self): model = self.trainer.model diff --git a/pytorch_lightning/accelerators/ddp_cpu_slurm_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_slurm_accelerator.py index c80e8a4ec355c..dbaa4a244f629 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_slurm_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_slurm_accelerator.py @@ -53,7 +53,7 @@ def __init__(self, trainer, cluster_environment=None, ddp_plugin=None): def setup(self, model): self.trainer.model = model - self.task_idx = int(os.environ['SLURM_LOCALID']) + self.task_idx = self.cluster_environment.local_rank() def train(self): model = self.trainer.model @@ -118,7 +118,7 @@ def ddp_train(self, process_idx, model): self.set_world_ranks(process_idx) # toggle prog bar - if self.trainer.global_rank == 0 and self.trainer.progress_bar_callback is not None: + if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: self.trainer.progress_bar_callback.disable() # set warning rank diff --git a/pytorch_lightning/accelerators/ddp_cpu_torchelastic_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_torchelastic_accelerator.py index a90d7750eaeea..016468a4517cf 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_torchelastic_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_torchelastic_accelerator.py @@ -52,7 +52,7 @@ def __init__(self, trainer, cluster_environment=None, ddp_plugin=None): def setup(self, model): self.trainer.model = model - self.task_idx = int(os.environ['LOCAL_RANK']) + self.task_idx = self.cluster_environment.local_rank() def train(self): model = self.trainer.model @@ -117,7 +117,7 @@ def ddp_train(self, process_idx, model): self.set_world_ranks(process_idx) # toggle prog bar - if self.trainer.global_rank == 0 and self.trainer.progress_bar_callback is not None: + if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: self.trainer.progress_bar_callback.disable() # set warning rank diff --git a/pytorch_lightning/accelerators/ddp_slurm_accelerator.py b/pytorch_lightning/accelerators/ddp_slurm_accelerator.py index e0a563eb91100..6bdd1930ecfe0 100644 --- a/pytorch_lightning/accelerators/ddp_slurm_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_slurm_accelerator.py @@ -51,7 +51,7 @@ def __init__(self, trainer, cluster_environment=None, ddp_plugin=None): def setup(self, model): self.trainer.model = model - self.task_idx = self.cluster_environment.local_rank + self.task_idx = self.cluster_environment.local_rank() def train(self): model = self.trainer.model @@ -118,7 +118,7 @@ def ddp_train(self, process_idx, model): self.set_world_ranks(process_idx) # toggle prog bar - if self.trainer.global_rank != 0 and self.trainer.progress_bar_callback is not None: + if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: self.trainer.progress_bar_callback.disable() # set warning rank diff --git a/pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py b/pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py index 1293585fc4567..53e784ee949d4 100644 --- a/pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py @@ -52,7 +52,7 @@ def __init__(self, trainer, cluster_environment=None, ddp_plugin=None): def setup(self, model): self.trainer.model = model - self.task_idx = self.cluster_environment.local_rank + self.task_idx = self.cluster_environment.local_rank() def train(self): model = self.trainer.model @@ -119,7 +119,7 @@ def ddp_train(self, process_idx, model): self.set_world_ranks(process_idx) # toggle prog bar - if self.trainer.global_rank != 0 and self.trainer.progress_bar_callback is not None: + if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: self.trainer.progress_bar_callback.disable() # set warning rank diff --git a/tests/backends/test_accelerator_connector.py b/tests/backends/test_accelerator_connector.py index cbc96b0793062..57ccdfc594238 100644 --- a/tests/backends/test_accelerator_connector.py +++ b/tests/backends/test_accelerator_connector.py @@ -104,7 +104,7 @@ def on_fit_start(self, trainer, pl_module): "SLURM_NTASKS": "2", "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", - "SLURM_LOCALID": "0" + "SLURM_LOCALID": "10" }) @mock.patch('torch.cuda.device_count', return_value=2) def test_accelerator_choice_ddp_slurm(tmpdir): @@ -113,6 +113,8 @@ def on_fit_start(self, trainer, pl_module): assert trainer.use_ddp assert isinstance(trainer.accelerator_backend, accelerators.DDPSLURMAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment) + assert trainer.accelerator_backend.task_idx == 10 + assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx raise SystemExit() model = BoringModel() @@ -133,7 +135,7 @@ def on_fit_start(self, trainer, pl_module): "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "LOCAL_RANK": "0", - "SLURM_LOCALID": "0" + "SLURM_LOCALID": "10" }) @mock.patch('torch.cuda.device_count', return_value=2) def test_accelerator_choice_ddp2_slurm(tmpdir): @@ -142,6 +144,9 @@ def on_fit_start(self, trainer, pl_module): assert trainer.use_ddp2 assert isinstance(trainer.accelerator_backend, accelerators.DDP2Accelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment) + assert trainer.accelerator_backend.task_idx == 10 + assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx + raise SystemExit() model = BoringModel() @@ -159,7 +164,7 @@ def on_fit_start(self, trainer, pl_module): @mock.patch.dict(os.environ, { "CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", - "LOCAL_RANK": "0", + "LOCAL_RANK": "10", "NODE_RANK": "0" }) @mock.patch('torch.cuda.device_count', return_value=2) @@ -169,6 +174,8 @@ def on_fit_start(self, trainer, pl_module): assert trainer.use_ddp assert isinstance(trainer.accelerator_backend, accelerators.DDPTorchElasticAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) + assert trainer.accelerator_backend.task_idx == 10 + assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx raise SystemExit() model = BoringModel() @@ -186,7 +193,7 @@ def on_fit_start(self, trainer, pl_module): @mock.patch.dict(os.environ, { "CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", - "LOCAL_RANK": "0", + "LOCAL_RANK": "10", "NODE_RANK": "0" }) @mock.patch('torch.cuda.device_count', return_value=2) @@ -196,6 +203,8 @@ def on_fit_start(self, trainer, pl_module): assert trainer.use_ddp2 assert isinstance(trainer.accelerator_backend, accelerators.DDP2Accelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) + assert trainer.accelerator_backend.task_idx == 10 + assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx raise SystemExit() model = BoringModel() @@ -212,7 +221,7 @@ def on_fit_start(self, trainer, pl_module): @mock.patch.dict(os.environ, { "WORLD_SIZE": "1", - "LOCAL_RANK": "0", + "LOCAL_RANK": "10", "NODE_RANK": "0" }) @mock.patch('torch.cuda.device_count', return_value=0) @@ -222,6 +231,9 @@ def on_fit_start(self, trainer, pl_module): assert trainer.use_ddp assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUTorchElasticAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) + assert trainer.accelerator_backend.task_idx == 10 + assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx + raise SystemExit() model = BoringModel()