Skip to content

Commit

Permalink
ref: adding compute environments (1/n) (#3837)
Browse files Browse the repository at this point in the history
* ref: adding compute environments (1/n)

* ref: adding compute environments (1/n)

* ref: adding compute environments (1/n)
  • Loading branch information
williamFalcon committed Oct 4, 2020
1 parent a3503ce commit 093535d
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,4 @@ Indices and tables
api/pytorch_lightning.tuner
api/pytorch_lightning.plugins
api/pytorch_lightning.distributed
api/pytorch_lightning.cluster_environments
3 changes: 3 additions & 0 deletions pytorch_lightning/cluster_environments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment
13 changes: 13 additions & 0 deletions pytorch_lightning/cluster_environments/cluster_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class ClusterEnvironment:

def __init__(self, world_size):
self._world_size = world_size

def master_address(self):
pass

def master_port(self):
pass

def world_size(self):
return self._world_size
66 changes: 66 additions & 0 deletions pytorch_lightning/cluster_environments/slurm_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import re
from pytorch_lightning import _logger as log
from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment


class SLURMEnvironment(ClusterEnvironment):

def __init__(self, world_size):
super().__init__(world_size)

def master_address(self):
# figure out the root node addr
try:
root_node = os.environ["SLURM_NODELIST"].split(" ")[0]
except Exception:
root_node = "127.0.0.1"

root_node = self._resolve_root_node_address(root_node)
os.environ["MASTER_ADDR"] = root_node
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
return root_node

def master_port(self):
# -----------------------
# SLURM JOB = PORT number
# -----------------------
# this way every process knows what port to use
try:
# use the last 4 numbers in the job id as the id
default_port = os.environ["SLURM_JOB_ID"]
default_port = default_port[-4:]

# all ports should be in the 10k+ range
default_port = int(default_port) + 15000

except Exception:
default_port = 12910

# -----------------------
# PORT NUMBER = MASTER_PORT
# -----------------------
# in case the user passed it in
try:
default_port = os.environ["MASTER_PORT"]
except Exception:
os.environ["MASTER_PORT"] = str(default_port)

log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

return default_port

def world_size(self):
return self._world_size

def _resolve_root_node_address(self, root_node):
if '[' in root_node:
name, numbers = root_node.split('[', maxsplit=1)
number = numbers.split(',', maxsplit=1)[0]
if '-' in number:
number = number.split('-')[0]

number = re.sub('[^0-9]', '', number)
root_node = name + number

return root_node
34 changes: 34 additions & 0 deletions pytorch_lightning/cluster_environments/torchelastic_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment


class TorchElasticEnvironment(ClusterEnvironment):

def __init__(self, world_size):
super().__init__(world_size)

def master_address(self):
if "MASTER_ADDR" not in os.environ:
rank_zero_warn(
"MASTER_ADDR environment variable is not defined. Set as localhost"
)
os.environ["MASTER_ADDR"] = "127.0.0.1"
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
master_address = os.environ.get('MASTER_ADDR')
return master_address

def master_port(self):
if "MASTER_PORT" not in os.environ:
rank_zero_warn(
"MASTER_PORT environment variable is not defined. Set as 12910"
)
os.environ["MASTER_PORT"] = "12910"
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

port = os.environ.get('MASTER_PORT')
return port

def world_size(self):
return os.environ.get('WORLD_SIZE', None)

0 comments on commit 093535d

Please sign in to comment.