Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the on_before_optimizer_step hook #8048

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
579ed0d
add on_before_optimizer hook to hook definitions
ddrevicky Jun 20, 2021
f2ccbbf
update training loop logic with on_before_optimizer_step
ddrevicky Jun 20, 2021
6311d8c
add on_before_optimizer_step to lambda_function
ddrevicky Jun 20, 2021
253a0e7
add on_before_optimizer_step to fx_validator
ddrevicky Jun 20, 2021
a5f2e6b
Parametrize fit hook test with different precision plugins
carmocca Jun 21, 2021
0ce2295
Fix tests
carmocca Jun 22, 2021
c22fc74
Parametrize fit hook test with manual optimization
carmocca Jun 22, 2021
4b1534b
Unnecessary parenthesis
carmocca Jun 22, 2021
f5828a8
WIP
carmocca Jun 22, 2021
72d5ee3
Comments
carmocca Jun 22, 2021
f34ee7e
Fix message
carmocca Jun 22, 2021
39c4a85
Test CI error
carmocca Jun 22, 2021
c3b458d
Revert "Test CI error"
carmocca Jun 22, 2021
c700cab
Add ddp training type teardown
carmocca Jun 22, 2021
e5602c9
Update CHANGELOG
carmocca Jun 22, 2021
52b2256
Adrian's fix
carmocca Jun 22, 2021
0b94b6c
Use destructor
carmocca Jun 23, 2021
aaf32ab
Update CHANGELOG.md
carmocca Jun 23, 2021
0444d54
RPC destructor
carmocca Jun 23, 2021
5d4f811
Update pytorch_lightning/plugins/training_type/ddp.py
carmocca Jun 23, 2021
bf8766d
Why do you not work :(
carmocca Jun 23, 2021
48bcb7e
Missing condition
carmocca Jun 23, 2021
5d6fa39
Merge branch 'master' into bug/teardown-ddp-process-group
carmocca Jun 23, 2021
21ad2d8
Fix deepspeed test
carmocca Jun 24, 2021
bbc489e
GC collect in conftest
carmocca Jun 24, 2021
5b06fd2
Do not show warnings for special tests
carmocca Jun 24, 2021
5e69ed8
Needs to run on 1.8
carmocca Jun 24, 2021
1e0cf40
Merge branch 'master' into tests/parametrize-hooks-precision-plugins
awaelchli Jun 24, 2021
aed51a2
Run torch 1.8
carmocca Jun 24, 2021
e0a3e87
Skip test due to 'Python bus error'
carmocca Jun 24, 2021
9ee2d19
Debug NCCL
carmocca Jun 24, 2021
3588aaa
shm size
carmocca Jun 24, 2021
067bf1a
Disable warnings for special tests
carmocca Jun 24, 2021
6060b05
Remove NCCL_DEBUG statement
carmocca Jun 24, 2021
f0fa1b7
Try smaller shm size
carmocca Jun 24, 2021
6dd7038
Revert "Skip test due to 'Python bus error'"
carmocca Jun 24, 2021
53082bf
Merge branch 'ci/gpu-tests-torch-1.8' into bug/teardown-ddp-process-g…
carmocca Jun 24, 2021
73e62f8
README and adjust versions
carmocca Jun 24, 2021
902ef02
Avoid self.on_gpu call
carmocca Jun 24, 2021
4ce0f9a
empty cache cleanup
carmocca Jun 24, 2021
990b2e9
Merge branch 'master' into bug/teardown-ddp-process-group
carmocca Jun 24, 2021
738daa5
More garbage collection
carmocca Jun 24, 2021
236aa97
Unroll parametrizations
awaelchli Jun 24, 2021
ffa532d
Do not reuse mock
carmocca Jun 24, 2021
5aa3790
Merge branch 'master' into tests/parametrize-hooks-precision-plugins
carmocca Jun 24, 2021
78baa5f
Merge branch 'bug/teardown-ddp-process-group' into tests/parametrize-…
carmocca Jun 24, 2021
e190089
Undo changes
carmocca Jun 24, 2021
261a166
Undo notebooks modification
carmocca Jun 24, 2021
acec7b0
Merge branch 'master' into tests/parametrize-hooks-precision-plugins
carmocca Jul 3, 2021
33a68d4
Undo
carmocca Jul 3, 2021
ac006c7
Fix test
carmocca Jul 3, 2021
9efe252
Revert "WIP"
carmocca Jul 3, 2021
beecfb9
Merge branch 'master' into tests/parametrize-hooks-manual-opt
carmocca Jul 3, 2021
01dfc7c
Merge branch 'tests/parametrize-hooks-precision-plugins' into tests/p…
carmocca Jul 3, 2021
dc7b17b
Rename
carmocca Jul 3, 2021
8e1f60e
Merge branch 'master' into tests/parametrize-hooks-manual-opt
carmocca Jul 5, 2021
9c3fbd4
Remove optimizers
carmocca Jul 5, 2021
f90348c
Fix bug with LightningOptimizer
carmocca Jul 5, 2021
cbf7b36
Add optimizers
carmocca Jul 5, 2021
d1a48a6
Update CHANGELOG
carmocca Jul 6, 2021
1869128
Merge branch 'master' into tests/parametrize-hooks-manual-opt
carmocca Jul 6, 2021
fe06ec0
On after backward refactor
carmocca Jul 6, 2021
938de4d
Do not call super
carmocca Jul 6, 2021
20da3b1
Fixes
carmocca Jul 6, 2021
abfbdd6
Remove should_accumulate
carmocca Jul 7, 2021
9c8993c
pre/post backward refactor
carmocca Jul 7, 2021
d7d2a71
Call the LM backward hook
carmocca Jul 7, 2021
f3c3726
Update tests
carmocca Jul 7, 2021
7cfed58
Remove dev debug patch
carmocca Jul 7, 2021
7838eae
Fix test
carmocca Jul 7, 2021
c070e84
Remove optimizer arguments and typing
carmocca Jul 7, 2021
5fabca8
Docs fixes
carmocca Jul 7, 2021
cf89192
Fix comment
carmocca Jul 7, 2021
6d77d72
Merge branch 'master' into tests/parametrize-hooks-manual-opt
carmocca Jul 7, 2021
f88cc51
Merge branch 'master' into tests/parametrize-hooks-manual-opt
carmocca Jul 7, 2021
d749a85
Undo changes
carmocca Jul 7, 2021
d1c342b
Merge branch 'master' into tests/parametrize-hooks-manual-opt
carmocca Jul 8, 2021
816cb4c
Merge branch 'master' into tests/parametrize-hooks-manual-opt
carmocca Jul 8, 2021
fb585a2
Merge branch 'master' into bug/7924_on_after_backward_should_always_run
carmocca Jul 8, 2021
beadb8f
Progress
carmocca Jul 8, 2021
f2c29e1
Progress
carmocca Jul 8, 2021
66bf012
Progress
carmocca Jul 8, 2021
dcc92f0
Merge branch 'tests/parametrize-hooks-manual-opt' into bug/7924_on_af…
carmocca Jul 8, 2021
c3d8cde
Progress
carmocca Jul 8, 2021
f8129e9
Fix test
carmocca Jul 8, 2021
65bcc57
Update CHANGELOG
carmocca Jul 8, 2021
c44a491
Fix test
carmocca Jul 8, 2021
71cbcd5
Docs
carmocca Jul 8, 2021
365aeeb
Staticmethod
carmocca Jul 8, 2021
90083a0
Fix fx validator test
carmocca Jul 8, 2021
683d1a5
Merge branch 'master' into bug/7924_on_after_backward_should_always_run
carmocca Jul 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added trainer stage hooks for Training Plugins and Accelerators ([#7864](https://github.com/PyTorchLightning/pytorch-lightning/pull/7864))


- Added the `on_before_optimizer_step` hook ([#8048](https://github.com/PyTorchLightning/pytorch-lightning/pull/8048))


- Added IPU Accelerator ([#7867](https://github.com/PyTorchLightning/pytorch-lightning/pull/7867))


Expand Down Expand Up @@ -244,10 +247,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Moved profilers to their own file ([#7822](https://github.com/PyTorchLightning/pytorch-lightning/pull/7822))


- The `on_after_backward` hook is now called on accumulating iterations ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))
- The `on_after_backward` hook is now called on accumulating iterations. Use the `on_before_optimizer_step` hook to mimic the old behaviour ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))


- The mixed precision loss is no longer unscaled before the `on_after_backward` hook ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))
- The mixed precision loss is no longer unscaled before the `on_after_backward` hook. Use the `on_before_optimizer_step` hook to mimic the old behaviour ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))


- The `TrainingTypePlugin.{pre,post}_backward` hooks no longer take the `optimizer, opt_idx, should_accumulate` arguments ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))
Expand Down
7 changes: 7 additions & 0 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,7 @@ for more information.
backward()
on_after_backward()

