forked from allenai/allennlp
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added
on_backward
trainer callback (allenai#5249)
* added BackwardCallback * finished tests * fixed linting issue * revised design per Dirk's suggestion * added OnBackwardException, changed loss to batch_ouputs, etc. Co-authored-by: Arjun Subramonian <arjuns@Arjuns-MacBook-Pro.local>
- Loading branch information
1 parent
38c930b
commit 0e3a225
Showing
6 changed files
with
148 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from typing import Dict, TYPE_CHECKING | ||
import torch | ||
|
||
from allennlp.training.callbacks.callback import TrainerCallback | ||
|
||
if TYPE_CHECKING: | ||
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer | ||
|
||
|
||
@TrainerCallback.register("mixed_precision_backward") | ||
class MixedPrecisionBackwardCallback(TrainerCallback): | ||
""" | ||
Performs backpropagation for mixed precision training. | ||
""" | ||
|
||
def on_backward( | ||
self, | ||
trainer: "GradientDescentTrainer", | ||
batch_outputs: Dict[str, torch.Tensor], | ||
backward_called: bool, | ||
**kwargs | ||
) -> bool: | ||
if backward_called: | ||
raise OnBackwardException() | ||
trainer._scaler.scale(batch_outputs["loss"]).backward() # type: ignore | ||
return True | ||
|
||
|
||
class OnBackwardException(Exception): | ||
""" | ||
The exception type raised if an `on_backward` callback | ||
attempts to call `backward` when `backward_called` is `True`. | ||
""" | ||
|
||
def __init__(self, message="") -> None: | ||
super().__init__( | ||
"Backpropagation has already been performed" | ||
"and the computation graph has been erased, so" | ||
"calling `loss.backward` is not permitted. " + message | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters