Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call optimizer.zero_grad() before backward inside closure in AutoOpt #6147

Merged
merged 37 commits into from
Mar 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a6da028
Call zero_grad inside closure
akihironitta Feb 23, 2021
57840f9
Call zero_grad inside closure independently of optim
akihironitta Feb 23, 2021
8d1253d
Remove optimizer.zero_grad from optimizer.step
akihironitta Feb 24, 2021
eaf42fd
Use accelerator's zero_grad
akihironitta Feb 24, 2021
00376fc
Update manual optimization docs
akihironitta Feb 24, 2021
be43077
Update automatic optimization docs
akihironitta Feb 24, 2021
77d7f83
Merge branch 'master' into bugfix/4083_lbfgs
akihironitta Feb 24, 2021
3f6e086
Update new-project docs
akihironitta Feb 24, 2021
c422ffa
Move on_before_zero_grad to trainloop
akihironitta Feb 25, 2021
95e0a0b
Use trainerloop methods
akihironitta Feb 25, 2021
e852efd
Merge branch 'bugfix/4083_lbfgs' of github.com:akihironitta/pytorch-l…
akihironitta Feb 25, 2021
8243b80
Remove zero_grad after backward
akihironitta Feb 25, 2021
8269f45
Split tests to step and zero_grad
akihironitta Feb 25, 2021
4694e3e
Call zero_grad before backward in tests
akihironitta Feb 25, 2021
d658920
Call zero_grad before backward in tests
akihironitta Feb 25, 2021
2dd6154
Add a test for optimization with lbfgs
akihironitta Feb 25, 2021
2fe4d28
Remove unused model
akihironitta Feb 25, 2021
1845a67
Add back BoringModel
akihironitta Feb 25, 2021
eb875b4
Update CHANGELOG
akihironitta Feb 25, 2021
d898e7b
zero_grad when the first batch of accumulation
akihironitta Feb 25, 2021
d262546
Merge branch 'master' into bugfix/4083_lbfgs
akihironitta Feb 25, 2021
435a3c8
Refactor tests. Remove duplicates
carmocca Feb 26, 2021
22fd399
Add a test to check zero_grad call order
akihironitta Feb 26, 2021
02ec3aa
flake8
akihironitta Feb 26, 2021
f92f708
Update test comment
akihironitta Feb 26, 2021
85431ff
Make test compatible with PT1.4
akihironitta Feb 26, 2021
1be883e
Apply suggestions from code review
akihironitta Feb 26, 2021
337d85b
Use simple string for logging called methods
akihironitta Feb 26, 2021
3194bb3
Add optimzer.step to test
akihironitta Feb 27, 2021
1b5c1ce
Merge branch 'master' into bugfix/4083_lbfgs
akihironitta Feb 27, 2021
91e7ae1
Update the test comment
akihironitta Feb 27, 2021
96a5504
Update the test
akihironitta Feb 27, 2021
87fdfb8
Update the test
akihironitta Feb 27, 2021
9d11f34
Remove unused import
akihironitta Feb 27, 2021
93d3f57
Update test comment
akihironitta Feb 27, 2021
05c28d5
Update CHANGELOG.md
akihironitta Feb 28, 2021
9836bb7
Update docs
akihironitta Feb 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Changed the order of `backward`, `step`, `zero_grad` to `zero_grad`, `backward`, `step` ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147))


### Deprecated

Expand Down Expand Up @@ -63,6 +65,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))


- Fixed LBFGS optimizer support which didn't converge in automatic optimization ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147))


- Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931))


Expand Down
69 changes: 39 additions & 30 deletions docs/source/common/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,31 @@ to manually manage the optimization process. To do so, do the following:

* Override your LightningModule ``automatic_optimization`` property to return ``False``
* Drop or ignore the optimizer_idx argument
* Use `self.manual_backward(loss)` instead of `loss.backward()`.
* Use ``self.manual_backward(loss)`` instead of ``loss.backward()``.

