Skip to content

Commit

Permalink
Clean utilities/argparse and add missing tests (#6607)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Mar 22, 2021
1 parent 870247f commit 853523e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 9 deletions.
10 changes: 2 additions & 8 deletions pytorch_lightning/utilities/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp
# Value has been passed as a flag => It is currently None, so we need to set it to True
# We always set to True, regardless of the default value.
# Users must pass False directly, but when passing nothing True is assumed.
# i.e. the only way to disable somthing that defaults to True is to use the long form:
# i.e. the only way to disable something that defaults to True is to use the long form:
# "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None,
# which then becomes True here.

Expand Down Expand Up @@ -242,9 +242,6 @@ def add_argparse_args(
if arg == 'track_grad_norm':
use_type = float

if arg_default is inspect._empty:
arg_default = None

parser.add_argument(
f'--{arg}',
dest=arg,
Expand Down Expand Up @@ -291,10 +288,7 @@ def _gpus_allowed_type(x) -> Union[int, str]:


def _gpus_arg_default(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)
return _gpus_allowed_type(x)


def _int_or_float_type(x) -> Union[int, float]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,51 @@
import io
from argparse import ArgumentParser
from argparse import ArgumentParser, Namespace
from typing import List
from unittest.mock import MagicMock

import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.utilities.argparse import (
add_argparse_args,
from_argparse_args,
get_abbrev_qualified_cls_name,
parse_argparser,
parse_args_from_docstring,
_gpus_arg_default,
_int_or_float_type
)


class ArgparseExample:
def __init__(self, a: int = 0, b: str = '', c: bool = False):
self.a = a
self.b = b
self.c = c


def test_from_argparse_args():
args = Namespace(a=1, b='test', c=True, d='not valid')
my_instance = from_argparse_args(ArgparseExample, args)
assert my_instance.a == 1
assert my_instance.b == 'test'
assert my_instance.c

parser = ArgumentParser()
mock_trainer = MagicMock()
_ = from_argparse_args(mock_trainer, parser)
mock_trainer.parse_argparser.assert_called_once_with(parser)


def test_parse_argparser():
args = Namespace(a=1, b='test', c=None, d='not valid')
new_args = parse_argparser(ArgparseExample, args)
assert new_args.a == 1
assert new_args.b == 'test'
assert new_args.c
assert new_args.d == 'not valid'


def test_parse_args_from_docstring_normal():
args_help = parse_args_from_docstring(
"""Constrain image dataset
Expand Down Expand Up @@ -168,3 +202,13 @@ def test_add_argparse_args_no_argument_group():
args = parser.parse_args(fake_argv)
assert args.main_arg == "abc"
assert args.my_parameter == 2


def test_gpus_arg_default():
assert _gpus_arg_default('1,2') == '1,2'
assert _gpus_arg_default('1') == 1


def test_int_or_float_type():
assert isinstance(_int_or_float_type('0.0'), float)
assert isinstance(_int_or_float_type('0'), int)

0 comments on commit 853523e

Please sign in to comment.