Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[doc] Improve Manual Optimization Example #6294

Merged
merged 14 commits into from
Mar 5, 2021
193 changes: 152 additions & 41 deletions docs/source/common/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,32 @@ Manual optimization
For advanced research topics like reinforcement learning, sparse coding, or GAN research, it may be desirable
to manually manage the optimization process. To do so, do the following:

* Override your LightningModule ``automatic_optimization`` property to return ``False``
* Drop or ignore the optimizer_idx argument
* Set the ``automatic_optimization`` property to ``False`` in your ``LightningModule`` ``__init__`` function
* Use ``self.manual_backward(loss)`` instead of ``loss.backward()``.

.. testcode:: python
tchaton marked this conversation as resolved.
Show resolved Hide resolved

from pytorch_lightning import LightningModule

class MyModel(LightningModule):

def __init__(self):
super().__init__()
# Important: This property activate ``manual optimization`` for your model
self.automatic_optimization = False

.. testcode:: python

from pytorch_lightning import LightningModule

class MyModel(LightningModule):

def training_step(batch, batch_idx):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
opt = self.optimizers()
loss = self.compute_loss(batch)
self.manual_backward(loss)


.. note:: This is only recommended for experts who need ultimate flexibility. Lightning will handle only precision and accelerators logic. The users are left with ``optimizer.zero_grad()``, gradient accumulation, model toggling, etc..

.. warning:: Before 1.2, ``optimzer.step`` was calling ``optimizer.zero_grad()`` internally. From 1.2, it is left to the users expertize.
Expand All @@ -35,7 +57,7 @@ to manually manage the optimization process. To do so, do the following:

.. code-block:: python

def training_step(batch, batch_idx, optimizer_idx):
def training_step(batch, batch_idx):
opt = self.optimizers()

loss = self.compute_loss(batch)
Expand All @@ -51,9 +73,9 @@ to manually manage the optimization process. To do so, do the following:

Here is the same example as above using a ``closure``.

.. code-block:: python
.. testcode:: python

def training_step(batch, batch_idx, optimizer_idx):
def training_step(batch, batch_idx):
opt = self.optimizers()

def forward_and_backward():
Expand All @@ -67,28 +89,89 @@ Here is the same example as above using a ``closure``.
opt.zero_grad()


.. code-block:: python
.. tip:: Be careful where you call ``zero_grad`` or your model won't converge. It is good pratice to call ``zero_grad`` before ``manual_backward``.


.. testcode:: python

import torch
from torch import Tensor
from pytorch_lightning import LightningModule

class SimpleGAN(LightningModule):

def __init__(self):
super().__init__()
self.G = Generator()
self.D = Discriminator()

# Important: This property activate ``manual optimization`` for this model
self.automatic_optimization = False

def generator_loss(self, d_z: Tensor) -> Tensor:
# the closer ``d_z`` is from 1,
# the better the generator is able to fool the discriminator
return -1 * torch.log(d_z).mean()

def discriminator_loss(self, d_x: Tensor, d_z: Tensor) -> Tensor:
# the closer is ``d_x`` from 1 and ``d_z`` from 0,
# the better the discriminator is able to distinguish
# true data from generated ones
return -1 * (torch.log(d_x).mean() + torch.log(1 - d_z).mean())
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# Scenario for a GAN.
def training_step(...):
opt_gen, opt_dis = self.optimizers()
def sample_z(self, n) -> Tensor:
sample = self._Z.sample((n,))
return sample

# compute generator loss
loss_gen = self.compute_generator_loss(...)
def sample_G(self, n) -> Tensor:
z = self.sample_z(n)
return self.G(z)

# zero_grad needs to be called before backward
opt_gen.zero_grad()
self.manual_backward(loss_gen)
opt_gen.step()
def training_step(self, batch, batch_idx):
# Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
g_opt, d_opt = self.optimizers()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# compute discriminator loss
loss_dis = self.compute_discriminator_loss(...)
X, _ = batch
batch_size = X.shape[0]

