Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the on_before_backward hook #7865

Merged
merged 148 commits into from
Jul 9, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
148 commits
Select commit Hold shift + click to select a range
7500fd0
Add callback to hook tests and add predict test
carmocca May 27, 2021
27e2dcf
Fix lambda callback test
carmocca May 27, 2021
174be4c
Simplify lambda call test
carmocca May 28, 2021
e99e711
Use LambdaCallback
carmocca May 28, 2021
c52ab79
Dynamically append to called for the model
carmocca May 28, 2021
fcfe381
Remove print
carmocca May 28, 2021
2ca39c4
Consistency
carmocca May 28, 2021
aa8eea0
Consistency
carmocca May 28, 2021
1de5fbd
Prepare args/kwargs testing
carmocca May 28, 2021
736f1c2
yapf doesn't like dict literals
carmocca May 28, 2021
020d98d
Add arguments for fit no val test
carmocca May 28, 2021
c069e2d
Add arguments for fit no val test
carmocca May 28, 2021
0be25f0
add before_backward_hook
tchaton Jun 7, 2021
2ced937
add test
tchaton Jun 7, 2021
6db056e
resolve flake8
tchaton Jun 7, 2021
23302bd
resolve tests
tchaton Jun 7, 2021
4a69d86
update changelog
tchaton Jun 7, 2021
d36a61c
add on_before_backward to LightningModule
tchaton Jun 7, 2021
50a4f6f
update on comments
tchaton Jun 8, 2021
6245149
Merge branch 'master' into tests/improve-hook-tests
carmocca Jun 10, 2021
deb67fb
Test arguments
carmocca Jun 11, 2021
4554003
Datamodule refactor
carmocca Jun 11, 2021
0163920
Merge branch 'master' into before_backward_hook
tchaton Jun 14, 2021
8c2fc2e
Merge branch 'master' into before_backward_hook
tchaton Jun 15, 2021
a3c74cb
Merge branch 'master' into before_backward_hook
tchaton Jun 15, 2021
fca960c
Merge branch 'master' into before_backward_hook
tchaton Jun 16, 2021
e331d24
Merge branch 'master' into before_backward_hook
tchaton Jun 16, 2021
7fc612f
Merge branch 'before_backward_hook' of https://github.com/PyTorchLigh…
tchaton Jun 16, 2021
6c92649
Merge branch 'master' into tests/improve-hook-tests
carmocca Jun 17, 2021
8c8e059
Fix eval test
carmocca Jun 17, 2021
f481ee1
Merge branch 'master' into before_backward_hook
tchaton Jun 17, 2021
7940318
remove extra file
tchaton Jun 17, 2021
5774d4d
resolve bug
tchaton Jun 18, 2021
a5b2fc7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2021
f46dea5
Merge branch 'master' into before_backward_hook
tchaton Jun 18, 2021
3782911
Merge branch 'before_backward_hook' of https://github.com/PyTorchLigh…
tchaton Jun 18, 2021
003d21c
move to hooks
tchaton Jun 18, 2021
6716de5
update
tchaton Jun 18, 2021
923ea72
resolve flake8
tchaton Jun 18, 2021
410f915
update on comments
tchaton Jun 18, 2021
9af5290
Merge branch 'master' into before_backward_hook
tchaton Jun 21, 2021
af39b28
Merge branch 'master' into tests/improve-hook-tests
carmocca Jun 21, 2021
6e9bcf8
Update full fit + val test
carmocca Jun 21, 2021
037100e
Update test
carmocca Jun 21, 2021
a5511f1
Remove FIXME
carmocca Jun 21, 2021
78b4062
Remove FIXME
carmocca Jun 21, 2021
fd65bb8
Undo change
carmocca Jun 21, 2021
c32e5e0
Fix
carmocca Jun 21, 2021
1c4c370
Merge branch 'tests/improve-hook-test-full-fit' into before_backward_…
carmocca Jun 21, 2021
6ec32df
Parametrize fit hook test
carmocca Jun 21, 2021
08ac721
Merge branch 'master' into before_backward_hook
carmocca Jun 21, 2021
a59000a
Comment
carmocca Jun 21, 2021
a5f2e6b
Parametrize fit hook test with different precision plugins
carmocca Jun 21, 2021
0ce2295
Fix tests
carmocca Jun 22, 2021
c22fc74
Parametrize fit hook test with manual optimization
carmocca Jun 22, 2021
4b1534b
Unnecessary parenthesis
carmocca Jun 22, 2021
f5828a8
WIP
carmocca Jun 22, 2021
c7f3865
Merge branch 'tests/parametrize-hooks-precision-plugins' into before_…
carmocca Jun 22, 2021
72d5ee3
Comments
carmocca Jun 22, 2021
f34ee7e
Fix message
carmocca Jun 22, 2021
39c4a85
Test CI error
carmocca Jun 22, 2021
c3b458d
Revert "Test CI error"
carmocca Jun 22, 2021
c700cab
Add ddp training type teardown
carmocca Jun 22, 2021
e5602c9
Update CHANGELOG
carmocca Jun 22, 2021
52b2256
Adrian's fix
carmocca Jun 22, 2021
0b94b6c
Use destructor
carmocca Jun 23, 2021
aaf32ab
Update CHANGELOG.md
carmocca Jun 23, 2021
0444d54
RPC destructor
carmocca Jun 23, 2021
5d4f811
Update pytorch_lightning/plugins/training_type/ddp.py
carmocca Jun 23, 2021
bf8766d
Why do you not work :(
carmocca Jun 23, 2021
48bcb7e
Missing condition
carmocca Jun 23, 2021
5d6fa39
Merge branch 'master' into bug/teardown-ddp-process-group
carmocca Jun 23, 2021
21ad2d8
Fix deepspeed test
carmocca Jun 24, 2021
bbc489e
GC collect in conftest
carmocca Jun 24, 2021
5b06fd2
Do not show warnings for special tests
carmocca Jun 24, 2021
5e69ed8
Needs to run on 1.8
carmocca Jun 24, 2021
1e0cf40
Merge branch 'master' into tests/parametrize-hooks-precision-plugins
awaelchli Jun 24, 2021
aed51a2
Run torch 1.8
carmocca Jun 24, 2021
e0a3e87
Skip test due to 'Python bus error'
carmocca Jun 24, 2021
9ee2d19
Debug NCCL
carmocca Jun 24, 2021
3588aaa
shm size
carmocca Jun 24, 2021
067bf1a
Disable warnings for special tests
carmocca Jun 24, 2021
6060b05
Remove NCCL_DEBUG statement
carmocca Jun 24, 2021
f0fa1b7
Try smaller shm size
carmocca Jun 24, 2021
6dd7038
Revert "Skip test due to 'Python bus error'"
carmocca Jun 24, 2021
53082bf
Merge branch 'ci/gpu-tests-torch-1.8' into bug/teardown-ddp-process-g…
carmocca Jun 24, 2021
73e62f8
README and adjust versions
carmocca Jun 24, 2021
902ef02
Avoid self.on_gpu call
carmocca Jun 24, 2021
4ce0f9a
empty cache cleanup
carmocca Jun 24, 2021
990b2e9
Merge branch 'master' into bug/teardown-ddp-process-group
carmocca Jun 24, 2021
738daa5
More garbage collection
carmocca Jun 24, 2021
236aa97
Unroll parametrizations
awaelchli Jun 24, 2021
ffa532d
Do not reuse mock
carmocca Jun 24, 2021
5aa3790
Merge branch 'master' into tests/parametrize-hooks-precision-plugins
carmocca Jun 24, 2021
78baa5f
Merge branch 'bug/teardown-ddp-process-group' into tests/parametrize-…
carmocca Jun 24, 2021
e190089
Undo changes
carmocca Jun 24, 2021
261a166
Undo notebooks modification
carmocca Jun 24, 2021
91edac0
Merge branch 'master' into before_backward_hook
tchaton Jun 25, 2021
3cec91c
Merge branch 'before_backward_hook' of https://github.com/PyTorchLigh…
tchaton Jun 25, 2021
031917e
resolve test
tchaton Jun 25, 2021
cf1aa34
Merge branch 'master' into before_backward_hook
tchaton Jun 28, 2021
30c57e3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 28, 2021
9f89e57
update
tchaton Jun 28, 2021
3eccc98
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 28, 2021
eb381f8
delete file
tchaton Jun 28, 2021
acec7b0
Merge branch 'master' into tests/parametrize-hooks-precision-plugins
carmocca Jul 3, 2021
33a68d4
Undo
carmocca Jul 3, 2021
ac006c7
Fix test
carmocca Jul 3, 2021
9efe252
Revert "WIP"
carmocca Jul 3, 2021
beecfb9
Merge branch 'master' into tests/parametrize-hooks-manual-opt
carmocca Jul 3, 2021
01dfc7c
Merge branch 'tests/parametrize-hooks-precision-plugins' into tests/p…
carmocca Jul 3, 2021
dc7b17b
Rename
carmocca Jul 3, 2021
8e1f60e
Merge branch 'master' into tests/parametrize-hooks-manual-opt
carmocca Jul 5, 2021
9c3fbd4
Remove optimizers
carmocca Jul 5, 2021
f90348c
Fix bug with LightningOptimizer
carmocca Jul 5, 2021
cbf7b36
Add optimizers
carmocca Jul 5, 2021
bdc3819
Merge branch 'master' into before_backward_hook
tchaton Jul 6, 2021
7e68389
update
tchaton Jul 6, 2021
ec198c5
update
tchaton Jul 6, 2021
d1a48a6
Update CHANGELOG
carmocca Jul 6, 2021
1869128
Merge branch 'master' into tests/parametrize-hooks-manual-opt
carmocca Jul 6, 2021
fe06ec0
On after backward refactor
carmocca Jul 6, 2021
938de4d
Do not call super
carmocca Jul 6, 2021
20da3b1
Fixes
carmocca Jul 6, 2021
abfbdd6
Remove should_accumulate
carmocca Jul 7, 2021
9c8993c
pre/post backward refactor
carmocca Jul 7, 2021
d7d2a71
Call the LM backward hook
carmocca Jul 7, 2021
f3c3726
Update tests
carmocca Jul 7, 2021
7cfed58
Remove dev debug patch
carmocca Jul 7, 2021
7838eae
Fix test
carmocca Jul 7, 2021
c070e84
Remove optimizer arguments and typing
carmocca Jul 7, 2021
5fabca8
Docs fixes
carmocca Jul 7, 2021
cf89192
Fix comment
carmocca Jul 7, 2021
6d77d72
Merge branch 'master' into tests/parametrize-hooks-manual-opt
carmocca Jul 7, 2021
f88cc51
Merge branch 'master' into tests/parametrize-hooks-manual-opt
carmocca Jul 7, 2021
d749a85
Undo changes
carmocca Jul 7, 2021
d1c342b
Merge branch 'master' into tests/parametrize-hooks-manual-opt
carmocca Jul 8, 2021
816cb4c
Merge branch 'master' into tests/parametrize-hooks-manual-opt
carmocca Jul 8, 2021
e2ea758
Split manual and auto
carmocca Jul 8, 2021
160c2b4
Undo change
carmocca Jul 8, 2021
cbc78db
Deepsource
carmocca Jul 8, 2021
6aa229c
Remove optimizers
carmocca Jul 8, 2021
b273fdd
Merge branch 'master' into before_backward_hook
carmocca Jul 8, 2021
b0f25b5
Merge branch 'tests/parametrize-hooks-manual-opt' into before_backwar…
carmocca Jul 8, 2021
b531ef1
Undo changes
carmocca Jul 8, 2021
3ce386f
Call the hook
carmocca Jul 8, 2021
44fcdb2
Docs
carmocca Jul 8, 2021
b47470b
Docs
carmocca Jul 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `on_load_checkpoint` and `on_save_checkpoint` hooks to the `PrecisionPlugin` base class ([#7831](https://github.com/PyTorchLightning/pytorch-lightning/pull/7831))


- Added `on_before_backward` hook ([#7865](https://github.com/PyTorchLightning/pytorch-lightning/pull/7865))


### Deprecated


Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import abc
from typing import Any, Dict, List, Optional

import torch
from torch.optim import Optimizer

import pytorch_lightning as pl
Expand Down Expand Up @@ -296,6 +297,10 @@ def on_load_checkpoint(
"""
pass

def on_before_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', loss: torch.Tensor) -> None:
"""Called before ``loss.backward()``."""
pass

def on_after_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called after ``loss.backward()`` and before optimizers do anything."""
pass
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/callbacks/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
on_keyboard_interrupt: Optional[Callable] = None,
on_save_checkpoint: Optional[Callable] = None,
on_load_checkpoint: Optional[Callable] = None,
on_before_backward: Optional[Callable] = None,
on_after_backward: Optional[Callable] = None,
on_before_zero_grad: Optional[Callable] = None,
on_predict_start: Optional[Callable] = None,
Expand Down
17 changes: 17 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,23 @@ def training_step(...):
self.trainer.train_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs)
self._running_manual_backward = False

def on_before_backward(self, loss: torch.Tensor) -> None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""
Override on_before_backward with your own implementation if you need to.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

Args:
loss: Loss is already scaled by accumulated grads and possibly scaled by Mixed Precision.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

Called to before backward step.

Example::

def on_before_backward(self, loss):
print(f"Current Loss: {loss}")

"""
pass

def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
"""
Override backward with your own implementation if you need to.
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def backward(
# enter apex context
closure_loss = scaled_loss.__enter__()

# hook
model.trainer.call_hook("on_before_backward", closure_loss)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# do backward pass
# TODO: not entirely sure, why we need this
if model is not None and isinstance(model, LightningModule):
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def backward(
)
# todo: hack around for deepspeed engine to call backward
deepspeed_engine = model.trainer.model

# hook
model.trainer.call_hook("on_before_backward", closure_loss)

deepspeed_engine.backward(closure_loss, *args, **kwargs)
# once backward has been applied, release graph
closure_loss = closure_loss.detach()
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def backward(
"""
closure_loss = self.scaler.scale(closure_loss)

# call `on_before_backward` hook
closure_loss = super().backward(model, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs)

# unscale gradient to allow analyze within `on_after_backward`
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def backward(
"""
automatic_optimization = model.automatic_optimization

model.trainer.call_hook("on_before_backward", closure_loss)

# do backward pass
if automatic_optimization:
model.backward(closure_loss, optimizer, opt_idx)
Expand Down
9 changes: 9 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from inspect import signature
from typing import Any, Callable, Dict, List, Optional, Type

import torch

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
Expand Down Expand Up @@ -313,6 +315,13 @@ def on_load_checkpoint(self, checkpoint):
else:
callback.on_load_checkpoint(self, self.lightning_module, state)

def on_before_backward(self, loss: torch.Tensor) -> None:
"""
Called before loss.backward().
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""
for callback in self.callbacks:
callback.on_before_backward(self, self.lightning_module, loss)

def on_after_backward(self):
"""
Called after loss.backward() and before optimizers do anything.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class FxValidator:
functions: Dict[str, Optional[Dict[str, Tuple[bool]]]] = dict(
on_before_accelerator_backend_setup=None,
on_configure_sharded_model=None,
on_before_backward=dict(on_step=(False, True), on_epoch=(False, True)),
on_after_backward=dict(on_step=(False, True), on_epoch=(False, True)),
on_before_zero_grad=dict(on_step=(False, True), on_epoch=(False, True)),
on_init_start=None,
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,9 @@ def _call_teardown_hook(self, model: LightningModule) -> None:
def _reset_result_and_set_fx_name(self, hook_name: str) -> bool:
# on_before_zero_grad is called within training_step
# TODO(@carmocca): Result should handle this logic
if "batch_start" in hook_name or hook_name in ("on_before_zero_grad", "on_after_backward"):
if "batch_start" in hook_name or hook_name in (
"on_before_zero_grad", "on_before_backward", "on_after_backward"
):
return True
model_ref = self.lightning_module
if model_ref is not None:
Expand Down
3 changes: 3 additions & 0 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,21 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_batch_start(trainer, model),
call.on_train_batch_start(trainer, model, ANY, 0, 0),
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_before_backward(trainer, model, ANY),
call.on_after_backward(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_batch_end(trainer, model),
call.on_batch_start(trainer, model),
call.on_train_batch_start(trainer, model, ANY, 1, 0),
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_before_backward(trainer, model, ANY),
call.on_after_backward(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0),
call.on_batch_end(trainer, model),
call.on_batch_start(trainer, model),
call.on_train_batch_start(trainer, model, ANY, 2, 0),
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_before_backward(trainer, model, ANY),
call.on_after_backward(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
call.on_batch_end(trainer, model),
Expand Down
17 changes: 17 additions & 0 deletions tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,20 @@ def assert_device(device: torch.device) -> None:
assert_device(torch.device("cpu"))
trainer.predict(model, dataloaders=model.train_dataloader())
assert_device(torch.device("cpu"))


def test_on_before_backward(tmpdir):
tchaton marked this conversation as resolved.
Show resolved Hide resolved

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.on_before_backward_called = False

def on_before_backward(self, loss: torch.Tensor) -> None:
self.on_before_backward_called = True

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model)
assert model.on_before_backward_called
4 changes: 4 additions & 0 deletions tests/helpers/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
if _SKLEARN_AVAILABLE:
from sklearn.datasets import make_classification, make_regression
from sklearn.model_selection import train_test_split
else:
make_classification = None
make_regression = None
train_test_split = None


class MNISTDataModule(LightningDataModule):
Expand Down
25 changes: 25 additions & 0 deletions tests/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from tests.helpers import BoringModel
Expand Down Expand Up @@ -69,18 +70,34 @@ def test_amp_apex_ddp(
assert isinstance(trainer.precision_plugin, plugin_cls)


class CheckOnBeforeBackward(Callback):
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
self.on_before_backward_called = False

def on_before_backward(self, trainer, pl_module, loss):
assert isinstance(loss, torch.Tensor)
assert loss.grad_fn is not None
self.on_before_backward_called = True


class GradientUnscaleBoringModel(BoringModel):

def on_after_backward(self):
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
if not (torch.isinf(norm) or torch.isnan(norm)):
assert norm.item() < 15.

cb = [cb for cb in self.trainer.callbacks if isinstance(cb, CheckOnBeforeBackward)]
assert len(cb) == 1
assert cb[0].on_before_backward_called


@RunIf(min_gpus=2, amp_native=True)
@pytest.mark.parametrize('accum', [1, 2])
def test_amp_gradient_unscale(tmpdir, accum: int):
model = GradientUnscaleBoringModel()
cb = CheckOnBeforeBackward()

trainer = Trainer(
max_epochs=2,
Expand All @@ -95,13 +112,15 @@ def test_amp_gradient_unscale(tmpdir, accum: int):
track_grad_norm=2,
log_every_n_steps=1,
accumulate_grad_batches=accum,
callbacks=[cb]
)
trainer.fit(model)


@RunIf(min_gpus=2, amp_apex=True, special=True)
@pytest.mark.parametrize("amp_level", ['O2'])
def test_amp_apex_ddp_fit(amp_level, tmpdir):
cb = CheckOnBeforeBackward()

class CustomBoringModel(BoringModel):

Expand All @@ -110,6 +129,11 @@ def training_step(self, batch, batch_idx):
assert self.trainer.precision_plugin._connected
return super().training_step(batch, batch_idx)

def on_after_backward(self):
cb = [cb for cb in self.trainer.callbacks if isinstance(cb, CheckOnBeforeBackward)]
assert len(cb) == 1
assert cb[0].on_before_backward_called

trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
Expand All @@ -118,6 +142,7 @@ def training_step(self, batch, batch_idx):
gpus=2,
accelerator='ddp',
plugins=ApexMixedPrecisionPlugin(amp_level=amp_level),
callbacks=[cb]
)
assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin)
model = CustomBoringModel()
Expand Down
13 changes: 13 additions & 0 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,14 +573,25 @@ def test_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, cpu_offload

class VerificationCallback(Callback):

def __init__(self):
self.on_train_batch_start_called = False
self.on_before_backward_called = False

def on_train_batch_start(
self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
deepspeed_engine = trainer.training_type_plugin.model
assert trainer.global_step == deepspeed_engine.global_steps
self.on_train_batch_start_called = True

def on_before_backward(self, trainer, pl_module, loss):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(loss, torch.Tensor)
assert loss.grad_fn is not None
self.on_before_backward_called = True

model = ModelParallelClassificationModel()
dm = ClassifDataModule()
verification_callback = VerificationCallback()
trainer = Trainer(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
Expand All @@ -593,6 +604,8 @@ def on_train_batch_start(
callbacks=[VerificationCallback()]
)
trainer.fit(model, datamodule=dm)
assert verification_callback.on_train_batch_start_called
assert verification_callback.on_before_backward_called


@RunIf(min_gpus=2, deepspeed=True, special=True)
Expand Down
1 change: 1 addition & 0 deletions tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def test_fx_validator(tmpdir):
funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_')])

callbacks_func = [
'on_before_backward',
'on_after_backward',
'on_batch_end',
'on_batch_start',
Expand Down