From 470f6cfbe9c87af2569d4c6ddbe8873439f08bd7 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 22 Sep 2021 07:51:38 -0400 Subject: [PATCH] Make gradient_checkpointing a training argument (#13657) * Make gradient_checkpointing a training argument * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman * Update src/transformers/configuration_utils.py Co-authored-by: Stas Bekman * Fix tests * Style * document Gradient Checkpointing as a performance feature * Small rename * PoC for not using the config * Adapt BC to new PoC * Forgot to save * Rollout changes to all other models * Fix typo Co-authored-by: Stas Bekman Co-authored-by: Stas Bekman --- docs/source/model_doc/led.rst | 4 +-- docs/source/performance.md | 16 +++++++++ examples/pytorch/language-modeling/README.md | 5 --- src/transformers/configuration_utils.py | 9 +++++ src/transformers/modeling_utils.py | 27 +++++++++++++++ .../models/bart/configuration_bart.py | 4 --- src/transformers/models/bart/modeling_bart.py | 14 +++++--- .../models/beit/configuration_beit.py | 2 -- src/transformers/models/beit/modeling_beit.py | 8 ++++- .../models/bert/configuration_bert.py | 4 --- src/transformers/models/bert/modeling_bert.py | 11 ++++-- .../configuration_bert_generation.py | 4 --- .../models/big_bird/configuration_big_bird.py | 4 --- .../models/big_bird/modeling_big_bird.py | 11 ++++-- .../configuration_bigbird_pegasus.py | 4 --- .../modeling_bigbird_pegasus.py | 14 +++++--- .../blenderbot/configuration_blenderbot.py | 4 --- .../models/blenderbot/modeling_blenderbot.py | 14 +++++--- .../configuration_blenderbot_small.py | 4 --- .../modeling_blenderbot_small.py | 14 +++++--- .../models/canine/configuration_canine.py | 2 -- .../models/canine/modeling_canine.py | 8 ++++- .../models/clip/configuration_clip.py | 8 ----- src/transformers/models/clip/modeling_clip.py | 8 ++++- .../models/convbert/modeling_convbert.py | 8 ++++- .../models/deit/configuration_deit.py | 2 -- src/transformers/models/deit/modeling_deit.py | 8 ++++- src/transformers/models/detr/modeling_detr.py | 8 ++++- .../models/dpr/configuration_dpr.py | 4 --- src/transformers/models/dpr/modeling_dpr.py | 10 +++++- .../models/electra/modeling_electra.py | 11 ++++-- .../models/fnet/configuration_fnet.py | 2 -- src/transformers/models/fnet/modeling_fnet.py | 8 ++++- .../models/gpt2/configuration_gpt2.py | 6 +--- src/transformers/models/gpt2/modeling_gpt2.py | 11 ++++-- .../models/gpt_neo/configuration_gpt_neo.py | 4 --- .../models/gpt_neo/modeling_gpt_neo.py | 11 ++++-- .../models/gptj/configuration_gptj.py | 4 --- src/transformers/models/gptj/modeling_gptj.py | 11 ++++-- .../models/hubert/configuration_hubert.py | 4 --- .../models/hubert/modeling_hubert.py | 11 ++++-- .../models/ibert/modeling_ibert.py | 18 ++++------ .../models/layoutlm/configuration_layoutlm.py | 4 --- .../models/layoutlm/modeling_layoutlm.py | 11 ++++-- .../models/layoutlmv2/modeling_layoutlmv2.py | 9 ++++- .../models/led/configuration_led.py | 4 --- src/transformers/models/led/modeling_led.py | 14 +++++--- .../models/longformer/modeling_longformer.py | 8 ++++- .../models/luke/configuration_luke.py | 4 --- src/transformers/models/luke/modeling_luke.py | 8 ++++- .../models/m2m_100/configuration_m2m_100.py | 4 --- .../models/m2m_100/modeling_m2m_100.py | 14 +++++--- .../models/marian/configuration_marian.py | 4 --- .../models/marian/modeling_marian.py | 14 +++++--- .../models/mbart/configuration_mbart.py | 4 --- .../models/mbart/modeling_mbart.py | 14 +++++--- .../configuration_megatron_bert.py | 4 --- .../convert_megatron_bert_checkpoint.py | 1 - .../megatron_bert/modeling_megatron_bert.py | 11 ++++-- .../convert_megatron_gpt2_checkpoint.py | 1 - .../models/pegasus/configuration_pegasus.py | 4 --- .../models/pegasus/modeling_pegasus.py | 14 +++++--- .../prophetnet/configuration_prophetnet.py | 6 ---- .../models/prophetnet/modeling_prophetnet.py | 14 +++++--- .../models/rembert/configuration_rembert.py | 2 -- .../models/rembert/modeling_rembert.py | 11 ++++-- .../models/roberta/modeling_roberta.py | 11 ++++-- .../models/roformer/configuration_roformer.py | 4 --- .../models/roformer/modeling_roformer.py | 11 ++++-- .../configuration_speech_to_text.py | 2 -- .../speech_to_text/modeling_speech_to_text.py | 13 +++++-- .../configuration_speech_to_text_2.py | 2 -- .../modeling_speech_to_text_2.py | 10 ++++-- .../models/splinter/configuration_splinter.py | 2 -- .../models/splinter/modeling_splinter.py | 11 ++++-- .../models/t5/configuration_t5.py | 4 --- src/transformers/models/t5/modeling_t5.py | 15 +++++--- .../models/tapas/configuration_tapas.py | 4 --- .../models/tapas/modeling_tapas.py | 8 ++++- .../visual_bert/modeling_visual_bert.py | 8 ++++- .../models/vit/configuration_vit.py | 2 -- src/transformers/models/vit/modeling_vit.py | 8 ++++- .../models/wav2vec2/configuration_wav2vec2.py | 4 --- .../models/wav2vec2/modeling_wav2vec2.py | 11 ++++-- src/transformers/trainer.py | 6 +++- src/transformers/training_args.py | 8 +++++ ...on_{{cookiecutter.lowercase_modelname}}.py | 4 --- ...ng_{{cookiecutter.lowercase_modelname}}.py | 34 ++++++++++--------- tests/test_modeling_beit.py | 21 ++++++++++++ tests/test_modeling_common.py | 5 ++- tests/test_modeling_deit.py | 24 ++++++++++++- tests/test_modeling_flax_gpt2.py | 3 +- tests/test_modeling_flax_gpt_neo.py | 3 +- tests/test_modeling_gpt2.py | 25 +++++++++----- tests/test_modeling_gpt_neo.py | 24 ++++++++----- tests/test_modeling_gptj.py | 25 +++++++++----- 96 files changed, 531 insertions(+), 309 deletions(-) diff --git a/docs/source/model_doc/led.rst b/docs/source/model_doc/led.rst index 2e05163d37b48e..1eaa9e325ffad9 100644 --- a/docs/source/model_doc/led.rst +++ b/docs/source/model_doc/led.rst @@ -46,8 +46,8 @@ Tips: - LED makes use of *global attention* by means of the ``global_attention_mask`` (see :class:`~transformers.LongformerModel`). For summarization, it is advised to put *global attention* only on the first ```` token. For question answering, it is advised to put *global attention* on all tokens of the question. -- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by setting - ``config.gradient_checkpointing = True``. +- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by executing + ``model.gradient_checkpointing_enable()``. - A notebook showing how to evaluate LED, can be accessed `here `__. - A notebook showing how to fine-tune LED, can be accessed `here diff --git a/docs/source/performance.md b/docs/source/performance.md index 4f479d8575699f..c3239f3b0c0842 100644 --- a/docs/source/performance.md +++ b/docs/source/performance.md @@ -53,6 +53,7 @@ Software: - Tensor Parallelism - Low-memory Optimizers - fp16/bf16 (smaller data) +- Gradient checkpointing @@ -226,6 +227,21 @@ pytorch `autocast` which performs AMP include a caching feature, which speed thi Autocast maintains a cache of the FP16 casts of model params (leaves). This helps streamline parameter reuse: if the same FP32 param is used in several different FP16list ops, like several matmuls, instead of re-casting the param to FP16 on entering each matmul, the cast will occur on the first matmul, the casted FP16 copy will be cached, and for all later matmuls the FP16 copy will be reused. The cache is maintained only within a particular outermost autocast context. When you exit the autocast context the cache is dropped. For recommended usage, in which autocast wraps the forward pass, and then you exit the context before calling backward(), this means the cache only lasts the duration of the forward pass each iteration, and will be rebuilt next iteration. (The cache of FP16-casted copies MUST be rebuilt each iteration. The FP32 params get updated by the optimizer, so the FP16 copies must be recreated, otherwise the FP16 values will be stale.) + +### Gradient Checkpointing + +One way to use significantly less GPU memory is to enabled "Gradient Checkpointing" (also known as "activation checkpointing"). When enabled, a lot of memory can be freed at the cost of small decrease in the training speed due to recomputing parts of the graph during back-propagation. + +This technique was first shared in the paper: [Training Deep Nets with Sublinear Memory Cost](https://arxiv.org/abs/1604.06174). The paper will also give you the exact details on the savings, but it's in the ballpark of `O(sqrt(n))`, where `n` is the number of feed-forward layers. + +To activate this feature in 🤗 Transformers for models that support it, use: + +```python +model.gradient_checkpointing_enable() +``` +or add `--gradient_checkpointing` to the Trainer arguments. + + ### Batch sizes One gets the most efficient performance when batch sizes and input/output neuron counts are divisible by a certain number, which typically starts at 8, but can be much higher as well. That number varies a lot depending on the specific hardware being used and the dtype of the model. diff --git a/examples/pytorch/language-modeling/README.md b/examples/pytorch/language-modeling/README.md index 23989d7ed1a0f9..c768f5ec31bb6e 100644 --- a/examples/pytorch/language-modeling/README.md +++ b/examples/pytorch/language-modeling/README.md @@ -174,8 +174,3 @@ python run_clm.py --model_type gpt2 --tokenizer_name gpt2 \ --config_overrides=" ``` This feature is only available in `run_clm.py`, `run_plm.py` and `run_mlm.py`. - -This feature can also be used to activate gradient checkpointing by passing: -``` ---config_overrides "gradient_checkpointing=true,use_cache=False" -``` diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 45683ac801a310..bc3ecf77ba1c47 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -19,6 +19,7 @@ import copy import json import os +import warnings from typing import Any, Dict, Tuple, Union from . import __version__ @@ -330,6 +331,14 @@ def __init__(self, **kwargs): # Drop the transformers version info self.transformers_version = kwargs.pop("transformers_version", None) + # Deal with gradient checkpointing + if "gradient_checkpointing" in kwargs: + warnings.warn( + "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 " + "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the " + "`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`." + ) + # Additional attributes without default values for key, value in kwargs.items(): try: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e82d0ad9e31f50..21a1b09f30d684 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -20,6 +20,7 @@ import warnings from contextlib import contextmanager from dataclasses import dataclass +from functools import partial from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import torch @@ -450,6 +451,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix _keys_to_ignore_on_save = None is_parallelizable = False + supports_gradient_checkpointing = False @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: @@ -469,6 +471,10 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): # Save config and origin of the pretrained weights if given in model self.config = config self.name_or_path = config.name_or_path + if getattr(self.config, "gradient_checkpointing", False): + self.gradient_checkpointing_enable() + # Remove the attribute now that is has been consumed, so it's no saved in the config. + delattr(self.config, "gradient_checkpointing") @classmethod def _from_config(cls, config, **kwargs): @@ -932,6 +938,27 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]): self.base_model._prune_heads(heads_to_prune) + def gradient_checkpointing_enable(self, flag: bool = True): + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if not self.supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + self.apply(partial(self._set_gradient_checkpointing, value=True)) + + def gradient_checkpointing_disable(self, flag: bool = True): + """ + Deactivates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if self.supports_gradient_checkpointing: + self.apply(partial(self._set_gradient_checkpointing, value=False)) + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/transformers/models/bart/configuration_bart.py b/src/transformers/models/bart/configuration_bart.py index e26afb2ab4b6e2..6efbe4ca510a0e 100644 --- a/src/transformers/models/bart/configuration_bart.py +++ b/src/transformers/models/bart/configuration_bart.py @@ -82,8 +82,6 @@ class BartConfig(PretrainedConfig): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -131,7 +129,6 @@ def __init__( init_std=0.02, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, use_cache=True, num_labels=3, pad_token_id=1, @@ -161,7 +158,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index a466be30a6881e..134669cee4ba6c 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -471,6 +471,7 @@ def forward(self, hidden_states: torch.Tensor): class BartPretrainedModel(PreTrainedModel): config_class = BartConfig base_model_prefix = "model" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] def _init_weights(self, module): @@ -484,6 +485,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (BartDecoder, BartEncoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -687,6 +692,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No self.layernorm_embedding = nn.LayerNorm(embed_dim) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -782,7 +788,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -849,6 +855,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No self.layernorm_embedding = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -1020,12 +1027,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/beit/configuration_beit.py b/src/transformers/models/beit/configuration_beit.py index 08ecc60646976b..d31f83dd3a5e59 100644 --- a/src/transformers/models/beit/configuration_beit.py +++ b/src/transformers/models/beit/configuration_beit.py @@ -57,8 +57,6 @@ class BeitConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. image_size (:obj:`int`, `optional`, defaults to :obj:`224`): The size (resolution) of each image. patch_size (:obj:`int`, `optional`, defaults to :obj:`16`): diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 236551d27cc667..1ad3fcd1e6d14b 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -432,6 +432,7 @@ def __init__(self, config, window_size=None): for i in range(config.num_hidden_layers) ] ) + self.gradient_checkpointing = False def forward( self, @@ -450,7 +451,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -494,6 +495,7 @@ class BeitPreTrainedModel(PreTrainedModel): config_class = BeitConfig base_model_prefix = "beit" + supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" @@ -511,6 +513,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BeitEncoder): + module.gradient_checkpointing = value + BEIT_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ subclass. Use diff --git a/src/transformers/models/bert/configuration_bert.py b/src/transformers/models/bert/configuration_bert.py index 8359f0c3b7e2b2..861cdfbc8ea676 100644 --- a/src/transformers/models/bert/configuration_bert.py +++ b/src/transformers/models/bert/configuration_bert.py @@ -92,8 +92,6 @@ class BertConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on @@ -137,7 +135,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - gradient_checkpointing=False, position_embedding_type="absolute", use_cache=True, classifier_dropout=None, @@ -157,7 +154,6 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index f02d67a31a2106..ecb0d184a4a45f 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -529,6 +529,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -555,12 +556,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -714,6 +714,7 @@ class BertPreTrainedModel(PreTrainedModel): config_class = BertConfig load_tf_weights = load_tf_weights_in_bert base_model_prefix = "bert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -732,6 +733,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BertEncoder): + module.gradient_checkpointing = value + @dataclass class BertForPreTrainingOutput(ModelOutput): diff --git a/src/transformers/models/bert_generation/configuration_bert_generation.py b/src/transformers/models/bert_generation/configuration_bert_generation.py index 54659f4394a5f8..2284f873e708b6 100644 --- a/src/transformers/models/bert_generation/configuration_bert_generation.py +++ b/src/transformers/models/bert_generation/configuration_bert_generation.py @@ -52,8 +52,6 @@ class BertGenerationConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on @@ -96,7 +94,6 @@ def __init__( pad_token_id=0, bos_token_id=2, eos_token_id=1, - gradient_checkpointing=False, position_embedding_type="absolute", use_cache=True, **kwargs @@ -114,6 +111,5 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing self.position_embedding_type = position_embedding_type self.use_cache = use_cache diff --git a/src/transformers/models/big_bird/configuration_big_bird.py b/src/transformers/models/big_bird/configuration_big_bird.py index e6fdfd1d14cd97..85dd8de7dd9a00 100644 --- a/src/transformers/models/big_bird/configuration_big_bird.py +++ b/src/transformers/models/big_bird/configuration_big_bird.py @@ -82,8 +82,6 @@ class BigBirdConfig(PretrainedConfig): num_random_blocks (:obj:`int`, `optional`, defaults to 3) Each query is going to attend these many number of random blocks. Useful only when :obj:`attention_type == "block_sparse"`. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. classifier_dropout (:obj:`float`, `optional`): The dropout ratio for the classification head. @@ -127,7 +125,6 @@ def __init__( rescale_embeddings=False, block_size=64, num_random_blocks=3, - gradient_checkpointing=False, classifier_dropout=None, **kwargs ): @@ -153,7 +150,6 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.use_cache = use_cache self.is_encoder_decoder = is_encoder_decoder - self.gradient_checkpointing = gradient_checkpointing self.rescale_embeddings = rescale_embeddings self.attention_type = attention_type diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index f7d0d857bc457f..84a428591e6900 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1555,6 +1555,7 @@ def __init__(self, config): self.layer = nn.ModuleList( [BigBirdLayer(config, seed=layer_idx) for layer_idx in range(config.num_hidden_layers)] ) + self.gradient_checkpointing = False def set_attention_type(self, value: str): if value not in ["original_full", "block_sparse"]: @@ -1598,12 +1599,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -1756,6 +1756,7 @@ class BigBirdPreTrainedModel(PreTrainedModel): config_class = BigBirdConfig load_tf_weights = load_tf_weights_in_big_bird base_model_prefix = "bert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -1774,6 +1775,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BigBirdEncoder): + module.gradient_checkpointing = value + BIG_BIRD_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. Use diff --git a/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py index 28211c9b164f8b..297e2cede4dafc 100644 --- a/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py @@ -94,8 +94,6 @@ class BigBirdPegasusConfig(PretrainedConfig): "block_sparse"`. scale_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`) Whether to rescale embeddings with (hidden_size ** 0.5). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -141,7 +139,6 @@ def __init__( decoder_start_token_id=2, classifier_dropout=0.0, scale_embedding=True, - gradient_checkpointing=False, pad_token_id=0, bos_token_id=2, eos_token_id=1, @@ -170,7 +167,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True # extra config diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 2fd765eb5dd05f..536cd784daaf50 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1567,6 +1567,7 @@ def forward(self, hidden_states: torch.Tensor): class BigBirdPegasusPreTrainedModel(PreTrainedModel): config_class = BigBirdPegasusConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -1579,6 +1580,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (BigBirdPegasusDecoder, BigBirdPegasusEncoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -1764,6 +1769,7 @@ def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embed self.layernorm_embedding = nn.LayerNorm(embed_dim) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -1894,7 +1900,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -2054,6 +2060,7 @@ def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embed self.layernorm_embedding = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -2225,12 +2232,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/blenderbot/configuration_blenderbot.py b/src/transformers/models/blenderbot/configuration_blenderbot.py index c2b272af034eac..13acbdf699aad3 100644 --- a/src/transformers/models/blenderbot/configuration_blenderbot.py +++ b/src/transformers/models/blenderbot/configuration_blenderbot.py @@ -78,8 +78,6 @@ class BlenderbotConfig(PretrainedConfig): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -128,7 +126,6 @@ def __init__( decoder_start_token_id=1, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, pad_token_id=0, bos_token_id=1, eos_token_id=2, @@ -155,7 +152,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index e6bc6f65714177..11e866594a3736 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -451,6 +451,7 @@ def forward( class BlenderbotPreTrainedModel(PreTrainedModel): config_class = BlenderbotConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -463,6 +464,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (BlenderbotDecoder, BlenderbotEncoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -644,6 +649,7 @@ def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -738,7 +744,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -808,6 +814,7 @@ def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -980,12 +987,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py b/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py index de8927a4ffe998..0f76e2e3ae0ea2 100644 --- a/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py @@ -78,8 +78,6 @@ class BlenderbotSmallConfig(PretrainedConfig): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -128,7 +126,6 @@ def __init__( decoder_start_token_id=1, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, pad_token_id=0, bos_token_id=1, eos_token_id=2, @@ -154,7 +151,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 81188488fe0a73..a15c8276c377aa 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -449,6 +449,7 @@ def forward( class BlenderbotSmallPreTrainedModel(PreTrainedModel): config_class = BlenderbotSmallConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -461,6 +462,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (BlenderbotSmallDecoder, BlenderbotSmallEncoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -645,6 +650,7 @@ def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embe self.layernorm_embedding = nn.LayerNorm(embed_dim) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -740,7 +746,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -808,6 +814,7 @@ def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embe self.layernorm_embedding = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -981,12 +988,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/canine/configuration_canine.py b/src/transformers/models/canine/configuration_canine.py index 3feef5ac75beb8..79be54a8247b5e 100644 --- a/src/transformers/models/canine/configuration_canine.py +++ b/src/transformers/models/canine/configuration_canine.py @@ -61,8 +61,6 @@ class CanineConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. downsampling_rate (:obj:`int`, `optional`, defaults to 4): The rate at which to downsample the original character sequence length before applying the deep Transformer encoder. diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 18ca01031cc43c..a13505d3a05260 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -772,6 +772,7 @@ def __init__( for _ in range(config.num_hidden_layers) ] ) + self.gradient_checkpointing = False def forward( self, @@ -791,7 +792,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -895,6 +896,7 @@ class CaninePreTrainedModel(PreTrainedModel): config_class = CanineConfig load_tf_weights = load_tf_weights_in_canine base_model_prefix = "canine" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -913,6 +915,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, CanineEncoder): + module.gradient_checkpointing = value + CANINE_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. Use diff --git a/src/transformers/models/clip/configuration_clip.py b/src/transformers/models/clip/configuration_clip.py index b824288711958c..0f8b6fa9a4de0e 100644 --- a/src/transformers/models/clip/configuration_clip.py +++ b/src/transformers/models/clip/configuration_clip.py @@ -68,8 +68,6 @@ class CLIPTextConfig(PretrainedConfig): initializer_factor (:obj:`float`, `optional`, defaults to 1): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -103,7 +101,6 @@ def __init__( pad_token_id=1, bos_token_id=0, eos_token_id=2, - gradient_checkpointing=False, **kwargs ): super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @@ -120,7 +117,6 @@ def __init__( self.initializer_range = initializer_range self.initializer_factor = initializer_factor self.attention_dropout = attention_dropout - self.gradient_checkpointing = gradient_checkpointing class CLIPVisionConfig(PretrainedConfig): @@ -161,8 +157,6 @@ class CLIPVisionConfig(PretrainedConfig): initializer_factor (:obj:`float`, `optional`, defaults to 1): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -194,7 +188,6 @@ def __init__( attention_dropout=0.0, initializer_range=0.02, initializer_factor=1.0, - gradient_checkpointing=False, **kwargs ): super().__init__(**kwargs) @@ -211,7 +204,6 @@ def __init__( self.attention_dropout = attention_dropout self.layer_norm_eps = layer_norm_eps self.hidden_act = hidden_act - self.gradient_checkpointing = gradient_checkpointing class CLIPConfig(PretrainedConfig): diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 8d723e05fc293c..4f3b280a1bc5b3 100755 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -338,6 +338,7 @@ class CLIPPreTrainedModel(PreTrainedModel): config_class = CLIPConfig base_model_prefix = "clip" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -383,6 +384,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, CLIPEncoder): + module.gradient_checkpointing = value + CLIP_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ subclass. Use @@ -499,6 +504,7 @@ def __init__(self, config: CLIPConfig): super().__init__() self.config = config self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -551,7 +557,7 @@ def forward( for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index fbd0cdfc5ec7c6..99d8ae5dd4cf8c 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -248,6 +248,7 @@ class ConvBertPreTrainedModel(PreTrainedModel): config_class = ConvBertConfig load_tf_weights = load_tf_weights_in_convbert base_model_prefix = "convbert" + supports_gradient_checkpointing = True authorized_missing_keys = [r"position_ids"] authorized_unexpected_keys = [r"convbert\.embeddings_project\.weight", r"convbert\.embeddings_project\.bias"] @@ -267,6 +268,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ConvBertEncoder): + module.gradient_checkpointing = value + class SeparableConv1D(nn.Module): """This class implements separable convolution, i.e. a depthwise and a pointwise layer""" @@ -603,6 +608,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([ConvBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -624,7 +630,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/deit/configuration_deit.py b/src/transformers/models/deit/configuration_deit.py index 0bbbff709b83f7..98bbe1b01ba8f0 100644 --- a/src/transformers/models/deit/configuration_deit.py +++ b/src/transformers/models/deit/configuration_deit.py @@ -58,8 +58,6 @@ class DeiTConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. image_size (:obj:`int`, `optional`, defaults to :obj:`224`): The size (resolution) of each image. patch_size (:obj:`int`, `optional`, defaults to :obj:`16`): diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index b848376817b157..6ffa6afa3a5bf7 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -324,6 +324,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([DeiTLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -342,7 +343,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -384,6 +385,7 @@ class DeiTPreTrainedModel(PreTrainedModel): config_class = DeiTConfig base_model_prefix = "deit" + supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" @@ -401,6 +403,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, DeiTEncoder): + module.gradient_checkpointing = value + DEIT_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ subclass. Use diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 3061addadaed97..af650e75e1a6cc 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -783,6 +783,7 @@ def forward(self, hidden_states: torch.Tensor): class DetrPreTrainedModel(PreTrainedModel): config_class = DetrConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -807,6 +808,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, DetrDecoder): + module.gradient_checkpointing = value + DETR_START_DOCSTRING = r""" This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic @@ -997,6 +1002,7 @@ def __init__(self, config: DetrConfig): self.layernorm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -1084,7 +1090,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): continue - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/dpr/configuration_dpr.py b/src/transformers/models/dpr/configuration_dpr.py index 2773835f721cd7..a9b5f96556c783 100644 --- a/src/transformers/models/dpr/configuration_dpr.py +++ b/src/transformers/models/dpr/configuration_dpr.py @@ -69,8 +69,6 @@ class DPRConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on @@ -99,7 +97,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - gradient_checkpointing=False, position_embedding_type="absolute", projection_dim: int = 0, **kwargs @@ -118,6 +115,5 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing self.projection_dim = projection_dim self.position_embedding_type = position_embedding_type diff --git a/src/transformers/models/dpr/modeling_dpr.py b/src/transformers/models/dpr/modeling_dpr.py index 37fa61b706513e..c1a3fa618d4eb1 100644 --- a/src/transformers/models/dpr/modeling_dpr.py +++ b/src/transformers/models/dpr/modeling_dpr.py @@ -30,7 +30,7 @@ from ...modeling_outputs import BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import logging -from ..bert.modeling_bert import BertModel +from ..bert.modeling_bert import BertEncoder, BertModel from .configuration_dpr import DPRConfig @@ -300,6 +300,10 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel): def init_weights(self): self.question_encoder.init_weights() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BertEncoder): + module.gradient_checkpointing = value + class DPRPretrainedReader(PreTrainedModel): """ @@ -317,6 +321,10 @@ def init_weights(self): self.span_predictor.qa_classifier.apply(self.span_predictor.encoder.bert_model._init_weights) self.span_predictor.qa_outputs.apply(self.span_predictor.encoder.bert_model._init_weights) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BertEncoder): + module.gradient_checkpointing = value + ############### # Actual Models diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 867a7c09151327..1f44b23522b775 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -527,6 +527,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([ElectraLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -553,12 +554,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -663,6 +663,7 @@ class ElectraPreTrainedModel(PreTrainedModel): config_class = ElectraConfig load_tf_weights = load_tf_weights_in_electra base_model_prefix = "electra" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_unexpected = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"] @@ -683,6 +684,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ElectraEncoder): + module.gradient_checkpointing = value + @dataclass class ElectraForPreTrainingOutput(ModelOutput): diff --git a/src/transformers/models/fnet/configuration_fnet.py b/src/transformers/models/fnet/configuration_fnet.py index 047190a3ed24c0..a6922f835588f5 100644 --- a/src/transformers/models/fnet/configuration_fnet.py +++ b/src/transformers/models/fnet/configuration_fnet.py @@ -64,8 +64,6 @@ class FNetConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. use_tpu_fourier_optimizations (:obj:`bool`, `optional`, defaults to :obj:`False`): Determines whether to use TPU optimized FFTs. If :obj:`True`, the model will favor axis-wise FFTs transforms. Set to :obj:`False` for GPU/CPU hardware, in which case n-dimensional FFTs are used. diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 2a1b7f5f2ab25f..9340eb04f3c43a 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -284,6 +284,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([FNetLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward(self, hidden_states, output_hidden_states=False, return_dict=True): all_hidden_states = () if output_hidden_states else None @@ -292,7 +293,7 @@ def forward(self, hidden_states, output_hidden_states=False, return_dict=True): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -413,6 +414,7 @@ class FNetPreTrainedModel(PreTrainedModel): config_class = FNetConfig base_model_prefix = "fnet" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -432,6 +434,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, FNetEncoder): + module.gradient_checkpointing = value + @dataclass class FNetForPreTrainingOutput(ModelOutput): diff --git a/src/transformers/models/gpt2/configuration_gpt2.py b/src/transformers/models/gpt2/configuration_gpt2.py index f003023ca8b058..41120c94daad6c 100644 --- a/src/transformers/models/gpt2/configuration_gpt2.py +++ b/src/transformers/models/gpt2/configuration_gpt2.py @@ -108,9 +108,7 @@ class GPT2Config(PretrainedConfig): The dropout ratio to be used after the projection and activation. scale_attn_weights (:obj:`bool`, `optional`, defaults to :obj:`True`): - Scale attention weights by dividing by sqrt(hidden_size). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass. + Scale attention weights by dividing by sqrt(hidden_size).. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). @@ -158,7 +156,6 @@ def __init__( summary_proj_to_labels=True, summary_first_dropout=0.1, scale_attn_weights=True, - gradient_checkpointing=False, use_cache=True, bos_token_id=50256, eos_token_id=50256, @@ -182,7 +179,6 @@ def __init__( self.summary_activation = summary_activation self.summary_first_dropout = summary_first_dropout self.summary_proj_to_labels = summary_proj_to_labels - self.gradient_checkpointing = gradient_checkpointing self.scale_attn_weights = scale_attn_weights self.use_cache = use_cache diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 43419e66151d95..d6fab7f7ff24b2 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -374,6 +374,7 @@ class GPT2PreTrainedModel(PreTrainedModel): load_tf_weights = load_tf_weights_in_gpt2 base_model_prefix = "transformer" is_parallelizable = True + supports_gradient_checkpointing = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -394,6 +395,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GPT2Model): + module.gradient_checkpointing = value + @dataclass class GPT2DoubleHeadsModelOutput(ModelOutput): @@ -589,6 +594,7 @@ def __init__(self, config): # Model parallel self.model_parallel = False self.device_map = None + self.gradient_checkpointing = False @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): @@ -764,12 +770,11 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/gpt_neo/configuration_gpt_neo.py b/src/transformers/models/gpt_neo/configuration_gpt_neo.py index e5b7e683d99adb..d5069fb017112f 100644 --- a/src/transformers/models/gpt_neo/configuration_gpt_neo.py +++ b/src/transformers/models/gpt_neo/configuration_gpt_neo.py @@ -79,8 +79,6 @@ class GPTNeoConfig(PretrainedConfig): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if ``config.is_decoder=True``. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -120,7 +118,6 @@ def __init__( summary_activation=None, summary_proj_to_labels=True, summary_first_dropout=0.1, - gradient_checkpointing=False, use_cache=True, bos_token_id=50256, eos_token_id=50256, @@ -144,7 +141,6 @@ def __init__( self.summary_activation = summary_activation self.summary_first_dropout = summary_first_dropout self.summary_proj_to_labels = summary_proj_to_labels - self.gradient_checkpointing = gradient_checkpointing self.use_cache = use_cache self.bos_token_id = bos_token_id diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 353d3b0fb6cec6..3fafd75ac21a30 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -361,6 +361,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel): config_class = GPTNeoConfig load_tf_weights = load_tf_weights_in_gpt_neo base_model_prefix = "transformer" + supports_gradient_checkpointing = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -381,6 +382,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GPTNeoModel): + module.gradient_checkpointing = value + GPT_NEO_START_DOCSTRING = r""" @@ -482,6 +487,7 @@ def __init__(self, config): self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.wte @@ -592,12 +598,11 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/gptj/configuration_gptj.py b/src/transformers/models/gptj/configuration_gptj.py index 93018fdcb60b54..61dfd4e6639386 100644 --- a/src/transformers/models/gptj/configuration_gptj.py +++ b/src/transformers/models/gptj/configuration_gptj.py @@ -68,8 +68,6 @@ class GPTJConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. scale_attn_weights (:obj:`bool`, `optional`, defaults to :obj:`True`): Scale attention weights by dividing by sqrt(hidden_size). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). @@ -111,7 +109,6 @@ def __init__( layer_norm_epsilon=1e-5, initializer_range=0.02, scale_attn_weights=True, - gradient_checkpointing=False, use_cache=True, bos_token_id=50256, eos_token_id=50256, @@ -131,7 +128,6 @@ def __init__( self.attn_pdrop = attn_pdrop self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range - self.gradient_checkpointing = gradient_checkpointing self.scale_attn_weights = scale_attn_weights self.use_cache = use_cache diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 2d7781a2758be3..a23da0834711b8 100755 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -303,6 +303,7 @@ class GPTJPreTrainedModel(PreTrainedModel): config_class = GPTJConfig base_model_prefix = "transformer" is_parallelizable = True + supports_gradient_checkpointing = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -323,6 +324,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GPTJModel): + module.gradient_checkpointing = value + GPTJ_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. Use @@ -445,6 +450,7 @@ def __init__(self, config): # Model parallel self.model_parallel = False self.device_map = None + self.gradient_checkpointing = False @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): @@ -598,12 +604,11 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/hubert/configuration_hubert.py b/src/transformers/models/hubert/configuration_hubert.py index 633807684fccac..624211431c55d5 100644 --- a/src/transformers/models/hubert/configuration_hubert.py +++ b/src/transformers/models/hubert/configuration_hubert.py @@ -120,8 +120,6 @@ class HubertConfig(PretrainedConfig): instance of :class:`~transformers.HubertForSequenceClassification`. classifier_proj_size (:obj:`int`, `optional`, defaults to 256): Dimensionality of the projection before token mean-pooling for classification. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -172,7 +170,6 @@ def __init__( ctc_zero_infinity=False, use_weighted_layer_sum=False, classifier_proj_size=256, - gradient_checkpointing=False, pad_token_id=0, bos_token_id=1, eos_token_id=2, @@ -203,7 +200,6 @@ def __init__( self.initializer_range = initializer_range self.vocab_size = vocab_size self.do_stable_layer_norm = do_stable_layer_norm - self.gradient_checkpointing = gradient_checkpointing self.use_weighted_layer_sum = use_weighted_layer_sum self.classifier_proj_size = classifier_proj_size diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 95d5c91f5ae837..6575f4932b9114 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -525,6 +525,7 @@ def __init__(self, config): self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -564,7 +565,7 @@ def forward( skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: # create gradient checkpointing function def create_custom_forward(module): def custom_forward(*inputs): @@ -612,6 +613,7 @@ def __init__(self, config): self.layers = nn.ModuleList( [HubertEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] ) + self.gradient_checkpointing = False def forward( self, @@ -651,7 +653,7 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: # create gradient checkpointing function def create_custom_forward(module): def custom_forward(*inputs): @@ -698,6 +700,7 @@ class HubertPreTrainedModel(PreTrainedModel): config_class = HubertConfig base_model_prefix = "hubert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -725,6 +728,10 @@ def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (HubertEncoder, HubertEncoderStableLayerNorm)): + module.gradient_checkpointing = value + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ Computes the output length of the convolutional layers diff --git a/src/transformers/models/ibert/modeling_ibert.py b/src/transformers/models/ibert/modeling_ibert.py index 61c775d9fff8bc..d4f74ff47e7dcd 100644 --- a/src/transformers/models/ibert/modeling_ibert.py +++ b/src/transformers/models/ibert/modeling_ibert.py @@ -579,17 +579,13 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: - raise NotImplementedError("gradient checkpointing is not currently supported") - - else: - layer_outputs = layer_module( - hidden_states, - hidden_states_scaling_factor, - attention_mask, - layer_head_mask, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + hidden_states_scaling_factor, + attention_mask, + layer_head_mask, + output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/layoutlm/configuration_layoutlm.py b/src/transformers/models/layoutlm/configuration_layoutlm.py index 4c23cde9a5b862..61a6ce264d0d5c 100644 --- a/src/transformers/models/layoutlm/configuration_layoutlm.py +++ b/src/transformers/models/layoutlm/configuration_layoutlm.py @@ -71,8 +71,6 @@ class LayoutLMConfig(BertConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. max_2d_position_embeddings (:obj:`int`, `optional`, defaults to 1024): The maximum value that the 2D position embedding might ever used. Typically set this to something large just in case (e.g., 1024). @@ -108,7 +106,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - gradient_checkpointing=False, max_2d_position_embeddings=1024, **kwargs ): @@ -126,7 +123,6 @@ def __init__( initializer_range=initializer_range, layer_norm_eps=layer_norm_eps, pad_token_id=pad_token_id, - gradient_checkpointing=gradient_checkpointing, **kwargs, ) self.max_2d_position_embeddings = max_2d_position_embeddings diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 3e7dfe8560c745..b47d2793d141db 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -442,6 +442,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([LayoutLMLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -468,12 +469,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -609,6 +609,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel): config_class = LayoutLMConfig pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST base_model_prefix = "layoutlm" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -627,6 +628,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LayoutLMEncoder): + module.gradient_checkpointing = value + LAYOUTLM_START_DOCSTRING = r""" The LayoutLM model was proposed in `LayoutLM: Pre-training of Text and Layout for Document Image Understanding diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index e42d77bab26217..6c42ce1ccc9a50 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -378,6 +378,8 @@ def __init__(self, config): self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False) self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False) + self.gradient_checkpointing = False + def _calculate_1d_position_embeddings(self, hidden_states, position_ids): rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1) rel_pos = relative_position_bucket( @@ -443,7 +445,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -502,6 +504,7 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel): config_class = LayoutLMv2Config pretrained_model_archive_map = LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST base_model_prefix = "layoutlmv2" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -520,6 +523,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LayoutLMv2Encoder): + module.gradient_checkpointing = value + def my_convert_sync_batchnorm(module, process_group=None): # same as `nn.modules.SyncBatchNorm.convert_sync_batchnorm` but allowing converting from `detectron2.layers.FrozenBatchNorm2d` diff --git a/src/transformers/models/led/configuration_led.py b/src/transformers/models/led/configuration_led.py index 5992d275ed9f12..e30c3e04c4f5c6 100644 --- a/src/transformers/models/led/configuration_led.py +++ b/src/transformers/models/led/configuration_led.py @@ -82,8 +82,6 @@ class LEDConfig(PretrainedConfig): https://arxiv.org/abs/1909.11556>`__ for more details. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models) - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -132,7 +130,6 @@ def __init__( pad_token_id=1, bos_token_id=0, eos_token_id=2, - gradient_checkpointing=False, attention_window: Union[List[int], int] = 512, **kwargs ): @@ -157,7 +154,6 @@ def __init__( self.use_cache = use_cache self.num_hidden_layers = encoder_layers self.attention_window = attention_window - self.gradient_checkpointing = gradient_checkpointing super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index c1c5af6d1ec191..926da161a97d1d 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1077,6 +1077,7 @@ def forward(self, hidden_states: torch.Tensor): class LEDPreTrainedModel(PreTrainedModel): config_class = LEDConfig base_model_prefix = "led" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -1089,6 +1090,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (LEDDecoder, LEDEncoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -1625,6 +1630,7 @@ def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = Non self.layernorm_embedding = nn.LayerNorm(embed_dim) self.init_weights() + self.gradient_checkpointing = False def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor): # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) @@ -1809,7 +1815,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -1894,6 +1900,7 @@ def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = Non self.layernorm_embedding = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -2061,12 +2068,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 6fbdfb12f57c1a..3e327c5c688e32 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1231,6 +1231,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([LongformerLayer(config, layer_id=i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -1259,7 +1260,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -1363,6 +1364,7 @@ class LongformerPreTrainedModel(PreTrainedModel): config_class = LongformerConfig base_model_prefix = "longformer" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -1381,6 +1383,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LongformerEncoder): + module.gradient_checkpointing = value + LONGFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/luke/configuration_luke.py b/src/transformers/models/luke/configuration_luke.py index befd3e45e5de65..ba6dc496438648 100644 --- a/src/transformers/models/luke/configuration_luke.py +++ b/src/transformers/models/luke/configuration_luke.py @@ -68,8 +68,6 @@ class LukeConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. use_entity_aware_attention (:obj:`bool`, defaults to :obj:`True`): Whether or not the model should use the entity-aware self-attention mechanism proposed in `LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention (Yamada et al.) @@ -106,7 +104,6 @@ def __init__( type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, - gradient_checkpointing=False, use_entity_aware_attention=True, pad_token_id=1, bos_token_id=0, @@ -130,5 +127,4 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing self.use_entity_aware_attention = use_entity_aware_attention diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index b9004c1d4970b6..97d1f1adfd9c5e 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -579,6 +579,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([LukeLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -600,7 +601,7 @@ def forward( all_entity_hidden_states = all_entity_hidden_states + (entity_hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -681,6 +682,7 @@ class LukePreTrainedModel(PreTrainedModel): config_class = LukeConfig base_model_prefix = "luke" + supports_gradient_checkpointing = True def _init_weights(self, module: nn.Module): """Initialize the weights""" @@ -699,6 +701,10 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LukeEncoder): + module.gradient_checkpointing = value + LUKE_START_DOCSTRING = r""" diff --git a/src/transformers/models/m2m_100/configuration_m2m_100.py b/src/transformers/models/m2m_100/configuration_m2m_100.py index 765bcb4cd1b48b..a4a0df749c29a4 100644 --- a/src/transformers/models/m2m_100/configuration_m2m_100.py +++ b/src/transformers/models/m2m_100/configuration_m2m_100.py @@ -79,8 +79,6 @@ class M2M100Config(PretrainedConfig): https://arxiv.org/abs/1909.11556>`__ for more details. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -121,7 +119,6 @@ def __init__( init_std=0.02, decoder_start_token_id=2, scale_embedding=True, - gradient_checkpointing=False, pad_token_id=1, bos_token_id=0, eos_token_id=2, @@ -145,7 +142,6 @@ def __init__( self.decoder_layerdrop = decoder_layerdrop self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index ce86fe2c775ad5..9bb15c918a8d8a 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -520,6 +520,7 @@ def forward( class M2M100PreTrainedModel(PreTrainedModel): config_class = M2M100Config base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -532,6 +533,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (M2M100Decoder, M2M100Encoder)): + module.gradient_checkpointing = value + M2M_100_START_DOCSTRING = r""" This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic @@ -693,6 +698,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -787,7 +793,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -857,6 +863,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -1013,12 +1020,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/marian/configuration_marian.py b/src/transformers/models/marian/configuration_marian.py index 1b974badfa6679..825c7d707a737e 100644 --- a/src/transformers/models/marian/configuration_marian.py +++ b/src/transformers/models/marian/configuration_marian.py @@ -78,8 +78,6 @@ class MarianConfig(PretrainedConfig): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -128,7 +126,6 @@ def __init__( decoder_start_token_id=58100, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, pad_token_id=58100, eos_token_id=0, forced_eos_token_id=0, @@ -153,7 +150,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index e2feb549b70f2a..a2df6373503843 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -466,6 +466,7 @@ def forward( class MarianPreTrainedModel(PreTrainedModel): config_class = MarianConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -480,6 +481,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (MarianDecoder, MarianEncoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -656,6 +661,7 @@ def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = ) self.layers = nn.ModuleList([MarianEncoderLayer(config) for _ in range(config.encoder_layers)]) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -750,7 +756,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -816,6 +822,7 @@ def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = ) self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)]) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -987,12 +994,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/mbart/configuration_mbart.py b/src/transformers/models/mbart/configuration_mbart.py index 05857241b4baab..d1eb27c0e808bc 100644 --- a/src/transformers/models/mbart/configuration_mbart.py +++ b/src/transformers/models/mbart/configuration_mbart.py @@ -82,8 +82,6 @@ class MBartConfig(PretrainedConfig): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -131,7 +129,6 @@ def __init__( init_std=0.02, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, pad_token_id=1, bos_token_id=0, eos_token_id=2, @@ -157,7 +154,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 0412eccaaab7af..0ebb5a1a8f348f 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -479,6 +479,7 @@ def forward(self, hidden_states: torch.Tensor): class MBartPreTrainedModel(PreTrainedModel): config_class = MBartConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -491,6 +492,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (MBartDecoder, MBartDecoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -685,6 +690,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -780,7 +786,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -850,6 +856,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -1022,12 +1029,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/megatron_bert/configuration_megatron_bert.py b/src/transformers/models/megatron_bert/configuration_megatron_bert.py index 19171e70da1bc2..d6e32cd4963095 100644 --- a/src/transformers/models/megatron_bert/configuration_megatron_bert.py +++ b/src/transformers/models/megatron_bert/configuration_megatron_bert.py @@ -65,8 +65,6 @@ class MegatronBertConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on @@ -108,7 +106,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - gradient_checkpointing=False, position_embedding_type="absolute", use_cache=True, **kwargs @@ -127,6 +124,5 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing self.position_embedding_type = position_embedding_type self.use_cache = use_cache diff --git a/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py b/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py index 3d7f03dcbb767c..1d33ef91e624dd 100644 --- a/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py +++ b/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py @@ -180,7 +180,6 @@ def convert_megatron_checkpoint(args, input_state_dict): "type_vocab_size": 2, "initializer_range": 0.2, "layer_norm_eps": 1e-12, - "gradient_checkpointing": False, "position_embedding_type": "absolute", "use_cache": False, } diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 3c49ea88b87342..80337b2dabf9a2 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -508,6 +508,7 @@ def __init__(self, config): # The final layer norm. We removed the 1st LN, moved LN to each hidden layer and this one # is simply the final LN (Transformer's BERT has it attached to each hidden layer). self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.gradient_checkpointing = False def forward( self, @@ -534,12 +535,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warn( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -705,6 +705,7 @@ class MegatronBertPreTrainedModel(PreTrainedModel): config_class = MegatronBertConfig load_tf_weights = load_tf_weights_in_megatron_bert base_model_prefix = "bert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -719,6 +720,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, MegatronBertEncoder): + module.gradient_checkpointing = value + @dataclass # Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->MegatronBert diff --git a/src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py b/src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py index 72271885bb44bb..57d42a117175a2 100644 --- a/src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py +++ b/src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py @@ -279,7 +279,6 @@ def main(): summary_proj_to_labels=True, summary_first_dropout=0.1, scale_attn_weights=True, - gradient_checkpointing=False, use_cache=True, bos_token_id=50256, eos_token_id=50256, diff --git a/src/transformers/models/pegasus/configuration_pegasus.py b/src/transformers/models/pegasus/configuration_pegasus.py index 2e815c2e486b59..8cf76c482bc161 100644 --- a/src/transformers/models/pegasus/configuration_pegasus.py +++ b/src/transformers/models/pegasus/configuration_pegasus.py @@ -78,8 +78,6 @@ class PegasusConfig(PretrainedConfig): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -128,7 +126,6 @@ def __init__( decoder_start_token_id=0, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, pad_token_id=0, eos_token_id=1, forced_eos_token_id=1, @@ -153,7 +150,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index ab1009b3393776..2728f144b352b9 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -466,6 +466,7 @@ def forward( class PegasusPreTrainedModel(PreTrainedModel): config_class = PegasusConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -480,6 +481,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (PegasusDecoder, PegasusEncoder)): + module.gradient_checkpointing = value + PEGASUS_START_DOCSTRING = r""" This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic @@ -646,6 +651,7 @@ def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def resize_position_embeddings(self, new_num_position_embeddings: int): """ @@ -770,7 +776,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -840,6 +846,7 @@ def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -1040,12 +1047,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/prophetnet/configuration_prophetnet.py b/src/transformers/models/prophetnet/configuration_prophetnet.py index c19e4a106f2b85..074bad3e24d84b 100644 --- a/src/transformers/models/prophetnet/configuration_prophetnet.py +++ b/src/transformers/models/prophetnet/configuration_prophetnet.py @@ -92,8 +92,6 @@ class ProphetNetConfig(PretrainedConfig): smoothing is performed. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. """ model_type = "prophetnet" keys_to_ignore_at_inference = ["past_key_values"] @@ -124,7 +122,6 @@ def __init__( num_buckets=32, relative_max_distance=128, disable_ngram_loss=False, - gradient_checkpointing=False, eps=0.0, use_cache=True, pad_token_id=0, @@ -158,9 +155,6 @@ def __init__( self.use_cache = use_cache - # 4 Training Args (should be removed soon) - self.gradient_checkpointing = gradient_checkpointing - super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index ed4c7926574977..9f72a35f0dfd27 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -547,6 +547,7 @@ class ProphetNetDecoderLMOutput(ModelOutput): class ProphetNetPreTrainedModel(PreTrainedModel): config_class = ProphetNetConfig base_model_prefix = "prophetnet" + supports_gradient_checkpointing = True def _init_weights(self, module): if isinstance(module, nn.Linear): @@ -558,6 +559,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (ProphetNetDecoder, ProphetNetEncoder)): + module.gradient_checkpointing = value + def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id pad_token_id = self.config.pad_token_id @@ -1262,6 +1267,7 @@ def __init__(self, config: ProphetNetConfig, word_embeddings: nn.Embedding = Non self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)]) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.word_embeddings @@ -1337,7 +1343,7 @@ def forward( if output_hidden_states: encoder_hidden_states = encoder_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -1406,6 +1412,7 @@ def __init__(self, config: ProphetNetConfig, word_embeddings: nn.Embedding = Non self.embeddings_layer_norm = LayerNorm(config.hidden_size) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.word_embeddings @@ -1566,12 +1573,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/rembert/configuration_rembert.py b/src/transformers/models/rembert/configuration_rembert.py index d9432d20a985ab..51c899dfc9856b 100644 --- a/src/transformers/models/rembert/configuration_rembert.py +++ b/src/transformers/models/rembert/configuration_rembert.py @@ -76,8 +76,6 @@ class RemBertConfig(PretrainedConfig): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if ``config.is_decoder=True``. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 46524ce9cbb8b6..ab3874865afc9c 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -501,6 +501,7 @@ def __init__(self, config): self.embedding_hidden_mapping_in = nn.Linear(config.input_embedding_size, config.hidden_size) self.layer = nn.ModuleList([RemBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -528,12 +529,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warn( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -648,6 +648,7 @@ class RemBertPreTrainedModel(PreTrainedModel): config_class = RemBertConfig load_tf_weights = load_tf_weights_in_rembert base_model_prefix = "rembert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -666,6 +667,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, RemBertEncoder): + module.gradient_checkpointing = value + REMBERT_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. Use diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 09472da7674a7f..f74954ac6428b4 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -469,6 +469,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -495,12 +496,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -585,6 +585,7 @@ class RobertaPreTrainedModel(PreTrainedModel): config_class = RobertaConfig base_model_prefix = "roberta" + supports_gradient_checkpointing = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): @@ -603,6 +604,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, RobertaEncoder): + module.gradient_checkpointing = value + def update_keys_to_ignore(self, config, del_keys_to_ignore): """Remove some keys from ignore list""" if not config.tie_word_embeddings: diff --git a/src/transformers/models/roformer/configuration_roformer.py b/src/transformers/models/roformer/configuration_roformer.py index 945d1064a10ea8..5027b3be1fb8b1 100644 --- a/src/transformers/models/roformer/configuration_roformer.py +++ b/src/transformers/models/roformer/configuration_roformer.py @@ -80,8 +80,6 @@ class RoFormerConfig(PretrainedConfig): relevant if ``config.is_decoder=True``. rotary_value (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not apply rotary position embeddings on value layer. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -114,7 +112,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - gradient_checkpointing=False, rotary_value=False, use_cache=True, **kwargs @@ -134,6 +131,5 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing self.rotary_value = rotary_value self.use_cache = use_cache diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index f08d3e5c8f6ac5..23929a4c613168 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -551,6 +551,7 @@ def __init__(self, config): config.max_position_embeddings, config.hidden_size // config.num_attention_heads ) self.layer = nn.ModuleList([RoFormerLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -580,12 +581,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -705,6 +705,7 @@ class RoFormerPreTrainedModel(PreTrainedModel): config_class = RoFormerConfig load_tf_weights = load_tf_weights_in_roformer base_model_prefix = "roformer" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [] _keys_to_ignore_on_load_unexpected = [ r"roformer\.embeddings_project\.weight", @@ -729,6 +730,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, RoFormerEncoder): + module.gradient_checkpointing = value + ROFORMER_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. Use diff --git a/src/transformers/models/speech_to_text/configuration_speech_to_text.py b/src/transformers/models/speech_to_text/configuration_speech_to_text.py index ff16601030dbb9..821362d2e636b2 100644 --- a/src/transformers/models/speech_to_text/configuration_speech_to_text.py +++ b/src/transformers/models/speech_to_text/configuration_speech_to_text.py @@ -134,7 +134,6 @@ def __init__( decoder_start_token_id=2, classifier_dropout=0.0, scale_embedding=True, - gradient_checkpointing=False, pad_token_id=1, bos_token_id=0, eos_token_id=2, @@ -165,7 +164,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.max_source_positions = max_source_positions self.max_target_positions = max_target_positions diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index ce19d680abf70b..e91af884c6c5eb 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -531,6 +531,7 @@ def forward( class Speech2TextPreTrainedModel(PreTrainedModel): config_class = Speech2TextConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -543,6 +544,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Speech2TextDecoder, Speech2TextEncoder)): + module.gradient_checkpointing = value + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ Computes the output length of the convolutional layers @@ -711,6 +716,7 @@ def __init__(self, config: Speech2TextConfig): self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -795,7 +801,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -863,6 +869,7 @@ def __init__(self, config: Speech2TextConfig): self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -1032,11 +1039,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." ) use_cache = False diff --git a/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py index f1f950599058fe..abeac09a105dc2 100644 --- a/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py @@ -108,7 +108,6 @@ def __init__( decoder_start_token_id=2, classifier_dropout=0.0, scale_embedding=True, - gradient_checkpointing=False, pad_token_id=1, bos_token_id=0, eos_token_id=2, @@ -130,7 +129,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = decoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.max_source_positions = max_source_positions self.max_target_positions = max_target_positions diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index 848df757c3b5f3..fbbbaa3cbf9da0 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -407,6 +407,7 @@ def forward( class Speech2Text2PreTrainedModel(PreTrainedModel): config_class = Speech2Text2Config base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -419,6 +420,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, Speech2Text2Decoder): + module.gradient_checkpointing = value + SPEECH_TO_TEXT_2_START_DOCSTRING = r""" This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic @@ -465,6 +470,7 @@ def __init__(self, config: Speech2Text2Config): self.layers = nn.ModuleList([Speech2Text2DecoderLayer(config) for _ in range(config.decoder_layers)]) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -635,11 +641,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." ) use_cache = False diff --git a/src/transformers/models/splinter/configuration_splinter.py b/src/transformers/models/splinter/configuration_splinter.py index 879451bbe50b65..986e436fe75702 100644 --- a/src/transformers/models/splinter/configuration_splinter.py +++ b/src/transformers/models/splinter/configuration_splinter.py @@ -71,8 +71,6 @@ class SplinterConfig(PretrainedConfig): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if ``config.is_decoder=True``. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. question_token_id (:obj:`int`, `optional`, defaults to 104): The id of the ``[QUESTION]`` token. diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 1296db12508ddf..381a280ebb2282 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -409,6 +409,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([SplinterLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -435,12 +436,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -509,6 +509,7 @@ class SplinterPreTrainedModel(PreTrainedModel): config_class = SplinterConfig base_model_prefix = "splinter" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights @@ -528,6 +529,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, SplinterEncoder): + module.gradient_checkpointing = value + SPLINTER_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. Use diff --git a/src/transformers/models/t5/configuration_t5.py b/src/transformers/models/t5/configuration_t5.py index 9a406591279e2f..bb16a5fb0f50b0 100644 --- a/src/transformers/models/t5/configuration_t5.py +++ b/src/transformers/models/t5/configuration_t5.py @@ -77,8 +77,6 @@ class T5Config(PretrainedConfig): the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. """ model_type = "t5" keys_to_ignore_at_inference = ["past_key_values"] @@ -102,7 +100,6 @@ def __init__( use_cache=True, pad_token_id=0, eos_token_id=1, - gradient_checkpointing=False, **kwargs ): self.vocab_size = vocab_size @@ -120,7 +117,6 @@ def __init__( self.initializer_factor = initializer_factor self.feed_forward_proj = feed_forward_proj self.use_cache = use_cache - self.gradient_checkpointing = gradient_checkpointing super().__init__( pad_token_id=pad_token_id, eos_token_id=eos_token_id, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 27ef440bfb152c..f18c9e66f5e60a 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -325,7 +325,7 @@ def __init__(self, config: T5Config, has_relative_attention_bias=False): if self.has_relative_attention_bias: self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) self.pruned_heads = set() - self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False) + self.gradient_checkpointing = False def prune_heads(self, heads): if len(heads) == 0: @@ -489,7 +489,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias = torch.zeros( (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype ) - if self.training and self.gradient_checkpointing: + if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: position_bias = self.compute_bias(real_seq_length, key_length) @@ -715,6 +715,7 @@ class T5PreTrainedModel(PreTrainedModel): load_tf_weights = load_tf_weights_in_t5 base_model_prefix = "transformer" is_parallelizable = True + supports_gradient_checkpointing = True @property def dummy_inputs(self): @@ -769,6 +770,10 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (T5Attention, T5Stack)): + module.gradient_checkpointing = value + def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id pad_token_id = self.config.pad_token_id @@ -813,6 +818,7 @@ def __init__(self, config, embed_tokens=None): # Model parallel self.model_parallel = False self.device_map = None + self.gradient_checkpointing = False @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): @@ -968,11 +974,10 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warn( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/tapas/configuration_tapas.py b/src/transformers/models/tapas/configuration_tapas.py index 834cae0c7ea60c..d59dc00f4515f0 100644 --- a/src/transformers/models/tapas/configuration_tapas.py +++ b/src/transformers/models/tapas/configuration_tapas.py @@ -73,8 +73,6 @@ class TapasConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to use gradient checkpointing to save memory at the expense of a slower backward pass. positive_label_weight (:obj:`float`, `optional`, defaults to 10.0): Weight for positive labels. num_aggregation_labels (:obj:`int`, `optional`, defaults to 0): @@ -159,7 +157,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - gradient_checkpointing=False, positive_label_weight=10.0, num_aggregation_labels=0, aggregation_loss_weight=1.0, @@ -202,7 +199,6 @@ def __init__( self.type_vocab_sizes = type_vocab_sizes self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing # Fine-tuning task hyperparameters self.positive_label_weight = positive_label_weight diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 29d4a3ef4f344b..9506216522dcb2 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -627,6 +627,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([TapasLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -649,7 +650,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -763,6 +764,7 @@ class TapasPreTrainedModel(PreTrainedModel): config_class = TapasConfig base_model_prefix = "tapas" + supports_gradient_checkpointing = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): @@ -781,6 +783,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, TapasEncoder): + module.gradient_checkpointing = value + TAPAS_START_DOCSTRING = r""" This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 21f5e01362ce8f..c6c01010081ece 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -398,6 +398,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -417,7 +418,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -532,6 +533,7 @@ class VisualBertPreTrainedModel(PreTrainedModel): config_class = VisualBertConfig base_model_prefix = "visual_bert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -547,6 +549,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, VisualBertEncoder): + module.gradient_checkpointing = value + @dataclass class VisualBertForPreTrainingOutput(ModelOutput): diff --git a/src/transformers/models/vit/configuration_vit.py b/src/transformers/models/vit/configuration_vit.py index 5e53df4cddfd7d..9c64be5141bcbb 100644 --- a/src/transformers/models/vit/configuration_vit.py +++ b/src/transformers/models/vit/configuration_vit.py @@ -57,8 +57,6 @@ class ViTConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. image_size (:obj:`int`, `optional`, defaults to :obj:`224`): The size (resolution) of each image. patch_size (:obj:`int`, `optional`, defaults to :obj:`16`): diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 5b147f2856324f..78911f7b4186d1 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -352,6 +352,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -370,7 +371,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -411,6 +412,7 @@ class ViTPreTrainedModel(PreTrainedModel): config_class = ViTConfig base_model_prefix = "vit" + supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" @@ -428,6 +430,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ViTEncoder): + module.gradient_checkpointing = value + VIT_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ subclass. Use diff --git a/src/transformers/models/wav2vec2/configuration_wav2vec2.py b/src/transformers/models/wav2vec2/configuration_wav2vec2.py index d82e6a6d3457c3..49818feb22df7d 100644 --- a/src/transformers/models/wav2vec2/configuration_wav2vec2.py +++ b/src/transformers/models/wav2vec2/configuration_wav2vec2.py @@ -138,8 +138,6 @@ class Wav2Vec2Config(PretrainedConfig): instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`. classifier_proj_size (:obj:`int`, `optional`, defaults to 256): Dimensionality of the projection before token mean-pooling for classification. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -198,7 +196,6 @@ def __init__( ctc_zero_infinity=False, use_weighted_layer_sum=False, classifier_proj_size=256, - gradient_checkpointing=False, pad_token_id=0, bos_token_id=1, eos_token_id=2, @@ -229,7 +226,6 @@ def __init__( self.initializer_range = initializer_range self.vocab_size = vocab_size self.do_stable_layer_norm = do_stable_layer_norm - self.gradient_checkpointing = gradient_checkpointing self.use_weighted_layer_sum = use_weighted_layer_sum self.classifier_proj_size = classifier_proj_size diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index ade54417f1dff9..71f431ca976d51 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -590,6 +590,7 @@ def __init__(self, config): self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -629,7 +630,7 @@ def forward( skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: # create gradient checkpointing function def create_custom_forward(module): def custom_forward(*inputs): @@ -676,6 +677,7 @@ def __init__(self, config): self.layers = nn.ModuleList( [Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] ) + self.gradient_checkpointing = False def forward( self, @@ -715,7 +717,7 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: # create gradient checkpointing function def create_custom_forward(module): def custom_forward(*inputs): @@ -842,6 +844,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): config_class = Wav2Vec2Config base_model_prefix = "wav2vec2" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -864,6 +867,10 @@ def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm)): + module.gradient_checkpointing = value + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ Computes the output length of the convolutional layers diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f5aa74616c47e5..d39a24bf46cb21 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -990,7 +990,7 @@ def _wrap_model(self, model, training=True): elif isinstance(model, PreTrainedModel): # find_unused_parameters breaks checkpointing as per # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 - find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False) + find_unused_parameters = not getattr(model.config, "_gradient_checkpointing", False) else: find_unused_parameters = True model = nn.parallel.DistributedDataParallel( @@ -1162,6 +1162,10 @@ def train( self.state = TrainerState() self.state.is_hyper_param_search = trial is not None + # Activate gradient checkpointing if needed + if args.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + model = self._wrap_model(self.model_wrapped) # for the rest of this function `model` is the outside model, whether it was wrapped or not diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index d34622abc03e9d..ce330a254c412b 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -372,6 +372,8 @@ class TrainingArguments: hub_token (:obj:`str`, `optional`): The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with :obj:`huggingface-cli login`. + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. """ output_dir: str = field( @@ -650,6 +652,12 @@ class TrainingArguments: metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."}, ) hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) + gradient_checkpointing: bool = field( + default=False, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) # Deprecated arguments push_to_hub_model_id: str = field( default=None, metadata={"help": "The name of the repository to which push the `Trainer`."} diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/configuration_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/configuration_{{cookiecutter.lowercase_modelname}}.py index 93da35a5d98be6..6978a3ddf3fd70 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/configuration_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/configuration_{{cookiecutter.lowercase_modelname}}.py @@ -72,8 +72,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if ``config.is_decoder=True``. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. {% else -%} vocab_size (:obj:`int`, `optional`, defaults to 50265): Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the @@ -186,7 +184,6 @@ def __init__( decoder_start_token_id=2, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, {% endif -%} pad_token_id=1, bos_token_id=0, @@ -225,7 +222,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True {% endif -%} diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 835382396cd51f..b0482f70621278 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -513,6 +513,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([{{cookiecutter.camelcase_modelname}}Layer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -539,12 +540,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -664,6 +664,7 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): config_class = {{cookiecutter.camelcase_modelname}}Config load_tf_weights = load_tf_weights_in_{{cookiecutter.lowercase_modelname}} base_model_prefix = "{{cookiecutter.lowercase_modelname}}" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -682,6 +683,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, {{cookiecutter.camelcase_modelname}}Encoder): + module.gradient_checkpointing = value + {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. @@ -2006,6 +2011,7 @@ def forward(self, hidden_states: torch.Tensor): class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): config_class = {{cookiecutter.camelcase_modelname}}Config base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -2017,16 +2023,10 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - - @property - def dummy_inputs(self): - pad_token = self.config.pad_token_id - input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) - dummy_inputs = { - "attention_mask": input_ids.ne(pad_token), - "input_ids": input_ids, - } - return dummy_inputs + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}Encoder)): + module.gradient_checkpointing = value {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" @@ -2213,6 +2213,7 @@ def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, embed_tok self.layernorm_embedding = nn.LayerNorm(embed_dim) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -2309,7 +2310,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -2376,6 +2377,7 @@ def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, embed_tok self.layernorm_embedding = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -2545,10 +2547,10 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: - logger.warning("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...") + logger.warning("`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`...") use_cache = False def create_custom_forward(module): diff --git a/tests/test_modeling_beit.py b/tests/test_modeling_beit.py index b4c670356f5c9f..6557936d59b0d7 100644 --- a/tests/test_modeling_beit.py +++ b/tests/test_modeling_beit.py @@ -224,6 +224,27 @@ def test_training(self): loss = model(**inputs).loss loss.backward() + def test_training_gradient_checkpointing(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if not self.model_tester.is_training: + return + + config.use_cache = False + config.return_dict = True + + for model_class in self.all_model_classes: + if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing: + continue + # we don't test BeitForMaskedImageModeling + if model_class.__name__ == "BeitForMaskedImageModeling": + continue + model = model_class(config) + model.to(torch_device) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f1a11871b0bf0b..b61cf834fbd90d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -370,15 +370,14 @@ def test_training(self): def test_training_gradient_checkpointing(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - if not self.model_tester.is_training or not hasattr(config, "gradient_checkpointing"): + if not self.model_tester.is_training: return - config.gradient_checkpointing = True config.use_cache = False config.return_dict = True for model_class in self.all_model_classes: - if model_class in get_values(MODEL_MAPPING): + if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing: continue model = model_class(config) model.to(torch_device) diff --git a/tests/test_modeling_deit.py b/tests/test_modeling_deit.py index c689a90af785fa..119e0988916b0e 100644 --- a/tests/test_modeling_deit.py +++ b/tests/test_modeling_deit.py @@ -20,6 +20,7 @@ from transformers import DeiTConfig from transformers.file_utils import cached_property, is_torch_available, is_vision_available +from transformers.models.auto import get_values from transformers.testing_utils import require_torch, require_vision, slow, torch_device from .test_configuration_common import ConfigTester @@ -340,7 +341,7 @@ def test_training(self): for model_class in self.all_model_classes: # DeiTForImageClassificationWithTeacher supports inference-only if ( - model_class in MODEL_MAPPING.values() + model_class in get_values(MODEL_MAPPING) or model_class.__name__ == "DeiTForImageClassificationWithTeacher" ): continue @@ -351,6 +352,27 @@ def test_training(self): loss = model(**inputs).loss loss.backward() + def test_training_gradient_checkpointing(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if not self.model_tester.is_training: + return + + config.use_cache = False + config.return_dict = True + + for model_class in self.all_model_classes: + if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing: + continue + # DeiTForImageClassificationWithTeacher supports inference-only + if model_class.__name__ == "DeiTForImageClassificationWithTeacher": + continue + model = model_class(config) + model.to(torch_device) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + def test_for_image_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_image_classification(*config_and_inputs) diff --git a/tests/test_modeling_flax_gpt2.py b/tests/test_modeling_flax_gpt2.py index 0c793ebd27b7d6..3b2e43680e601b 100644 --- a/tests/test_modeling_flax_gpt2.py +++ b/tests/test_modeling_flax_gpt2.py @@ -82,7 +82,7 @@ def __init__( self.eos_token_id = vocab_size - 1 self.pad_token_id = vocab_size - 1 - def prepare_config_and_inputs(self, gradient_checkpointing=False): + def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None @@ -100,7 +100,6 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False): bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, - gradient_checkpointing=gradient_checkpointing, ) return (config, input_ids, input_mask) diff --git a/tests/test_modeling_flax_gpt_neo.py b/tests/test_modeling_flax_gpt_neo.py index 2916bec5b94f12..7d0d832295a461 100644 --- a/tests/test_modeling_flax_gpt_neo.py +++ b/tests/test_modeling_flax_gpt_neo.py @@ -86,7 +86,7 @@ def __init__( self.eos_token_id = vocab_size - 1 self.pad_token_id = vocab_size - 1 - def prepare_config_and_inputs(self, gradient_checkpointing=False): + def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None @@ -105,7 +105,6 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False): pad_token_id=self.pad_token_id, window_size=self.window_size, attention_types=self.attention_types, - gradient_checkpointing=gradient_checkpointing, ) return (config, input_ids, input_mask) diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index 91d2edcdc8382a..214a17f0508ffa 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -96,7 +96,7 @@ def __init__( def get_large_model_config(self): return GPT2Config.from_pretrained("gpt2") - def prepare_config_and_inputs(self, gradient_checkpointing=False): + def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None @@ -119,7 +119,7 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False): token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) - config = self.get_config(gradient_checkpointing=gradient_checkpointing) + config = self.get_config() head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) @@ -135,7 +135,7 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False): choice_labels, ) - def get_config(self, gradient_checkpointing=False): + def get_config(self): return GPT2Config( vocab_size=self.vocab_size, n_embd=self.hidden_size, @@ -149,11 +149,10 @@ def get_config(self, gradient_checkpointing=False): n_ctx=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, initializer_range=self.initializer_range, - use_cache=not gradient_checkpointing, + use_cache=True, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, - gradient_checkpointing=gradient_checkpointing, ) def prepare_config_and_inputs_for_decoder(self): @@ -322,9 +321,13 @@ def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mas self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + def create_and_check_forward_and_backwards( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False + ): model = GPT2LMHeadModel(config) model.to(torch_device) + if gradient_checkpointing: + model.gradient_checkpointing_enable() result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) self.parent.assertEqual(result.loss.shape, ()) @@ -478,8 +481,8 @@ def test_gpt2_token_classification_model(self): self.model_tester.create_and_check_gpt2_for_token_classification(*config_and_inputs) def test_gpt2_gradient_checkpointing(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) - self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True) @slow def test_batch_generation(self): @@ -612,7 +615,11 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): @slow def test_lm_generate_gpt2(self): for checkpointing in [True, False]: - model = GPT2LMHeadModel.from_pretrained("gpt2", gradient_checkpointing=checkpointing) + model = GPT2LMHeadModel.from_pretrained("gpt2") + if checkpointing: + model.gradient_checkpointing_enable() + else: + model.gradient_checkpointing_disable() model.to(torch_device) input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog expected_output_ids = [ diff --git a/tests/test_modeling_gpt_neo.py b/tests/test_modeling_gpt_neo.py index fa1b63b4f616cc..a8e5b4babc57d6 100644 --- a/tests/test_modeling_gpt_neo.py +++ b/tests/test_modeling_gpt_neo.py @@ -97,7 +97,7 @@ def __init__( def get_large_model_config(self): return GPTNeoConfig.from_pretrained("gpt_neo") - def prepare_config_and_inputs(self, gradient_checkpointing=False): + def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None @@ -120,7 +120,7 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False): token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) - config = self.get_config(gradient_checkpointing=False) + config = self.get_config() head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) @@ -136,18 +136,17 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False): choice_labels, ) - def get_config(self, gradient_checkpointing=False): + def get_config(self): return GPTNeoConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_layers=self.num_hidden_layers, num_heads=self.num_attention_heads, max_position_embeddings=self.max_position_embeddings, - use_cache=not gradient_checkpointing, + use_cache=True, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, - gradient_checkpointing=gradient_checkpointing, window_size=self.window_size, attention_types=self.attention_types, ) @@ -329,8 +328,12 @@ def create_and_check_gpt_neo_for_sequence_classification( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) - def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + def create_and_check_forward_and_backwards( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False + ): model = GPTNeoForCausalLM(config) + if gradient_checkpointing: + model.gradient_checkpointing_enable() model.to(torch_device) result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) @@ -411,8 +414,8 @@ def test_gpt_neo_sequence_classification_model(self): self.model_tester.create_and_check_gpt_neo_for_sequence_classification(*config_and_inputs) def test_gpt_neo_gradient_checkpointing(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) - self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True) def _get_hidden_states(self): return torch.tensor( @@ -473,7 +476,10 @@ def tokenizer(self): def test_lm_generate_gpt_neo(self): for checkpointing in [True, False]: model = self.model - model.config.gradient_checkpointing = checkpointing + if checkpointing: + model.gradient_checkpointing_enable() + else: + model.gradient_checkpointing_disable() input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog # fmt: off # The dog-eared copy of the book, which is a collection of essays by the late author, diff --git a/tests/test_modeling_gptj.py b/tests/test_modeling_gptj.py index 5739aed5a1f76b..06979a2c7f82de 100644 --- a/tests/test_modeling_gptj.py +++ b/tests/test_modeling_gptj.py @@ -92,7 +92,7 @@ def __init__( def get_large_model_config(self): return GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B") - def prepare_config_and_inputs(self, gradient_checkpointing=False): + def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None @@ -115,7 +115,7 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False): token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) - config = self.get_config(gradient_checkpointing=gradient_checkpointing) + config = self.get_config() head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) @@ -131,7 +131,7 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False): choice_labels, ) - def get_config(self, gradient_checkpointing=False): + def get_config(self): return GPTJConfig( vocab_size=self.vocab_size, n_embd=self.hidden_size, @@ -145,11 +145,10 @@ def get_config(self, gradient_checkpointing=False): n_ctx=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, initializer_range=self.initializer_range, - use_cache=not gradient_checkpointing, + use_cache=True, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, - gradient_checkpointing=gradient_checkpointing, ) def prepare_config_and_inputs_for_decoder(self): @@ -318,8 +317,12 @@ def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mas self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + def create_and_check_forward_and_backwards( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False + ): model = GPTJForCausalLM(config) + if gradient_checkpointing: + model.gradient_checkpointing_enable() model.to(torch_device) result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) @@ -390,8 +393,8 @@ def test_gptj_lm_head_model(self): self.model_tester.create_and_check_lm_head_model(*config_and_inputs) def test_gptj_gradient_checkpointing(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) - self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True) @slow def test_batch_generation(self): @@ -464,7 +467,11 @@ class GPTJModelLanguageGenerationTest(unittest.TestCase): @slow def test_lm_generate_gptj(self): for checkpointing in [True, False]: - model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", gradient_checkpointing=checkpointing) + model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B") + if checkpointing: + model.gradient_checkpointing_enable() + else: + model.gradient_checkpointing_disable() model.to(torch_device) input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog expected_output_ids = [