diff --git a/CHANGELOG.md b/CHANGELOG.md index 7040df36c57e4..e77df6371261b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -262,6 +262,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `PrecisionPlugin.backward` hooks no longer takes a `should_accumulate` argument ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328)) +- Added the `on_before_backward` hook ([#7865](https://github.com/PyTorchLightning/pytorch-lightning/pull/7865)) + + - `LightningCLI` now aborts with a clearer message if config already exists and disables save config during `fast_dev_run`([#7963](https://github.com/PyTorchLightning/pytorch-lightning/pull/7963)) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 3310daa636206..84ffb7cec82f5 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1191,6 +1191,7 @@ for more information. on_before_zero_grad() optimizer_zero_grad() + on_before_backward() backward() on_after_backward() @@ -1247,6 +1248,12 @@ get_progress_bar_dict .. automethod:: pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict :noindex: +on_before_backward +~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_backward + :noindex: + on_after_backward ~~~~~~~~~~~~~~~~~ diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index c8d7effd1d4ce..88527f11777f1 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -351,6 +351,12 @@ on_load_checkpoint .. automethod:: pytorch_lightning.callbacks.Callback.on_load_checkpoint :noindex: +on_before_backward +^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_before_backward + :noindex: + on_after_backward ^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 59f29c9836232..1db7b20b92208 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -19,6 +19,7 @@ import abc from typing import Any, Dict, List, Optional +import torch from torch.optim import Optimizer import pytorch_lightning as pl @@ -296,6 +297,10 @@ def on_load_checkpoint( """ pass + def on_before_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', loss: torch.Tensor) -> None: + """Called before ``loss.backward()``.""" + pass + def on_after_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called after ``loss.backward()`` and before optimizers are stepped.""" pass diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 98745291797e7..ca9af484dbc0c 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -77,6 +77,7 @@ def __init__( on_keyboard_interrupt: Optional[Callable] = None, on_save_checkpoint: Optional[Callable] = None, on_load_checkpoint: Optional[Callable] = None, + on_before_backward: Optional[Callable] = None, on_after_backward: Optional[Callable] = None, on_before_optimizer_step: Optional[Callable] = None, on_before_zero_grad: Optional[Callable] = None, diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index e83fb479f3deb..a30f699c70cfd 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -295,6 +295,15 @@ def on_before_zero_grad(self, optimizer: Optimizer) -> None: optimizer: The optimizer for which grads should be zeroed. """ + def on_before_backward(self, loss: torch.Tensor) -> None: + """ + Called before ``loss.backward()``. + + Args: + loss: Loss divided by number of batches for gradient accumulation and scaled if using native AMP. + """ + pass + def on_after_backward(self) -> None: """ Called after ``loss.backward()`` and before optimizers are stepped. diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 32b259acec381..7cf4f089f15b3 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -42,7 +42,8 @@ def pre_backward( model: 'pl.LightningModule', closure_loss: torch.Tensor, ) -> torch.Tensor: - return self.scaler.scale(closure_loss) + closure_loss = self.scaler.scale(closure_loss) + return super().pre_backward(model, closure_loss) def pre_optimizer_step( self, diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 25c966d6685bc..ae806ff25e6fc 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -62,6 +62,7 @@ def pre_backward( model: the model to be optimized closure_loss: the loss value obtained from the closure """ + model.trainer.call_hook("on_before_backward", closure_loss) return closure_loss def backward( diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 060ebf0db86ca..63c23d50fa772 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -17,6 +17,8 @@ from inspect import signature from typing import Any, Callable, Dict, List, Optional, Type +import torch + import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn @@ -313,6 +315,11 @@ def on_load_checkpoint(self, checkpoint): else: callback.on_load_checkpoint(self, self.lightning_module, state) + def on_before_backward(self, loss: torch.Tensor) -> None: + """Called before ``loss.backward()``.""" + for callback in self.callbacks: + callback.on_before_backward(self, self.lightning_module, loss) + def on_after_backward(self): """ Called after loss.backward() and before optimizers do anything. diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index 41815ad55be88..3604574fd1e81 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -21,6 +21,7 @@ class FxValidator: functions: Dict[str, Optional[Dict[str, Tuple[bool]]]] = dict( on_before_accelerator_backend_setup=None, on_configure_sharded_model=None, + on_before_backward=dict(on_step=(False, True), on_epoch=(False, True)), on_after_backward=dict(on_step=(False, True), on_epoch=(False, True)), on_before_optimizer_step=dict(on_step=(False, True), on_epoch=(False, True)), on_before_zero_grad=dict(on_step=(False, True), on_epoch=(False, True)), diff --git a/requirements.txt b/requirements.txt index af311aa4c7ed5..293d167b37ade 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ numpy>=1.17.2 torch>=1.4 future>=0.17.1 # required for builtins in setup.py tqdm>=4.41.0 -PyYAML>=5.1,<=5.4.1 +PyYAML>=5.1 fsspec[http]>=2021.05.0, !=2021.06.0 tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file into descriptor pool!' torchmetrics>=0.4.0 diff --git a/tests/helpers/datamodules.py b/tests/helpers/datamodules.py index 12ec16261159d..2fc9f8a901f22 100644 --- a/tests/helpers/datamodules.py +++ b/tests/helpers/datamodules.py @@ -24,6 +24,10 @@ if _SKLEARN_AVAILABLE: from sklearn.datasets import make_classification, make_regression from sklearn.model_selection import train_test_split +else: + make_classification = None + make_regression = None + train_test_split = None class MNISTDataModule(LightningDataModule): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index fe728d632adcd..d89fc090c401f 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -256,6 +256,8 @@ class HookedModel(BoringModel): def __init__(self, called): super().__init__() pl_module_hooks = get_members(LightningModule) + # remove non-hooks + pl_module_hooks.difference_update({'optimizers'}) # remove most `nn.Module` hooks module_hooks = get_members(torch.nn.Module) module_hooks.difference_update({'forward', 'zero_grad', 'train'}) @@ -320,7 +322,8 @@ def _auto_train_batch(trainer, model, batches, device=torch.device('cpu'), curre dict(name='Callback.on_before_zero_grad', args=(trainer, model, ANY)), dict(name='on_before_zero_grad', args=(ANY, )), dict(name='optimizer_zero_grad', args=(current_epoch, i, ANY, 0)), - # TODO: `on_before_backward` + dict(name='Callback.on_before_backward', args=(trainer, model, ANY)), + dict(name='on_before_backward', args=(ANY, )), # DeepSpeed handles backward internally *([dict(name='backward', args=(ANY, ANY, 0))] if not using_deepspeed else []), dict(name='Callback.on_after_backward', args=(trainer, model)), @@ -351,7 +354,8 @@ def _manual_train_batch(trainer, model, batches, device=torch.device('cpu'), **k dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)), dict(name='on_train_batch_start', args=(ANY, i, 0)), dict(name='forward', args=(ANY, )), - dict(name='optimizers'), + dict(name='Callback.on_before_backward', args=(trainer, model, ANY)), + dict(name='on_before_backward', args=(ANY, )), # DeepSpeed handles backward internally *([dict(name='backward', args=(ANY, None, None))] if not using_deepspeed else []), dict(name='Callback.on_after_backward', args=(trainer, model)), diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 68b0f2d9178a9..f3d89b54ae236 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -657,14 +657,19 @@ def _deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimiz class VerificationCallback(Callback): + def __init__(self): + self.on_train_batch_start_called = False + def on_train_batch_start( self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int ) -> None: deepspeed_engine = trainer.training_type_plugin.model assert trainer.global_step == deepspeed_engine.global_steps + self.on_train_batch_start_called = True model = ModelParallelClassificationModel() dm = ClassifDataModule() + verification_callback = VerificationCallback() trainer = Trainer( default_root_dir=tmpdir, progress_bar_refresh_rate=0, @@ -674,9 +679,10 @@ def on_train_batch_start( limit_val_batches=2, precision=16, accumulate_grad_batches=2, - callbacks=[VerificationCallback()] + callbacks=[verification_callback] ) trainer.fit(model, datamodule=dm) + assert verification_callback.on_train_batch_start_called @RunIf(min_gpus=2, deepspeed=True, special=True) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 833b6740f6ca8..27598b40fbd31 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -32,6 +32,7 @@ def test_fx_validator(tmpdir): funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_')]) callbacks_func = [ + 'on_before_backward', 'on_after_backward', 'on_before_optimizer_step', 'on_batch_end',