Skip to content

Commit

Permalink
validate manual optimization and supported features before running tr…
Browse files Browse the repository at this point in the history
…aining (#7788)


Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
awaelchli and carmocca authored Jun 3, 2021
1 parent 0bad218 commit 36770b2
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 6 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `teardown()` in `Accelerator` to allow `training_type_plugin` to customize `teardown` logic ([#7579](https://github.com/PyTorchLightning/pytorch-lightning/pull/7579))


- `Trainer.fit` now raises an error when using manual optimization with unsupported features such as `gradient_clip_val` or `accumulate_grad_batches` ([#7788](https://github.com/PyTorchLightning/pytorch-lightning/pull/7788))


### Deprecated


Expand Down
17 changes: 17 additions & 0 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def verify_loop_configurations(self, model: 'pl.LightningModule') -> None:
if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
self.__verify_train_loop_configuration(model)
self.__verify_eval_loop_configuration(model, 'val')
self.__verify_manual_optimization_support(model)
elif self.trainer.state.fn == TrainerFn.VALIDATING:
self.__verify_eval_loop_configuration(model, 'val')
elif self.trainer.state.fn == TrainerFn.TESTING:
Expand Down Expand Up @@ -112,3 +113,19 @@ def __verify_dp_batch_transfer_support(self, model: 'pl.LightningModule') -> Non
for hook in batch_transfer_hooks:
if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model):
raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.')

def __verify_manual_optimization_support(self, model: 'pl.LightningModule') -> None:
if model.automatic_optimization:
return
if self.trainer.gradient_clip_val > 0:
raise MisconfigurationException(
f"Automatic gradient clipping is not supported for manual optimization."
f" Remove `Trainer(gradient_clip_val={self.trainer.gradient_clip_val})`"
f" or switch to automatic optimization."
)
if self.trainer.accumulate_grad_batches != 1:
raise MisconfigurationException(
f"Automatic gradient accumulation is not supported for manual optimization."
f" Remove `Trainer(accumulate_grad_batches={self.trainer.accumulate_grad_batches})`"
f" or switch to automatic optimization."
)
1 change: 0 additions & 1 deletion tests/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def configure_optimizers(self):
limit_val_batches=1,
max_epochs=1,
weights_summary=None,
accumulate_grad_batches=999, # does not do anything if manual optimization
)

with patch.multiple(torch.optim.SGD, zero_grad=DEFAULT, step=DEFAULT) as sgd, \
Expand Down
5 changes: 0 additions & 5 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,6 @@ def on_train_epoch_end(self, *_, **__):
limit_val_batches=0,
precision=16,
amp_backend='native',
accumulate_grad_batches=4,
gpus=1,
)
trainer.fit(model)
Expand Down Expand Up @@ -631,7 +630,6 @@ def configure_optimizers(self):
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
accumulate_grad_batches=2,
)

trainer.fit(model)
Expand Down Expand Up @@ -682,7 +680,6 @@ def configure_optimizers(self):
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
accumulate_grad_batches=2,
)

trainer.fit(model)
Expand Down Expand Up @@ -757,7 +754,6 @@ def configure_optimizers(self):
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
accumulate_grad_batches=2,
)

trainer.fit(model)
Expand Down Expand Up @@ -867,7 +863,6 @@ def train_manual_optimization(tmpdir, accelerator, model_cls=TesManualOptimizati
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
accumulate_grad_batches=2,
gpus=2,
accelerator=accelerator,
callbacks=[TestManualOptimizationDDPCallack()]
Expand Down
14 changes: 14 additions & 0 deletions tests/trainer/test_config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,17 @@ def predict_dataloader(self):

with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"):
trainer.predict(model)


def test_trainer_manual_optimization_config(tmpdir):
""" Test error message when requesting Trainer features unsupported with manual optimization """
model = BoringModel()
model.automatic_optimization = False

trainer = Trainer(gradient_clip_val=1.0)
with pytest.raises(MisconfigurationException, match="Automatic gradient clipping is not supported"):
trainer.fit(model)

trainer = Trainer(accumulate_grad_batches=2)
with pytest.raises(MisconfigurationException, match="Automatic gradient accumulation is not supported"):
trainer.fit(model)

0 comments on commit 36770b2

Please sign in to comment.