Skip to content

Commit

Permalink
Fix amp autocast (#6080)
Browse files Browse the repository at this point in the history
* precision fixes

* add amp test model

* fix test

* revert

* move assert to training step

* fix test

* fix test

* remove unrelated changes

* add changelog

* remove unused import
  • Loading branch information
awaelchli authored and lexierule committed Feb 24, 2021
1 parent 3645eb1 commit 7c323ba
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [1.2.1] - 2021-02-23

### Fixed

- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080))


## [1.2.0] - 2021-02-18

### Added
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,5 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
@contextmanager
def train_step_context(self) -> Generator[autocast, None, None]:
"""Enable autocast context"""
yield torch.cuda.amp.autocast()
with torch.cuda.amp.autocast():
yield
22 changes: 15 additions & 7 deletions tests/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@
from tests.helpers import BoringModel


class AMPTestModel(BoringModel):

def training_step(self, batch, batch_idx):
assert torch.is_autocast_enabled()
output = self(batch)
assert output.dtype == torch.float16
loss = self.loss(batch, output)
return {"loss": loss}


@pytest.mark.skip(reason='dp + amp not supported currently') # TODO
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_amp_single_gpu_dp(tmpdir):
Expand All @@ -41,7 +51,7 @@ def test_amp_single_gpu_dp(tmpdir):
precision=16,
)

model = BoringModel()
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)

Expand All @@ -60,10 +70,9 @@ def test_amp_single_gpu_ddp_spawn(tmpdir):
precision=16,
)

model = BoringModel()
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)

assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


Expand All @@ -81,7 +90,7 @@ def test_amp_multi_gpu_dp(tmpdir):
precision=16,
)

model = BoringModel()
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)

Expand All @@ -100,10 +109,9 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir):
precision=16,
)

model = BoringModel()
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)

assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


Expand All @@ -122,7 +130,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
# simulate setting slurm flags
tutils.set_random_master_port()

model = BoringModel()
model = AMPTestModel()

# exp file to get meta
logger = tutils.get_default_logger(tmpdir)
Expand Down

0 comments on commit 7c323ba

Please sign in to comment.