diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index 49cbaf3c6bdcf..46d88184ee190 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -67,7 +67,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp # Value has been passed as a flag => It is currently None, so we need to set it to True # We always set to True, regardless of the default value. # Users must pass False directly, but when passing nothing True is assumed. - # i.e. the only way to disable somthing that defaults to True is to use the long form: + # i.e. the only way to disable something that defaults to True is to use the long form: # "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None, # which then becomes True here. @@ -242,9 +242,6 @@ def add_argparse_args( if arg == 'track_grad_norm': use_type = float - if arg_default is inspect._empty: - arg_default = None - parser.add_argument( f'--{arg}', dest=arg, @@ -291,10 +288,7 @@ def _gpus_allowed_type(x) -> Union[int, str]: def _gpus_arg_default(x) -> Union[int, str]: - if ',' in x: - return str(x) - else: - return int(x) + return _gpus_allowed_type(x) def _int_or_float_type(x) -> Union[int, float]: diff --git a/tests/utilities/test_argparse_utils.py b/tests/utilities/test_argparse.py similarity index 80% rename from tests/utilities/test_argparse_utils.py rename to tests/utilities/test_argparse.py index b2eac514941e6..fdf5ae0cafe65 100644 --- a/tests/utilities/test_argparse_utils.py +++ b/tests/utilities/test_argparse.py @@ -1,17 +1,51 @@ import io -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace from typing import List +from unittest.mock import MagicMock import pytest from pytorch_lightning import Trainer from pytorch_lightning.utilities.argparse import ( add_argparse_args, + from_argparse_args, get_abbrev_qualified_cls_name, + parse_argparser, parse_args_from_docstring, + _gpus_arg_default, + _int_or_float_type ) +class ArgparseExample: + def __init__(self, a: int = 0, b: str = '', c: bool = False): + self.a = a + self.b = b + self.c = c + + +def test_from_argparse_args(): + args = Namespace(a=1, b='test', c=True, d='not valid') + my_instance = from_argparse_args(ArgparseExample, args) + assert my_instance.a == 1 + assert my_instance.b == 'test' + assert my_instance.c + + parser = ArgumentParser() + mock_trainer = MagicMock() + _ = from_argparse_args(mock_trainer, parser) + mock_trainer.parse_argparser.assert_called_once_with(parser) + + +def test_parse_argparser(): + args = Namespace(a=1, b='test', c=None, d='not valid') + new_args = parse_argparser(ArgparseExample, args) + assert new_args.a == 1 + assert new_args.b == 'test' + assert new_args.c + assert new_args.d == 'not valid' + + def test_parse_args_from_docstring_normal(): args_help = parse_args_from_docstring( """Constrain image dataset @@ -168,3 +202,13 @@ def test_add_argparse_args_no_argument_group(): args = parser.parse_args(fake_argv) assert args.main_arg == "abc" assert args.my_parameter == 2 + + +def test_gpus_arg_default(): + assert _gpus_arg_default('1,2') == '1,2' + assert _gpus_arg_default('1') == 1 + + +def test_int_or_float_type(): + assert isinstance(_int_or_float_type('0.0'), float) + assert isinstance(_int_or_float_type('0'), int)