Skip to content

Commit

Permalink
Docs and Tests for "gpus" Trainer Argument (#593)
Browse files Browse the repository at this point in the history
* add table for gpus argument

* fix typo in error message

* tests for supported values

* tests for unsupported values

* fix typo

* add table for gpus argument

* fix typo in error message

* tests for supported values

* tests for unsupported values

* fix typo

* fix typo list->str

* fix travis warning "line too long"
  • Loading branch information
Adrian Wälchli authored and williamFalcon committed Dec 7, 2019
1 parent cc65f39 commit f7e1040
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
34 changes: 33 additions & 1 deletion pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,38 @@
# RECOMMENDED use DistributedDataParallel
trainer = Trainer(gpus=8, distributed_backend='ddp')
Custom device selection
-----------------------
The number of GPUs can also be selected with a list of indices or a string containing
a comma separated list of GPU ids.
The table below lists examples of possible input formats and how they are interpreted by Lightning.
Note in particular the difference between `gpus=0`, `gpus=[0]` and `gpus="0"`.
+---------------+-----------+---------------------+---------------------------------+
| `gpus` | Type | Parsed | Meaning |
+===============+===========+=====================+=================================+
| None | NoneType | None | CPU |
+---------------+-----------+---------------------+---------------------------------+
| 0 | int | None | CPU |
+---------------+-----------+---------------------+---------------------------------+
| 3 | int | [0, 1, 2] | first 3 GPUs |
+---------------+-----------+---------------------+---------------------------------+
| -1 | int | [0, 1, 2, ...] | all available GPUs |
+---------------+-----------+---------------------+---------------------------------+
| [0] | list | [0] | GPU 0 |
+---------------+-----------+---------------------+---------------------------------+
| [1, 3] | list | [1, 3] | GPUs 1 and 3 |
+---------------+-----------+---------------------+---------------------------------+
| "0" | str | [0] | GPU 0 |
+---------------+-----------+---------------------+---------------------------------+
| "3" | str | [3] | GPU 3 |
+---------------+-----------+---------------------+---------------------------------+
| "1, 3" | str | [1, 3] | GPUs 1 and 3 |
+---------------+-----------+---------------------+---------------------------------+
| "-1" | str | [0, 1, 2, ...] | all available GPUs |
+---------------+-----------+---------------------+---------------------------------+
Multi-node
----------
Expand Down Expand Up @@ -531,7 +563,7 @@ def parse_gpu_ids(gpus):
gpus = sanitize_gpu_ids(gpus)

if not gpus:
raise MisconfigurationException("GPUs requested but non are available.")
raise MisconfigurationException("GPUs requested but none are available.")
return gpus


Expand Down
34 changes: 33 additions & 1 deletion tests/test_gpu_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,14 @@ def test_determine_root_gpu_device(gpus, expected_root_gpu):
pytest.param(None, None),
pytest.param(0, None),
pytest.param(1, [0]),
pytest.param(3, [0, 1, 2]),
pytest.param(-1, list(range(PRETEND_N_OF_GPUS)), id="-1 - use all gpus"),
pytest.param([0], [0]),
pytest.param([1, 3], [1, 3]),
pytest.param('0', [0]),
pytest.param('3', [3]),
pytest.param('1, 3', [1, 3]),
pytest.param('-1', list(range(PRETEND_N_OF_GPUS)), id="'-1' - use all gpus"),
pytest.param(3, [0, 1, 2]),
]


Expand All @@ -363,6 +368,33 @@ def test_parse_gpu_ids(mocked_device_count, gpus, expected_gpu_ids):
assert parse_gpu_ids(gpus) == expected_gpu_ids


test_parse_gpu_invalid_inputs_data = [
pytest.param(0.1),
pytest.param(-2),
pytest.param(False),
pytest.param([]),
pytest.param([-1]),
pytest.param([None]),
pytest.param(['0']),
pytest.param((0, 1)),
]


@pytest.mark.gpus_param_tests
@pytest.mark.parametrize(['gpus'], test_parse_gpu_invalid_inputs_data)
def test_parse_gpu_fail_on_unsupported_inputs(mocked_device_count, gpus):
with pytest.raises(MisconfigurationException):
parse_gpu_ids(gpus)


@pytest.mark.gpus_param_tests
@pytest.mark.parametrize("gpus", [''])
def test_parse_gpu_fail_on_empty_string(mocked_device_count, gpus):
# This currently results in a ValueError instead of MisconfigurationException
with pytest.raises(ValueError):
parse_gpu_ids(gpus)


@pytest.mark.gpus_param_tests
@pytest.mark.parametrize("gpus", [[1, 2, 19], -1, '-1'])
def test_parse_gpu_fail_on_non_existant_id(mocked_device_count_0, gpus):
Expand Down

0 comments on commit f7e1040

Please sign in to comment.