Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor reading env defaults #6510

Merged
merged 7 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions pytorch_lightning/trainer/connectors/env_vars_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,24 @@
from pytorch_lightning.utilities.argparse import get_init_arguments_and_types, parse_env_variables


def overwrite_by_env_vars(fn: Callable) -> Callable:
def _defaults_from_env_vars(fn: Callable) -> Callable:
"""
Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which
input arguments should be moved automatically to the correct device.

"""

@wraps(fn)
def overwrite_by_env_vars(self, *args, **kwargs):
# get the class
cls = self.__class__
def insert_env_defaults(self, *args, **kwargs):
cls = self.__class__ # get the class
if args: # inace any args passed move them to kwargs
# parse only the argument names
cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]
# convert args to kwargs
kwargs.update({k: v for k, v in zip(cls_arg_names, args)})
env_variables = vars(parse_env_variables(cls))
# update the kwargs by env variables
# todo: maybe add a warning that some init args were overwritten by Env arguments
kwargs.update(vars(parse_env_variables(cls)))
kwargs = dict(list(env_variables.items()) + list(kwargs.items()))

# all args were already moved to kwargs
return fn(self, **kwargs)

return overwrite_by_env_vars
return insert_env_defaults
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector
from pytorch_lightning.trainer.connectors.env_vars_connector import overwrite_by_env_vars
from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector
Expand Down Expand Up @@ -84,7 +84,7 @@ class Trainer(
DeprecatedTrainerAttributes,
):

@overwrite_by_env_vars
@_defaults_from_env_vars
def __init__(
self,
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/utilities/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s")


def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
r"""Scans the Trainer signature and returns argument names, types and default values.
r"""Scans the class signature and returns argument names, types and default values.

Returns:
List with tuples of 3 values:
Expand All @@ -119,11 +119,11 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
>>> args = get_init_arguments_and_types(Trainer)

"""
trainer_default_params = inspect.signature(cls).parameters
cls_default_params = inspect.signature(cls).parameters
name_type_default = []
for arg in trainer_default_params:
arg_type = trainer_default_params[arg].annotation
arg_default = trainer_default_params[arg].default
for arg in cls_default_params:
arg_type = cls_default_params[arg].annotation
arg_default = cls_default_params[arg].default
try:
arg_types = tuple(arg_type.__args__)
except AttributeError:
Expand Down
31 changes: 22 additions & 9 deletions tests/trainer/flags/test_env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock

from pytorch_lightning import Trainer


def test_passing_env_variables(tmpdir):
def test_passing_no_env_variables():
"""Testing overwriting trainer arguments """
trainer = Trainer()
assert trainer.logger is not None
Expand All @@ -25,17 +26,29 @@ def test_passing_env_variables(tmpdir):
assert trainer.logger is None
assert trainer.max_steps == 42

os.environ['PL_TRAINER_LOGGER'] = 'False'
os.environ['PL_TRAINER_MAX_STEPS'] = '7'

@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "False", "PL_TRAINER_MAX_STEPS": "7"})
def test_passing_env_variables_only():
"""Testing overwriting trainer arguments """
trainer = Trainer()
assert trainer.logger is None
assert trainer.max_steps == 7

os.environ['PL_TRAINER_LOGGER'] = 'True'

@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "True", "PL_TRAINER_MAX_STEPS": "7"})
def test_passing_env_variables_defaults():
"""Testing overwriting trainer arguments """
trainer = Trainer(False, max_steps=42)
assert trainer.logger is not None
assert trainer.max_steps == 7
assert trainer.logger is None
assert trainer.max_steps == 42


# this has to be cleaned
del os.environ['PL_TRAINER_LOGGER']
del os.environ['PL_TRAINER_MAX_STEPS']
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"})
@mock.patch('torch.cuda.device_count', return_value=2)
@mock.patch('torch.cuda.is_available', return_value=True)
def test_passing_env_variables_gpus(cuda_available_mock, device_count_mock):
"""Testing overwriting trainer arguments """
trainer = Trainer()
assert trainer.gpus == 2
trainer = Trainer(gpus=1)
assert trainer.gpus == 1