From 46f718d2ba31fdbb8f2abbef03471fec66204b66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 7 Dec 2021 03:14:02 +0100 Subject: [PATCH] Fix typing in `pl.plugins.environments` (#10943) --- pyproject.toml | 5 --- .../environments/lightning_environment.py | 10 +++--- .../plugins/environments/lsf_environment.py | 33 +++++++++---------- .../plugins/environments/slurm_environment.py | 11 +++---- .../environments/torchelastic_environment.py | 12 +++---- .../test_torchelastic_environment.py | 4 ++- 6 files changed, 33 insertions(+), 42 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 14d4cf93704ff..5adc3b444e5c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,6 @@ module = [ "pytorch_lightning.callbacks.finetuning", "pytorch_lightning.callbacks.lr_monitor", "pytorch_lightning.callbacks.model_checkpoint", - "pytorch_lightning.callbacks.prediction_writer", "pytorch_lightning.callbacks.progress.base", "pytorch_lightning.callbacks.progress.progress", "pytorch_lightning.callbacks.progress.rich_progress", @@ -70,10 +69,6 @@ module = [ "pytorch_lightning.loggers.test_tube", "pytorch_lightning.loggers.wandb", "pytorch_lightning.loops.epoch.training_epoch_loop", - "pytorch_lightning.plugins.environments.lightning_environment", - "pytorch_lightning.plugins.environments.lsf_environment", - "pytorch_lightning.plugins.environments.slurm_environment", - "pytorch_lightning.plugins.environments.torchelastic_environment", "pytorch_lightning.plugins.training_type.ddp", "pytorch_lightning.plugins.training_type.ddp2", "pytorch_lightning.plugins.training_type.ddp_spawn", diff --git a/pytorch_lightning/plugins/environments/lightning_environment.py b/pytorch_lightning/plugins/environments/lightning_environment.py index ac66fa6214b95..44ec210b560a7 100644 --- a/pytorch_lightning/plugins/environments/lightning_environment.py +++ b/pytorch_lightning/plugins/environments/lightning_environment.py @@ -34,9 +34,9 @@ class LightningEnvironment(ClusterEnvironment): training as it provides a convenient way to launch the training script. """ - def __init__(self): + def __init__(self) -> None: super().__init__() - self._main_port = None + self._main_port: int = -1 self._global_rank: int = 0 self._world_size: int = 1 @@ -55,9 +55,9 @@ def main_address(self) -> str: @property def main_port(self) -> int: - if self._main_port is None: - self._main_port = os.environ.get("MASTER_PORT", find_free_network_port()) - return int(self._main_port) + if self._main_port == -1: + self._main_port = int(os.environ.get("MASTER_PORT", find_free_network_port())) + return self._main_port @staticmethod def detect() -> bool: diff --git a/pytorch_lightning/plugins/environments/lsf_environment.py b/pytorch_lightning/plugins/environments/lsf_environment.py index c25d068ae01bb..653fb1f2f4a6e 100644 --- a/pytorch_lightning/plugins/environments/lsf_environment.py +++ b/pytorch_lightning/plugins/environments/lsf_environment.py @@ -14,6 +14,7 @@ import os import socket +from typing import Dict, List from pytorch_lightning import _logger as log from pytorch_lightning.plugins.environments import ClusterEnvironment @@ -41,7 +42,7 @@ class LSFEnvironment(ClusterEnvironment): The world size for the task. This environment variable is set by jsrun """ - def __init__(self): + def __init__(self) -> None: super().__init__() # TODO: remove in 1.7 if hasattr(self, "is_using_lsf") and callable(self.is_using_lsf): @@ -74,7 +75,7 @@ def detect() -> bool: required_env_vars = {"LSB_JOBID", "LSB_HOSTS", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE"} return required_env_vars.issubset(os.environ.keys()) - def world_size(self): + def world_size(self) -> int: """The world size is read from the environment variable `JSM_NAMESPACE_SIZE`.""" var = "JSM_NAMESPACE_SIZE" world_size = os.environ.get(var) @@ -88,7 +89,7 @@ def world_size(self): def set_world_size(self, size: int) -> None: log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") - def global_rank(self): + def global_rank(self) -> int: """The world size is read from the environment variable `JSM_NAMESPACE_RANK`.""" var = "JSM_NAMESPACE_RANK" global_rank = os.environ.get(var) @@ -102,7 +103,7 @@ def global_rank(self): def set_global_rank(self, rank: int) -> None: log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") - def local_rank(self): + def local_rank(self) -> int: """The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`.""" var = "JSM_NAMESPACE_LOCAL_RANK" local_rank = os.environ.get(var) @@ -113,11 +114,11 @@ def local_rank(self): ) return int(local_rank) - def node_rank(self): + def node_rank(self) -> int: """The node rank is determined by the position of the current hostname in the list of hosts stored in the environment variable `LSB_HOSTS`.""" hosts = self._read_hosts() - count = {} + count: Dict[str, int] = {} for host in hosts: if "batch" in host or "login" in host: continue @@ -126,7 +127,7 @@ def node_rank(self): return count[socket.gethostname()] @staticmethod - def _read_hosts(): + def _read_hosts() -> List[str]: hosts = os.environ.get("LSB_HOSTS") if not hosts: raise ValueError("Could not find hosts in environment variable LSB_HOSTS") @@ -148,15 +149,13 @@ def _get_main_port() -> int: Uses the LSF job ID so all ranks can compute the main port. """ # check for user-specified main port - port = os.environ.get("MASTER_PORT") - if not port: - jobid = os.environ.get("LSB_JOBID") - if not jobid: - raise ValueError("Could not find job id in environment variable LSB_JOBID") - port = int(jobid) + if "MASTER_PORT" in os.environ: + log.debug(f"Using externally specified main port: {os.environ['MASTER_PORT']}") + return int(os.environ["MASTER_PORT"]) + if "LSB_JOBID" in os.environ: + port = int(os.environ["LSB_JOBID"]) # all ports should be in the 10k+ range - port = int(port) % 1000 + 10000 + port = port % 1000 + 10000 log.debug(f"calculated LSF main port: {port}") - else: - log.debug(f"using externally specified main port: {port}") - return int(port) + return port + raise ValueError("Could not find job id in environment variable LSB_JOBID") diff --git a/pytorch_lightning/plugins/environments/slurm_environment.py b/pytorch_lightning/plugins/environments/slurm_environment.py index bde236c672837..c17d2d765464e 100644 --- a/pytorch_lightning/plugins/environments/slurm_environment.py +++ b/pytorch_lightning/plugins/environments/slurm_environment.py @@ -58,10 +58,10 @@ def main_port(self) -> int: # SLURM JOB = PORT number # ----------------------- # this way every process knows what port to use - default_port = os.environ.get("SLURM_JOB_ID") - if default_port: + job_id = os.environ.get("SLURM_JOB_ID") + if job_id is not None: # use the last 4 numbers in the job id as the id - default_port = default_port[-4:] + default_port = job_id[-4:] # all ports should be in the 10k+ range default_port = int(default_port) + 15000 else: @@ -72,13 +72,12 @@ def main_port(self) -> int: # ----------------------- # in case the user passed it in if "MASTER_PORT" in os.environ: - default_port = os.environ["MASTER_PORT"] + default_port = int(os.environ["MASTER_PORT"]) else: os.environ["MASTER_PORT"] = str(default_port) log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") - - return int(default_port) + return default_port @staticmethod def detect() -> bool: diff --git a/pytorch_lightning/plugins/environments/torchelastic_environment.py b/pytorch_lightning/plugins/environments/torchelastic_environment.py index 3631f32daa8d4..a5eed7750989f 100644 --- a/pytorch_lightning/plugins/environments/torchelastic_environment.py +++ b/pytorch_lightning/plugins/environments/torchelastic_environment.py @@ -14,7 +14,6 @@ import logging import os -from typing import Optional from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn @@ -45,8 +44,7 @@ def main_address(self) -> str: rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost") os.environ["MASTER_ADDR"] = "127.0.0.1" log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") - main_address = os.environ.get("MASTER_ADDR") - return main_address + return os.environ["MASTER_ADDR"] @property def main_port(self) -> int: @@ -55,8 +53,7 @@ def main_port(self) -> int: os.environ["MASTER_PORT"] = "12910" log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") - port = int(os.environ.get("MASTER_PORT")) - return port + return int(os.environ["MASTER_PORT"]) @staticmethod def detect() -> bool: @@ -64,9 +61,8 @@ def detect() -> bool: required_env_vars = {"RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE"} return required_env_vars.issubset(os.environ.keys()) - def world_size(self) -> Optional[int]: - world_size = os.environ.get("WORLD_SIZE") - return int(world_size) if world_size is not None else world_size + def world_size(self) -> int: + return int(os.environ["WORLD_SIZE"]) def set_world_size(self, size: int) -> None: log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") diff --git a/tests/plugins/environments/test_torchelastic_environment.py b/tests/plugins/environments/test_torchelastic_environment.py index 9f66e29e3b4c2..beeaab736f6c2 100644 --- a/tests/plugins/environments/test_torchelastic_environment.py +++ b/tests/plugins/environments/test_torchelastic_environment.py @@ -27,7 +27,9 @@ def test_default_attributes(): assert env.creates_processes_externally assert env.main_address == "127.0.0.1" assert env.main_port == 12910 - assert env.world_size() is None + with pytest.raises(KeyError): + # world size is required to be passed as env variable + env.world_size() with pytest.raises(KeyError): # local rank is required to be passed as env variable env.local_rank()