Skip to content

Commit

Permalink
Fix typing in pl.plugins.environments (#10943)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Dec 7, 2021
1 parent 6bfc0bb commit 46f718d
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 42 deletions.
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ module = [
"pytorch_lightning.callbacks.finetuning",
"pytorch_lightning.callbacks.lr_monitor",
"pytorch_lightning.callbacks.model_checkpoint",
"pytorch_lightning.callbacks.prediction_writer",
"pytorch_lightning.callbacks.progress.base",
"pytorch_lightning.callbacks.progress.progress",
"pytorch_lightning.callbacks.progress.rich_progress",
Expand All @@ -70,10 +69,6 @@ module = [
"pytorch_lightning.loggers.test_tube",
"pytorch_lightning.loggers.wandb",
"pytorch_lightning.loops.epoch.training_epoch_loop",
"pytorch_lightning.plugins.environments.lightning_environment",
"pytorch_lightning.plugins.environments.lsf_environment",
"pytorch_lightning.plugins.environments.slurm_environment",
"pytorch_lightning.plugins.environments.torchelastic_environment",
"pytorch_lightning.plugins.training_type.ddp",
"pytorch_lightning.plugins.training_type.ddp2",
"pytorch_lightning.plugins.training_type.ddp_spawn",
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/plugins/environments/lightning_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ class LightningEnvironment(ClusterEnvironment):
training as it provides a convenient way to launch the training script.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__()
self._main_port = None
self._main_port: int = -1
self._global_rank: int = 0
self._world_size: int = 1

Expand All @@ -55,9 +55,9 @@ def main_address(self) -> str:

@property
def main_port(self) -> int:
if self._main_port is None:
self._main_port = os.environ.get("MASTER_PORT", find_free_network_port())
return int(self._main_port)
if self._main_port == -1:
self._main_port = int(os.environ.get("MASTER_PORT", find_free_network_port()))
return self._main_port

@staticmethod
def detect() -> bool:
Expand Down
33 changes: 16 additions & 17 deletions pytorch_lightning/plugins/environments/lsf_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os
import socket
from typing import Dict, List

from pytorch_lightning import _logger as log
from pytorch_lightning.plugins.environments import ClusterEnvironment
Expand Down Expand Up @@ -41,7 +42,7 @@ class LSFEnvironment(ClusterEnvironment):
The world size for the task. This environment variable is set by jsrun
"""

def __init__(self):
def __init__(self) -> None:
super().__init__()
# TODO: remove in 1.7
if hasattr(self, "is_using_lsf") and callable(self.is_using_lsf):
Expand Down Expand Up @@ -74,7 +75,7 @@ def detect() -> bool:
required_env_vars = {"LSB_JOBID", "LSB_HOSTS", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE"}
return required_env_vars.issubset(os.environ.keys())

def world_size(self):
def world_size(self) -> int:
"""The world size is read from the environment variable `JSM_NAMESPACE_SIZE`."""
var = "JSM_NAMESPACE_SIZE"
world_size = os.environ.get(var)
Expand All @@ -88,7 +89,7 @@ def world_size(self):
def set_world_size(self, size: int) -> None:
log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

def global_rank(self):
def global_rank(self) -> int:
"""The world size is read from the environment variable `JSM_NAMESPACE_RANK`."""
var = "JSM_NAMESPACE_RANK"
global_rank = os.environ.get(var)
Expand All @@ -102,7 +103,7 @@ def global_rank(self):
def set_global_rank(self, rank: int) -> None:
log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")

def local_rank(self):
def local_rank(self) -> int:
"""The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`."""
var = "JSM_NAMESPACE_LOCAL_RANK"
local_rank = os.environ.get(var)
Expand All @@ -113,11 +114,11 @@ def local_rank(self):
)
return int(local_rank)

def node_rank(self):
def node_rank(self) -> int:
"""The node rank is determined by the position of the current hostname in the list of hosts stored in the
environment variable `LSB_HOSTS`."""
hosts = self._read_hosts()
count = {}
count: Dict[str, int] = {}
for host in hosts:
if "batch" in host or "login" in host:
continue
Expand All @@ -126,7 +127,7 @@ def node_rank(self):
return count[socket.gethostname()]

@staticmethod
def _read_hosts():
def _read_hosts() -> List[str]:
hosts = os.environ.get("LSB_HOSTS")
if not hosts:
raise ValueError("Could not find hosts in environment variable LSB_HOSTS")
Expand All @@ -148,15 +149,13 @@ def _get_main_port() -> int:
Uses the LSF job ID so all ranks can compute the main port.
"""
# check for user-specified main port
port = os.environ.get("MASTER_PORT")
if not port:
jobid = os.environ.get("LSB_JOBID")
if not jobid:
raise ValueError("Could not find job id in environment variable LSB_JOBID")
port = int(jobid)
if "MASTER_PORT" in os.environ:
log.debug(f"Using externally specified main port: {os.environ['MASTER_PORT']}")
return int(os.environ["MASTER_PORT"])
if "LSB_JOBID" in os.environ:
port = int(os.environ["LSB_JOBID"])
# all ports should be in the 10k+ range
port = int(port) % 1000 + 10000
port = port % 1000 + 10000
log.debug(f"calculated LSF main port: {port}")
else:
log.debug(f"using externally specified main port: {port}")
return int(port)
return port
raise ValueError("Could not find job id in environment variable LSB_JOBID")
11 changes: 5 additions & 6 deletions pytorch_lightning/plugins/environments/slurm_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def main_port(self) -> int:
# SLURM JOB = PORT number
# -----------------------
# this way every process knows what port to use
default_port = os.environ.get("SLURM_JOB_ID")
if default_port:
job_id = os.environ.get("SLURM_JOB_ID")
if job_id is not None:
# use the last 4 numbers in the job id as the id
default_port = default_port[-4:]
default_port = job_id[-4:]
# all ports should be in the 10k+ range
default_port = int(default_port) + 15000
else:
Expand All @@ -72,13 +72,12 @@ def main_port(self) -> int:
# -----------------------
# in case the user passed it in
if "MASTER_PORT" in os.environ:
default_port = os.environ["MASTER_PORT"]
default_port = int(os.environ["MASTER_PORT"])
else:
os.environ["MASTER_PORT"] = str(default_port)

log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

return int(default_port)
return default_port

@staticmethod
def detect() -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import logging
import os
from typing import Optional

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
Expand Down Expand Up @@ -45,8 +44,7 @@ def main_address(self) -> str:
rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost")
os.environ["MASTER_ADDR"] = "127.0.0.1"
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
main_address = os.environ.get("MASTER_ADDR")
return main_address
return os.environ["MASTER_ADDR"]

@property
def main_port(self) -> int:
Expand All @@ -55,18 +53,16 @@ def main_port(self) -> int:
os.environ["MASTER_PORT"] = "12910"
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

port = int(os.environ.get("MASTER_PORT"))
return port
return int(os.environ["MASTER_PORT"])

@staticmethod
def detect() -> bool:
"""Returns ``True`` if the current process was launched using the torchelastic command."""
required_env_vars = {"RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE"}
return required_env_vars.issubset(os.environ.keys())

def world_size(self) -> Optional[int]:
world_size = os.environ.get("WORLD_SIZE")
return int(world_size) if world_size is not None else world_size
def world_size(self) -> int:
return int(os.environ["WORLD_SIZE"])

def set_world_size(self, size: int) -> None:
log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
Expand Down
4 changes: 3 additions & 1 deletion tests/plugins/environments/test_torchelastic_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def test_default_attributes():
assert env.creates_processes_externally
assert env.main_address == "127.0.0.1"
assert env.main_port == 12910
assert env.world_size() is None
with pytest.raises(KeyError):
# world size is required to be passed as env variable
env.world_size()
with pytest.raises(KeyError):
# local rank is required to be passed as env variable
env.local_rank()
Expand Down

0 comments on commit 46f718d

Please sign in to comment.