Skip to content

Commit

Permalink
Merge branch 'master' into bug/7924_on_after_backward_should_always_run
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Jul 9, 2021
2 parents 90083a0 + 1c825a2 commit 683d1a5
Show file tree
Hide file tree
Showing 15 changed files with 61 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `PrecisionPlugin.backward` hooks no longer takes a `should_accumulate` argument ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))


- Added the `on_before_backward` hook ([#7865](https://github.com/PyTorchLightning/pytorch-lightning/pull/7865))


- `LightningCLI` now aborts with a clearer message if config already exists and disables save config during `fast_dev_run`([#7963](https://github.com/PyTorchLightning/pytorch-lightning/pull/7963))


Expand Down
7 changes: 7 additions & 0 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,7 @@ for more information.
on_before_zero_grad()
optimizer_zero_grad()
on_before_backward()
backward()
on_after_backward()
Expand Down Expand Up @@ -1247,6 +1248,12 @@ get_progress_bar_dict
.. automethod:: pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict
:noindex:

on_before_backward
~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_backward
:noindex:

on_after_backward
~~~~~~~~~~~~~~~~~

Expand Down
6 changes: 6 additions & 0 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,12 @@ on_load_checkpoint
.. automethod:: pytorch_lightning.callbacks.Callback.on_load_checkpoint
:noindex:

on_before_backward
^^^^^^^^^^^^^^^^^^

.. automethod:: pytorch_lightning.callbacks.Callback.on_before_backward
:noindex:

on_after_backward
^^^^^^^^^^^^^^^^^

Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import abc
from typing import Any, Dict, List, Optional

import torch
from torch.optim import Optimizer

import pytorch_lightning as pl
Expand Down Expand Up @@ -296,6 +297,10 @@ def on_load_checkpoint(
"""
pass

def on_before_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', loss: torch.Tensor) -> None:
"""Called before ``loss.backward()``."""
pass

def on_after_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called after ``loss.backward()`` and before optimizers are stepped."""
pass
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/callbacks/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
on_keyboard_interrupt: Optional[Callable] = None,
on_save_checkpoint: Optional[Callable] = None,
on_load_checkpoint: Optional[Callable] = None,
on_before_backward: Optional[Callable] = None,
on_after_backward: Optional[Callable] = None,
on_before_optimizer_step: Optional[Callable] = None,
on_before_zero_grad: Optional[Callable] = None,
Expand Down
9 changes: 9 additions & 0 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,15 @@ def on_before_zero_grad(self, optimizer: Optimizer) -> None:
optimizer: The optimizer for which grads should be zeroed.
"""

def on_before_backward(self, loss: torch.Tensor) -> None:
"""
Called before ``loss.backward()``.
Args:
loss: Loss divided by number of batches for gradient accumulation and scaled if using native AMP.
"""
pass

def on_after_backward(self) -> None:
"""
Called after ``loss.backward()`` and before optimizers are stepped.
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def pre_backward(
model: 'pl.LightningModule',
closure_loss: torch.Tensor,
) -> torch.Tensor:
return self.scaler.scale(closure_loss)
closure_loss = self.scaler.scale(closure_loss)
return super().pre_backward(model, closure_loss)

def pre_optimizer_step(
self,
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def pre_backward(
model: the model to be optimized
closure_loss: the loss value obtained from the closure
"""
model.trainer.call_hook("on_before_backward", closure_loss)
return closure_loss

def backward(
Expand Down
7 changes: 7 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from inspect import signature
from typing import Any, Callable, Dict, List, Optional, Type

import torch

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
Expand Down Expand Up @@ -313,6 +315,11 @@ def on_load_checkpoint(self, checkpoint):
else:
callback.on_load_checkpoint(self, self.lightning_module, state)

def on_before_backward(self, loss: torch.Tensor) -> None:
"""Called before ``loss.backward()``."""
for callback in self.callbacks:
callback.on_before_backward(self, self.lightning_module, loss)

def on_after_backward(self):
"""
Called after loss.backward() and before optimizers do anything.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class FxValidator:
functions: Dict[str, Optional[Dict[str, Tuple[bool]]]] = dict(
on_before_accelerator_backend_setup=None,
on_configure_sharded_model=None,
on_before_backward=dict(on_step=(False, True), on_epoch=(False, True)),
on_after_backward=dict(on_step=(False, True), on_epoch=(False, True)),
on_before_optimizer_step=dict(on_step=(False, True), on_epoch=(False, True)),
on_before_zero_grad=dict(on_step=(False, True), on_epoch=(False, True)),
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ numpy>=1.17.2
torch>=1.4
future>=0.17.1 # required for builtins in setup.py
tqdm>=4.41.0
PyYAML>=5.1,<=5.4.1
PyYAML>=5.1
fsspec[http]>=2021.05.0, !=2021.06.0
tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file into descriptor pool!'
torchmetrics>=0.4.0
Expand Down
4 changes: 4 additions & 0 deletions tests/helpers/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
if _SKLEARN_AVAILABLE:
from sklearn.datasets import make_classification, make_regression
from sklearn.model_selection import train_test_split
else:
make_classification = None
make_regression = None
train_test_split = None


class MNISTDataModule(LightningDataModule):
Expand Down
8 changes: 6 additions & 2 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ class HookedModel(BoringModel):
def __init__(self, called):
super().__init__()
pl_module_hooks = get_members(LightningModule)
# remove non-hooks
pl_module_hooks.difference_update({'optimizers'})
# remove most `nn.Module` hooks
module_hooks = get_members(torch.nn.Module)
module_hooks.difference_update({'forward', 'zero_grad', 'train'})
Expand Down Expand Up @@ -320,7 +322,8 @@ def _auto_train_batch(trainer, model, batches, device=torch.device('cpu'), curre
dict(name='Callback.on_before_zero_grad', args=(trainer, model, ANY)),
dict(name='on_before_zero_grad', args=(ANY, )),
dict(name='optimizer_zero_grad', args=(current_epoch, i, ANY, 0)),
# TODO: `on_before_backward`
dict(name='Callback.on_before_backward', args=(trainer, model, ANY)),
dict(name='on_before_backward', args=(ANY, )),
# DeepSpeed handles backward internally
*([dict(name='backward', args=(ANY, ANY, 0))] if not using_deepspeed else []),
dict(name='Callback.on_after_backward', args=(trainer, model)),
Expand Down Expand Up @@ -351,7 +354,8 @@ def _manual_train_batch(trainer, model, batches, device=torch.device('cpu'), **k
dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)),
dict(name='on_train_batch_start', args=(ANY, i, 0)),
dict(name='forward', args=(ANY, )),
dict(name='optimizers'),
dict(name='Callback.on_before_backward', args=(trainer, model, ANY)),
dict(name='on_before_backward', args=(ANY, )),
# DeepSpeed handles backward internally
*([dict(name='backward', args=(ANY, None, None))] if not using_deepspeed else []),
dict(name='Callback.on_after_backward', args=(trainer, model)),
Expand Down
8 changes: 7 additions & 1 deletion tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,14 +657,19 @@ def _deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimiz

class VerificationCallback(Callback):

def __init__(self):
self.on_train_batch_start_called = False

def on_train_batch_start(
self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
deepspeed_engine = trainer.training_type_plugin.model
assert trainer.global_step == deepspeed_engine.global_steps
self.on_train_batch_start_called = True

model = ModelParallelClassificationModel()
dm = ClassifDataModule()
verification_callback = VerificationCallback()
trainer = Trainer(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
Expand All @@ -674,9 +679,10 @@ def on_train_batch_start(
limit_val_batches=2,
precision=16,
accumulate_grad_batches=2,
callbacks=[VerificationCallback()]
callbacks=[verification_callback]
)
trainer.fit(model, datamodule=dm)
assert verification_callback.on_train_batch_start_called


@RunIf(min_gpus=2, deepspeed=True, special=True)
Expand Down
1 change: 1 addition & 0 deletions tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_fx_validator(tmpdir):
funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_')])

callbacks_func = [
'on_before_backward',
'on_after_backward',
'on_before_optimizer_step',
'on_batch_end',
Expand Down

0 comments on commit 683d1a5

Please sign in to comment.