# zero_grad needs to be called before backward
opt_dis.zero_grad()
self.manual_backward(loss_dis)
opt_dis.step()
real_label = torch.ones((batch_size, 1), device=self.device)
fake_label = torch.zeros((batch_size, 1), device=self.device)

g_X = self.sample_G(batch_size)

###########################
# Optimize Discriminator #
###########################
d_opt.zero_grad()

d_x = self.D(X)
errD_real = self.criterion(d_x, real_label)

d_z = self.D(g_X.detach())
errD_fake = self.criterion(d_z, fake_label)

tchaton marked this conversation as resolved.
Show resolved Hide resolved
errD = (errD_real + errD_fake)

self.manual_backward(errD)
d_opt.step()

#######################
# Optimize Generator #
#######################
g_opt.zero_grad()

d_z = self.D(g_X)
errG = self.criterion(d_z, real_label)

self.manual_backward(errG)
g_opt.step()

self.log_dict({'g_loss': errG, 'd_loss': errD}, prog_bar=True)

def configure_optimizers(self):
g_opt = torch.optim.Adam(self.G.parameters(), lr=1e-5)
d_opt = torch.optim.Adam(self.D.parameters(), lr=1e-5)
return g_opt, d_opt

.. note:: ``LightningOptimizer`` provides a ``toggle_model`` function as a ``@context_manager`` for advanced users. It can be useful when performing gradient accumulation with several optimizers or training in a distributed setting.

Expand All @@ -100,36 +183,64 @@ Toggling means that all parameters from B exclusive to A will have their ``requi
When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase.
Setting ``sync_grad`` to ``False`` will block this synchronization and improve your training speed.

Here is an example on how to use it:

.. code-block:: python
Here is an example for advanced use-case.


.. testcode:: python


# Scenario for a GAN with gradient accumulation every 2 batches and optimized for multiple gpus.

def training_step(self, batch, batch_idx, ...):
opt_gen, opt_dis = self.optimizers()
class SimpleGAN(LightningModule):

...

def training_step(self, batch, batch_idx):
# Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
g_opt, d_opt = self.optimizers()

X, _ = batch
X.requires_grad = True
batch_size = X.shape[0]

real_label = torch.ones((batch_size, 1), device=self.device)
fake_label = torch.zeros((batch_size, 1), device=self.device)

accumulated_grad_batches = batch_idx % 2 == 0

g_X = self.sample_G(batch_size)

###########################
# Optimize Discriminator #
###########################
with d_opt.toggle_model(sync_grad=accumulated_grad_batches):
d_x = self.D(X)
errD_real = self.criterion(d_x, real_label)

d_z = self.D(g_X.detach())
errD_fake = self.criterion(d_z, fake_label)

accumulated_grad_batches = batch_idx % 2 == 0
errD = (errD_real + errD_fake)

# compute generator loss
def closure_gen():
loss_gen = self.compute_generator_loss(...)
self.manual_backward(loss_gen)
if accumulated_grad_batches:
opt_gen.zero_grad()
self.manual_backward(errD)
if accumulated_grad_batches:
d_opt.step()
d_opt.zero_grad()

with opt_gen.toggle_model(sync_grad=accumulated_grad_batches):
opt_gen.step(closure=closure_gen)
#######################
# Optimize Generator #
#######################
with g_opt.toggle_model(sync_grad=accumulated_grad_batches):
d_z = self.D(g_X)
errG = self.criterion(d_z, real_label)

def closure_dis():
loss_dis = self.compute_discriminator_loss(...)
self.manual_backward(loss_dis)
if accumulated_grad_batches:
opt_dis.zero_grad()
self.manual_backward(errG)
if accumulated_grad_batches:
g_opt.step()
g_opt.zero_grad()

with opt_dis.toggle_model(sync_grad=accumulated_grad_batches):
opt_dis.step(closure=closure_dis)
self.log_dict({'g_loss': errG, 'd_loss': errD}, prog_bar=True)

------

Expand Down