.. note:: This is only recommended for experts who need ultimate flexibility. Lightning will handle only precision and accelerators logic. The users are left with zero_grad, accumulated_grad_batches, model toggling, etc..
.. note:: This is only recommended for experts who need ultimate flexibility. Lightning will handle only precision and accelerators logic. The users are left with ``optimizer.zero_grad()``, gradient accumulation, model toggling, etc..

.. warning:: Before 1.2, ``optimzer.step`` was calling ``zero_grad`` internally. From 1.2, it is left to the users expertize.
.. warning:: Before 1.2, ``optimzer.step`` was calling ``optimizer.zero_grad()`` internally. From 1.2, it is left to the users expertize.

.. tip:: To perform ``accumulate_grad_batches`` with one optimizer, you can do as such.

.. tip:: ``self.optimizers()`` will return ``LightningOptimizer`` objects. You can access your own optimizer with ``optimizer.optimizer``. However, if you use your own optimizer to perform a step, Lightning won't be able to support accelerators and precision for you.


.. code-block:: python

def training_step(batch, batch_idx, optimizer_idx):
opt = self.optimizers()

loss = self.compute_loss(batch)
self.manual_backward(loss)
opt.step()
carmocca marked this conversation as resolved.
Show resolved Hide resolved

# accumulate gradient batches
if batch_idx % 2 == 0:
opt.step()
opt.zero_grad()


.. tip:: It is a good practice to provide the optimizer with a ``closure`` function that performs a ``forward`` and ``backward`` pass of your model. It is optional for most optimizers, but makes your code compatible if you switch to an optimizer which requires a closure.
.. tip:: It is a good practice to provide the optimizer with a ``closure`` function that performs a ``forward`` and ``backward`` pass of your model. It is optional for most optimizers, but makes your code compatible if you switch to an optimizer which requires a closure. See also `the PyTorch docs <https://pytorch.org/docs/stable/optim.html#optimizer-step-closure>`_.

Here is the same example as above using a ``closure``.

Expand All @@ -71,7 +70,6 @@ Here is the same example as above using a ``closure``.
.. code-block:: python

# Scenario for a GAN.

def training_step(...):
opt_gen, opt_dis = self.optimizers()

Expand Down Expand Up @@ -137,8 +135,12 @@ Here is an example on how to use it:

Automatic optimization
======================
With Lightning most users don't have to think about when to call .backward(), .step(), .zero_grad(), since
Lightning automates that for you.
With Lightning most users don't have to think about when to call ``.zero_grad()``, ``.backward()`` and ``.step()``
since Lightning automates that for you.

.. warning::
Before 1.2.2, ``.zero_grad()`` was called after ``.backward()`` and ``.step()`` internally.
From 1.2.2, Lightning calls ``.zero_grad()`` before ``.backward()``.

Under the hood Lightning does the following:

Expand All @@ -147,33 +149,33 @@ Under the hood Lightning does the following:
for epoch in epochs:
for batch in data:
loss = model.training_step(batch, batch_idx, ...)
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()

for scheduler in schedulers:
scheduler.step()
for lr_scheduler in lr_schedulers:
lr_scheduler.step()

In the case of multiple optimizers, Lightning does the following:

.. code-block:: python

for epoch in epochs:
for batch in data:
for opt in optimizers:
disable_grads_for_other_optimizers()
train_step(opt)
opt.step()
for batch in data:
for opt in optimizers:
loss = model.training_step(batch, batch_idx, optimizer_idx)
opt.zero_grad()
loss.backward()
opt.step()

for scheduler in schedulers:
scheduler.step()
for lr_scheduler in lr_schedulers:
lr_scheduler.step()


Learning rate scheduling
------------------------
Every optimizer you use can be paired with any `LearningRateScheduler <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_.
In the basic use-case, the scheduler (or multiple schedulers) should be returned as the second output from the ``.configure_optimizers``
method:
Every optimizer you use can be paired with any `Learning Rate Scheduler <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_.
In the basic use-case, the scheduler (or multiple schedulers) should be returned as the second output from the ``.configure_optimizers`` method:

