Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[doc] Improve Manual Optimization Example #6294

Merged
merged 14 commits into from
Mar 5, 2021
176 changes: 135 additions & 41 deletions docs/source/common/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,26 @@ Manual optimization
For advanced research topics like reinforcement learning, sparse coding, or GAN research, it may be desirable
to manually manage the optimization process. To do so, do the following:

* Override your LightningModule ``automatic_optimization`` property to return ``False``
* Drop or ignore the optimizer_idx argument
* Set the ``automatic_optimization`` property to ``False`` in your ``LightningModule`` ``__init__`` function
* Use ``self.manual_backward(loss)`` instead of ``loss.backward()``.

.. testcode:: python
tchaton marked this conversation as resolved.
Show resolved Hide resolved

from pytorch_lightning import LightningModule

class MyModel(LightningModule):

def __init__(self):
super().__init__()
# Important: This property activate ``manual optimization`` for your model
self.automatic_optimization = False

def training_step(batch, batch_idx):
opt = self.optimizers()
loss = self.compute_loss(batch)
self.manual_backward(loss)


.. note:: This is only recommended for experts who need ultimate flexibility. Lightning will handle only precision and accelerators logic. The users are left with ``optimizer.zero_grad()``, gradient accumulation, model toggling, etc..

.. warning:: Before 1.2, ``optimzer.step`` was calling ``optimizer.zero_grad()`` internally. From 1.2, it is left to the users expertize.
Expand All @@ -35,7 +51,7 @@ to manually manage the optimization process. To do so, do the following:

.. code-block:: python

def training_step(batch, batch_idx, optimizer_idx):
def training_step(batch, batch_idx):
opt = self.optimizers()

loss = self.compute_loss(batch)
Expand All @@ -51,9 +67,9 @@ to manually manage the optimization process. To do so, do the following:

Here is the same example as above using a ``closure``.

.. code-block:: python
.. testcode:: python

def training_step(batch, batch_idx, optimizer_idx):
def training_step(batch, batch_idx):
opt = self.optimizers()

def forward_and_backward():
Expand All @@ -67,28 +83,78 @@ Here is the same example as above using a ``closure``.
opt.zero_grad()


.. code-block:: python
.. tip:: Be careful where you call ``zero_grad`` or your model won't converge. It is good pratice to call ``zero_grad`` before ``manual_backward``.


.. testcode:: python

import torch
from torch import Tensor
from pytorch_lightning import LightningModule

class SimpleGAN(LightningModule):

def __init__(self):
super().__init__()
self.G = Generator()
self.D = Discriminator()

# Important: This property activate ``manual optimization`` for this model
self.automatic_optimization = False

def sample_z(self, n) -> Tensor:
sample = self._Z.sample((n,))
return sample

def sample_G(self, n) -> Tensor:
z = self.sample_z(n)
return self.G(z)

def training_step(self, batch, batch_idx):
# Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
g_opt, d_opt = self.optimizers()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# Scenario for a GAN.
def training_step(...):
opt_gen, opt_dis = self.optimizers()
X, _ = batch
batch_size = X.shape[0]

# compute generator loss
loss_gen = self.compute_generator_loss(...)
real_label = torch.ones((batch_size, 1), device=self.device)
fake_label = torch.zeros((batch_size, 1), device=self.device)

# zero_grad needs to be called before backward
opt_gen.zero_grad()
self.manual_backward(loss_gen)
opt_gen.step()
g_X = self.sample_G(batch_size)

# compute discriminator loss
loss_dis = self.compute_discriminator_loss(...)
###########################
# Optimize Discriminator #
###########################
d_opt.zero_grad()

# zero_grad needs to be called before backward
opt_dis.zero_grad()
self.manual_backward(loss_dis)
opt_dis.step()
d_x = self.D(X)
errD_real = self.criterion(d_x, real_label)

d_z = self.D(g_X.detach())
errD_fake = self.criterion(d_z, fake_label)

tchaton marked this conversation as resolved.
Show resolved Hide resolved
errD = (errD_real + errD_fake)

self.manual_backward(errD)
d_opt.step()

#######################
# Optimize Generator #
#######################
g_opt.zero_grad()

d_z = self.D(g_X)
errG = self.criterion(d_z, real_label)

self.manual_backward(errG)
g_opt.step()

self.log_dict({'g_loss': errG, 'd_loss': errD}, prog_bar=True)

def configure_optimizers(self):
g_opt = torch.optim.Adam(self.G.parameters(), lr=1e-5)
d_opt = torch.optim.Adam(self.D.parameters(), lr=1e-5)
return g_opt, d_opt

.. note:: ``LightningOptimizer`` provides a ``toggle_model`` function as a ``@context_manager`` for advanced users. It can be useful when performing gradient accumulation with several optimizers or training in a distributed setting.

Expand All @@ -100,36 +166,64 @@ Toggling means that all parameters from B exclusive to A will have their ``requi
When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase.
Setting ``sync_grad`` to ``False`` will block this synchronization and improve your training speed.

Here is an example on how to use it:

.. code-block:: python
Here is an example for advanced use-case.


.. testcode:: python


# Scenario for a GAN with gradient accumulation every 2 batches and optimized for multiple gpus.

def training_step(self, batch, batch_idx, ...):
opt_gen, opt_dis = self.optimizers()
class SimpleGAN(LightningModule):

...

def training_step(self, batch, batch_idx):
# Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
g_opt, d_opt = self.optimizers()

X, _ = batch
X.requires_grad = True
batch_size = X.shape[0]

real_label = torch.ones((batch_size, 1), device=self.device)
fake_label = torch.zeros((batch_size, 1), device=self.device)

accumulated_grad_batches = batch_idx % 2 == 0

g_X = self.sample_G(batch_size)

###########################
# Optimize Discriminator #
###########################
with d_opt.toggle_model(sync_grad=accumulated_grad_batches):
d_x = self.D(X)
errD_real = self.criterion(d_x, real_label)

d_z = self.D(g_X.detach())
errD_fake = self.criterion(d_z, fake_label)

accumulated_grad_batches = batch_idx % 2 == 0
errD = (errD_real + errD_fake)

# compute generator loss
def closure_gen():
loss_gen = self.compute_generator_loss(...)
self.manual_backward(loss_gen)
if accumulated_grad_batches:
opt_gen.zero_grad()
self.manual_backward(errD)
if accumulated_grad_batches:
d_opt.step()
d_opt.zero_grad()

with opt_gen.toggle_model(sync_grad=accumulated_grad_batches):
opt_gen.step(closure=closure_gen)
#######################
# Optimize Generator #
#######################
with g_opt.toggle_model(sync_grad=accumulated_grad_batches):
d_z = self.D(g_X)
errG = self.criterion(d_z, real_label)

def closure_dis():
loss_dis = self.compute_discriminator_loss(...)
self.manual_backward(loss_dis)
if accumulated_grad_batches:
opt_dis.zero_grad()
self.manual_backward(errG)
if accumulated_grad_batches:
g_opt.step()
g_opt.zero_grad()

with opt_dis.toggle_model(sync_grad=accumulated_grad_batches):
opt_dis.step(closure=closure_dis)
self.log_dict({'g_loss': errG, 'd_loss': errD}, prog_bar=True)

------

Expand Down