Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[doc] Improve Manual Optimization Example #6294

Merged
merged 14 commits into from
Mar 5, 2021
159 changes: 125 additions & 34 deletions docs/source/common/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,28 +67,91 @@ Here is the same example as above using a ``closure``.
opt.zero_grad()


.. tip:: Be careful where you call ``zero_grad`` or your model won't converge. It is good pratice to call ``zero_grad`` before ``manual_backward``.


.. code-block:: python
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# Scenario for a GAN.
def training_step(...):
opt_gen, opt_dis = self.optimizers()
import torch
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset

class SimpleGAN(LightningModule):

def __init__(self):
super().__init__()
self.G = Generator(...)
self.D = Discriminator(...)

@property
def automatic_optimization(self):
# Important: This property activate ``manual optimization`` for this model
return False
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def generator_loss(self, d_z: Tensor) -> Tensor:
# the closer ``d_z`` is from 1,
# the better the generator is able to fool the discriminator
return -1 * tr.log(d_z).mean()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def discriminator_loss(self, d_x: Tensor, d_z: Tensor) -> Tensor:
# the closer is ``d_x`` from 1 and ``dz`` from 0,
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# the better the discriminator is able to distinguish
# true data from generated ones
return -1 * (tr.log(d_x).mean() + tr.log(1 - d_z).mean())
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def sample_z(self, n) -> Tensor:
sample = self._Z.sample((n,))
return sample

def sample_G(self, n) -> Tensor:
z = self.sample_z(n)
return self.G(z)

def training_step(self, batch, batch_idx, optimizer_idx, *args):
# Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
g_opt, d_opt = self.optimizers()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

X, _ = batch
batch_size = X.shape[0]

real_label = torch.ones((batch_size, 1), device=self.device)
fake_label = torch.zeros((batch_size, 1), device=self.device)

g_X = self.sample_G(batch_size)

###########################
# Optimize Discriminator #
###########################
d_opt.zero_grad()

tchaton marked this conversation as resolved.
Show resolved Hide resolved
# compute generator loss
loss_gen = self.compute_generator_loss(...)
d_x = self.D(X)
errD_real = self.criterion(d_x, real_label)

# zero_grad needs to be called before backward
opt_gen.zero_grad()
self.manual_backward(loss_gen)
opt_gen.step()
d_z = self.D(g_X.detach())
errD_fake = self.criterion(d_z, fake_label)

# compute discriminator loss
loss_dis = self.compute_discriminator_loss(...)
errD = (errD_real + errD_fake)

# zero_grad needs to be called before backward
opt_dis.zero_grad()
self.manual_backward(loss_dis)
opt_dis.step()
self.manual_backward(errD)
d_opt.step()

#######################
# Optimize Generator #
#######################
g_opt.zero_grad()

d_z = self.D(g_X)
errG = self.criterion(d_z, real_label)

self.manual_backward(errG)
g_opt.step()

self.log_dict({'g_loss': errG, 'd_loss': errD}, prog_bar=True)

def configure_optimizers(self):
g_opt = torch.optim.Adam(self.G.parameters(), lr=1e-5)
d_opt = torch.optim.Adam(self.D.parameters(), lr=1e-5)
return g_opt, d_opt

.. note:: ``LightningOptimizer`` provides a ``toggle_model`` function as a ``@context_manager`` for advanced users. It can be useful when performing gradient accumulation with several optimizers or training in a distributed setting.

Expand All @@ -100,36 +163,64 @@ Toggling means that all parameters from B exclusive to A will have their ``requi
When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase.
Setting ``sync_grad`` to ``False`` will block this synchronization and improve your training speed.

Here is an example on how to use it:

Here is an example for advanced use-case.


.. code-block:: python
tchaton marked this conversation as resolved.
Show resolved Hide resolved


# Scenario for a GAN with gradient accumulation every 2 batches and optimized for multiple gpus.

def training_step(self, batch, batch_idx, ...):
opt_gen, opt_dis = self.optimizers()
class SimpleGAN(LightningModule):

...

def training_step(self, batch, batch_idx, optimizer_idx, *args):
# Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
g_opt, d_opt = self.optimizers()

X, _ = batch
X.requires_grad = True
batch_size = X.shape[0]

real_label = torch.ones((batch_size, 1), device=self.device)
fake_label = torch.zeros((batch_size, 1), device=self.device)

accumulated_grad_batches = batch_idx % 2 == 0

g_X = self.sample_G(batch_size)

###########################
# Optimize Discriminator #
###########################
with d_opt.toggle_model(sync_grad=accumulated_grad_batches):
d_x = self.D(X)
errD_real = self.criterion(d_x, real_label)

d_z = self.D(g_X.detach())
errD_fake = self.criterion(d_z, fake_label)

accumulated_grad_batches = batch_idx % 2 == 0
errD = (errD_real + errD_fake)

# compute generator loss
def closure_gen():
loss_gen = self.compute_generator_loss(...)
self.manual_backward(loss_gen)
if accumulated_grad_batches:
opt_gen.zero_grad()
self.manual_backward(errD)
if accumulated_grad_batches:
d_opt.step()
d_opt.zero_grad()

with opt_gen.toggle_model(sync_grad=accumulated_grad_batches):
opt_gen.step(closure=closure_gen)
#######################
# Optimize Generator #
#######################
with g_opt.toggle_model(sync_grad=accumulated_grad_batches):
d_z = self.D(g_X)
errG = self.criterion(d_z, real_label)

def closure_dis():
loss_dis = self.compute_discriminator_loss(...)
self.manual_backward(loss_dis)
if accumulated_grad_batches:
opt_dis.zero_grad()
self.manual_backward(errG)
if accumulated_grad_batches:
g_opt.step()
g_opt.zero_grad()

with opt_dis.toggle_model(sync_grad=accumulated_grad_batches):
opt_dis.step(closure=closure_dis)
self.log_dict({'g_loss': errG, 'd_loss': errD}, prog_bar=True)

------

Expand Down