Skip to content

Commit

Permalink
Call optimizer.zero_grad() before backward inside closure in AutoOpt (
Browse files Browse the repository at this point in the history
#6147)

Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
  • Loading branch information
2 people authored and lexierule committed Mar 5, 2021
1 parent 5abfd2c commit 3c498ce
Show file tree
Hide file tree
Showing 11 changed files with 292 additions and 378 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Changed the order of `backward`, `step`, `zero_grad` to `zero_grad`, `backward`, `step` ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147))


### Deprecated

Expand All @@ -30,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))


- Fixed LBFGS optimizer support which didn't converge in automatic optimization ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147))


- Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931))


Expand Down
69 changes: 39 additions & 30 deletions docs/source/common/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,31 @@ to manually manage the optimization process. To do so, do the following:

* Override your LightningModule ``automatic_optimization`` property to return ``False``
* Drop or ignore the optimizer_idx argument
* Use `self.manual_backward(loss)` instead of `loss.backward()`.
* Use ``self.manual_backward(loss)`` instead of ``loss.backward()``.

.. note:: This is only recommended for experts who need ultimate flexibility. Lightning will handle only precision and accelerators logic. The users are left with zero_grad, accumulated_grad_batches, model toggling, etc..
.. note:: This is only recommended for experts who need ultimate flexibility. Lightning will handle only precision and accelerators logic. The users are left with ``optimizer.zero_grad()``, gradient accumulation, model toggling, etc..

.. warning:: Before 1.2, ``optimzer.step`` was calling ``zero_grad`` internally. From 1.2, it is left to the users expertize.
.. warning:: Before 1.2, ``optimzer.step`` was calling ``optimizer.zero_grad()`` internally. From 1.2, it is left to the users expertize.

.. tip:: To perform ``accumulate_grad_batches`` with one optimizer, you can do as such.

.. tip:: ``self.optimizers()`` will return ``LightningOptimizer`` objects. You can access your own optimizer with ``optimizer.optimizer``. However, if you use your own optimizer to perform a step, Lightning won't be able to support accelerators and precision for you.


.. code-block:: python
def training_step(batch, batch_idx, optimizer_idx):
opt = self.optimizers()
loss = self.compute_loss(batch)
self.manual_backward(loss)
opt.step()
# accumulate gradient batches
if batch_idx % 2 == 0:
opt.step()
opt.zero_grad()
.. tip:: It is a good practice to provide the optimizer with a ``closure`` function that performs a ``forward`` and ``backward`` pass of your model. It is optional for most optimizers, but makes your code compatible if you switch to an optimizer which requires a closure.
.. tip:: It is a good practice to provide the optimizer with a ``closure`` function that performs a ``forward`` and ``backward`` pass of your model. It is optional for most optimizers, but makes your code compatible if you switch to an optimizer which requires a closure. See also `the PyTorch docs <https://pytorch.org/docs/stable/optim.html#optimizer-step-closure>`_.

Here is the same example as above using a ``closure``.

Expand All @@ -71,7 +70,6 @@ Here is the same example as above using a ``closure``.
.. code-block:: python
# Scenario for a GAN.
def training_step(...):
opt_gen, opt_dis = self.optimizers()
Expand Down Expand Up @@ -137,8 +135,12 @@ Here is an example on how to use it:

Automatic optimization
======================
With Lightning most users don't have to think about when to call .backward(), .step(), .zero_grad(), since
Lightning automates that for you.
With Lightning most users don't have to think about when to call ``.zero_grad()``, ``.backward()`` and ``.step()``
since Lightning automates that for you.

.. warning::
Before 1.2.2, ``.zero_grad()`` was called after ``.backward()`` and ``.step()`` internally.
From 1.2.2, Lightning calls ``.zero_grad()`` before ``.backward()``.

Under the hood Lightning does the following:

Expand All @@ -147,33 +149,33 @@ Under the hood Lightning does the following:
for epoch in epochs:
for batch in data:
loss = model.training_step(batch, batch_idx, ...)
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()
for scheduler in schedulers:
scheduler.step()
for lr_scheduler in lr_schedulers:
lr_scheduler.step()
In the case of multiple optimizers, Lightning does the following:

