From 7c323ba221e9e1f5759c2452dfe7add6df5935b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Feb 2021 18:00:27 +0100 Subject: [PATCH] Fix amp autocast (#6080) * precision fixes * add amp test model * fix test * revert * move assert to training step * fix test * fix test * remove unrelated changes * add changelog * remove unused import --- CHANGELOG.md | 7 ++++++ .../plugins/precision/native_amp.py | 3 ++- tests/models/test_amp.py | 22 +++++++++++++------ 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b80afe7b24d0f..c1e6068f83fde 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [1.2.1] - 2021-02-23 + +### Fixed + +- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080)) + + ## [1.2.0] - 2021-02-18 ### Added diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 60c0f5f84626f..94e6cf376b03a 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -91,4 +91,5 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: @contextmanager def train_step_context(self) -> Generator[autocast, None, None]: """Enable autocast context""" - yield torch.cuda.amp.autocast() + with torch.cuda.amp.autocast(): + yield diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 2dd6c9d997dbf..53ec32764f3ed 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -27,6 +27,16 @@ from tests.helpers import BoringModel +class AMPTestModel(BoringModel): + + def training_step(self, batch, batch_idx): + assert torch.is_autocast_enabled() + output = self(batch) + assert output.dtype == torch.float16 + loss = self.loss(batch, output) + return {"loss": loss} + + @pytest.mark.skip(reason='dp + amp not supported currently') # TODO @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_amp_single_gpu_dp(tmpdir): @@ -41,7 +51,7 @@ def test_amp_single_gpu_dp(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) @@ -60,10 +70,9 @@ def test_amp_single_gpu_ddp_spawn(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -81,7 +90,7 @@ def test_amp_multi_gpu_dp(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) @@ -100,10 +109,9 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -122,7 +130,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): # simulate setting slurm flags tutils.set_random_master_port() - model = BoringModel() + model = AMPTestModel() # exp file to get meta logger = tutils.get_default_logger(tmpdir)