Skip to content

Commit

Permalink
Unify names in Utils (#5199)
Browse files Browse the repository at this point in the history
* warnings

* argparse

* mutils

* xla device

* deprecated

* tests

* simple

* flake8

* fix

* flake8

* 1.4
  • Loading branch information
Borda committed Dec 21, 2020
1 parent ccffc34 commit a884866
Show file tree
Hide file tree
Showing 27 changed files with 505 additions and 442 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch.nn.parallel._functions import Gather

from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities.warning_utils import WarningCache
from pytorch_lightning.utilities.warnings import WarningCache


def _find_tensors(obj): # pragma: no-cover
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.utilities.model_helpers import is_overridden


class ConfigValidator(object):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from typing import List, Optional, Union
from torch.utils.data import DataLoader
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.utilities.model_helpers import is_overridden


class DataConnector(object):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/connectors/env_vars_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from functools import wraps
from typing import Callable

from pytorch_lightning.utilities.argparse_utils import parse_env_variables, get_init_arguments_and_types
from pytorch_lightning.utilities.argparse import parse_env_variables, get_init_arguments_and_types


def overwrite_by_env_vars(fn: Callable) -> Callable:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore, LoggerStages
from pytorch_lightning.utilities import flatten_dict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.utilities.model_helpers import is_overridden


class LoggerConnector:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.utilities.model_helpers import is_overridden


class TrainerDataLoadingMixin(ABC):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.utilities.warning_utils import WarningCache
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import WarningCache


class EvaluationLoop(object):
Expand Down
15 changes: 9 additions & 6 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, argparse_utils
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE
from pytorch_lightning.utilities.argparse import (
from_argparse_args, parse_argparser, parse_env_variables, add_argparse_args
)
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.utilities.model_helpers import is_overridden

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -150,19 +153,19 @@ def get_deprecated_arg_names(cls) -> List:

@classmethod
def from_argparse_args(cls: Type['_T'], args: Union[Namespace, ArgumentParser], **kwargs) -> '_T':
return argparse_utils.from_argparse_args(cls, args, **kwargs)
return from_argparse_args(cls, args, **kwargs)

@classmethod
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
return argparse_utils.parse_argparser(cls, arg_parser)
return parse_argparser(cls, arg_parser)

@classmethod
def match_env_arguments(cls) -> Namespace:
return argparse_utils.parse_env_variables(cls)
return parse_env_variables(cls)

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
return argparse_utils.add_argparse_args(cls, parent_parser)
return add_argparse_args(cls, parent_parser)

@property
def num_gpus(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.utilities.model_helpers import is_overridden

# warnings to ignore in trainer
warnings.filterwarnings(
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.warning_utils import WarningCache
from pytorch_lightning.utilities.warnings import WarningCache


class TrainLoop:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
rank_zero_warn,
)
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401
from pytorch_lightning.utilities.xla_device_utils import _XLA_AVAILABLE, XLADeviceUtils # noqa: F401
from pytorch_lightning.utilities.xla_device import _XLA_AVAILABLE, XLADeviceUtils # noqa: F401


def _module_available(module_path: str) -> bool:
Expand Down
249 changes: 249 additions & 0 deletions pytorch_lightning/utilities/argparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from argparse import ArgumentParser, Namespace
from typing import Dict, Union, List, Tuple, Any
from pytorch_lightning.utilities import parsing


def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
"""
Create an instance from CLI arguments.
Eventually use varibles from OS environement which are defined as "PL_<CLASS-NAME>_<CLASS_ARUMENT_NAME>"
Args:
args: The parser or namespace to take arguments from. Only known arguments will be
parsed and passed to the :class:`Trainer`.
**kwargs: Additional keyword arguments that may override ones in the parser or namespace.
These must be valid Trainer arguments.
Example:
>>> from pytorch_lightning import Trainer
>>> parser = ArgumentParser(add_help=False)
>>> parser = Trainer.add_argparse_args(parser)
>>> parser.add_argument('--my_custom_arg', default='something') # doctest: +SKIP
>>> args = Trainer.parse_argparser(parser.parse_args(""))
>>> trainer = Trainer.from_argparse_args(args, logger=False)
"""
if isinstance(args, ArgumentParser):
args = cls.parse_argparser(args)

params = vars(args)

# we only want to pass in valid Trainer args, the rest may be user specific
valid_kwargs = inspect.signature(cls.__init__).parameters
trainer_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params)
trainer_kwargs.update(**kwargs)

return cls(**trainer_kwargs)


def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
"""Parse CLI arguments, required for custom bool types."""
args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser

types_default = {
arg: (arg_types, arg_default) for arg, arg_types, arg_default in get_init_arguments_and_types(cls)
}

modified_args = {}
for k, v in vars(args).items():
if k in types_default and v is None:
# We need to figure out if the None is due to using nargs="?" or if it comes from the default value
arg_types, arg_default = types_default[k]
if bool in arg_types and isinstance(arg_default, bool):
# Value has been passed as a flag => It is currently None, so we need to set it to True
# We always set to True, regardless of the default value.
# Users must pass False directly, but when passing nothing True is assumed.
# i.e. the only way to disable somthing that defaults to True is to use the long form:
# "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None,
# which then becomes True here.

v = True

modified_args[k] = v
return Namespace(**modified_args)


def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
"""Parse environment arguments if they are defined.
Example:
>>> from pytorch_lightning import Trainer
>>> parse_env_variables(Trainer)
Namespace()
>>> import os
>>> os.environ["PL_TRAINER_GPUS"] = '42'
>>> os.environ["PL_TRAINER_BLABLABLA"] = '1.23'
>>> parse_env_variables(Trainer)
Namespace(gpus=42)
>>> del os.environ["PL_TRAINER_GPUS"]
"""
cls_arg_defaults = get_init_arguments_and_types(cls)

env_args = {}
for arg_name, _, _ in cls_arg_defaults:
env = template % {'cls_name': cls.__name__.upper(), 'cls_argument': arg_name.upper()}
val = os.environ.get(env)
if not (val is None or val == ''):
try: # converting to native types like int/float/bool
val = eval(val)
except Exception:
pass
env_args[arg_name] = val
return Namespace(**env_args)


def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
r"""Scans the Trainer signature and returns argument names, types and default values.
Returns:
List with tuples of 3 values:
(argument name, set with argument types, argument default value).
Examples:
>>> from pytorch_lightning import Trainer
>>> args = get_init_arguments_and_types(Trainer)
"""
trainer_default_params = inspect.signature(cls).parameters
name_type_default = []
for arg in trainer_default_params:
arg_type = trainer_default_params[arg].annotation
arg_default = trainer_default_params[arg].default
try:
arg_types = tuple(arg_type.__args__)
except AttributeError:
arg_types = (arg_type,)

name_type_default.append((arg, arg_types, arg_default))

return name_type_default


def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
r"""Extends existing argparse by default `Trainer` attributes.
Args:
parent_parser:
The custom cli arguments parser, which will be extended by
the Trainer default arguments.
Only arguments of the allowed types (str, float, int, bool) will
extend the `parent_parser`.
Examples:
>>> import argparse
>>> from pytorch_lightning import Trainer
>>> parser = argparse.ArgumentParser()
>>> parser = Trainer.add_argparse_args(parser)
>>> args = parser.parse_args([])
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False,)

blacklist = ['kwargs']
depr_arg_names = cls.get_deprecated_arg_names() + blacklist

allowed_types = (str, int, float, bool)

args_help = parse_args_from_docstring(cls.__init__.__doc__ or cls.__doc__)
for arg, arg_types, arg_default in (
at for at in get_init_arguments_and_types(cls) if at[0] not in depr_arg_names
):
arg_types = [at for at in allowed_types if at in arg_types]
if not arg_types:
# skip argument with not supported type
continue
arg_kwargs = {}
if bool in arg_types:
arg_kwargs.update(nargs="?", const=True)
# if the only arg type is bool
if len(arg_types) == 1:
use_type = parsing.str_to_bool
elif str in arg_types:
use_type = parsing.str_to_bool_or_str
else:
# filter out the bool as we need to use more general
use_type = [at for at in arg_types if at is not bool][0]
else:
use_type = arg_types[0]

if arg == 'gpus' or arg == 'tpu_cores':
use_type = _gpus_allowed_type
arg_default = _gpus_arg_default

# hack for types in (int, float)
if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types):
use_type = _int_or_float_type

# hack for track_grad_norm
if arg == 'track_grad_norm':
use_type = float

parser.add_argument(
f'--{arg}',
dest=arg,
default=arg_default,
type=use_type,
help=args_help.get(arg),
**arg_kwargs,
)

return parser


def parse_args_from_docstring(docstring: str) -> Dict[str, str]:
arg_block_indent = None
current_arg = None
parsed = {}
for line in docstring.split("\n"):
stripped = line.lstrip()
if not stripped:
continue
line_indent = len(line) - len(stripped)
if stripped.startswith(('Args:', 'Arguments:', 'Parameters:')):
arg_block_indent = line_indent + 4
elif arg_block_indent is None:
continue
elif line_indent < arg_block_indent:
break
elif line_indent == arg_block_indent:
current_arg, arg_description = stripped.split(':', maxsplit=1)
parsed[current_arg] = arg_description.lstrip()
elif line_indent > arg_block_indent:
parsed[current_arg] += f' {stripped}'
return parsed


def _gpus_allowed_type(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)


def _gpus_arg_default(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)


def _int_or_float_type(x) -> Union[int, float]:
if '.' in str(x):
return float(x)
else:
return int(x)
Loading

0 comments on commit a884866

Please sign in to comment.