on_before_optimizer_step()
optimizer_step()

on_train_batch_end()
Expand Down Expand Up @@ -1451,6 +1452,12 @@ on_test_model_train
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_train
:noindex:

on_before_optimizer_step
~~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_optimizer_step
:noindex:

optimizer_step
~~~~~~~~~~~~~~

Expand Down
6 changes: 6 additions & 0 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,12 @@ on_after_backward
.. automethod:: pytorch_lightning.callbacks.Callback.on_after_backward
:noindex:

on_before_optimizer_step
^^^^^^^^^^^^^^^^^^^^^^^^

.. automethod:: pytorch_lightning.callbacks.Callback.on_before_optimizer_step
:noindex:

on_before_zero_grad
^^^^^^^^^^^^^^^^^^^

Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,13 @@ def on_before_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModu
pass

def on_after_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called after ``loss.backward()`` and before optimizers do anything."""
"""Called after ``loss.backward()`` and before optimizers are stepped."""
pass

def on_before_optimizer_step(
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', optimizer: Optimizer, opt_idx: int
) -> None:
"""Called before ``optimizer.step()``."""
pass

def on_before_zero_grad(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', optimizer: Optimizer) -> None:
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/callbacks/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
on_load_checkpoint: Optional[Callable] = None,
on_before_backward: Optional[Callable] = None,
on_after_backward: Optional[Callable] = None,
on_before_optimizer_step: Optional[Callable] = None,
on_before_zero_grad: Optional[Callable] = None,
on_predict_start: Optional[Callable] = None,
on_predict_end: Optional[Callable] = None,
Expand Down
25 changes: 21 additions & 4 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,19 +306,36 @@ def on_before_backward(self, loss: torch.Tensor) -> None:

def on_after_backward(self) -> None:
"""
Called in the training loop after loss.backward() and before optimizers do anything.
This is the ideal place to inspect or log gradient information.
Called after ``loss.backward()`` and before optimizers are stepped.

