Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

added on_backward trainer callback #5249

Merged
merged 5 commits into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added `BackwardCallback`, a training callback which allows for control over backpropagation and gradient manipulation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment isn't accurate anymore, is it?


### Fixed

- Fixed Broken link in `allennlp.fairness.fairness_metrics.Separation` docs
Expand Down
5 changes: 5 additions & 0 deletions allennlp/training/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@
from allennlp.training.callbacks.tensorboard import TensorBoardCallback
from allennlp.training.callbacks.track_epoch import TrackEpochCallback
from allennlp.training.callbacks.wandb import WandBCallback
from allennlp.training.callbacks.backward import (
VanillaBackwardCallback,
MixedPrecisionBackwardCallback,
BackwardCallbackError,
)
51 changes: 51 additions & 0 deletions allennlp/training/callbacks/backward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import TYPE_CHECKING
import torch

from allennlp.training.callbacks.callback import TrainerCallback

if TYPE_CHECKING:
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer


class BackwardCallback(TrainerCallback):
def on_backward(
self, trainer: "GradientDescentTrainer", loss: torch.FloatTensor, **kwargs
) -> None:
"""
This callback hook performs backpropagation and allows for gradient manipulation.
"""
raise NotImplementedError


@TrainerCallback.register("vanilla_backward")
class VanillaBackwardCallback(BackwardCallback):
"""
Performs vanilla backpropagation.
"""

def on_backward(
self, trainer: "GradientDescentTrainer", loss: torch.FloatTensor, **kwargs
) -> None:
loss.backward()


@TrainerCallback.register("mixed_precision_backward")
class MixedPrecisionBackwardCallback(BackwardCallback):
"""
Performs backpropagation for mixed precision training.
"""

def on_backward(
self, trainer: "GradientDescentTrainer", loss: torch.FloatTensor, **kwargs
) -> None:
trainer._scaler.scale(loss).backward() # type: ignore


class BackwardCallbackError(Exception):
"""
The error type raised when multiple callbacks passed to a trainer
implement `on_backward`.
"""

