Skip to content

Commit

Permalink
refactor reading env defaults (#6510)
Browse files Browse the repository at this point in the history
* change tests

* fix

* test

* _defaults_from_env_vars

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
(cherry picked from commit 0f07eaf)
  • Loading branch information
Borda authored and lexierule committed Mar 16, 2021
1 parent 0e8f4a8 commit 4b762a9
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 25 deletions.
15 changes: 6 additions & 9 deletions pytorch_lightning/trainer/connectors/env_vars_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,24 @@
from pytorch_lightning.utilities.argparse import get_init_arguments_and_types, parse_env_variables


def overwrite_by_env_vars(fn: Callable) -> Callable:
def _defaults_from_env_vars(fn: Callable) -> Callable:
"""
Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which
input arguments should be moved automatically to the correct device.
"""

@wraps(fn)
def overwrite_by_env_vars(self, *args, **kwargs):
# get the class
cls = self.__class__
def insert_env_defaults(self, *args, **kwargs):
cls = self.__class__ # get the class
if args: # inace any args passed move them to kwargs
# parse only the argument names
cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]
# convert args to kwargs
kwargs.update({k: v for k, v in zip(cls_arg_names, args)})
env_variables = vars(parse_env_variables(cls))
# update the kwargs by env variables
# todo: maybe add a warning that some init args were overwritten by Env arguments
kwargs.update(vars(parse_env_variables(cls)))
kwargs = dict(list(env_variables.items()) + list(kwargs.items()))

# all args were already moved to kwargs
return fn(self, **kwargs)

return overwrite_by_env_vars
return insert_env_defaults
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector
from pytorch_lightning.trainer.connectors.env_vars_connector import overwrite_by_env_vars
from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector
Expand Down Expand Up @@ -83,7 +83,7 @@ class Trainer(
DeprecatedTrainerAttributes,
):

@overwrite_by_env_vars
@_defaults_from_env_vars
def __init__(
self,
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/utilities/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s")


def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
r"""Scans the Trainer signature and returns argument names, types and default values.
r"""Scans the class signature and returns argument names, types and default values.
Returns:
List with tuples of 3 values:
Expand All @@ -120,11 +120,11 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
>>> args = get_init_arguments_and_types(Trainer)
"""
trainer_default_params = inspect.signature(cls).parameters
cls_default_params = inspect.signature(cls).parameters
name_type_default = []
for arg in trainer_default_params:
arg_type = trainer_default_params[arg].annotation
arg_default = trainer_default_params[arg].default
for arg in cls_default_params:
arg_type = cls_default_params[arg].annotation
arg_default = cls_default_params[arg].default
try:
arg_types = tuple(arg_type.__args__)
except AttributeError:
Expand Down
31 changes: 22 additions & 9 deletions tests/trainer/flags/test_env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock

from pytorch_lightning import Trainer


def test_passing_env_variables(tmpdir):
def test_passing_no_env_variables():
"""Testing overwriting trainer arguments """
trainer = Trainer()
assert trainer.logger is not None
Expand All @@ -25,17 +26,29 @@ def test_passing_env_variables(tmpdir):
assert trainer.logger is None
assert trainer.max_steps == 42

os.environ['PL_TRAINER_LOGGER'] = 'False'
os.environ['PL_TRAINER_MAX_STEPS'] = '7'

@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "False", "PL_TRAINER_MAX_STEPS": "7"})
def test_passing_env_variables_only():
"""Testing overwriting trainer arguments """
trainer = Trainer()
assert trainer.logger is None
assert trainer.max_steps == 7

os.environ['PL_TRAINER_LOGGER'] = 'True'

@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "True", "PL_TRAINER_MAX_STEPS": "7"})
def test_passing_env_variables_defaults():
"""Testing overwriting trainer arguments """
trainer = Trainer(False, max_steps=42)
assert trainer.logger is not None
assert trainer.max_steps == 7
assert trainer.logger is None
assert trainer.max_steps == 42


# this has to be cleaned
del os.environ['PL_TRAINER_LOGGER']
del os.environ['PL_TRAINER_MAX_STEPS']
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"})
@mock.patch('torch.cuda.device_count', return_value=2)
@mock.patch('torch.cuda.is_available', return_value=True)
def test_passing_env_variables_gpus(cuda_available_mock, device_count_mock):
"""Testing overwriting trainer arguments """
trainer = Trainer()
assert trainer.gpus == 2
trainer = Trainer(gpus=1)
assert trainer.gpus == 1

0 comments on commit 4b762a9

Please sign in to comment.