From 7a0adbb56ae719a784f781bb2d80edf856e71916 Mon Sep 17 00:00:00 2001 From: Iz Beltagy Date: Thu, 28 May 2020 18:50:35 -0700 Subject: [PATCH 1/7] add support for gradient checkpointing in BERT --- src/transformers/configuration_bert.py | 4 ++++ src/transformers/modeling_bert.py | 14 +++++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/transformers/configuration_bert.py b/src/transformers/configuration_bert.py index 5026954468e734..6d4a7390f38fdb 100644 --- a/src/transformers/configuration_bert.py +++ b/src/transformers/configuration_bert.py @@ -89,6 +89,8 @@ class BertConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): The epsilon used by the layer normalization layers. + gradient_checkpointing (:obj:`bool`, optional, defaults to False): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -125,6 +127,7 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, + gradient_checkpointing=False, **kwargs ): super().__init__(pad_token_id=pad_token_id, **kwargs) @@ -141,3 +144,4 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps + self.gradient_checkpointing = gradient_checkpointing diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 1e31b5c402a5e1..db0879a9c0af9e 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -23,6 +23,7 @@ import torch from torch import nn from torch.nn import CrossEntropyLoss, MSELoss +import torch.utils.checkpoint from .activations import gelu, gelu_new, swish from .configuration_bert import BertConfig @@ -385,6 +386,7 @@ def forward( class BertEncoder(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) @@ -403,9 +405,15 @@ def forward( if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = layer_module( - hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask - ) + if self.config.gradient_checkpointing: + layer_outputs = torch.utils.checkpoint.checkpoint( + layer_module, hidden_states, attention_mask, head_mask[i], + encoder_hidden_states, encoder_attention_mask + ) + else: + layer_outputs = layer_module( + hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask + ) hidden_states = layer_outputs[0] if self.output_attentions: From 1765a1430d091e0464a2413354eac362628db56a Mon Sep 17 00:00:00 2001 From: Iz Beltagy Date: Thu, 28 May 2020 19:25:30 -0700 Subject: [PATCH 2/7] fix unit tests --- src/transformers/modeling_bert.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index db0879a9c0af9e..fde1cfc0ce48d2 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -405,10 +405,14 @@ def forward( if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.config.gradient_checkpointing: + if getattr(self.config, "gradient_checkpointing", False): layer_outputs = torch.utils.checkpoint.checkpoint( - layer_module, hidden_states, attention_mask, head_mask[i], - encoder_hidden_states, encoder_attention_mask + layer_module, + hidden_states, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, ) else: layer_outputs = layer_module( From bf4342743ad2f5a5e1090818ecb72f2ebc6e4f73 Mon Sep 17 00:00:00 2001 From: Iz Beltagy Date: Thu, 28 May 2020 19:37:44 -0700 Subject: [PATCH 3/7] isort --- src/transformers/modeling_bert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index fde1cfc0ce48d2..9337f7c2a3188e 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -21,9 +21,9 @@ import os import torch +import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss, MSELoss -import torch.utils.checkpoint from .activations import gelu, gelu_new, swish from .configuration_bert import BertConfig From b36648fc3d2c322421263612da9cf4fca93bc9d2 Mon Sep 17 00:00:00 2001 From: Iz Beltagy Date: Thu, 11 Jun 2020 00:37:08 -0700 Subject: [PATCH 4/7] black --- src/transformers/modeling_bert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 3f2aa2c8a3f976..e69e9d8a7c4a62 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -420,7 +420,7 @@ def forward( head_mask[i], encoder_hidden_states, encoder_attention_mask, - output_attentions + output_attentions, ) else: layer_outputs = layer_module( From 5eb68bb804f5ffbfc7ba13c45a47717f72d04574 Mon Sep 17 00:00:00 2001 From: Iz Beltagy Date: Fri, 12 Jun 2020 14:31:55 -0700 Subject: [PATCH 5/7] workaround for `torch.utils.checkpoint.checkpoint` not accepting bool --- src/transformers/modeling_bert.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index e69e9d8a7c4a62..ec2d1959a5d78e 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -413,14 +413,20 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if getattr(self.config, "gradient_checkpointing", False): + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + layer_outputs = torch.utils.checkpoint.checkpoint( - layer_module, + create_custom_forward(layer_module), hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, - output_attentions, ) else: layer_outputs = layer_module( From afee616afbd0b450d3e63d94fb338ef13e439cda Mon Sep 17 00:00:00 2001 From: Iz Beltagy Date: Fri, 12 Jun 2020 16:39:13 -0700 Subject: [PATCH 6/7] Revert "workaround for `torch.utils.checkpoint.checkpoint` not accepting bool" This reverts commit 5eb68bb804f5ffbfc7ba13c45a47717f72d04574. --- src/transformers/modeling_bert.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index ec2d1959a5d78e..e69e9d8a7c4a62 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -413,20 +413,14 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if getattr(self.config, "gradient_checkpointing", False): - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_module, hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + output_attentions, ) else: layer_outputs = layer_module( From 00bd3d5255940b3e235f14ffeddd50b030174df8 Mon Sep 17 00:00:00 2001 From: Iz Beltagy Date: Fri, 12 Jun 2020 14:31:55 -0700 Subject: [PATCH 7/7] workaround for `torch.utils.checkpoint.checkpoint` not accepting bool --- src/transformers/modeling_bert.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index e69e9d8a7c4a62..ec2d1959a5d78e 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -413,14 +413,20 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if getattr(self.config, "gradient_checkpointing", False): + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + layer_outputs = torch.utils.checkpoint.checkpoint( - layer_module, + create_custom_forward(layer_module), hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, - output_attentions, ) else: layer_outputs = layer_module(