Skip to content

Commit

Permalink
update usage of deprecated automatic_optimization (Lightning-AI#5011)
Browse files Browse the repository at this point in the history
* drop deprecated usage automatic_optimization

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
  • Loading branch information
3 people committed Dec 10, 2020
1 parent 77fb425 commit 4ebce38
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 35 deletions.
6 changes: 4 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,8 +1412,10 @@ def get_progress_bar_dict(self):

def _verify_is_manual_optimization(self, fn_name):
if self.trainer.train_loop.automatic_optimization:
m = f'to use {fn_name}, please disable automatic optimization: Trainer(automatic_optimization=False)'
raise MisconfigurationException(m)
raise MisconfigurationException(
f'to use {fn_name}, please disable automatic optimization:'
' set model property `automatic_optimization` as False'
)

@classmethod
def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]:
Expand Down
11 changes: 6 additions & 5 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __verify_train_loop_configuration(self, model):
if trainer.overriden_optimizer_step and not enable_pl_optimizer and automatic_optimization:
rank_zero_warn(
"When overriding `LightningModule` optimizer_step with"
" `Trainer(..., enable_pl_optimizer=False, automatic_optimization=True, ...)`,"
" `Trainer(..., enable_pl_optimizer=False, ...)`,"
" we won't be calling `.zero_grad` we can't assume when you call your `optimizer.step()`."
" For Lightning to take care of it, please use `Trainer(enable_pl_optimizer=True)`."
)
Expand All @@ -89,15 +89,16 @@ def __verify_train_loop_configuration(self, model):
has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad
if (has_overriden_optimization_functions) and going_to_accumulate_grad_batches and automatic_optimization:
raise MisconfigurationException(
'When overriding `LightningModule` optimizer_step or optimizer_zero_grad with '
'`Trainer(automatic_optimization=True, ...)`, `accumulate_grad_batches` should to be 1.'
'When overriding `LightningModule` optimizer_step or optimizer_zero_grad'
' , `accumulate_grad_batches` in `Trainer` should to be 1.'
' It ensures optimizer_step or optimizer_zero_grad are called on every batch.'
)

if (enable_pl_optimizer) and trainer.overriden_optimizer_zero_grad and not automatic_optimization:
raise MisconfigurationException(
'When overriding `LightningModule` optimizer_zero_grad with '
'`Trainer(automatic_optimization=False, enable_pl_optimizer=True, ...) is not supported'
'When overriding `LightningModule` optimizer_zero_grad'
' and preserving model property `automatic_optimization` as True with'
' `Trainer(enable_pl_optimizer=True, ...) is not supported'
)

def __verify_eval_loop_configuration(self, model, eval_loop_name):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def __init__(
)

# init train loop related flags
# TODO: deprecate in 1.2.0
# TODO: remove in 1.3.0
if automatic_optimization is None:
automatic_optimization = True
else:
Expand Down
2 changes: 0 additions & 2 deletions tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def optimizer_step(self, *_, **__):
default_root_dir=tmpdir,
limit_train_batches=2,
accumulate_grad_batches=2,
automatic_optimization=True
)

trainer.fit(model)
Expand Down Expand Up @@ -90,7 +89,6 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
default_root_dir=tmpdir,
limit_train_batches=8,
accumulate_grad_batches=1,
automatic_optimization=True,
enable_pl_optimizer=enable_pl_optimizer
)

Expand Down
17 changes: 10 additions & 7 deletions tests/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def configure_optimizers(self):
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1)
return [optimizer_1, optimizer_2], [lr_scheduler]

@property
def automatic_optimization(self) -> bool:
return False

model = TestModel()
model.training_step_end = None
model.training_epoch_end = None
Expand All @@ -121,8 +125,8 @@ def configure_optimizers(self):
limit_val_batches=1,
max_epochs=1,
weights_summary=None,
automatic_optimization=False,
enable_pl_optimizer=True)
enable_pl_optimizer=True,
)
trainer.fit(model)

assert len(mock_sgd_step.mock_calls) == 2
Expand Down Expand Up @@ -161,6 +165,10 @@ def configure_optimizers(self):
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1)
return [optimizer_1, optimizer_2], [lr_scheduler]

@property
def automatic_optimization(self) -> bool:
return False

model = TestModel()
model.training_step_end = None
model.training_epoch_end = None
Expand All @@ -170,7 +178,6 @@ def configure_optimizers(self):
limit_val_batches=1,
max_epochs=1,
weights_summary=None,
automatic_optimization=False,
accumulate_grad_batches=2,
enable_pl_optimizer=True,
)
Expand Down Expand Up @@ -237,7 +244,6 @@ def configure_optimizers(self):
max_epochs=1,
weights_summary=None,
enable_pl_optimizer=True,
automatic_optimization=True
)
trainer.fit(model)

Expand Down Expand Up @@ -291,7 +297,6 @@ def configure_optimizers(self):
max_epochs=1,
weights_summary=None,
enable_pl_optimizer=True,
automatic_optimization=True
)
trainer.fit(model)

Expand Down Expand Up @@ -352,7 +357,6 @@ def configure_optimizers(self):
max_epochs=1,
weights_summary=None,
enable_pl_optimizer=True,
automatic_optimization=True
)
trainer.fit(model)

Expand Down Expand Up @@ -406,7 +410,6 @@ def configure_optimizers(self):
max_epochs=1,
weights_summary=None,
enable_pl_optimizer=True,
automatic_optimization=True,
)
trainer.fit(model)

Expand Down
5 changes: 4 additions & 1 deletion tests/trainer/dynamic_args/test_multiple_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,14 @@ def configure_optimizers(self):
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer, optimizer_2

@property
def automatic_optimization(self) -> bool:
return False

model = TestModel()
model.val_dataloader = None

trainer = Trainer(
automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
Expand Down
Loading

0 comments on commit 4ebce38

Please sign in to comment.