Skip to content

Commit

Permalink
added on_backward trainer callback (allenai#5249)
Browse files Browse the repository at this point in the history
* added BackwardCallback

* finished tests

* fixed linting issue

* revised design per Dirk's suggestion

* added OnBackwardException, changed loss to batch_ouputs, etc.

Co-authored-by: Arjun Subramonian <arjuns@Arjuns-MacBook-Pro.local>
  • Loading branch information
2 people authored and Abhishek-P committed Aug 11, 2021
1 parent 38c930b commit 0e3a225
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 6 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added `on_backward` training callback which allows for control over backpropagation and gradient manipulation.

### Fixed

- Fixed Broken link in `allennlp.fairness.fairness_metrics.Separation` docs
Expand Down
1 change: 1 addition & 0 deletions allennlp/training/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from allennlp.training.callbacks.tensorboard import TensorBoardCallback
from allennlp.training.callbacks.track_epoch import TrackEpochCallback
from allennlp.training.callbacks.wandb import WandBCallback
from allennlp.training.callbacks.backward import MixedPrecisionBackwardCallback, OnBackwardException
40 changes: 40 additions & 0 deletions allennlp/training/callbacks/backward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Dict, TYPE_CHECKING
import torch

from allennlp.training.callbacks.callback import TrainerCallback

if TYPE_CHECKING:
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer


@TrainerCallback.register("mixed_precision_backward")
class MixedPrecisionBackwardCallback(TrainerCallback):
"""
Performs backpropagation for mixed precision training.
"""

def on_backward(
self,
trainer: "GradientDescentTrainer",
batch_outputs: Dict[str, torch.Tensor],
backward_called: bool,
**kwargs
) -> bool:
if backward_called:
raise OnBackwardException()
trainer._scaler.scale(batch_outputs["loss"]).backward() # type: ignore
return True


class OnBackwardException(Exception):
"""
The exception type raised if an `on_backward` callback
attempts to call `backward` when `backward_called` is `True`.
"""

def __init__(self, message="") -> None:
super().__init__(
"Backpropagation has already been performed"
"and the computation graph has been erased, so"
"calling `loss.backward` is not permitted. " + message
)
17 changes: 16 additions & 1 deletion allennlp/training/callbacks/callback.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Dict, Any, Optional, TYPE_CHECKING
import torch

from allennlp.common import Registrable
from allennlp.data import TensorDict
Expand All @@ -12,7 +13,7 @@ class TrainerCallback(Registrable):
"""
A general callback object that handles multiple events.
This class has `on_batch`, `on_epoch`, and `on_end` methods, corresponding to
This class has `on_backward`, `on_batch`, `on_epoch`, and `on_end` methods, corresponding to
each callback type. Each one receives the state of the wrapper object as `self`.
This enables easier state sharing between related callbacks.
Expand All @@ -33,6 +34,20 @@ def on_start(
"""
self.trainer = trainer

def on_backward(
self,
trainer: "GradientDescentTrainer",
batch_outputs: Dict[str, torch.Tensor],
backward_called: bool,
**kwargs,
) -> bool:
"""
This callback hook performs backpropagation and allows for gradient manipulation.
`backward_called` indicates if `loss.backward` has been called prior to this callback.
`on_backward` should return `True` if and only if `loss.backward` is called in its body.
"""
return False

def on_batch(
self,
trainer: "GradientDescentTrainer",
Expand Down
18 changes: 13 additions & 5 deletions allennlp/training/gradient_descent_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from allennlp.models.model import Model
from allennlp.training.callbacks import ConsoleLoggerCallback
from allennlp.training.callbacks.confidence_checks import ConfidenceChecksCallback
from allennlp.training.callbacks.backward import MixedPrecisionBackwardCallback
from allennlp.training.checkpointer import Checkpointer
from allennlp.training.learning_rate_schedulers.learning_rate_scheduler import LearningRateScheduler
from allennlp.training.metric_tracker import MetricTracker
Expand Down Expand Up @@ -148,7 +149,7 @@ class GradientDescentTrainer(Trainer):
parameters. This is necessary because we want the saved model to perform as well as the validated
model if we load it later. But this may cause problems if you restart the training from checkpoint.
callbacks : `List[Lazy[TrainerCallback]]`, optional (default = `None`)
callbacks : `List[TrainerCallback]`, optional (default = `None`)
A list of callbacks that can be called at certain events: e.g. each batch, epoch, and at the start
and end of training, etc.
Expand Down Expand Up @@ -469,10 +470,17 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
batch_reg_loss = reg_loss.item()
train_reg_loss += batch_reg_loss # type: ignore

if self._scaler is not None:
self._scaler.scale(loss).backward()
else:
loss.backward()
backward_called = False
for callback in self._callbacks:
backward_called |= callback.on_backward(self, batch_outputs, backward_called)
if not backward_called:
if self._scaler is not None:
MixedPrecisionBackwardCallback(self._serialization_dir).on_backward(
self, batch_outputs, backward_called
)
else:
loss.backward()

if len(batch_group_outputs) <= 0:
continue

Expand Down
74 changes: 74 additions & 0 deletions tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
TensorBoardCallback,
ConfidenceChecksCallback,
ConsoleLoggerCallback,
OnBackwardException,
)
from allennlp.training.callbacks.confidence_checks import ConfidenceCheckError
from allennlp.training.learning_rate_schedulers import CosineWithRestarts
Expand Down Expand Up @@ -127,6 +128,26 @@ def setup_method(self):
self.validation_data_loader.index_with(self.vocab)


class ZeroGradientsBackwardCallback(TrainerCallback):
"""
Zeros all gradients after backpropagation.
"""

def on_backward(
self,
trainer: "GradientDescentTrainer",
batch_outputs: Dict[str, torch.Tensor],
backward_called: bool,
**kwargs,
) -> bool:
if backward_called:
raise OnBackwardException()
batch_outputs["loss"].backward()
for param in trainer.model.parameters():
param.grad.data.zero_()
return True


class TestTrainer(TrainerTestBase):
def test_trainer_can_run(self):
trainer = GradientDescentTrainer(
Expand Down Expand Up @@ -168,6 +189,59 @@ def test_trainer_can_run(self):
assert isinstance(metrics["peak_worker_0_memory_MB"], float)
assert metrics["peak_worker_0_memory_MB"] > 0

def test_train_zero_gradients(self):
weights = {}
for name, param in self.model.named_parameters():
weights[name] = param.data.clone()

trainer = GradientDescentTrainer(
self.model,
self.optimizer,
self.data_loader,
num_epochs=2,
validation_data_loader=self.validation_data_loader,
callbacks=[ZeroGradientsBackwardCallback(serialization_dir=self.TEST_DIR)],
)
trainer.train()

# weights should be the same
for name, param in self.model.named_parameters():
assert torch.equal(weights[name], param.data)

def test_two_backward_callbacks(self):
class SecondBackwardCallback(TrainerCallback):
"""
Changes all gradients to 1 after backpropagation.
"""

def on_backward(
self,
trainer: "GradientDescentTrainer",
batch_outputs: Dict[str, torch.Tensor],
backward_called: bool,
**kwargs,
) -> bool:
if backward_called:
raise OnBackwardException()
batch_outputs["loss"].backward()
for param in trainer.model.parameters():
param.grad = torch.ones_like(param.grad, device=param.grad.device)
return True

with pytest.raises(OnBackwardException):
trainer = GradientDescentTrainer(
self.model,
self.optimizer,
self.data_loader,
num_epochs=2,
validation_data_loader=self.validation_data_loader,
callbacks=[
ZeroGradientsBackwardCallback(serialization_dir=self.TEST_DIR),
SecondBackwardCallback(serialization_dir=self.TEST_DIR),
],
)
trainer.train()

def test_trainer_can_run_exponential_moving_average(self):
moving_average = ExponentialMovingAverage(self.model.named_parameters(), decay=0.9999)
trainer = GradientDescentTrainer(
Expand Down

0 comments on commit 0e3a225

Please sign in to comment.