Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add skipif warpper #6258

Merged
merged 27 commits into from
Mar 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tests/callbacks/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.datamodules import RegressDataModule
from tests.helpers.simple_models import RegressionModel
from tests.helpers.skipif import skipif_args
from tests.helpers.skipif import SkipIf


@pytest.mark.parametrize(
"observe",
['average', pytest.param('histogram', marks=pytest.mark.skipif(**skipif_args(min_torch="1.5")))]
['average', pytest.param('histogram', marks=SkipIf(min_torch="1.5"))]
)
@pytest.mark.parametrize("fuse", [True, False])
@pytest.mark.skipif(**skipif_args(quant_available=True))
@SkipIf(quantization=True)
def test_quantization(tmpdir, observe, fuse):
"""Parity test for quant model"""
seed_everything(42)
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_quantization(tmpdir, observe, fuse):
assert torch.allclose(org_score, quant_score, atol=0.45)


@pytest.mark.skipif(**skipif_args(quant_available=True))
@SkipIf(quantization=True)
def test_quantize_torchscript(tmpdir):
"""Test converting to torchscipt """
dm = RegressDataModule()
Expand All @@ -81,7 +81,7 @@ def test_quantize_torchscript(tmpdir):
tsmodel(tsmodel.quant(batch[0]))


@pytest.mark.skipif(**skipif_args(quant_available=True))
@SkipIf(quantization=True)
def test_quantization_exceptions(tmpdir):
"""Test wrong fuse layers"""
with pytest.raises(MisconfigurationException, match='Unsupported qconfig'):
Expand Down Expand Up @@ -124,7 +124,7 @@ def custom_trigger_last(trainer):
(custom_trigger_last, 2),
]
)
@pytest.mark.skipif(**skipif_args(quant_available=True))
@SkipIf(quantization=True)
def test_quantization_triggers(tmpdir, trigger_fn, expected_count):
"""Test how many times the quant is called"""
dm = RegressDataModule()
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.states import TrainerState
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.skipif import skipif_args
from tests.helpers.skipif import SkipIf


def _setup_ddp(rank, worldsize):
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_result_reduce_ddp(result_cls):
pytest.param(5, False, 0, id='nested_list_predictions'),
pytest.param(6, False, 0, id='dict_list_predictions'),
pytest.param(7, True, 0, id='write_dict_predictions'),
pytest.param(0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(**skipif_args(min_gpus=1)))
pytest.param(0, True, 1, id='full_loop_single_gpu', marks=SkipIf(min_gpus=1))
]
)
def test_result_obj_predictions(tmpdir, test_option, do_train, gpus):
Expand Down
84 changes: 48 additions & 36 deletions tests/helpers/skipif.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,52 +21,64 @@
from pytorch_lightning.utilities import _TORCH_QUANTIZE_AVAILABLE


def skipif_args(
min_gpus: int = 0,
min_torch: Optional[str] = None,
quant_available: bool = False,
) -> dict:
""" Creating aggregated arguments for standard pytest skipif, sot the usecase is::

@pytest.mark.skipif(**create_skipif(min_torch="99"))
def test_any_func(...):
...
class SkipIf:
"""
SkipIf wrapper for simple marking specific cases, fully compatible with pytest.mark::

>>> from pprint import pprint
>>> pprint(skipif_args(min_torch="99", min_gpus=0))
{'condition': True, 'reason': 'Required: [torch>=99]'}
>>> pprint(skipif_args(min_torch="0.0", min_gpus=0)) # doctest: +NORMALIZE_WHITESPACE
{'condition': False, 'reason': 'Conditions satisfied, going ahead with the test.'}
@SkipIf(min_torch="0.0")
@pytest.mark.parametrize("arg1", [1, 2.0])
def test_wrapper(arg1):
assert arg1 > 0.0
"""
conditions = []
reasons = []

if min_gpus:
conditions.append(torch.cuda.device_count() < min_gpus)
reasons.append(f"GPUs>={min_gpus}")
def __new__(
self,
*args,
min_gpus: int = 0,
min_torch: Optional[str] = None,
quantization: bool = False,
**kwargs
):
"""
Args:
args: native pytest.mark.skipif arguments
min_gpus: min number of gpus required to run test
min_torch: minimum pytorch version to run test
quantization: if `torch.quantization` package is required to run test
kwargs: native pytest.mark.skipif keyword arguments
"""
conditions = []
Borda marked this conversation as resolved.
Show resolved Hide resolved
reasons = []

if min_torch:
torch_version = LooseVersion(get_distribution("torch").version)
conditions.append(torch_version < LooseVersion(min_torch))
reasons.append(f"torch>={min_torch}")
if min_gpus:
conditions.append(torch.cuda.device_count() < min_gpus)
reasons.append(f"GPUs>={min_gpus}")

if quant_available:
_miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines
conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default)
reasons.append("PyTorch quantization is available")
if min_torch:
torch_version = LooseVersion(get_distribution("torch").version)
conditions.append(torch_version < LooseVersion(min_torch))
reasons.append(f"torch>={min_torch}")

if not any(conditions):
return dict(condition=False, reason="Conditions satisfied, going ahead with the test.")
if quantization:
_miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines
conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default)
reasons.append("missing PyTorch quantization")

reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
return dict(condition=any(conditions), reason=f"Required: [{' + '.join(reasons)}]",)
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
return pytest.mark.skipif(
*args,
condition=any(conditions),
reason=f"Requires: [{' + '.join(reasons)}]",
**kwargs,
)


@pytest.mark.skipif(**skipif_args(min_torch="99"))
@SkipIf(min_torch="99")
def test_always_skip():
exit(1)


@pytest.mark.skipif(**skipif_args(min_torch="0.0"))
def test_always_pass():
assert True
@pytest.mark.parametrize("arg1", [0.5, 1.0, 2.0])
@SkipIf(min_torch="0.0")
def test_wrapper(arg1):
assert arg1 > 0.0