.. code-block:: python
for epoch in epochs:
for batch in data:
for opt in optimizers:
disable_grads_for_other_optimizers()
train_step(opt)
opt.step()
for batch in data:
for opt in optimizers:
loss = model.training_step(batch, batch_idx, optimizer_idx)
opt.zero_grad()
loss.backward()
opt.step()
for scheduler in schedulers:
scheduler.step()
for lr_scheduler in lr_schedulers:
lr_scheduler.step()
Learning rate scheduling
------------------------
Every optimizer you use can be paired with any `LearningRateScheduler <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_.
In the basic use-case, the scheduler (or multiple schedulers) should be returned as the second output from the ``.configure_optimizers``
method:
Every optimizer you use can be paired with any `Learning Rate Scheduler <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_.
In the basic use-case, the scheduler (or multiple schedulers) should be returned as the second output from the ``.configure_optimizers`` method:

.. testcode::

Expand Down Expand Up @@ -262,7 +264,7 @@ returned as a dict which can contain the following keywords:

Use multiple optimizers (like GANs)
-----------------------------------
To use multiple optimizers return > 1 optimizers from :meth:`pytorch_lightning.core.LightningModule.configure_optimizers`
To use multiple optimizers return two or more optimizers from :meth:`pytorch_lightning.core.LightningModule.configure_optimizers`

.. testcode::

Expand All @@ -283,13 +285,15 @@ Lightning will call each optimizer sequentially:
.. code-block:: python
for epoch in epochs:
for batch in data:
for opt in optimizers:
train_step(opt)
opt.step()
for batch in data:
for opt in optimizers:
loss = train_step(batch, batch_idx, optimizer_idx)
opt.zero_grad()
loss.backward()
opt.step()
for scheduler in schedulers:
scheduler.step()
for lr_scheduler in lr_schedulers:
lr_scheduler.step()
----------

Expand Down Expand Up @@ -334,7 +338,7 @@ Here we add a learning-rate warm up
# update params
optimizer.step(closure=closure)

.. note:: The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step. It handles TPUs, AMP, accumulate_grad_batches, zero_grad, and much more ...
.. note:: The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step. It handles TPUs, AMP, accumulate_grad_batches and much more ...

.. testcode::

Expand Down Expand Up @@ -364,6 +368,11 @@ Using the closure functions for optimization

When using optimization schemes such as LBFGS, the `second_order_closure` needs to be enabled. By default, this function is defined by wrapping the `training_step` and the backward steps as follows

.. warning::
Before 1.2.2, ``.zero_grad()`` was called outside the closure internally.
From 1.2.2, the closure calls ``.zero_grad()`` inside, so there is no need to define your own closure
when using similar optimizers to :class:`torch.optim.LBFGS` which requires reevaluation of the loss with the closure in ``optimizer.step()``.

.. testcode::

def second_order_closure(pl_module, split_batch, batch_idx, opt_idx, optimizer, hidden):
Expand Down
4 changes: 2 additions & 2 deletions docs/source/starter/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,9 @@ The training step is what happens inside the training loop.
# TRAINING STEP
# ....
# TRAINING STEP
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()
In the case of MNIST, we do the following

Expand All @@ -377,9 +377,9 @@ In the case of MNIST, we do the following
loss = F.nll_loss(logits, y)
# ------ TRAINING STEP END ------
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()
In Lightning, everything that is in the training step gets organized under the
:func:`~pytorch_lightning.core.LightningModule.training_step` function in the LightningModule.
Expand Down
21 changes: 10 additions & 11 deletions docs/source/starter/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ as long as you return a loss with an attached graph from the `training_step`, Li
.. code-block:: python
def training_step(self, batch, batch_idx):
loss = self.encoder(batch[0])
loss = self.encoder(batch)
return loss
.. _manual_opt:
Expand All @@ -267,19 +267,18 @@ Turn off automatic optimization and you control the train loop!
def training_step(self, batch, batch_idx, optimizer_idx):
# access your optimizers with use_pl_optimizer=False. Default is True
(opt_a, opt_b, opt_c) = self.optimizers(use_pl_optimizer=True)
opt_a, opt_b = self.optimizers(use_pl_optimizer=True)
loss_a = self.generator(batch[0])
# use this instead of loss.backward so we can automate half precision, etc...
self.manual_backward(loss_a, opt_a, retain_graph=True)
self.manual_backward(loss_a, opt_a)
opt_a.step()
loss_a = self.generator(batch)
opt_a.zero_grad()
# use `manual_backward()` instead of `loss.backward` to automate half precision, etc...
self.manual_backward(loss_a)
opt_a.step()
loss_b = self.discriminator(batch[0])
self.manual_backward(loss_b, opt_b)
...
loss_b = self.discriminator(batch)
opt_b.zero_grad()
self.manual_backward(loss_b)
opt_b.step()
Predict or Deploy
Expand Down
5 changes: 0 additions & 5 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,10 @@ def toggle_model(self, sync_grad: bool = True):
def __optimizer_step(self, closure: Optional[Callable] = None, profiler_name: str = None, **kwargs):
trainer = self._trainer
optimizer = self._optimizer
model = trainer.lightning_module

