diff --git a/pytorch_lightning/trainer/connectors/env_vars_connector.py b/pytorch_lightning/trainer/connectors/env_vars_connector.py index 2e788c256af0d..f4209f40d002e 100644 --- a/pytorch_lightning/trainer/connectors/env_vars_connector.py +++ b/pytorch_lightning/trainer/connectors/env_vars_connector.py @@ -18,27 +18,24 @@ from pytorch_lightning.utilities.argparse import get_init_arguments_and_types, parse_env_variables -def overwrite_by_env_vars(fn: Callable) -> Callable: +def _defaults_from_env_vars(fn: Callable) -> Callable: """ Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which input arguments should be moved automatically to the correct device. - """ - @wraps(fn) - def overwrite_by_env_vars(self, *args, **kwargs): - # get the class - cls = self.__class__ + def insert_env_defaults(self, *args, **kwargs): + cls = self.__class__ # get the class if args: # inace any args passed move them to kwargs # parse only the argument names cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)] # convert args to kwargs kwargs.update({k: v for k, v in zip(cls_arg_names, args)}) + env_variables = vars(parse_env_variables(cls)) # update the kwargs by env variables - # todo: maybe add a warning that some init args were overwritten by Env arguments - kwargs.update(vars(parse_env_variables(cls))) + kwargs = dict(list(env_variables.items()) + list(kwargs.items())) # all args were already moved to kwargs return fn(self, **kwargs) - return overwrite_by_env_vars + return insert_env_defaults diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ebf8ddb1f07ea..cffb1914c69f9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -37,7 +37,7 @@ from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector -from pytorch_lightning.trainer.connectors.env_vars_connector import overwrite_by_env_vars +from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector @@ -83,7 +83,7 @@ class Trainer( DeprecatedTrainerAttributes, ): - @overwrite_by_env_vars + @_defaults_from_env_vars def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index 62626d1b5bcc8..2533dbc425948 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -108,7 +108,7 @@ def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: - r"""Scans the Trainer signature and returns argument names, types and default values. + r"""Scans the class signature and returns argument names, types and default values. Returns: List with tuples of 3 values: @@ -120,11 +120,11 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: >>> args = get_init_arguments_and_types(Trainer) """ - trainer_default_params = inspect.signature(cls).parameters + cls_default_params = inspect.signature(cls).parameters name_type_default = [] - for arg in trainer_default_params: - arg_type = trainer_default_params[arg].annotation - arg_default = trainer_default_params[arg].default + for arg in cls_default_params: + arg_type = cls_default_params[arg].annotation + arg_default = cls_default_params[arg].default try: arg_types = tuple(arg_type.__args__) except AttributeError: diff --git a/tests/trainer/flags/test_env_vars.py b/tests/trainer/flags/test_env_vars.py index ba76820d15ee8..65b251a6633b5 100644 --- a/tests/trainer/flags/test_env_vars.py +++ b/tests/trainer/flags/test_env_vars.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from unittest import mock from pytorch_lightning import Trainer -def test_passing_env_variables(tmpdir): +def test_passing_no_env_variables(): """Testing overwriting trainer arguments """ trainer = Trainer() assert trainer.logger is not None @@ -25,17 +26,29 @@ def test_passing_env_variables(tmpdir): assert trainer.logger is None assert trainer.max_steps == 42 - os.environ['PL_TRAINER_LOGGER'] = 'False' - os.environ['PL_TRAINER_MAX_STEPS'] = '7' + +@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "False", "PL_TRAINER_MAX_STEPS": "7"}) +def test_passing_env_variables_only(): + """Testing overwriting trainer arguments """ trainer = Trainer() assert trainer.logger is None assert trainer.max_steps == 7 - os.environ['PL_TRAINER_LOGGER'] = 'True' + +@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "True", "PL_TRAINER_MAX_STEPS": "7"}) +def test_passing_env_variables_defaults(): + """Testing overwriting trainer arguments """ trainer = Trainer(False, max_steps=42) - assert trainer.logger is not None - assert trainer.max_steps == 7 + assert trainer.logger is None + assert trainer.max_steps == 42 + - # this has to be cleaned - del os.environ['PL_TRAINER_LOGGER'] - del os.environ['PL_TRAINER_MAX_STEPS'] +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"}) +@mock.patch('torch.cuda.device_count', return_value=2) +@mock.patch('torch.cuda.is_available', return_value=True) +def test_passing_env_variables_gpus(cuda_available_mock, device_count_mock): + """Testing overwriting trainer arguments """ + trainer = Trainer() + assert trainer.gpus == 2 + trainer = Trainer(gpus=1) + assert trainer.gpus == 1