Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry-picking 1.2.1 release [full merge, no squash] #6154

Merged
merged 4 commits into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,20 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [1.2.1] - 2021-02-23

### Fixed

- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080))
- Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089))
- Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107))


## [1.2.0] - 2021-02-18

### Added

- Added `DataType`, `AverageMethod` and `MDMCAverageMethod` enum in metrics ([#5657](https://github.com/PyTorchLightning/pytorch-lightning/pull/5689)
- Added `DataType`, `AverageMethod` and `MDMCAverageMethod` enum in metrics ([#5657](https://github.com/PyTorchLightning/pytorch-lightning/pull/5689))
- Added support for summarized model total params size in megabytes ([#5590](https://github.com/PyTorchLightning/pytorch-lightning/pull/5590))
- Added support for multiple train loaders ([#1959](https://github.com/PyTorchLightning/pytorch-lightning/pull/1959))
- Added `Accuracy` metric now generalizes to Top-k accuracy for (multi-dimensional) multi-class inputs using the `top_k` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time

_this_year = time.strftime("%Y")
__version__ = '1.2.0'
__version__ = '1.2.1'
__author__ = 'William Falcon et al.'
__author_email__ = 'waf2107@columbia.edu'
__license__ = 'Apache-2.0'
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class CPUAccelerator(Accelerator):

def setup(self, trainer, model):
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
MisconfigurationException("amp + cpu is not supported. Please use a GPU option")
raise MisconfigurationException("amp + cpu is not supported. Please use a GPU option")

if "cpu" not in str(self.root_device):
raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead")
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,5 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
@contextmanager
def train_step_context(self) -> Generator[autocast, None, None]:
"""Enable autocast context"""
yield torch.cuda.amp.autocast()
with torch.cuda.amp.autocast():
yield
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def handle_given_plugins(

for plug in plugins:
if isinstance(plug, str):
# Reset the distributed type as the user has overridden training type
# via the plugins argument
self._distrib_type = None
self.set_distributed_mode(plug)

elif isinstance(plug, TrainingTypePlugin):
Expand Down Expand Up @@ -196,7 +199,6 @@ def handle_given_plugins(
)

self._training_type_plugin = training_type
self._training_type_plugin = self.training_type_plugin
self._precision_plugin = precision
self._cluster_environment = cluster_environment or self.select_cluster_environment()

Expand Down
24 changes: 23 additions & 1 deletion tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.gpu import GPUAccelerator
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import DDP2Plugin, DDPPlugin, DDPSpawnPlugin, PrecisionPlugin, SingleDevicePlugin
from pytorch_lightning.plugins import (
DDP2Plugin,
DDPPlugin,
DDPShardedPlugin,
DDPSpawnPlugin,
PrecisionPlugin,
SingleDevicePlugin,
)
from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment
from tests.helpers.boring_model import BoringModel

Expand Down Expand Up @@ -378,3 +385,18 @@ def on_fit_start(self, trainer, pl_module):

with pytest.raises(SystemExit):
trainer.fit(model)


@pytest.mark.parametrize(
["accelerator", "plugin"],
[('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')],
)
def test_plugin_accelerator_choice(accelerator, plugin):
"""
Ensure that when a plugin and accelerator is passed in, that the plugin takes precedent.
"""
trainer = Trainer(accelerator=accelerator, plugins=plugin, num_processes=2)
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)

trainer = Trainer(plugins=plugin, num_processes=2)
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)
21 changes: 21 additions & 0 deletions tests/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from unittest.mock import Mock

import pytest
import torch

from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.plugins import SingleDevicePlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException


def test_unsupported_precision_plugins():
""" Test error messages are raised for unsupported precision plugins with CPU. """
trainer = Mock()
model = Mock()
accelerator = CPUAccelerator(
training_type_plugin=SingleDevicePlugin(torch.device("cpu")),
precision_plugin=MixedPrecisionPlugin()
)
with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."):
accelerator.setup(trainer=trainer, model=model)
22 changes: 15 additions & 7 deletions tests/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@
from tests.helpers import BoringModel


class AMPTestModel(BoringModel):

def training_step(self, batch, batch_idx):
assert torch.is_autocast_enabled()
output = self(batch)
assert output.dtype == torch.float16
loss = self.loss(batch, output)
return {"loss": loss}


@pytest.mark.skip(reason='dp + amp not supported currently') # TODO
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_amp_single_gpu_dp(tmpdir):
Expand All @@ -41,7 +51,7 @@ def test_amp_single_gpu_dp(tmpdir):
precision=16,
)

model = BoringModel()
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)

Expand All @@ -60,10 +70,9 @@ def test_amp_single_gpu_ddp_spawn(tmpdir):
precision=16,
)

model = BoringModel()
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)

assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


Expand All @@ -81,7 +90,7 @@ def test_amp_multi_gpu_dp(tmpdir):
precision=16,
)

model = BoringModel()
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)

Expand All @@ -100,10 +109,9 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir):
precision=16,
)