with trainer.profiler.profile(profiler_name):
trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)

if self._trainer.train_loop.automatic_optimization:
trainer.train_loop.on_before_zero_grad(optimizer)
model.optimizer_zero_grad(trainer.current_epoch, trainer.batch_idx, optimizer, self._optimizer_idx)

def step(self, *args, closure: Optional[Callable] = None, **kwargs):
"""
Call this directly from your training_step when doing optimizations manually.
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,13 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
self._curr_step_result = result

if not self._skip_backward and self.trainer.train_loop.automatic_optimization:
if not self._skip_backward and self.automatic_optimization:
is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0

if is_first_batch_to_accumulate:
self.on_before_zero_grad(optimizer)
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)

# backward pass
if result is not None:
with self.trainer.profiler.profile("model_backward"):
Expand Down
6 changes: 3 additions & 3 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,20 @@ def test_trainer_callback_system(torch_save, tmpdir):
call.on_train_epoch_start(trainer, model),
call.on_batch_start(trainer, model),
call.on_train_batch_start(trainer, model, ANY, 0, 0),
call.on_after_backward(trainer, model),
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_after_backward(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_batch_end(trainer, model),
call.on_batch_start(trainer, model),
call.on_train_batch_start(trainer, model, ANY, 1, 0),
call.on_after_backward(trainer, model),
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_after_backward(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0),
call.on_batch_end(trainer, model),
call.on_batch_start(trainer, model),
call.on_train_batch_start(trainer, model, ANY, 2, 0),
call.on_after_backward(trainer, model),
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_after_backward(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
call.on_batch_end(trainer, model),
call.on_train_epoch_end(trainer, model, ANY),
Expand Down
68 changes: 2 additions & 66 deletions tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock, patch
from unittest.mock import Mock

import pytest
from torch import nn
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_property_logger(tmpdir):
assert model.logger == logger


def test_automatic_optimization(tmpdir):
def test_automatic_optimization_raises(tmpdir):

class TestModel(BoringModel):

Expand All @@ -95,70 +95,6 @@ def optimizer_step(self, *_, **__):
trainer.fit(model)


def test_automatic_optimization_num_calls(tmpdir):

with patch("torch.optim.SGD.step") as sgd_step, \
patch("torch.optim.SGD.zero_grad") as sgd_zero_grad, \
patch("torch.optim.Adam.step") as adam_step, \
patch("torch.optim.Adam.zero_grad") as adam_zero_grad:

class TestModel(BoringModel):

def training_step(self, batch, batch_idx, optimizer_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}

def configure_optimizers(self):
optimizer = SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = Adam(self.layer.parameters(), lr=0.1)
return [optimizer, optimizer_2]

def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_idx,
optimizer_closure,
on_tpu,
using_native_amp,
using_lbfgs,
):

assert optimizer_closure.__name__ == "train_step_and_backward_closure"

# update generator opt every 2 steps
if optimizer_idx == 0:
if batch_idx % 2 == 0:
assert isinstance(optimizer, SGD)
optimizer.step(closure=optimizer_closure)

# update discriminator opt every 4 steps
if optimizer_idx == 1:
if batch_idx % 4 == 0:
assert isinstance(optimizer, Adam)
optimizer.step(closure=optimizer_closure)

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=8,
limit_val_batches=1,
accumulate_grad_batches=1,
)

trainer.fit(model)

assert sgd_step.call_count == 4
assert sgd_zero_grad.call_count == 4
assert adam_step.call_count == 2
assert adam_zero_grad.call_count == 2


def test_params_groups_and_state_are_accessible(tmpdir):

class TestModel(BoringModel):
Expand Down
Loading

0 comments on commit 3c498ce

Please sign in to comment.