Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: adding compute environments (1/n) #3837

Merged
merged 3 commits into from
Oct 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,4 @@ Indices and tables
api/pytorch_lightning.tuner
api/pytorch_lightning.plugins
api/pytorch_lightning.distributed
api/pytorch_lightning.cluster_environments
3 changes: 3 additions & 0 deletions pytorch_lightning/cluster_environments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rename this package just to environments as TE is not a true cluster and there may be another non cluster evns

from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment
13 changes: 13 additions & 0 deletions pytorch_lightning/cluster_environments/cluster_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class ClusterEnvironment:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to the file general - missing Licence


def __init__(self, world_size):
self._world_size = world_size

def master_address(self):
pass

def master_port(self):
pass

def world_size(self):
return self._world_size
66 changes: 66 additions & 0 deletions pytorch_lightning/cluster_environments/slurm_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import re
from pytorch_lightning import _logger as log
from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment


class SLURMEnvironment(ClusterEnvironment):

def __init__(self, world_size):
super().__init__(world_size)

def master_address(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall it be rather a property?

# figure out the root node addr
try:
root_node = os.environ["SLURM_NODELIST"].split(" ")[0]
except Exception:
root_node = "127.0.0.1"

root_node = self._resolve_root_node_address(root_node)
os.environ["MASTER_ADDR"] = root_node
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
return root_node

def master_port(self):
# -----------------------
# SLURM JOB = PORT number
# -----------------------
# this way every process knows what port to use
try:
# use the last 4 numbers in the job id as the id
default_port = os.environ["SLURM_JOB_ID"]
default_port = default_port[-4:]

# all ports should be in the 10k+ range
default_port = int(default_port) + 15000

except Exception:
default_port = 12910

# -----------------------
# PORT NUMBER = MASTER_PORT
# -----------------------
# in case the user passed it in
try:
default_port = os.environ["MASTER_PORT"]
except Exception:
os.environ["MASTER_PORT"] = str(default_port)

log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

return default_port

def world_size(self):
return self._world_size

def _resolve_root_node_address(self, root_node):
if '[' in root_node:
name, numbers = root_node.split('[', maxsplit=1)
number = numbers.split(',', maxsplit=1)[0]
if '-' in number:
number = number.split('-')[0]

number = re.sub('[^0-9]', '', number)
root_node = name + number

return root_node
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment


class TorchElasticEnvironment(ClusterEnvironment):

def __init__(self, world_size):
super().__init__(world_size)

def master_address(self):
if "MASTER_ADDR" not in os.environ:
rank_zero_warn(
"MASTER_ADDR environment variable is not defined. Set as localhost"
)
os.environ["MASTER_ADDR"] = "127.0.0.1"
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
master_address = os.environ.get('MASTER_ADDR')
return master_address

def master_port(self):
if "MASTER_PORT" not in os.environ:
rank_zero_warn(
"MASTER_PORT environment variable is not defined. Set as 12910"
)
os.environ["MASTER_PORT"] = "12910"
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

port = os.environ.get('MASTER_PORT')
return port

def world_size(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this probably also does not change during training, so a property?

return os.environ.get('WORLD_SIZE', None)