model = BoringModel()
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)

assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


Expand All @@ -122,7 +130,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
# simulate setting slurm flags
tutils.set_random_master_port()

model = BoringModel()
model = AMPTestModel()

# exp file to get meta
logger = tutils.get_default_logger(tmpdir)
Expand Down
67 changes: 4 additions & 63 deletions tests/plugins/test_amp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel


Expand All @@ -25,78 +23,21 @@
)
@mock.patch('torch.cuda.device_count', return_value=2)
@pytest.mark.parametrize(
['ddp_backend', 'gpus', 'num_processes'],
[('ddp_cpu', None, 2), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)],
['ddp_backend', 'gpus'],
[('ddp', 2), ('ddp2', 2), ('ddp_spawn', 2)],
)
def on_fit_start(tmpdir, ddp_backend, gpus, num_processes):

class CB(Callback):

def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.precision_plugin, NativeMixedPrecisionPlugin)
raise SystemExit()

def train():
model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
precision=16,
amp_backend='native',
gpus=gpus,
num_processes=num_processes,
accelerator=ddp_backend,
callbacks=[CB()],
)
trainer.fit(model)

if ddp_backend == "ddp_cpu":
with pytest.raises(MisconfigurationException, match="MP is only available on GPU"):
train()
else:
with pytest.raises(SystemExit):
train()


@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
@mock.patch.dict(
os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_LOCALID": "0"
}
)
@mock.patch('torch.cuda.device_count', return_value=2)
@pytest.mark.parametrize(
['ddp_backend', 'gpus', 'num_processes'],
[('ddp_cpu', None, 2), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)],
)
def test_amp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
def test_amp_choice_custom_ddp_cpu(device_count_mock, ddp_backend, gpus):

class MyNativeAMP(NativeMixedPrecisionPlugin):
pass

class CB(Callback):

def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.precision_plugin, MyNativeAMP)
raise SystemExit()

model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
precision=16,
amp_backend='native',
num_processes=num_processes,
accelerator=ddp_backend,
plugins=[MyNativeAMP()],
callbacks=[CB()],
)

with pytest.raises(SystemExit):
trainer.fit(model)
assert isinstance(trainer.precision_plugin, MyNativeAMP)


class GradientUnscaleBoringModel(BoringModel):
Expand Down
40 changes: 8 additions & 32 deletions tests/plugins/test_apex_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin
from pytorch_lightning.utilities import _APEX_AVAILABLE
from tests.helpers.boring_model import BoringModel


@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
Expand All @@ -23,30 +21,19 @@
)
@mock.patch('torch.cuda.device_count', return_value=2)
@pytest.mark.parametrize(
['ddp_backend', 'gpus', 'num_processes'],
[('ddp_cpu', None, 2), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)],
['ddp_backend', 'gpus'],
[('ddp', 2), ('ddp2', 2), ('ddp_spawn', 2)],
)
def test_amp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
def test_amp_choice_default_ddp(mocked_device_count, ddp_backend, gpus):

class CB(Callback):

def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin)
raise SystemExit()

model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
precision=16,
amp_backend='apex',
gpus=gpus,
num_processes=num_processes,
accelerator=ddp_backend,
callbacks=[CB()],
)

with pytest.raises(SystemExit):
trainer.fit(model)
assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin)


@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
Expand All @@ -62,31 +49,20 @@ def on_fit_start(self, trainer, pl_module):
)
@mock.patch('torch.cuda.device_count', return_value=2)
@pytest.mark.parametrize(
['ddp_backend', 'gpus', 'num_processes'],
[('ddp_cpu', None, 2), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)],
['ddp_backend', 'gpus'],
[('ddp', 2), ('ddp2', 2), ('ddp_spawn', 2)],
)
def test_amp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
def test_amp_choice_custom_ddp(mocked_device_count, ddp_backend, gpus):

class MyApexPlugin(ApexMixedPrecisionPlugin):
pass

class CB(Callback):

def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.precision_plugin, MyApexPlugin)
raise SystemExit()

model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
precision=16,
amp_backend='apex',
gpus=gpus,
num_processes=num_processes,
accelerator=ddp_backend,
plugins=[MyApexPlugin(amp_level="O2")],
callbacks=[CB()],
)

with pytest.raises(SystemExit):
trainer.fit(model)
assert isinstance(trainer.precision_plugin, MyApexPlugin)