diff --git a/docs/source/common/optimizers.rst b/docs/source/common/optimizers.rst index 3b29fd4c08f13..37b1a8caa09b0 100644 --- a/docs/source/common/optimizers.rst +++ b/docs/source/common/optimizers.rst @@ -40,10 +40,9 @@ to manually manage the optimization process. To do so, do the following: loss = self.compute_loss(batch) self.manual_backward(loss) - .. note:: This is only recommended for experts who need ultimate flexibility. Lightning will handle only precision and accelerators logic. The users are left with ``optimizer.zero_grad()``, gradient accumulation, model toggling, etc.. -.. warning:: Before 1.2, ``optimzer.step`` was calling ``optimizer.zero_grad()`` internally. From 1.2, it is left to the users expertize. +.. warning:: Before 1.2, ``optimzer.step`` was calling ``optimizer.zero_grad()`` internally. From 1.2, it is left to the users expertise. .. tip:: To perform ``accumulate_grad_batches`` with one optimizer, you can do as such. @@ -62,8 +61,7 @@ to manually manage the optimization process. To do so, do the following: opt.step() opt.zero_grad() - -.. tip:: It is a good practice to provide the optimizer with a ``closure`` function that performs a ``forward`` and ``backward`` pass of your model. It is optional for most optimizers, but makes your code compatible if you switch to an optimizer which requires a closure. See also `the PyTorch docs `_. +.. tip:: It is a good practice to provide the optimizer with a ``closure`` function that performs a ``forward``, ``zero_grad`` and ``backward`` of your model. It is optional for most optimizers, but makes your code compatible if you switch to an optimizer which requires a closure. See also `the PyTorch docs `_. Here is the same example as above using a ``closure``. @@ -72,20 +70,20 @@ Here is the same example as above using a ``closure``. def training_step(batch, batch_idx): opt = self.optimizers() - def forward_and_backward(): + def closure(): + # Only zero_grad on the first batch to accumulate gradients + is_first_batch_to_accumulate = batch_idx % 2 == 0 + if is_first_batch_to_accumulate: + opt.zero_grad() + loss = self.compute_loss(batch) self.manual_backward(loss) + return loss - opt.step(closure=forward_and_backward) - - # accumulate gradient batches - if batch_idx % 2 == 0: - opt.zero_grad() - + opt.step(closure=closure) .. tip:: Be careful where you call ``zero_grad`` or your model won't converge. It is good pratice to call ``zero_grad`` before ``manual_backward``. - .. testcode:: python import torch @@ -169,10 +167,8 @@ Setting ``sync_grad`` to ``False`` will block this synchronization and improve y Here is an example for advanced use-case. - .. testcode:: python - # Scenario for a GAN with gradient accumulation every 2 batches and optimized for multiple gpus. class SimpleGAN(LightningModule):