diff --git a/tests/callbacks/test_quantization.py b/tests/callbacks/test_quantization.py index 37fccdb00ff4c..e5f6feab393c5 100644 --- a/tests/callbacks/test_quantization.py +++ b/tests/callbacks/test_quantization.py @@ -22,15 +22,15 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.datamodules import RegressDataModule from tests.helpers.simple_models import RegressionModel -from tests.helpers.skipif import skipif_args +from tests.helpers.skipif import SkipIf @pytest.mark.parametrize( "observe", - ['average', pytest.param('histogram', marks=pytest.mark.skipif(**skipif_args(min_torch="1.5")))] + ['average', pytest.param('histogram', marks=SkipIf(min_torch="1.5"))] ) @pytest.mark.parametrize("fuse", [True, False]) -@pytest.mark.skipif(**skipif_args(quant_available=True)) +@SkipIf(quantization=True) def test_quantization(tmpdir, observe, fuse): """Parity test for quant model""" seed_everything(42) @@ -65,7 +65,7 @@ def test_quantization(tmpdir, observe, fuse): assert torch.allclose(org_score, quant_score, atol=0.45) -@pytest.mark.skipif(**skipif_args(quant_available=True)) +@SkipIf(quantization=True) def test_quantize_torchscript(tmpdir): """Test converting to torchscipt """ dm = RegressDataModule() @@ -81,7 +81,7 @@ def test_quantize_torchscript(tmpdir): tsmodel(tsmodel.quant(batch[0])) -@pytest.mark.skipif(**skipif_args(quant_available=True)) +@SkipIf(quantization=True) def test_quantization_exceptions(tmpdir): """Test wrong fuse layers""" with pytest.raises(MisconfigurationException, match='Unsupported qconfig'): @@ -124,7 +124,7 @@ def custom_trigger_last(trainer): (custom_trigger_last, 2), ] ) -@pytest.mark.skipif(**skipif_args(quant_available=True)) +@SkipIf(quantization=True) def test_quantization_triggers(tmpdir, trigger_fn, expected_count): """Test how many times the quant is called""" dm = RegressDataModule() diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 5db282b6e9081..2638d54259e47 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,7 +26,7 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel -from tests.helpers.skipif import skipif_args +from tests.helpers.skipif import SkipIf def _setup_ddp(rank, worldsize): @@ -72,7 +72,7 @@ def test_result_reduce_ddp(result_cls): pytest.param(5, False, 0, id='nested_list_predictions'), pytest.param(6, False, 0, id='dict_list_predictions'), pytest.param(7, True, 0, id='write_dict_predictions'), - pytest.param(0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(**skipif_args(min_gpus=1))) + pytest.param(0, True, 1, id='full_loop_single_gpu', marks=SkipIf(min_gpus=1)) ] ) def test_result_obj_predictions(tmpdir, test_option, do_train, gpus): diff --git a/tests/helpers/skipif.py b/tests/helpers/skipif.py index d8f5835dd6290..f3f24b0e76a54 100644 --- a/tests/helpers/skipif.py +++ b/tests/helpers/skipif.py @@ -21,52 +21,64 @@ from pytorch_lightning.utilities import _TORCH_QUANTIZE_AVAILABLE -def skipif_args( - min_gpus: int = 0, - min_torch: Optional[str] = None, - quant_available: bool = False, -) -> dict: - """ Creating aggregated arguments for standard pytest skipif, sot the usecase is:: - - @pytest.mark.skipif(**create_skipif(min_torch="99")) - def test_any_func(...): - ... +class SkipIf: + """ + SkipIf wrapper for simple marking specific cases, fully compatible with pytest.mark:: - >>> from pprint import pprint - >>> pprint(skipif_args(min_torch="99", min_gpus=0)) - {'condition': True, 'reason': 'Required: [torch>=99]'} - >>> pprint(skipif_args(min_torch="0.0", min_gpus=0)) # doctest: +NORMALIZE_WHITESPACE - {'condition': False, 'reason': 'Conditions satisfied, going ahead with the test.'} + @SkipIf(min_torch="0.0") + @pytest.mark.parametrize("arg1", [1, 2.0]) + def test_wrapper(arg1): + assert arg1 > 0.0 """ - conditions = [] - reasons = [] - if min_gpus: - conditions.append(torch.cuda.device_count() < min_gpus) - reasons.append(f"GPUs>={min_gpus}") + def __new__( + self, + *args, + min_gpus: int = 0, + min_torch: Optional[str] = None, + quantization: bool = False, + **kwargs + ): + """ + Args: + args: native pytest.mark.skipif arguments + min_gpus: min number of gpus required to run test + min_torch: minimum pytorch version to run test + quantization: if `torch.quantization` package is required to run test + kwargs: native pytest.mark.skipif keyword arguments + """ + conditions = [] + reasons = [] - if min_torch: - torch_version = LooseVersion(get_distribution("torch").version) - conditions.append(torch_version < LooseVersion(min_torch)) - reasons.append(f"torch>={min_torch}") + if min_gpus: + conditions.append(torch.cuda.device_count() < min_gpus) + reasons.append(f"GPUs>={min_gpus}") - if quant_available: - _miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines - conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default) - reasons.append("PyTorch quantization is available") + if min_torch: + torch_version = LooseVersion(get_distribution("torch").version) + conditions.append(torch_version < LooseVersion(min_torch)) + reasons.append(f"torch>={min_torch}") - if not any(conditions): - return dict(condition=False, reason="Conditions satisfied, going ahead with the test.") + if quantization: + _miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines + conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default) + reasons.append("missing PyTorch quantization") - reasons = [rs for cond, rs in zip(conditions, reasons) if cond] - return dict(condition=any(conditions), reason=f"Required: [{' + '.join(reasons)}]",) + reasons = [rs for cond, rs in zip(conditions, reasons) if cond] + return pytest.mark.skipif( + *args, + condition=any(conditions), + reason=f"Requires: [{' + '.join(reasons)}]", + **kwargs, + ) -@pytest.mark.skipif(**skipif_args(min_torch="99")) +@SkipIf(min_torch="99") def test_always_skip(): exit(1) -@pytest.mark.skipif(**skipif_args(min_torch="0.0")) -def test_always_pass(): - assert True +@pytest.mark.parametrize("arg1", [0.5, 1.0, 2.0]) +@SkipIf(min_torch="0.0") +def test_wrapper(arg1): + assert arg1 > 0.0