From a884866ff0b72f54412c2ffd38f86b5c928c1395 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 22 Dec 2020 00:23:33 +0100 Subject: [PATCH] Unify names in Utils (#5199) * warnings * argparse * mutils * xla device * deprecated * tests * simple * flake8 * fix * flake8 * 1.4 --- pytorch_lightning/overrides/data_parallel.py | 2 +- .../trainer/configuration_validator.py | 2 +- .../trainer/connectors/data_connector.py | 2 +- .../trainer/connectors/env_vars_connector.py | 2 +- .../logger_connector/logger_connector.py | 2 +- pytorch_lightning/trainer/data_loading.py | 2 +- pytorch_lightning/trainer/evaluation_loop.py | 4 +- pytorch_lightning/trainer/properties.py | 15 +- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/trainer/training_loop.py | 4 +- pytorch_lightning/utilities/__init__.py | 2 +- pytorch_lightning/utilities/argparse.py | 249 +++++++++++++++++ pytorch_lightning/utilities/argparse_utils.py | 251 +----------------- pytorch_lightning/utilities/model_helpers.py | 43 +++ pytorch_lightning/utilities/model_utils.py | 45 +--- pytorch_lightning/utilities/warning_utils.py | 27 +- pytorch_lightning/utilities/warnings.py | 25 ++ pytorch_lightning/utilities/xla_device.py | 104 ++++++++ .../utilities/xla_device_utils.py | 106 +------- tests/backends/test_tpu_backend.py | 2 +- tests/core/test_datamodules.py | 2 +- tests/deprecated_api/test_remove_1-2.py | 4 +- tests/deprecated_api/test_remove_1-3.py | 7 +- tests/deprecated_api/test_remove_1-4.py | 35 +++ tests/trainer/test_trainer_cli.py | 4 +- tests/utilities/test_argparse_utils.py | 2 +- tests/utilities/test_xla_device_utils.py | 2 +- 27 files changed, 505 insertions(+), 442 deletions(-) create mode 100644 pytorch_lightning/utilities/argparse.py create mode 100644 pytorch_lightning/utilities/model_helpers.py create mode 100644 pytorch_lightning/utilities/warnings.py create mode 100644 pytorch_lightning/utilities/xla_device.py create mode 100644 tests/deprecated_api/test_remove_1-4.py diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 79dd940866f40..855f026701c7d 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -24,7 +24,7 @@ from torch.nn.parallel._functions import Gather from pytorch_lightning.core.step_result import Result -from pytorch_lightning.utilities.warning_utils import WarningCache +from pytorch_lightning.utilities.warnings import WarningCache def _find_tensors(obj): # pragma: no-cover diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 7bafd819e7a2b..6e3e062e45a53 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -14,7 +14,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.model_utils import is_overridden +from pytorch_lightning.utilities.model_helpers import is_overridden class ConfigValidator(object): diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 7ed957000ec71..641b604785f03 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -16,7 +16,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from typing import List, Optional, Union from torch.utils.data import DataLoader -from pytorch_lightning.utilities.model_utils import is_overridden +from pytorch_lightning.utilities.model_helpers import is_overridden class DataConnector(object): diff --git a/pytorch_lightning/trainer/connectors/env_vars_connector.py b/pytorch_lightning/trainer/connectors/env_vars_connector.py index 2cbbc8e40e909..5fc8497fa7db3 100644 --- a/pytorch_lightning/trainer/connectors/env_vars_connector.py +++ b/pytorch_lightning/trainer/connectors/env_vars_connector.py @@ -15,7 +15,7 @@ from functools import wraps from typing import Callable -from pytorch_lightning.utilities.argparse_utils import parse_env_variables, get_init_arguments_and_types +from pytorch_lightning.utilities.argparse import parse_env_variables, get_init_arguments_and_types def overwrite_by_env_vars(fn: Callable) -> Callable: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index a3f86f62874ca..f417294aa8ec0 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -25,7 +25,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore, LoggerStages from pytorch_lightning.utilities import flatten_dict from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.model_utils import is_overridden +from pytorch_lightning.utilities.model_helpers import is_overridden class LoggerConnector: diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index f2bcb1d1760d6..f8452067e0c67 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -27,7 +27,7 @@ from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.model_utils import is_overridden +from pytorch_lightning.utilities.model_helpers import is_overridden class TrainerDataLoadingMixin(ABC): diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index af153846c3a88..a8fa9f43684ca 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -16,8 +16,8 @@ from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.model_utils import is_overridden -from pytorch_lightning.utilities.warning_utils import WarningCache +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.warnings import WarningCache class EvaluationLoop(object): diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 53f217176d1e0..1457bd3d538a8 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -27,9 +27,12 @@ from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, argparse_utils +from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE +from pytorch_lightning.utilities.argparse import ( + from_argparse_args, parse_argparser, parse_env_variables, add_argparse_args +) from pytorch_lightning.utilities.cloud_io import get_filesystem -from pytorch_lightning.utilities.model_utils import is_overridden +from pytorch_lightning.utilities.model_helpers import is_overridden if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm @@ -150,19 +153,19 @@ def get_deprecated_arg_names(cls) -> List: @classmethod def from_argparse_args(cls: Type['_T'], args: Union[Namespace, ArgumentParser], **kwargs) -> '_T': - return argparse_utils.from_argparse_args(cls, args, **kwargs) + return from_argparse_args(cls, args, **kwargs) @classmethod def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: - return argparse_utils.parse_argparser(cls, arg_parser) + return parse_argparser(cls, arg_parser) @classmethod def match_env_arguments(cls) -> Namespace: - return argparse_utils.parse_env_variables(cls) + return parse_env_variables(cls) @classmethod def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: - return argparse_utils.add_argparse_args(cls, parent_parser) + return add_argparse_args(cls, parent_parser) @property def num_gpus(self) -> int: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 35da90625adef..62d7deb0eb378 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -61,7 +61,7 @@ from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach -from pytorch_lightning.utilities.model_utils import is_overridden +from pytorch_lightning.utilities.model_helpers import is_overridden # warnings to ignore in trainer warnings.filterwarnings( diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1637e3504dd0d..cd73e9793ddae 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -28,9 +28,9 @@ from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach -from pytorch_lightning.utilities.model_utils import is_overridden +from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.parsing import AttributeDict -from pytorch_lightning.utilities.warning_utils import WarningCache +from pytorch_lightning.utilities.warnings import WarningCache class TrainLoop: diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index c01592cceadbc..4be619500798b 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -29,7 +29,7 @@ rank_zero_warn, ) from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401 -from pytorch_lightning.utilities.xla_device_utils import _XLA_AVAILABLE, XLADeviceUtils # noqa: F401 +from pytorch_lightning.utilities.xla_device import _XLA_AVAILABLE, XLADeviceUtils # noqa: F401 def _module_available(module_path: str) -> bool: diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py new file mode 100644 index 0000000000000..38ff23897434d --- /dev/null +++ b/pytorch_lightning/utilities/argparse.py @@ -0,0 +1,249 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import os +from argparse import ArgumentParser, Namespace +from typing import Dict, Union, List, Tuple, Any +from pytorch_lightning.utilities import parsing + + +def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): + """ + Create an instance from CLI arguments. + Eventually use varibles from OS environement which are defined as "PL__" + + Args: + args: The parser or namespace to take arguments from. Only known arguments will be + parsed and passed to the :class:`Trainer`. + **kwargs: Additional keyword arguments that may override ones in the parser or namespace. + These must be valid Trainer arguments. + + Example: + >>> from pytorch_lightning import Trainer + >>> parser = ArgumentParser(add_help=False) + >>> parser = Trainer.add_argparse_args(parser) + >>> parser.add_argument('--my_custom_arg', default='something') # doctest: +SKIP + >>> args = Trainer.parse_argparser(parser.parse_args("")) + >>> trainer = Trainer.from_argparse_args(args, logger=False) + """ + if isinstance(args, ArgumentParser): + args = cls.parse_argparser(args) + + params = vars(args) + + # we only want to pass in valid Trainer args, the rest may be user specific + valid_kwargs = inspect.signature(cls.__init__).parameters + trainer_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params) + trainer_kwargs.update(**kwargs) + + return cls(**trainer_kwargs) + + +def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: + """Parse CLI arguments, required for custom bool types.""" + args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser + + types_default = { + arg: (arg_types, arg_default) for arg, arg_types, arg_default in get_init_arguments_and_types(cls) + } + + modified_args = {} + for k, v in vars(args).items(): + if k in types_default and v is None: + # We need to figure out if the None is due to using nargs="?" or if it comes from the default value + arg_types, arg_default = types_default[k] + if bool in arg_types and isinstance(arg_default, bool): + # Value has been passed as a flag => It is currently None, so we need to set it to True + # We always set to True, regardless of the default value. + # Users must pass False directly, but when passing nothing True is assumed. + # i.e. the only way to disable somthing that defaults to True is to use the long form: + # "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None, + # which then becomes True here. + + v = True + + modified_args[k] = v + return Namespace(**modified_args) + + +def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: + """Parse environment arguments if they are defined. + + Example: + >>> from pytorch_lightning import Trainer + >>> parse_env_variables(Trainer) + Namespace() + >>> import os + >>> os.environ["PL_TRAINER_GPUS"] = '42' + >>> os.environ["PL_TRAINER_BLABLABLA"] = '1.23' + >>> parse_env_variables(Trainer) + Namespace(gpus=42) + >>> del os.environ["PL_TRAINER_GPUS"] + """ + cls_arg_defaults = get_init_arguments_and_types(cls) + + env_args = {} + for arg_name, _, _ in cls_arg_defaults: + env = template % {'cls_name': cls.__name__.upper(), 'cls_argument': arg_name.upper()} + val = os.environ.get(env) + if not (val is None or val == ''): + try: # converting to native types like int/float/bool + val = eval(val) + except Exception: + pass + env_args[arg_name] = val + return Namespace(**env_args) + + +def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: + r"""Scans the Trainer signature and returns argument names, types and default values. + + Returns: + List with tuples of 3 values: + (argument name, set with argument types, argument default value). + + Examples: + + >>> from pytorch_lightning import Trainer + >>> args = get_init_arguments_and_types(Trainer) + + """ + trainer_default_params = inspect.signature(cls).parameters + name_type_default = [] + for arg in trainer_default_params: + arg_type = trainer_default_params[arg].annotation + arg_default = trainer_default_params[arg].default + try: + arg_types = tuple(arg_type.__args__) + except AttributeError: + arg_types = (arg_type,) + + name_type_default.append((arg, arg_types, arg_default)) + + return name_type_default + + +def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: + r"""Extends existing argparse by default `Trainer` attributes. + + Args: + parent_parser: + The custom cli arguments parser, which will be extended by + the Trainer default arguments. + + Only arguments of the allowed types (str, float, int, bool) will + extend the `parent_parser`. + + Examples: + >>> import argparse + >>> from pytorch_lightning import Trainer + >>> parser = argparse.ArgumentParser() + >>> parser = Trainer.add_argparse_args(parser) + >>> args = parser.parse_args([]) + """ + parser = ArgumentParser(parents=[parent_parser], add_help=False,) + + blacklist = ['kwargs'] + depr_arg_names = cls.get_deprecated_arg_names() + blacklist + + allowed_types = (str, int, float, bool) + + args_help = parse_args_from_docstring(cls.__init__.__doc__ or cls.__doc__) + for arg, arg_types, arg_default in ( + at for at in get_init_arguments_and_types(cls) if at[0] not in depr_arg_names + ): + arg_types = [at for at in allowed_types if at in arg_types] + if not arg_types: + # skip argument with not supported type + continue + arg_kwargs = {} + if bool in arg_types: + arg_kwargs.update(nargs="?", const=True) + # if the only arg type is bool + if len(arg_types) == 1: + use_type = parsing.str_to_bool + elif str in arg_types: + use_type = parsing.str_to_bool_or_str + else: + # filter out the bool as we need to use more general + use_type = [at for at in arg_types if at is not bool][0] + else: + use_type = arg_types[0] + + if arg == 'gpus' or arg == 'tpu_cores': + use_type = _gpus_allowed_type + arg_default = _gpus_arg_default + + # hack for types in (int, float) + if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types): + use_type = _int_or_float_type + + # hack for track_grad_norm + if arg == 'track_grad_norm': + use_type = float + + parser.add_argument( + f'--{arg}', + dest=arg, + default=arg_default, + type=use_type, + help=args_help.get(arg), + **arg_kwargs, + ) + + return parser + + +def parse_args_from_docstring(docstring: str) -> Dict[str, str]: + arg_block_indent = None + current_arg = None + parsed = {} + for line in docstring.split("\n"): + stripped = line.lstrip() + if not stripped: + continue + line_indent = len(line) - len(stripped) + if stripped.startswith(('Args:', 'Arguments:', 'Parameters:')): + arg_block_indent = line_indent + 4 + elif arg_block_indent is None: + continue + elif line_indent < arg_block_indent: + break + elif line_indent == arg_block_indent: + current_arg, arg_description = stripped.split(':', maxsplit=1) + parsed[current_arg] = arg_description.lstrip() + elif line_indent > arg_block_indent: + parsed[current_arg] += f' {stripped}' + return parsed + + +def _gpus_allowed_type(x) -> Union[int, str]: + if ',' in x: + return str(x) + else: + return int(x) + + +def _gpus_arg_default(x) -> Union[int, str]: + if ',' in x: + return str(x) + else: + return int(x) + + +def _int_or_float_type(x) -> Union[int, float]: + if '.' in str(x): + return float(x) + else: + return int(x) diff --git a/pytorch_lightning/utilities/argparse_utils.py b/pytorch_lightning/utilities/argparse_utils.py index 38ff23897434d..e4a6fc5cd89c1 100644 --- a/pytorch_lightning/utilities/argparse_utils.py +++ b/pytorch_lightning/utilities/argparse_utils.py @@ -1,249 +1,6 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import inspect -import os -from argparse import ArgumentParser, Namespace -from typing import Dict, Union, List, Tuple, Any -from pytorch_lightning.utilities import parsing +from warnings import warn +warn("`argparse_utils` package has been renamed to `argparse` since v1.2 and will be removed in v1.4", + DeprecationWarning) -def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): - """ - Create an instance from CLI arguments. - Eventually use varibles from OS environement which are defined as "PL__" - - Args: - args: The parser or namespace to take arguments from. Only known arguments will be - parsed and passed to the :class:`Trainer`. - **kwargs: Additional keyword arguments that may override ones in the parser or namespace. - These must be valid Trainer arguments. - - Example: - >>> from pytorch_lightning import Trainer - >>> parser = ArgumentParser(add_help=False) - >>> parser = Trainer.add_argparse_args(parser) - >>> parser.add_argument('--my_custom_arg', default='something') # doctest: +SKIP - >>> args = Trainer.parse_argparser(parser.parse_args("")) - >>> trainer = Trainer.from_argparse_args(args, logger=False) - """ - if isinstance(args, ArgumentParser): - args = cls.parse_argparser(args) - - params = vars(args) - - # we only want to pass in valid Trainer args, the rest may be user specific - valid_kwargs = inspect.signature(cls.__init__).parameters - trainer_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params) - trainer_kwargs.update(**kwargs) - - return cls(**trainer_kwargs) - - -def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: - """Parse CLI arguments, required for custom bool types.""" - args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser - - types_default = { - arg: (arg_types, arg_default) for arg, arg_types, arg_default in get_init_arguments_and_types(cls) - } - - modified_args = {} - for k, v in vars(args).items(): - if k in types_default and v is None: - # We need to figure out if the None is due to using nargs="?" or if it comes from the default value - arg_types, arg_default = types_default[k] - if bool in arg_types and isinstance(arg_default, bool): - # Value has been passed as a flag => It is currently None, so we need to set it to True - # We always set to True, regardless of the default value. - # Users must pass False directly, but when passing nothing True is assumed. - # i.e. the only way to disable somthing that defaults to True is to use the long form: - # "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None, - # which then becomes True here. - - v = True - - modified_args[k] = v - return Namespace(**modified_args) - - -def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: - """Parse environment arguments if they are defined. - - Example: - >>> from pytorch_lightning import Trainer - >>> parse_env_variables(Trainer) - Namespace() - >>> import os - >>> os.environ["PL_TRAINER_GPUS"] = '42' - >>> os.environ["PL_TRAINER_BLABLABLA"] = '1.23' - >>> parse_env_variables(Trainer) - Namespace(gpus=42) - >>> del os.environ["PL_TRAINER_GPUS"] - """ - cls_arg_defaults = get_init_arguments_and_types(cls) - - env_args = {} - for arg_name, _, _ in cls_arg_defaults: - env = template % {'cls_name': cls.__name__.upper(), 'cls_argument': arg_name.upper()} - val = os.environ.get(env) - if not (val is None or val == ''): - try: # converting to native types like int/float/bool - val = eval(val) - except Exception: - pass - env_args[arg_name] = val - return Namespace(**env_args) - - -def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: - r"""Scans the Trainer signature and returns argument names, types and default values. - - Returns: - List with tuples of 3 values: - (argument name, set with argument types, argument default value). - - Examples: - - >>> from pytorch_lightning import Trainer - >>> args = get_init_arguments_and_types(Trainer) - - """ - trainer_default_params = inspect.signature(cls).parameters - name_type_default = [] - for arg in trainer_default_params: - arg_type = trainer_default_params[arg].annotation - arg_default = trainer_default_params[arg].default - try: - arg_types = tuple(arg_type.__args__) - except AttributeError: - arg_types = (arg_type,) - - name_type_default.append((arg, arg_types, arg_default)) - - return name_type_default - - -def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: - r"""Extends existing argparse by default `Trainer` attributes. - - Args: - parent_parser: - The custom cli arguments parser, which will be extended by - the Trainer default arguments. - - Only arguments of the allowed types (str, float, int, bool) will - extend the `parent_parser`. - - Examples: - >>> import argparse - >>> from pytorch_lightning import Trainer - >>> parser = argparse.ArgumentParser() - >>> parser = Trainer.add_argparse_args(parser) - >>> args = parser.parse_args([]) - """ - parser = ArgumentParser(parents=[parent_parser], add_help=False,) - - blacklist = ['kwargs'] - depr_arg_names = cls.get_deprecated_arg_names() + blacklist - - allowed_types = (str, int, float, bool) - - args_help = parse_args_from_docstring(cls.__init__.__doc__ or cls.__doc__) - for arg, arg_types, arg_default in ( - at for at in get_init_arguments_and_types(cls) if at[0] not in depr_arg_names - ): - arg_types = [at for at in allowed_types if at in arg_types] - if not arg_types: - # skip argument with not supported type - continue - arg_kwargs = {} - if bool in arg_types: - arg_kwargs.update(nargs="?", const=True) - # if the only arg type is bool - if len(arg_types) == 1: - use_type = parsing.str_to_bool - elif str in arg_types: - use_type = parsing.str_to_bool_or_str - else: - # filter out the bool as we need to use more general - use_type = [at for at in arg_types if at is not bool][0] - else: - use_type = arg_types[0] - - if arg == 'gpus' or arg == 'tpu_cores': - use_type = _gpus_allowed_type - arg_default = _gpus_arg_default - - # hack for types in (int, float) - if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types): - use_type = _int_or_float_type - - # hack for track_grad_norm - if arg == 'track_grad_norm': - use_type = float - - parser.add_argument( - f'--{arg}', - dest=arg, - default=arg_default, - type=use_type, - help=args_help.get(arg), - **arg_kwargs, - ) - - return parser - - -def parse_args_from_docstring(docstring: str) -> Dict[str, str]: - arg_block_indent = None - current_arg = None - parsed = {} - for line in docstring.split("\n"): - stripped = line.lstrip() - if not stripped: - continue - line_indent = len(line) - len(stripped) - if stripped.startswith(('Args:', 'Arguments:', 'Parameters:')): - arg_block_indent = line_indent + 4 - elif arg_block_indent is None: - continue - elif line_indent < arg_block_indent: - break - elif line_indent == arg_block_indent: - current_arg, arg_description = stripped.split(':', maxsplit=1) - parsed[current_arg] = arg_description.lstrip() - elif line_indent > arg_block_indent: - parsed[current_arg] += f' {stripped}' - return parsed - - -def _gpus_allowed_type(x) -> Union[int, str]: - if ',' in x: - return str(x) - else: - return int(x) - - -def _gpus_arg_default(x) -> Union[int, str]: - if ',' in x: - return str(x) - else: - return int(x) - - -def _int_or_float_type(x) -> Union[int, float]: - if '.' in str(x): - return float(x) - else: - return int(x) +from pytorch_lightning.utilities.argparse import * # noqa: F403 E402 F401 diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py new file mode 100644 index 0000000000000..993d9e11e1491 --- /dev/null +++ b/pytorch_lightning/utilities/model_helpers.py @@ -0,0 +1,43 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.datamodule import LightningDataModule + + +def is_overridden(method_name: str, model: Union[LightningModule, LightningDataModule]) -> bool: + # if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super + # TODO - refector this function to accept model_name, instance, parent so it makes more sense + super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule + + if not hasattr(model, method_name) or not hasattr(super_object, method_name): + # in case of calling deprecated method + return False + + instance_attr = getattr(model, method_name) + if not instance_attr: + return False + super_attr = getattr(super_object, method_name) + + # when code pointers are different, it was implemented + if hasattr(instance_attr, 'patch_loader_code'): + # cannot pickle __code__ so cannot verify if PatchDataloader + # exists which shows dataloader methods have been overwritten. + # so, we hack it by using the string representation + is_overridden = instance_attr.patch_loader_code != str(super_attr.__code__) + else: + is_overridden = instance_attr.__code__ is not super_attr.__code__ + return is_overridden diff --git a/pytorch_lightning/utilities/model_utils.py b/pytorch_lightning/utilities/model_utils.py index 993d9e11e1491..a5472614499e9 100644 --- a/pytorch_lightning/utilities/model_utils.py +++ b/pytorch_lightning/utilities/model_utils.py @@ -1,43 +1,6 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +from warnings import warn -from typing import Union +warn("`model_utils` package has been renamed to `model_helpers` since v1.2 and will be removed in v1.4", + DeprecationWarning) -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.core.datamodule import LightningDataModule - - -def is_overridden(method_name: str, model: Union[LightningModule, LightningDataModule]) -> bool: - # if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super - # TODO - refector this function to accept model_name, instance, parent so it makes more sense - super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule - - if not hasattr(model, method_name) or not hasattr(super_object, method_name): - # in case of calling deprecated method - return False - - instance_attr = getattr(model, method_name) - if not instance_attr: - return False - super_attr = getattr(super_object, method_name) - - # when code pointers are different, it was implemented - if hasattr(instance_attr, 'patch_loader_code'): - # cannot pickle __code__ so cannot verify if PatchDataloader - # exists which shows dataloader methods have been overwritten. - # so, we hack it by using the string representation - is_overridden = instance_attr.patch_loader_code != str(super_attr.__code__) - else: - is_overridden = instance_attr.__code__ is not super_attr.__code__ - return is_overridden +from pytorch_lightning.utilities.model_helpers import * # noqa: F403 E402 F401 diff --git a/pytorch_lightning/utilities/warning_utils.py b/pytorch_lightning/utilities/warning_utils.py index a5d5be95ad76f..3ae0ada6f325b 100644 --- a/pytorch_lightning/utilities/warning_utils.py +++ b/pytorch_lightning/utilities/warning_utils.py @@ -1,25 +1,6 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from pytorch_lightning.utilities.distributed import rank_zero_warn +from warnings import warn +warn("`warning_utils` package has been renamed to `warnings` since v1.2 and will be removed in v1.4", + DeprecationWarning) -class WarningCache: - - def __init__(self): - self.warnings = set() - - def warn(self, m): - if m not in self.warnings: - self.warnings.add(m) - rank_zero_warn(m) +from pytorch_lightning.utilities.warnings import * # noqa: F403 E402 F401 diff --git a/pytorch_lightning/utilities/warnings.py b/pytorch_lightning/utilities/warnings.py new file mode 100644 index 0000000000000..a5d5be95ad76f --- /dev/null +++ b/pytorch_lightning/utilities/warnings.py @@ -0,0 +1,25 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.utilities.distributed import rank_zero_warn + + +class WarningCache: + + def __init__(self): + self.warnings = set() + + def warn(self, m): + if m not in self.warnings: + self.warnings.add(m) + rank_zero_warn(m) diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py new file mode 100644 index 0000000000000..d7702aef3357b --- /dev/null +++ b/pytorch_lightning/utilities/xla_device.py @@ -0,0 +1,104 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import importlib +import queue as q +import traceback +from multiprocessing import Process, Queue + +import torch + +_XLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None + +if _XLA_AVAILABLE: + import torch_xla.core.xla_model as xm + + +def inner_f(queue, func, *args, **kwargs): # pragma: no cover + try: + queue.put(func(*args, **kwargs)) + except Exception: + traceback.print_exc() + queue.put(None) + + +def pl_multi_process(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + queue = Queue() + proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs) + proc.start() + proc.join(20) + try: + return queue.get_nowait() + except q.Empty: + traceback.print_exc() + return False + + return wrapper + + +class XLADeviceUtils: + """Used to detect the type of XLA device""" + + TPU_AVAILABLE = None + + @staticmethod + def _fetch_xla_device_type(device: torch.device) -> str: + """ + Returns XLA device type + + Args: + device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0 + + Return: + Returns a str of the device hardware type. i.e TPU + """ + if _XLA_AVAILABLE: + return xm.xla_device_hw(device) + + @staticmethod + def _is_device_tpu() -> bool: + """ + Check if device is TPU + + Return: + A boolean value indicating if the xla device is a TPU device or not + """ + if _XLA_AVAILABLE: + device = xm.xla_device() + device_type = XLADeviceUtils._fetch_xla_device_type(device) + return device_type == "TPU" + + @staticmethod + def xla_available() -> bool: + """ + Check if XLA library is installed + + Return: + A boolean value indicating if a XLA is installed + """ + return _XLA_AVAILABLE + + @staticmethod + def tpu_device_exists() -> bool: + """ + Runs XLA device check within a separate process + + Return: + A boolean value indicating if a TPU device exists on the system + """ + if XLADeviceUtils.TPU_AVAILABLE is None and _XLA_AVAILABLE: + XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)() + return XLADeviceUtils.TPU_AVAILABLE diff --git a/pytorch_lightning/utilities/xla_device_utils.py b/pytorch_lightning/utilities/xla_device_utils.py index d7702aef3357b..14011406916fb 100644 --- a/pytorch_lightning/utilities/xla_device_utils.py +++ b/pytorch_lightning/utilities/xla_device_utils.py @@ -1,104 +1,6 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import functools -import importlib -import queue as q -import traceback -from multiprocessing import Process, Queue +from warnings import warn -import torch +warn("`xla_device_utils` package has been renamed to `xla_device` since v1.2 and will be removed in v1.4", + DeprecationWarning) -_XLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None - -if _XLA_AVAILABLE: - import torch_xla.core.xla_model as xm - - -def inner_f(queue, func, *args, **kwargs): # pragma: no cover - try: - queue.put(func(*args, **kwargs)) - except Exception: - traceback.print_exc() - queue.put(None) - - -def pl_multi_process(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - queue = Queue() - proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs) - proc.start() - proc.join(20) - try: - return queue.get_nowait() - except q.Empty: - traceback.print_exc() - return False - - return wrapper - - -class XLADeviceUtils: - """Used to detect the type of XLA device""" - - TPU_AVAILABLE = None - - @staticmethod - def _fetch_xla_device_type(device: torch.device) -> str: - """ - Returns XLA device type - - Args: - device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0 - - Return: - Returns a str of the device hardware type. i.e TPU - """ - if _XLA_AVAILABLE: - return xm.xla_device_hw(device) - - @staticmethod - def _is_device_tpu() -> bool: - """ - Check if device is TPU - - Return: - A boolean value indicating if the xla device is a TPU device or not - """ - if _XLA_AVAILABLE: - device = xm.xla_device() - device_type = XLADeviceUtils._fetch_xla_device_type(device) - return device_type == "TPU" - - @staticmethod - def xla_available() -> bool: - """ - Check if XLA library is installed - - Return: - A boolean value indicating if a XLA is installed - """ - return _XLA_AVAILABLE - - @staticmethod - def tpu_device_exists() -> bool: - """ - Runs XLA device check within a separate process - - Return: - A boolean value indicating if a TPU device exists on the system - """ - if XLADeviceUtils.TPU_AVAILABLE is None and _XLA_AVAILABLE: - XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)() - return XLADeviceUtils.TPU_AVAILABLE +from pytorch_lightning.utilities.xla_device import * # noqa: F403 E402 F401 diff --git a/tests/backends/test_tpu_backend.py b/tests/backends/test_tpu_backend.py index 63729f86ce862..de6f8e3cb3b99 100644 --- a/tests/backends/test_tpu_backend.py +++ b/tests/backends/test_tpu_backend.py @@ -16,7 +16,7 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils +from pytorch_lightning.utilities.xla_device import XLADeviceUtils from tests.base.boring_model import BoringModel from tests.base.develop_utils import pl_multi_process_test diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 75c799f397349..73f6b2303d102 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -25,7 +25,7 @@ from tests.base.datasets import TrialMNIST from tests.base.datamodules import TrialMNISTDataModule from tests.base.develop_utils import reset_seed -from pytorch_lightning.utilities.model_utils import is_overridden +from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator from pytorch_lightning.callbacks import ModelCheckpoint diff --git a/tests/deprecated_api/test_remove_1-2.py b/tests/deprecated_api/test_remove_1-2.py index 331208d56df10..70e3f088c11e2 100644 --- a/tests/deprecated_api/test_remove_1-2.py +++ b/tests/deprecated_api/test_remove_1-2.py @@ -20,7 +20,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException -def test_tbd_remove_in_v1_2_0(): +def test_v1_2_0_deprecated_arguments(): with pytest.deprecated_call(match='will be removed in v1.2'): ModelCheckpoint(filepath='..') @@ -31,7 +31,7 @@ def test_tbd_remove_in_v1_2_0(): ModelCheckpoint(filepath='..', dirpath='.') -def test_tbd_remove_in_v1_2_0_metrics(): +def test_v1_2_0_deprecated_metrics(): from pytorch_lightning.metrics.classification import Fbeta from pytorch_lightning.metrics.functional.classification import f1_score, fbeta_score diff --git a/tests/deprecated_api/test_remove_1-3.py b/tests/deprecated_api/test_remove_1-3.py index 2e44b8463e14e..c855086c9526d 100644 --- a/tests/deprecated_api/test_remove_1-3.py +++ b/tests/deprecated_api/test_remove_1-3.py @@ -21,9 +21,10 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler +from tests.deprecated_api import _soft_unimport_module -def test_tbd_remove_in_v1_3_0(tmpdir): +def test_v1_3_0_deprecated_arguments(tmpdir): with pytest.deprecated_call(match='will no longer be supported in v1.3'): callback = ModelCheckpoint() Trainer(checkpoint_callback=callback, callbacks=[], default_root_dir=tmpdir) @@ -47,7 +48,7 @@ def __init__(self, hparams): DeprecatedHparamsModel({}) -def test_tbd_remove_in_v1_3_0_metrics(): +def test_v1_3_0_deprecated_metrics(): from pytorch_lightning.metrics.functional.classification import to_onehot with pytest.deprecated_call(match='will be removed in v1.3'): to_onehot(torch.tensor([1, 2, 3])) @@ -124,7 +125,7 @@ def test_trainer_profiler_remove_in_v1_3_0(profiler, expected): ('--profiler False', False, PassThroughProfiler), ], ) -def test_trainer_cli_profiler_remove_in_v1_3_0(cli_args, expected_parsed_arg, expected_profiler): +def test_v1_3_0_trainer_cli_profiler(cli_args, expected_parsed_arg, expected_profiler): cli_args = cli_args.split(' ') with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): parser = ArgumentParser(add_help=False) diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py new file mode 100644 index 0000000000000..a65f44cea4279 --- /dev/null +++ b/tests/deprecated_api/test_remove_1-4.py @@ -0,0 +1,35 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test deprecated functionality which will be removed in vX.Y.Z""" +import pytest + +from tests.deprecated_api import _soft_unimport_module + + +def test_v1_4_0_deprecated_imports(): + _soft_unimport_module('pytorch_lightning.utilities.argparse_utils') + with pytest.deprecated_call(match='will be removed in v1.4'): + from pytorch_lightning.utilities.argparse_utils import from_argparse_args # noqa: F811 F401 + + _soft_unimport_module('pytorch_lightning.utilities.model_utils') + with pytest.deprecated_call(match='will be removed in v1.4'): + from pytorch_lightning.utilities.model_utils import is_overridden # noqa: F811 F401 + + _soft_unimport_module('pytorch_lightning.utilities.warning_utils') + with pytest.deprecated_call(match='will be removed in v1.4'): + from pytorch_lightning.utilities.warning_utils import WarningCache # noqa: F811 F401 + + _soft_unimport_module('pytorch_lightning.utilities.xla_device_utils') + with pytest.deprecated_call(match='will be removed in v1.4'): + from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils # noqa: F811 F401 diff --git a/tests/trainer/test_trainer_cli.py b/tests/trainer/test_trainer_cli.py index db8df3ef0d5cd..e8632b8443325 100644 --- a/tests/trainer/test_trainer_cli.py +++ b/tests/trainer/test_trainer_cli.py @@ -22,7 +22,7 @@ import tests.base.develop_utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.utilities import argparse_utils +from pytorch_lightning.utilities import argparse @mock.patch('argparse.ArgumentParser.parse_args') @@ -73,7 +73,7 @@ def test_add_argparse_args_redefined(cli_args): def test_get_init_arguments_and_types(): """Asserts a correctness of the `get_init_arguments_and_types` Trainer classmethod.""" - args = argparse_utils.get_init_arguments_and_types(Trainer) + args = argparse.get_init_arguments_and_types(Trainer) parameters = inspect.signature(Trainer).parameters assert len(parameters) == len(args) for arg in args: diff --git a/tests/utilities/test_argparse_utils.py b/tests/utilities/test_argparse_utils.py index 978ad820482b2..63227abf831ec 100644 --- a/tests/utilities/test_argparse_utils.py +++ b/tests/utilities/test_argparse_utils.py @@ -1,4 +1,4 @@ -from pytorch_lightning.utilities.argparse_utils import parse_args_from_docstring +from pytorch_lightning.utilities.argparse import parse_args_from_docstring def test_parse_args_from_docstring_normal(): diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index a495c29be1668..92825fc7f4e8c 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -15,7 +15,7 @@ import pytest -import pytorch_lightning.utilities.xla_device_utils as xla_utils +import pytorch_lightning.utilities.xla_device as xla_utils from pytorch_lightning.utilities import _XLA_AVAILABLE, _TPU_AVAILABLE from tests.base.develop_utils import pl_multi_process_test