def __init__(self, message) -> None:
super().__init__(message)
2 changes: 1 addition & 1 deletion allennlp/training/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class TrainerCallback(Registrable):
"""
A general callback object that handles multiple events.

This class has `on_batch`, `on_epoch`, and `on_end` methods, corresponding to
This class has `on_backward`, `on_batch`, `on_epoch`, and `on_end` methods, corresponding to
each callback type. Each one receives the state of the wrapper object as `self`.
This enables easier state sharing between related callbacks.

Expand Down
48 changes: 34 additions & 14 deletions allennlp/training/gradient_descent_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
from allennlp.models.model import Model
from allennlp.training.callbacks import ConsoleLoggerCallback
from allennlp.training.callbacks.confidence_checks import ConfidenceChecksCallback
from allennlp.training.callbacks.backward import (
BackwardCallback,
MixedPrecisionBackwardCallback,
VanillaBackwardCallback,
BackwardCallbackError,
)
from allennlp.training.checkpointer import Checkpointer
from allennlp.training.learning_rate_schedulers.learning_rate_scheduler import LearningRateScheduler
from allennlp.training.metric_tracker import MetricTracker
Expand Down Expand Up @@ -148,9 +154,9 @@ class GradientDescentTrainer(Trainer):
parameters. This is necessary because we want the saved model to perform as well as the validated
model if we load it later. But this may cause problems if you restart the training from checkpoint.

callbacks : `List[Lazy[TrainerCallback]]`, optional (default = `None`)
callbacks : `List[TrainerCallback]`, optional (default = `None`)
A list of callbacks that can be called at certain events: e.g. each batch, epoch, and at the start
and end of training, etc.
and end of training, etc. At most one callback can be a `BackwardCallback`.

distributed : `bool`, optional, (default = `False`)
If set, PyTorch's `DistributedDataParallel` is used to train the model in multiple GPUs. This also
Expand Down Expand Up @@ -277,9 +283,31 @@ def __init__(
self._momentum_scheduler = momentum_scheduler
self._moving_average = moving_average

# Enable automatic mixed precision training.
self._scaler: Optional[amp.GradScaler] = None
self._use_amp = use_amp
if self._use_amp:
if self.cuda_device == torch.device("cpu"):
raise ValueError("Using AMP requires a cuda device")
self._scaler = amp.GradScaler()

self._callbacks = callbacks or []
default_callbacks = list(DEFAULT_CALLBACKS) if enable_default_callbacks else []

on_backward_callable = [
isinstance(callback, BackwardCallback) for callback in self._callbacks
]
# append VanillaBackwardCallback or MixedPrecisionBackwardCallback if no callback is BackwardCallback
if not any(on_backward_callable):
if self._scaler is not None:
default_callbacks.append(MixedPrecisionBackwardCallback)
else:
default_callbacks.append(VanillaBackwardCallback)
on_backward_callable_iter = iter(on_backward_callable)
# raise BackwardCallbackError if more than one callback is BackwardCallback
if any(on_backward_callable_iter) and any(on_backward_callable_iter):
raise BackwardCallbackError("At most one callback can be a `BackwardCallback`.")

if run_confidence_checks:
default_callbacks.append(ConfidenceChecksCallback)
for callback_cls in default_callbacks:
Expand All @@ -291,14 +319,6 @@ def __init__(

self._num_gradient_accumulation_steps = num_gradient_accumulation_steps

# Enable automatic mixed precision training.
self._scaler: Optional[amp.GradScaler] = None
self._use_amp = use_amp
if self._use_amp:
if self.cuda_device == torch.device("cpu"):
raise ValueError("Using AMP requires a cuda device")
self._scaler = amp.GradScaler()

# Using `DistributedDataParallel`(ddp) brings in a quirk wrt AllenNLP's `Model` interface and its
# usage. A `Model` object is wrapped by `ddp`, but assigning the wrapped model to `self.model`
# will break the usages such as `Model.get_regularization_penalty`, `Model.get_metrics`, etc.
Expand Down Expand Up @@ -469,10 +489,10 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
batch_reg_loss = reg_loss.item()
train_reg_loss += batch_reg_loss # type: ignore

if self._scaler is not None:
self._scaler.scale(loss).backward()
else:
loss.backward()
for callback in self._callbacks:
if isinstance(callback, BackwardCallback):
callback.on_backward(self, loss)

if len(batch_group_outputs) <= 0:
continue

Expand Down
58 changes: 58 additions & 0 deletions tests/commands/train_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from allennlp.training.callbacks.backward import BackwardCallback
import argparse
import copy
import json
Expand Down Expand Up @@ -27,6 +28,11 @@
ExponentialLearningRateScheduler,
LearningRateScheduler,
)
from allennlp.training.callbacks.backward import (
VanillaBackwardCallback,
MixedPrecisionBackwardCallback,
BackwardCallbackError,
)

SEQUENCE_TAGGING_DATA_PATH = str(AllenNlpTestCase.FIXTURES_ROOT / "data" / "sequence_tagging.tsv")
SEQUENCE_TAGGING_SHARDS_PATH = str(AllenNlpTestCase.FIXTURES_ROOT / "data" / "shards" / "*")
Expand Down Expand Up @@ -89,6 +95,32 @@ def on_start(
assert torch.distributed.get_rank() == 0


@TrainerCallback.register("zero_gradients")
class ZeroGradientsBackwardCallback(BackwardCallback):
"""
Zeros all gradients after backpropagation.
"""

def on_backward(
self, trainer: "GradientDescentTrainer", loss: torch.FloatTensor, **kwargs
) -> None:
loss.backward()
for param in trainer.model.parameters():
param.grad *= 0.0


@TrainerCallback.register("extra_backward")
class ExtraBackwardCallback(BackwardCallback):
"""
Invalid extra backward callback.
"""

def on_backward(
self, trainer: "GradientDescentTrainer", loss: torch.FloatTensor, **kwargs
) -> None:
pass


class TestTrain(AllenNlpTestCase):
DEFAULT_PARAMS = Params(
{
Expand Down Expand Up @@ -157,6 +189,32 @@ def test_train_model(self):
recover=True,
)

def test_train_zero_gradients(self):
import copy

params = copy.deepcopy(self.DEFAULT_PARAMS)
weights = {}
model = train_model(params, serialization_dir=os.path.join(self.TEST_DIR, "test_baseline"))
for name, param in model.named_parameters():
weights[name] = param.data.clone()

params = copy.deepcopy(self.DEFAULT_PARAMS)
params["trainer"]["callbacks"] = ["zero_gradients"]
model = train_model(
params, serialization_dir=os.path.join(self.TEST_DIR, "test_zero_gradients")
)
# weights should not be the same
for name, param in model.named_parameters():
assert not torch.equal(weights[name], param.data)

def test_backward_callback_exception(self):
import copy

params = copy.deepcopy(self.DEFAULT_PARAMS)
params["trainer"]["callbacks"] = ["zero_gradients", "extra_backward"]
with pytest.raises(BackwardCallbackError):
train_model(params, serialization_dir=os.path.join(self.TEST_DIR, "extra_backward"))

@cpu_or_gpu
def test_detect_gpu(self):
import copy
Expand Down