Skip to content

Commit

Permalink
Remove deprecated optimizer argument from manual_backward (#8287)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
  • Loading branch information
awaelchli and carmocca authored Jul 6, 2021
1 parent 9eda520 commit f1341a5
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 65 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated trainer attributes - `on_cpu`, `on_tpu`, `use_tpu`, `on_gpu`, `use_dp`, `use_ddp`, `use_ddp2`, `use_horovod`, `use_single_gpu` ([#7501](https://github.com/PyTorchLightning/pytorch-lightning/pull/7501))


- Removed deprecated `optimizer` argument in `LightningModule.manual_backward()`; Toggling optimizers in manual optimization should be done using `LightningModule.{un}toggle_optimizer()` ([#8287](https://github.com/PyTorchLightning/pytorch-lightning/pull/8287))


### Fixed

- Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877))
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@ def training_step(self, batch, batch_idx, optimizer_idx):
(opt_a, opt_b) = self.optimizers(use_pl_optimizer=True)
loss_1 = self.step(batch)

self.manual_backward(loss_1, opt_a)
self.manual_backward(loss_1)
opt_a.step()

# fake discriminator
loss_2 = self.step(batch[0])

# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.manual_backward(loss_2, opt_b)
self.manual_backward(loss_2)
# todo: understand why synchronization breaks there.
# self.manual_backward(loss_2, opt_a, retain_graph=True)
# self.manual_backward(loss_2, retain_graph=True)
opt_b.step()

assert self.layer.weight.grad is None or torch.all(self.layer.weight.grad == 0)
Expand Down
8 changes: 1 addition & 7 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,7 +1418,7 @@ def configure_optimizers(self):
"""
rank_zero_warn("`configure_optimizers` must be implemented to be used with the Lightning Trainer")

def manual_backward(self, loss: Tensor, optimizer: Optional[Optimizer] = None, *args, **kwargs) -> None:
def manual_backward(self, loss: Tensor, *args, **kwargs) -> None:
"""
Call this directly from your :meth:`training_step` when doing optimizations manually.
By using this, Lightning can ensure that all the proper scaling gets applied when using mixed precision.
Expand All @@ -1437,15 +1437,9 @@ def training_step(...):
Args:
loss: The tensor on which to compute gradients. Must have a graph attached.
optimizer: This argument is unused and deprecated. It will be removed in v1.4.
*args: Additional positional arguments to be forwarded to :meth:`~torch.Tensor.backward`
**kwargs: Additional keyword arguments to be forwarded to :meth:`~torch.Tensor.backward`
"""
if optimizer is not None:
rank_zero_deprecation(
"`optimizer` argument to `manual_backward` is deprecated in v1.2 and will be removed in v1.4"
)

# make sure we're using manual opt
self._verify_is_manual_optimization('manual_backward')

Expand Down
26 changes: 0 additions & 26 deletions tests/deprecated_api/test_remove_1-4.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,6 @@ def test_v1_4_0_deprecated_imports():
from pytorch_lightning.utilities.argparse_utils import _gpus_arg_default # noqa: F811 F401


def test_v1_4_0_deprecated_manual_optimization_optimizer(tmpdir):

class TestModel(BoringModel):

def training_step(self, batch, *_, **kwargs):
opt = self.optimizers()
output = self.layer(batch)
loss = self.loss(batch, output)
self.manual_backward(loss, opt)

@property
def automatic_optimization(self):
return False

model = TestModel()
model.training_epoch_end = None
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
)
with pytest.deprecated_call(
match="`optimizer` argument to `manual_backward` is deprecated in v1.2 and will be removed in v1.4"
):
trainer.fit(model)


def test_v1_4_0_deprecated_checkpoint_on(tmpdir):
from pytorch_lightning.callbacks.model_checkpoint import warning_cache
warning_cache.clear()
Expand Down
54 changes: 27 additions & 27 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ def training_step(self, batch, batch_idx):
assert torch.all(self.layer.weight.grad == 0)

loss_1 = self.step(batch[0])
self.manual_backward(loss_1, opt_a)
self.manual_backward(loss_1)
opt_a.step()
opt_a.zero_grad()
assert torch.all(self.layer.weight.grad == 0)

loss_2 = self.step(batch[0])
# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.manual_backward(loss_2, opt_b, retain_graph=True)
self.manual_backward(loss_2, opt_a)
self.manual_backward(loss_2, retain_graph=True)
self.manual_backward(loss_2)
assert self.layer.weight.grad is not None
opt_b.step()
opt_b.zero_grad()
Expand Down Expand Up @@ -254,7 +254,7 @@ def training_step(self, batch, batch_idx):

if self.should_update:

self.manual_backward(loss, opt)
self.manual_backward(loss)
opt.step()
opt.zero_grad()

Expand Down Expand Up @@ -385,7 +385,7 @@ def training_step(self, batch, batch_idx):

if self.should_update:

self.manual_backward(loss, opt)
self.manual_backward(loss)
if self.should_have_updated:
opt.step()
opt.zero_grad()
Expand Down Expand Up @@ -458,7 +458,7 @@ def training_step(self, batch, batch_idx):
if self.layer.weight.grad is not None:
assert torch.all(self.layer.weight.grad == 0)

self.manual_backward(loss_1, opt_a)
self.manual_backward(loss_1)
opt_a.step()

# fake discriminator
Expand All @@ -467,8 +467,8 @@ def training_step(self, batch, batch_idx):

# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.manual_backward(loss_2, opt_b, retain_graph=True)
self.manual_backward(loss_2, opt_a, retain_graph=True)
self.manual_backward(loss_2, retain_graph=True)
self.manual_backward(loss_2, retain_graph=True)

assert self.layer.weight.grad is not None
opt_b.step()
Expand Down Expand Up @@ -542,7 +542,7 @@ def optimizer_closure():
loss = compute_loss()
losses.append(loss)
retain_graph = (num_backward - 1) != backward_idx
self.manual_backward(loss, opt, retain_graph=retain_graph)
self.manual_backward(loss, retain_graph=retain_graph)
# emulate MC dropout training
loss = torch.stack(losses).mean()
self._losses.append(loss)
Expand Down Expand Up @@ -604,7 +604,7 @@ def optimizer_closure():
num_backward = 1
for backward_idx in range(num_backward + 1):
retain_graph = num_backward != backward_idx # noqa E225
self.manual_backward(loss_1, opt, retain_graph=retain_graph)
self.manual_backward(loss_1, retain_graph=retain_graph)

weight_before = self.layer.weight.clone()

Expand Down Expand Up @@ -661,7 +661,7 @@ def optimizer_closure():
num_backward = 1
for backward_idx in range(num_backward + 1):
retain_graph = num_backward != backward_idx # noqa E225
self.manual_backward(loss_1, opt, retain_graph=retain_graph)
self.manual_backward(loss_1, retain_graph=retain_graph)

opt.step(closure=optimizer_closure)
opt.zero_grad()
Expand Down Expand Up @@ -719,12 +719,12 @@ def compute_loss():
def gen_closure():
loss_gen = compute_loss()
self.log("loss_gen", loss_gen, on_step=True, on_epoch=True)
self.manual_backward(loss_gen, opt_gen)
self.manual_backward(loss_gen)

def dis_closure():
loss_dis = compute_loss()
self.log("loss_dis", loss_dis, on_step=True, on_epoch=True)
self.manual_backward(loss_dis, opt_dis)
self.manual_backward(loss_dis)

# this will accumulate gradients for 2 batches and then call opt_gen.step()
gen_closure()
Expand Down Expand Up @@ -813,8 +813,8 @@ def compute_loss():
loss_zeros = self.loss_zeros(None, predictions)
return loss_ones, loss_zeros

def make_manual_backward(loss, opt, retain_graph=False, make_optimizer_step=True):
self.manual_backward(loss, opt, retain_graph=retain_graph)
def make_manual_backward(loss, retain_graph=False, make_optimizer_step=True):
self.manual_backward(loss, retain_graph=retain_graph)
if make_optimizer_step:
grad_clone = self.layer.weight.grad.clone()
assert self.manual_sync_grad()
Expand All @@ -823,13 +823,13 @@ def make_manual_backward(loss, opt, retain_graph=False, make_optimizer_step=True

def gen_closure():
loss_ones_gen, loss_zeros = compute_loss()
make_manual_backward(loss_ones_gen, opt_gen, retain_graph=True, make_optimizer_step=make_gen_optimizer_step)
make_manual_backward(loss_ones_gen, opt_gen, make_optimizer_step=make_gen_optimizer_step)
make_manual_backward(loss_ones_gen, retain_graph=True, make_optimizer_step=make_gen_optimizer_step)
make_manual_backward(loss_ones_gen, make_optimizer_step=make_gen_optimizer_step)

def dis_closure():
loss_ones_gen, loss_zeros = compute_loss()
make_manual_backward(loss_ones_gen, opt_dis, retain_graph=True, make_optimizer_step=make_dis_optimizer_step)
make_manual_backward(loss_ones_gen, opt_dis, make_optimizer_step=make_dis_optimizer_step)
make_manual_backward(loss_ones_gen, retain_graph=True, make_optimizer_step=make_dis_optimizer_step)
make_manual_backward(loss_ones_gen, make_optimizer_step=make_dis_optimizer_step)

# this will accumulate gradients for 2 batches and then call opt_gen.step()
if make_gen_optimizer_step:
Expand Down Expand Up @@ -917,8 +917,8 @@ def compute_loss():
loss_zeros = self.loss_zeros(None, predictions)
return loss_ones, loss_zeros

def make_manual_backward(loss, opt, retain_graph=False, make_optimizer_step=True):
self.manual_backward(loss, opt, retain_graph=retain_graph)
def make_manual_backward(loss, retain_graph=False, make_optimizer_step=True):
self.manual_backward(loss, retain_graph=retain_graph)
if make_optimizer_step:
grad_clone = self.layer.weight.grad.clone()
assert self.manual_sync_grad()
Expand All @@ -927,13 +927,13 @@ def make_manual_backward(loss, opt, retain_graph=False, make_optimizer_step=True

def gen_closure():
loss_ones_gen, loss_zeros = compute_loss()
make_manual_backward(loss_ones_gen, opt_gen, retain_graph=True, make_optimizer_step=make_gen_optimizer_step)
make_manual_backward(loss_ones_gen, opt_gen, make_optimizer_step=make_gen_optimizer_step)
make_manual_backward(loss_ones_gen, retain_graph=True, make_optimizer_step=make_gen_optimizer_step)
make_manual_backward(loss_ones_gen, make_optimizer_step=make_gen_optimizer_step)

def dis_closure():
loss_ones_gen, loss_zeros = compute_loss()
make_manual_backward(loss_ones_gen, opt_dis, retain_graph=True, make_optimizer_step=make_dis_optimizer_step)
make_manual_backward(loss_ones_gen, opt_dis, make_optimizer_step=make_dis_optimizer_step)
make_manual_backward(loss_ones_gen, retain_graph=True, make_optimizer_step=make_dis_optimizer_step)
make_manual_backward(loss_ones_gen, make_optimizer_step=make_dis_optimizer_step)

# this will accumulate gradients for 2 batches and then call opt_gen.step()
with opt_gen.toggle_model(sync_grad=make_gen_optimizer_step):
Expand Down Expand Up @@ -1055,7 +1055,7 @@ def training_step(self, batch, batch_idx):
self.log("loss_d", loss_d, prog_bar=True)

optimizer.zero_grad()
self.manual_backward(loss_d, optimizer)
self.manual_backward(loss_d)
optimizer.step()
self.untoggle_optimizer(optimizer_idx)

Expand All @@ -1068,7 +1068,7 @@ def training_step(self, batch, batch_idx):
self.log("loss_g", loss_g, prog_bar=True)

optimizer.zero_grad()
self.manual_backward(loss_g, optimizer)
self.manual_backward(loss_g)
optimizer.step()
self.untoggle_optimizer(optimizer_idx)

Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/optimization/test_multiple_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,13 @@ def training_step(self, batch, batch_idx):
loss_1 = self.step(batch[0])

# fake generator
self.manual_backward(loss_1, opt_a)
self.manual_backward(loss_1)
opt_a.step()
opt_a.zero_grad()

# fake discriminator
loss_2 = self.step(batch[0])
self.manual_backward(loss_2, opt_b)
self.manual_backward(loss_2)
opt_b.step()
opt_b.zero_grad()

Expand Down

0 comments on commit f1341a5

Please sign in to comment.