Skip to content

Commit

Permalink
Clean up environment access in plugins (#6941)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Apr 13, 2021
1 parent 4c2005d commit ac71300
Show file tree
Hide file tree
Showing 30 changed files with 715 additions and 210 deletions.
21 changes: 3 additions & 18 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))


- Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ([#4948](https://github.com/PyTorchLightning/pytorch-lightning/pull/4948))


- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))


- Added `teardown()` hook to LightningDataModule ([#4673](https://github.com/PyTorchLightning/pytorch-lightning/pull/4673))


- Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277))


- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))


Expand Down Expand Up @@ -240,6 +222,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950))


- Fixed process rank not being available right away after `Trainer` instantiation ([#6941](https://github.com/PyTorchLightning/pytorch-lightning/pull/6941))


## [1.2.7] - 2021-04-06

### Fixed
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401
40 changes: 30 additions & 10 deletions pytorch_lightning/plugins/environments/cluster_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,44 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod


class ClusterEnvironment:
class ClusterEnvironment(ABC):
""" Specification of a cluster environment. """

def __init__(self):
self._world_size = None
@abstractmethod
def creates_children(self) -> bool:
""" Whether the environment creates the subprocesses or not. """

def master_address(self):
pass
@abstractmethod
def master_address(self) -> str:
""" The master address through which all processes connect and communicate. """

def master_port(self):
pass
@abstractmethod
def master_port(self) -> int:
""" An open and configured port in the master node through which all processes communicate. """

@abstractmethod
def world_size(self) -> int:
return self._world_size
""" The number of processes across all devices and nodes. """

def local_rank(self) -> int:
@abstractmethod
def set_world_size(self, size: int) -> None:
pass

def node_rank(self) -> int:
@abstractmethod
def global_rank(self) -> int:
""" The rank (index) of the currently running process across all nodes and devices. """

@abstractmethod
def set_global_rank(self, rank: int) -> None:
pass

@abstractmethod
def local_rank(self) -> int:
""" The rank (index) of the currently running process inside of the current node. """

@abstractmethod
def node_rank(self) -> int:
""" The rank (index) of the node on which the current process runs. """
83 changes: 83 additions & 0 deletions pytorch_lightning/plugins/environments/lightning_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import socket

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import rank_zero_only


class LightningEnvironment(ClusterEnvironment):
"""
The default environment used by Lightning for a single node or free cluster (not managed).
The master process must be launched by the user and Lightning will spawn new
worker processes for distributed training, either in a single node or across multiple nodes.
If the master address and port are not provided, the default environment will choose them
automatically. It is recommended to use this default environment for single-node distributed
training as it provides the most convenient way to launch the training script.
"""

def __init__(self):
super().__init__()
self._master_port = None
self._global_rank: int = 0
self._world_size: int = 1

def creates_children(self) -> bool:
return False

def master_address(self) -> str:
return os.environ.get("MASTER_ADDR", "127.0.0.1")

def master_port(self) -> int:
if self._master_port is None:
self._master_port = os.environ.get("MASTER_PORT", find_free_network_port())
return int(self._master_port)

def world_size(self) -> int:
return self._world_size

def set_world_size(self, size: int) -> None:
self._world_size = size

def global_rank(self) -> int:
return self._global_rank

def set_global_rank(self, rank: int) -> None:
self._global_rank = rank
rank_zero_only.rank = rank

def local_rank(self) -> int:
return int(os.environ.get("LOCAL_RANK", 0))

def node_rank(self) -> int:
group_rank = os.environ.get("GROUP_RANK", 0)
return int(os.environ.get("NODE_RANK", group_rank))


def find_free_network_port() -> int:
"""
Finds a free port on localhost.
It is useful in single-node training when we don't want to connect to a real master node but
have to set the `MASTER_PORT` environment variable.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
s.close()
return port
29 changes: 19 additions & 10 deletions pytorch_lightning/plugins/environments/slurm_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

class SLURMEnvironment(ClusterEnvironment):

def __init__(self):
super().__init__()
def creates_children(self) -> bool:
return True

def master_address(self):
def master_address(self) -> str:
# figure out the root node addr
slurm_nodelist = os.environ.get("SLURM_NODELIST")
if slurm_nodelist:
Expand All @@ -39,7 +39,7 @@ def master_address(self):
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
return root_node

def master_port(self):
def master_port(self) -> int:
# -----------------------
# SLURM JOB = PORT number
# -----------------------
Expand All @@ -64,18 +64,27 @@ def master_port(self):

log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

return default_port
return int(default_port)

def world_size(self):
return self._world_size
def world_size(self) -> int:
return int(os.environ["SLURM_NTASKS"])

def local_rank(self):
def set_world_size(self, size: int) -> None:
log.debug("SLURMEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

def global_rank(self) -> int:
return int(os.environ["SLURM_PROCID"])

def set_global_rank(self, rank: int) -> None:
log.debug("SLURMEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")

def local_rank(self) -> int:
return int(os.environ['SLURM_LOCALID'])

def node_rank(self):
def node_rank(self) -> int:
return int(os.environ['SLURM_NODEID'])

def resolve_root_node_address(self, root_node):
def resolve_root_node_address(self, root_node: str) -> str:
if '[' in root_node:
name, numbers = root_node.split('[', maxsplit=1)
number = numbers.split(',', maxsplit=1)[0]
Expand Down
35 changes: 27 additions & 8 deletions pytorch_lightning/plugins/environments/torchelastic_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import logging
import os
from typing import Optional

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import rank_zero_warn
Expand All @@ -23,30 +24,48 @@

class TorchElasticEnvironment(ClusterEnvironment):

def __init__(self):
super().__init__()
@staticmethod
def is_using_torchelastic() -> bool:
""" Returns ``True`` if the current process was launched using the torchelastic command. """
required_env_vars = ("RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE")
return all(v in os.environ for v in required_env_vars)

def master_address(self):
def creates_children(self) -> bool:
return True

def master_address(self) -> str:
if "MASTER_ADDR" not in os.environ:
rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost")
os.environ["MASTER_ADDR"] = "127.0.0.1"
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
master_address = os.environ.get('MASTER_ADDR')
return master_address

def master_port(self):
def master_port(self) -> int:
if "MASTER_PORT" not in os.environ:
rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910")
os.environ["MASTER_PORT"] = "12910"
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

port = os.environ.get('MASTER_PORT')
port = int(os.environ.get('MASTER_PORT'))
return port

def world_size(self):
return os.environ.get('WORLD_SIZE')
def world_size(self) -> Optional[int]:
world_size = os.environ.get('WORLD_SIZE')
return int(world_size) if world_size is not None else world_size

def set_world_size(self, size: int) -> None:
log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

def global_rank(self) -> int:
return int(os.environ["RANK"])

def set_global_rank(self, rank: int) -> None:
log.debug(
"TorchElasticEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored."
)

def local_rank(self):
def local_rank(self) -> int:
return int(os.environ['LOCAL_RANK'])

def node_rank(self) -> int:
Expand Down
Loading

0 comments on commit ac71300

Please sign in to comment.