-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
Copy pathdevice_parser.py
229 lines (173 loc) · 7.48 KB
/
device_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, MutableSequence, Optional, Tuple, Union
import torch
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
def determine_root_gpu_device(gpus: List[int]) -> Optional[int]:
"""
Args:
gpus: non-empty list of ints representing which gpus to use
Returns:
designated root GPU device id
Raises:
TypeError:
If ``gpus`` is not a list
AssertionError:
If GPU list is empty
"""
if gpus is None:
return None
if not isinstance(gpus, list):
raise TypeError("gpus should be a list")
assert len(gpus) > 0, "gpus should be a non empty list"
# set root gpu
root_gpu = gpus[0]
return root_gpu
def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[int]]:
"""
Parses the GPU ids given in the format as accepted by the
:class:`~pytorch_lightning.trainer.Trainer`.
Args:
gpus: An int -1 or string '-1' indicate that all available GPUs should be used.
A list of unique ints or a string containing list of comma separated unique integers
indicates specific GPUs to use.
An int 0 means that no GPUs should be used.
Any int N > 0 indicates that GPUs [0..N) should be used.
Returns:
a list of gpus to be used or ``None`` if no GPUs were requested
If no GPUs are available but the value of gpus variable indicates request for GPUs
then a MisconfigurationException is raised.
"""
# Check that gpus param is None, Int, String or List
_check_data_type(gpus)
# Handle the case when no gpus are requested
if gpus is None or isinstance(gpus, int) and gpus == 0 or str(gpus).strip() == "0":
return None
# We know user requested GPUs therefore if some of the
# requested GPUs are not available an exception is thrown.
gpus = _normalize_parse_gpu_string_input(gpus)
gpus = _normalize_parse_gpu_input_to_list(gpus)
if not gpus:
raise MisconfigurationException("GPUs requested but none are available.")
if TorchElasticEnvironment.is_using_torchelastic() and len(gpus) != 1 and len(_get_all_available_gpus()) == 1:
# omit sanity check on torchelastic as by default shows one visible GPU per process
return gpus
# Check that gpus are unique. Duplicate gpus are not supported by the backend.
_check_unique(gpus)
return _sanitize_gpu_ids(gpus)
def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[int, List[int]]]:
"""
Parses the tpu_cores given in the format as accepted by the
:class:`~pytorch_lightning.trainer.Trainer`.
Args:
tpu_cores: An int 1 or string '1' indicate that 1 core with multi-processing should be used
An int 8 or string '8' indicate that all 8 cores with multi-processing should be used
A list of int or a string containing list of comma separated integer
indicates specific TPU core to use.
Returns:
a list of tpu_cores to be used or ``None`` if no TPU cores were requested
Raises:
MisconfigurationException:
If TPU cores aren't 1 or 8 cores, or no TPU devices are found
"""
_check_data_type(tpu_cores)
if isinstance(tpu_cores, str):
tpu_cores = _parse_tpu_cores_str(tpu_cores.strip())
if not _tpu_cores_valid(tpu_cores):
raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]")
if tpu_cores is not None and not _TPU_AVAILABLE:
raise MisconfigurationException("No TPU devices were found.")
return tpu_cores
def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]:
if not isinstance(s, str):
return s
if s == "-1":
return -1
if "," in s:
return [int(x.strip()) for x in s.split(",") if len(x) > 0]
return int(s.strip())
def _sanitize_gpu_ids(gpus: List[int]) -> List[int]:
"""Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of
the GPUs is not available.
Args:
gpus: list of ints corresponding to GPU indices
Returns:
unmodified gpus variable
Raises:
MisconfigurationException:
If machine has fewer available GPUs than requested.
"""
all_available_gpus = _get_all_available_gpus()
for gpu in gpus:
if gpu not in all_available_gpus:
raise MisconfigurationException(
f"You requested GPUs: {gpus}\n But your machine only has: {all_available_gpus}"
)
return gpus
def _normalize_parse_gpu_input_to_list(gpus: Union[int, List[int], Tuple[int, ...]]) -> Optional[List[int]]:
assert gpus is not None
if isinstance(gpus, (MutableSequence, tuple)):
return list(gpus)
# must be an int
if not gpus: # gpus==0
return None
if gpus == -1:
return _get_all_available_gpus()
return list(range(gpus))
def _get_all_available_gpus() -> List[int]:
"""
Returns:
a list of all available gpus
"""
return list(range(torch.cuda.device_count()))
def _check_unique(device_ids: List[int]) -> None:
"""Checks that the device_ids are unique.
Args:
device_ids: list of ints corresponding to gpus indices
Raises:
MisconfigurationException:
If ``device_ids`` of GPUs aren't unique
"""
if len(device_ids) != len(set(device_ids)):
raise MisconfigurationException("Device ID's (GPU) must be unique.")
def _check_data_type(device_ids: Any) -> None:
"""Checks that the device_ids argument is one of: None, Int, String or List. Raises a MisconfigurationException
otherwise.
Args:
device_ids: gpus/tpu_cores parameter as passed to the Trainer
Raises:
MisconfigurationException:
If ``device_ids`` of GPU/TPUs aren't ``int``, ``str``, sequence of ``int`` or ``None``
"""
if device_ids is not None and (
not isinstance(device_ids, (int, str, MutableSequence, tuple)) or isinstance(device_ids, bool)
):
raise MisconfigurationException("Device ID's (GPU/TPU) must be int, string or sequence of ints or None.")
def _tpu_cores_valid(tpu_cores: Any) -> bool:
# allow 1 or 8 cores
if tpu_cores in (1, 8, None):
return True
# allow picking 1 of 8 indexes
if isinstance(tpu_cores, (list, tuple, set)):
has_1_tpu_idx = len(tpu_cores) == 1
is_valid_tpu_idx = 1 <= list(tpu_cores)[0] <= 8
is_valid_tpu_core_choice = has_1_tpu_idx and is_valid_tpu_idx
return is_valid_tpu_core_choice
return False
def _parse_tpu_cores_str(tpu_cores: str) -> Union[int, List[int]]:
if tpu_cores in ("1", "8"):
return int(tpu_cores)
return [int(x.strip()) for x in tpu_cores.split(",") if len(x) > 0]