Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Remove DeepSpeed Plugin FP16 exception #8462

Merged
merged 9 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated `optimizer` argument in `LightningModule.manual_backward()`; Toggling optimizers in manual optimization should be done using `LightningModule.{un}toggle_optimizer()` ([#8287](https://github.com/PyTorchLightning/pytorch-lightning/pull/8287))


- Removed DeepSpeed FP16 Exception as FP32 is now supported ([#8462](https://github.com/PyTorchLightning/pytorch-lightning/pull/8462))


### Fixed

- Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877))
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def _format_precision_config(self):
amp_type = self.lightning_module.trainer.accelerator_connector.amp_type
amp_level = self.lightning_module.trainer.accelerator_connector.amp_level
precision = self.lightning_module.trainer.accelerator_connector.precision
if precision == 16:
if precision in (16, 'mixed'):
if "fp16" not in self.config and amp_type == AMPType.NATIVE:
# FP16 is a DeepSpeed standalone AMP implementation
rank_zero_info("Enabling DeepSpeed FP16.")
Expand All @@ -559,8 +559,6 @@ def _format_precision_config(self):
"enabled": True,
"opt_level": amp_level,
}
if "zero_optimization" in self.config and not ("amp" in self.config or "fp16" in self.config):
raise MisconfigurationException("To use DeepSpeed ZeRO Optimization, you must set precision=16.")

def _create_default_config(
self,
Expand Down
22 changes: 7 additions & 15 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,21 +224,6 @@ def test_deepspeed_defaults(tmpdir):
assert isinstance(plugin.config["zero_optimization"], dict)


@RunIf(min_gpus=1, deepspeed=True)
def test_invalid_deepspeed_defaults_no_precision(tmpdir):
"""Test to ensure that using defaults, if precision is not set to 16, we throw an exception."""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
plugins='deepspeed',
)
with pytest.raises(
MisconfigurationException, match='To use DeepSpeed ZeRO Optimization, you must set precision=16.'
):
trainer.fit(model)


@RunIf(min_gpus=1, deepspeed=True, special=True)
def test_warn_deepspeed_override_backward(tmpdir):
"""Test to ensure that if the backward hook in the LightningModule is overridden, we throw a warning."""
Expand Down Expand Up @@ -448,6 +433,13 @@ def test_deepspeed_multigpu(tmpdir, deepspeed_config):
_assert_save_model_is_equal(model, tmpdir, trainer)


@RunIf(min_gpus=1, deepspeed=True)
def test_deepspeed_fp32_works(tmpdir):
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, gpus=1, plugins='deepspeed_stage_3', fast_dev_run=True)
trainer.fit(model)


class ModelParallelClassificationModel(LightningModule):

def __init__(self, lr: float = 0.01, num_blocks: int = 5):
Expand Down