From 093535d433b025fad6fdd10c593724ddc56fedfd Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 4 Oct 2020 07:31:19 -0400 Subject: [PATCH] ref: adding compute environments (1/n) (#3837) * ref: adding compute environments (1/n) * ref: adding compute environments (1/n) * ref: adding compute environments (1/n) --- docs/source/index.rst | 1 + .../cluster_environments/__init__.py | 3 + .../cluster_environment.py | 13 ++++ .../cluster_environments/slurm_environment.py | 66 +++++++++++++++++++ .../torchelastic_environment.py | 34 ++++++++++ 5 files changed, 117 insertions(+) create mode 100644 pytorch_lightning/cluster_environments/__init__.py create mode 100644 pytorch_lightning/cluster_environments/cluster_environment.py create mode 100644 pytorch_lightning/cluster_environments/slurm_environment.py create mode 100644 pytorch_lightning/cluster_environments/torchelastic_environment.py diff --git a/docs/source/index.rst b/docs/source/index.rst index ec85f7e043df3..03cebc03871a8 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -142,3 +142,4 @@ Indices and tables api/pytorch_lightning.tuner api/pytorch_lightning.plugins api/pytorch_lightning.distributed + api/pytorch_lightning.cluster_environments diff --git a/pytorch_lightning/cluster_environments/__init__.py b/pytorch_lightning/cluster_environments/__init__.py new file mode 100644 index 0000000000000..539902be141da --- /dev/null +++ b/pytorch_lightning/cluster_environments/__init__.py @@ -0,0 +1,3 @@ +from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment +from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment diff --git a/pytorch_lightning/cluster_environments/cluster_environment.py b/pytorch_lightning/cluster_environments/cluster_environment.py new file mode 100644 index 0000000000000..653cab7e4361c --- /dev/null +++ b/pytorch_lightning/cluster_environments/cluster_environment.py @@ -0,0 +1,13 @@ +class ClusterEnvironment: + + def __init__(self, world_size): + self._world_size = world_size + + def master_address(self): + pass + + def master_port(self): + pass + + def world_size(self): + return self._world_size diff --git a/pytorch_lightning/cluster_environments/slurm_environment.py b/pytorch_lightning/cluster_environments/slurm_environment.py new file mode 100644 index 0000000000000..d40884b31e062 --- /dev/null +++ b/pytorch_lightning/cluster_environments/slurm_environment.py @@ -0,0 +1,66 @@ +import os +import re +from pytorch_lightning import _logger as log +from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment + + +class SLURMEnvironment(ClusterEnvironment): + + def __init__(self, world_size): + super().__init__(world_size) + + def master_address(self): + # figure out the root node addr + try: + root_node = os.environ["SLURM_NODELIST"].split(" ")[0] + except Exception: + root_node = "127.0.0.1" + + root_node = self._resolve_root_node_address(root_node) + os.environ["MASTER_ADDR"] = root_node + log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") + return root_node + + def master_port(self): + # ----------------------- + # SLURM JOB = PORT number + # ----------------------- + # this way every process knows what port to use + try: + # use the last 4 numbers in the job id as the id + default_port = os.environ["SLURM_JOB_ID"] + default_port = default_port[-4:] + + # all ports should be in the 10k+ range + default_port = int(default_port) + 15000 + + except Exception: + default_port = 12910 + + # ----------------------- + # PORT NUMBER = MASTER_PORT + # ----------------------- + # in case the user passed it in + try: + default_port = os.environ["MASTER_PORT"] + except Exception: + os.environ["MASTER_PORT"] = str(default_port) + + log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") + + return default_port + + def world_size(self): + return self._world_size + + def _resolve_root_node_address(self, root_node): + if '[' in root_node: + name, numbers = root_node.split('[', maxsplit=1) + number = numbers.split(',', maxsplit=1)[0] + if '-' in number: + number = number.split('-')[0] + + number = re.sub('[^0-9]', '', number) + root_node = name + number + + return root_node diff --git a/pytorch_lightning/cluster_environments/torchelastic_environment.py b/pytorch_lightning/cluster_environments/torchelastic_environment.py new file mode 100644 index 0000000000000..e8a9890be3347 --- /dev/null +++ b/pytorch_lightning/cluster_environments/torchelastic_environment.py @@ -0,0 +1,34 @@ +import os +from pytorch_lightning import _logger as log +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment + + +class TorchElasticEnvironment(ClusterEnvironment): + + def __init__(self, world_size): + super().__init__(world_size) + + def master_address(self): + if "MASTER_ADDR" not in os.environ: + rank_zero_warn( + "MASTER_ADDR environment variable is not defined. Set as localhost" + ) + os.environ["MASTER_ADDR"] = "127.0.0.1" + log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") + master_address = os.environ.get('MASTER_ADDR') + return master_address + + def master_port(self): + if "MASTER_PORT" not in os.environ: + rank_zero_warn( + "MASTER_PORT environment variable is not defined. Set as 12910" + ) + os.environ["MASTER_PORT"] = "12910" + log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") + + port = os.environ.get('MASTER_PORT') + return port + + def world_size(self): + return os.environ.get('WORLD_SIZE', None)