diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ad54381a082b..7dad863d41293 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080)) + ## [1.2.0] - 2021-02-18 diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 60c0f5f84626f..94e6cf376b03a 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -91,4 +91,5 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: @contextmanager def train_step_context(self) -> Generator[autocast, None, None]: """Enable autocast context""" - yield torch.cuda.amp.autocast() + with torch.cuda.amp.autocast(): + yield diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 2dd6c9d997dbf..53ec32764f3ed 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -27,6 +27,16 @@ from tests.helpers import BoringModel +class AMPTestModel(BoringModel): + + def training_step(self, batch, batch_idx): + assert torch.is_autocast_enabled() + output = self(batch) + assert output.dtype == torch.float16 + loss = self.loss(batch, output) + return {"loss": loss} + + @pytest.mark.skip(reason='dp + amp not supported currently') # TODO @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_amp_single_gpu_dp(tmpdir): @@ -41,7 +51,7 @@ def test_amp_single_gpu_dp(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) @@ -60,10 +70,9 @@ def test_amp_single_gpu_ddp_spawn(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -81,7 +90,7 @@ def test_amp_multi_gpu_dp(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) @@ -100,10 +109,9 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -122,7 +130,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): # simulate setting slurm flags tutils.set_random_master_port() - model = BoringModel() + model = AMPTestModel() # exp file to get meta logger = tutils.get_default_logger(tmpdir)