Skip to content

Commit

Permalink
Make gradient_checkpointing a training argument (huggingface#13657)
Browse files Browse the repository at this point in the history
* Make gradient_checkpointing a training argument

* Update src/transformers/modeling_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/configuration_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Fix tests

* Style

* document Gradient Checkpointing as a performance feature

* Small rename

* PoC for not using the config

* Adapt BC to new PoC

* Forgot to save

* Rollout changes to all other models

* Fix typo

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas@stason.org>
  • Loading branch information
3 people authored and Narsil committed Sep 25, 2021
1 parent e88712c commit 470f6cf
Show file tree
Hide file tree
Showing 96 changed files with 531 additions and 309 deletions.
4 changes: 2 additions & 2 deletions docs/source/model_doc/led.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ Tips:
- LED makes use of *global attention* by means of the ``global_attention_mask`` (see
:class:`~transformers.LongformerModel`). For summarization, it is advised to put *global attention* only on the first
``<s>`` token. For question answering, it is advised to put *global attention* on all tokens of the question.
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by setting
``config.gradient_checkpointing = True``.
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by executing
``model.gradient_checkpointing_enable()``.
- A notebook showing how to evaluate LED, can be accessed `here
<https://colab.research.google.com/drive/12INTTR6n64TzS4RrXZxMSXfrOd9Xzamo?usp=sharing>`__.
- A notebook showing how to fine-tune LED, can be accessed `here
Expand Down
16 changes: 16 additions & 0 deletions docs/source/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Software:
- Tensor Parallelism
- Low-memory Optimizers
- fp16/bf16 (smaller data)
- Gradient checkpointing



Expand Down Expand Up @@ -226,6 +227,21 @@ pytorch `autocast` which performs AMP include a caching feature, which speed thi

Autocast maintains a cache of the FP16 casts of model params (leaves). This helps streamline parameter reuse: if the same FP32 param is used in several different FP16list ops, like several matmuls, instead of re-casting the param to FP16 on entering each matmul, the cast will occur on the first matmul, the casted FP16 copy will be cached, and for all later matmuls the FP16 copy will be reused. The cache is maintained only within a particular outermost autocast context. When you exit the autocast context the cache is dropped. For recommended usage, in which autocast wraps the forward pass, and then you exit the context before calling backward(), this means the cache only lasts the duration of the forward pass each iteration, and will be rebuilt next iteration. (The cache of FP16-casted copies MUST be rebuilt each iteration. The FP32 params get updated by the optimizer, so the FP16 copies must be recreated, otherwise the FP16 values will be stale.)


### Gradient Checkpointing

One way to use significantly less GPU memory is to enabled "Gradient Checkpointing" (also known as "activation checkpointing"). When enabled, a lot of memory can be freed at the cost of small decrease in the training speed due to recomputing parts of the graph during back-propagation.

This technique was first shared in the paper: [Training Deep Nets with Sublinear Memory Cost](https://arxiv.org/abs/1604.06174). The paper will also give you the exact details on the savings, but it's in the ballpark of `O(sqrt(n))`, where `n` is the number of feed-forward layers.

To activate this feature in 🤗 Transformers for models that support it, use:

```python
model.gradient_checkpointing_enable()
```
or add `--gradient_checkpointing` to the Trainer arguments.


### Batch sizes

One gets the most efficient performance when batch sizes and input/output neuron counts are divisible by a certain number, which typically starts at 8, but can be much higher as well. That number varies a lot depending on the specific hardware being used and the dtype of the model.
Expand Down
5 changes: 0 additions & 5 deletions examples/pytorch/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,3 @@ python run_clm.py --model_type gpt2 --tokenizer_name gpt2 \ --config_overrides="
```

This feature is only available in `run_clm.py`, `run_plm.py` and `run_mlm.py`.

This feature can also be used to activate gradient checkpointing by passing:
```
--config_overrides "gradient_checkpointing=true,use_cache=False"
```
9 changes: 9 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import copy
import json
import os
import warnings
from typing import Any, Dict, Tuple, Union

from . import __version__
Expand Down Expand Up @@ -330,6 +331,14 @@ def __init__(self, **kwargs):
# Drop the transformers version info
self.transformers_version = kwargs.pop("transformers_version", None)

# Deal with gradient checkpointing
if "gradient_checkpointing" in kwargs:
warnings.warn(
"Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
"Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the "
"`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`."
)

# Additional attributes without default values
for key, value in kwargs.items():
try:
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

import torch
Expand Down Expand Up @@ -450,6 +451,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_keys_to_ignore_on_save = None

is_parallelizable = False
supports_gradient_checkpointing = False

@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
Expand All @@ -469,6 +471,10 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
# Save config and origin of the pretrained weights if given in model
self.config = config
self.name_or_path = config.name_or_path
if getattr(self.config, "gradient_checkpointing", False):
self.gradient_checkpointing_enable()
# Remove the attribute now that is has been consumed, so it's no saved in the config.
delattr(self.config, "gradient_checkpointing")

@classmethod
def _from_config(cls, config, **kwargs):
Expand Down Expand Up @@ -932,6 +938,27 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]):