.. testcode::

Expand Down Expand Up @@ -262,7 +264,7 @@ returned as a dict which can contain the following keywords:

Use multiple optimizers (like GANs)
-----------------------------------
To use multiple optimizers return > 1 optimizers from :meth:`pytorch_lightning.core.LightningModule.configure_optimizers`
To use multiple optimizers return two or more optimizers from :meth:`pytorch_lightning.core.LightningModule.configure_optimizers`

.. testcode::

Expand All @@ -283,13 +285,15 @@ Lightning will call each optimizer sequentially:
.. code-block:: python

for epoch in epochs:
for batch in data:
for opt in optimizers:
train_step(opt)
opt.step()
for batch in data:
for opt in optimizers:
loss = train_step(batch, batch_idx, optimizer_idx)
opt.zero_grad()
loss.backward()
opt.step()

for scheduler in schedulers:
scheduler.step()
for lr_scheduler in lr_schedulers:
lr_scheduler.step()

----------

Expand Down Expand Up @@ -332,7 +336,7 @@ Here we add a learning-rate warm up
# update params
optimizer.step(closure=closure)

.. note:: The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step. It handles TPUs, AMP, accumulate_grad_batches, zero_grad, and much more ...
.. note:: The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step. It handles TPUs, AMP, accumulate_grad_batches and much more ...

.. testcode::

Expand Down Expand Up @@ -362,6 +366,11 @@ Using the closure functions for optimization

When using optimization schemes such as LBFGS, the `second_order_closure` needs to be enabled. By default, this function is defined by wrapping the `training_step` and the backward steps as follows

.. warning::
Before 1.2.2, ``.zero_grad()`` was called outside the closure internally.
From 1.2.2, the closure calls ``.zero_grad()`` inside, so there is no need to define your own closure
when using similar optimizers to :class:`torch.optim.LBFGS` which requires reevaluation of the loss with the closure in ``optimizer.step()``.

.. testcode::

def second_order_closure(pl_module, split_batch, batch_idx, opt_idx, optimizer, hidden):
Expand Down
4 changes: 2 additions & 2 deletions docs/source/starter/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,9 @@ The training step is what happens inside the training loop.
# TRAINING STEP
# ....
# TRAINING STEP
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()

In the case of MNIST, we do the following

Expand All @@ -377,9 +377,9 @@ In the case of MNIST, we do the following
loss = F.nll_loss(logits, y)
# ------ TRAINING STEP END ------

optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()

In Lightning, everything that is in the training step gets organized under the
:func:`~pytorch_lightning.core.LightningModule.training_step` function in the LightningModule.
Expand Down
21 changes: 10 additions & 11 deletions docs/source/starter/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ as long as you return a loss with an attached graph from the `training_step`, Li
.. code-block:: python

def training_step(self, batch, batch_idx):
loss = self.encoder(batch[0])
loss = self.encoder(batch)
return loss

.. _manual_opt:
Expand All @@ -267,19 +267,18 @@ Turn off automatic optimization and you control the train loop!

def training_step(self, batch, batch_idx, optimizer_idx):
# access your optimizers with use_pl_optimizer=False. Default is True
(opt_a, opt_b, opt_c) = self.optimizers(use_pl_optimizer=True)
opt_a, opt_b = self.optimizers(use_pl_optimizer=True)

loss_a = self.generator(batch[0])

# use this instead of loss.backward so we can automate half precision, etc...
self.manual_backward(loss_a, opt_a, retain_graph=True)
self.manual_backward(loss_a, opt_a)
opt_a.step()
loss_a = self.generator(batch)
opt_a.zero_grad()
# use `manual_backward()` instead of `loss.backward` to automate half precision, etc...
self.manual_backward(loss_a)
opt_a.step()

loss_b = self.discriminator(batch[0])
self.manual_backward(loss_b, opt_b)
...
loss_b = self.discriminator(batch)
opt_b.zero_grad()
self.manual_backward(loss_b)
opt_b.step()


