Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: unify slurm and TE under backendPlugin 1/n #4578

Merged
merged 2 commits into from
Nov 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,8 @@ def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
self.nickname = 'ddp2'

def setup(self, model):
self._resolve_task_idx()
self.trainer.model = model

def _resolve_task_idx(self):
if self.trainer.is_slurm_managing_tasks:
self.task_idx = int(os.environ['SLURM_LOCALID'])
else:
# torchelastic or general non_slurm ddp2
try:
self.task_idx = int(os.environ['LOCAL_RANK'])
except Exception as exp:
m = 'ddp2 only works in SLURM or via torchelastic with the WORLD_SIZE, LOCAL_RANK, GROUP_RANK flags'
raise MisconfigurationException(m) from exp
self.task_idx = self.cluster_environment.local_rank()

def train(self):
model = self.trainer.model
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/ddp_cpu_slurm_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):

def setup(self, model):
self.trainer.model = model
self.task_idx = int(os.environ['SLURM_LOCALID'])
self.task_idx = self.cluster_environment.local_rank()

def train(self):
model = self.trainer.model
Expand Down Expand Up @@ -118,7 +118,7 @@ def ddp_train(self, process_idx, model):
self.set_world_ranks(process_idx)

# toggle prog bar
if self.trainer.global_rank == 0 and self.trainer.progress_bar_callback is not None:
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()

# set warning rank
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):

def setup(self, model):
self.trainer.model = model
self.task_idx = int(os.environ['LOCAL_RANK'])
self.task_idx = self.cluster_environment.local_rank()

def train(self):
model = self.trainer.model
Expand Down Expand Up @@ -117,7 +117,7 @@ def ddp_train(self, process_idx, model):
self.set_world_ranks(process_idx)

# toggle prog bar
if self.trainer.global_rank == 0 and self.trainer.progress_bar_callback is not None:
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()

# set warning rank
Expand Down
11 changes: 3 additions & 8 deletions pytorch_lightning/accelerators/ddp_slurm_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
from pytorch_lightning.utilities.seed import seed_everything

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand All @@ -52,7 +51,7 @@ def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):

def setup(self, model):
self.trainer.model = model
self.task_idx = int(os.environ['SLURM_LOCALID'])
self.task_idx = self.cluster_environment.local_rank()

def train(self):
model = self.trainer.model
Expand Down Expand Up @@ -88,7 +87,7 @@ def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: str = None):
def barrier(self, name: Optional[str] = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()

Expand All @@ -115,15 +114,11 @@ def ddp_train(self, process_idx, model):
Dict with evaluation results

"""
seed = os.environ.get("PL_GLOBAL_SEED")
if seed is not None:
seed_everything(int(seed))

# determine which process we are and world size
self.set_world_ranks(process_idx)

# toggle prog bar
if self.trainer.global_rank != 0 and self.trainer.progress_bar_callback is not None:
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()

# set warning rank
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available


try:
Expand Down Expand Up @@ -53,7 +52,7 @@ def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):

def setup(self, model):
self.trainer.model = model
self.task_idx = int(os.environ['LOCAL_RANK'])
self.task_idx = self.cluster_environment.local_rank()

def train(self):
model = self.trainer.model
Expand Down Expand Up @@ -120,7 +119,7 @@ def ddp_train(self, process_idx, model):
self.set_world_ranks(process_idx)

# toggle prog bar
if self.trainer.global_rank == 0 and self.trainer.progress_bar_callback is not None:
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()

# set warning rank
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


class ClusterEnvironment:

def __init__(self):
Expand All @@ -25,3 +26,6 @@ def master_port(self):

def world_size(self):
return self._world_size

def local_rank(self):
pass
3 changes: 3 additions & 0 deletions pytorch_lightning/cluster_environments/slurm_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def master_port(self):
def world_size(self):
return self._world_size

def local_rank(self):
return int(os.environ['SLURM_LOCALID'])

def _resolve_root_node_address(self, root_node):
if '[' in root_node:
name, numbers = root_node.split('[', maxsplit=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ def master_port(self):

def world_size(self):
return os.environ.get('WORLD_SIZE')

def local_rank(self):
return int(os.environ['LOCAL_RANK'])
22 changes: 17 additions & 5 deletions tests/backends/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def on_fit_start(self, trainer, pl_module):
"SLURM_NTASKS": "2",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"SLURM_LOCALID": "0"
"SLURM_LOCALID": "10"
})
@mock.patch('torch.cuda.device_count', return_value=2)
def test_accelerator_choice_ddp_slurm(tmpdir):
Expand All @@ -113,6 +113,8 @@ def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp
assert isinstance(trainer.accelerator_backend, accelerators.DDPSLURMAccelerator)
assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment)
assert trainer.accelerator_backend.task_idx == 10
assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx
raise SystemExit()

model = BoringModel()
Expand All @@ -133,7 +135,7 @@ def on_fit_start(self, trainer, pl_module):
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_LOCALID": "0"
"SLURM_LOCALID": "10"
})
@mock.patch('torch.cuda.device_count', return_value=2)
def test_accelerator_choice_ddp2_slurm(tmpdir):
Expand All @@ -142,6 +144,9 @@ def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp2
assert isinstance(trainer.accelerator_backend, accelerators.DDP2Accelerator)
assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment)
assert trainer.accelerator_backend.task_idx == 10
assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx

raise SystemExit()

model = BoringModel()
Expand All @@ -159,7 +164,7 @@ def on_fit_start(self, trainer, pl_module):
@mock.patch.dict(os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
"WORLD_SIZE": "2",
"LOCAL_RANK": "0",
"LOCAL_RANK": "10",
"NODE_RANK": "0"
})
@mock.patch('torch.cuda.device_count', return_value=2)
Expand All @@ -169,6 +174,8 @@ def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp
assert isinstance(trainer.accelerator_backend, accelerators.DDPTorchElasticAccelerator)
assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment)
assert trainer.accelerator_backend.task_idx == 10
assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx
raise SystemExit()

model = BoringModel()
Expand All @@ -186,7 +193,7 @@ def on_fit_start(self, trainer, pl_module):
@mock.patch.dict(os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
"WORLD_SIZE": "2",
"LOCAL_RANK": "0",
"LOCAL_RANK": "10",
"NODE_RANK": "0"
})
@mock.patch('torch.cuda.device_count', return_value=2)
Expand All @@ -196,6 +203,8 @@ def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp2
assert isinstance(trainer.accelerator_backend, accelerators.DDP2Accelerator)
assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment)
assert trainer.accelerator_backend.task_idx == 10
assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx
raise SystemExit()

model = BoringModel()
Expand All @@ -212,7 +221,7 @@ def on_fit_start(self, trainer, pl_module):

@mock.patch.dict(os.environ, {
"WORLD_SIZE": "1",
"LOCAL_RANK": "0",
"LOCAL_RANK": "10",
"NODE_RANK": "0"
})
@mock.patch('torch.cuda.device_count', return_value=0)
Expand All @@ -222,6 +231,9 @@ def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp
assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUTorchElasticAccelerator)
assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment)
assert trainer.accelerator_backend.task_idx == 10
assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx

raise SystemExit()

model = BoringModel()
Expand Down