Skip to content

Commit

Permalink
fix gpus default for Trainer.add_argparse_args (#6898)
Browse files Browse the repository at this point in the history
(cherry picked from commit 9c9e2a0)
  • Loading branch information
awaelchli authored and SeanNaren committed Apr 13, 2021
1 parent 94eaef6 commit f895e9f
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 47 deletions.
23 changes: 9 additions & 14 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [UnReleased] - 2021-MM-DD
## [1.3.0] - 2021-MM-DD

### Added

- Added utils for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/))


- Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667))

Expand Down Expand Up @@ -81,8 +83,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595))


- Added support for DDP communication hooks ([#6736](https://github.com/PyTorchLightning/pytorch-lightning/issues/6736))


- Added `artifact_location` argument to `MLFlowLogger` which will be passed to the `MlflowClient.create_experiment` call ([#6677](https://github.com/PyTorchLightning/pytorch-lightning/pull/6677))


Expand Down Expand Up @@ -111,6 +115,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- Deprecated `TrainerTrainingTricksMixin` in favor of a separate utilities module for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/))


- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))


Expand Down Expand Up @@ -221,19 +228,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))


## [1.2.8] - 2021-04-13


### Changed


### Removed


### Fixed


- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))


## [1.2.7] - 2021-04-06
Expand Down
8 changes: 0 additions & 8 deletions pytorch_lightning/utilities/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:

if arg == 'gpus' or arg == 'tpu_cores':
use_type = _gpus_allowed_type
arg_default = _gpus_arg_default

# hack for types in (int, float)
if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types):
Expand Down Expand Up @@ -238,13 +237,6 @@ def _gpus_allowed_type(x) -> Union[int, str]:
return int(x)


def _gpus_arg_default(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)


def _int_or_float_type(x) -> Union[int, float]:
if '.' in str(x):
return float(x)
Expand Down
9 changes: 0 additions & 9 deletions pytorch_lightning/utilities/device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,6 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i
If no GPUs are available but the value of gpus variable indicates request for GPUs
then a MisconfigurationException is raised.
"""

# nothing was passed into the GPUs argument
if callable(gpus):
return None

# Check that gpus param is None, Int, String or List
_check_data_type(gpus)

Expand Down Expand Up @@ -97,10 +92,6 @@ def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[List[int
Returns:
a list of tpu_cores to be used or ``None`` if no TPU cores were requested
"""

if callable(tpu_cores):
return None

_check_data_type(tpu_cores)

if isinstance(tpu_cores, str):
Expand Down
4 changes: 1 addition & 3 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import os
import pickle
import types
from argparse import ArgumentParser
from unittest import mock

Expand Down Expand Up @@ -172,11 +171,10 @@ def wrapper_something():
params.wrapper_something_wo_name = lambda: lambda: '1'
params.wrapper_something = wrapper_something

assert isinstance(params.gpus, types.FunctionType)
params = WandbLogger._convert_params(params)
params = WandbLogger._flatten_dict(params)
params = WandbLogger._sanitize_callable_params(params)
assert params["gpus"] == '_gpus_arg_default'
assert params["gpus"] == "None"
assert params["something"] == "something"
assert params["wrapper_something"] == "wrapper_something"
assert params["wrapper_something_wo_name"] == "<lambda>"
Expand Down
24 changes: 12 additions & 12 deletions tests/trainer/test_trainer_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import argparse
from tests.helpers.runif import RunIf


@mock.patch('argparse.ArgumentParser.parse_args')
Expand All @@ -45,7 +46,7 @@ def test_default_args(mock_argparse, tmpdir):


@pytest.mark.parametrize('cli_args', [['--accumulate_grad_batches=22'], ['--weights_save_path=./'], []])
def test_add_argparse_args_redefined(cli_args):
def test_add_argparse_args_redefined(cli_args: list):
"""Redefines some default Trainer arguments via the cli and
tests the Trainer initialization correctness.
"""
Expand Down Expand Up @@ -90,7 +91,7 @@ def test_get_init_arguments_and_types():


@pytest.mark.parametrize('cli_args', [['--callbacks=1', '--logger'], ['--foo', '--bar=1']])
def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
def test_add_argparse_args_redefined_error(cli_args: list, monkeypatch):
"""Asserts thar an error raised in case of passing not default cli arguments."""

class _UnkArgError(Exception):
Expand Down Expand Up @@ -171,27 +172,26 @@ def test_argparse_args_parsing(cli_args, expected):
assert Trainer.from_argparse_args(args)


@pytest.mark.parametrize(['cli_args', 'expected_gpu'], [
pytest.param('--gpus 1', [0]),
pytest.param('--gpus 0,', [0]),
@pytest.mark.parametrize(['cli_args', 'expected_parsed', 'expected_device_ids'], [
pytest.param('', None, None),
pytest.param('--gpus 1', 1, [0]),
pytest.param('--gpus 0,', '0,', [0]),
])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_argparse_args_parsing_gpus(cli_args, expected_gpu):
@RunIf(min_gpus=1)
def test_argparse_args_parsing_gpus(cli_args, expected_parsed, expected_device_ids):
"""Test multi type argument with bool."""
cli_args = cli_args.split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parent_parser=parser)
args = Trainer.parse_argparser(parser)

assert args.gpus == expected_parsed
trainer = Trainer.from_argparse_args(args)
assert trainer.data_parallel_device_ids == expected_gpu
assert trainer.data_parallel_device_ids == expected_device_ids


@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="signature inspection while mocking is not working in Python < 3.7 despite autospec"
)
@RunIf(min_python="3.7.0")
@pytest.mark.parametrize(['cli_args', 'extra_args'], [
pytest.param({}, {}),
pytest.param({'logger': False}, {}),
Expand Down
12 changes: 11 additions & 1 deletion tests/utilities/test_argparse_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pytorch_lightning.utilities.argparse import parse_args_from_docstring
from pytorch_lightning.utilities.argparse import parse_args_from_docstring, _gpus_allowed_type, _int_or_float_type


def test_parse_args_from_docstring_normal():
Expand Down Expand Up @@ -48,3 +48,13 @@ def test_parse_args_from_docstring_empty():
"""
)
assert len(args_help.keys()) == 0


def test_gpus_allowed_type():
assert _gpus_allowed_type('1,2') == '1,2'
assert _gpus_allowed_type('1') == 1


def test_int_or_float_type():
assert isinstance(_int_or_float_type('0.0'), float)
assert isinstance(_int_or_float_type('0'), int)

0 comments on commit f895e9f

Please sign in to comment.