self.base_model._prune_heads(heads_to_prune)

def gradient_checkpointing_enable(self, flag: bool = True):
"""
Activates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True))

def gradient_checkpointing_disable(self, flag: bool = True):
"""
Deactivates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
if self.supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/bart/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ class BartConfig(PretrainedConfig):
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the decoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Expand Down Expand Up @@ -131,7 +129,6 @@ def __init__(
init_std=0.02,
classifier_dropout=0.0,
scale_embedding=False,
gradient_checkpointing=False,
use_cache=True,
num_labels=3,
pad_token_id=1,
Expand Down Expand Up @@ -161,7 +158,6 @@ def __init__(
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True

super().__init__(
Expand Down
14 changes: 10 additions & 4 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def forward(self, hidden_states: torch.Tensor):
class BartPretrainedModel(PreTrainedModel):
config_class = BartConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]

def _init_weights(self, module):
Expand All @@ -484,6 +485,10 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (BartDecoder, BartEncoder)):
module.gradient_checkpointing = value

@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
Expand Down Expand Up @@ -687,6 +692,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
self.layernorm_embedding = nn.LayerNorm(embed_dim)

self.init_weights()
self.gradient_checkpointing = False

def forward(
self,
Expand Down Expand Up @@ -782,7 +788,7 @@ def forward(
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -849,6 +855,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
self.layernorm_embedding = nn.LayerNorm(config.d_model)

self.init_weights()
self.gradient_checkpointing = False

def get_input_embeddings(self):
return self.embed_tokens
Expand Down Expand Up @@ -1020,12 +1027,11 @@ def forward(

past_key_value = past_key_values[idx] if past_key_values is not None else None

if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:

if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/beit/configuration_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ class BeitConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
image_size (:obj:`int`, `optional`, defaults to :obj:`224`):
The size (resolution) of each image.
patch_size (:obj:`int`, `optional`, defaults to :obj:`16`):
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def __init__(self, config, window_size=None):
for i in range(config.num_hidden_layers)
]
)
self.gradient_checkpointing = False

def forward(
self,
Expand All @@ -450,7 +451,7 @@ def forward(

layer_head_mask = head_mask[i] if head_mask is not None else None

if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -494,6 +495,7 @@ class BeitPreTrainedModel(PreTrainedModel):

config_class = BeitConfig
base_model_prefix = "beit"
supports_gradient_checkpointing = True

def _init_weights(self, module):
"""Initialize the weights"""
Expand All @@ -511,6 +513,10 @@ def _init_weights(self, module):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BeitEncoder):
module.gradient_checkpointing = value


BEIT_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/bert/configuration_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@ class BertConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
Expand Down Expand Up @@ -137,7 +135,6 @@ def __init__(
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
position_embedding_type="absolute",
use_cache=True,
classifier_dropout=None,
Expand All @@ -157,7 +154,6 @@ def __init__(
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout
Expand Down
11 changes: 8 additions & 3 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False

def forward(
self,
Expand All @@ -555,12 +556,11 @@ def forward(
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None

if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:

if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

Expand Down Expand Up @@ -714,6 +714,7 @@ class BertPreTrainedModel(PreTrainedModel):
config_class = BertConfig
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]

def _init_weights(self, module):
Expand All @@ -732,6 +733,10 @@ def _init_weights(self, module):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value


@dataclass
class BertForPreTrainingOutput(ModelOutput):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ class BertGenerationConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
Expand Down Expand Up @@ -96,7 +94,6 @@ def __init__(
pad_token_id=0,
bos_token_id=2,
eos_token_id=1,
gradient_checkpointing=False,
position_embedding_type="absolute",
use_cache=True,
**kwargs
Expand All @@ -114,6 +111,5 @@ def __init__(
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
4 changes: 0 additions & 4 deletions src/transformers/models/big_bird/configuration_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ class BigBirdConfig(PretrainedConfig):
num_random_blocks (:obj:`int`, `optional`, defaults to 3)
Each query is going to attend these many number of random blocks. Useful only when :obj:`attention_type ==
"block_sparse"`.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
classifier_dropout (:obj:`float`, `optional`):
The dropout ratio for the classification head.
Expand Down Expand Up @@ -127,7 +125,6 @@ def __init__(
rescale_embeddings=False,
block_size=64,
num_random_blocks=3,
gradient_checkpointing=False,
classifier_dropout=None,
**kwargs
):
Expand All @@ -153,7 +150,6 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache
self.is_encoder_decoder = is_encoder_decoder
self.gradient_checkpointing = gradient_checkpointing

self.rescale_embeddings = rescale_embeddings
self.attention_type = attention_type
Expand Down
Loading

0 comments on commit 470f6cf

Please sign in to comment.