Note:
If using native AMP, the gradients will not be unscaled at this point.
Use the ``on_before_optimizer_step`` if you need the unscaled gradients.
"""

def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
"""
Called before ``optimizer.step()``.

The hook is only called if gradients do not need to be accumulated.
See: :paramref:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches`.
If using native AMP, the loss will be unscaled before calling this hook.
See these `docs <https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients>`__
for more information on the scaling of gradients.

Args:
optimizer: Current optimizer being used.
optimizer_idx: Index of the current optimizer being used.

Example::

def on_after_backward(self):
def on_before_optimizer_step(self, optimizer, optimizer_idx):
# example to inspect gradient information in tensorboard
if self.trainer.global_step % 25 == 0: # don't make the tf file huge
for k, v in self.named_parameters():
self.logger.experiment.add_histogram(
tag=k, values=v.grad, global_step=self.trainer.global_step
)

"""

def on_post_move_to_device(self) -> None:
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,16 @@ def reinit_scheduler_properties(optimizers: Sequence[Optimizer], schedulers: Seq

def pre_optimizer_step(
self,
pl_module: 'pl.LightningModule',
model: 'pl.LightningModule',
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
"""
always called before the optimizer step.
"""
# apex amp does not support closures.
lambda_closure()
"""Hook to do something before each optimizer step."""
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
lambda_closure() # APEX amp does not support closures
optimizer.step(**kwargs)
return False

Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,17 @@ def __init__(self, precision: int) -> None:

def pre_optimizer_step(
self,
pl_module: 'pl.LightningModule',
model: 'pl.LightningModule',
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
# DeepSpeed not support closures.
lambda_closure()
deepspeed_engine = pl_module.trainer.model
"""Hook to do something before each optimizer step."""
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
lambda_closure() # DeepSpeed does not support closures
deepspeed_engine = model.trainer.model
deepspeed_engine.step()
return False

Expand Down
19 changes: 9 additions & 10 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def pre_backward(

def pre_optimizer_step(
self,
pl_module: 'pl.LightningModule',
model: 'pl.LightningModule',
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
Expand All @@ -58,16 +58,15 @@ def pre_optimizer_step(
f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})."
" To request, please file a Github issue in PyTorch and tag @mcarilli"
)
# TODO: Add `on_before_optimizer_step`
# self.scaler.unscale_(optimizer)
# pl_module.trainer.call_hook("on_before_optimizer_step")
if pl_module.automatic_optimization:
result = True
if model.automatic_optimization:
result = lambda_closure()
if result is None:
# lambda_closure returning None indicates that backward has been skipped
return False
self.scaler.step(optimizer)
self.scaler.update()
self.scaler.unscale_(optimizer)
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# lambda_closure returning None indicates that backward has been skipped
if result is not None:
self.scaler.step(optimizer)
self.scaler.update()
return False

@contextmanager
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,14 @@ def post_backward(

def pre_optimizer_step(
self,
pl_module: 'pl.LightningModule',
model: 'pl.LightningModule',
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
"""Hook to do something before each optimizer step."""
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return True

def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
Expand Down
7 changes: 7 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,13 @@ def on_after_backward(self):
for callback in self.callbacks:
callback.on_after_backward(self, self.lightning_module)

def on_before_optimizer_step(self, optimizer, optimizer_idx):
"""
Called after on_after_backward() once the gradient is accumulated and before optimizer.step().
"""
for callback in self.callbacks:
callback.on_before_optimizer_step(self, self.lightning_module, optimizer, optimizer_idx)

def on_before_zero_grad(self, optimizer):
"""
Called after optimizer.step() and before optimizer.zero_grad().
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class FxValidator:
on_configure_sharded_model=None,
on_before_backward=dict(on_step=(False, True), on_epoch=(False, True)),
on_after_backward=dict(on_step=(False, True), on_epoch=(False, True)),
on_before_optimizer_step=dict(on_step=(False, True), on_epoch=(False, True)),
on_before_zero_grad=dict(on_step=(False, True), on_epoch=(False, True)),
on_init_start=None,
on_init_end=None,
Expand Down
13 changes: 11 additions & 2 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,10 @@ def _auto_train_batch(trainer, model, batches, device=torch.device('cpu'), curre
using_native_amp = kwargs.get('amp_backend') == 'native'
using_deepspeed = kwargs.get('plugins') == 'deepspeed'
out = []
on_before_optimizer_step = [
dict(name='Callback.on_before_optimizer_step', args=(trainer, model, ANY, 0)),
dict(name='on_before_optimizer_step', args=(ANY, 0)),
]
for i in range(batches):
out.extend([
dict(name='on_before_batch_transfer', args=(ANY, 0)),
Expand All @@ -308,7 +312,10 @@ def _auto_train_batch(trainer, model, batches, device=torch.device('cpu'), curre
dict(name='Callback.on_batch_start', args=(trainer, model)),
dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)),
dict(name='on_train_batch_start', args=(ANY, i, 0)),
# TODO: `on_before_optimizer_step`
# these are before the training step because
# they are not part of the `training_step_and_backward` closure, however,
# with native amp, the closure is run first and then the optimizer step.
*(on_before_optimizer_step if not using_native_amp else []),
dict(name='forward', args=(ANY, )),
dict(name='training_step', args=(ANY, i)),
dict(name='training_step_end', args=(dict(loss=ANY), )),
Expand All @@ -321,6 +328,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device('cpu'), curre
*([dict(name='backward', args=(ANY, ANY, 0))] if not using_deepspeed else []),
dict(name='Callback.on_after_backward', args=(trainer, model)),
dict(name='on_after_backward'),
*(on_before_optimizer_step if using_native_amp else []),
dict(
name='optimizer_step',
args=(current_epoch, i, ANY, 0, ANY),
Expand Down Expand Up @@ -354,7 +362,8 @@ def _manual_train_batch(trainer, model, batches, device=torch.device('cpu'), **k
dict(name='on_after_backward'),
# `manual_backward` calls the previous 3
dict(name='manual_backward', args=(ANY, )),
# TODO: `on_before_optimizer_step`
dict(name='Callback.on_before_optimizer_step', args=(trainer, model, ANY, 0)),
dict(name='on_before_optimizer_step', args=(ANY, 0)),
dict(name='training_step', args=(ANY, i)),
dict(name='training_step_end', args=(dict(loss=ANY), )),
dict(name='Callback.on_train_batch_end', args=(trainer, model, dict(loss=ANY), ANY, i, 0)),
Expand Down
7 changes: 1 addition & 6 deletions tests/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,7 @@ def test_amp_apex_ddp(

class GradientUnscaleBoringModel(BoringModel):

def on_after_backward(self):
# TODO: replace with `on_before_optimizer_step` so we don't need to check accumulate and unscale manually
if self.trainer.fit_loop.should_accumulate():
return
opt = self.optimizers()
self.trainer.precision_plugin.scaler.unscale_(opt)
def on_before_optimizer_step(self, *_):
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
if not (torch.isinf(norm) or torch.isnan(norm)):
assert norm.item() < 15.
Expand Down
2 changes: 2 additions & 0 deletions tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_fx_validator(tmpdir):
callbacks_func = [
'on_before_backward',
'on_after_backward',
'on_before_optimizer_step',
'on_batch_end',
'on_batch_start',
'on_before_accelerator_backend_setup',
Expand Down Expand Up @@ -124,6 +125,7 @@ def test_fx_validator(tmpdir):
# creating allowed condition
allowed = (
is_stage or "batch" in func_name or "epoch" in func_name or "grad" in func_name or "backward" in func_name
or "optimizer_step" in func_name
)
allowed = (
allowed and "pretrain" not in func_name and "predict" not in func_name
Expand Down