Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for gradient checkpointing in BERT #4659

Merged
merged 9 commits into from
Jun 22, 2020
4 changes: 4 additions & 0 deletions src/transformers/configuration_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class BertConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, optional, defaults to False):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.

Example::

Expand Down Expand Up @@ -121,6 +123,7 @@ def __init__(
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
Expand All @@ -137,3 +140,4 @@ def __init__(
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
35 changes: 27 additions & 8 deletions src/transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import warnings

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss

Expand Down Expand Up @@ -391,6 +392,7 @@ def forward(
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])

def forward(
Expand All @@ -409,14 +411,31 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

layer_outputs = layer_module(
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
if getattr(self.config, "gradient_checkpointing", False):

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
hidden_states = layer_outputs[0]

if output_attentions:
Expand Down