Skip to content

Commit

Permalink
Add RunIf
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored and lexierule committed Mar 30, 2021
1 parent 9539e7a commit 832e771
Showing 1 changed file with 184 additions and 0 deletions.
184 changes: 184 additions & 0 deletions tests/helpers/runif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from distutils.version import LooseVersion
from typing import Optional

import pytest
import torch
from pkg_resources import get_distribution

from pytorch_lightning.utilities import (
_APEX_AVAILABLE,
_DEEPSPEED_AVAILABLE,
_FAIRSCALE_AVAILABLE,
_FAIRSCALE_PIPE_AVAILABLE,
_HOROVOD_AVAILABLE,
_NATIVE_AMP_AVAILABLE,
_RPC_AVAILABLE,
_TORCH_QUANTIZE_AVAILABLE,
_TPU_AVAILABLE,
)

try:
from horovod.common.util import nccl_built
nccl_built()
except (ImportError, ModuleNotFoundError, AttributeError):
_HOROVOD_NCCL_AVAILABLE = False
finally:
_HOROVOD_NCCL_AVAILABLE = True


class RunIf:
"""
RunIf wrapper for simple marking specific cases, fully compatible with pytest.mark::
@RunIf(min_torch="0.0")
@pytest.mark.parametrize("arg1", [1, 2.0])
def test_wrapper(arg1):
assert arg1 > 0.0
"""

def __new__(
self,
*args,
min_gpus: int = 0,
min_torch: Optional[str] = None,
max_torch: Optional[str] = None,
min_python: Optional[str] = None,
quantization: bool = False,
amp_apex: bool = False,
amp_native: bool = False,
tpu: bool = False,
horovod: bool = False,
horovod_nccl: bool = False,
skip_windows: bool = False,
special: bool = False,
rpc: bool = False,
fairscale: bool = False,
fairscale_pipe: bool = False,
deepspeed: bool = False,
**kwargs
):
"""
Args:
args: native pytest.mark.skipif arguments
min_gpus: min number of gpus required to run test
min_torch: minimum pytorch version to run test
max_torch: maximum pytorch version to run test
min_python: minimum python version required to run test
quantization: if `torch.quantization` package is required to run test
amp_apex: NVIDIA Apex is installed
amp_native: if native PyTorch native AMP is supported
tpu: if TPU is available
horovod: if Horovod is installed
horovod_nccl: if Horovod is installed with NCCL support
skip_windows: skip test for Windows platform (typically fo some limited torch functionality)
special: running in special mode, outside pytest suit
rpc: requires Remote Procedure Call (RPC)
fairscale: if `fairscale` module is required to run the test
deepspeed: if `deepspeed` module is required to run the test
kwargs: native pytest.mark.skipif keyword arguments
"""
conditions = []
reasons = []

if min_gpus:
conditions.append(torch.cuda.device_count() < min_gpus)
reasons.append(f"GPUs>={min_gpus}")

if min_torch:
torch_version = LooseVersion(get_distribution("torch").version)
conditions.append(torch_version < LooseVersion(min_torch))
reasons.append(f"torch>={min_torch}")

if max_torch:
torch_version = LooseVersion(get_distribution("torch").version)
conditions.append(torch_version >= LooseVersion(max_torch))
reasons.append(f"torch<{max_torch}")

if min_python:
py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
conditions.append(py_version < LooseVersion(min_python))
reasons.append(f"python>={min_python}")

if quantization:
_miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines
conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default)
reasons.append("PyTorch quantization")

if amp_native:
conditions.append(not _NATIVE_AMP_AVAILABLE)
reasons.append("native AMP")

if amp_apex:
conditions.append(not _APEX_AVAILABLE)
reasons.append("NVIDIA Apex")

if skip_windows:
conditions.append(sys.platform == "win32")
reasons.append("unimplemented on Windows")

if tpu:
conditions.append(not _TPU_AVAILABLE)
reasons.append("TPU")

if horovod:
conditions.append(not _HOROVOD_AVAILABLE)
reasons.append("Horovod")

if horovod_nccl:
conditions.append(not _HOROVOD_NCCL_AVAILABLE)
reasons.append("Horovod with NCCL")

if special:
env_flag = os.getenv("PL_RUNNING_SPECIAL_TESTS", '0')
conditions.append(env_flag != '1')
reasons.append("Special execution")

if rpc:
conditions.append(not _RPC_AVAILABLE)
reasons.append("RPC")

if fairscale:
conditions.append(not _FAIRSCALE_AVAILABLE)
reasons.append("Fairscale")

if fairscale_pipe:
conditions.append(not _FAIRSCALE_PIPE_AVAILABLE)
reasons.append("Fairscale Pipe")

if deepspeed:
conditions.append(not _DEEPSPEED_AVAILABLE)
reasons.append("Deepspeed")

reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
return pytest.mark.skipif(
*args,
condition=any(conditions),
reason=f"Requires: [{' + '.join(reasons)}]",
**kwargs,
)


@RunIf(min_torch="99")
def test_always_skip():
exit(1)


@pytest.mark.parametrize("arg1", [0.5, 1.0, 2.0])
@RunIf(min_torch="0.0")
def test_wrapper(arg1: float):
assert arg1 > 0.0

0 comments on commit 832e771

Please sign in to comment.