From f7e1040236e088f4a0b5c725461cdf0eed80b068 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 7 Dec 2019 14:48:45 +0100 Subject: [PATCH] Docs and Tests for "gpus" Trainer Argument (#593) * add table for gpus argument * fix typo in error message * tests for supported values * tests for unsupported values * fix typo * add table for gpus argument * fix typo in error message * tests for supported values * tests for unsupported values * fix typo * fix typo list->str * fix travis warning "line too long" --- pytorch_lightning/trainer/distrib_parts.py | 34 +++++++++++++++++++++- tests/test_gpu_models.py | 34 +++++++++++++++++++++- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 4c1f3cdc9a4cf..4681d11b77121 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -164,6 +164,38 @@ # RECOMMENDED use DistributedDataParallel trainer = Trainer(gpus=8, distributed_backend='ddp') +Custom device selection +----------------------- + +The number of GPUs can also be selected with a list of indices or a string containing +a comma separated list of GPU ids. +The table below lists examples of possible input formats and how they are interpreted by Lightning. +Note in particular the difference between `gpus=0`, `gpus=[0]` and `gpus="0"`. + ++---------------+-----------+---------------------+---------------------------------+ +| `gpus` | Type | Parsed | Meaning | ++===============+===========+=====================+=================================+ +| None | NoneType | None | CPU | ++---------------+-----------+---------------------+---------------------------------+ +| 0 | int | None | CPU | ++---------------+-----------+---------------------+---------------------------------+ +| 3 | int | [0, 1, 2] | first 3 GPUs | ++---------------+-----------+---------------------+---------------------------------+ +| -1 | int | [0, 1, 2, ...] | all available GPUs | ++---------------+-----------+---------------------+---------------------------------+ +| [0] | list | [0] | GPU 0 | ++---------------+-----------+---------------------+---------------------------------+ +| [1, 3] | list | [1, 3] | GPUs 1 and 3 | ++---------------+-----------+---------------------+---------------------------------+ +| "0" | str | [0] | GPU 0 | ++---------------+-----------+---------------------+---------------------------------+ +| "3" | str | [3] | GPU 3 | ++---------------+-----------+---------------------+---------------------------------+ +| "1, 3" | str | [1, 3] | GPUs 1 and 3 | ++---------------+-----------+---------------------+---------------------------------+ +| "-1" | str | [0, 1, 2, ...] | all available GPUs | ++---------------+-----------+---------------------+---------------------------------+ + Multi-node ---------- @@ -531,7 +563,7 @@ def parse_gpu_ids(gpus): gpus = sanitize_gpu_ids(gpus) if not gpus: - raise MisconfigurationException("GPUs requested but non are available.") + raise MisconfigurationException("GPUs requested but none are available.") return gpus diff --git a/tests/test_gpu_models.py b/tests/test_gpu_models.py index 68b1fe2f2a5e6..900ba6c3f1e23 100644 --- a/tests/test_gpu_models.py +++ b/tests/test_gpu_models.py @@ -351,9 +351,14 @@ def test_determine_root_gpu_device(gpus, expected_root_gpu): pytest.param(None, None), pytest.param(0, None), pytest.param(1, [0]), + pytest.param(3, [0, 1, 2]), pytest.param(-1, list(range(PRETEND_N_OF_GPUS)), id="-1 - use all gpus"), + pytest.param([0], [0]), + pytest.param([1, 3], [1, 3]), + pytest.param('0', [0]), + pytest.param('3', [3]), + pytest.param('1, 3', [1, 3]), pytest.param('-1', list(range(PRETEND_N_OF_GPUS)), id="'-1' - use all gpus"), - pytest.param(3, [0, 1, 2]), ] @@ -363,6 +368,33 @@ def test_parse_gpu_ids(mocked_device_count, gpus, expected_gpu_ids): assert parse_gpu_ids(gpus) == expected_gpu_ids +test_parse_gpu_invalid_inputs_data = [ + pytest.param(0.1), + pytest.param(-2), + pytest.param(False), + pytest.param([]), + pytest.param([-1]), + pytest.param([None]), + pytest.param(['0']), + pytest.param((0, 1)), +] + + +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize(['gpus'], test_parse_gpu_invalid_inputs_data) +def test_parse_gpu_fail_on_unsupported_inputs(mocked_device_count, gpus): + with pytest.raises(MisconfigurationException): + parse_gpu_ids(gpus) + + +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize("gpus", ['']) +def test_parse_gpu_fail_on_empty_string(mocked_device_count, gpus): + # This currently results in a ValueError instead of MisconfigurationException + with pytest.raises(ValueError): + parse_gpu_ids(gpus) + + @pytest.mark.gpus_param_tests @pytest.mark.parametrize("gpus", [[1, 2, 19], -1, '-1']) def test_parse_gpu_fail_on_non_existant_id(mocked_device_count_0, gpus):