-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
184 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
import sys | ||
from distutils.version import LooseVersion | ||
from typing import Optional | ||
|
||
import pytest | ||
import torch | ||
from pkg_resources import get_distribution | ||
|
||
from pytorch_lightning.utilities import ( | ||
_APEX_AVAILABLE, | ||
_DEEPSPEED_AVAILABLE, | ||
_FAIRSCALE_AVAILABLE, | ||
_FAIRSCALE_PIPE_AVAILABLE, | ||
_HOROVOD_AVAILABLE, | ||
_NATIVE_AMP_AVAILABLE, | ||
_RPC_AVAILABLE, | ||
_TORCH_QUANTIZE_AVAILABLE, | ||
_TPU_AVAILABLE, | ||
) | ||
|
||
try: | ||
from horovod.common.util import nccl_built | ||
nccl_built() | ||
except (ImportError, ModuleNotFoundError, AttributeError): | ||
_HOROVOD_NCCL_AVAILABLE = False | ||
finally: | ||
_HOROVOD_NCCL_AVAILABLE = True | ||
|
||
|
||
class RunIf: | ||
""" | ||
RunIf wrapper for simple marking specific cases, fully compatible with pytest.mark:: | ||
@RunIf(min_torch="0.0") | ||
@pytest.mark.parametrize("arg1", [1, 2.0]) | ||
def test_wrapper(arg1): | ||
assert arg1 > 0.0 | ||
""" | ||
|
||
def __new__( | ||
self, | ||
*args, | ||
min_gpus: int = 0, | ||
min_torch: Optional[str] = None, | ||
max_torch: Optional[str] = None, | ||
min_python: Optional[str] = None, | ||
quantization: bool = False, | ||
amp_apex: bool = False, | ||
amp_native: bool = False, | ||
tpu: bool = False, | ||
horovod: bool = False, | ||
horovod_nccl: bool = False, | ||
skip_windows: bool = False, | ||
special: bool = False, | ||
rpc: bool = False, | ||
fairscale: bool = False, | ||
fairscale_pipe: bool = False, | ||
deepspeed: bool = False, | ||
**kwargs | ||
): | ||
""" | ||
Args: | ||
args: native pytest.mark.skipif arguments | ||
min_gpus: min number of gpus required to run test | ||
min_torch: minimum pytorch version to run test | ||
max_torch: maximum pytorch version to run test | ||
min_python: minimum python version required to run test | ||
quantization: if `torch.quantization` package is required to run test | ||
amp_apex: NVIDIA Apex is installed | ||
amp_native: if native PyTorch native AMP is supported | ||
tpu: if TPU is available | ||
horovod: if Horovod is installed | ||
horovod_nccl: if Horovod is installed with NCCL support | ||
skip_windows: skip test for Windows platform (typically fo some limited torch functionality) | ||
special: running in special mode, outside pytest suit | ||
rpc: requires Remote Procedure Call (RPC) | ||
fairscale: if `fairscale` module is required to run the test | ||
deepspeed: if `deepspeed` module is required to run the test | ||
kwargs: native pytest.mark.skipif keyword arguments | ||
""" | ||
conditions = [] | ||
reasons = [] | ||
|
||
if min_gpus: | ||
conditions.append(torch.cuda.device_count() < min_gpus) | ||
reasons.append(f"GPUs>={min_gpus}") | ||
|
||
if min_torch: | ||
torch_version = LooseVersion(get_distribution("torch").version) | ||
conditions.append(torch_version < LooseVersion(min_torch)) | ||
reasons.append(f"torch>={min_torch}") | ||
|
||
if max_torch: | ||
torch_version = LooseVersion(get_distribution("torch").version) | ||
conditions.append(torch_version >= LooseVersion(max_torch)) | ||
reasons.append(f"torch<{max_torch}") | ||
|
||
if min_python: | ||
py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" | ||
conditions.append(py_version < LooseVersion(min_python)) | ||
reasons.append(f"python>={min_python}") | ||
|
||
if quantization: | ||
_miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines | ||
conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default) | ||
reasons.append("PyTorch quantization") | ||
|
||
if amp_native: | ||
conditions.append(not _NATIVE_AMP_AVAILABLE) | ||
reasons.append("native AMP") | ||
|
||
if amp_apex: | ||
conditions.append(not _APEX_AVAILABLE) | ||
reasons.append("NVIDIA Apex") | ||
|
||
if skip_windows: | ||
conditions.append(sys.platform == "win32") | ||
reasons.append("unimplemented on Windows") | ||
|
||
if tpu: | ||
conditions.append(not _TPU_AVAILABLE) | ||
reasons.append("TPU") | ||
|
||
if horovod: | ||
conditions.append(not _HOROVOD_AVAILABLE) | ||
reasons.append("Horovod") | ||
|
||
if horovod_nccl: | ||
conditions.append(not _HOROVOD_NCCL_AVAILABLE) | ||
reasons.append("Horovod with NCCL") | ||
|
||
if special: | ||
env_flag = os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') | ||
conditions.append(env_flag != '1') | ||
reasons.append("Special execution") | ||
|
||
if rpc: | ||
conditions.append(not _RPC_AVAILABLE) | ||
reasons.append("RPC") | ||
|
||
if fairscale: | ||
conditions.append(not _FAIRSCALE_AVAILABLE) | ||
reasons.append("Fairscale") | ||
|
||
if fairscale_pipe: | ||
conditions.append(not _FAIRSCALE_PIPE_AVAILABLE) | ||
reasons.append("Fairscale Pipe") | ||
|
||
if deepspeed: | ||
conditions.append(not _DEEPSPEED_AVAILABLE) | ||
reasons.append("Deepspeed") | ||
|
||
reasons = [rs for cond, rs in zip(conditions, reasons) if cond] | ||
return pytest.mark.skipif( | ||
*args, | ||
condition=any(conditions), | ||
reason=f"Requires: [{' + '.join(reasons)}]", | ||
**kwargs, | ||
) | ||
|
||
|
||
@RunIf(min_torch="99") | ||
def test_always_skip(): | ||
exit(1) | ||
|
||
|
||
@pytest.mark.parametrize("arg1", [0.5, 1.0, 2.0]) | ||
@RunIf(min_torch="0.0") | ||
def test_wrapper(arg1: float): | ||
assert arg1 > 0.0 |