Predict or Deploy
Expand Down
5 changes: 0 additions & 5 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,10 @@ def toggle_model(self, sync_grad: bool = True):
def __optimizer_step(self, closure: Optional[Callable] = None, profiler_name: str = None, **kwargs):
trainer = self._trainer
optimizer = self._optimizer
model = trainer.lightning_module

with trainer.profiler.profile(profiler_name):
trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)

if self._trainer.train_loop.automatic_optimization:
trainer.train_loop.on_before_zero_grad(optimizer)
model.optimizer_zero_grad(trainer.current_epoch, trainer.batch_idx, optimizer, self._optimizer_idx)

def step(self, *args, closure: Optional[Callable] = None, **kwargs):
"""
Call this directly from your training_step when doing optimizations manually.
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,13 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
self._curr_step_result = result

if not self._skip_backward and self.trainer.train_loop.automatic_optimization:
if not self._skip_backward and self.automatic_optimization:
is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0

if is_first_batch_to_accumulate:
self.on_before_zero_grad(optimizer)
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)

# backward pass
if result is not None:
with self.trainer.profiler.profile("model_backward"):
Expand Down
6 changes: 3 additions & 3 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,20 @@ def test_trainer_callback_system(torch_save, tmpdir):
call.on_train_epoch_start(trainer, model),
call.on_batch_start(trainer, model),
call.on_train_batch_start(trainer, model, ANY, 0, 0),
call.on_after_backward(trainer, model),
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_after_backward(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_batch_end(trainer, model),
call.on_batch_start(trainer, model),
call.on_train_batch_start(trainer, model, ANY, 1, 0),
call.on_after_backward(trainer, model),
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_after_backward(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0),
call.on_batch_end(trainer, model),
call.on_batch_start(trainer, model),
call.on_train_batch_start(trainer, model, ANY, 2, 0),
call.on_after_backward(trainer, model),
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_after_backward(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
call.on_batch_end(trainer, model),
call.on_train_epoch_end(trainer, model, ANY),
Expand Down
68 changes: 2 additions & 66 deletions tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock, patch
from unittest.mock import Mock

import pytest
from torch import nn
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_property_logger(tmpdir):
assert model.logger == logger


def test_automatic_optimization(tmpdir):
def test_automatic_optimization_raises(tmpdir):

class TestModel(BoringModel):

Expand All @@ -95,70 +95,6 @@ def optimizer_step(self, *_, **__):
trainer.fit(model)


def test_automatic_optimization_num_calls(tmpdir):

with patch("torch.optim.SGD.step") as sgd_step, \
patch("torch.optim.SGD.zero_grad") as sgd_zero_grad, \
patch("torch.optim.Adam.step") as adam_step, \
patch("torch.optim.Adam.zero_grad") as adam_zero_grad:

class TestModel(BoringModel):

def training_step(self, batch, batch_idx, optimizer_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}

def configure_optimizers(self):
optimizer = SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = Adam(self.layer.parameters(), lr=0.1)
return [optimizer, optimizer_2]

def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_idx,
optimizer_closure,
on_tpu,
using_native_amp,
using_lbfgs,
):

assert optimizer_closure.__name__ == "train_step_and_backward_closure"

# update generator opt every 2 steps
if optimizer_idx == 0:
if batch_idx % 2 == 0:
assert isinstance(optimizer, SGD)
optimizer.step(closure=optimizer_closure)

# update discriminator opt every 4 steps
if optimizer_idx == 1:
if batch_idx % 4 == 0:
assert isinstance(optimizer, Adam)
optimizer.step(closure=optimizer_closure)

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=8,
limit_val_batches=1,
accumulate_grad_batches=1,
)

trainer.fit(model)

assert sgd_step.call_count == 4
assert sgd_zero_grad.call_count == 4
assert adam_step.call_count == 2
assert adam_zero_grad.call_count == 2


def test_params_groups_and_state_are_accessible(tmpdir):

class TestModel(BoringModel):
Expand Down
Loading