Skip to content

Commit

Permalink
Merge 1074aff into 0b27147
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Feb 19, 2021
2 parents 0b27147 + 1074aff commit c40a462
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080))


## [1.2.0] - 2021-02-18

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,5 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
@contextmanager
def train_step_context(self) -> Generator[autocast, None, None]:
"""Enable autocast context"""
yield torch.cuda.amp.autocast()
with torch.cuda.amp.autocast():
yield
22 changes: 15 additions & 7 deletions tests/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@
from tests.helpers import BoringModel


class AMPTestModel(BoringModel):

def training_step(self, batch, batch_idx):
assert torch.is_autocast_enabled()
output = self(batch)
assert output.dtype == torch.float16
loss = self.loss(batch, output)
return {"loss": loss}


@pytest.mark.skip(reason='dp + amp not supported currently') # TODO
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_amp_single_gpu_dp(tmpdir):
Expand All @@ -41,7 +51,7 @@ def test_amp_single_gpu_dp(tmpdir):
precision=16,
)

model = BoringModel()
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)

Expand All @@ -60,10 +70,9 @@ def test_amp_single_gpu_ddp_spawn(tmpdir):
precision=16,
)

model = BoringModel()
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)

assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


Expand All @@ -81,7 +90,7 @@ def test_amp_multi_gpu_dp(tmpdir):
precision=16,
)

model = BoringModel()
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)

Expand All @@ -100,10 +109,9 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir):
precision=16,
)

model = BoringModel()
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)

assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


Expand All @@ -122,7 +130,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
# simulate setting slurm flags
tutils.set_random_master_port()

model = BoringModel()
model = AMPTestModel()

# exp file to get meta
logger = tutils.get_default_logger(tmpdir)
Expand Down

0 comments on commit c40a462

Please sign in to comment.