Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: nb is set total number of devices, when nb is -1. #4209

Merged
merged 13 commits into from
Oct 29, 2020
Merged
2 changes: 2 additions & 0 deletions docs/source/multi_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ Note in particular the difference between `gpus=0`, `gpus=[0]` and `gpus="0"`.
`auto_select_gpus=True` will automatically help you find `k` gpus that are not
occupied by other processes. This is especially useful when GPUs are configured
to be in "exclusive mode", such that only one process at a time can access them.
For more details see the :ref:`Trainer guide <trainer>`.


Remove CUDA flags
^^^^^^^^^^^^^^^^^
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,12 @@ def training_step(self, batch, batch_idx, optimizer_idx):
# enable auto selection (will find two available gpus on system)
trainer = Trainer(gpus=2, auto_select_gpus=True)

# specifies all GPUs regardless of its availability
Trainer(gpus=-1, auto_select_gpus=False)

# specifies all available GPUs (if only one GPU is not occupied, uses one gpu)
Trainer(gpus=-1, auto_select_gpus=True)

auto_lr_find
^^^^^^^^^^^^

Expand Down
10 changes: 10 additions & 0 deletions pytorch_lightning/tuner/auto_gpu_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,18 @@
# limitations under the License.
import torch

from pytorch_lightning.utilities.exceptions import MisconfigurationException


def pick_multiple_gpus(nb):
if nb == 0:
raise MisconfigurationException(
r"auto_select_gpus=True, gpus=0 is not a valid configuration.\
Please select a valid number of GPU resources when using auto_select_gpus."
)

nb = torch.cuda.device_count() if nb == -1 else nb

picked = []
for _ in range(nb):
picked.append(pick_single_gpu(exclude_gpus=picked))
Expand Down
Empty file added tests/tuner/__init__.py
Empty file.
74 changes: 74 additions & 0 deletions tests/tuner/test_auto_gpu_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re

import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
from pytorch_lightning.utilities.exceptions import MisconfigurationException


@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="test requires a number of GPU machine greater than 1"
)
@pytest.mark.parametrize(
["auto_select_gpus", "gpus", "expected_error"],
[
(True, 0, MisconfigurationException),
(True, -1, None),
(False, 0, None),
(False, -1, None),
],
)
def test_trainer_with_gpus_options_combination_at_available_gpus_env(
auto_select_gpus, gpus, expected_error
):
if expected_error:
with pytest.raises(
expected_error,
match=re.escape(
r"auto_select_gpus=True, gpus=0 is not a valid configuration.\
Please select a valid number of GPU resources when using auto_select_gpus."
),
):
trainer = Trainer(auto_select_gpus=auto_select_gpus, gpus=gpus)
else:
trainer = Trainer(auto_select_gpus=auto_select_gpus, gpus=gpus)


@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="test requires a number of GPU machine greater than 1"
)
@pytest.mark.parametrize(
["nb", "expected_gpu_idxs", "expected_error"],
[
(0, [], MisconfigurationException),
(-1, [i for i in range(torch.cuda.device_count())], None),
(1, [0], None),
],
)
def test_pick_multiple_gpus(nb, expected_gpu_idxs, expected_error):
if expected_error:
with pytest.raises(
expected_error,
match=re.escape(
r"auto_select_gpus=True, gpus=0 is not a valid configuration.\
Please select a valid number of GPU resources when using auto_select_gpus."
),
):
pick_multiple_gpus(nb)
else:
assert expected_gpu_idxs == pick_multiple_gpus(nb)