Skip to content

Commit

Permalink
Refactor: skipif for AMPs 3/n (#6293)
Browse files Browse the repository at this point in the history
* args

* native

* apex

* isort
  • Loading branch information
Borda committed Mar 2, 2021
1 parent bc577ca commit b46d221
Show file tree
Hide file tree
Showing 11 changed files with 35 additions and 38 deletions.
4 changes: 1 addition & 3 deletions tests/core/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.core.memory import ModelSummary, UNKNOWN_SIZE
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.advanced_models import ParityModuleRNN
Expand Down Expand Up @@ -292,8 +291,7 @@ def test_empty_model_size(mode):
assert 0.0 == summary.model_size


@RunIf(min_gpus=1)
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires native AMP.")
@RunIf(min_gpus=1, amp_native=True)
@pytest.mark.parametrize(
'precision', [
pytest.param(16, marks=pytest.mark.skip(reason="no longer valid, because 16 can mean mixed precision")),
Expand Down
14 changes: 13 additions & 1 deletion tests/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from pkg_resources import get_distribution

from pytorch_lightning.utilities import _TORCH_QUANTIZE_AVAILABLE
from pytorch_lightning.utilities import _APEX_AVAILABLE, _NATIVE_AMP_AVAILABLE, _TORCH_QUANTIZE_AVAILABLE


class RunIf:
Expand All @@ -38,6 +38,8 @@ def __new__(
min_gpus: int = 0,
min_torch: Optional[str] = None,
quantization: bool = False,
amp_apex: bool = False,
amp_native: bool = False,
skip_windows: bool = False,
**kwargs
):
Expand All @@ -47,6 +49,8 @@ def __new__(
min_gpus: min number of gpus required to run test
min_torch: minimum pytorch version to run test
quantization: if `torch.quantization` package is required to run test
amp_apex: NVIDIA Apex is installed
amp_native: if native PyTorch native AMP is supported
skip_windows: skip test for Windows platform (typically fo some limited torch functionality)
kwargs: native pytest.mark.skipif keyword arguments
"""
Expand All @@ -67,6 +71,14 @@ def __new__(
conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default)
reasons.append("missing PyTorch quantization")

if amp_native:
conditions.append(not _NATIVE_AMP_AVAILABLE)
reasons.append("missing native AMP")

if amp_apex:
conditions.append(not _APEX_AVAILABLE)
reasons.append("missing NVIDIA Apex")

if skip_windows:
conditions.append(sys.platform == "win32")
reasons.append("unimplemented on Windows")
Expand Down
4 changes: 1 addition & 3 deletions tests/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _APEX_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -193,8 +192,7 @@ def test_amp_without_apex(tmpdir):


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@RunIf(min_gpus=1)
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
@RunIf(min_gpus=1, amp_apex=True)
def test_amp_with_apex(tmpdir):
"""Check calling apex scaling in training."""

Expand Down
8 changes: 3 additions & 5 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.metrics.classification.accuracy import Accuracy
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _APEX_AVAILABLE, _HOROVOD_AVAILABLE, _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from tests.helpers import BoringModel
from tests.helpers.advanced_models import BasicGAN
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -120,8 +120,7 @@ def test_horovod_multi_gpu(tmpdir):

@pytest.mark.skip(reason="Horovod has a problem with broadcast when using apex?")
@pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
@RunIf(min_gpus=2, skip_windows=True)
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
@RunIf(min_gpus=2, skip_windows=True, amp_apex=True)
def test_horovod_apex(tmpdir):
"""Test Horovod with multi-GPU support using apex amp."""
trainer_options = dict(
Expand All @@ -143,8 +142,7 @@ def test_horovod_apex(tmpdir):

@pytest.mark.skip(reason="Skip till Horovod fixes integration with Native torch.cuda.amp")
@pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
@RunIf(min_gpus=2, skip_windows=True)
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires torch.cuda.amp")
@RunIf(min_gpus=2, skip_windows=True, amp_native=True)
def test_horovod_amp(tmpdir):
"""Test Horovod with multi-GPU support using native amp."""
trainer_options = dict(
Expand Down
9 changes: 3 additions & 6 deletions tests/plugins/test_amp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf


@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
@RunIf(amp_native=True)
@mock.patch.dict(
os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
Expand Down Expand Up @@ -49,8 +48,7 @@ def on_after_backward(self):
assert norm.item() < 15.


@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
@RunIf(min_gpus=2)
@RunIf(min_gpus=2, amp_native=True)
def test_amp_gradient_unscale(tmpdir):
model = GradientUnscaleBoringModel()

Expand Down Expand Up @@ -78,8 +76,7 @@ def on_after_backward(self):
assert norm.item() < 15.


@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
@RunIf(min_gpus=2)
@RunIf(min_gpus=2, amp_native=True)
def test_amp_gradient_unscale_accumulate_grad_batches(tmpdir):
model = UnscaleAccumulateGradBatchesBoringModel()

Expand Down
6 changes: 3 additions & 3 deletions tests/plugins/test_apex_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin
from pytorch_lightning.utilities import _APEX_AVAILABLE
from tests.helpers.runif import RunIf


@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
@RunIf(amp_apex=True)
@mock.patch.dict(
os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
Expand Down Expand Up @@ -36,7 +36,7 @@ def test_amp_choice_default_ddp(mocked_device_count, ddp_backend, gpus):
assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin)


@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
@RunIf(amp_apex=True)
@mock.patch.dict(
os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
Expand Down
8 changes: 4 additions & 4 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule
from pytorch_lightning.utilities import _APEX_AVAILABLE, _DEEPSPEED_AVAILABLE, _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -122,12 +122,12 @@ def test_deepspeed_plugin_env(tmpdir, monkeypatch, deepspeed_config):

@pytest.mark.parametrize(
"amp_backend", [
pytest.param("native", marks=pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")),
pytest.param("apex", marks=pytest.mark.skipif(not _APEX_AVAILABLE, reason="Requires Apex")),
pytest.param("native", marks=RunIf(amp_native=True)),
pytest.param("apex", marks=RunIf(amp_apex=True)),
]
)
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@RunIf(amp_native=True)
def test_deepspeed_precision_choice(amp_backend, tmpdir):
"""
Test to ensure precision plugin is also correctly chosen.
Expand Down
7 changes: 3 additions & 4 deletions tests/plugins/test_sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import DDPShardedPlugin, DDPSpawnShardedPlugin
from pytorch_lightning.utilities import _APEX_AVAILABLE, _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -39,7 +39,7 @@ def on_fit_start(self, trainer, pl_module):
trainer.fit(model)


@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
@RunIf(amp_apex=True)
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_invalid_apex_sharded(tmpdir):
"""
Expand All @@ -58,10 +58,9 @@ def test_invalid_apex_sharded(tmpdir):
trainer.fit(model)


@RunIf(min_gpus=2)
@RunIf(min_gpus=2, amp_native=True)
@pytest.mark.parametrize(["accelerator"], [("ddp_sharded", ), ("ddp_sharded_spawn", )])
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
def test_ddp_choice_sharded_amp(tmpdir, accelerator):
"""
Test to ensure that plugin native amp plugin is correctly chosen when using sharded
Expand Down
4 changes: 1 addition & 3 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import _APEX_AVAILABLE
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -310,8 +309,7 @@ def configure_optimizers(self):


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@RunIf(min_gpus=1)
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
@RunIf(min_gpus=1, amp_apex=True)
def test_multiple_optimizers_manual_apex(tmpdir):
"""
Tests that only training_step can be used
Expand Down
4 changes: 1 addition & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
Expand Down Expand Up @@ -881,8 +880,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde
trainer.fit(model)


@RunIf(min_gpus=1)
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires native AMP.")
@RunIf(min_gpus=1, amp_native=True)
def test_gradient_clipping_fp16(tmpdir):
"""
Test gradient clipping with fp16
Expand Down
5 changes: 2 additions & 3 deletions tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, AMPType
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.helpers import BoringModel
Expand Down Expand Up @@ -342,8 +342,7 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
trainer.tune(model, **fit_options)


@RunIf(min_gpus=1)
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires native AMP.")
@RunIf(min_gpus=1, amp_native=True)
def test_auto_scale_batch_size_with_amp(tmpdir):
model = EvalModelTemplate()
batch_size_before = model.batch_size
Expand Down

0 comments on commit b46d221

Please sign in to comment.